package org.neo4j.gds.kmeans;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.ExecutorService;
import org.jetbrains.annotations.NotNull;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.nodeproperties.ValueType;
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.concurrency.Concurrency;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.kmeans.KmeansSampler;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/kmeans/Kmeans.class */
public final class Kmeans extends Algorithm<KmeansResult> {
    static final String KMEANS_DESCRIPTION = "The Kmeans  algorithm clusters nodes into different communities based on Euclidean distance";
    private static final int UNASSIGNED = -1;
    private HugeIntArray bestCommunities;
    private final Graph graph;
    private final KmeansParameters parameters;
    private final Concurrency concurrency;
    private final ExecutorService executorService;
    private final SplittableRandom random;
    private final NodePropertyValues nodePropertyValues;
    private final int dimensions;
    private double[][] bestCentroids;
    private HugeDoubleArray distanceFromCentroid;
    private final KmeansIterationStopper kmeansIterationStopper;
    private HugeDoubleArray silhouette;
    private double averageSilhouette;
    private double bestDistance;
    private long[] nodesInCluster;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static Kmeans createKmeans(Graph graph, KmeansParameters kmeansParameters, KmeansContext kmeansContext, TerminationFlag terminationFlag) {
        String nodeProperty = kmeansParameters.nodeProperty();
        NodePropertyValues nodeProperties = graph.nodeProperties(nodeProperty);
        if (nodeProperties == null) {
            throw new IllegalArgumentException("Property '" + nodeProperty + "' does not exist for all nodes");
        }
        return new Kmeans(kmeansContext.progressTracker(), kmeansContext.executor(), graph, kmeansParameters, getSplittableRandom(kmeansParameters.randomSeed()), nodeProperties, terminationFlag);
    }

