package org.neo4j.gds.traversal;

import java.util.List;
import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.mutable.MutableDouble;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.beta.k1coloring.ColoringStep;
import org.neo4j.gds.config.SourceNodesConfig;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.queue.QueueBasedSpliterator;
import org.neo4j.gds.degree.DegreeCentrality;
import org.neo4j.gds.degree.ImmutableDegreeCentralityConfig;
import org.neo4j.gds.ml.core.EmbeddingUtils;

/* loaded from: input_file:org/neo4j/gds/traversal/RandomWalk.class */
public final class RandomWalk extends Algorithm<RandomWalk, Stream<long[]>> {
    private static final int MAX_TRIES = 100;
    private final Graph graph;
    private final RandomWalkBaseConfig config;
    private final AllocationTracker allocationTracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    @FunctionalInterface
    /* loaded from: input_file:org/neo4j/gds/traversal/RandomWalk$CumulativeWeightSupplier.class */
    public interface CumulativeWeightSupplier {
        double forNode(long j);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @FunctionalInterface
    /* loaded from: input_file:org/neo4j/gds/traversal/RandomWalk$NextNodeSupplier.class */
    public interface NextNodeSupplier {
        public static final long NO_MORE_NODES = -1;

        /* loaded from: input_file:org/neo4j/gds/traversal/RandomWalk$NextNodeSupplier$GraphNodeSupplier.class */
        public static class GraphNodeSupplier implements NextNodeSupplier {
            private final long numberOfNodes;
            private final AtomicLong nextNodeId = new AtomicLong(0);

            GraphNodeSupplier(long j) {
                this.numberOfNodes = j;
            }

            @Override // org.neo4j.gds.traversal.RandomWalk.NextNodeSupplier
            public long nextNode() {
                long andIncrement = this.nextNodeId.getAndIncrement();
                if (andIncrement < this.numberOfNodes) {
                    return andIncrement;
                }
                return -1L;
            }
        }

        /* loaded from: input_file:org/neo4j/gds/traversal/RandomWalk$NextNodeSupplier$ListNodeSupplier.class */
        public static final class ListNodeSupplier implements NextNodeSupplier {
            private final List<Long> nodes;
            private final AtomicInteger nextIndex = new AtomicInteger(0);

            static ListNodeSupplier of(SourceNodesConfig sourceNodesConfig, Graph graph) {
                Stream stream = sourceNodesConfig.sourceNodes().stream();
                Objects.requireNonNull(graph);
                return new ListNodeSupplier((List) stream.map((v1) -> {
                    return r1.toMappedNodeId(v1);
                }).collect(Collectors.toList()));
            }

            private ListNodeSupplier(List<Long> list) {
                this.nodes = list;
            }

            @Override // org.neo4j.gds.traversal.RandomWalk.NextNodeSupplier
            public long nextNode() {
                int andIncrement = this.nextIndex.getAndIncrement();
                if (andIncrement < this.nodes.size()) {
                    return this.nodes.get(andIncrement).longValue();
                }
                return -1L;
            }
        }

        long nextNode();
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/traversal/RandomWalk$RandomWalkTask.class */
    public static final class RandomWalkTask implements Runnable {
        private final Graph graph;
        private final BlockingQueue<long[]> walks;
        private final NextNodeSupplier nextNodeSupplier;
        private final double normalizedReturnProbability;
        private final double normalizedSameDistanceProbability;
        private final double normalizedInOutProbability;
        private final long randomSeed;
        private final ProgressTracker progressTracker;
        private final CumulativeWeightSupplier cumulativeWeightSupplier;
        private final RandomWalkBaseConfig config;
        private final Random random = new Random();
        private final MutableDouble currentWeight = new MutableDouble(0.0d);
        private final MutableLong randomNeighbour = new MutableLong(-1);
        private final long[][] buffer = new long[ColoringStep.INITIAL_FORBIDDEN_COLORS];
        private final MutableInt bufferPosition = new MutableInt(0);

