package org.neo4j.gds.embeddings.graphsage;

import java.util.stream.IntStream;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.Dimensions;
import org.neo4j.gds.ml.core.RelationshipWeights;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.functions.SingleParentVariable;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Scalar;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageLoss.class */
public class GraphSageLoss extends SingleParentVariable<Scalar> {
    private static final int NEGATIVE_NODES_OFFSET = 2;
    private static final int SAMPLING_BUCKETS = 3;
    private final RelationshipWeights relationshipWeights;
    private final Variable<Matrix> combinedEmbeddings;
    private final long[] batch;
    private final int negativeSamplingFactor;
    private static final double ALPHA = 1.0d;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GraphSageLoss(RelationshipWeights relationshipWeights, Variable<Matrix> variable, long[] jArr, int i) {
        super(variable, Dimensions.scalar());
        this.relationshipWeights = relationshipWeights;
        this.combinedEmbeddings = variable;
        this.batch = jArr;
        this.negativeSamplingFactor = i;
    }

    /* renamed from: apply, reason: merged with bridge method [inline-methods] */
    public Scalar m18apply(ComputationContext computationContext) {
        Matrix data = computationContext.data(this.combinedEmbeddings);
        int rows = data.rows() / SAMPLING_BUCKETS;
        int i = NEGATIVE_NODES_OFFSET * rows;
        return new Scalar(IntStream.range(0, rows).mapToDouble(i2 -> {
            int i2 = i2 + rows;
            return ((-relationshipWeightFactor(this.batch[i2], this.batch[i2])) * Math.log(Sigmoid.sigmoid(affinity(data, i2, i2)))) - (this.negativeSamplingFactor * Math.log(Sigmoid.sigmoid(-affinity(data, i2, i2 + i))));
        }).sum());
    }

    private double relationshipWeightFactor(long j, long j2) {
        double weight = this.relationshipWeights.weight(j, j2);
        if (Double.isNaN(weight)) {
            weight = 1.0d;
        }
        return Math.pow(weight, ALPHA);
    }

    private double affinity(Matrix matrix, int i, int i2) {
        int dimension = this.combinedEmbeddings.dimension(1);
        double d = 0.0d;
        for (int i3 = 0; i3 < dimension; i3++) {
            d += matrix.dataAt(i, i3) * matrix.dataAt(i2, i3);
        }
        return d;
    }

    public Matrix gradient(Variable<?> variable, ComputationContext computationContext) {
        if (variable != this.combinedEmbeddings) {
            throw new IllegalStateException(StringFormatting.formatWithLocale("This variable only has a single parent. Expected %s but got %s", new Object[]{this.combinedEmbeddings, variable}));
        }
        Matrix matrix = (Matrix) computationContext.data(this.combinedEmbeddings);
        Matrix createWithSameDimensions = matrix.createWithSameDimensions();
        int rows = matrix.rows() / SAMPLING_BUCKETS;
        int i = NEGATIVE_NODES_OFFSET * rows;
        int cols = matrix.cols();
        for (int i2 = 0; i2 < rows; i2++) {
            int i3 = i2 + rows;
            int i4 = i2 + i;
            double affinity = affinity(matrix, i2, i3);
            double affinity2 = affinity(matrix, i2, i4);
            double logisticFunction = logisticFunction(affinity);
            double logisticFunction2 = logisticFunction(-affinity2);
            for (int i5 = 0; i5 < cols; i5++) {
                computeGradientForEmbeddingIdx(matrix, createWithSameDimensions, i2, i3, i4, logisticFunction, logisticFunction2, i5);
            }
        }
        return createWithSameDimensions;
    }

    private void computeGradientForEmbeddingIdx(Matrix matrix, Matrix matrix2, int i, int i2, int i3, double d, double d2, int i4) {
        double relationshipWeightFactor = relationshipWeightFactor(this.batch[i], this.batch[i2]) * d;
        matrix2.setDataAt(i, i4, ((-matrix.dataAt(i2, i4)) * relationshipWeightFactor) + (this.negativeSamplingFactor * matrix.dataAt(i3, i4) * d2));
        double dataAt = matrix.dataAt(i, i4);
        matrix2.setDataAt(i2, i4, (-dataAt) * relationshipWeightFactor);
        matrix2.setDataAt(i3, i4, this.negativeSamplingFactor * dataAt * d2);
    }

    private double logisticFunction(double d) {
        return ALPHA / (ALPHA + Math.pow(2.718281828459045d, d));
    }

    /* renamed from: gradient, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Tensor m17gradient(Variable variable, ComputationContext computationContext) {
        return gradient((Variable<?>) variable, computationContext);
    }
}
