package org.neo4j.gds.embeddings.node2vec;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.function.LongUnaryOperator;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.BitUtil;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.tensor.FloatVector;
import org.neo4j.gds.ml.core.tensor.operations.FloatVectorOperations;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecModel.class */
public class Node2VecModel {
    private final HugeObjectArray<FloatVector> centerEmbeddings;
    private final HugeObjectArray<FloatVector> contextEmbeddings;
    private final double initialLearningRate;
    private final double minLearningRate;
    private final int iterations;
    private final int embeddingDimension;
    private final int windowSize;
    private final int negativeSamplingRate;
    private final EmbeddingInitializer embeddingInitializer;
    private final Concurrency concurrency;
    private final CompressedRandomWalks walks;
    private final RandomWalkProbabilities randomWalkProbabilities;
    private final ProgressTracker progressTracker;
    private final long randomSeed;
    private static final double EPSILON = 1.0E-10d;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.neo4j.gds.embeddings.node2vec.Node2VecModel$1, reason: invalid class name */
    /* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecModel$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$neo4j$gds$embeddings$node2vec$EmbeddingInitializer = new int[EmbeddingInitializer.values().length];

        static {
            try {
                $SwitchMap$org$neo4j$gds$embeddings$node2vec$EmbeddingInitializer[EmbeddingInitializer.UNIFORM.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$neo4j$gds$embeddings$node2vec$EmbeddingInitializer[EmbeddingInitializer.NORMALIZED.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecModel$FloatConsumer.class */
    public static class FloatConsumer {
        float[] values;
        int index;

        FloatConsumer(int i) {
            this.values = new float[i];
        }

        void add(double d) {
            float[] fArr = this.values;
            int i = this.index;
            this.index = i + 1;
            fArr[i] = (float) d;
        }

        void addAll(FloatConsumer floatConsumer) {
            System.arraycopy(floatConsumer.values, 0, this.values, this.index, floatConsumer.index);
            this.index += floatConsumer.index;
        }
    }

    /* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecModel$TrainingTask.class */
    private static final class TrainingTask implements Runnable {
        private final HugeObjectArray<FloatVector> centerEmbeddings;
        private final HugeObjectArray<FloatVector> contextEmbeddings;
        private final PositiveSampleProducer positiveSampleProducer;
        private final NegativeSampleProducer negativeSampleProducer;
        private final FloatVector centerGradientBuffer;
        private final FloatVector contextGradientBuffer;
        private final int negativeSamplingRate;
        private final float learningRate;
        private final ProgressTracker progressTracker;
        private double lossSum;

        private TrainingTask(HugeObjectArray<FloatVector> hugeObjectArray, HugeObjectArray<FloatVector> hugeObjectArray2, PositiveSampleProducer positiveSampleProducer, HugeLongArray hugeLongArray, float f, int i, int i2, ProgressTracker progressTracker, long j) {
            this.centerEmbeddings = hugeObjectArray;
            this.contextEmbeddings = hugeObjectArray2;
            this.positiveSampleProducer = positiveSampleProducer;
            this.negativeSampleProducer = new NegativeSampleProducer(hugeLongArray, j + Thread.currentThread().getId());
            this.learningRate = f;
            this.negativeSamplingRate = i;
            this.centerGradientBuffer = new FloatVector(i2);
            this.contextGradientBuffer = new FloatVector(i2);
            this.progressTracker = progressTracker;
        }

        @Override // java.lang.Runnable
        public void run() {
            long[] jArr = new long[2];
            while (this.positiveSampleProducer.next(jArr)) {
                trainSample(jArr[0], jArr[1], true);
                for (int i = 0; i < this.negativeSamplingRate; i++) {
                    trainSample(jArr[0], this.negativeSampleProducer.next(), false);
                }
                this.progressTracker.logProgress();
            }
        }

        private void trainSample(long j, long j2, boolean z) {
            FloatVector floatVector = (FloatVector) this.centerEmbeddings.get(j);
            FloatVector floatVector2 = (FloatVector) this.contextEmbeddings.get(j2);
            double sigmoid = Sigmoid.sigmoid(floatVector.innerProduct(floatVector2));
            double d = 1.0d - sigmoid;
            this.lossSum -= z ? Math.log(sigmoid + Node2VecModel.EPSILON) : Math.log(d + Node2VecModel.EPSILON);
            float f = (-(z ? (float) (-d) : (float) sigmoid)) * this.learningRate;
            FloatVectorOperations.scale(floatVector2.data(), f, this.centerGradientBuffer.data());
            FloatVectorOperations.scale(floatVector.data(), f, this.contextGradientBuffer.data());
            FloatVectorOperations.addInPlace(floatVector.data(), this.centerGradientBuffer.data());
            FloatVectorOperations.addInPlace(floatVector2.data(), this.contextGradientBuffer.data());
        }

