package org.neo4j.gds.traversal;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.EmbeddingUtils;
import org.neo4j.gds.ml.core.samplers.RandomWalkSampler;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/traversal/RandomWalkCountingNodeVisits.class */
public final class RandomWalkCountingNodeVisits extends Algorithm<HugeAtomicLongArray> {
    private final Concurrency concurrency;
    private final ExecutorService executorService;
    private final Graph graph;
    private final long randomSeed;
    private final WalkParameters walkParameters;
    private final List<Long> sourceNodes;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/traversal/RandomWalkCountingNodeVisits$RandomWalkTask.class */
    public static final class RandomWalkTask implements Runnable {
        private final Graph graph;
        private final NextNodeSupplier nextNodeSupplier;
        private final RandomWalkSampler sampler;
        private final int walksPerNode;
        private final HugeAtomicLongArray result;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;

        private RandomWalkTask(Graph graph, NextNodeSupplier nextNodeSupplier, RandomWalkSampler randomWalkSampler, int i, HugeAtomicLongArray hugeAtomicLongArray, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.graph = graph;
            this.nextNodeSupplier = nextNodeSupplier;
            this.sampler = randomWalkSampler;
            this.walksPerNode = i;
            this.result = hugeAtomicLongArray;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (this.terminationFlag.running()) {
                long nextNode = this.nextNodeSupplier.nextNode();
                if (nextNode == -1) {
                    return;
                }
                if (this.graph.degree(nextNode) == 0) {
                    this.progressTracker.logProgress();
                } else {
                    this.sampler.prepareForNewNode(nextNode);
                    for (int i = 0; i < this.walksPerNode; i++) {
                        for (long j : this.sampler.walk(nextNode)) {
                            this.result.getAndAdd(j, 1L);
                        }
                    }
                    this.progressTracker.logProgress();
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/traversal/RandomWalkCountingNodeVisits$RandomWalkTaskSupplier.class */
    public static class RandomWalkTaskSupplier implements Supplier<RandomWalkTask> {
        private final Supplier<Graph> graphSupplier;
        private final NextNodeSupplier nextNodeSupplier;
        private final RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier;
        private final WalkParameters walkParameters;
        private final long randomSeed;
        private final HugeAtomicLongArray result;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;

        RandomWalkTaskSupplier(Supplier<Graph> supplier, NextNodeSupplier nextNodeSupplier, RandomWalkSampler.CumulativeWeightSupplier cumulativeWeightSupplier, WalkParameters walkParameters, long j, HugeAtomicLongArray hugeAtomicLongArray, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.graphSupplier = supplier;
            this.nextNodeSupplier = nextNodeSupplier;
            this.cumulativeWeightSupplier = cumulativeWeightSupplier;
            this.walkParameters = walkParameters;
            this.randomSeed = j;
            this.result = hugeAtomicLongArray;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.function.Supplier
        public RandomWalkTask get() {
            Graph graph = this.graphSupplier.get();
            return new RandomWalkTask(graph, this.nextNodeSupplier, RandomWalkSampler.create(graph, this.cumulativeWeightSupplier, this.walkParameters.walkLength(), this.walkParameters.returnFactor(), this.walkParameters.inOutFactor(), this.randomSeed), this.walkParameters.walksPerNode(), this.result, this.progressTracker, this.terminationFlag);
        }
    }

    public static RandomWalkCountingNodeVisits create(Graph graph, Concurrency concurrency, WalkParameters walkParameters, List<Long> list, Optional<Long> optional, ProgressTracker progressTracker, ExecutorService executorService, TerminationFlag terminationFlag) {
        if (graph.hasRelationshipProperty()) {
            EmbeddingUtils.validateRelationshipWeightPropertyValue(graph, concurrency, d -> {
                return d >= 0.0d;
            }, "RandomWalk only supports non-negative weights.", executorService);
        }
        return new RandomWalkCountingNodeVisits(graph, concurrency, executorService, walkParameters, list, optional.orElseGet(() -> {
            return Long.valueOf(new Random().nextLong());
        }).longValue(), progressTracker, terminationFlag);
    }

    private RandomWalkCountingNodeVisits(Graph graph, Concurrency concurrency, ExecutorService executorService, WalkParameters walkParameters, List<Long> list, long j, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        super(progressTracker);
        this.concurrency = concurrency;
        this.executorService = executorService;
        this.graph = graph;
        this.walkParameters = walkParameters;
        this.sourceNodes = list;
        this.randomSeed = j;
        this.terminationFlag = terminationFlag;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public HugeAtomicLongArray m124compute() {
        this.progressTracker.beginSubTask("RandomWalk");
        HugeAtomicLongArray of = HugeAtomicLongArray.of(this.graph.nodeCount(), ParalleLongPageCreator.of(this.concurrency, j -> {
            return 0L;
        }));
        RandomWalkTaskSupplier createRandomWalkTaskSupplier = createRandomWalkTaskSupplier(of);
        RunWithConcurrency.builder().executor(this.executorService).concurrency(this.concurrency).tasks(IntStream.range(0, this.concurrency.value()).mapToObj(i -> {
            return createRandomWalkTaskSupplier.get();
        }).toList()).terminationFlag(this.terminationFlag).mayInterruptIfRunning(true).run();
        this.progressTracker.endSubTask("RandomWalk");
        return of;
    }

    private RandomWalkTaskSupplier createRandomWalkTaskSupplier(HugeAtomicLongArray hugeAtomicLongArray) {
        NextNodeSupplier nextNodeSupplier = RandomWalkCompanion.nextNodeSupplier(this.graph, this.sourceNodes);
        RandomWalkSampler.CumulativeWeightSupplier cumulativeWeights = RandomWalkCompanion.cumulativeWeights(this.graph, this.concurrency, this.executorService, this.progressTracker);
        Graph graph = this.graph;
        Objects.requireNonNull(graph);
        return new RandomWalkTaskSupplier(graph::concurrentCopy, nextNodeSupplier, cumulativeWeights, this.walkParameters, this.randomSeed, hugeAtomicLongArray, this.progressTracker, this.terminationFlag);
    }
}
