package de.lmu.ifi.dbs.elki.algorithm.projection;

import de.lmu.ifi.dbs.elki.algorithm.projection.TSNE;
import de.lmu.ifi.dbs.elki.data.DoubleVector;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
import de.lmu.ifi.dbs.elki.data.type.VectorFieldTypeInformation;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.datastore.DataStoreFactory;
import de.lmu.ifi.dbs.elki.database.datastore.WritableDataStore;
import de.lmu.ifi.dbs.elki.database.ids.DBIDArrayIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
import de.lmu.ifi.dbs.elki.database.relation.MaterializedRelation;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.logging.statistics.Duration;
import de.lmu.ifi.dbs.elki.logging.statistics.LongStatistic;
import de.lmu.ifi.dbs.elki.math.MathUtil;
import de.lmu.ifi.dbs.elki.utilities.Priority;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.io.FormatUtil;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.ParameterConstraint;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.random.RandomFactory;
import java.util.ArrayList;
import java.util.Arrays;

@Priority(199)
@Reference(authors = "L. J. P. van der Maaten", title = "Accelerating t-SNE using Tree-Based Algorithms", booktitle = "Journal of Machine Learning Research 15", url = "http://dl.acm.org/citation.cfm?id=2697068", bibkey = "DBLP:journals/jmlr/Maaten14")
/* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/projection/BarnesHutTSNE.class */
public class BarnesHutTSNE<O> extends TSNE<O> {
    private static final Logging LOG = Logging.getLogger((Class<?>) BarnesHutTSNE.class);
    protected static final double PERPLEXITY_ERROR = 1.0E-4d;
    protected static final int PERPLEXITY_MAXITER = 25;
    private static final double QUADTREE_MIN_RESOLUION = 1.0E-10d;
    protected double sqtheta;

    /* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/projection/BarnesHutTSNE$Parameterizer.class */
    public static class Parameterizer<O> extends TSNE.Parameterizer<O> {
        public static final OptionID THETA_ID = new OptionID("tsne.theta", "Approximation quality parameter");
        public double theta;

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Multi-variable type inference failed */
        @Override // de.lmu.ifi.dbs.elki.algorithm.projection.TSNE.Parameterizer, de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public void makeOptions(Parameterization parameterization) {
            super.makeOptions(parameterization);
            DoubleParameter doubleParameter = (DoubleParameter) ((DoubleParameter) new DoubleParameter(THETA_ID).setDefaultValue((DoubleParameter) Double.valueOf(0.5d))).addConstraint((ParameterConstraint) CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE);
            if (parameterization.grab(doubleParameter)) {
                this.theta = ((Double) doubleParameter.getValue()).doubleValue();
            }
        }

