package org.neo4j.gds.embeddings.graphsage;

import java.util.stream.IntStream;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Dimensions;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Sigmoid;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.SingleParentVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageLoss.class */
public class GraphSageLoss extends SingleParentVariable<Scalar> {
    private static final int NEGATIVE_NODES_OFFSET = 2;
    private final RelationshipWeightsFunction relationshipWeightsFunction;
    private final Variable<Matrix> combinedEmbeddings;
    private final int negativeSamplingFactor;
    private static final double ALPHA = 1.0d;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GraphSageLoss(RelationshipWeightsFunction relationshipWeightsFunction, Variable<Matrix> variable, int i) {
        super(variable, Dimensions.scalar());
        this.relationshipWeightsFunction = relationshipWeightsFunction;
        this.combinedEmbeddings = variable;
        this.negativeSamplingFactor = i;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Scalar apply(ComputationContext computationContext) {
        Tensor<?> data = computationContext.data(parent());
        int dimension = data.dimension(0) / 3;
        return new Scalar(IntStream.range(0, dimension).mapToDouble(i -> {
            int i = i + dimension;
            return ((-relationshipWeightFactor(i, i)) * Math.log(Sigmoid.sigmoid(affinity(data, i, i)))) - (this.negativeSamplingFactor * Math.log(Sigmoid.sigmoid(-affinity(data, i, i + (NEGATIVE_NODES_OFFSET * dimension)))));
        }).sum());
    }

    private double relationshipWeightFactor(int i, int i2) {
        double weight = this.relationshipWeightsFunction.weight(i, i2);
        if (Double.isNaN(weight)) {
            weight = 1.0d;
        }
        return Math.pow(weight, 1.0d);
    }

    private double affinity(Tensor<?> tensor, int i, int i2) {
        int dimension = this.combinedEmbeddings.dimension(1);
        double d = 0.0d;
        for (int i3 = 0; i3 < dimension; i3++) {
            d += tensor.dataAt((i * dimension) + i3) * tensor.dataAt((i2 * dimension) + i3);
        }
        return d;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix gradient(Variable<?> variable, ComputationContext computationContext) {
        Tensor<?> data = computationContext.data(variable);
        double[] data2 = data.data();
        int dimension = data.dimension(0);
        int i = dimension / 3;
        int dimension2 = data.dimension(1);
        double[] dArr = new double[dimension * dimension2];
        IntStream.range(0, i).forEach(i2 -> {
            int i2 = i2 + i;
            int i3 = i2 + (NEGATIVE_NODES_OFFSET * i);
            int dimension3 = variable.dimension(1);
            double affinity = affinity(data, i2, i2);
            double affinity2 = affinity(data, i2, i3);
            double logisticFunction = logisticFunction(affinity);
            double logisticFunction2 = logisticFunction(-affinity2);
            IntStream.range(0, dimension2).forEach(i4 -> {
                partialComputeGradient(data2, dArr, i2, i2, i3, dimension3, logisticFunction, logisticFunction2, i4);
            });
        });
        return new Matrix(dArr, dimension, dimension2);
    }

    private void partialComputeGradient(double[] dArr, double[] dArr2, int i, int i2, int i3, int i4, double d, double d2, int i5) {
        int i6 = (i * i4) + i5;
        int i7 = (i2 * i4) + i5;
        int i8 = (i3 * i4) + i5;
        double relationshipWeightFactor = relationshipWeightFactor(i, i2) * d;
        dArr2[i6] = ((-dArr[i7]) * relationshipWeightFactor) + (this.negativeSamplingFactor * dArr[i8] * d2);
        dArr2[i7] = (-dArr[i6]) * relationshipWeightFactor;
        dArr2[i8] = this.negativeSamplingFactor * dArr[i6] * d2;
    }

    private double logisticFunction(double d) {
        return 1.0d / (1.0d + Math.pow(2.718281828459045d, d));
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public /* bridge */ /* synthetic */ Tensor gradient(Variable variable, ComputationContext computationContext) {
        return gradient((Variable<?>) variable, computationContext);
    }
}
