package org.neo4j.gds.wcc;

import com.carrotsearch.hppc.LongIntHashMap;
import com.carrotsearch.hppc.cursors.LongIntCursor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.RelationshipConsumer;
import org.neo4j.gds.api.RelationshipIterator;
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.dss.DisjointSetStruct;
import org.neo4j.gds.core.utils.paged.dss.HugeAtomicDisjointSetStruct;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/wcc/Wcc.class */
public class Wcc extends Algorithm<DisjointSetStruct> {
    private static final int NEIGHBOR_ROUNDS = 2;
    private static final int SAMPLING_SIZE = 1024;
    private final WccBaseConfig config;
    private final NodePropertyValues initialComponents;
    private final ExecutorService executor;
    private final long nodeCount;
    private final long batchSize;
    private final int threadSize;
    private Graph graph;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/wcc/Wcc$DirectedUnionTask.class */
    public class DirectedUnionTask implements Runnable, RelationshipConsumer {
        final DisjointSetStruct struct;
        final RelationshipIterator rels;
        private final long offset;
        private final long end;

        DirectedUnionTask(DisjointSetStruct disjointSetStruct, long j) {
            this.struct = disjointSetStruct;
            this.rels = Wcc.this.graph.concurrentCopy();
            this.offset = j;
            this.end = Math.min(j + Wcc.this.batchSize, Wcc.this.nodeCount);
        }

        @Override // java.lang.Runnable
        public void run() {
            long j = this.offset;
            while (true) {
                long j2 = j;
                if (j2 >= this.end) {
                    return;
                }
                compute(j2);
                if (j2 % 10000 == 0) {
                    Wcc.this.terminationFlag.assertRunning();
                }
                Wcc.this.progressTracker.logProgress(Wcc.this.graph.degree(j2));
                j = j2 + 1;
            }
        }

        void compute(long j) {
            this.rels.forEachRelationship(j, this);
        }

