package de.jungblut.classification.regression;

import de.jungblut.math.DoubleMatrix;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.activation.ActivationFunctionSelector;
import de.jungblut.math.dense.DenseDoubleMatrix;
import de.jungblut.math.loss.LogLoss;
import de.jungblut.math.loss.LossFunction;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import java.util.Arrays;

/* loaded from: input_file:de/jungblut/classification/regression/LogisticRegressionCostFunction.class */
public final class LogisticRegressionCostFunction implements CostFunction {
    private static final LossFunction ERROR_FUNCTION = new LogLoss();
    private final DoubleMatrix x;
    private final DoubleMatrix xTransposed;
    private final DoubleMatrix y;
    private final int m;
    private final double lambda;

    public LogisticRegressionCostFunction(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, double d) {
        this.x = doubleMatrix;
        this.lambda = d;
        this.m = doubleMatrix.getRowCount();
        this.xTransposed = this.x.transpose();
        this.y = doubleMatrix2;
    }

    @Override // de.jungblut.math.minimize.CostFunction
    public CostGradientTuple evaluateCost(DoubleVector doubleVector) {
        DoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(Arrays.asList(ActivationFunctionSelector.SIGMOID.get().apply(this.x.multiplyVectorRow(doubleVector))));
        double calculateLoss = ERROR_FUNCTION.calculateLoss(this.y, denseDoubleMatrix);
        double d = calculateLoss / this.m;
        DoubleVector divide = this.xTransposed.multiplyVectorRow(denseDoubleMatrix.subtract(this.y).getRowVector(0)).divide(this.m);
        if (this.lambda != 0.0d) {
            DoubleVector multiply = doubleVector.multiply(this.lambda / this.m);
            multiply.set(0, 0.0d);
            divide = divide.add(multiply);
            d += (this.lambda * doubleVector.pow(2.0d).sum()) / this.m;
        }
        return new CostGradientTuple(d, divide);
    }
}
