package org.deeplearning4j.optimize.solvers;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.util.OptimizerMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/StochasticHessianFree.class */
public class StochasticHessianFree implements OptimizerMatrix {
    boolean converged;
    OptimizableByGradientValueMatrix optimizable;
    TrainingEvaluator eval;
    double initialStepSize;
    double tolerance;
    double gradientTolerance;
    private BaseMultiLayerNetwork network;
    int maxIterations;
    private String myName;
    private INDArray ch;
    private INDArray gradient;
    private INDArray xi;
    private IterationListener listener;
    private double pi;
    private double decrease;
    private double boost;
    private double f;
    private double score;
    private double step;
    private static Logger logger = LoggerFactory.getLogger(StochasticHessianFree.class);
    private static Logger log = LoggerFactory.getLogger(StochasticHessianFree.class);

    public StochasticHessianFree(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix, double d, BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this.converged = false;
        this.initialStepSize = 1.0d;
        this.tolerance = 9.999999747378752E-6d;
        this.gradientTolerance = 0.0d;
        this.maxIterations = 10000;
        this.myName = "";
        this.pi = 0.5d;
        this.decrease = 0.9900000095367432d;
        this.boost = 1.0d / this.decrease;
        this.f = 1.0d;
        this.initialStepSize = d;
        this.optimizable = optimizableByGradientValueMatrix;
        this.network = baseMultiLayerNetwork;
    }

    public StochasticHessianFree(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix, IterationListener iterationListener, BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this(optimizableByGradientValueMatrix, 0.009999999776482582d, baseMultiLayerNetwork);
        this.listener = iterationListener;
    }

    public StochasticHessianFree(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix, double d, IterationListener iterationListener, BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this(optimizableByGradientValueMatrix, d, baseMultiLayerNetwork);
        this.listener = iterationListener;
    }

    public StochasticHessianFree(OptimizableByGradientValueMatrix optimizableByGradientValueMatrix, BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this(optimizableByGradientValueMatrix, 0.009999999776482582d, baseMultiLayerNetwork);
        this.network = baseMultiLayerNetwork;
    }

    void setup() {
        this.ch = Nd4j.zeros(1, this.optimizable.getNumParameters());
        this.xi = this.network.pack();
    }

    @Override // org.deeplearning4j.util.OptimizerMatrix
    public boolean isConverged() {
        return this.converged;
    }

    @Override // org.deeplearning4j.util.OptimizerMatrix
    public boolean optimize() {
        return optimize(this.maxIterations);
    }

