package org.deeplearning4j.plot;

import com.google.common.primitives.Ints;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.SpecifiedIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/plot/Tsne.class */
public class Tsne {
    protected int maxIter;
    protected double realMin;
    protected double initialMomentum;
    protected double finalMomentum;
    protected double minGain;
    protected double momentum;
    protected int switchMomentumIteration;
    protected boolean normalize;
    protected boolean usePca;
    protected int stopLyingIteration;
    protected double tolerance;
    protected double learningRate;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad;
    protected double perplexity;
    protected INDArray Y;
    protected static final Logger logger = LoggerFactory.getLogger(Tsne.class);

    /* loaded from: input_file:org/deeplearning4j/plot/Tsne$Builder.class */
    public static class Builder {
        protected int maxIter = 1000;
        protected double realMin = 9.999999960041972E-13d;
        protected double initialMomentum = 0.5d;
        protected double finalMomentum = 0.800000011920929d;
        protected double momentum = 0.5d;
        protected int switchMomentumIteration = 100;
        protected boolean normalize = true;
        protected boolean usePca = false;
        protected int stopLyingIteration = 100;
        protected double tolerance = 9.999999747378752E-6d;
        protected double learningRate = 0.10000000149011612d;
        protected boolean useAdaGrad = false;
        protected double perplexity = 30.0d;
        protected double minGain = 0.10000000149011612d;

        public Builder minGain(double d) {
            this.minGain = d;
            return this;
        }

        public Builder perplexity(double d) {
            this.perplexity = d;
            return this;
        }

        public Builder useAdaGrad(boolean z) {
            this.useAdaGrad = z;
            return this;
        }

        public Builder learningRate(double d) {
            this.learningRate = d;
            return this;
        }

        public Builder tolerance(double d) {
            this.tolerance = d;
            return this;
        }

        public Builder stopLyingIteration(int i) {
            this.stopLyingIteration = i;
            return this;
        }

        public Builder usePca(boolean z) {
            this.usePca = z;
            return this;
        }

        public Builder normalize(boolean z) {
            this.normalize = z;
            return this;
        }

        public Builder setMaxIter(int i) {
            this.maxIter = i;
            return this;
        }

        public Builder setRealMin(double d) {
            this.realMin = d;
            return this;
        }

        public Builder setInitialMomentum(double d) {
            this.initialMomentum = d;
            return this;
        }

        public Builder setFinalMomentum(double d) {
            this.finalMomentum = d;
            return this;
        }

        public Builder setMomentum(double d) {
            this.momentum = d;
            return this;
        }

        public Builder setSwitchMomentumIteration(int i) {
            this.switchMomentumIteration = i;
            return this;
        }

