package org.deeplearning4j.optimize.optimizers;

import org.deeplearning4j.nn.gradient.OutputLayerGradient;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/optimize/optimizers/OutputLayerOptimizer.class */
public class OutputLayerOptimizer implements OptimizableByGradientValueMatrix {
    private OutputLayer logReg;
    private double lr;
    private int currIteration = -1;

    public OutputLayerOptimizer(OutputLayer outputLayer, double d) {
        this.logReg = outputLayer;
        this.lr = d;
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setCurrentIteration(int i) {
        this.currIteration = i;
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public int getNumParameters() {
        return this.logReg.getW().length() + this.logReg.getB().length();
    }

    public void getParameters(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = getParameter(i);
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public double getParameter(int i) {
        return i >= this.logReg.getW().length() ? ((Double) this.logReg.getB().getScalar(i - this.logReg.getW().length()).element()).doubleValue() : ((Double) this.logReg.getW().getScalar(i).element()).doubleValue();
    }

    public void setParameters(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            setParameter(i, dArr[i]);
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setParameter(int i, double d) {
        if (i >= this.logReg.getW().length()) {
            this.logReg.getB().putScalar(i - this.logReg.getW().length(), d);
        } else {
            this.logReg.getW().putScalar(i, d);
        }
    }

    public void getValueGradient(double[] dArr) {
        OutputLayerGradient gradient = this.logReg.getGradient(this.lr);
        for (int i = 0; i < dArr.length; i++) {
            if (i < this.logReg.getW().length()) {
                dArr[i] = ((Double) gradient.getwGradient().getScalar(i).element()).doubleValue();
            } else {
                dArr[i] = ((Double) gradient.getbGradient().getScalar(i - this.logReg.getW().length()).element()).doubleValue();
            }
        }
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public double getValue() {
        return -this.logReg.score();
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getParameters() {
        return this.logReg.params();
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public void setParameters(INDArray iNDArray) {
        if (this.logReg.conf().isConstrainGradientToUnitNorm()) {
            iNDArray.divi(iNDArray.normmax(Integer.MAX_VALUE));
        }
        this.logReg.setParams(iNDArray);
    }

    @Override // org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix
    public INDArray getValueGradient(int i) {
        this.currIteration = i;
        OutputLayerGradient gradient = this.logReg.getGradient(this.lr);
        if (this.logReg.getW().length() != gradient.getwGradient().length()) {
            throw new IllegalStateException("Illegal length for gradient");
        }
        if (this.logReg.getB().length() != gradient.getbGradient().length()) {
            throw new IllegalStateException("Illegal length for gradient");
        }
        return Nd4j.toFlattened(new INDArray[]{gradient.getwGradient(), gradient.getbGradient()});
    }
}
