package org.deeplearning4j.optimize.solvers;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/StochasticHessianFree.class */
public class StochasticHessianFree extends BaseOptimizer {
    private static final Logger logger = LoggerFactory.getLogger(StochasticHessianFree.class);
    boolean converged;
    double initialStepSize;
    double tolerance;
    double gradientTolerance;
    private MultiLayerNetwork network;
    int maxIterations;
    private INDArray ch;
    private INDArray gradient;
    private INDArray xi;
    private double pi;
    private double decrease;
    private double boost;
    private double score;

    public StochasticHessianFree(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, model);
        this.converged = false;
        this.initialStepSize = 1.0d;
        this.tolerance = 9.999999747378752E-6d;
        this.gradientTolerance = 0.0d;
        this.maxIterations = 10000;
        this.pi = 0.5d;
        this.decrease = 0.9900000095367432d;
        this.boost = 1.0d / this.decrease;
        setup();
    }

    public StochasticHessianFree(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Collection<TerminationCondition> collection2, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, collection2, model);
        this.converged = false;
        this.initialStepSize = 1.0d;
        this.tolerance = 9.999999747378752E-6d;
        this.gradientTolerance = 0.0d;
        this.maxIterations = 10000;
        this.pi = 0.5d;
        this.decrease = 0.9900000095367432d;
        this.boost = 1.0d / this.decrease;
        setup();
    }

    void setup() {
        if (this.model instanceof MultiLayerNetwork) {
            this.network = (MultiLayerNetwork) this.model;
            this.xi = this.network.pack();
            this.ch = Nd4j.zeros(1, this.xi.length());
        }
    }

    public boolean isConverged() {
        return this.converged;
    }

    public Pair<List<Integer>, List<INDArray>> conjGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        INDArray subi = this.network.getBackPropRGradient(iNDArray2).subi(iNDArray);
        INDArray div = subi.div(iNDArray3);
        double d = subi.mul(div).sum(Integer.MAX_VALUE).getDouble(0);
        INDArray neg = div.neg();
        for (int i2 = 0; i2 < i; i2++) {
            INDArray backPropRGradient = this.network.getBackPropRGradient(neg);
            double d2 = backPropRGradient.mul(neg).sum(Integer.MAX_VALUE).getDouble(0);
            if (d2 < 0.0d) {
                log.info("Negative slope: " + d2 + " breaking");
            }
            double d3 = d / d2;
            iNDArray2.addi(neg.mul(Double.valueOf(d3)));
            INDArray addi = subi.addi(backPropRGradient.mul(Double.valueOf(d3)));
            INDArray div2 = addi.div(iNDArray3);
            double d4 = d;
            d = addi.mul(div2).sum(Integer.MAX_VALUE).getDouble(0);
            neg = div2.neg().addi(neg.mul(Double.valueOf(d / d4)));
            subi = addi;
            arrayList.add(Integer.valueOf(i2));
            arrayList2.add(iNDArray2.dup());
        }
        return new Pair<>(arrayList, arrayList2);
    }

    private Triple<INDArray, List<INDArray>, INDArray> runConjugateGradient(INDArray iNDArray, int i) {
        Pair<List<Integer>, List<INDArray>> conjGradient = conjGradient(this.gradient, this.ch, iNDArray, i);
        this.ch = conjGradient.getSecond().get(conjGradient.getSecond().size() - 1);
        return new Triple<>(this.ch, conjGradient.getSecond(), this.ch);
    }

    public double lineSearch(double d, INDArray iNDArray, INDArray iNDArray2) {
        double d2 = 1.0d;
        int i = 0;
        while (i < 60) {
            if (10 % 60 == 0) {
                log.info("Iteration " + i + " on line search with current rate of " + d2);
            }
            if (d <= this.gradient.mul(iNDArray2).muli(Double.valueOf(this.score + (0.009999999776482582d * d2))).sum(Integer.MAX_VALUE).getDouble(0)) {
                break;
            }
            d2 *= 0.800000011920929d;
            i++;
            d = this.network.score(iNDArray.add(iNDArray2.mul(Double.valueOf(d2))));
        }
        if (i == 60) {
            d2 = 0.0d;
            log.info("Went too far...reverting rate to 0");
        }
        return d2;
    }

    public Pair<INDArray, Double> cgBackTrack(List<INDArray> list, INDArray iNDArray) {
        INDArray params = this.network.params();
        double score = this.network.score(iNDArray.add(params));
        double score2 = this.network.score();
        int size = list.size() - 2;
        while (size > 0) {
            double score3 = this.network.score(params.add(list.get(size)));
            if (score3 < score || score3 < score2) {
                size++;
                score = score3;
                break;
            }
            size--;
        }
        if (size < 0) {
            size = 0;
        }
        return new Pair<>(list.get(size), Double.valueOf(score));
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public boolean optimize() {
        if (!(this.model instanceof MultiLayerNetwork) || this.converged) {
            return true;
        }
        this.score = this.network.score();
        this.xi = this.network.params();
        for (int i = 0; i < this.conf.getNumIterations(); i++) {
            Pair<INDArray, INDArray> backPropGradient2 = this.network.getBackPropGradient2();
            this.gradient = backPropGradient2.getFirst().neg();
            INDArray second = backPropGradient2.getSecond();
            if (this.ch == null) {
                setup();
            }
            this.ch.muli(Double.valueOf(this.pi));
            Triple<INDArray, List<INDArray>, INDArray> runConjugateGradient = runConjugateGradient(second, this.conf.getNumIterations());
            Pair<INDArray, Double> cgBackTrack = cgBackTrack(runConjugateGradient.getSecond(), runConjugateGradient.getFirst());
            INDArray first = cgBackTrack.getFirst();
            double reductionRatio = this.network.reductionRatio(cgBackTrack.getFirst(), this.network.score(), cgBackTrack.getSecond().doubleValue(), this.gradient);
            double score = this.network.score(cgBackTrack.getFirst());
            double lineSearch = lineSearch(score, this.gradient, first);
            this.network.dampingUpdate(reductionRatio, this.boost, this.decrease);
            this.network.setParameters(this.xi.add(first.mul(Double.valueOf(1.0d * lineSearch))));
            log.info("Score at iteration " + i + " was " + score);
        }
        return true;
    }
}
