package org.deeplearning4j.plot;

import com.google.common.primitives.Ints;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.indexing.functions.Zero;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/deeplearning4j/plot/Tsne.class */
public class Tsne implements Serializable {
    protected int maxIter;
    protected double realMin;
    protected double initialMomentum;
    protected double finalMomentum;
    protected double minGain;
    protected double momentum;
    protected int switchMomentumIteration;
    protected boolean normalize;
    protected boolean usePca;
    protected int stopLyingIteration;
    protected double tolerance;
    protected double learningRate;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad;
    protected double perplexity;
    protected INDArray gains;
    protected INDArray yIncs;
    protected INDArray y;
    protected transient IterationListener iterationListener;
    protected static ClassPathResource r = new ClassPathResource("/scripts/tsne.py");
    protected static final Logger log = LoggerFactory.getLogger(Tsne.class);

    /* loaded from: input_file:org/deeplearning4j/plot/Tsne$Builder.class */
    public static class Builder {
        protected int maxIter = 1000;
        protected double realMin = 9.999999960041972E-13d;
        protected double initialMomentum = 0.5d;
        protected double finalMomentum = 0.800000011920929d;
        protected double momentum = 0.5d;
        protected int switchMomentumIteration = 100;
        protected boolean normalize = true;
        protected boolean usePca = false;
        protected int stopLyingIteration = 100;
        protected double tolerance = 9.999999747378752E-6d;
        protected double learningRate = 0.10000000149011612d;
        protected boolean useAdaGrad = false;
        protected double perplexity = 30.0d;
        protected double minGain = 0.10000000149011612d;

        public Builder minGain(double d) {
            this.minGain = d;
            return this;
        }

        public Builder perplexity(double d) {
            this.perplexity = d;
            return this;
        }

        public Builder useAdaGrad(boolean z) {
            this.useAdaGrad = z;
            return this;
        }

        public Builder learningRate(double d) {
            this.learningRate = d;
            return this;
        }

        public Builder tolerance(double d) {
            this.tolerance = d;
            return this;
        }

        public Builder stopLyingIteration(int i) {
            this.stopLyingIteration = i;
            return this;
        }

        public Builder usePca(boolean z) {
            this.usePca = z;
            return this;
        }

        public Builder normalize(boolean z) {
            this.normalize = z;
            return this;
        }

        public Builder setMaxIter(int i) {
            this.maxIter = i;
            return this;
        }

        public Builder setRealMin(double d) {
            this.realMin = d;
            return this;
        }

        public Builder setInitialMomentum(double d) {
            this.initialMomentum = d;
            return this;
        }

        public Builder setFinalMomentum(double d) {
            this.finalMomentum = d;
            return this;
        }

        public Builder setMomentum(double d) {
            this.momentum = d;
            return this;
        }

        public Builder setSwitchMomentumIteration(int i) {
            this.switchMomentumIteration = i;
            return this;
        }