    public Pair<List<Integer>, List<INDArray>> conjGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        INDArray subi = this.network.getBackPropRGradient(iNDArray2).subi(iNDArray);
        INDArray div = subi.div(iNDArray3);
        double d = subi.mul(div).sum(Integer.MAX_VALUE).getDouble(0);
        INDArray neg = div.neg();
        for (int i2 = 0; i2 < i; i2++) {
            INDArray backPropRGradient = this.network.getBackPropRGradient(neg);
            double d2 = backPropRGradient.mul(neg).sum(Integer.MAX_VALUE).getDouble(0);
            if (d2 < 0.0d) {
                log.info("Negative slope: " + d2 + " breaking");
            }
            log.info("Iteration on conjugate gradient " + i2 + " with value " + (0.5d * Nd4j.getBlasWrapper().dot(iNDArray.neg().add(subi).transpose(), iNDArray2)));
            double d3 = d / d2;
            iNDArray2.addi(neg.mul(Double.valueOf(d3)));
            INDArray add = subi.add(backPropRGradient.mul(Double.valueOf(d3)));
            INDArray div2 = add.div(iNDArray3);
            double d4 = d;
            d = add.mul(div2).sum(Integer.MAX_VALUE).getDouble(0);
            neg = div2.neg().add(neg.mul(Double.valueOf(d / d4)));
            subi = add;
            arrayList.add(Integer.valueOf(i2));
            arrayList2.add(iNDArray2.dup());
        }
        return new Pair<>(arrayList, arrayList2);
    }

    private Triple<INDArray, List<INDArray>, INDArray> runConjugateGradient(INDArray iNDArray, int i) {
        Pair<List<Integer>, List<INDArray>> conjGradient = conjGradient(this.gradient, this.ch, iNDArray, i);
        this.ch = conjGradient.getSecond().get(conjGradient.getSecond().size() - 1);
        return new Triple<>(this.ch, conjGradient.getSecond(), this.ch);
    }

    public double lineSearch(double d, INDArray iNDArray, INDArray iNDArray2) {
        double d2 = 1.0d;
        int i = 0;
        while (i < 60) {
            if (10 % 60 == 0) {
                log.info("Iteration " + i + " on line search with current rate of " + d2);
            }
            if (d <= this.gradient.mul(iNDArray2).mul(Double.valueOf(this.score + (0.009999999776482582d * d2))).sum(Integer.MAX_VALUE).getDouble(0)) {
                break;
            }
            d2 *= 0.800000011920929d;
            i++;
            d = this.network.score(iNDArray.add(iNDArray2.mul(Double.valueOf(d2))));
        }
        if (i == 60) {
            d2 = 0.0d;
            log.info("Went too far...reverting rate to 0");
        }
        return d2;
    }

    public Pair<INDArray, Double> cgBackTrack(List<INDArray> list, INDArray iNDArray) {
        INDArray params = this.network.params();
        double score = this.network.score(iNDArray.add(params));
        double score2 = this.network.score();
        int size = list.size() - 2;
        while (size > 0) {
            double score3 = this.network.score(params.add(list.get(size)));
            if (score3 < score || score3 < score2) {
                size++;
                score = score3;
                log.info("Breaking on new score " + score3 + " with iteration " + size + " with current minimum of " + score2);
                break;
            }
            log.info("Trial " + size + " with trial score of " + score3);
            size--;
        }
        if (size < 0) {
            size = 0;
        }
        return new Pair<>(list.get(size), Double.valueOf(score));
    }

    @Override // org.deeplearning4j.util.OptimizerMatrix
    public boolean optimize(int i) {
        this.myName = Thread.currentThread().getName();
        if (this.converged) {
            return true;
        }
        this.score = this.network.score();
        this.xi = this.network.params();
        Pair<INDArray, INDArray> backPropGradient2 = this.network.getBackPropGradient2();
        this.gradient = backPropGradient2.getFirst().neg();
        INDArray second = backPropGradient2.getSecond();
        if (this.ch == null) {
            setup();
        }
        this.ch.muli(Double.valueOf(this.pi));
        Triple<INDArray, List<INDArray>, INDArray> runConjugateGradient = runConjugateGradient(second, i);
        Pair<INDArray, Double> cgBackTrack = cgBackTrack(runConjugateGradient.getSecond(), runConjugateGradient.getFirst());
        INDArray first = cgBackTrack.getFirst();
        double reductionRatio = this.network.reductionRatio(cgBackTrack.getFirst(), this.network.score(), cgBackTrack.getSecond().doubleValue(), this.gradient);
        this.step = lineSearch(this.network.score(cgBackTrack.getFirst()), this.gradient, first);
        this.network.dampingUpdate(reductionRatio, this.boost, this.decrease);
        this.network.setParameters(this.xi.add(first.mul(Double.valueOf(this.f * this.step))));
        return true;
    }

    @Override // org.deeplearning4j.util.OptimizerMatrix
    public void setTrainingEvaluator(TrainingEvaluator trainingEvaluator) {
        this.eval = trainingEvaluator;
    }

    public void reset() {
        this.xi = null;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    @Override // org.deeplearning4j.util.OptimizerMatrix
    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    @Override // org.deeplearning4j.util.OptimizerMatrix
    public void setTolerance(double d) {
    }
}
