package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans;

import de.lmu.ifi.dbs.elki.algorithm.AbstractNumberVectorDistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.DistanceBasedAlgorithm;
import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.KMeansInitialization;
import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.RandomlyChosenInitialMeans;
import de.lmu.ifi.dbs.elki.data.Cluster;
import de.lmu.ifi.dbs.elki.data.Clustering;
import de.lmu.ifi.dbs.elki.data.DoubleVector;
import de.lmu.ifi.dbs.elki.data.HierarchicalClassLabel;
import de.lmu.ifi.dbs.elki.data.NumberVector;
import de.lmu.ifi.dbs.elki.data.SparseNumberVector;
import de.lmu.ifi.dbs.elki.data.model.KMeansModel;
import de.lmu.ifi.dbs.elki.data.model.Model;
import de.lmu.ifi.dbs.elki.data.type.CombinedTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.datastore.DataStoreUtil;
import de.lmu.ifi.dbs.elki.database.datastore.WritableIntegerDataStore;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancefunction.NumberVectorDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.EuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.SquaredEuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.IndefiniteProgress;
import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic;
import de.lmu.ifi.dbs.elki.logging.statistics.Duration;
import de.lmu.ifi.dbs.elki.logging.statistics.LongStatistic;
import de.lmu.ifi.dbs.elki.math.linearalgebra.VMath;
import de.lmu.ifi.dbs.elki.utilities.datastructures.arrays.DoubleIntegerArrayQuickSort;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import net.jafama.FastMath;

/* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.class */
public abstract class AbstractKMeans<V extends NumberVector, M extends Model> extends AbstractNumberVectorDistanceBasedAlgorithm<V, Clustering<M>> implements KMeans<V, M> {
    protected int k;
    protected int maxiter;
    protected KMeansInitialization initializer;

    /* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans$Instance.class */
    protected static abstract class Instance {
        double[][] means;
        protected List<ModifiableDBIDs> clusters;
        protected WritableIntegerDataStore assignment;
        protected double[] varsum;
        protected Relation<? extends NumberVector> relation;
        private long diststat = 0;
        private final NumberVectorDistanceFunction<?> df;
        protected final int k;
        protected final boolean isSquared;
        protected String key;
        static final /* synthetic */ boolean $assertionsDisabled;