        public Tsne build() {
            return new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.minGain, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity);
        }
    }

    public Tsne(int i, double d, double d2, double d3, double d4, double d5, int i2, boolean z, boolean z2, int i3, double d6, double d7, boolean z3, double d8) {
        this.maxIter = 1000;
        this.realMin = Nd4j.EPS_THRESHOLD;
        this.initialMomentum = 0.5d;
        this.finalMomentum = 0.8d;
        this.minGain = 0.01d;
        this.momentum = this.initialMomentum;
        this.switchMomentumIteration = 100;
        this.normalize = true;
        this.usePca = false;
        this.stopLyingIteration = 250;
        this.tolerance = 1.0E-5d;
        this.learningRate = 500.0d;
        this.useAdaGrad = true;
        this.perplexity = 30.0d;
        this.maxIter = i;
        this.realMin = d;
        this.initialMomentum = d2;
        this.finalMomentum = d3;
        this.minGain = d4;
        this.momentum = d5;
        this.switchMomentumIteration = i2;
        this.normalize = z;
        this.usePca = z2;
        this.stopLyingIteration = i3;
        this.tolerance = d6;
        this.learningRate = d7;
        this.useAdaGrad = z3;
        this.perplexity = d8;
        init();
    }

    protected void init() {
    }

    public INDArray calculate(INDArray iNDArray, int i, double d) {
        if (this.usePca) {
            iNDArray = PCA.pca(iNDArray, Math.min(50, iNDArray.columns()), this.normalize);
        } else if (this.normalize) {
            iNDArray.subi(iNDArray.min(new int[]{Integer.MAX_VALUE}));
            INDArray divi = iNDArray.divi(iNDArray.max(new int[]{Integer.MAX_VALUE}));
            iNDArray = divi.subiRowVector(divi.mean(new int[]{0}));
        }
        int rows = iNDArray.rows();
        this.Y = Nd4j.randn(iNDArray.rows(), i, Nd4j.getRandom());
        Nd4j.zeros(rows, i);
        INDArray zeros = Nd4j.zeros(rows, i);
        INDArray ones = Nd4j.ones(rows, i);
        boolean z = false;
        logger.debug("Y:Shape is = " + Arrays.toString(this.Y.shape()));
        INDArray x2p = x2p(iNDArray, this.tolerance, d);
        for (int i2 = 0; i2 < this.maxIter; i2++) {
            INDArray transpose = Transforms.pow(this.Y, 2).sum(new int[]{1}).transpose();
            INDArray rdivi = this.Y.mmul(this.Y.transpose()).muli(-2).addiRowVector(transpose).transpose().addiRowVector(transpose).addi(1).rdivi(1);
            INDArray div = rdivi.div(Double.valueOf(rdivi.sumNumber().doubleValue()));
            BooleanIndexing.applyWhere(div, Conditions.lessThan(Double.valueOf(1.0E-12d)), new Value(Double.valueOf(1.0E-12d)));
            INDArray muli = x2p.sub(div).muli(rdivi);
            logger.debug("PQ shape is: " + Arrays.toString(muli.shape()));
            logger.debug("PQ.sum(1) shape is: " + Arrays.toString(muli.sum(new int[]{1}).shape()));
            INDArray muli2 = diag(muli.sum(new int[]{1})).subi(muli).mmul(this.Y).muli(4);
            if (i2 < this.switchMomentumIteration) {
                this.momentum = this.initialMomentum;
            } else {
                this.momentum = this.finalMomentum;
            }
            ones = ones.add(Double.valueOf(0.2d)).muli(muli2.cond(Conditions.greaterThan(0)).neqi(zeros.cond(Conditions.greaterThan(0)))).addi(ones.mul(Double.valueOf(0.8d)).muli(muli2.cond(Conditions.greaterThan(0)).eqi(zeros.cond(Conditions.greaterThan(0)))));
            BooleanIndexing.applyWhere(ones, Conditions.lessThan(Double.valueOf(this.minGain)), new Value(Double.valueOf(this.minGain)));
            INDArray mul = ones.mul(muli2);
            mul.muli(Double.valueOf(this.learningRate));
            zeros.muli(Double.valueOf(this.momentum)).subi(mul);
            logger.info("Iteration [" + i2 + "] error is: [" + x2p.mul(Transforms.log(x2p.div(div), false)).sumNumber().doubleValue() + "]");
            this.Y.addi(zeros);
            this.Y.subi(Nd4j.tile(this.Y.mean(new int[]{0}), new int[]{this.Y.rows(), 1}));
            if (!z && (i2 > this.maxIter / 2 || i2 >= this.stopLyingIteration)) {
                x2p.divi(4);
                z = true;
            }
        }
        return this.Y;
    }

    public INDArray diag(INDArray iNDArray) {
        boolean z = iNDArray.rows() > iNDArray.columns();
        INDArray slice = iNDArray.slice(0);
        int max = Math.max(iNDArray.columns(), iNDArray.rows());
        INDArray create = Nd4j.create(max, max);
        for (int i = 0; i < max; i++) {
            INDArray slice2 = iNDArray.slice(i);
            INDArray slice3 = create.slice(i);
            for (int i2 = 0; i2 < max; i2++) {
                if (i == i2) {
                    if (z) {
                        slice3.putScalar(i2, slice2.getDouble(0));
                    } else {
                        slice3.putScalar(i2, slice.getDouble(i));
                    }
                }
            }
        }
        return create;
    }

    public void plot(INDArray iNDArray, int i, List<String> list, String str) throws IOException {
        calculate(iNDArray, i, this.perplexity);
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str), true));
        for (int i2 = 0; i2 < this.Y.rows() && i2 < list.size(); i2++) {
            String str2 = list.get(i2);
            if (str2 != null) {
                StringBuilder sb = new StringBuilder();
                INDArray row = this.Y.getRow(i2);
                for (int i3 = 0; i3 < row.length(); i3++) {
                    sb.append(row.getDouble(i3));
                    if (i3 < row.length() - 1) {
                        sb.append(",");
                    }
                }
                sb.append(",");
                sb.append(str2);
                sb.append(" ");
                sb.append("\n");
                bufferedWriter.write(sb.toString());
            }
        }
        bufferedWriter.flush();
        bufferedWriter.close();
    }

    public Pair<Double, INDArray> hBeta(INDArray iNDArray, double d) {
        INDArray exp = Transforms.exp(iNDArray.neg().muli(Double.valueOf(d)));
        double doubleValue = exp.sumNumber().doubleValue();
        Double valueOf = Double.valueOf(FastMath.log(doubleValue) + ((d * iNDArray.mul(exp).sumNumber().doubleValue()) / doubleValue));
        exp.divi(Double.valueOf(doubleValue));
        return new Pair<>(valueOf, exp);
    }

    /* JADX WARN: Type inference failed for: r0v47, types: [int[], int[][]] */
    private INDArray x2p(INDArray iNDArray, double d, double d2) {
        int rows = iNDArray.rows();
        INDArray zeros = Nd4j.zeros(rows, rows);
        INDArray ones = Nd4j.ones(rows, 1);
        double log = Math.log(d2);
        INDArray sum = Transforms.pow(iNDArray, 2).sum(new int[]{1});
        logger.debug("sumX shape: " + Arrays.toString(sum.shape()));
        INDArray muli = iNDArray.mmul(iNDArray.transpose()).muli(-2);
        logger.debug("times shape: " + Arrays.toString(muli.shape()));
        logger.debug("prodSum shape: " + Arrays.toString(muli.transpose().addiColumnVector(sum).shape()));
        INDArray addRowVector = iNDArray.mmul(iNDArray.transpose()).mul(-2).transpose().addColumnVector(sum).addRowVector(sum.transpose());
        logger.info("Calculating probabilities of data similarities...");
        logger.debug("Tolerance: " + d);
        for (int i = 0; i < rows; i++) {
            if (i % 500 == 0 && i > 0) {
                logger.info("Handled [" + i + "] records out of [" + rows + "]");
            }
            double d3 = Double.NEGATIVE_INFINITY;
            double d4 = Double.POSITIVE_INFINITY;
            INDArrayIndex[] iNDArrayIndexArr = {new SpecifiedIndex(Ints.concat((int[][]) new int[]{ArrayUtil.range(0, i), ArrayUtil.range(i + 1, rows)}))};
            INDArray iNDArray2 = addRowVector.slice(i).get(iNDArrayIndexArr);
            Pair<Double, INDArray> hBeta = hBeta(iNDArray2, ones.getDouble(i));
            double doubleValue = ((Double) hBeta.getFirst()).doubleValue() - log;
            for (int i2 = 0; Math.abs(doubleValue) > d && i2 < 50; i2++) {
                if (doubleValue > 0.0d) {
                    d3 = ones.getDouble(i);
                    if (Double.isInfinite(d4)) {
                        ones.putScalar(i, ones.getDouble(i) * 2.0d);
                    } else {
                        ones.putScalar(i, (ones.getDouble(i) + d4) / 2.0d);
                    }
                } else {
                    d4 = ones.getDouble(i);
                    if (Double.isInfinite(d3)) {
                        ones.putScalar(i, ones.getDouble(i) / 2.0d);
                    } else {
                        ones.putScalar(i, (ones.getDouble(i) + d3) / 2.0d);
                    }
                }
                hBeta = hBeta(iNDArray2, ones.getDouble(i));
                doubleValue = ((Double) hBeta.getFirst()).doubleValue() - log;
            }
            zeros.slice(i).put(iNDArrayIndexArr, (INDArray) hBeta.getSecond());
        }
        logger.info("Mean value of sigma " + Transforms.sqrt(ones.rdiv(1)).mean(new int[]{Integer.MAX_VALUE}));
        BooleanIndexing.applyWhere(zeros, Conditions.isNan(), new Value(Double.valueOf(1.0E-12d)));
        INDArray add = zeros.add(zeros.transpose());
        add.divi(Double.valueOf(add.sumNumber().doubleValue() + 1.0E-6d));
        add.muli(4);
        BooleanIndexing.applyWhere(add, Conditions.lessThan(Double.valueOf(1.0E-12d)), new Value(Double.valueOf(1.0E-12d)));
        return add;
    }
}
