package org.deeplearning4j.optimize.optimizers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.solvers.VectorizedDeepLearningGradientAscent;
import org.deeplearning4j.optimize.solvers.VectorizedNonZeroStoppingConjugateGradient;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.OptimizerMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/optimizers/NeuralNetworkOptimizer.class */
public abstract class NeuralNetworkOptimizer implements OptimizableByGradientValueMatrix, Serializable, IterationListener {
    private static final long serialVersionUID = 4455143696487934647L;
    protected NeuralNetwork network;
    protected double lr;
    protected Object[] extraParams;
    protected static Logger log = LoggerFactory.getLogger(NeuralNetworkOptimizer.class);
    protected transient OptimizerMatrix opt;
    protected NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm;
    protected LossFunctions.LossFunction lossFunction;
    protected double tolerance = 9.999999747378752E-6d;
    protected List<Double> errors = new ArrayList();
    protected NeuralNetPlotter plotter = new NeuralNetPlotter();
    protected double maxStep = -1.0d;
    protected int currIteration = -1;

    public NeuralNetworkOptimizer(NeuralNetwork neuralNetwork, double d, Object[] objArr, NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm, LossFunctions.LossFunction lossFunction) {
        this.network = neuralNetwork;
        this.lr = d;
        if (objArr != null) {
            this.extraParams = new Object[objArr.length + 1];
            System.arraycopy(objArr, 0, this.extraParams, 0, objArr.length);
        } else {
            this.extraParams = new Object[1];
        }
        this.optimizationAlgorithm = optimizationAlgorithm;
        this.lossFunction = lossFunction;
    }

    private void createOptimizationAlgorithm() {
        if (this.optimizationAlgorithm == NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT) {
            this.opt = new VectorizedNonZeroStoppingConjugateGradient(this, this);
            this.opt.setTolerance(this.tolerance);
            return;
        }
        this.opt = new VectorizedDeepLearningGradientAscent(this, this);
        this.opt.setTolerance(this.tolerance);
        if (this.maxStep > 0.0d) {
            ((VectorizedDeepLearningGradientAscent) this.opt).setMaxStepSize(this.maxStep);
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getParameters() {
        return this.network.params();
    }

    public void train(INDArray iNDArray) {
        if (this.opt == null) {
            createOptimizationAlgorithm();
        }
        this.network.setInput(iNDArray);
        int intValue = this.extraParams.length < 3 ? 1000 : ((Integer) this.extraParams[2]).intValue();
        this.opt.setMaxIterations(intValue);
        this.opt.optimize(intValue);
        this.network.backProp(this.lr, intValue, this.extraParams);
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public void iterationDone(int i) {
        int renderWeightIterations = this.network.conf().getRenderWeightIterations();
        if (renderWeightIterations > 0 && i % renderWeightIterations == 0) {
            this.plotter.plotNetworkGradient(this.network, this.network.getGradient(this.extraParams), 100);
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public int getNumParameters() {
        return this.network.numParams();
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public double getParameter(int i) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setParameters(INDArray iNDArray) {
        if (this.network.conf().isConstrainGradientToUnitNorm()) {
            iNDArray.divi(iNDArray.normmax(Integer.MAX_VALUE));
        }
        this.network.setParams(iNDArray);
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setParameter(int i, double d) {
        throw new UnsupportedOperationException();
    }

    private int getAdjustedIndex(int i) {
        int length = this.network.getW().length();
        int length2 = this.network.getvBias().length();
        return i < length ? i : i >= length + length2 ? (i - length) - length2 : i - length;
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getValueGradient(int i) {
        if (i >= 1) {
            this.extraParams[this.extraParams.length - 1] = Integer.valueOf(i);
        }
        NeuralNetworkGradient gradient = this.network.getGradient(this.extraParams);
        return Nd4j.toFlattened(Arrays.asList(gradient.getwGradient(), gradient.getvBiasGradient(), gradient.gethBiasGradient()));
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public double getValue() {
        return -this.network.score();
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setCurrentIteration(int i) {
        if (i < 1) {
            return;
        }
        this.currIteration = i;
    }
}
