package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.NonZeroStoppingConjugateGradient;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/NeuralNetworkOptimizer.class */
public abstract class NeuralNetworkOptimizer implements Optimizable.ByGradientValue, Serializable, NeuralNetEpochListener {
    private static final long serialVersionUID = 4455143696487934647L;
    protected BaseNeuralNetwork network;
    protected double lr;
    protected Object[] extraParams;
    protected static Logger log = LoggerFactory.getLogger(NeuralNetworkOptimizer.class);
    protected transient NonZeroStoppingConjugateGradient opt;
    protected double tolerance = 1.0E-4d;
    protected List<Double> errors = new ArrayList();
    protected double minLearningRate = 0.001d;

    public NeuralNetworkOptimizer(BaseNeuralNetwork baseNeuralNetwork, double d, Object[] objArr) {
        this.network = baseNeuralNetwork;
        this.lr = d;
        this.extraParams = objArr;
    }

    public void train(DoubleMatrix doubleMatrix) {
        if (this.opt == null) {
            this.opt = new NonZeroStoppingConjugateGradient(this, this);
        }
        this.opt.setTolerance(this.tolerance);
        int intValue = ((Integer) this.extraParams[2]).intValue();
        this.opt.setMaxIterations(10000);
        this.opt.optimize(intValue);
    }

    @Override // org.deeplearning4j.optimize.NeuralNetEpochListener
    public void epochDone(int i) {
        int renderEpochs = this.network.getRenderEpochs();
        if (renderEpochs <= 0) {
            return;
        }
        if (i % renderEpochs == 0 || i == 0) {
            new NeuralNetPlotter().plotNetworkGradient(this.network, this.network.getGradient(this.extraParams));
        }
    }

    public List<Double> getErrors() {
        return this.errors;
    }

    public int getNumParameters() {
        return this.network.W.length + this.network.hBias.length + this.network.vBias.length;
    }

    public void getParameters(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = getParameter(i);
        }
    }

    public double getParameter(int i) {
        if (i < this.network.W.length) {
            return this.network.W.get(i);
        }
        int adjustedIndex = getAdjustedIndex(i);
        return i >= this.network.vBias.length + this.network.W.length ? this.network.hBias.get(adjustedIndex) : this.network.vBias.get(adjustedIndex);
    }

    public void setParameters(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            setParameter(i, dArr[i]);
        }
    }

    public void setParameter(int i, double d) {
        if (i < this.network.W.length) {
            this.network.W.put(i, d);
        } else if (i >= this.network.vBias.length + this.network.W.length) {
            this.network.hBias.put(getAdjustedIndex(i), d);
        } else {
            this.network.vBias.put(getAdjustedIndex(i), d);
        }
    }

    private int getAdjustedIndex(int i) {
        int i2 = this.network.W.length;
        int i3 = this.network.vBias.length;
        return i < i2 ? i : i >= i2 + i3 ? (i - i2) - i3 : i - i2;
    }

    public abstract void getValueGradient(double[] dArr);

    public double getValue() {
        return -this.network.getReConstructionCrossEntropy();
    }
}
