package org.neo4j.gds.models.logisticregression;

import java.util.List;
import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.gradientdescent.Objective;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.ReducedCrossEntropyLoss;
import org.neo4j.gds.ml.core.functions.Softmax;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.models.Features;

/* loaded from: input_file:org/neo4j/gds/models/logisticregression/LogisticRegressionObjective.class */
public class LogisticRegressionObjective implements Objective<LogisticRegressionData> {
    private final LogisticRegressionClassifier classifier;
    private final double penalty;
    private final Features features;
    private final HugeLongArray labels;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static long sizeOfBatchInBytes(boolean z, int i, int i2, int i3) {
        int i4 = z ? i3 - 1 : i3;
        long sizeInBytes = Weights.sizeInBytes(i4, i2);
        long sizeInBytes2 = Matrix.sizeInBytes(i, 1);
        long sizeInBytes3 = MatrixMultiplyWithTransposedSecondOperand.sizeInBytes(i, i4);
        long sizeInBytes4 = Softmax.sizeInBytes(i, i3);
        long sizeInBytes5 = ReducedCrossEntropyLoss.sizeInBytes();
        long sizeInBytesOfApply = L2NormSquared.sizeInBytesOfApply();
        long sizeOfPredictionsVariableInBytes = LogisticRegressionClassifier.sizeOfPredictionsVariableInBytes(i, i2, i3, i4);
        return Math.max((1 * sizeInBytes2) + (1 * sizeInBytes3) + (1 * sizeInBytes4) + (2 * sizeInBytes5) + (2 * sizeInBytesOfApply) + (2 * sizeInBytesOfApply) + (2 * sizeInBytesOfApply) + sizeOfPredictionsVariableInBytes + sizeInBytes, (1 * sizeInBytes2) + (1 * sizeInBytes3) + (1 * sizeInBytes4) + (1 * sizeInBytes5) + (1 * sizeInBytesOfApply) + (1 * sizeInBytesOfApply) + (1 * sizeInBytesOfApply) + sizeOfPredictionsVariableInBytes);
    }

    public LogisticRegressionObjective(LogisticRegressionClassifier logisticRegressionClassifier, double d, Features features, HugeLongArray hugeLongArray) {
        this.classifier = logisticRegressionClassifier;
        this.penalty = d;
        this.features = features;
        this.labels = hugeLongArray;
        if (!$assertionsDisabled && features.size() <= 0) {
            throw new AssertionError();
        }
    }

    @Override // org.neo4j.gds.gradientdescent.Objective
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.classifier.data().weights(), this.classifier.data().bias());
    }

    @Override // org.neo4j.gds.gradientdescent.Objective
    public Variable<Scalar> loss(Batch batch, long j) {
        return new ElementSum(List.of(crossEntropyLoss(batch), penaltyForBatch(batch, j)));
    }

    public ConstantScale<Scalar> penaltyForBatch(Batch batch, long j) {
        return new ConstantScale<>(new L2NormSquared(modelData().weights()), (batch.size() * this.penalty) / j);
    }

    public ReducedCrossEntropyLoss crossEntropyLoss(Batch batch) {
        Constant<Vector> batchLabelVector = batchLabelVector(batch, this.classifier.classIdMap());
        Constant<Matrix> batchFeatureMatrix = LogisticRegressionClassifier.batchFeatureMatrix(batch, this.features);
        return new ReducedCrossEntropyLoss(this.classifier.predictionsVariable(batchFeatureMatrix), this.classifier.data().weights(), this.classifier.data().bias(), batchFeatureMatrix, batchLabelVector);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.gradientdescent.Objective
    public LogisticRegressionData modelData() {
        return this.classifier.data();
    }

    Constant<Vector> batchLabelVector(Batch batch, LocalIdMap localIdMap) {
        Vector vector = new Vector(batch.size());
        MutableInt mutableInt = new MutableInt();
        batch.nodeIds().forEach(l -> {
            vector.setDataAt(mutableInt.getAndIncrement(), localIdMap.toMapped(this.labels.get(l.longValue())));
        });
        return new Constant<>(vector);
    }

    static {
        $assertionsDisabled = !LogisticRegressionObjective.class.desiredAssertionStatus();
    }
}
