package org.neo4j.gds.similarity.nodesim;

import com.carrotsearch.hppc.BitSet;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.RelationshipConsumer;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.SetBitsIterable;
import org.neo4j.gds.core.utils.progress.BatchingProgressLogger;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.similarity.SimilarityGraphBuilder;
import org.neo4j.gds.similarity.SimilarityGraphResult;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.gds.similarity.filtering.NodeFilter;

/* loaded from: input_file:org/neo4j/gds/similarity/nodesim/NodeSimilarity.class */
public class NodeSimilarity extends Algorithm<NodeSimilarityResult> {
    private final Graph graph;
    private final boolean sortVectors;
    private final NodeSimilarityBaseConfig config;
    private final BitSet sourceNodes;
    private final BitSet targetNodes;
    private final NodeFilter sourceNodeFilter;
    private final NodeFilter targetNodeFilter;
    private final ExecutorService executorService;
    private final int concurrency;
    private final MetricSimilarityComputer similarityComputer;
    private HugeObjectArray<long[]> neighbors;
    private HugeObjectArray<double[]> weights;
    private final boolean weighted;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/similarity/nodesim/NodeSimilarity$DegreeComputer.class */
    public static final class DegreeComputer implements RelationshipConsumer {
        long lastTarget = -1;
        int degree = 0;

        private DegreeComputer() {
        }

        public boolean accept(long j, long j2) {
            if (j != j2 && this.lastTarget != j2) {
                this.degree++;
            }
            this.lastTarget = j2;
            return true;
        }

