package org.neo4j.gds.similarity.knn;

import com.carrotsearch.hppc.LongArrayList;
import com.carrotsearch.hppc.cursors.LongCursor;
import java.util.Iterator;
import java.util.SplittableRandom;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

/* loaded from: input_file:org/neo4j/gds/similarity/knn/JoinNeighbors.class */
final class JoinNeighbors implements Runnable {
    private final SplittableRandom random;
    private final SimilarityFunction similarityFunction;
    private final NeighborFilter neighborFilter;
    private final Neighbors allNeighbors;
    private final HugeObjectArray<LongArrayList> allOldNeighbors;
    private final HugeObjectArray<LongArrayList> allNewNeighbors;
    private final HugeObjectArray<LongArrayList> allReverseOldNeighbors;
    private final HugeObjectArray<LongArrayList> allReverseNewNeighbors;
    private final int sampledK;
    private final int randomJoins;
    private final ProgressTracker progressTracker;
    private final long nodeCount;
    private final Partition partition;
    private final double perturbationRate;
    private long updateCount = 0;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/neo4j/gds/similarity/knn/JoinNeighbors$Factory.class */
    static class Factory {
        private final SimilarityFunction similarityFunction;
        private final int sampledK;
        private final double perturbationRate;
        private final int randomJoins;
        private final SplittableRandom splittableRandom;
        private final ProgressTracker progressTracker;

        /* JADX INFO: Access modifiers changed from: package-private */
        public Factory(SimilarityFunction similarityFunction, int i, double d, int i2, SplittableRandom splittableRandom, ProgressTracker progressTracker) {
            this.similarityFunction = similarityFunction;
            this.sampledK = i;
            this.perturbationRate = d;
            this.randomJoins = i2;
            this.splittableRandom = splittableRandom;
            this.progressTracker = progressTracker;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public JoinNeighbors create(Partition partition, Neighbors neighbors, HugeObjectArray<LongArrayList> hugeObjectArray, HugeObjectArray<LongArrayList> hugeObjectArray2, HugeObjectArray<LongArrayList> hugeObjectArray3, HugeObjectArray<LongArrayList> hugeObjectArray4, NeighborFilter neighborFilter) {
            return new JoinNeighbors(partition, neighbors, hugeObjectArray, hugeObjectArray2, hugeObjectArray3, hugeObjectArray4, neighborFilter, this.similarityFunction, this.sampledK, this.perturbationRate, this.randomJoins, this.splittableRandom.split(), this.progressTracker);
        }
    }

    JoinNeighbors(Partition partition, Neighbors neighbors, HugeObjectArray<LongArrayList> hugeObjectArray, HugeObjectArray<LongArrayList> hugeObjectArray2, HugeObjectArray<LongArrayList> hugeObjectArray3, HugeObjectArray<LongArrayList> hugeObjectArray4, NeighborFilter neighborFilter, SimilarityFunction similarityFunction, int i, double d, int i2, SplittableRandom splittableRandom, ProgressTracker progressTracker) {
        this.random = splittableRandom;
        this.similarityFunction = similarityFunction;
        this.neighborFilter = neighborFilter;
        this.allNeighbors = neighbors;
        this.nodeCount = hugeObjectArray2.size();
        this.allOldNeighbors = hugeObjectArray;
        this.allNewNeighbors = hugeObjectArray2;
        this.allReverseOldNeighbors = hugeObjectArray3;
        this.allReverseNewNeighbors = hugeObjectArray4;
        this.sampledK = i;
        this.randomJoins = i2;
        this.partition = partition;
        this.progressTracker = progressTracker;
        this.perturbationRate = d;
    }

