package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/GravesLSTM.class */
public class GravesLSTM extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.GravesLSTM> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";
    public static final String STATE_KEY_PREV_MEMCELL = "prevMem";

    public GravesLSTM(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
    }

    public GravesLSTM(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        throw new UnsupportedOperationException("gradient() method for layerwise pretraining: not supported for LSTMs (pretraining not possible)");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        return backpropGradientHelper(iNDArray, false, -1);
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer
    public Pair<Gradient, INDArray> tbpttBackpropGradient(INDArray iNDArray, int i) {
        return backpropGradientHelper(iNDArray, true, i);
    }

    private Pair<Gradient, INDArray> backpropGradientHelper(INDArray iNDArray, boolean z, int i) {
        FwdPassReturn activateHelper;
        INDArray param = getParam("W");
        INDArray param2 = getParam("RW");
        if (z) {
            activateHelper = activateHelper(true, this.stateMap.get("prevAct"), this.stateMap.get(STATE_KEY_PREV_MEMCELL), true);
            this.tBpttStateMap.put("prevAct", activateHelper.lastAct);
            this.tBpttStateMap.put(STATE_KEY_PREV_MEMCELL, activateHelper.lastMemCell);
        } else {
            activateHelper = activateHelper(true, null, null, true);
        }
        return LSTMHelpers.backpropGradientHelper(this.conf, this.input, param2, param, iNDArray, z, i, activateHelper, true, "W", "RW", "b");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return activate(iNDArray, true);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return activate(iNDArray, z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray);
        return activateHelper(z, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return activateHelper(true, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return activateHelper(z, null, null, false).fwdPassOutput;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return activateHelper(false, null, null, false).fwdPassOutput;
    }

    private FwdPassReturn activateHelper(boolean z, INDArray iNDArray, INDArray iNDArray2, boolean z2) {
        return LSTMHelpers.activateHelper(this, this.conf, this.input, getParam("RW"), getParam("W"), getParam("b"), z, iNDArray, iNDArray2, z2, true, "W");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activationMean() {
        return activate();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL2() <= 0.0d) {
            return 0.0d;
        }
        double doubleValue = getParam("RW").norm2Number().doubleValue();
        double d = doubleValue * doubleValue;
        double doubleValue2 = getParam("W").norm2Number().doubleValue();
        return 0.5d * this.conf.getLayer().getL2() * (d + (doubleValue2 * doubleValue2));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL1() <= 0.0d) {
            return 0.0d;
        }
        return this.conf.getLayer().getL1() * (getParam("RW").norm1Number().doubleValue() + getParam("W").norm1Number().doubleValue());
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer
    public INDArray rnnTimeStep(INDArray iNDArray) {
        setInput(iNDArray);
        FwdPassReturn activateHelper = activateHelper(false, this.stateMap.get("prevAct"), this.stateMap.get(STATE_KEY_PREV_MEMCELL), false);
        INDArray iNDArray2 = activateHelper.fwdPassOutput;
        this.stateMap.put("prevAct", activateHelper.lastAct);
        this.stateMap.put(STATE_KEY_PREV_MEMCELL, activateHelper.lastMemCell);
        return iNDArray2;
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer
    public INDArray rnnActivateUsingStoredState(INDArray iNDArray, boolean z, boolean z2) {
        setInput(iNDArray);
        FwdPassReturn activateHelper = activateHelper(z, this.stateMap.get("prevAct"), this.stateMap.get(STATE_KEY_PREV_MEMCELL), false);
        INDArray iNDArray2 = activateHelper.fwdPassOutput;
        if (z2) {
            this.tBpttStateMap.put("prevAct", activateHelper.lastAct);
            this.tBpttStateMap.put(STATE_KEY_PREV_MEMCELL, activateHelper.lastMemCell);
        }
        return iNDArray2;
    }
}
