package org.neo4j.graphalgo.impl.similarity;

import com.carrotsearch.hppc.LongHashSet;
import com.carrotsearch.hppc.cursors.LongCursor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import java.util.stream.Stream;
import org.neo4j.graphalgo.Orientation;
import org.neo4j.graphalgo.RelationshipType;
import org.neo4j.graphalgo.annotation.ValueClass;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.IdMapGraph;
import org.neo4j.graphalgo.api.RelationshipIterator;
import org.neo4j.graphalgo.core.Aggregation;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.huge.HugeGraph;
import org.neo4j.graphalgo.core.loading.GraphStore;
import org.neo4j.graphalgo.core.loading.HugeGraphUtil;
import org.neo4j.graphalgo.core.loading.IdMap;
import org.neo4j.graphalgo.core.loading.IdMapBuilder;
import org.neo4j.graphalgo.core.loading.IdsAndProperties;
import org.neo4j.graphalgo.core.loading.NodeImporter;
import org.neo4j.graphalgo.core.loading.NodesBatchBuffer;
import org.neo4j.graphalgo.core.loading.NodesBatchBufferBuilder;
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArrayBuilder;
import org.neo4j.graphalgo.impl.similarity.SimilarityInput;
import org.neo4j.graphalgo.results.SimilarityResult;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.roaringbitmap.RoaringBitmap;

/* loaded from: input_file:org/neo4j/graphalgo/impl/similarity/ApproxNearestNeighborsAlgorithm.class */
public final class ApproxNearestNeighborsAlgorithm<INPUT extends SimilarityInput> extends SimilarityAlgorithm<ApproxNearestNeighborsAlgorithm<INPUT>, INPUT> {
    private static final RelationshipType ANN_OUT_GRAPH = RelationshipType.of("ANN_OUT");
    private static final RelationshipType ANN_IN_GRAPH = RelationshipType.of("ANN_IN");
    private final ApproximateNearestNeighborsConfig config;
    private final SimilarityAlgorithm<?, INPUT> algorithm;
    private final Log log;
    private final AtomicLong nodeQueue;
    private final AtomicInteger actualIterations;
    private final Random random;
    private final ExecutorService executor;
    private final AllocationTracker tracker;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/similarity/ApproxNearestNeighborsAlgorithm$ComputeTask.class */
    public class ComputeTask implements NeighborhoodTask {
        private final INPUT[] inputs;
        private final SimilarityComputer<INPUT> similarityComputer;
        private final RleDecoder rleDecoder;
        private final AnnTopKConsumer[] localTopKConsumers;
        private final RelationshipIterator oldOutRelationships;
        private final RelationshipIterator oldInRelationships;
        private final RelationshipIterator newOutRelationships;
        private final RelationshipIterator newInRelationships;
        private final double sampleRate;

        ComputeTask(INPUT[] inputArr, SimilarityComputer<INPUT> similarityComputer, Supplier<RleDecoder> supplier, int i, GraphStore graphStore, GraphStore graphStore2, double d) {
            this.inputs = inputArr;
            this.similarityComputer = similarityComputer;
            this.rleDecoder = supplier.get();
            this.localTopKConsumers = AnnTopKConsumer.initializeTopKConsumers(i, ApproxNearestNeighborsAlgorithm.this.config.topK());
            this.oldOutRelationships = graphStore.getGraph(new RelationshipType[]{ApproxNearestNeighborsAlgorithm.ANN_OUT_GRAPH}).concurrentCopy();
            this.oldInRelationships = graphStore.getGraph(new RelationshipType[]{ApproxNearestNeighborsAlgorithm.ANN_IN_GRAPH}).concurrentCopy();
            this.newOutRelationships = graphStore2.getGraph(new RelationshipType[]{ApproxNearestNeighborsAlgorithm.ANN_OUT_GRAPH}).concurrentCopy();
            this.newInRelationships = graphStore2.getGraph(new RelationshipType[]{ApproxNearestNeighborsAlgorithm.ANN_IN_GRAPH}).concurrentCopy();
            this.sampleRate = d;
        }

