package org.neo4j.graphalgo.beta.modularity;

import com.carrotsearch.hppc.BitSet;
import com.carrotsearch.hppc.cursors.LongLongCursor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Objects;
import java.util.concurrent.ExecutorService;
import java.util.stream.LongStream;
import org.apache.commons.lang3.mutable.MutableDouble;
import org.jetbrains.annotations.Nullable;
import org.neo4j.graphalgo.Algorithm;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.NodeProperties;
import org.neo4j.graphalgo.beta.k1coloring.ImmutableK1ColoringStreamConfig;
import org.neo4j.graphalgo.beta.k1coloring.K1Coloring;
import org.neo4j.graphalgo.beta.k1coloring.K1ColoringFactory;
import org.neo4j.graphalgo.core.concurrency.ParallelUtil;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeAtomicDoubleArray;
import org.neo4j.graphalgo.core.utils.paged.HugeDoubleArray;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.graphalgo.core.utils.paged.HugeLongLongMap;
import org.neo4j.graphalgo.core.utils.paged.PageFiller;
import org.neo4j.graphalgo.utils.CloseableThreadLocal;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/graphalgo/beta/modularity/ModularityOptimization.class */
public final class ModularityOptimization extends Algorithm<ModularityOptimization, ModularityOptimization> {
    private final int concurrency;
    private final int maxIterations;
    private final long nodeCount;
    private final long batchSize;
    private final double tolerance;
    private final Graph graph;
    private final NodeProperties seedProperty;
    private final ExecutorService executor;
    private final AllocationTracker tracker;
    private int iterationCounter;
    private boolean didConverge = false;
    private double totalNodeWeight = 0.0d;
    private double modularity = -1.0d;
    private BitSet colorsUsed;
    private HugeLongArray colors;
    private HugeLongArray currentCommunities;
    private HugeLongArray nextCommunities;
    private HugeLongArray reverseSeedCommunityMapping;
    private HugeDoubleArray cumulativeNodeWeights;
    private HugeDoubleArray nodeCommunityInfluences;
    private HugeAtomicDoubleArray communityWeights;
    private HugeAtomicDoubleArray communityWeightUpdates;

    public ModularityOptimization(Graph graph, int i, double d, @Nullable NodeProperties nodeProperties, int i2, int i3, ExecutorService executorService, ProgressLogger progressLogger, AllocationTracker allocationTracker) {
        this.graph = graph;
        this.nodeCount = graph.nodeCount();
        this.maxIterations = i;
        this.tolerance = d;
        this.seedProperty = nodeProperties;
        this.executor = executorService;
        this.concurrency = i2;
        this.progressLogger = progressLogger;
        this.tracker = allocationTracker;
        this.batchSize = ParallelUtil.adjustedBatchSize(this.nodeCount, i2, i3, 2147483647L);
        if (i < 1) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Need to run at least one iteration, but got %d", new Object[]{Integer.valueOf(i)}));
        }
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public ModularityOptimization m4compute() {
        this.progressLogger.logMessage(":: Start");
        this.progressLogger.logMessage(":: Initialization :: Start");
        computeColoring();
        initSeeding();
        init();
        this.progressLogger.logMessage(":: Initialization :: Finished");
        this.iterationCounter = 0;
        while (true) {
            if (this.iterationCounter >= this.maxIterations) {
                break;
            }
            this.progressLogger.logMessage(StringFormatting.formatWithLocale(":: Iteration %d :: Start", new Object[]{Integer.valueOf(this.iterationCounter + 1)}));
            this.nodeCommunityInfluences.fill(0.0d);
            long nextSetBit = this.colorsUsed.nextSetBit(0);
            while (true) {
                long j = nextSetBit;
                if (j == -1) {
                    break;
                }
                assertRunning();
                optimizeForColor(j);
                nextSetBit = this.colorsUsed.nextSetBit(j + 1);
            }
            boolean z = !updateModularity();
            this.progressLogger.logMessage(StringFormatting.formatWithLocale(":: Iteration %d :: Finished", new Object[]{Integer.valueOf(this.iterationCounter + 1)}));
            if (z) {
                this.didConverge = true;
                this.iterationCounter++;
                break;
            }
            this.progressLogger.reset(this.graph.relationshipCount());
            this.iterationCounter++;
        }
        this.progressLogger.logMessage(":: Finished");
        return this;
    }

