package org.neo4j.gds.embeddings.node2vec;

import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.tensor.FloatVector;
import org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations;

/* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/TrainingTask.class */
final class TrainingTask implements Runnable {
    private final HugeObjectArray<FloatVector> centerEmbeddings;
    private final HugeObjectArray<FloatVector> contextEmbeddings;
    private final PositiveSampleProducer positiveSampleProducer;
    private final NegativeSampleProducer negativeSampleProducer;
    private final FloatVector centerGradientBuffer;
    private final FloatVector contextGradientBuffer;
    private final int negativeSamplingRate;
    private final float learningRate;
    private double lossSum;

    /* JADX INFO: Access modifiers changed from: package-private */
    public TrainingTask(HugeObjectArray<FloatVector> hugeObjectArray, HugeObjectArray<FloatVector> hugeObjectArray2, PositiveSampleProducer positiveSampleProducer, NegativeSampleProducer negativeSampleProducer, float f, int i, int i2) {
        this.centerEmbeddings = hugeObjectArray;
        this.contextEmbeddings = hugeObjectArray2;
        this.positiveSampleProducer = positiveSampleProducer;
        this.negativeSampleProducer = negativeSampleProducer;
        this.learningRate = f;
        this.negativeSamplingRate = i;
        this.centerGradientBuffer = new FloatVector(i2);
        this.contextGradientBuffer = new FloatVector(i2);
    }

    @Override // java.lang.Runnable
    public void run() {
        long[] jArr = new long[2];
        while (this.positiveSampleProducer.next(jArr)) {
            trainPositiveSample(jArr[0], jArr[1]);
            for (int i = 0; i < this.negativeSamplingRate; i++) {
                trainNegativeSample(jArr[0], this.negativeSampleProducer.next());
            }
        }
    }

    void trainPositiveSample(long j, long j2) {
        FloatVector floatVector = (FloatVector) this.centerEmbeddings.get(j);
        FloatVector floatVector2 = (FloatVector) this.contextEmbeddings.get(j2);
        updateEmbeddings(floatVector, floatVector2, computePositiveGradient(floatVector, floatVector2), this.centerGradientBuffer, this.contextGradientBuffer);
    }

    void trainNegativeSample(long j, long j2) {
        FloatVector floatVector = (FloatVector) this.centerEmbeddings.get(j);
        FloatVector floatVector2 = (FloatVector) this.contextEmbeddings.get(j2);
        updateEmbeddings(floatVector, floatVector2, computeNegativeGradient(floatVector, floatVector2), this.centerGradientBuffer, this.contextGradientBuffer);
    }

    float computePositiveGradient(FloatVector floatVector, FloatVector floatVector2) {
        double sigmoid = Sigmoid.sigmoid(floatVector.innerProduct(floatVector2));
        this.lossSum -= Math.log(sigmoid + 1.0E-10d);
        return (-((float) (-(1.0d - sigmoid)))) * this.learningRate;
    }

    float computeNegativeGradient(FloatVector floatVector, FloatVector floatVector2) {
        double sigmoid = Sigmoid.sigmoid(floatVector.innerProduct(floatVector2));
        this.lossSum -= Math.log((1.0d - sigmoid) + 1.0E-10d);
        return (-((float) sigmoid)) * this.learningRate;
    }

    void updateEmbeddings(FloatVector floatVector, FloatVector floatVector2, float f, FloatVector floatVector3, FloatVector floatVector4) {
        FloatVectorOperations.scale(floatVector2.data(), f, floatVector3.data());
        FloatVectorOperations.scale(floatVector.data(), f, floatVector4.data());
        FloatVectorOperations.addInPlace(floatVector.data(), floatVector3.data());
        FloatVectorOperations.addInPlace(floatVector2.data(), floatVector4.data());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public double lossSum() {
        return this.lossSum;
    }
}