        public Tsne build() {
            return new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity, this.minGain);
        }
    }

    public Tsne() {
        this.maxIter = 1000;
        this.realMin = Nd4j.EPS_THRESHOLD;
        this.initialMomentum = 0.5d;
        this.finalMomentum = 0.8d;
        this.minGain = 0.01d;
        this.momentum = this.initialMomentum;
        this.switchMomentumIteration = 100;
        this.normalize = true;
        this.usePca = false;
        this.stopLyingIteration = 250;
        this.tolerance = 1.0E-5d;
        this.learningRate = 500.0d;
        this.useAdaGrad = true;
        this.perplexity = 30.0d;
    }

    public Tsne(int i, double d, double d2, double d3, double d4, int i2, boolean z, boolean z2, int i3, double d5, double d6, boolean z3, double d7, double d8) {
        this.maxIter = 1000;
        this.realMin = Nd4j.EPS_THRESHOLD;
        this.initialMomentum = 0.5d;
        this.finalMomentum = 0.8d;
        this.minGain = 0.01d;
        this.momentum = this.initialMomentum;
        this.switchMomentumIteration = 100;
        this.normalize = true;
        this.usePca = false;
        this.stopLyingIteration = 250;
        this.tolerance = 1.0E-5d;
        this.learningRate = 500.0d;
        this.useAdaGrad = true;
        this.perplexity = 30.0d;
        this.tolerance = d5;
        this.minGain = d8;
        this.useAdaGrad = z3;
        this.learningRate = d6;
        this.stopLyingIteration = i3;
        this.maxIter = i;
        this.realMin = d;
        this.normalize = z;
        this.initialMomentum = d2;
        this.usePca = z2;
        this.finalMomentum = d3;
        this.momentum = d4;
        this.switchMomentumIteration = i2;
        this.perplexity = d7;
    }

    public Pair<INDArray, INDArray> hBeta(INDArray iNDArray, double d) {
        INDArray exp = Transforms.exp(iNDArray.neg().muli(Double.valueOf(d)));
        INDArray sum = exp.sum(new int[]{Integer.MAX_VALUE});
        INDArray addi = Transforms.log(sum).addi(iNDArray.mul(exp).sum(new int[]{0}).muli(Double.valueOf(d)).divi(sum));
        exp.divi(sum);
        return new Pair<>(addi, exp);
    }

    /* JADX WARN: Type inference failed for: r0v26, types: [int[], int[][]] */
    public INDArray computeGaussianPerplexity(INDArray iNDArray, double d) {
        int rows = iNDArray.rows();
        INDArray zeros = Nd4j.zeros(rows, rows);
        INDArray ones = Nd4j.ones(rows, 1);
        double log2 = Math.log(d);
        log.info("Calculating probabilities of data similarities..");
        for (int i = 0; i < rows; i++) {
            if (i % 500 == 0 && i > 0) {
                log.info("Handled " + i + " records");
            }
            double d2 = Double.NEGATIVE_INFINITY;
            double d3 = Double.POSITIVE_INFINITY;
            INDArrayIndex[] iNDArrayIndexArr = {new NDArrayIndex(Ints.concat((int[][]) new int[]{ArrayUtil.range(0, i), ArrayUtil.range(i + 1, iNDArray.columns())}))};
            INDArray iNDArray2 = iNDArray.slice(i).get(iNDArrayIndexArr);
            Pair<INDArray, INDArray> hBeta = hBeta(iNDArray2, ones.getDouble(i));
            INDArray sub = hBeta.getFirst().sub(Double.valueOf(log2));
            for (int i2 = 0; BooleanIndexing.and(Transforms.abs(sub), Conditions.greaterThan(Double.valueOf(this.tolerance))) && i2 < 50; i2++) {
                if (BooleanIndexing.and(sub, Conditions.greaterThan(0))) {
                    if (Double.isInfinite(d3)) {
                        ones.putScalar(i, ones.getDouble(i) * 2.0d);
                    } else {
                        ones.putScalar(i, (ones.getDouble(i) + d3) / 2.0d);
                    }
                    d2 = ones.getDouble(i);
                } else {
                    if (Double.isInfinite(d2)) {
                        ones.putScalar(i, ones.getDouble(i) / 2.0d);
                    } else {
                        ones.putScalar(i, (ones.getDouble(i) + d2) / 2.0d);
                    }
                    d3 = ones.getDouble(i);
                }
                hBeta = hBeta(iNDArray2, ones.getDouble(i));
                sub = hBeta.getFirst().subi(Double.valueOf(log2));
            }
            zeros.slice(i).put(iNDArrayIndexArr, hBeta.getSecond());
        }
        log.info("Mean value of sigma " + Transforms.sqrt(ones.rdiv(1)).mean(new int[]{Integer.MAX_VALUE}));
        BooleanIndexing.applyWhere(zeros, Conditions.isNan(), new Value(Double.valueOf(this.realMin)));
        INDArray add = zeros.add(zeros.transpose());
        add.divi(add.sum(new int[]{Integer.MAX_VALUE}));
        BooleanIndexing.applyWhere(add, Conditions.lessThan(Double.valueOf(Nd4j.EPS_THRESHOLD)), new Value(Double.valueOf(Nd4j.EPS_THRESHOLD)));
        return add;
    }

    public INDArray calculate(INDArray iNDArray, int i, double d) {
        if (this.usePca) {
            iNDArray = PCA.pca(iNDArray, Math.min(50, iNDArray.columns()), this.normalize);
        }
        if (this.normalize) {
            iNDArray.subi(iNDArray.min(new int[]{Integer.MAX_VALUE}));
            INDArray divi = iNDArray.divi(iNDArray.max(new int[]{Integer.MAX_VALUE}));
            iNDArray = divi.subiRowVector(divi.mean(new int[]{0}));
        }
        if (i > iNDArray.columns()) {
            i = iNDArray.columns();
        }
        INDArray sum = Transforms.pow(iNDArray, 2).sum(new int[]{1});
        INDArray addRowVector = iNDArray.mmul(iNDArray.transpose()).muli(-2).addRowVector(sum).transpose().addRowVector(sum);
        if (this.y == null) {
            this.y = Nd4j.randn(iNDArray.rows(), i, Nd4j.getRandom()).muli(Float.valueOf(0.001f));
        }
        INDArray computeGaussianPerplexity = computeGaussianPerplexity(addRowVector, d);
        computeGaussianPerplexity.muli(4);
        if (this.useAdaGrad && this.adaGrad == null) {
            this.adaGrad = new AdaGrad(this.learningRate);
        }
        for (int i2 = 0; i2 < this.maxIter; i2++) {
            step(computeGaussianPerplexity, i2);
            if (i2 == this.switchMomentumIteration) {
                this.momentum = this.finalMomentum;
            }
            if (i2 == this.stopLyingIteration) {
                computeGaussianPerplexity.divi(4);
            }
            if (this.iterationListener != null) {
                this.iterationListener.iterationDone(null, i2);
            }
        }
        return this.y;
    }

    protected Pair<Double, INDArray> gradient(INDArray iNDArray) {
        INDArray sum = Transforms.pow(this.y, 2).sum(new int[]{1});
        if (this.yIncs == null) {
            this.yIncs = Nd4j.zeros(this.y.shape());
        }
        if (this.gains == null) {
            this.gains = Nd4j.ones(this.y.shape());
        }
        INDArray rdivi = this.y.mmul(this.y.transpose()).muli(-2).addiRowVector(sum).transpose().addiRowVector(sum).addi(1).rdivi(1);
        int rows = this.y.rows();
        Nd4j.doAlongDiagonal(rdivi, new Zero());
        INDArray div = rdivi.div(rdivi.sum(new int[]{Integer.MAX_VALUE}));
        BooleanIndexing.applyWhere(div, Conditions.lessThan(Double.valueOf(this.realMin)), new Value(Double.valueOf(this.realMin)));
        INDArray yGradient = getYGradient(rows, iNDArray.sub(div), rdivi);
        this.gains = this.gains.add(Double.valueOf(0.2d)).muli(yGradient.cond(Conditions.greaterThan(0)).neqi(this.yIncs.cond(Conditions.greaterThan(0)))).addi(this.gains.mul(Double.valueOf(0.8d)).muli(yGradient.cond(Conditions.greaterThan(0)).eqi(this.yIncs.cond(Conditions.greaterThan(0)))));
        BooleanIndexing.applyWhere(this.gains, Conditions.lessThan(Double.valueOf(this.minGain)), new Value(Double.valueOf(this.minGain)));
        INDArray mul = this.gains.mul(yGradient);
        if (this.useAdaGrad) {
            mul = this.adaGrad.getGradient(mul, 0);
        } else {
            mul.muli(Double.valueOf(this.learningRate));
        }
        this.yIncs.muli(Double.valueOf(this.momentum)).subi(mul);
        return new Pair<>(Double.valueOf(iNDArray.mul(Transforms.log(iNDArray.div(div), false)).sum(new int[]{Integer.MAX_VALUE}).getDouble(0)), this.yIncs);
    }

    public INDArray getYGradient(int i, INDArray iNDArray, INDArray iNDArray2) {
        INDArray create = Nd4j.create(this.y.shape());
        for (int i2 = 0; i2 < i; i2++) {
            create.putRow(i2, Nd4j.tile(iNDArray.getRow(i2).mul(iNDArray2.getRow(i2)), new int[]{this.y.columns(), 1}).transpose().mul(this.y.getRow(i2).broadcast(this.y.shape()).sub(this.y)).sum(new int[]{0}));
        }
        return create;
    }

    public void step(INDArray iNDArray, int i) {
        Pair<Double, INDArray> gradient = gradient(iNDArray);
        INDArray second = gradient.getSecond();
        log.info("Cost at iteration " + i + " was " + gradient.getFirst());
        this.y.addi(second);
        this.y.addi(second).subiRowVector(this.y.mean(new int[]{0}));
        this.y.subi(Nd4j.tile(this.y.mean(new int[]{0}), new int[]{this.y.rows(), this.y.columns()}));
    }

    public void plot(INDArray iNDArray, int i, List<String> list) throws IOException {
        plot(iNDArray, i, list, "coords.csv");
    }

    public void plot(INDArray iNDArray, int i, List<String> list, String str) throws IOException {
        calculate(iNDArray, i, this.perplexity);
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(str), true));
        for (int i2 = 0; i2 < this.y.rows() && i2 < list.size(); i2++) {
            String str2 = list.get(i2);
            if (str2 != null) {
                StringBuffer stringBuffer = new StringBuffer();
                INDArray row = this.y.getRow(i2);
                for (int i3 = 0; i3 < row.length(); i3++) {
                    stringBuffer.append(row.getDouble(i3));
                    if (i3 < row.length() - 1) {
                        stringBuffer.append(",");
                    }
                }
                stringBuffer.append(",");
                stringBuffer.append(str2);
                stringBuffer.append(" ");
                stringBuffer.append("\n");
                bufferedWriter.write(stringBuffer.toString());
            }
        }
        bufferedWriter.flush();
        bufferedWriter.close();
    }

    public INDArray getY() {
        return this.y;
    }

    public void setY(INDArray iNDArray) {
        this.y = iNDArray;
    }

    public IterationListener getIterationListener() {
        return this.iterationListener;
    }

    public void setIterationListener(IterationListener iterationListener) {
        this.iterationListener = iterationListener;
    }
}