    private Kmeans(ProgressTracker progressTracker, ExecutorService executorService, Graph graph, KmeansParameters kmeansParameters, SplittableRandom splittableRandom, NodePropertyValues nodePropertyValues, TerminationFlag terminationFlag) {
        super(progressTracker);
        this.executorService = executorService;
        this.graph = graph;
        this.random = splittableRandom;
        this.bestCommunities = HugeIntArray.newArray(graph.nodeCount());
        this.nodePropertyValues = nodePropertyValues;
        this.dimensions = nodePropertyValues.doubleArrayValue(0L).length;
        this.kmeansIterationStopper = new KmeansIterationStopper(kmeansParameters.deltaThreshold(), kmeansParameters.maxIterations(), graph.nodeCount());
        this.distanceFromCentroid = HugeDoubleArray.newArray(graph.nodeCount());
        this.parameters = kmeansParameters;
        this.concurrency = kmeansParameters.concurrency();
        this.nodesInCluster = new long[kmeansParameters.k()];
        this.terminationFlag = terminationFlag;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public KmeansResult m44compute() {
        this.progressTracker.beginSubTask();
        checkInputValidity();
        if (this.parameters.k() > this.graph.nodeCount()) {
            this.progressTracker.logWarning("Number of requested clusters is larger than the number of nodes.");
            this.bestCommunities.setAll(j -> {
                return (int) j;
            });
            this.distanceFromCentroid.setAll(j2 -> {
                return 0.0d;
            });
            this.progressTracker.endSubTask();
            this.bestCentroids = new double[(int) this.graph.nodeCount()][this.dimensions];
            for (int i = 0; i < ((int) this.graph.nodeCount()); i++) {
                this.bestCentroids[i] = this.nodePropertyValues.doubleArrayValue(i);
            }
            return ImmutableKmeansResult.of(this.bestCommunities, this.distanceFromCentroid, this.bestCentroids, 0.0d, this.silhouette, 0.0d);
        }
        long nodeCount = this.graph.nodeCount();
        HugeIntArray newArray = HugeIntArray.newArray(nodeCount);
        HugeDoubleArray newArray2 = HugeDoubleArray.newArray(nodeCount);
        this.bestDistance = Double.POSITIVE_INFINITY;
        this.bestCommunities.setAll(j3 -> {
            return -1;
        });
        if (this.parameters.numberOfRestarts() == 1) {
            kMeans(nodeCount, newArray, newArray2, 0);
        } else {
            for (int i2 = 0; i2 < this.parameters.numberOfRestarts(); i2++) {
                this.progressTracker.beginSubTask();
                kMeans(nodeCount, newArray, newArray2, i2);
                this.progressTracker.endSubTask();
            }
        }
        if (this.parameters.computeSilhouette()) {
            calculateSilhouette();
        }
        this.progressTracker.endSubTask();
        return ImmutableKmeansResult.of(this.bestCommunities, this.distanceFromCentroid, this.bestCentroids, this.bestDistance, this.silhouette, this.averageSilhouette);
    }

    private void kMeans(long j, HugeIntArray hugeIntArray, HugeDoubleArray hugeDoubleArray, int i) {
        long j2;
        ClusterManager createClusterManager = ClusterManager.createClusterManager(this.nodePropertyValues, this.dimensions, this.parameters.k());
        hugeIntArray.setAll(j3 -> {
            return -1;
        });
        List<KmeansTask> rangePartition = PartitionUtils.rangePartition(this.concurrency, j, partition -> {
            return KmeansTask.createTask(this.parameters.samplerType(), createClusterManager, this.nodePropertyValues, hugeIntArray, hugeDoubleArray, this.parameters.k(), this.dimensions, partition);
        }, Optional.of(Integer.valueOf(((int) j) / this.concurrency.value())));
        int size = rangePartition.size();
        KmeansSampler createSampler = KmeansSampler.createSampler(this.parameters.samplerType(), this.random, createClusterManager, j, this.parameters.k(), this.concurrency, hugeDoubleArray, this.executorService, rangePartition, this.progressTracker);
        if (!$assertionsDisabled && size > this.concurrency.value()) {
            throw new AssertionError();
        }
        initializeCentroids(createClusterManager, createSampler);
        int i2 = 0;
        this.progressTracker.beginSubTask();
        do {
            this.progressTracker.beginSubTask();
            j2 = 0;
            if (i2 > 0 || this.parameters.samplerType() == KmeansSampler.SamplerType.UNIFORM) {
                RunWithConcurrency.builder().concurrency(this.concurrency).tasks(rangePartition).executor(this.executorService).run();
                Iterator<KmeansTask> it = rangePartition.iterator();
                while (it.hasNext()) {
                    j2 += it.next().getSwaps();
                }
            }
            recomputeCentroids(createClusterManager, rangePartition);
            this.progressTracker.endSubTask();
            i2++;
        } while (!this.kmeansIterationStopper.shouldQuit(j2, i2));
        this.progressTracker.endSubTask();
        updateBestSolution(i, createClusterManager, calculatedistancePhase(rangePartition), hugeIntArray, hugeDoubleArray);
    }

    private void initializeCentroids(ClusterManager clusterManager, KmeansSampler kmeansSampler) {
        this.progressTracker.beginSubTask();
        if (this.parameters.isSeeded()) {
            clusterManager.assignSeededCentroids(this.parameters.seedCentroids());
        } else {
            kmeansSampler.performInitialSampling();
        }
        this.progressTracker.endSubTask();
    }

    private void recomputeCentroids(ClusterManager clusterManager, List<KmeansTask> list) {
        clusterManager.reset();
        Iterator<KmeansTask> it = list.iterator();
        while (it.hasNext()) {
            clusterManager.updateFromTask(it.next());
        }
        clusterManager.normalizeClusters();
    }

    @NotNull
    private static SplittableRandom getSplittableRandom(Optional<Long> optional) {
        return (SplittableRandom) optional.map((v1) -> {
            return new SplittableRandom(v1);
        }).orElseGet(SplittableRandom::new);
    }

    private void checkInputValidity() {
        if (this.parameters.isSeeded()) {
            for (List<Double> list : this.parameters.seedCentroids()) {
                if (list.size() != this.dimensions) {
                    throw new IllegalStateException("All property arrays for K-Means should have the same number of dimensions");
                }
                Iterator<Double> it = list.iterator();
                while (it.hasNext()) {
                    if (Double.isNaN(it.next().doubleValue())) {
                        throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");
                    }
                }
            }
        }
        ParallelUtil.parallelForEachNode(this.graph.nodeCount(), this.concurrency, TerminationFlag.RUNNING_TRUE, j -> {
            if (this.nodePropertyValues.valueType() == ValueType.FLOAT_ARRAY) {
                float[] floatArrayValue = this.nodePropertyValues.floatArrayValue(j);
                if (floatArrayValue == null) {
                    throw new IllegalArgumentException("Property '" + this.parameters.nodeProperty() + "' does not exist for all nodes");
                }
                if (floatArrayValue.length != this.dimensions) {
                    throw new IllegalStateException("All property arrays for K-Means should have the same number of dimensions");
                }
                for (int i = 0; i < this.dimensions; i++) {
                    if (Float.isNaN(floatArrayValue[i])) {
                        throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");
                    }
                }
                return;
            }
            double[] doubleArrayValue = this.nodePropertyValues.doubleArrayValue(j);
            if (doubleArrayValue == null) {
                throw new IllegalArgumentException("Property '" + this.parameters.nodeProperty() + "' does not exist for all nodes");
            }
            if (doubleArrayValue.length != this.dimensions) {
                throw new IllegalStateException("All property arrays for K-Means should have the same number of dimensions");
            }
            for (int i2 = 0; i2 < this.dimensions; i2++) {
                if (Double.isNaN(doubleArrayValue[i2])) {
                    throw new IllegalArgumentException("Input for K-Means should not contain any NaN values");
                }
            }
        });
    }

    private void calculateSilhouette() {
        long nodeCount = this.graph.nodeCount();
        this.progressTracker.beginSubTask();
        this.silhouette = HugeDoubleArray.newArray(nodeCount);
        List rangePartition = PartitionUtils.rangePartition(this.concurrency, nodeCount, partition -> {
            return SilhouetteTask.createTask(this.nodePropertyValues, this.bestCommunities, this.silhouette, this.parameters.k(), this.dimensions, this.nodesInCluster, partition, this.progressTracker);
        }, Optional.of(Integer.valueOf(((int) nodeCount) / this.concurrency.value())));
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks(rangePartition).executor(this.executorService).run();
        Iterator it = rangePartition.iterator();
        while (it.hasNext()) {
            this.averageSilhouette += ((SilhouetteTask) it.next()).getAverageSilhouette();
        }
        this.progressTracker.endSubTask();
    }

    private double calculatedistancePhase(List<KmeansTask> list) {
        Iterator<KmeansTask> it = list.iterator();
        while (it.hasNext()) {
            it.next().switchToPhase(TaskPhase.DISTANCE);
        }
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks(list).executor(this.executorService).run();
        double d = 0.0d;
        Iterator<KmeansTask> it2 = list.iterator();
        while (it2.hasNext()) {
            d += it2.next().getDistanceFromCentroidNormalized();
        }
        return d;
    }

    private void updateBestSolution(int i, ClusterManager clusterManager, double d, HugeIntArray hugeIntArray, HugeDoubleArray hugeDoubleArray) {
        if (i < 1) {
            this.bestCommunities = hugeIntArray;
            this.distanceFromCentroid = hugeDoubleArray;
            this.bestCentroids = clusterManager.getCentroids();
            this.bestDistance = d;
            if (this.parameters.computeSilhouette()) {
                this.nodesInCluster = clusterManager.getNodesInCluster();
                return;
            }
            return;
        }
        if (d < this.bestDistance) {
            this.bestDistance = d;
            ParallelUtil.parallelForEachNode(this.graph.nodeCount(), this.concurrency, this.terminationFlag, j -> {
                this.bestCommunities.set(j, hugeIntArray.get(j));
                this.distanceFromCentroid.set(j, hugeDoubleArray.get(j));
            });
            this.bestCentroids = clusterManager.getCentroids();
            if (this.parameters.computeSilhouette()) {
                this.nodesInCluster = clusterManager.getNodesInCluster();
            }
        }
    }

    static {
        $assertionsDisabled = !Kmeans.class.desiredAssertionStatus();
    }
}