        @Override // java.lang.Runnable
        public void run() {
            SimilarityResult similarity;
            while (true) {
                long andIncrement = ApproxNearestNeighborsAlgorithm.this.nodeQueue.getAndIncrement();
                if (andIncrement >= this.inputs.length || !ApproxNearestNeighborsAlgorithm.this.running()) {
                    return;
                }
                LongHashSet neighbors = getNeighbors(andIncrement, this.oldOutRelationships, this.oldInRelationships);
                long[] array = getNeighbors(andIncrement, this.newOutRelationships, this.newInRelationships).toArray();
                for (int i = 0; i < array.length; i++) {
                    int intExact = Math.toIntExact(array[i]);
                    INPUT input = this.inputs[intExact];
                    for (int i2 = i + 1; i2 < array.length; i2++) {
                        int intExact2 = Math.toIntExact(array[i2]);
                        SimilarityResult similarity2 = this.similarityComputer.similarity(this.rleDecoder, input, this.inputs[intExact2], ApproxNearestNeighborsAlgorithm.this.config.similarityCutoff());
                        if (similarity2 != null) {
                            this.localTopKConsumers[intExact].applyAsInt(similarity2);
                            this.localTopKConsumers[intExact2].applyAsInt(similarity2.reverse());
                        }
                    }
                    Iterator it = neighbors.iterator();
                    while (it.hasNext()) {
                        int intExact3 = Math.toIntExact(((LongCursor) it.next()).value);
                        INPUT input2 = this.inputs[intExact3];
                        if (intExact != intExact3 && (similarity = this.similarityComputer.similarity(this.rleDecoder, input, input2, ApproxNearestNeighborsAlgorithm.this.config.similarityCutoff())) != null) {
                            this.localTopKConsumers[intExact].applyAsInt(similarity);
                            this.localTopKConsumers[intExact3].applyAsInt(similarity.reverse());
                        }
                    }
                }
            }
        }

        private LongHashSet getNeighbors(long j, RelationshipIterator relationshipIterator, RelationshipIterator relationshipIterator2) {
            long[] array = ApproxNearestNeighborsAlgorithm.this.findNeighbors(j, relationshipIterator2).toArray();
            long[] sampleNeighbors = ApproxNearestNeighborsAlgorithm.this.config.sampling() ? ANNUtils.sampleNeighbors(array, this.sampleRate, ApproxNearestNeighborsAlgorithm.this.random) : array;
            LongHashSet findNeighbors = ApproxNearestNeighborsAlgorithm.this.findNeighbors(j, relationshipIterator);
            LongHashSet longHashSet = new LongHashSet();
            longHashSet.addAll(sampleNeighbors);
            longHashSet.addAll(findNeighbors);
            return longHashSet;
        }

