package org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression;

import java.util.List;
import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.ml.Objective;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.LogisticLoss;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/pipeline/logisticRegression/LinkLogisticRegressionObjective.class */
public class LinkLogisticRegressionObjective implements Objective<LinkLogisticRegressionData> {
    private final LinkLogisticRegressionData modelData;
    private final double penalty;
    private final HugeObjectArray<double[]> linkFeatures;
    private final HugeDoubleArray targets;
    static final /* synthetic */ boolean $assertionsDisabled;

    public LinkLogisticRegressionObjective(LinkLogisticRegressionData linkLogisticRegressionData, double d, HugeObjectArray<double[]> hugeObjectArray, HugeDoubleArray hugeDoubleArray) {
        this.modelData = linkLogisticRegressionData;
        this.penalty = d;
        this.linkFeatures = hugeObjectArray;
        this.targets = hugeDoubleArray;
    }

    @Override // org.neo4j.gds.ml.Objective
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.modelData.weights());
    }

    @Override // org.neo4j.gds.ml.Objective
    public Variable<Scalar> loss(Batch batch, long j) {
        Constant<Matrix> features = features(batch, this.linkFeatures);
        return new ElementSum(List.of(new LogisticLoss(this.modelData.weights(), new Sigmoid(MatrixMultiplyWithTransposedSecondOperand.of(features, this.modelData.weights())), features, makeTargetsArray(batch)), new ConstantScale(new L2NormSquared(this.modelData.weights()), (batch.size() * this.penalty) / j)));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Objective
    public LinkLogisticRegressionData modelData() {
        return this.modelData;
    }

    Constant<Vector> makeTargetsArray(Batch batch) {
        Vector vector = new Vector(batch.size());
        MutableInt mutableInt = new MutableInt();
        batch.nodeIds().forEach(l -> {
            vector.setDataAt(mutableInt.getAndIncrement(), this.targets.get(l.longValue()));
        });
        return new Constant<>(vector);
    }

    Constant<Matrix> features(Batch batch, HugeObjectArray<double[]> hugeObjectArray) {
        if (!$assertionsDisabled && hugeObjectArray.size() <= 0) {
            throw new AssertionError();
        }
        Matrix matrix = new Matrix(batch.size(), ((double[]) hugeObjectArray.get(0L)).length);
        MutableInt mutableInt = new MutableInt();
        batch.nodeIds().forEach(l -> {
            matrix.setRow(mutableInt.getAndIncrement(), (double[]) hugeObjectArray.get(l.longValue()));
        });
        return new Constant<>(matrix);
    }

    static {
        $assertionsDisabled = !LinkLogisticRegressionObjective.class.desiredAssertionStatus();
    }
}
