package org.deeplearning4j.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import org.deeplearning4j.distancefunction.DistanceFunction;
import org.deeplearning4j.distancefunction.EuclideanDistance;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/clustering/KMeansClustering.class */
public class KMeansClustering implements Serializable {
    private static final long serialVersionUID = 338231277453149972L;
    private static Logger log = LoggerFactory.getLogger(KMeansClustering.class);
    private List<Long> counts;
    private DoubleMatrix centroids;
    private List<DoubleMatrix> initFeatures;
    private Class<DistanceFunction> clazz;
    private Integer nbCluster;

    public KMeansClustering(Integer num, Class<? extends DistanceFunction> cls) {
        this.counts = null;
        this.initFeatures = new ArrayList();
        this.nbCluster = num;
    }

    public KMeansClustering(Integer num) {
        this(num, EuclideanDistance.class);
    }

    public Integer classify(DoubleMatrix doubleMatrix) {
        if (isReady()) {
            return nearestCentroid(doubleMatrix);
        }
        throw new IllegalStateException("KMeans is not ready yet");
    }

    public Integer update(DoubleMatrix doubleMatrix) {
        if (!isReady()) {
            initIfPossible(doubleMatrix);
            log.info("Initializing feature vector with length of " + doubleMatrix.length);
            return null;
        }
        Integer classify = classify(doubleMatrix);
        this.counts.set(classify.intValue(), Long.valueOf(this.counts.get(classify.intValue()).longValue() + 1));
        this.centroids.putRow(classify.intValue(), this.centroids.getRow(classify.intValue()).add(doubleMatrix.sub(this.centroids.getRow(classify.intValue())).mul(1.0d / this.counts.get(classify.intValue()).longValue())));
        return classify;
    }

    public DoubleMatrix distribution(DoubleMatrix doubleMatrix) {
        if (!isReady()) {
            throw new IllegalStateException("KMeans is not ready yet");
        }
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(1, this.nbCluster.intValue());
        for (int i = 0; i < this.nbCluster.intValue(); i++) {
            doubleMatrix2.put(i, getDistance(this.centroids.getRow(i), doubleMatrix));
        }
        return doubleMatrix2;
    }

    private double getDistance(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        try {
            return ((Double) this.clazz.getConstructor(DoubleMatrix.class).newInstance(doubleMatrix).apply(doubleMatrix2)).doubleValue();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public DoubleMatrix getCentroids() {
        return this.centroids;
    }

    protected Integer nearestCentroid(DoubleMatrix doubleMatrix) {
        Integer num = 0;
        Double valueOf = Double.valueOf(Double.MAX_VALUE);
        for (int i = 0; i < this.centroids.rows; i++) {
            DoubleMatrix row = this.centroids.getRow(i);
            if (row != null) {
                Double valueOf2 = Double.valueOf(getDistance(row, doubleMatrix));
                if (valueOf2.doubleValue() < valueOf.doubleValue()) {
                    valueOf = valueOf2;
                    num = Integer.valueOf(i);
                }
            }
        }
        return num;
    }

    protected boolean isReady() {
        return (this.counts != null) && (this.centroids != null);
    }

    protected void initIfPossible(DoubleMatrix doubleMatrix) {
        this.initFeatures.add(doubleMatrix);
        log.info("Added feature vector of length " + doubleMatrix.length);
        if (this.initFeatures.size() >= 10 * this.nbCluster.intValue()) {
            initCentroids();
        }
    }

    protected void initCentroids() {
        this.counts = new ArrayList(this.nbCluster.intValue());
        for (int i = 0; i < this.nbCluster.intValue(); i++) {
            this.counts.add(0L);
        }
        Random random = new Random();
        DoubleMatrix remove = this.initFeatures.remove(random.nextInt(this.initFeatures.size()));
        this.centroids = new DoubleMatrix(this.nbCluster.intValue(), remove.columns);
        this.centroids.putRow(0, remove);
        log.info("Added initial centroid");
        for (int i2 = 1; i2 < this.nbCluster.intValue(); i2++) {
            DoubleMatrix computeDxs = computeDxs();
            double nextDouble = random.nextDouble() * computeDxs.get(computeDxs.length - 1);
            int i3 = 0;
            while (true) {
                if (i3 >= computeDxs.length) {
                    break;
                }
                if (computeDxs.get(i3) >= nextDouble) {
                    this.centroids.putRow(i2, this.initFeatures.remove(i3));
                    break;
                }
                i3++;
            }
        }
        this.initFeatures.clear();
    }

    protected DoubleMatrix computeDxs() {
        DoubleMatrix doubleMatrix = new DoubleMatrix(this.initFeatures.size(), this.initFeatures.get(0).columns);
        int i = 0;
        for (int i2 = 0; i2 < this.initFeatures.size(); i2++) {
            DoubleMatrix doubleMatrix2 = this.initFeatures.get(i2);
            i = (int) (i + MatrixFunctions.pow(getDistance(doubleMatrix2, this.centroids.getRow(nearestCentroid(doubleMatrix2).intValue())), 2.0d));
            doubleMatrix.put(i2, i);
        }
        return doubleMatrix;
    }

    public void reset() {
        this.counts = null;
        this.centroids = null;
        this.initFeatures = new ArrayList();
    }
}
