package org.deeplearning4j.optimize.optimizers;

import java.io.Serializable;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.deeplearning4j.optimize.solvers.StochasticHessianFree;
import org.deeplearning4j.optimize.solvers.VectorizedDeepLearningGradientAscent;
import org.deeplearning4j.optimize.solvers.VectorizedNonZeroStoppingConjugateGradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/optimizers/BackPropOptimizer.class */
public class BackPropOptimizer implements Serializable, OptimizableByGradientValueMatrix {
    private BaseMultiLayerNetwork network;
    private double lr;
    private int epochs;
    private static Logger log = LoggerFactory.getLogger(BackPropOptimizer.class);
    private int length = -1;
    private int currentIteration = -1;

    public BackPropOptimizer(BaseMultiLayerNetwork baseMultiLayerNetwork, double d, int i) {
        this.lr = 0.10000000149011612d;
        this.epochs = 1000;
        this.network = baseMultiLayerNetwork;
        this.lr = d;
        this.epochs = i;
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setCurrentIteration(int i) {
        this.currentIteration = i;
    }

    public void optimize(TrainingEvaluator trainingEvaluator, int i, boolean z) {
        if (z) {
            NeuralNetwork.OptimizationAlgorithm optimizationAlgo = this.network.getDefaultConfiguration().getOptimizationAlgo();
            if (optimizationAlgo == NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT) {
                VectorizedNonZeroStoppingConjugateGradient vectorizedNonZeroStoppingConjugateGradient = new VectorizedNonZeroStoppingConjugateGradient(this);
                vectorizedNonZeroStoppingConjugateGradient.setTrainingEvaluator(trainingEvaluator);
                vectorizedNonZeroStoppingConjugateGradient.setMaxIterations(i);
                vectorizedNonZeroStoppingConjugateGradient.optimize(i);
                return;
            }
            if (optimizationAlgo != NeuralNetwork.OptimizationAlgorithm.HESSIAN_FREE) {
                VectorizedDeepLearningGradientAscent vectorizedDeepLearningGradientAscent = new VectorizedDeepLearningGradientAscent(this);
                vectorizedDeepLearningGradientAscent.setTrainingEvaluator(trainingEvaluator);
                vectorizedDeepLearningGradientAscent.optimize(i);
                return;
            } else {
                StochasticHessianFree stochasticHessianFree = new StochasticHessianFree(this, this.network);
                stochasticHessianFree.setTrainingEvaluator(trainingEvaluator);
                stochasticHessianFree.setMaxIterations(i);
                stochasticHessianFree.optimize(i);
                return;
            }
        }
        log.info("BEGIN BACKPROP WITH SCORE OF " + this.network.score());
        double score = this.network.score();
        BaseMultiLayerNetwork m15clone = this.network.m15clone();
        if (this.network.isForceNumEpochs()) {
            for (int i2 = 0; i2 < this.epochs; i2++) {
                if (i2 % this.network.getDefaultConfiguration().getResetAdaGradIterations() == 0) {
                    this.network.getOutputLayer().getAdaGrad().historicalGradient = null;
                }
                this.network.backPropStep();
                log.info("Iteration " + i2 + " error " + this.network.score());
            }
            return;
        }
        boolean z2 = true;
        int i3 = 0;
        int i4 = 0;
        while (z2) {
            if (i4 >= this.epochs) {
                log.info("Backprop number of iterations max hit; converging");
                return;
            }
            i3++;
            this.network.backPropStep();
            double score2 = this.network.score();
            if (score2 < score) {
                if (Math.abs(score2 - score) < 9.999999974752427E-7d) {
                    log.info("Not enough of a change on back prop...breaking");
                    return;
                } else {
                    score = score2;
                    log.info("New score " + score);
                    m15clone = this.network.m15clone();
                }
            } else if (i3 >= this.epochs) {
                log.info("Hit max number of epochs...breaking");
                z2 = false;
            } else if (score2 >= score) {
                z2 = false;
                this.network.update(m15clone);
                log.info("Reverting to best score " + score);
            }
            i4++;
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public double getValue() {
        return -this.network.score();
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public int getNumParameters() {
        if (this.length < 0) {
            this.length = getParameters().length();
        }
        return this.length;
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setParameter(int i, double d) {
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getParameters() {
        return this.network.params();
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public double getParameter(int i) {
        return 0.0d;
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setParameters(INDArray iNDArray) {
        this.network.setParameters(iNDArray);
        this.network.getOutputLayer().trainTillConvergence(this.lr, this.epochs);
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getValueGradient(int i) {
        return this.network.getBackPropGradient2().getFirst();
    }
}
