package org.neo4j.gds.ml.linkmodels.logisticregression;

import java.util.List;
import org.apache.commons.lang3.mutable.MutableInt;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.LogisticLoss;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MatrixConstant;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Weights;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.gds.ml.Batch;
import org.neo4j.gds.ml.Objective;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/logisticregression/LinkLogisticRegressionObjective.class */
public class LinkLogisticRegressionObjective extends LinkLogisticRegressionBase implements Objective<LinkLogisticRegressionData> {
    private final Graph graph;

    public LinkLogisticRegressionObjective(List<String> list, LinkFeatureCombiner linkFeatureCombiner, Graph graph) {
        super(makeData(list, linkFeatureCombiner));
        this.graph = graph;
    }

    private static LinkLogisticRegressionData makeData(List<String> list, LinkFeatureCombiner linkFeatureCombiner) {
        return LinkLogisticRegressionData.builder().weights(initWeights(list)).linkFeatureCombiner(linkFeatureCombiner).nodePropertyKeys(list).numberOfFeatures(computeNumberOfFeatures(list)).build();
    }

    private static int computeNumberOfFeatures(List<String> list) {
        return list.size() + 1;
    }

    private static Weights<Matrix> initWeights(List<String> list) {
        double[] dArr = new double[computeNumberOfFeatures(list)];
        return new Weights<>(new Matrix(dArr, 1, dArr.length));
    }

    @Override // org.neo4j.gds.ml.Objective
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.modelData.weights());
    }

    @Override // org.neo4j.gds.ml.Objective
    public Variable<Scalar> loss(Batch batch, long j) {
        MatrixConstant features = features(this.graph, batch);
        Variable<Matrix> predictions = predictions(features);
        MutableInt mutableInt = new MutableInt();
        batch.nodeIds().forEach(l -> {
            mutableInt.add(this.graph.degree(l.longValue()));
        });
        Integer value = mutableInt.getValue();
        return new LogisticLoss(this.modelData.weights(), predictions, features, new MatrixConstant(makeTargetsArray(batch, value.intValue()), value.intValue(), 1));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.Objective
    public LinkLogisticRegressionData modelData() {
        return this.modelData;
    }

    private double[] makeTargetsArray(Batch batch, int i) {
        Graph concurrentCopy = this.graph.concurrentCopy();
        double[] dArr = new double[i];
        MutableInt mutableInt = new MutableInt();
        batch.nodeIds().forEach(l -> {
            concurrentCopy.forEachRelationship(l.longValue(), -0.66d, (j, j2, d) -> {
                if (Double.compare(d, 1.0d) == 0) {
                    dArr[mutableInt.getValue().intValue()] = 1.0d;
                } else {
                    if (Double.compare(d, 0.0d) != 0) {
                        throw new IllegalArgumentException(StringFormatting.formatWithLocale("The relationship property must have value %d or %d but it has %d", new Object[]{Double.valueOf(0.0d), Double.valueOf(1.0d), Double.valueOf(d)}));
                    }
                    dArr[mutableInt.getValue().intValue()] = 0.0d;
                }
                mutableInt.increment();
                return true;
            });
        });
        return dArr;
    }
}