    private void computeColoring() {
        K1Coloring k1Coloring = (K1Coloring) new K1ColoringFactory().build(this.graph, (Graph) ImmutableK1ColoringStreamConfig.builder().concurrency(this.concurrency).maxIterations(5).batchSize((int) this.batchSize).build(), this.tracker, this.progressLogger.getLog()).withTerminationFlag(this.terminationFlag);
        this.colors = k1Coloring.m2compute();
        this.colorsUsed = k1Coloring.usedColors();
    }

    private void initSeeding() {
        this.currentCommunities = HugeLongArray.newArray(this.nodeCount, this.tracker);
        if (this.seedProperty == null) {
            return;
        }
        long orElse = this.seedProperty.getMaxPropertyValue().orElse(0L);
        HugeLongLongMap hugeLongLongMap = new HugeLongLongMap(this.nodeCount, this.tracker);
        long j = -1;
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= this.nodeCount) {
                break;
            }
            long nodeProperty = (long) this.seedProperty.nodeProperty(j3, -1.0d);
            long originalNodeId = nodeProperty >= 0 ? nodeProperty : this.graph.toOriginalNodeId(j3) + orElse;
            if (hugeLongLongMap.getOrDefault(originalNodeId, -1L) < 0) {
                long j4 = j + 1;
                j = j4;
                hugeLongLongMap.addTo(originalNodeId, j4);
            }
            this.currentCommunities.set(j3, hugeLongLongMap.getOrDefault(originalNodeId, -1L));
            j2 = j3 + 1;
        }
        this.reverseSeedCommunityMapping = HugeLongArray.newArray(hugeLongLongMap.size(), this.tracker);
        Iterator it = hugeLongLongMap.iterator();
        while (it.hasNext()) {
            LongLongCursor longLongCursor = (LongLongCursor) it.next();
            this.reverseSeedCommunityMapping.set(longLongCursor.value, longLongCursor.key);
        }
    }

    private void init() {
        this.nextCommunities = HugeLongArray.newArray(this.nodeCount, this.tracker);
        this.cumulativeNodeWeights = HugeDoubleArray.newArray(this.nodeCount, this.tracker);
        this.nodeCommunityInfluences = HugeDoubleArray.newArray(this.nodeCount, this.tracker);
        this.communityWeights = HugeAtomicDoubleArray.newArray(this.nodeCount, this.tracker);
        this.communityWeightUpdates = HugeAtomicDoubleArray.newArray(this.nodeCount, this.tracker);
        Graph graph = this.graph;
        Objects.requireNonNull(graph);
        CloseableThreadLocal withInitial = CloseableThreadLocal.withInitial(graph::concurrentCopy);
        try {
            double doubleValue = ((Double) ParallelUtil.parallelStream(LongStream.range(0L, this.nodeCount), this.concurrency, longStream -> {
                return Double.valueOf(longStream.mapToDouble(j -> {
                    if (this.seedProperty == null) {
                        this.currentCommunities.set(j, j);
                    }
                    MutableDouble mutableDouble = new MutableDouble(0.0d);
                    ((Graph) withInitial.get()).forEachRelationship(j, 1.0d, (j, j2, d) -> {
                        mutableDouble.add(d);
                        return true;
                    });
                    this.communityWeights.update(this.currentCommunities.get(j), d2 -> {
                        return d2 + mutableDouble.doubleValue();
                    });
                    this.cumulativeNodeWeights.set(j, mutableDouble.doubleValue());
                    return mutableDouble.doubleValue();
                }).reduce(Double::sum).orElseThrow(() -> {
                    return new RuntimeException("Error initializing modularity optimization.");
                }));
            })).doubleValue();
            if (withInitial != null) {
                withInitial.close();
            }
            this.totalNodeWeight = doubleValue / 2.0d;
            this.currentCommunities.copyTo(this.nextCommunities, this.nodeCount);
        } catch (Throwable th) {
            if (withInitial != null) {
                try {
                    withInitial.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void optimizeForColor(long j) {
        ParallelUtil.runWithConcurrency(this.concurrency, createModularityOptimizationTasks(j), this.executor);
        this.nextCommunities.copyTo(this.currentCommunities, this.nodeCount);
        ParallelUtil.parallelStreamConsume(LongStream.range(0L, this.nodeCount), this.concurrency, longStream -> {
            longStream.forEach(j2 -> {
                double d = this.communityWeightUpdates.get(j2);
                this.communityWeights.update(j2, d2 -> {
                    return d2 + d;
                });
            });
        });
        this.communityWeightUpdates = HugeAtomicDoubleArray.newArray(this.nodeCount, PageFiller.allZeros(this.concurrency), this.tracker);
    }

    private Collection<ModularityOptimizationTask> createModularityOptimizationTasks(long j) {
        ArrayList arrayList = new ArrayList(this.concurrency);
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= this.nodeCount) {
                return arrayList;
            }
            arrayList.add(new ModularityOptimizationTask(this.graph, j3, Math.min(j3 + this.batchSize, this.nodeCount), j, this.totalNodeWeight, this.colors, this.currentCommunities, this.nextCommunities, this.cumulativeNodeWeights, this.nodeCommunityInfluences, this.communityWeights, this.communityWeightUpdates, getProgressLogger()));
            j2 = j3 + this.batchSize;
        }
    }

    private boolean updateModularity() {
        double d = this.modularity;
        this.modularity = calculateModularity();
        return this.modularity > d && Math.abs(this.modularity - d) > this.tolerance;
    }

    private double calculateModularity() {
        return (((Double) ParallelUtil.parallelStream(LongStream.range(0L, this.nodeCount), this.concurrency, longStream -> {
            HugeDoubleArray hugeDoubleArray = this.nodeCommunityInfluences;
            Objects.requireNonNull(hugeDoubleArray);
            return Double.valueOf(longStream.mapToDouble(hugeDoubleArray::get).reduce(Double::sum).orElseThrow(() -> {
                return new RuntimeException("Error while comptuing modularity");
            }));
        })).doubleValue() / (2.0d * this.totalNodeWeight)) - (((Double) ParallelUtil.parallelStream(LongStream.range(0L, this.nodeCount), this.concurrency, longStream2 -> {
            return Double.valueOf(longStream2.mapToDouble(j -> {
                return Math.pow(this.communityWeights.get(j), 2.0d);
            }).reduce(Double::sum).orElseThrow(() -> {
                return new RuntimeException("Error while comptuing modularity");
            }));
        })).doubleValue() / Math.pow(2.0d * this.totalNodeWeight, 2.0d));
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public ModularityOptimization m3me() {
        return this;
    }

    public void release() {
        this.nextCommunities.release();
        this.communityWeights.release();
        this.communityWeightUpdates.release();
        this.cumulativeNodeWeights.release();
        this.nodeCommunityInfluences.release();
        this.colors.release();
        this.colorsUsed = null;
    }

    public long getCommunityId(long j) {
        return (this.seedProperty == null || this.reverseSeedCommunityMapping == null) ? this.currentCommunities.get(j) : this.reverseSeedCommunityMapping.get(this.currentCommunities.get(j));
    }

    public int getIterations() {
        return this.iterationCounter;
    }

    public double getModularity() {
        return this.modularity;
    }

    public boolean didConverge() {
        return this.didConverge;
    }

    public double getTolerance() {
        return this.tolerance;
    }
}