        @Override // de.lmu.ifi.dbs.elki.algorithm.projection.TSNE.Parameterizer
        protected Class<?> getDefaultAffinity() {
            return NearestNeighborAffinityMatrixBuilder.class;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // de.lmu.ifi.dbs.elki.algorithm.projection.TSNE.Parameterizer, de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer
        public BarnesHutTSNE<O> makeInstance() {
            return new BarnesHutTSNE<>(this.affinity, this.dim, this.finalMomentum, this.learningRate, this.iterations, this.random, this.keep, this.theta);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:de/lmu/ifi/dbs/elki/algorithm/projection/BarnesHutTSNE$QuadTree.class */
    public static class QuadTree {
        public double[] center;
        public double[][] points;
        public double squareSize;
        public int weight;
        public QuadTree[] children;
        static final /* synthetic */ boolean $assertionsDisabled;

        private QuadTree(double[][] dArr, QuadTree[] quadTreeArr, double[] dArr2, int i, double d) {
            this.center = dArr2;
            this.points = dArr;
            this.weight = i;
            this.squareSize = d;
            this.children = quadTreeArr;
        }

        public static QuadTree build(int i, double[][] dArr) {
            return build(i, (double[][]) dArr.clone(), 0, dArr.length);
        }

        private static QuadTree build(int i, double[][] dArr, int i2, int i3) {
            double[] computeExtend = computeExtend(i, dArr, i2, i3);
            double computeSquareSize = computeSquareSize(computeExtend);
            double[] computeCenterofMass = computeCenterofMass(i, dArr, i2, i3);
            int i4 = i3 - i2;
            if (computeSquareSize <= 1.0E-10d) {
                return new QuadTree((double[][]) Arrays.copyOfRange(dArr, i2, i3), null, computeCenterofMass, i4, computeSquareSize);
            }
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            splitRecursively(dArr, i2, i3, 0, i, computeExtend, arrayList, arrayList2);
            return new QuadTree(arrayList.size() > 0 ? (double[][]) arrayList.toArray((Object[]) new double[arrayList.size()]) : (double[][]) null, arrayList2.size() > 0 ? (QuadTree[]) arrayList2.toArray(new QuadTree[arrayList2.size()]) : null, computeCenterofMass, i4, computeSquareSize);
        }

        private static void splitRecursively(double[][] dArr, int i, int i2, int i3, int i4, double[] dArr2, ArrayList<double[]> arrayList, ArrayList<QuadTree> arrayList2) {
            int i5 = i2 - i;
            if (i5 <= 1) {
                if (i5 == 1) {
                    arrayList.add(dArr[i]);
                    return;
                }
                return;
            }
            int i6 = i3;
            do {
                int i7 = i6 << 1;
                double d = dArr2[i7];
                double d2 = 0.5d * (d + dArr2[i7 + 1]);
                if (d < d2) {
                    int i8 = i;
                    int i9 = i2 - 1;
                    while (i8 <= i9) {
                        while (i8 <= i9 && dArr[i8][i6] <= d2) {
                            i8++;
                        }
                        while (i8 <= i9 && dArr[i9][i6] >= d2) {
                            i9--;
                        }
                        if (i8 < i9) {
                            if (!$assertionsDisabled && dArr[i8][i6] <= d2) {
                                throw new AssertionError();
                            }
                            if (!$assertionsDisabled && dArr[i9][i6] >= d2) {
                                throw new AssertionError();
                            }
                            double[] dArr3 = dArr[i9];
                            dArr[i9] = dArr[i8];
                            dArr[i8] = dArr3;
                            i8++;
                            i9--;
                        }
                    }
                    if (!$assertionsDisabled && i8 != i2 && dArr[i8][i6] < d2) {
                        throw new AssertionError();
                    }
                    if (!$assertionsDisabled && i8 != i && dArr[i8 - 1][i6] > d2) {
                        throw new AssertionError();
                    }
                    int i10 = i6 + 1;
                    if (i10 < i4) {
                        if (i < i8) {
                            splitRecursively(dArr, i, i8, i10, i4, dArr2, arrayList, arrayList2);
                        }
                        if (i8 < i2) {
                            splitRecursively(dArr, i8, i2, i10, i4, dArr2, arrayList, arrayList2);
                            return;
                        }
                        return;
                    }
                    if (i < i8) {
                        arrayList2.add(build(i4, dArr, i, i8));
                    }
                    if (i8 < i2) {
                        arrayList2.add(build(i4, dArr, i8, i2));
                        return;
                    }
                    return;
                }
                i6++;
            } while (i6 != i4);
            BarnesHutTSNE.LOG.warning("Should not be reached", new Throwable());
            if (!$assertionsDisabled && i3 == 0) {
                throw new AssertionError("All dimensions constant?");
            }
            BarnesHutTSNE.LOG.warning("Unexpected all-constant split.");
            arrayList2.add(new QuadTree((double[][]) Arrays.copyOfRange(dArr, i, i2), null, computeCenterofMass(i4, dArr, i, i2), i5, 0.0d));
        }

        private static double[] computeCenterofMass(int i, double[][] dArr, int i2, int i3) {
            int i4 = i3 - i2;
            if (i4 == 1) {
                return dArr[i2];
            }
            double[] dArr2 = new double[i];
            for (int i5 = i2; i5 < i3; i5++) {
                double[] dArr3 = dArr[i5];
                for (int i6 = 0; i6 < i; i6++) {
                    int i7 = i6;
                    dArr2[i7] = dArr2[i7] + dArr3[i6];
                }
            }
            double d = 1.0d / i4;
            for (int i8 = 0; i8 < i; i8++) {
                int i9 = i8;
                dArr2[i9] = dArr2[i9] * d;
            }
            return dArr2;
        }

        private static double[] computeExtend(int i, double[][] dArr, int i2, int i3) {
            double[] dArr2 = new double[i << 1];
            int i4 = 0;
            while (i4 < dArr2.length) {
                int i5 = i4;
                int i6 = i4 + 1;
                dArr2[i5] = Double.POSITIVE_INFINITY;
                i4 = i6 + 1;
                dArr2[i6] = Double.NEGATIVE_INFINITY;
            }
            for (int i7 = i2; i7 < i3; i7++) {
                double[] dArr3 = dArr[i7];
                int i8 = 0;
                for (int i9 = 0; i9 < i; i9++) {
                    double d = dArr3[i9];
                    dArr2[i8] = MathUtil.min(dArr2[i8], d);
                    int i10 = i8 + 1;
                    dArr2[i10] = MathUtil.max(dArr2[i10], d);
                    i8 = i10 + 1;
                }
            }
            return dArr2;
        }

        private static double computeSquareSize(double[] dArr) {
            double d = 0.0d;
            int length = dArr.length - 1;
            for (int i = 0; i < length; i += 2) {
                double d2 = dArr[i + 1] - dArr[i];
                d += d2 * d2;
            }
            return d;
        }

        public String toString() {
            return "QuadTree[center=" + FormatUtil.format(this.center) + ", weight=" + this.weight + ", points=" + this.points.length + ", children=" + this.children.length + ", sqSize=" + this.squareSize + "]";
        }

        static {
            $assertionsDisabled = !BarnesHutTSNE.class.desiredAssertionStatus();
        }
    }

    public BarnesHutTSNE(AffinityMatrixBuilder<? super O> affinityMatrixBuilder, int i, double d, double d2, int i2, RandomFactory randomFactory, boolean z, double d3) {
        super(affinityMatrixBuilder, i, d, d2 * 4.0d, i2, randomFactory, z);
        this.sqtheta = d3 * d3;
    }

    public Relation<DoubleVector> run(Database database, Relation<O> relation) {
        AffinityMatrix computeAffinityMatrix = this.affinity.computeAffinityMatrix(relation, 4.0d);
        double[][] randomInitialSolution = randomInitialSolution(computeAffinityMatrix.size(), this.dim, this.random.getSingleThreadedRandom());
        this.projectedDistances = 0L;
        optimizetSNE(computeAffinityMatrix, randomInitialSolution);
        LOG.statistics(new LongStatistic(getClass().getName() + ".projected-distances", this.projectedDistances));
        removePreviousRelation(relation);
        DBIDs dBIDs = relation.getDBIDs();
        WritableDataStore makeStorage = DataStoreFactory.FACTORY.makeStorage(dBIDs, 30, DoubleVector.class);
        VectorFieldTypeInformation vectorFieldTypeInformation = new VectorFieldTypeInformation(DoubleVector.FACTORY, this.dim);
        DBIDArrayIter iterDBIDs = computeAffinityMatrix.iterDBIDs();
        while (iterDBIDs.valid()) {
            makeStorage.put(iterDBIDs, DoubleVector.wrap(randomInitialSolution[iterDBIDs.getOffset()]));
            iterDBIDs.advance();
        }
        return new MaterializedRelation("tSNE", "t-SNE", vectorFieldTypeInformation, makeStorage, dBIDs);
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.projection.TSNE
    protected void optimizetSNE(AffinityMatrix affinityMatrix, double[][] dArr) {
        int size = affinityMatrix.size();
        if (size * 3 * this.dim > 2147483642) {
            throw new AbortException("Memory exceeds Java array size limit.");
        }
        double[] dArr2 = new double[size * 3 * this.dim];
        int i = this.dim * 3;
        int i2 = 2 * this.dim;
        while (true) {
            int i3 = i2;
            if (i3 >= dArr2.length) {
                break;
            }
            Arrays.fill(dArr2, i3, i3 + this.dim, 1.0d);
            i2 = i3 + i;
        }
        FiniteProgress finiteProgress = LOG.isVerbose() ? new FiniteProgress("Iterative Optimization", this.iterations, LOG) : null;
        Duration begin = LOG.isStatistics() ? LOG.newDuration(getClass().getName() + ".runtime.optimization").begin() : null;
        for (int i4 = 0; i4 < this.iterations; i4++) {
            computeGradient(affinityMatrix, dArr, dArr2);
            updateSolution(dArr, dArr2, i4);
            if (i4 == 50) {
                affinityMatrix.scale(0.25d);
            }
            LOG.incrementProcessed(finiteProgress);
        }
        LOG.ensureCompleted(finiteProgress);
        if (begin != null) {
            LOG.statistics(begin.end());
        }
    }

    private void computeGradient(AffinityMatrix affinityMatrix, double[][] dArr, double[] dArr2) {
        int i = 3 * this.dim;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (i3 >= dArr2.length) {
                break;
            }
            Arrays.fill(dArr2, i3, i3 + this.dim, 0.0d);
            i2 = i3 + i;
        }
        QuadTree build = QuadTree.build(this.dim, dArr);
        double d = 0.0d;
        int i4 = 0;
        int i5 = 0;
        while (true) {
            int i6 = i5;
            if (i4 >= dArr.length) {
                break;
            }
            d -= computeRepulsiveForces(dArr2, i6, dArr[i4], build);
            i4++;
            i5 = i6 + i;
        }
        double d2 = 1.0d / d;
        int i7 = 0;
        while (true) {
            int i8 = i7;
            if (i8 >= dArr2.length) {
                computeAttractiveForces(dArr2, affinityMatrix, dArr);
                return;
            }
            for (int i9 = 0; i9 < this.dim; i9++) {
                int i10 = i8 + i9;
                dArr2[i10] = dArr2[i10] * d2;
            }
            i7 = i8 + i;
        }
    }

    private void computeAttractiveForces(double[] dArr, AffinityMatrix affinityMatrix, double[][] dArr2) {
        int i = 3 * this.dim;
        int i2 = 0;
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= dArr.length) {
                return;
            }
            double[] dArr3 = dArr2[i2];
            int iter = affinityMatrix.iter(i2);
            while (true) {
                int i5 = iter;
                if (affinityMatrix.iterValid(i2, i5)) {
                    double[] dArr4 = dArr2[affinityMatrix.iterDim(i2, i5)];
                    double iterValue = affinityMatrix.iterValue(i2, i5) / (1.0d + sqDist(dArr3, dArr4));
                    for (int i6 = 0; i6 < this.dim; i6++) {
                        int i7 = i4 + i6;
                        dArr[i7] = dArr[i7] + (iterValue * (dArr3[i6] - dArr4[i6]));
                    }
                    iter = affinityMatrix.iterAdvance(i2, i5);
                }
            }
            i2++;
            i3 = i4 + i;
        }
    }