    @Override // java.lang.Runnable
    public void run() {
        long startNode = this.partition.startNode();
        long nodeCount = startNode + this.partition.nodeCount();
        long j = startNode;
        while (true) {
            long j2 = j;
            if (j2 >= nodeCount) {
                this.progressTracker.logProgress(this.partition.nodeCount());
                return;
            }
            LongArrayList longArrayList = (LongArrayList) this.allOldNeighbors.get(j2);
            if (longArrayList != null) {
                combineNeighbors((LongArrayList) this.allReverseOldNeighbors.get(j2), longArrayList);
            }
            LongArrayList longArrayList2 = (LongArrayList) this.allNewNeighbors.get(j2);
            if (longArrayList2 != null) {
                combineNeighbors((LongArrayList) this.allReverseNewNeighbors.get(j2), longArrayList2);
                this.updateCount += joinNewNeighbors(j2, longArrayList, longArrayList2);
            }
            randomJoins(this.nodeCount, j2);
            j = j2 + 1;
        }
    }

    private long joinNewNeighbors(long j, LongArrayList longArrayList, LongArrayList longArrayList2) {
        long j2 = 0;
        long[] jArr = longArrayList2.buffer;
        int i = longArrayList2.elementsCount;
        boolean isSymmetric = this.similarityFunction.isSymmetric();
        for (int i2 = 0; i2 < i; i2++) {
            long j3 = jArr[i2];
            if (!$assertionsDisabled && j3 == j) {
                throw new AssertionError();
            }
            j2 += join(j3, j);
            for (int i3 = i2 + 1; i3 < i; i3++) {
                long j4 = jArr[i3];
                if (j3 != j4) {
                    j2 = isSymmetric ? j2 + joinSymmetric(j3, j4) : j2 + join(j3, j4) + join(j4, j3);
                }
            }
            if (longArrayList != null) {
                Iterator it = longArrayList.iterator();
                while (it.hasNext()) {
                    long j5 = ((LongCursor) it.next()).value;
                    if (j3 != j5) {
                        j2 = isSymmetric ? j2 + joinSymmetric(j3, j5) : j2 + join(j3, j5) + join(j5, j3);
                    }
                }
            }
        }
        return j2;
    }

    private void combineNeighbors(@Nullable LongArrayList longArrayList, LongArrayList longArrayList2) {
        if (longArrayList != null) {
            int size = longArrayList.size();
            Iterator it = longArrayList.iterator();
            while (it.hasNext()) {
                LongCursor longCursor = (LongCursor) it.next();
                if (this.random.nextInt(size) < this.sampledK) {
                    longArrayList2.add(longCursor.value);
                }
            }
        }
    }

    private void randomJoins(long j, long j2) {
        for (int i = 0; i < this.randomJoins; i++) {
            long nextLong = this.random.nextLong(j - 1);
            if (nextLong >= j2) {
                nextLong++;
            }
            join(j2, nextLong);
        }
    }

    private long joinSymmetric(long j, long j2) {
        long add;
        long add2;
        if (!$assertionsDisabled && j == j2) {
            throw new AssertionError();
        }
        if (this.neighborFilter.excludeNodePair(j, j2)) {
            return 0L;
        }
        double computeSimilarity = this.similarityFunction.computeSimilarity(j, j2);
        NeighborList andIncrementCounter = this.allNeighbors.getAndIncrementCounter(j);
        synchronized (andIncrementCounter) {
            add = 0 + andIncrementCounter.add(j2, computeSimilarity, this.random, this.perturbationRate);
        }
        NeighborList neighborList = this.allNeighbors.get(j2);
        synchronized (neighborList) {
            add2 = add + neighborList.add(j, computeSimilarity, this.random, this.perturbationRate);
        }
        return add2;
    }

    private long join(long j, long j2) {
        long add;
        if (!$assertionsDisabled && j == j2) {
            throw new AssertionError();
        }
        if (this.neighborFilter.excludeNodePair(j, j2)) {
            return 0L;
        }
        double computeSimilarity = this.similarityFunction.computeSimilarity(j, j2);
        NeighborList andIncrementCounter = this.allNeighbors.getAndIncrementCounter(j);
        synchronized (andIncrementCounter) {
            add = andIncrementCounter.add(j2, computeSimilarity, this.random, this.perturbationRate);
        }
        return add;
    }

    long nodePairsConsidered() {
        return this.allNeighbors.joinCounter();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long updateCount() {
        return this.updateCount;
    }

    static {
        $assertionsDisabled = !JoinNeighbors.class.desiredAssertionStatus();
    }
}
