package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/GRU.class */
public class GRU extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.layers.GRU> {
    public static final String STATE_KEY_PREV_ACTIVATION = "prevAct";

    public GRU(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        throw new UnsupportedOperationException("GRU layer disabled: Backprop implementation is incorrect in this version. Consider using GravesLSTM instead");
    }

    public GRU(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
        throw new UnsupportedOperationException("GRU layer disabled: Backprop implementation is incorrect in this version. Consider using GravesLSTM instead");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public Gradient gradient() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        INDArray tensorAlongDimension;
        INDArray tensorAlongDimension2;
        INDArray[] activateHelper = activateHelper(true, null);
        INDArray iNDArray2 = activateHelper[0];
        INDArray iNDArray3 = activateHelper[1];
        INDArray iNDArray4 = activateHelper[2];
        INDArray param = getParam("W");
        INDArray param2 = getParam("RW");
        int size = param2.size(0);
        int size2 = param.size(0);
        int size3 = iNDArray.size(0);
        boolean z = iNDArray.rank() < 3;
        int size4 = z ? 1 : iNDArray.size(2);
        INDArray iNDArray5 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)});
        INDArray iNDArray6 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)});
        INDArray iNDArray7 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)});
        INDArray iNDArray8 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)});
        INDArray iNDArray9 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)});
        INDArray iNDArray10 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)});
        INDArray transpose = Nd4j.diag(iNDArray8).transpose();
        INDArray transpose2 = Nd4j.diag(iNDArray10).transpose();
        INDArray zeros = Nd4j.zeros(new int[]{1, 3 * size});
        INDArray zeros2 = Nd4j.zeros(new int[]{size2, 3 * size});
        INDArray zeros3 = Nd4j.zeros(new int[]{size, 3 * size});
        INDArray zeros4 = Nd4j.zeros(new int[]{size3, size2, size4});
        INDArray zeros5 = Nd4j.zeros(size3, size);
        int i = size4 - 1;
        while (i >= 0) {
            INDArray zeros6 = i == 0 ? Nd4j.zeros(size3, size) : iNDArray2.tensorAlongDimension(i - 1, new int[]{1, 0});
            INDArray tensorAlongDimension3 = z ? iNDArray4 : iNDArray4.tensorAlongDimension(i, new int[]{1, 0});
            INDArray tensorAlongDimension4 = z ? iNDArray3 : iNDArray3.tensorAlongDimension(i, new int[]{1, 0});
            if (i == size4 - 1) {
                tensorAlongDimension = Nd4j.zeros(size3, 3 * size);
                tensorAlongDimension2 = Nd4j.zeros(size3, 3 * size);
            } else {
                tensorAlongDimension = iNDArray4.tensorAlongDimension(i + 1, new int[]{1, 0});
                tensorAlongDimension2 = iNDArray3.tensorAlongDimension(i + 1, new int[]{1, 0});
            }
            INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", tensorAlongDimension4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}).dup()).derivative());
            INDArray dup = (z ? iNDArray : iNDArray.tensorAlongDimension(i, new int[]{1, 0})).dup();
            if (i < size4 - 1) {
                INDArray tensorAlongDimension5 = z ? iNDArray2 : iNDArray2.tensorAlongDimension(i, new int[]{1, 0});
                INDArray iNDArray11 = tensorAlongDimension.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)});
                INDArray iNDArray12 = tensorAlongDimension.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)});
                INDArray iNDArray13 = tensorAlongDimension.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)});
                INDArray iNDArray14 = tensorAlongDimension2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)});
                INDArray iNDArray15 = tensorAlongDimension2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)});
                INDArray iNDArray16 = tensorAlongDimension2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)});
                INDArray execAndReturn2 = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", iNDArray14.dup()).derivative());
                INDArray execAndReturn3 = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", iNDArray15.dup()).derivative());
                INDArray execAndReturn4 = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), iNDArray16.dup()).derivative());
                dup.addi(iNDArray12.mul(zeros5));
                dup.addi(tensorAlongDimension5.sub(iNDArray13).muli(execAndReturn3).muli(iNDArray9.mmul(zeros5.transpose()).transpose()));
                dup.addi(iNDArray12.rsub(Double.valueOf(1.0d)).muli(execAndReturn4).muli(iNDArray11.add(tensorAlongDimension5.mul(execAndReturn2).muliRowVector(transpose))).muli(iNDArray10.mmul(zeros5.transpose()).transpose()));
            }
            INDArray muli = dup.mul(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", tensorAlongDimension4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}).dup()).derivative())).muli(zeros6.sub(tensorAlongDimension3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)})));
            INDArray muli2 = dup.mul(Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), tensorAlongDimension4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}).dup()).derivative())).muli(tensorAlongDimension3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}).rsub(Double.valueOf(1.0d)));
            INDArray muli3 = muli2.mulRowVector(transpose2).muli(zeros6).muli(execAndReturn);
            INDArray tensorAlongDimension6 = z ? this.input : this.input.tensorAlongDimension(i, new int[]{1, 0});
            zeros2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}).addi(muli3.transpose().mmul(tensorAlongDimension6).transpose());
            zeros2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}).addi(muli.transpose().mmul(tensorAlongDimension6).transpose());
            zeros2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}).addi(muli2.transpose().mmul(tensorAlongDimension6).transpose());
            if (i > 0) {
                zeros3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}).addi(muli3.transpose().mmul(zeros6).transpose());
                zeros3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size, 2 * size)}).addi(muli.transpose().mmul(zeros6).transpose());
                zeros3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size, 3 * size)}).addi(muli2.transpose().mmul(zeros6.mul(tensorAlongDimension3.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size)}))).transpose());
            }
            zeros.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, size)}).addi(muli3.sum(new int[]{0}));
            zeros.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(size, 2 * size)}).addi(muli.sum(new int[]{0}));
            zeros.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(2 * size, 3 * size)}).addi(muli2.sum(new int[]{0}));
            zeros4.tensorAlongDimension(i, new int[]{1, 0}).assign(iNDArray5.mmul(muli3.transpose()).transpose().addi(iNDArray6.mmul(muli.transpose()).transpose()).addi(iNDArray7.mmul(muli2.transpose()).transpose()));
            zeros5 = dup;
            i--;
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.setGradientFor("W", zeros2);
        defaultGradient.setGradientFor("RW", zeros3);
        defaultGradient.setGradientFor("b", zeros);
        return new Pair<>(defaultGradient, zeros4);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        return activate(iNDArray, true);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray, boolean z) {
        return activate(iNDArray, z);
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray, boolean z) {
        setInput(iNDArray, z);
        return activateHelper(z, null)[0];
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        setInput(iNDArray);
        return activateHelper(true, null)[0];
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        return activateHelper(z, null)[0];
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return activateHelper(false, null)[0];
    }

    private INDArray[] activateHelper(boolean z, INDArray iNDArray) {
        INDArray param = getParam("W");
        INDArray param2 = getParam("RW");
        INDArray param3 = getParam("b");
        boolean z2 = this.input.rank() < 3;
        int size = z2 ? 1 : this.input.size(2);
        int size2 = param2.size(0);
        int size3 = this.input.size(0);
        INDArray iNDArray2 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray3 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray4 = param.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray5 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray6 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray7 = param2.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)});
        INDArray iNDArray8 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(0, size2)});
        INDArray iNDArray9 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(size2, 2 * size2)});
        INDArray iNDArray10 = param3.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(2 * size2, 3 * size2)});
        if (this.conf.isUseDropConnect() && z && this.conf.getLayer().getDropOut() > 0.0d) {
            Dropout.applyDropConnect(this, "W");
        }
        INDArray zeros = Nd4j.zeros(new int[]{size3, size2, size});
        INDArray zeros2 = Nd4j.zeros(new int[]{size3, 3 * size2, size});
        INDArray zeros3 = Nd4j.zeros(new int[]{size3, 3 * size2, size});
        if (iNDArray == null) {
            iNDArray = Nd4j.zeros(size3, size2);
        }
        for (int i = 0; i < size; i++) {
            INDArray tensorAlongDimension = z2 ? this.input : this.input.tensorAlongDimension(i, new int[]{1, 0});
            if (i > 0) {
                iNDArray = zeros.tensorAlongDimension(i - 1, new int[]{1, 0});
            }
            INDArray zeros4 = Nd4j.zeros(size3, 3 * size2);
            INDArray zeros5 = Nd4j.zeros(size3, 3 * size2);
            INDArray addiRowVector = tensorAlongDimension.mmul(iNDArray2).addi(iNDArray.mmul(iNDArray5)).addiRowVector(iNDArray8);
            INDArray execAndReturn = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", addiRowVector.dup()));
            zeros4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)}).assign(addiRowVector);
            zeros5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, size2)}).assign(execAndReturn);
            INDArray addiRowVector2 = tensorAlongDimension.mmul(iNDArray3).addi(iNDArray.mmul(iNDArray6)).addiRowVector(iNDArray9);
            INDArray execAndReturn2 = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", addiRowVector2.dup()));
            zeros4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)}).assign(addiRowVector2);
            zeros5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(size2, 2 * size2)}).assign(execAndReturn2);
            INDArray addiRowVector3 = tensorAlongDimension.mmul(iNDArray4).addi(iNDArray.mul(execAndReturn).mmul(iNDArray7)).addiRowVector(iNDArray10);
            INDArray execAndReturn3 = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.conf.getLayer().getActivationFunction(), addiRowVector3.dup()));
            zeros4.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)}).assign(addiRowVector3);
            zeros5.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(2 * size2, 3 * size2)}).assign(execAndReturn3);
            INDArray addi = execAndReturn2.mul(iNDArray).addi(execAndReturn2.rsub(1).mul(execAndReturn3));
            zeros2.tensorAlongDimension(i, new int[]{1, 0}).assign(zeros4);
            zeros3.tensorAlongDimension(i, new int[]{1, 0}).assign(zeros5);
            zeros.tensorAlongDimension(i, new int[]{1, 0}).assign(addi);
        }
        return new INDArray[]{zeros, zeros2, zeros3};
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activationMean() {
        return activate();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getL2() <= 0.0d) {
            return 0.0d;
        }
        return 0.5d * this.conf.getL2() * (Transforms.pow(getParam("RW"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.pow(getParam("W"), 2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getL1() <= 0.0d) {
            return 0.0d;
        }
        return this.conf.getL1() * (Transforms.abs(getParam("RW")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.abs(getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0));
    }

    @Override // org.deeplearning4j.nn.layers.recurrent.BaseRecurrentLayer
    public INDArray rnnTimeStep(INDArray iNDArray) {
        INDArray iNDArray2 = activateHelper(false, this.stateMap.get("prevAct"))[0];
        this.stateMap.put("prevAct", iNDArray2.tensorAlongDimension(iNDArray2.size(2) - 1, new int[]{1, 0}).dup());
        return iNDArray2;
    }
}