        static RandomWalkTask of(NextNodeSupplier nextNodeSupplier, CumulativeWeightSupplier cumulativeWeightSupplier, Graph graph, RandomWalkBaseConfig randomWalkBaseConfig, BlockingQueue<long[]> blockingQueue, long j, ProgressTracker progressTracker) {
            double max = Math.max(Math.max(1.0d / randomWalkBaseConfig.returnFactor(), 1.0d), 1.0d / randomWalkBaseConfig.inOutFactor());
            return new RandomWalkTask(nextNodeSupplier, cumulativeWeightSupplier, randomWalkBaseConfig, blockingQueue, (1.0d / randomWalkBaseConfig.returnFactor()) / max, 1.0d / max, (1.0d / randomWalkBaseConfig.inOutFactor()) / max, graph, j, progressTracker);
        }

        /* JADX WARN: Type inference failed for: r1v14, types: [long[], long[][]] */
        private RandomWalkTask(NextNodeSupplier nextNodeSupplier, CumulativeWeightSupplier cumulativeWeightSupplier, RandomWalkBaseConfig randomWalkBaseConfig, BlockingQueue<long[]> blockingQueue, double d, double d2, double d3, Graph graph, long j, ProgressTracker progressTracker) {
            this.nextNodeSupplier = nextNodeSupplier;
            this.cumulativeWeightSupplier = cumulativeWeightSupplier;
            this.graph = graph;
            this.config = randomWalkBaseConfig;
            this.walks = blockingQueue;
            this.normalizedReturnProbability = d;
            this.normalizedSameDistanceProbability = d2;
            this.normalizedInOutProbability = d3;
            this.randomSeed = j;
            this.progressTracker = progressTracker;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                long nextNode = this.nextNodeSupplier.nextNode();
                if (nextNode == -1) {
                    flushBuffer();
                    return;
                }
                if (this.graph.degree(nextNode) == 0) {
                    this.progressTracker.logProgress();
                } else {
                    this.random.setSeed(this.randomSeed + nextNode);
                    int walksPerNode = this.config.walksPerNode();
                    for (int i = 0; i < walksPerNode; i++) {
                        this.buffer[this.bufferPosition.getAndIncrement()] = walk(nextNode);
                        if (this.bufferPosition.getValue().intValue() == this.buffer.length) {
                            flushBuffer();
                        }
                    }
                    this.progressTracker.logProgress();
                }
            }
        }

        private long[] walk(long j) {
            int walkLength = this.config.walkLength();
            long[] jArr = new long[walkLength];
            jArr[0] = j;
            jArr[1] = randomNeighbour(j);
            int i = 2;
            while (true) {
                if (i >= walkLength) {
                    break;
                }
                long walkOneStep = walkOneStep(jArr[i - 2], jArr[i - 1]);
                if (walkOneStep == -1) {
                    long[] jArr2 = new long[i];
                    System.arraycopy(jArr, 0, jArr2, 0, jArr2.length);
                    jArr = jArr2;
                    break;
                }
                jArr[i] = walkOneStep;
                i++;
            }
            return jArr;
        }

        private long walkOneStep(long j, long j2) {
            int degree = this.graph.degree(j2);
            if (degree == 0) {
                return -1L;
            }
            if (degree == 1) {
                return randomNeighbour(j2);
            }
            for (int i = 0; i < RandomWalk.MAX_TRIES; i++) {
                long randomNeighbour = randomNeighbour(j2);
                double nextDouble = this.random.nextDouble();
                if (randomNeighbour == j) {
                    if (nextDouble < this.normalizedReturnProbability) {
                        return randomNeighbour;
                    }
                } else if (isNeighbour(j, randomNeighbour)) {
                    if (nextDouble < this.normalizedSameDistanceProbability) {
                        return randomNeighbour;
                    }
                } else if (nextDouble < this.normalizedInOutProbability) {
                    return randomNeighbour;
                }
            }
            return randomNeighbour(j2);
        }