        public boolean accept(long j, long j2) {
            this.struct.union(j, j2);
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/wcc/Wcc$DirectedUnionWithThresholdTask.class */
    public class DirectedUnionWithThresholdTask extends DirectedUnionTask implements RelationshipWithPropertyConsumer {
        private final double threshold;

        DirectedUnionWithThresholdTask(double d, DisjointSetStruct disjointSetStruct, long j) {
            super(disjointSetStruct, j);
            this.threshold = d;
        }

        @Override // org.neo4j.gds.wcc.Wcc.DirectedUnionTask
        void compute(long j) {
            this.rels.forEachRelationship(j, Wcc.defaultWeight(this.threshold), this);
        }

        public boolean accept(long j, long j2, double d) {
            if (d <= this.threshold) {
                return true;
            }
            this.struct.union(j, j2);
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/wcc/Wcc$UndirectedSamplingTask.class */
    public static final class UndirectedSamplingTask implements Runnable, RelationshipConsumer {
        private final Graph graph;
        private final Partition partition;
        private final DisjointSetStruct components;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;
        private long limit;

        UndirectedSamplingTask(Graph graph, Partition partition, DisjointSetStruct disjointSetStruct, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.graph = graph.concurrentCopy();
            this.partition = partition;
            this.components = disjointSetStruct;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
        }

        @Override // java.lang.Runnable
        public void run() {
            long startNode = this.partition.startNode();
            long nodeCount = startNode + this.partition.nodeCount();
            long j = startNode;
            while (true) {
                long j2 = j;
                if (j2 >= nodeCount) {
                    return;
                }
                reset();
                this.graph.forEachRelationship(j2, this);
                if (j2 % 10000 == 0) {
                    this.terminationFlag.assertRunning();
                }
                this.progressTracker.logProgress(Math.min(Wcc.NEIGHBOR_ROUNDS, this.graph.degree(j2)));
                j = j2 + 1;
            }
        }

        public boolean accept(long j, long j2) {
            this.components.union(j, j2);
            this.limit--;
            return this.limit != 0;
        }

        public void reset() {
            this.limit = 2L;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/wcc/Wcc$UndirectedUnionTask.class */
    public static final class UndirectedUnionTask implements Runnable, RelationshipConsumer {
        private final Graph graph;
        private final long skipComponent;
        private final Partition partition;
        private final DisjointSetStruct components;
        private final ProgressTracker progressTracker;
        private final TerminationFlag terminationFlag;
        private long skip;

        UndirectedUnionTask(Graph graph, Partition partition, long j, DisjointSetStruct disjointSetStruct, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
            this.graph = graph.concurrentCopy();
            this.skipComponent = j;
            this.partition = partition;
            this.components = disjointSetStruct;
            this.progressTracker = progressTracker;
            this.terminationFlag = terminationFlag;
        }

        @Override // java.lang.Runnable
        public void run() {
            int degree;
            long startNode = this.partition.startNode();
            long nodeCount = startNode + this.partition.nodeCount();
            long j = startNode;
            while (true) {
                long j2 = j;
                if (j2 >= nodeCount) {
                    return;
                }
                if (this.components.setIdOf(j2) != this.skipComponent && (degree = this.graph.degree(j2)) > Wcc.NEIGHBOR_ROUNDS) {
                    reset();
                    this.graph.forEachRelationship(j2, this);
                    this.progressTracker.logProgress(degree - Wcc.NEIGHBOR_ROUNDS);
                    if (j2 % 10000 == 0) {
                        this.terminationFlag.assertRunning();
                    }
                }
                j = j2 + 1;
            }
        }

        public boolean accept(long j, long j2) {
            this.skip++;
            if (this.skip <= 2) {
                return true;
            }
            this.components.union(j, j2);
            return true;
        }

        public void reset() {
            this.skip = 0L;
        }
    }

    public static MemoryEstimation memoryEstimation(boolean z) {
        return MemoryEstimations.builder(Wcc.class.getSimpleName()).add("dss", HugeAtomicDisjointSetStruct.memoryEstimation(z)).build();
    }

    public Wcc(Graph graph, ExecutorService executorService, int i, WccBaseConfig wccBaseConfig, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.config = wccBaseConfig;
        this.initialComponents = wccBaseConfig.isIncremental() ? graph.nodeProperties(wccBaseConfig.seedProperty()) : null;
        this.executor = executorService;
        this.nodeCount = graph.nodeCount();
        this.batchSize = ParallelUtil.adjustedBatchSize(this.nodeCount, wccBaseConfig.concurrency(), i, 2147483647L);
        long threadCount = ParallelUtil.threadCount(this.batchSize, this.nodeCount);
        if (threadCount > 2147483647L) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Too many nodes (%d) to run union find with the given concurrency (%d) and batchSize (%d)", new Object[]{Long.valueOf(this.nodeCount), Integer.valueOf(wccBaseConfig.concurrency()), Long.valueOf(this.batchSize)}));
        }
        this.threadSize = (int) threadCount;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public DisjointSetStruct m88compute() {
        this.progressTracker.beginSubTask();
        long nodeCount = this.graph.nodeCount();
        HugeAtomicDisjointSetStruct hugeAtomicDisjointSetStruct = this.config.isIncremental() ? new HugeAtomicDisjointSetStruct(nodeCount, this.initialComponents, this.config.concurrency()) : new HugeAtomicDisjointSetStruct(nodeCount, this.config.concurrency());
        if (!this.graph.isUndirected() || this.config.hasThreshold()) {
            computeDirected(hugeAtomicDisjointSetStruct);
        } else {
            computeUndirected(hugeAtomicDisjointSetStruct);
        }
        this.progressTracker.endSubTask();
        return hugeAtomicDisjointSetStruct;
    }

    public void release() {
        this.graph = null;
    }

    public double threshold() {
        return this.config.threshold();
    }

    private void computeDirected(DisjointSetStruct disjointSetStruct) {
        ArrayList arrayList = new ArrayList(this.threadSize);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= this.nodeCount) {
                ParallelUtil.run(arrayList, this.executor);
                return;
            } else {
                arrayList.add(!this.config.hasThreshold() ? new DirectedUnionTask(disjointSetStruct, j2) : new DirectedUnionWithThresholdTask(threshold(), disjointSetStruct, j2));
                j = j2 + this.batchSize;
            }
        }
    }

    private void computeUndirected(DisjointSetStruct disjointSetStruct) {
        List<Partition> rangePartition = PartitionUtils.rangePartition(this.config.concurrency(), this.graph.nodeCount(), Function.identity(), Optional.empty());
        sampleSubgraph(disjointSetStruct, rangePartition);
        linkRemaining(disjointSetStruct, rangePartition, findLargestComponent(disjointSetStruct));
    }

    private void sampleSubgraph(DisjointSetStruct disjointSetStruct, List<Partition> list) {
        ParallelUtil.run((List) list.stream().map(partition -> {
            return new UndirectedSamplingTask(this.graph, partition, disjointSetStruct, this.progressTracker, this.terminationFlag);
        }).collect(Collectors.toList()), this.executor);
    }

    private long findLargestComponent(DisjointSetStruct disjointSetStruct) {
        SplittableRandom splittableRandom = new SplittableRandom();
        LongIntHashMap longIntHashMap = new LongIntHashMap();
        for (int i = 0; i < SAMPLING_SIZE; i++) {
            longIntHashMap.addTo(disjointSetStruct.setIdOf(splittableRandom.nextLong(this.nodeCount)), 1);
        }
        int i2 = -1;
        long j = -1;
        Iterator it = longIntHashMap.iterator();
        while (it.hasNext()) {
            LongIntCursor longIntCursor = (LongIntCursor) it.next();
            long j2 = longIntCursor.key;
            int i3 = longIntCursor.value;
            if (i3 > i2) {
                i2 = i3;
                j = j2;
            }
        }
        return j;
    }

    private void linkRemaining(DisjointSetStruct disjointSetStruct, List<Partition> list, long j) {
        ParallelUtil.run((List) list.stream().map(partition -> {
            return new UndirectedUnionTask(this.graph, partition, j, disjointSetStruct, this.progressTracker, this.terminationFlag);
        }).collect(Collectors.toList()), this.executor);
    }

    private static double defaultWeight(double d) {
        return d + 1.0d;
    }
}