        public Instance(Relation<? extends NumberVector> relation, NumberVectorDistanceFunction<?> numberVectorDistanceFunction, double[][] dArr) {
            this.relation = relation;
            this.df = numberVectorDistanceFunction;
            this.isSquared = numberVectorDistanceFunction.isSquared();
            this.means = dArr;
            this.k = dArr.length;
            int size = (int) ((relation.size() * 2.0d) / this.k);
            this.clusters = new ArrayList(this.k);
            for (int i = 0; i < this.k; i++) {
                this.clusters.add(DBIDUtil.newHashSet(size));
            }
            this.assignment = DataStoreUtil.makeIntegerStorage(relation.getDBIDs(), 3, -1);
            this.varsum = new double[this.k];
            this.key = getClass().getName().replace("$Instance", "");
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public double distance(NumberVector numberVector, NumberVector numberVector2) {
            this.diststat++;
            return this.df.distance(numberVector, numberVector2);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void run(int i) {
            Logging logger = getLogger();
            IndefiniteProgress indefiniteProgress = logger.isVerbose() ? new IndefiniteProgress("Iteration") : null;
            int i2 = 0;
            while (true) {
                i2++;
                if (i2 > i) {
                    break;
                }
                logger.incrementProcessed(indefiniteProgress);
                if (iterate(i2) == 0) {
                    break;
                }
                if (logger.isStatistics()) {
                    logger.statistics(new LongStatistic(this.key + HierarchicalClassLabel.DEFAULT_SEPARATOR_STRING + i2 + ".reassignments", Math.abs(r0)));
                    double sum = VMath.sum(this.varsum);
                    if (sum > 0.0d) {
                        logger.statistics(new DoubleStatistic(this.key + HierarchicalClassLabel.DEFAULT_SEPARATOR_STRING + i2 + ".variance-sum", sum));
                    }
                }
            }
            logger.setCompleted(indefiniteProgress);
            logger.statistics(new LongStatistic(this.key + ".iterations", i2));
            logger.statistics(new LongStatistic(this.key + ".distance-computations", this.diststat));
        }

        protected abstract int iterate(int i);

        /* JADX INFO: Access modifiers changed from: protected */
        public void meansFromSums(double[][] dArr, double[][] dArr2) {
            for (int i = 0; i < this.k; i++) {
                VMath.overwriteTimes(dArr[i], dArr2[i], 1.0d / this.clusters.get(i).size());
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void copyMeans(double[][] dArr, double[][] dArr2) {
            for (int i = 0; i < this.k; i++) {
                System.arraycopy(dArr[i], 0, dArr2[i], 0, dArr[i].length);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public int assignToNearestCluster() {
            if (!$assertionsDisabled && this.k != this.means.length) {
                throw new AssertionError();
            }
            int i = 0;
            Arrays.fill(this.varsum, 0.0d);
            Iterator<ModifiableDBIDs> it2 = this.clusters.iterator();
            while (it2.hasNext()) {
                it2.next().clear();
            }
            DBIDIter iterDBIDs = this.relation.iterDBIDs();
            while (iterDBIDs.valid()) {
                double d = Double.POSITIVE_INFINITY;
                NumberVector numberVector = this.relation.get(iterDBIDs);
                int i2 = 0;
                for (int i3 = 0; i3 < this.k; i3++) {
                    double distance = distance(numberVector, DoubleVector.wrap(this.means[i3]));
                    if (distance < d) {
                        i2 = i3;
                        d = distance;
                    }
                }
                double[] dArr = this.varsum;
                int i4 = i2;
                dArr[i4] = dArr[i4] + d;
                this.clusters.get(i2).add(iterDBIDs);
                if (this.assignment.putInt(iterDBIDs, i2) != i2) {
                    i++;
                }
                iterDBIDs.advance();
            }
            return i;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void recomputeSeperation(double[] dArr, double[][] dArr2) {
            int length = this.means.length;
            boolean isSquared = this.df.isSquared();
            if (!$assertionsDisabled && dArr.length != length) {
                throw new AssertionError();
            }
            Arrays.fill(dArr, Double.POSITIVE_INFINITY);
            for (int i = 1; i < length; i++) {
                DoubleVector wrap = DoubleVector.wrap(this.means[i]);
                for (int i2 = 0; i2 < i; i2++) {
                    double distance = distance(wrap, DoubleVector.wrap(this.means[i2]));
                    double sqrt = 0.5d * (isSquared ? FastMath.sqrt(distance) : distance);
                    dArr2[i2][i] = sqrt;
                    dArr2[i][i2] = sqrt;
                    dArr[i] = sqrt < dArr[i] ? sqrt : dArr[i];
                    dArr[i2] = sqrt < dArr[i2] ? sqrt : dArr[i2];
                }
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public double movedDistance(double[][] dArr, double[][] dArr2, double[] dArr3) {
            if (!$assertionsDisabled && dArr2.length != dArr.length) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && dArr3.length != dArr.length) {
                throw new AssertionError();
            }
            boolean isSquared = this.df.isSquared();
            double d = 0.0d;
            for (int i = 0; i < dArr.length; i++) {
                double distance = distance(DoubleVector.wrap(dArr[i]), DoubleVector.wrap(dArr2[i]));
                int i2 = i;
                double sqrt = isSquared ? FastMath.sqrt(distance) : distance;
                dArr3[i2] = dArr3;
                d = sqrt > d ? sqrt : d;
            }
            return d;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Clustering<KMeansModel> buildResult() {
            Clustering<KMeansModel> clustering = new Clustering<>("k-Means Clustering", "kmeans-clustering");
            for (int i = 0; i < this.clusters.size(); i++) {
                ModifiableDBIDs modifiableDBIDs = this.clusters.get(i);
                if (modifiableDBIDs.isEmpty()) {
                    getLogger().warning("K-Means produced an empty cluster - bad initialization?");
                }
                clustering.addToplevelCluster(new Cluster<>(modifiableDBIDs, new KMeansModel(this.means[i], this.varsum[i])));
            }
            return clustering;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public Clustering<KMeansModel> buildResult(boolean z, Relation<? extends NumberVector> relation) {
            double d = 0.0d;
            Clustering<KMeansModel> clustering = new Clustering<>("k-Means Clustering", "kmeans-clustering");
            for (int i = 0; i < this.clusters.size(); i++) {
                ModifiableDBIDs modifiableDBIDs = this.clusters.get(i);
                if (!modifiableDBIDs.isEmpty()) {
                    double d2 = 0.0d;
                    if (z) {
                        DoubleVector wrap = DoubleVector.wrap(this.means[i]);
                        DBIDIter iter = modifiableDBIDs.iter();
                        while (iter.valid()) {
                            d2 += distance(wrap, relation.get(iter));
                            iter.advance();
                        }
                        d += d2;
                    }
                    clustering.addToplevelCluster(new Cluster<>(modifiableDBIDs, new KMeansModel(this.means[i], d2)));
                }
            }
            Logging logger = getLogger();
            if (z && logger.isStatistics()) {
                logger.statistics(new DoubleStatistic(this.key + ".variance-sum", d));
                logger.statistics(new LongStatistic(this.key + ".distance-computations", this.diststat));
            }
            return clustering;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public boolean isSquared() {
            return this.df.isSquared();
        }

        abstract Logging getLogger();

        static {
            $assertionsDisabled = !AbstractKMeans.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans$Parameterizer.class */
    public static abstract class Parameterizer<V extends NumberVector> extends AbstractNumberVectorDistanceBasedAlgorithm.Parameterizer<V> {
        protected int k;
        protected int maxiter;
        protected KMeansInitialization initializer;
        protected boolean varstat = false;

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractNumberVectorDistanceBasedAlgorithm.Parameterizer, de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public void makeOptions(Parameterization parameterization) {
            getParameterK(parameterization);
            getParameterInitialization(parameterization);
            getParameterDistanceFunction(parameterization);
            getParameterMaxIter(parameterization);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Multi-variable type inference failed */
        public void getParameterK(Parameterization parameterization) {
            IntParameter intParameter = (IntParameter) new IntParameter(KMeans.K_ID).addConstraint((ParameterConstraint) CommonConstraints.GREATER_EQUAL_ONE_INT);
            if (parameterization.grab(intParameter)) {
                this.k = ((Integer) intParameter.getValue()).intValue();
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void getParameterDistanceFunction(Parameterization parameterization) {
            ObjectParameter objectParameter = new ObjectParameter(DistanceBasedAlgorithm.DISTANCE_FUNCTION_ID, (Class<?>) PrimitiveDistanceFunction.class, (Class<?>) SquaredEuclideanDistanceFunction.class);
            if (parameterization.grab(objectParameter)) {
                this.distanceFunction = (NumberVectorDistanceFunction) objectParameter.instantiateClass(parameterization);
                if (this.distanceFunction == null || (this.distanceFunction instanceof SquaredEuclideanDistanceFunction) || (this.distanceFunction instanceof EuclideanDistanceFunction)) {
                    return;
                }
                if (!needsMetric() || this.distanceFunction.isMetric()) {
                    Logging.getLogger(getClass()).warning("k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!");
                } else {
                    Logging.getLogger(getClass()).warning("This k-means variants requires the triangle inequality, and thus should only be used with squared Euclidean distance!");
                }
            }
        }

        protected boolean needsMetric() {
            return false;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void getParameterInitialization(Parameterization parameterization) {
            ObjectParameter objectParameter = new ObjectParameter(KMeans.INIT_ID, (Class<?>) KMeansInitialization.class, (Class<?>) RandomlyChosenInitialMeans.class);
            if (parameterization.grab(objectParameter)) {
                this.initializer = (KMeansInitialization) objectParameter.instantiateClass(parameterization);
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Multi-variable type inference failed */
        public void getParameterMaxIter(Parameterization parameterization) {
            IntParameter intParameter = (IntParameter) new IntParameter(KMeans.MAXITER_ID, 0).addConstraint((ParameterConstraint) CommonConstraints.GREATER_EQUAL_ZERO_INT);
            if (parameterization.grab(intParameter)) {
                this.maxiter = ((Integer) intParameter.getValue()).intValue();
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void getParameterVarstat(Parameterization parameterization) {
            Flag flag = new Flag(KMeans.VARSTAT_ID);
            this.varstat = parameterization.grab(flag) && flag.isTrue();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public abstract AbstractKMeans<V, ?> makeInstance();
    }

    public AbstractKMeans(NumberVectorDistanceFunction<? super V> numberVectorDistanceFunction, int i, int i2, KMeansInitialization kMeansInitialization) {
        super(numberVectorDistanceFunction);
        this.k = i;
        this.maxiter = i2 > 0 ? i2 : Integer.MAX_VALUE;
        this.initializer = kMeansInitialization;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm, de.lmu.ifi.dbs.elki.algorithm.Algorithm
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new CombinedTypeInformation(TypeUtil.NUMBER_VECTOR_FIELD, getDistanceFunction().getInputTypeRestriction()));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[][] initialMeans(Database database, Relation<V> relation) {
        Duration begin = getLogger().newDuration(this.initializer.getClass() + ".time").begin();
        double[][] chooseInitialMeans = this.initializer.chooseInitialMeans(database, relation, this.k, getDistanceFunction());
        getLogger().statistics(begin.end());
        return chooseInitialMeans;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double[][] means(List<? extends DBIDs> list, double[][] dArr, Relation<? extends NumberVector> relation) {
        return TypeUtil.SPARSE_VECTOR_FIELD.isAssignableFromType(relation.getDataTypeInformation()) ? sparseMeans(list, dArr, relation) : denseMeans(list, dArr, relation);
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    private static double[][] denseMeans(List<? extends DBIDs> list, double[][] dArr, Relation<? extends NumberVector> relation) {
        ?? r0 = new double[dArr.length];
        for (int i = 0; i < r0.length; i++) {
            DBIDs dBIDs = list.get(i);
            if (dBIDs.isEmpty()) {
                r0[i] = dArr[i];
            } else {
                DBIDIter iter = dBIDs.iter();
                double[] array = relation.get(iter).toArray();
                iter.advance();
                while (iter.valid()) {
                    plusEquals(array, relation.get(iter));
                    iter.advance();
                }
                r0[i] = VMath.timesEquals(array, 1.0d / dBIDs.size());
            }
        }
        return r0;
    }

    public static void plusEquals(double[] dArr, NumberVector numberVector) {
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + numberVector.doubleValue(i);
        }
    }

    public static void minusEquals(double[] dArr, NumberVector numberVector) {
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] - numberVector.doubleValue(i);
        }
    }

    public static void plusMinusEquals(double[] dArr, double[] dArr2, NumberVector numberVector) {
        for (int i = 0; i < dArr.length; i++) {
            double doubleValue = numberVector.doubleValue(i);
            int i2 = i;
            dArr[i2] = dArr[i2] + doubleValue;
            int i3 = i;
            dArr2[i3] = dArr2[i3] - doubleValue;
        }
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    private static double[][] sparseMeans(List<? extends DBIDs> list, double[][] dArr, Relation<? extends SparseNumberVector> relation) {
        int length = dArr.length;
        ?? r0 = new double[length];
        for (int i = 0; i < length; i++) {
            DBIDs dBIDs = list.get(i);
            if (dBIDs.isEmpty()) {
                r0[i] = dArr[i];
            } else {
                DBIDIter iter = dBIDs.iter();
                double[] array = relation.get(iter).toArray();
                iter.advance();
                while (iter.valid()) {
                    SparseNumberVector sparseNumberVector = relation.get(iter);
                    int iter2 = sparseNumberVector.iter();
                    while (true) {
                        int i2 = iter2;
                        if (sparseNumberVector.iterValid(i2)) {
                            int iterDim = sparseNumberVector.iterDim(i2);
                            array[iterDim] = array[iterDim] + sparseNumberVector.iterDoubleValue(i2);
                            iter2 = sparseNumberVector.iterAdvance(i2);
                        }
                    }
                    iter.advance();
                }
                r0[i] = VMath.timesEquals(array, 1.0d / dBIDs.size());
            }
        }
        return r0;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void nearestMeans(double[][] dArr, int[][] iArr) {
        int length = dArr.length;
        double[] dArr2 = new double[length - 1];
        int i = 0;
        while (i < length) {
            System.arraycopy(dArr[i], 0, dArr2, 0, i);
            System.arraycopy(dArr[i], i + 1, dArr2, i, (length - i) - 1);
            int i2 = 0;
            while (i2 < dArr2.length) {
                iArr[i][i2] = i2 < i ? i2 : i2 + 1;
                i2++;
            }
            DoubleIntegerArrayQuickSort.sort(dArr2, iArr[i], length - 1);
            i++;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void incrementalUpdateMean(double[] dArr, NumberVector numberVector, int i, double d) {
        if (i == 0) {
            return;
        }
        VMath.plusTimesEquals(dArr, VMath.minusEquals(numberVector.toArray(), dArr), d / i);
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeans
    public void setK(int i) {
        this.k = i;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeans
    public void setDistanceFunction(NumberVectorDistanceFunction<? super V> numberVectorDistanceFunction) {
        this.distanceFunction = numberVectorDistanceFunction;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeans
    public void setInitializer(KMeansInitialization kMeansInitialization) {
        this.initializer = kMeansInitialization;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm, de.lmu.ifi.dbs.elki.algorithm.Algorithm
    public /* bridge */ /* synthetic */ Clustering run(Database database) {
        return (Clustering) super.run(database);
    }
}
