package org.neo4j.gds.algorithms.machinelearning;

import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.DoubleArrayList;
import com.carrotsearch.hppc.predicates.LongLongPredicate;
import java.util.List;
import java.util.Objects;
import java.util.stream.LongStream;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.SetBitsIterable;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.similarity.nodesim.TopKMap;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.AutoCloseableThreadLocal;
import org.neo4j.gds.utils.CloseableThreadLocal;

/* loaded from: input_file:org/neo4j/gds/algorithms/machinelearning/TopKMapComputer.class */
public class TopKMapComputer extends Algorithm<KGEPredictResult> {
    private final Graph graph;
    private final ProgressTracker progressTracker;
    private final BitSet sourceNodes;
    private final BitSet targetNodes;
    private final String nodeEmbeddingProperty;
    private final DoubleArrayList relationshipTypeEmbedding;
    private final Concurrency concurrency;
    private final int topK;
    private final ScoreFunction scoreFunction;
    private final boolean higherIsBetter;

    public TopKMapComputer(Graph graph, BitSet bitSet, BitSet bitSet2, String str, List<Double> list, ScoreFunction scoreFunction, int i, Concurrency concurrency, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        super(progressTracker);
        this.graph = graph;
        this.progressTracker = progressTracker;
        this.sourceNodes = bitSet;
        this.targetNodes = bitSet2;
        this.nodeEmbeddingProperty = str;
        this.relationshipTypeEmbedding = DoubleArrayList.from(list.stream().mapToDouble((v0) -> {
            return v0.doubleValue();
        }).toArray());
        this.concurrency = concurrency;
        this.topK = i;
        this.scoreFunction = scoreFunction;
        this.higherIsBetter = scoreFunction == ScoreFunction.DISTMULT;
        this.terminationFlag = terminationFlag;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public KGEPredictResult m2compute() {
        this.progressTracker.beginSubTask(estimateWorkload());
        TopKMap topKMap = new TopKMap(this.sourceNodes.capacity(), this.sourceNodes, Math.abs(this.topK), this.higherIsBetter);
        NodePropertyValues nodeProperties = this.graph.nodeProperties(this.nodeEmbeddingProperty);
        AutoCloseableThreadLocal withInitial = AutoCloseableThreadLocal.withInitial(() -> {
            return LinkScorerFactory.create(this.scoreFunction, nodeProperties, this.relationshipTypeEmbedding);
        });
        try {
            Graph graph = this.graph;
            Objects.requireNonNull(graph);
            CloseableThreadLocal withInitial2 = CloseableThreadLocal.withInitial(graph::concurrentCopy);
            try {
                ParallelUtil.parallelStreamConsume(new SetBitsIterable(this.sourceNodes).stream(), this.concurrency, this.terminationFlag, longStream -> {
                    longStream.forEach(j -> {
                        this.terminationFlag.assertRunning();
                        LongLongPredicate isCandidateLink = isCandidateLink((Graph) withInitial2.get());
                        LinkScorer linkScorer = (LinkScorer) withInitial.get();
                        linkScorer.init(j);
                        targetNodesStream().filter(j -> {
                            return isCandidateLink.apply(j, j);
                        }).forEach(j2 -> {
                            double computeScore = linkScorer.computeScore(j2);
                            if (Double.isNaN(computeScore)) {
                                return;
                            }
                            topKMap.put(j, j2, computeScore);
                        });
                    });
                });
                if (withInitial2 != null) {
                    withInitial2.close();
                }
                this.progressTracker.logProgress();
                if (withInitial != null) {
                    withInitial.close();
                }
                this.progressTracker.endSubTask();
                return KGEPredictResult.of(topKMap);
            } finally {
            }
        } catch (Throwable th) {
            if (withInitial != null) {
                try {
                    withInitial.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private LongStream targetNodesStream() {
        return new SetBitsIterable(this.targetNodes, 0L).stream();
    }

    private long estimateWorkload() {
        return this.sourceNodes.cardinality() * this.targetNodes.cardinality();
    }

    private LongLongPredicate isCandidateLink(Graph graph) {
        return (j, j2) -> {
            return (j == j2 || graph.exists(j, j2)) ? false : true;
        };
    }
}