        void reset() {
            this.lastTarget = -1L;
            this.degree = 0;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/similarity/nodesim/NodeSimilarity$SimilarityConsumer.class */
    public interface SimilarityConsumer {
        void accept(long j, long j2, double d);
    }

    public static NodeSimilarity create(Graph graph, NodeSimilarityBaseConfig nodeSimilarityBaseConfig, ExecutorService executorService, ProgressTracker progressTracker) {
        return new NodeSimilarity(graph, nodeSimilarityBaseConfig, nodeSimilarityBaseConfig.similarityMetric().build(nodeSimilarityBaseConfig.similarityCutoff()), nodeSimilarityBaseConfig.concurrency(), executorService, progressTracker);
    }

    public NodeSimilarity(Graph graph, NodeSimilarityBaseConfig nodeSimilarityBaseConfig, MetricSimilarityComputer metricSimilarityComputer, int i, ExecutorService executorService, ProgressTracker progressTracker) {
        this(graph, nodeSimilarityBaseConfig, metricSimilarityComputer, NodeFilter.noOp, NodeFilter.noOp, i, executorService, progressTracker);
    }

    public NodeSimilarity(Graph graph, NodeSimilarityBaseConfig nodeSimilarityBaseConfig, MetricSimilarityComputer metricSimilarityComputer, NodeFilter nodeFilter, NodeFilter nodeFilter2, int i, ExecutorService executorService, ProgressTracker progressTracker) {
        super(progressTracker);
        this.graph = graph;
        this.sortVectors = graph.schema().relationshipSchema().availableTypes().size() > 1;
        this.sourceNodeFilter = nodeFilter;
        this.targetNodeFilter = nodeFilter2;
        this.concurrency = i;
        this.config = nodeSimilarityBaseConfig;
        this.similarityComputer = metricSimilarityComputer;
        this.executorService = executorService;
        this.sourceNodes = new BitSet(graph.nodeCount());
        this.targetNodes = new BitSet(graph.nodeCount());
        this.weighted = nodeSimilarityBaseConfig.hasRelationshipWeightProperty();
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public NodeSimilarityResult m130compute() {
        this.progressTracker.beginSubTask();
        prepare();
        if (this.config.computeToStream()) {
            Stream<SimilarityResult> computeToStream = computeToStream();
            this.progressTracker.endSubTask();
            return ImmutableNodeSimilarityResult.of((Optional<? extends Stream<SimilarityResult>>) Optional.of(computeToStream), (Optional<? extends SimilarityGraphResult>) Optional.empty());
        }
        SimilarityGraphResult computeToGraph = computeToGraph();
        this.progressTracker.endSubTask();
        return ImmutableNodeSimilarityResult.of((Optional<? extends Stream<SimilarityResult>>) Optional.empty(), (Optional<? extends SimilarityGraphResult>) Optional.of(computeToGraph));
    }

    private Stream<SimilarityResult> computeToStream() {
        this.terminationFlag.assertRunning();
        return (!this.config.hasTopN() || this.config.hasTopK()) ? this.config.isParallel() ? computeParallel() : computeSimilarityResultStream() : computeTopN();
    }

    private SimilarityGraphResult computeToGraph() {
        TopKGraph build;
        boolean z = false;
        if (!this.config.hasTopK() || this.config.hasTopN()) {
            build = new SimilarityGraphBuilder(this.graph, this.concurrency, this.executorService, this.terminationFlag).build(computeToStream());
        } else {
            this.terminationFlag.assertRunning();
            z = true;
            build = new TopKGraph(this.graph, this.config.isParallel() ? computeTopKMapParallel() : computeTopKMap());
        }
        return new SimilarityGraphResult(build, this.sourceNodes.cardinality(), z);
    }

    private void prepare() {
        this.progressTracker.beginSubTask();
        this.neighbors = HugeObjectArray.newArray(long[].class, this.graph.nodeCount());
        if (this.weighted) {
            this.weights = HugeObjectArray.newArray(double[].class, this.graph.nodeCount());
        }
        DegreeComputer degreeComputer = new DegreeComputer();
        VectorComputer of = VectorComputer.of(this.graph, this.weighted);
        DegreeFilter degreeFilter = new DegreeFilter(this.config.degreeCutoff(), this.config.upperDegreeCutoff());
        this.neighbors.setAll(j -> {
            this.graph.forEachRelationship(j, degreeComputer);
            int i = degreeComputer.degree;
            degreeComputer.reset();
            of.reset(i);
            this.progressTracker.logProgress(this.graph.degree(j));
            if (!degreeFilter.apply(i)) {
                return null;
            }
            if (this.sourceNodeFilter.test(j)) {
                this.sourceNodes.set(j);
            }
            if (this.targetNodeFilter.test(j)) {
                this.targetNodes.set(j);
            }
            of.forEachRelationship(j);
            if (this.sortVectors) {
                of.sortTargetIds();
            }
            if (this.weighted) {
                this.weights.set(j, of.getWeights());
            }
            return of.targetIds.buffer;
        });
        this.progressTracker.endSubTask();
    }

    private Stream<SimilarityResult> computeSimilarityResultStream() {
        if (!this.config.hasTopK()) {
            return computeAll();
        }
        TopKMap computeTopKMap = computeTopKMap();
        return this.config.hasTopN() ? computeTopN(computeTopKMap) : computeTopKMap.stream();
    }

    private Stream<SimilarityResult> computeParallel() {
        if (!this.config.hasTopK()) {
            return computeAllParallel();
        }
        TopKMap computeTopKMapParallel = computeTopKMapParallel();
        return this.config.hasTopN() ? computeTopN(computeTopKMapParallel) : computeTopKMapParallel.stream();
    }

    private Stream<SimilarityResult> computeAll() {
        this.progressTracker.beginSubTask(calculateWorkload());
        Stream flatMap = loggableAndTerminatableSourceNodeStream().boxed().flatMap((v1) -> {
            return computeSimilaritiesForNode(v1);
        });
        this.progressTracker.endSubTask();
        return flatMap;
    }

    private Stream<SimilarityResult> computeAllParallel() {
        return (Stream) ParallelUtil.parallelStream(loggableAndTerminatableSourceNodeStream(), this.concurrency, longStream -> {
            return longStream.boxed().flatMap((v1) -> {
                return computeSimilaritiesForNode(v1);
            });
        });
    }

    private TopKMap computeTopKMap() {
        this.progressTracker.beginSubTask(calculateWorkload());
        TopKMap topKMap = new TopKMap(this.neighbors.size(), this.sourceNodes, Math.abs(this.config.normalizedK()), this.config.normalizedK() > 0 ? SimilarityResult.DESCENDING : SimilarityResult.ASCENDING);
        loggableAndTerminatableSourceNodeStream().forEach(j -> {
            if (this.sourceNodeFilter.equals(NodeFilter.noOp)) {
                targetNodesStream(j + 1).forEach(j -> {
                    computeSimilarityFor(j, j, (j, j2, d) -> {
                        topKMap.put(j, j2, d);
                        topKMap.put(j2, j, d);
                    });
                });
            } else {
                targetNodesStream().filter(j2 -> {
                    return j != j2;
                }).forEach(j3 -> {
                    Objects.requireNonNull(topKMap);
                    computeSimilarityFor(j, j3, topKMap::put);
                });
            }
        });
        this.progressTracker.endSubTask();
        return topKMap;
    }

    private TopKMap computeTopKMapParallel() {
        this.progressTracker.beginSubTask(calculateWorkload());
        TopKMap topKMap = new TopKMap(this.neighbors.size(), this.sourceNodes, Math.abs(this.config.normalizedK()), this.config.normalizedK() > 0 ? SimilarityResult.DESCENDING : SimilarityResult.ASCENDING);
        ParallelUtil.parallelStreamConsume(loggableAndTerminatableSourceNodeStream(), this.concurrency, this.terminationFlag, longStream -> {
            longStream.forEach(j -> {
                targetNodesStream().filter(j -> {
                    return j != j;
                }).forEach(j2 -> {
                    Objects.requireNonNull(topKMap);
                    computeSimilarityFor(j, j2, topKMap::put);
                });
            });
        });
        this.progressTracker.endSubTask();
        return topKMap;
    }

    private Stream<SimilarityResult> computeTopN() {
        this.progressTracker.beginSubTask(calculateWorkload());
        TopNList topNList = new TopNList(this.config.normalizedN());
        loggableAndTerminatableSourceNodeStream().forEach(j -> {
            if (this.sourceNodeFilter.equals(NodeFilter.noOp)) {
                targetNodesStream(j + 1).forEach(j -> {
                    Objects.requireNonNull(topNList);
                    computeSimilarityFor(j, j, topNList::add);
                });
            } else {
                targetNodesStream().filter(j2 -> {
                    return j != j2;
                }).forEach(j3 -> {
                    Objects.requireNonNull(topNList);
                    computeSimilarityFor(j, j3, topNList::add);
                });
            }
        });
        this.progressTracker.endSubTask();
        return topNList.stream();
    }

    private Stream<SimilarityResult> computeTopN(TopKMap topKMap) {
        TopNList topNList = new TopNList(this.config.normalizedN());
        Objects.requireNonNull(topNList);
        topKMap.forEach(topNList::add);
        return topNList.stream();
    }

    private LongStream sourceNodesStream(long j) {
        return new SetBitsIterable(this.sourceNodes, j).stream();
    }

    private LongStream sourceNodesStream() {
        return sourceNodesStream(0L);
    }

    private LongStream loggableAndTerminatableSourceNodeStream() {
        return checkProgress(sourceNodesStream());
    }

    private LongStream targetNodesStream(long j) {
        return new SetBitsIterable(this.targetNodes, j).stream();
    }

    private LongStream targetNodesStream() {
        return targetNodesStream(0L);
    }

    private Stream<SimilarityResult> computeSimilaritiesForNode(long j) {
        return targetNodesStream(j + 1).mapToObj(j2 -> {
            SimilarityResult[] similarityResultArr = {null};
            computeSimilarityFor(j, j2, (j2, j3, d) -> {
                similarityResultArr[0] = new SimilarityResult(j2, j3, d);
            });
            return similarityResultArr[0];
        }).filter((v0) -> {
            return Objects.nonNull(v0);
        });
    }

    private void computeSimilarityFor(long j, long j2, SimilarityConsumer similarityConsumer) {
        long[] jArr = (long[]) this.neighbors.get(j);
        long[] jArr2 = (long[]) this.neighbors.get(j2);
        double computeWeightedSimilarity = this.weighted ? computeWeightedSimilarity(jArr, jArr2, (double[]) this.weights.get(j), (double[]) this.weights.get(j2)) : computeSimilarity(jArr, jArr2);
        if (Double.isNaN(computeWeightedSimilarity)) {
            return;
        }
        similarityConsumer.accept(j, j2, computeWeightedSimilarity);
    }

    private double computeWeightedSimilarity(long[] jArr, long[] jArr2, double[] dArr, double[] dArr2) {
        double computeWeightedSimilarity = this.similarityComputer.computeWeightedSimilarity(jArr, jArr2, dArr, dArr2);
        this.progressTracker.logProgress();
        return computeWeightedSimilarity;
    }

    private double computeSimilarity(long[] jArr, long[] jArr2) {
        double computeSimilarity = this.similarityComputer.computeSimilarity(jArr, jArr2);
        this.progressTracker.logProgress();
        return computeSimilarity;
    }

    private LongStream checkProgress(LongStream longStream) {
        return longStream.peek(j -> {
            if ((j & BatchingProgressLogger.MAXIMUM_LOG_INTERVAL) == 0) {
                this.terminationFlag.assertRunning();
            }
        });
    }

    private long calculateWorkload() {
        long cardinality = this.sourceNodes.cardinality() * this.targetNodes.cardinality();
        boolean z = this.sourceNodeFilter.equals(NodeFilter.noOp) && this.targetNodeFilter.equals(NodeFilter.noOp);
        if (this.concurrency == 1 && z) {
            cardinality /= 2;
        }
        return cardinality;
    }
}
