package org.deeplearning4j.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.distancefunction.DistanceFunction;
import org.nd4j.linalg.distancefunction.EuclideanDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/clustering/KMeansClustering.class */
public class KMeansClustering implements Serializable {
    private static final long serialVersionUID = 338231277453149972L;
    private static Logger log = LoggerFactory.getLogger(KMeansClustering.class);
    private List<Long> counts;
    private INDArray centroids;
    private List<INDArray> initFeatures;
    private Class<? extends DistanceFunction> clazz;
    private transient ExecutorService exec;
    private Integer nbCluster;

    public KMeansClustering(Integer num, Class<? extends DistanceFunction> cls) {
        this.counts = null;
        this.initFeatures = new ArrayList();
        this.nbCluster = num;
        this.clazz = cls;
    }

    public KMeansClustering(Integer num) {
        this(num, EuclideanDistance.class);
    }

    public Integer classify(INDArray iNDArray) {
        if (isReady()) {
            return nearestCentroid(iNDArray);
        }
        throw new IllegalStateException("KMeans is not ready yet");
    }

    public Integer update(INDArray iNDArray) {
        if (!isReady()) {
            initIfPossible(iNDArray);
            log.info("Initializing feature vector with length of " + iNDArray.length());
            return null;
        }
        Integer classify = classify(iNDArray);
        this.counts.set(classify.intValue(), Long.valueOf(this.counts.get(classify.intValue()).longValue() + 1));
        this.centroids.getRow(classify.intValue()).addi(iNDArray.sub(this.centroids.getRow(classify.intValue())).mul(Double.valueOf(1.0d / this.counts.get(classify.intValue()).longValue())));
        return classify;
    }

    public INDArray distribution(INDArray iNDArray) {
        if (!isReady()) {
            throw new IllegalStateException("KMeans is not ready yet");
        }
        INDArray create = Nd4j.create(1, this.nbCluster.intValue());
        for (int i = 0; i < this.nbCluster.intValue(); i++) {
            create.putScalar(i, getDistance(this.centroids.getRow(i), iNDArray));
        }
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double getDistance(INDArray iNDArray, INDArray iNDArray2) {
        try {
            return ((Float) this.clazz.getConstructor(INDArray.class).newInstance(iNDArray).apply(iNDArray2)).floatValue();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public INDArray getCentroids() {
        return this.centroids;
    }

    protected Integer nearestCentroid(INDArray iNDArray) {
        Integer num = 0;
        double d = 3.4028234663852886E38d;
        for (int i = 0; i < this.centroids.rows(); i++) {
            INDArray row = this.centroids.getRow(i);
            if (row != null) {
                double distance = getDistance(row, iNDArray);
                if (distance < d) {
                    d = distance;
                    num = Integer.valueOf(i);
                }
            }
        }
        return num;
    }

    protected boolean isReady() {
        return (this.counts != null) && (this.centroids != null);
    }

    protected void initIfPossible(INDArray iNDArray) {
        this.initFeatures.add(iNDArray);
        if (this.exec == null) {
            this.exec = Executors.newScheduledThreadPool(Runtime.getRuntime().availableProcessors());
        }
        log.info("Added feature vector of length " + iNDArray.length());
        if (this.initFeatures.size() >= 10 * this.nbCluster.intValue()) {
            initCentroids();
        }
    }

    protected void initCentroids() {
        this.counts = new ArrayList(this.nbCluster.intValue());
        for (int i = 0; i < this.nbCluster.intValue(); i++) {
            this.counts.add(0L);
        }
        INDArray linearView = this.initFeatures.remove(new Random().nextInt(this.initFeatures.size())).linearView();
        this.centroids = Nd4j.create(this.nbCluster.intValue(), linearView.columns());
        this.centroids.putRow(0, linearView);
        log.info("Added initial centroid");
        for (int i2 = 1; i2 < this.nbCluster.intValue(); i2++) {
            INDArray computeDxs = computeDxs();
            double nextFloat = r0.nextFloat() * computeDxs.getDouble(computeDxs.length() - 1);
            int i3 = 0;
            while (true) {
                if (i3 >= computeDxs.length()) {
                    break;
                }
                if (computeDxs.getDouble(i3) >= nextFloat) {
                    this.centroids.putRow(i2, this.initFeatures.remove(i3));
                    break;
                }
                i3++;
            }
        }
        this.initFeatures.clear();
    }

    protected INDArray computeDxs() {
        final INDArray create = Nd4j.create(this.initFeatures.size(), this.initFeatures.get(0).columns());
        final AtomicInteger atomicInteger = new AtomicInteger(0);
        final CountDownLatch countDownLatch = new CountDownLatch(this.initFeatures.size());
        for (int i = 0; i < this.initFeatures.size(); i++) {
            final int i2 = i;
            this.exec.execute(new Runnable() { // from class: org.deeplearning4j.clustering.KMeansClustering.1
                @Override // java.lang.Runnable
                public void run() {
                    INDArray iNDArray = (INDArray) KMeansClustering.this.initFeatures.get(i2);
                    atomicInteger.getAndAdd((int) Math.pow(KMeansClustering.this.getDistance(iNDArray, KMeansClustering.this.centroids.getRow(KMeansClustering.this.nearestCentroid(iNDArray).intValue())), 2.0d));
                    create.putScalar(i2, atomicInteger.get());
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        return create;
    }

    public void reset() {
        this.counts = null;
        this.centroids = null;
        this.initFeatures = new ArrayList();
    }
}