        double lossSum() {
            return this.lossSum;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Node2VecModel(LongUnaryOperator longUnaryOperator, long j, TrainParameters trainParameters, Concurrency concurrency, Optional<Long> optional, CompressedRandomWalks compressedRandomWalks, RandomWalkProbabilities randomWalkProbabilities, ProgressTracker progressTracker) {
        this(longUnaryOperator, j, trainParameters.initialLearningRate(), trainParameters.minLearningRate(), trainParameters.iterations(), trainParameters.windowSize(), trainParameters.negativeSamplingRate(), trainParameters.embeddingDimension(), trainParameters.embeddingInitializer(), concurrency, optional, compressedRandomWalks, randomWalkProbabilities, progressTracker);
    }

    Node2VecModel(LongUnaryOperator longUnaryOperator, long j, double d, double d2, int i, int i2, int i3, int i4, EmbeddingInitializer embeddingInitializer, Concurrency concurrency, Optional<Long> optional, CompressedRandomWalks compressedRandomWalks, RandomWalkProbabilities randomWalkProbabilities, ProgressTracker progressTracker) {
        this.initialLearningRate = d;
        this.minLearningRate = d2;
        this.iterations = i;
        this.embeddingDimension = i4;
        this.windowSize = i2;
        this.negativeSamplingRate = i3;
        this.embeddingInitializer = embeddingInitializer;
        this.concurrency = concurrency;
        this.walks = compressedRandomWalks;
        this.randomWalkProbabilities = randomWalkProbabilities;
        this.progressTracker = progressTracker;
        this.randomSeed = optional.orElseGet(() -> {
            return Long.valueOf(new SplittableRandom().nextLong());
        }).longValue();
        Random random = new Random();
        this.centerEmbeddings = initializeEmbeddings(longUnaryOperator, j, i4, random);
        this.contextEmbeddings = initializeEmbeddings(longUnaryOperator, j, i4, random);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Node2VecResult train() {
        this.progressTracker.beginSubTask();
        double d = (this.initialLearningRate - this.minLearningRate) / this.iterations;
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.iterations; i++) {
            this.progressTracker.beginSubTask();
            this.progressTracker.setVolume(this.walks.size());
            float max = (float) Math.max(this.minLearningRate, this.initialLearningRate - (i * d));
            long size = this.walks.size();
            CompressedRandomWalks compressedRandomWalks = this.walks;
            Objects.requireNonNull(compressedRandomWalks);
            List degreePartitionWithBatchSize = PartitionUtils.degreePartitionWithBatchSize(size, compressedRandomWalks::walkLength, BitUtil.ceilDiv(this.randomWalkProbabilities.sampleCount(), this.concurrency.value()), degreePartition -> {
                return new TrainingTask(this.centerEmbeddings, this.contextEmbeddings, new PositiveSampleProducer(this.walks.iterator(degreePartition.startNode(), degreePartition.nodeCount()), this.randomWalkProbabilities.positiveSamplingProbabilities(), this.windowSize, Optional.of(Long.valueOf(this.randomSeed))), this.randomWalkProbabilities.negativeSamplingDistribution(), max, this.negativeSamplingRate, this.embeddingDimension, this.progressTracker, this.randomSeed);
            });
            RunWithConcurrency.builder().concurrency(this.concurrency).tasks(degreePartitionWithBatchSize).run();
            double sum = degreePartitionWithBatchSize.stream().mapToDouble((v0) -> {
                return v0.lossSum();
            }).sum();
            this.progressTracker.logInfo(StringFormatting.formatWithLocale("Loss %.4f", new Object[]{Double.valueOf(sum)}));
            arrayList.add(Double.valueOf(sum));
            this.progressTracker.endSubTask();
        }
        this.progressTracker.endSubTask();
        return new Node2VecResult(this.centerEmbeddings, arrayList);
    }

    private HugeObjectArray<FloatVector> initializeEmbeddings(LongUnaryOperator longUnaryOperator, long j, int i, Random random) {
        double d;
        HugeObjectArray<FloatVector> newArray = HugeObjectArray.newArray(FloatVector.class, j);
        switch (AnonymousClass1.$SwitchMap$org$neo4j$gds$embeddings$node2vec$EmbeddingInitializer[this.embeddingInitializer.ordinal()]) {
            case 1:
                d = 1.0d;
                break;
            case 2:
                d = 0.5d / i;
                break;
            default:
                throw new IllegalStateException("Missing implementation for: " + String.valueOf(this.embeddingInitializer));
        }
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return newArray;
            }
            random.setSeed(longUnaryOperator.applyAsLong(j3) + this.randomSeed);
            newArray.set(j3, new FloatVector(((FloatConsumer) random.doubles(i, -d, d).collect(() -> {
                return new FloatConsumer(i);
            }, (v0, v1) -> {
                v0.add(v1);
            }, (v0, v1) -> {
                v0.addAll(v1);
            })).values));
            j2 = j3 + 1;
        }
    }
}
