package org.neo4j.gds.embeddings.node2vec;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.PrimitiveIterator;
import java.util.SplittableRandom;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.collection.primitive.PrimitiveLongCollections;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.BitUtil;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.tensor.FloatVector;
import org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecModel.class */
public class Node2VecModel {
    private final NegativeSampleProducer negativeSamples;
    private final HugeObjectArray<FloatVector> centerEmbeddings;
    private final HugeObjectArray<FloatVector> contextEmbeddings;
    private final Node2VecBaseConfig config;
    private final CompressedRandomWalks walks;
    private final RandomWalkProbabilities randomWalkProbabilities;
    private final ProgressTracker progressTracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecModel$FloatConsumer.class */
    public static class FloatConsumer {
        float[] values;
        int index;

        FloatConsumer(int i) {
            this.values = new float[i];
        }

        void add(double d) {
            float[] fArr = this.values;
            int i = this.index;
            this.index = i + 1;
            fArr[i] = (float) d;
        }

        void addAll(FloatConsumer floatConsumer) {
            System.arraycopy(floatConsumer.values, 0, this.values, this.index, floatConsumer.index);
            this.index += floatConsumer.index;
        }
    }

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecModel$Result.class */
    public interface Result {
        HugeObjectArray<FloatVector> embeddings();

        List<Double> lossPerIteration();
    }

    /* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecModel$TrainingTask.class */
    private static final class TrainingTask implements Runnable {
        private final HugeObjectArray<FloatVector> centerEmbeddings;
        private final HugeObjectArray<FloatVector> contextEmbeddings;
        private final PositiveSampleProducer positiveSampleProducer;
        private final NegativeSampleProducer negativeSampleProducer;
        private final FloatVector centerGradientBuffer;
        private final FloatVector contextGradientBuffer;
        private final int negativeSamplingRate;
        private final float learningRate;
        private double lossSum;

        private TrainingTask(HugeObjectArray<FloatVector> hugeObjectArray, HugeObjectArray<FloatVector> hugeObjectArray2, PositiveSampleProducer positiveSampleProducer, NegativeSampleProducer negativeSampleProducer, float f, int i, int i2) {
            this.centerEmbeddings = hugeObjectArray;
            this.contextEmbeddings = hugeObjectArray2;
            this.positiveSampleProducer = positiveSampleProducer;
            this.negativeSampleProducer = negativeSampleProducer;
            this.learningRate = f;
            this.negativeSamplingRate = i;
            this.centerGradientBuffer = new FloatVector(i2);
            this.contextGradientBuffer = new FloatVector(i2);
        }

        @Override // java.lang.Runnable
        public void run() {
            long[] jArr = new long[2];
            while (this.positiveSampleProducer.next(jArr)) {
                trainSample(jArr[0], jArr[1], true);
                for (int i = 0; i < this.negativeSamplingRate; i++) {
                    trainSample(jArr[0], this.negativeSampleProducer.next(), false);
                }
            }
        }

        private void trainSample(long j, long j2, boolean z) {
            FloatVector floatVector = (FloatVector) this.centerEmbeddings.get(j);
            FloatVector floatVector2 = (FloatVector) this.contextEmbeddings.get(j2);
            float innerProduct = z ? -floatVector.innerProduct(floatVector2) : floatVector.innerProduct(floatVector2);
            float sigmoid = (float) (z ? Sigmoid.sigmoid(innerProduct) : -Sigmoid.sigmoid(innerProduct));
            this.lossSum -= sigmoid;
            FloatVectorOperations.scale(floatVector2.data(), sigmoid * this.learningRate, this.centerGradientBuffer.data());
            FloatVectorOperations.scale(floatVector.data(), sigmoid * this.learningRate, this.contextGradientBuffer.data());
            FloatVectorOperations.addInPlace(floatVector.data(), this.centerGradientBuffer.data());
            FloatVectorOperations.addInPlace(floatVector2.data(), this.contextGradientBuffer.data());
        }

        double lossSum() {
            return this.lossSum;
        }
    }

