package org.neo4j.gds.ml.nodemodels.logisticregression;

import java.util.Iterator;
import java.util.List;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.LogisticLoss;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MatrixConstant;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Weights;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.gds.ml.Batch;
import org.neo4j.gds.ml.Objective;
import org.neo4j.gds.ml.nodemodels.NodeFeaturesSupport;
import org.neo4j.graphalgo.api.Graph;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/logisticregression/NodeLogisticRegressionObjective.class */
public class NodeLogisticRegressionObjective implements Objective<NodeLogisticRegressionData> {
    private final String targetPropertyKey;
    private final Graph graph;
    private final NodeLogisticRegressionPredictor predictor;

    public NodeLogisticRegressionObjective(List<String> list, String str, Graph graph) {
        this.predictor = new NodeLogisticRegressionPredictor(makeData(list));
        this.targetPropertyKey = str;
        this.graph = graph;
    }

    @Override // org.neo4j.gds.ml.Objective
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(modelData().weights());
    }

    @Override // org.neo4j.gds.ml.Objective
    public Variable<Scalar> loss(Batch batch, long j) {
        MatrixConstant makeTargets = makeTargets(batch);
        MatrixConstant features = NodeFeaturesSupport.features(this.graph, batch, modelData().nodePropertyKeys());
        return new LogisticLoss(modelData().weights(), this.predictor.predictionsVariable(this.graph, batch), features, makeTargets);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Objective
    public NodeLogisticRegressionData modelData() {
        return this.predictor.modelData();
    }

    private static NodeLogisticRegressionData makeData(List<String> list) {
        return NodeLogisticRegressionData.builder().weights(initWeights(list)).nodePropertyKeys(list).build();
    }

    private static Weights<Matrix> initWeights(List<String> list) {
        double[] dArr = new double[list.size() + 1];
        return new Weights<>(new Matrix(dArr, 1, dArr.length));
    }

    private MatrixConstant makeTargets(Batch batch) {
        Iterable<Long> nodeIds = batch.nodeIds();
        int size = batch.size();
        double[] dArr = new double[size];
        int i = 0;
        Iterator<Long> it = nodeIds.iterator();
        while (it.hasNext()) {
            dArr[i] = this.graph.nodeProperties(this.targetPropertyKey).doubleValue(it.next().longValue());
            i++;
        }
        return new MatrixConstant(dArr, size, 1);
    }
}