        private long randomNeighbour(long j) {
            double forNode = this.cumulativeWeightSupplier.forNode(j) * this.random.nextDouble();
            this.currentWeight.setValue(0.0d);
            this.randomNeighbour.setValue(-1L);
            this.graph.forEachRelationship(j, 1.0d, (j2, j3, d) -> {
                if (forNode > this.currentWeight.addAndGet(d)) {
                    return true;
                }
                this.randomNeighbour.setValue(j3);
                return false;
            });
            return this.randomNeighbour.getValue().longValue();
        }

        private boolean isNeighbour(long j, long j2) {
            return this.graph.exists(j, j2);
        }

        private void flushBuffer() {
            for (int i = 0; i < this.bufferPosition.getValue().intValue(); i++) {
                try {
                    this.walks.put(this.buffer[i]);
                } catch (InterruptedException e) {
                }
            }
            this.bufferPosition.setValue(0);
        }
    }

    private RandomWalk(Graph graph, RandomWalkBaseConfig randomWalkBaseConfig, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = randomWalkBaseConfig;
        this.allocationTracker = allocationTracker;
    }

    public static RandomWalk create(Graph graph, RandomWalkBaseConfig randomWalkBaseConfig, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        if (graph.hasRelationshipProperty()) {
            EmbeddingUtils.validateRelationshipWeightPropertyValue(graph, randomWalkBaseConfig.concurrency(), d -> {
                return d >= 0.0d;
            }, "Node2Vec only supports non-negative weights.", Pools.DEFAULT);
        }
        return new RandomWalk(graph, randomWalkBaseConfig, allocationTracker, progressTracker);
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public Stream<long[]> m61compute() {
        CumulativeWeightSupplier cumulativeWeightSupplier;
        this.progressTracker.beginSubTask("RandomWalk");
        ArrayBlockingQueue arrayBlockingQueue = new ArrayBlockingQueue(this.config.walkBufferSize());
        long[] jArr = new long[0];
        if (this.graph.hasRelationshipProperty()) {
            DegreeCentrality.DegreeFunction cumulativeWeights = cumulativeWeights();
            Objects.requireNonNull(cumulativeWeights);
            cumulativeWeightSupplier = cumulativeWeights::get;
        } else {
            Graph graph = this.graph;
            Objects.requireNonNull(graph);
            cumulativeWeightSupplier = graph::degree;
        }
        CumulativeWeightSupplier cumulativeWeightSupplier2 = cumulativeWeightSupplier;
        Long l = (Long) this.config.randomSeed().orElseGet(() -> {
            return Long.valueOf(new Random().nextLong());
        });
        NextNodeSupplier graphNodeSupplier = (this.config.sourceNodes() == null || this.config.sourceNodes().isEmpty()) ? new NextNodeSupplier.GraphNodeSupplier(this.graph.nodeCount()) : NextNodeSupplier.ListNodeSupplier.of(this.config, this.graph);
        List list = (List) IntStream.range(0, this.config.concurrency()).mapToObj(i -> {
            return RandomWalkTask.of(graphNodeSupplier, cumulativeWeightSupplier2, this.graph.concurrentCopy(), this.config, arrayBlockingQueue, l.longValue(), this.progressTracker);
        }).collect(Collectors.toList());
        this.progressTracker.beginSubTask("create walks");
        new Thread(() -> {
            ParallelUtil.runWithConcurrency(this.config.concurrency(), list, this.terminationFlag, Pools.DEFAULT);
            try {
                this.progressTracker.endSubTask("create walks");
                this.progressTracker.endSubTask("RandomWalk");
                arrayBlockingQueue.put(jArr);
            } catch (InterruptedException e) {
            }
        }).start();
        return StreamSupport.stream(new QueueBasedSpliterator(arrayBlockingQueue, jArr, this.terminationFlag, MAX_TRIES), false);
    }

    private DegreeCentrality.DegreeFunction cumulativeWeights() {
        return new DegreeCentrality(this.graph, Pools.DEFAULT, ImmutableDegreeCentralityConfig.builder().concurrency(this.config.concurrency()).relationshipWeightProperty("DUMMY").build(), this.progressTracker, this.allocationTracker).m10compute();
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public RandomWalk m60me() {
        return this;
    }

    public void release() {
    }
}
