package org.neo4j.gds.pricesteiner;

import com.carrotsearch.hppc.BitSet;
import java.util.function.LongToDoubleFunction;
import org.agrona.collections.MutableLong;
import org.apache.commons.lang3.mutable.MutableDouble;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/pricesteiner/StrongPruning.class */
public class StrongPruning {
    private final TreeStructure treeStructure;
    private final BitSet activeOriginalNodes;
    private final LongToDoubleFunction prizes;
    private final HugeLongArray parentArray;
    private final HugeDoubleArray parentCostArray;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private long effectiveNodeCount;
    private double sumOfPrizes;
    private double totalWeight;

    public StrongPruning(TreeStructure treeStructure, BitSet bitSet, LongToDoubleFunction longToDoubleFunction, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.treeStructure = treeStructure;
        this.activeOriginalNodes = bitSet;
        this.prizes = longToDoubleFunction;
        this.parentArray = HugeLongArray.newArray(treeStructure.originalNodeCount());
        this.parentCostArray = HugeDoubleArray.newArray(treeStructure.originalNodeCount());
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
        this.parentArray.fill(-2L);
        this.effectiveNodeCount = this.activeOriginalNodes.cardinality();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void performPruning() {
        this.progressTracker.beginSubTask("Pruning Phase");
        if (this.activeOriginalNodes.cardinality() == 1) {
            int nextSetBit = this.activeOriginalNodes.nextSetBit(0);
            this.parentArray.set(nextSetBit, -1L);
            this.effectiveNodeCount = 1L;
            this.sumOfPrizes = this.prizes.applyAsDouble(nextSetBit);
        } else {
            HugeLongArray newArray = HugeLongArray.newArray(this.activeOriginalNodes.cardinality());
            HugeDoubleArray newArray2 = HugeDoubleArray.newArray(this.treeStructure.originalNodeCount());
            long j = 0;
            long j2 = 0;
            long j3 = -1;
            Graph tree = this.treeStructure.tree();
            HugeLongArray degrees = this.treeStructure.degrees();
            long j4 = 0;
            while (true) {
                long j5 = j4;
                if (j5 >= degrees.size()) {
                    break;
                }
                if (degrees.get(j5) == 1) {
                    long j6 = j;
                    j = j6 + 1;
                    newArray.set(j6, j5);
                }
                j4 = j5 + 1;
            }
            while (j2 < j) {
                this.terminationFlag.assertRunning();
                long j7 = j2;
                j2 = j7 + 1;
                long j8 = newArray.get(j7);
                MutableLong mutableLong = new MutableLong(-1L);
                MutableDouble mutableDouble = new MutableDouble(-1.0d);
                double applyAsDouble = this.prizes.applyAsDouble(j8);
                newArray2.addTo(j8, applyAsDouble);
                degrees.set(j8, 0L);
                this.progressTracker.logProgress();
                this.sumOfPrizes += applyAsDouble;
                tree.forEachRelationship(j8, 1.0d, (j9, j10, d) -> {
                    if (degrees.get(j10) <= 0) {
                        return true;
                    }
                    mutableLong.set(j10);
                    mutableDouble.setValue(d);
                    return false;
                });
                long j11 = mutableLong.get();
                this.parentArray.set(j8, -1L);
                if (j3 == -1 || Double.compare(newArray2.get(j3), newArray2.get(j8)) < 0) {
                    j3 = j8;
                }
                if (j11 != -1) {
                    Double value = mutableDouble.getValue();
                    if (Double.compare(value.doubleValue(), newArray2.get(j8)) < 0) {
                        newArray2.addTo(j11, newArray2.get(j8) - value.doubleValue());
                        this.parentArray.set(j8, j11);
                        this.parentCostArray.set(j8, value.doubleValue());
                        this.totalWeight += value.doubleValue();
                    }
                    degrees.addTo(j11, -1L);
                    if (degrees.get(j11) == 1) {
                        long j12 = j;
                        j = j12 + 1;
                        newArray.set(j12, j11);
                    }
                }
            }
            pruneUnnecessarySubTrees(j3, newArray, this.parentArray);
        }
        this.progressTracker.endSubTask("Pruning Phase");
    }

    void pruneUnnecessarySubTrees(long j, HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2) {
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= this.treeStructure.tree().nodeCount()) {
                return;
            }
            if (hugeLongArray2.get(j3) == -1 && j3 != j) {
                pruneSubtree(j3, hugeLongArray, hugeLongArray2);
            }
            j2 = j3 + 1;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public PrizeSteinerTreeResult resultTree() {
        return new PrizeSteinerTreeResult(this.parentArray, this.parentCostArray, this.effectiveNodeCount, this.totalWeight, this.sumOfPrizes);
    }

    private void pruneSubtree(long j, HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2) {
        this.terminationFlag.assertRunning();
        Graph tree = this.treeStructure.tree();
        long j2 = 0;
        MutableLong mutableLong = new MutableLong();
        hugeLongArray.set(mutableLong.getAndIncrement(), j);
        while (j2 < mutableLong.get()) {
            long j3 = j2;
            j2 = j3 + 1;
            long j4 = hugeLongArray.get(j3);
            this.progressTracker.logProgress();
            hugeLongArray2.set(j4, -2L);
            this.effectiveNodeCount--;
            this.sumOfPrizes -= this.prizes.applyAsDouble(j4);
            tree.forEachRelationship(j4, 1.0d, (j5, j6, d) -> {
                if (hugeLongArray2.get(j6) != j5) {
                    return true;
                }
                hugeLongArray.set(mutableLong.getAndIncrement(), j6);
                this.totalWeight -= d;
                return true;
            });
        }
    }
}
