package org.neo4j.gds.ml.nodemodels.logisticregression;

import java.util.Iterator;
import java.util.List;
import org.neo4j.gds.ml.Objective;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.functions.Constant;
import org.neo4j.gds.ml.core.functions.ConstantScale;
import org.neo4j.gds.ml.core.functions.CrossEntropyLoss;
import org.neo4j.gds.ml.core.functions.ElementSum;
import org.neo4j.gds.ml.core.functions.L2NormSquared;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeProperties;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/logisticregression/NodeLogisticRegressionObjective.class */
public class NodeLogisticRegressionObjective implements Objective<NodeLogisticRegressionData> {
    private final String targetProperty;
    private final Graph graph;
    private final double penalty;
    private final NodeLogisticRegressionPredictor predictor;

    public static long sizeOfBatchInBytes(int i, int i2, int i3) {
        long sizeInBytes = Weights.sizeInBytes(i3, i2);
        long sizeInBytes2 = Matrix.sizeInBytes(i, 1);
        long sizeInBytes3 = MatrixMultiplyWithTransposedSecondOperand.sizeInBytes(Dimensions.matrix(i, i2), Dimensions.matrix(i3, i2));
        long sizeInBytes4 = CrossEntropyLoss.sizeInBytes();
        long sizeInBytesOfApply = L2NormSquared.sizeInBytesOfApply();
        long sizeInBytesOfApply2 = L2NormSquared.sizeInBytesOfApply();
        long sizeOfPredictionsVariableInBytes = NodeLogisticRegressionPredictor.sizeOfPredictionsVariableInBytes(i, i2, i3);
        return Math.max((1 * sizeInBytes2) + (1 * sizeInBytes3) + (1 * sizeInBytes3) + (2 * sizeInBytes4) + (2 * sizeInBytesOfApply) + (2 * sizeInBytesOfApply2) + (2 * sizeInBytesOfApply2) + sizeOfPredictionsVariableInBytes + sizeInBytes, (1 * sizeInBytes2) + (1 * sizeInBytes3) + (1 * sizeInBytes3) + (1 * sizeInBytes4) + (1 * sizeInBytesOfApply) + (1 * sizeInBytesOfApply2) + (1 * sizeInBytesOfApply2) + sizeOfPredictionsVariableInBytes);
    }

    static long costOfMakeTargets(int i) {
        return Matrix.sizeInBytes(i, 1);
    }

    public NodeLogisticRegressionObjective(Graph graph, NodeLogisticRegressionPredictor nodeLogisticRegressionPredictor, String str, double d) {
        this.predictor = nodeLogisticRegressionPredictor;
        this.targetProperty = str;
        this.graph = graph;
        this.penalty = d;
    }

    @Override // org.neo4j.gds.ml.Objective
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(modelData().weights());
    }

    @Override // org.neo4j.gds.ml.Objective
    public Variable<Scalar> loss(Batch batch, long j) {
        return new ElementSum(List.of(new CrossEntropyLoss(this.predictor.predictionsVariable(this.graph, batch), makeTargets(batch)), new ConstantScale(new L2NormSquared(modelData().weights()), (batch.size() * this.penalty) / j)));
    }

    private Constant<Vector> makeTargets(Batch batch) {
        Iterable nodeIds = batch.nodeIds();
        double[] dArr = new double[batch.size()];
        int i = 0;
        LocalIdMap classIdMap = modelData().classIdMap();
        NodeProperties nodeProperties = this.graph.nodeProperties(this.targetProperty);
        Iterator it = nodeIds.iterator();
        while (it.hasNext()) {
            dArr[i] = classIdMap.toMapped((long) nodeProperties.doubleValue(((Long) it.next()).longValue()));
            i++;
        }
        return Constant.vector(dArr);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Objective
    public NodeLogisticRegressionData modelData() {
        return this.predictor.modelData();
    }
}