    public static MemoryEstimation memoryEstimation(Node2VecBaseConfig node2VecBaseConfig) {
        long sizeOfFloatArray = MemoryUsage.sizeOfFloatArray(node2VecBaseConfig.embeddingDimension());
        return MemoryEstimations.builder(Node2Vec.class.getSimpleName()).perNode("center embeddings", j -> {
            return HugeObjectArray.memoryEstimation(j, sizeOfFloatArray);
        }).perNode("context embeddings", j2 -> {
            return HugeObjectArray.memoryEstimation(j2, sizeOfFloatArray);
        }).build();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Node2VecModel(long j, Node2VecBaseConfig node2VecBaseConfig, CompressedRandomWalks compressedRandomWalks, RandomWalkProbabilities randomWalkProbabilities, ProgressTracker progressTracker) {
        this.config = node2VecBaseConfig;
        this.walks = compressedRandomWalks;
        this.randomWalkProbabilities = randomWalkProbabilities;
        this.progressTracker = progressTracker;
        this.negativeSamples = new NegativeSampleProducer(randomWalkProbabilities.negativeSamplingDistribution());
        SplittableRandom splittableRandom = new SplittableRandom(((Long) node2VecBaseConfig.randomSeed().orElseGet(() -> {
            return Long.valueOf(new SplittableRandom().nextLong());
        })).longValue());
        this.centerEmbeddings = initializeEmbeddings(j, node2VecBaseConfig.embeddingDimension(), splittableRandom);
        this.contextEmbeddings = initializeEmbeddings(j, node2VecBaseConfig.embeddingDimension(), splittableRandom);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Result train() {
        this.progressTracker.beginSubTask();
        double initialLearningRate = (this.config.initialLearningRate() - this.config.minLearningRate()) / this.config.iterations();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.config.iterations(); i++) {
            this.progressTracker.beginSubTask();
            this.progressTracker.setVolume(this.walks.size());
            float max = (float) Math.max(this.config.minLearningRate(), this.config.initialLearningRate() - (i * initialLearningRate));
            PrimitiveIterator.OfLong range = PrimitiveLongCollections.range(0L, this.walks.size() - 1);
            CompressedRandomWalks compressedRandomWalks = this.walks;
            Objects.requireNonNull(compressedRandomWalks);
            List degreePartitionWithBatchSize = PartitionUtils.degreePartitionWithBatchSize(range, compressedRandomWalks::walkLength, BitUtil.ceilDiv(this.randomWalkProbabilities.sampleCount(), this.config.concurrency()), degreePartition -> {
                return new TrainingTask(this.centerEmbeddings, this.contextEmbeddings, new PositiveSampleProducer(this.walks.iterator(degreePartition.startNode(), degreePartition.nodeCount()), this.randomWalkProbabilities.positiveSamplingProbabilities(), this.config.windowSize(), this.progressTracker), this.negativeSamples, max, this.config.negativeSamplingRate(), this.config.embeddingDimension());
            });
            RunWithConcurrency.builder().concurrency(this.config.concurrency()).tasks(degreePartitionWithBatchSize).run();
            double sum = degreePartitionWithBatchSize.stream().mapToDouble((v0) -> {
                return v0.lossSum();
            }).sum();
            this.progressTracker.logInfo(StringFormatting.formatWithLocale("Loss %.4f", new Object[]{Double.valueOf(sum)}));
            arrayList.add(Double.valueOf(sum));
            this.progressTracker.endSubTask();
        }
        this.progressTracker.endSubTask();
        return ImmutableResult.of(this.centerEmbeddings, (List<Double>) arrayList);
    }

    private static HugeObjectArray<FloatVector> initializeEmbeddings(long j, int i, SplittableRandom splittableRandom) {
        HugeObjectArray<FloatVector> newArray = HugeObjectArray.newArray(FloatVector.class, j);
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return newArray;
            }
            newArray.set(j3, new FloatVector(((FloatConsumer) splittableRandom.doubles(i, -1.0d, 1.0d).collect(() -> {
                return new FloatConsumer(i);
            }, (v0, v1) -> {
                v0.add(v1);
            }, (v0, v1) -> {
                v0.addAll(v1);
            })).values));
            j2 = j3 + 1;
        }
    }
}