    private double computeRepulsiveForces(double[] dArr, int i, double[] dArr2, QuadTree quadTree) {
        double[] dArr3 = quadTree.center;
        double sqDist = sqDist(dArr2, dArr3);
        if (quadTree.weight == 1 || quadTree.squareSize / sqDist < this.sqtheta) {
            double d = 1.0d / (1.0d + sqDist);
            double d2 = quadTree.weight * d;
            double d3 = d2 * d;
            for (int i2 = 0; i2 < this.dim; i2++) {
                int i3 = i + i2;
                dArr[i3] = dArr[i3] + (d3 * (dArr2[i2] - dArr3[i2]));
            }
            return d2;
        }
        double d4 = 0.0d;
        if (quadTree.points != null) {
            for (double[] dArr4 : quadTree.points) {
                double sqDist2 = 1.0d / (1.0d + sqDist(dArr2, dArr4));
                double d5 = sqDist2 * sqDist2;
                for (int i4 = 0; i4 < this.dim; i4++) {
                    int i5 = i + i4;
                    dArr[i5] = dArr[i5] + (d5 * (dArr2[i4] - dArr4[i4]));
                }
                d4 += sqDist2;
            }
        }
        if (quadTree.children != null) {
            for (QuadTree quadTree2 : quadTree.children) {
                d4 += computeRepulsiveForces(dArr, i, dArr2, quadTree2);
            }
        }
        return d4;
    }

    @Override // de.lmu.ifi.dbs.elki.algorithm.projection.TSNE, de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm, de.lmu.ifi.dbs.elki.algorithm.Algorithm
    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(this.affinity.getInputTypeRestriction());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.lmu.ifi.dbs.elki.algorithm.projection.TSNE, de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm
    public Logging getLogger() {
        return LOG;
    }
}
