package org.neo4j.gds.ml.nodemodels.multiclasslogisticregression;

import java.util.Iterator;
import java.util.List;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.ConstantScale;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.ElementSum;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.L2NormSquared;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MatrixConstant;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MultiClassCrossEntropyLoss;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Weights;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.gds.embeddings.graphsage.subgraph.LocalIdMap;
import org.neo4j.gds.ml.Batch;
import org.neo4j.gds.ml.Objective;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeProperties;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/multiclasslogisticregression/MultiClassNLRObjective.class */
public class MultiClassNLRObjective implements Objective<MultiClassNLRData> {
    private final String targetPropertyKey;
    private final Graph graph;
    private final double penalty;
    private final MultiClassNLRPredictor predictor;

    public MultiClassNLRObjective(List<String> list, String str, Graph graph, double d) {
        this.predictor = new MultiClassNLRPredictor(makeData(list, str, graph), list);
        this.targetPropertyKey = str;
        this.graph = graph;
        this.penalty = d;
    }

    private static MultiClassNLRData makeData(List<String> list, String str, Graph graph) {
        LocalIdMap makeClassIdMap = makeClassIdMap(str, graph);
        return MultiClassNLRData.builder().classIdMap(makeClassIdMap).weights(initWeights(list, makeClassIdMap.originalIds().length)).build();
    }

    private static LocalIdMap makeClassIdMap(String str, Graph graph) {
        LocalIdMap localIdMap = new LocalIdMap();
        graph.forEachNode(j -> {
            localIdMap.toMapped((long) graph.nodeProperties(str).doubleValue(j));
            return true;
        });
        return localIdMap;
    }

    private static Weights<Matrix> initWeights(List<String> list, int i) {
        return new Weights<>(Matrix.fill(0.0d, i, list.size() + 1));
    }

    @Override // org.neo4j.gds.ml.Objective
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(modelData().weights());
    }

    @Override // org.neo4j.gds.ml.Objective
    public Variable<Scalar> loss(Batch batch, long j) {
        return new ElementSum(List.of(new MultiClassCrossEntropyLoss(this.predictor.predictionsVariable(this.graph, batch), makeTargets(batch)), new ConstantScale(new L2NormSquared(modelData().weights()), (batch.size() * this.penalty) / j)));
    }

    private MatrixConstant makeTargets(Batch batch) {
        Iterable<Long> nodeIds = batch.nodeIds();
        int size = batch.size();
        double[] dArr = new double[size];
        int i = 0;
        LocalIdMap classIdMap = modelData().classIdMap();
        NodeProperties nodeProperties = this.graph.nodeProperties(this.targetPropertyKey);
        Iterator<Long> it = nodeIds.iterator();
        while (it.hasNext()) {
            dArr[i] = classIdMap.toMapped((long) nodeProperties.doubleValue(it.next().longValue()));
            i++;
        }
        return new MatrixConstant(dArr, size, 1);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Objective
    public MultiClassNLRData modelData() {
        return this.predictor.modelData();
    }
}