        @Override // org.neo4j.graphalgo.impl.similarity.ApproxNearestNeighborsAlgorithm.NeighborhoodTask
        public int mergeInto(AnnTopKConsumer[] annTopKConsumerArr) {
            int i = 0;
            for (int i2 = 0; i2 < annTopKConsumerArr.length; i2++) {
                i += annTopKConsumerArr[i2].apply(this.localTopKConsumers[i2]);
            }
            return i;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/similarity/ApproxNearestNeighborsAlgorithm$InitTask.class */
    public class InitTask implements Runnable {
        private final INPUT[] inputs;
        private final AnnTopKConsumer[] topKConsumers;
        private final RleDecoder rleDecoder;
        private final SimilarityComputer<INPUT> similarityComputer;

        InitTask(INPUT[] inputArr, AnnTopKConsumer[] annTopKConsumerArr, Supplier<RleDecoder> supplier, SimilarityComputer<INPUT> similarityComputer) {
            this.inputs = inputArr;
            this.topKConsumers = annTopKConsumerArr;
            this.rleDecoder = supplier.get();
            this.similarityComputer = similarityComputer;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                long andIncrement = ApproxNearestNeighborsAlgorithm.this.nodeQueue.getAndIncrement();
                if (andIncrement >= this.inputs.length || !ApproxNearestNeighborsAlgorithm.this.running()) {
                    return;
                }
                int intExact = Math.toIntExact(andIncrement);
                AnnTopKConsumer annTopKConsumer = this.topKConsumers[intExact];
                INPUT input = this.inputs[intExact];
                Iterator<Integer> it = ANNUtils.selectRandomNeighbors(Math.abs(ApproxNearestNeighborsAlgorithm.this.config.topK()), this.inputs.length, intExact, ApproxNearestNeighborsAlgorithm.this.random).iterator();
                while (it.hasNext()) {
                    SimilarityResult similarity = this.similarityComputer.similarity(this.rleDecoder, input, this.inputs[it.next().intValue()], ApproxNearestNeighborsAlgorithm.this.config.similarityCutoff());
                    if (similarity != null) {
                        annTopKConsumer.applyAsInt(similarity);
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/similarity/ApproxNearestNeighborsAlgorithm$NeighborhoodTask.class */
    public interface NeighborhoodTask extends Runnable {
        int mergeInto(AnnTopKConsumer[] annTopKConsumerArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/similarity/ApproxNearestNeighborsAlgorithm$NewOldGraph.class */
    public static class NewOldGraph {
        private final Graph graph;
        private final RoaringBitmap[] visitedRelationships;

        NewOldGraph(Graph graph, RoaringBitmap[] roaringBitmapArr) {
            this.graph = graph;
            this.visitedRelationships = roaringBitmapArr;
        }

        LongHashSet findOldNeighbors(long j) {
            LongHashSet longHashSet = new LongHashSet();
            RoaringBitmap roaringBitmap = this.visitedRelationships[(int) j];
            this.graph.forEachRelationship(j, (j2, j3) -> {
                if (!roaringBitmap.contains((int) j3)) {
                    return true;
                }
                longHashSet.add(j3);
                return true;
            });
            return longHashSet;
        }

        LongHashSet findNewNeighbors(long j) {
            LongHashSet longHashSet = new LongHashSet();
            RoaringBitmap roaringBitmap = this.visitedRelationships[(int) j];
            this.graph.forEachRelationship(j, (j2, j3) -> {
                if (roaringBitmap.contains((int) j3)) {
                    return true;
                }
                longHashSet.add(j3);
                return true;
            });
            return longHashSet;
        }
    }

    @ValueClass
    /* loaded from: input_file:org/neo4j/graphalgo/impl/similarity/ApproxNearestNeighborsAlgorithm$RelationshipImporter.class */
    public interface RelationshipImporter {
        HugeGraphUtil.RelationshipsBuilder outImporter();

        HugeGraphUtil.RelationshipsBuilder inImporter();

        default void consume(AnnTopKConsumer[] annTopKConsumerArr) {
            for (AnnTopKConsumer annTopKConsumer : annTopKConsumerArr) {
                annTopKConsumer.stream().forEach(similarityResult -> {
                    long j = similarityResult.item1;
                    long j2 = similarityResult.item2;
                    if (j == -1 || j2 == -1 || j == j2) {
                        return;
                    }
                    addRelationship(j, j2);
                });
            }
        }

        default void addRelationship(long j, long j2) {
            outImporter().addFromInternal(j, j2);
            inImporter().addFromInternal(j2, j);
        }

        default GraphStore buildGraphStore(IdMap idMap, AllocationTracker allocationTracker) {
            HugeGraph.Relationships build = outImporter().build();
            HugeGraph.Relationships build2 = inImporter().build();
            HashMap hashMap = new HashMap();
            hashMap.put(ApproxNearestNeighborsAlgorithm.ANN_OUT_GRAPH, build.topology());
            hashMap.put(ApproxNearestNeighborsAlgorithm.ANN_IN_GRAPH, build2.topology());
            return GraphStore.of(idMap, Collections.emptyMap(), hashMap, Collections.emptyMap(), allocationTracker);
        }

        static RelationshipImporter of(IdMap idMap, ExecutorService executorService, AllocationTracker allocationTracker) {
            return ImmutableRelationshipImporter.of(new HugeGraphUtil.RelationshipsBuilder(idMap, Orientation.NATURAL, false, Aggregation.NONE, executorService, allocationTracker), new HugeGraphUtil.RelationshipsBuilder(idMap, Orientation.REVERSE, false, Aggregation.NONE, executorService, allocationTracker));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/graphalgo/impl/similarity/ApproxNearestNeighborsAlgorithm$SetupTask.class */
    public class SetupTask implements Runnable {
        private final NewOldGraph graph;
        private final RelationshipImporter oldImporter;
        private final RelationshipImporter newImporter;
        private final double sampleSize;
        private final RoaringBitmap[] visitedRelationships;
        private final long startNodeId;
        private final long nodeCount;

        SetupTask(NewOldGraph newOldGraph, RoaringBitmap[] roaringBitmapArr, RelationshipImporter relationshipImporter, RelationshipImporter relationshipImporter2, double d, long j, long j2) {
            this.graph = newOldGraph;
            this.visitedRelationships = roaringBitmapArr;
            this.oldImporter = relationshipImporter;
            this.newImporter = relationshipImporter2;
            this.sampleSize = d;
            this.startNodeId = j;
            this.nodeCount = j2;
        }

        @Override // java.lang.Runnable
        public void run() {
            long j = this.startNodeId + this.nodeCount;
            long j2 = this.startNodeId;
            while (true) {
                long j3 = j2;
                if (j3 >= j || !ApproxNearestNeighborsAlgorithm.this.running()) {
                    return;
                }
                int intExact = Math.toIntExact(j3);
                Iterator it = this.graph.findOldNeighbors(j3).iterator();
                while (it.hasNext()) {
                    this.oldImporter.addRelationship(j3, ((LongCursor) it.next()).value);
                }
                long[] array = this.graph.findNewNeighbors(j3).toArray();
                long[] sampleNeighbors = ApproxNearestNeighborsAlgorithm.this.config.sampling() ? ANNUtils.sampleNeighbors(array, this.sampleSize, ApproxNearestNeighborsAlgorithm.this.random) : array;
                for (long j4 : sampleNeighbors) {
                    this.newImporter.addRelationship(j3, j4);
                }
                for (long j5 : sampleNeighbors) {
                    this.visitedRelationships[intExact].add(Math.toIntExact(Long.valueOf(j5).longValue()));
                }
                j2 = j3 + 1;
            }
        }
    }

    public ApproxNearestNeighborsAlgorithm(ApproximateNearestNeighborsConfig approximateNearestNeighborsConfig, SimilarityAlgorithm<?, INPUT> similarityAlgorithm, GraphDatabaseAPI graphDatabaseAPI, Log log, ExecutorService executorService, AllocationTracker allocationTracker) {
        super(approximateNearestNeighborsConfig, graphDatabaseAPI);
        this.config = approximateNearestNeighborsConfig;
        this.algorithm = similarityAlgorithm;
        this.log = log;
        this.executor = executorService;
        this.tracker = allocationTracker;
        this.nodeQueue = new AtomicLong();
        this.actualIterations = new AtomicInteger();
        this.random = new Random(approximateNearestNeighborsConfig.randomSeed());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.neo4j.graphalgo.impl.similarity.SimilarityAlgorithm
    public INPUT[] prepareInputs(Object obj, SimilarityConfig similarityConfig) {
        return this.algorithm.prepareInputs(obj, similarityConfig);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neo4j.graphalgo.impl.similarity.SimilarityAlgorithm
    public Supplier<RleDecoder> createDecoderFactory(int i) {
        return this.algorithm.createDecoderFactory(i);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.neo4j.graphalgo.impl.similarity.SimilarityAlgorithm
    public Supplier<RleDecoder> inputDecoderFactory(INPUT[] inputArr) {
        return this.algorithm.inputDecoderFactory(inputArr);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // org.neo4j.graphalgo.impl.similarity.SimilarityAlgorithm
    public SimilarityComputer<INPUT> similarityComputer(Double d, int[] iArr, int[] iArr2) {
        return this.algorithm.similarityComputer(d, iArr, iArr2);
    }

    @Override // org.neo4j.graphalgo.impl.similarity.SimilarityAlgorithm
    protected Stream<SimilarityResult> similarityStream(INPUT[] inputArr, int[] iArr, int[] iArr2, SimilarityComputer<INPUT> similarityComputer, Supplier<RleDecoder> supplier, double d, int i) {
        double min = Math.min(this.config.p(), 1.0d) * Math.abs(this.config.topK());
        int length = inputArr.length;
        AnnTopKConsumer[] initializeTopKConsumers = AnnTopKConsumer.initializeTopKConsumers(length, i);
        ParallelUtil.runWithConcurrency(this.config.concurrency(), createInitTasks(inputArr, initializeTopKConsumers, supplier, similarityComputer), this.executor);
        IdsAndProperties buildNodes = buildNodes(inputArr);
        RoaringBitmap[] initializeRoaringBitmaps = ANNUtils.initializeRoaringBitmaps(length);
        RoaringBitmap[] initializeRoaringBitmaps2 = ANNUtils.initializeRoaringBitmaps(length);
        for (int i2 = 1; i2 <= this.config.maxIterations(); i2++) {
            for (int i3 = 0; i3 < length; i3++) {
                initializeRoaringBitmaps[i3] = RoaringBitmap.or(initializeRoaringBitmaps[i3], initializeRoaringBitmaps2[i3]);
            }
            initializeRoaringBitmaps2 = ANNUtils.initializeRoaringBitmaps(length);
            RelationshipImporter of = RelationshipImporter.of(buildNodes.idMap(), this.executor, this.tracker);
            of.consume(initializeTopKConsumers);
            IdMapGraph union = of.buildGraphStore(buildNodes.idMap(), this.tracker).getUnion();
            RelationshipImporter of2 = RelationshipImporter.of(buildNodes.idMap(), this.executor, this.tracker);
            RelationshipImporter of3 = RelationshipImporter.of(buildNodes.idMap(), this.executor, this.tracker);
            ParallelUtil.runWithConcurrency(1, setupTasks(min, length, initializeRoaringBitmaps, initializeRoaringBitmaps2, union, of2, of3), this.executor);
            Collection<NeighborhoodTask> computeTasks = computeTasks(min, inputArr, similarityComputer, of3.buildGraphStore(buildNodes.idMap(), this.tracker), supplier, of2.buildGraphStore(buildNodes.idMap(), this.tracker));
            ParallelUtil.runWithConcurrency(this.config.concurrency(), computeTasks, this.executor);
            int mergeConsumers = mergeConsumers(computeTasks, initializeTopKConsumers);
            this.log.info("ANN: Changes in iteration %d: %d", new Object[]{Integer.valueOf(i2), Integer.valueOf(mergeConsumers)});
            this.actualIterations.set(i2);
            if (shouldTerminate(mergeConsumers, length, this.config.topK())) {
                break;
            }
        }
        return Arrays.stream(initializeTopKConsumers).flatMap((v0) -> {
            return v0.stream();
        });
    }

    private Collection<Runnable> setupTasks(double d, int i, RoaringBitmap[] roaringBitmapArr, RoaringBitmap[] roaringBitmapArr2, Graph graph, RelationshipImporter relationshipImporter, RelationshipImporter relationshipImporter2) {
        int adjustedBatchSize = ParallelUtil.adjustedBatchSize(i, this.config.concurrency(), 100);
        int i2 = (i / adjustedBatchSize) + 1;
        ArrayList arrayList = new ArrayList(i2);
        long j = 0;
        for (int i3 = 0; i3 < i2; i3++) {
            long min = Math.min(adjustedBatchSize, i - (i3 * adjustedBatchSize));
            arrayList.add(new SetupTask(new NewOldGraph(graph, roaringBitmapArr), roaringBitmapArr2, relationshipImporter, relationshipImporter2, d, j, min));
            j += min;
        }
        return arrayList;
    }

    private List<Runnable> createInitTasks(INPUT[] inputArr, AnnTopKConsumer[] annTopKConsumerArr, Supplier<RleDecoder> supplier, SimilarityComputer<INPUT> similarityComputer) {
        this.nodeQueue.set(0L);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.config.concurrency(); i++) {
            arrayList.add(new InitTask(inputArr, annTopKConsumerArr, supplier, similarityComputer));
        }
        return arrayList;
    }

    private Collection<NeighborhoodTask> computeTasks(double d, INPUT[] inputArr, SimilarityComputer<INPUT> similarityComputer, GraphStore graphStore, Supplier<RleDecoder> supplier, GraphStore graphStore2) {
        this.nodeQueue.set(0L);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.config.concurrency(); i++) {
            arrayList.add(new ComputeTask(inputArr, similarityComputer, supplier, inputArr.length, graphStore, graphStore2, d));
        }
        return arrayList;
    }

    private IdsAndProperties buildNodes(INPUT[] inputArr) {
        HugeLongArrayBuilder of = HugeLongArrayBuilder.of(inputArr.length, AllocationTracker.EMPTY);
        NodeImporter nodeImporter = new NodeImporter(of);
        long j = 0;
        NodesBatchBuffer build = new NodesBatchBufferBuilder().nodeLabelIds(new LongHashSet()).capacity(inputArr.length).hasLabelInformation(false).readProperty(false).build();
        for (INPUT input : inputArr) {
            if (input.getId() > j) {
                j = input.getId();
            }
            build.add(input.getId(), -1L, (long[]) null);
            if (build.isFull()) {
                nodeImporter.importNodes(build, (NodeImporter.PropertyReader) null);
                build.reset();
            }
        }
        nodeImporter.importNodes(build, (NodeImporter.PropertyReader) null);
        return new IdsAndProperties(IdMapBuilder.build(of, (Map) null, j, 1, AllocationTracker.EMPTY), Collections.emptyMap());
    }

    private int mergeConsumers(Iterable<NeighborhoodTask> iterable, AnnTopKConsumer[] annTopKConsumerArr) {
        int i = 0;
        Iterator<NeighborhoodTask> it = iterable.iterator();
        while (it.hasNext()) {
            i += it.next().mergeInto(annTopKConsumerArr);
        }
        return i;
    }

    private boolean shouldTerminate(int i, int i2, int i3) {
        return i == 0 || ((double) i) < ((double) (i2 * Math.abs(i3))) * this.config.precision();
    }

    private LongHashSet findNeighbors(long j, RelationshipIterator relationshipIterator) {
        LongHashSet longHashSet = new LongHashSet();
        relationshipIterator.forEachRelationship(j, (j2, j3) -> {
            longHashSet.add(j3);
            return true;
        });
        return longHashSet;
    }

    public int iterations() {
        return this.actualIterations.get();
    }
}
