package org.deeplearning4j.nn.layers.recurrent;

import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cuda;
import org.bytedeco.javacpp.cudnn;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/CudnnLSTMHelper.class */
public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnLSTMHelper.class);
    protected static final int NUM_LAYERS = 1;
    protected static final float DROPOUT = 0.0f;
    protected static final boolean BIDIRECTIONAL = false;
    protected static final int RNN_MODE = 2;
    protected static final int NUM_LINEAR_LAYERS = 8;
    private CudnnLSTMContext cudnnContext = new CudnnLSTMContext();
    private BaseCudnnHelper.TensorArray xDesc = new BaseCudnnHelper.TensorArray();
    private BaseCudnnHelper.TensorArray yDesc = new BaseCudnnHelper.TensorArray();
    private BaseCudnnHelper.TensorArray dxDesc = new BaseCudnnHelper.TensorArray();
    private BaseCudnnHelper.TensorArray dyDesc = new BaseCudnnHelper.TensorArray();
    private BaseCudnnHelper.DataCache stateSpace = new BaseCudnnHelper.DataCache();
    private BaseCudnnHelper.DataCache reserveSpace = new BaseCudnnHelper.DataCache();
    private BaseCudnnHelper.DataCache weightsSpace = new BaseCudnnHelper.DataCache();
    private boolean initializedDropoutDescriptor = false;

    /* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/CudnnLSTMHelper$CudnnLSTMContext.class */
    private static class CudnnLSTMContext extends BaseCudnnHelper.CudnnContext {
        private cudnn.cudnnTensorStruct hxDesc;
        private cudnn.cudnnTensorStruct cxDesc;
        private cudnn.cudnnTensorStruct hyDesc;
        private cudnn.cudnnTensorStruct cyDesc;
        private cudnn.cudnnTensorStruct dhxDesc;
        private cudnn.cudnnTensorStruct dcxDesc;
        private cudnn.cudnnTensorStruct dhyDesc;
        private cudnn.cudnnTensorStruct dcyDesc;
        private cudnn.cudnnFilterStruct wDesc;
        private cudnn.cudnnFilterStruct dwDesc;
        private cudnn.cudnnFilterStruct linLayerMatDesc;
        private cudnn.cudnnFilterStruct linLayerBiasDesc;
        private cudnn.cudnnRNNStruct rnnDesc;
        private cudnn.cudnnDropoutStruct dropoutDesc;
        private cudnn.cudnnActivationStruct activationDesc;

        /* loaded from: input_file:org/deeplearning4j/nn/layers/recurrent/CudnnLSTMHelper$CudnnLSTMContext$Deallocator.class */
        private static class Deallocator extends CudnnLSTMContext implements Pointer.Deallocator {
            Deallocator(CudnnLSTMContext cudnnLSTMContext) {
                super(cudnnLSTMContext);
            }

            public void deallocate() {
                destroyHandles();
            }
        }

        public CudnnLSTMContext() {
            this.hxDesc = new cudnn.cudnnTensorStruct();
            this.cxDesc = new cudnn.cudnnTensorStruct();
            this.hyDesc = new cudnn.cudnnTensorStruct();
            this.cyDesc = new cudnn.cudnnTensorStruct();
            this.dhxDesc = new cudnn.cudnnTensorStruct();
            this.dcxDesc = new cudnn.cudnnTensorStruct();
            this.dhyDesc = new cudnn.cudnnTensorStruct();
            this.dcyDesc = new cudnn.cudnnTensorStruct();
            this.wDesc = new cudnn.cudnnFilterStruct();
            this.dwDesc = new cudnn.cudnnFilterStruct();
            this.linLayerMatDesc = new cudnn.cudnnFilterStruct();
            this.linLayerBiasDesc = new cudnn.cudnnFilterStruct();
            this.rnnDesc = new cudnn.cudnnRNNStruct();
            this.dropoutDesc = new cudnn.cudnnDropoutStruct();
            this.activationDesc = new cudnn.cudnnActivationStruct();
            createHandles();
            deallocator(new Deallocator(this));
        }

        public CudnnLSTMContext(CudnnLSTMContext cudnnLSTMContext) {
            super(cudnnLSTMContext);
            this.hxDesc = new cudnn.cudnnTensorStruct();
            this.cxDesc = new cudnn.cudnnTensorStruct();
            this.hyDesc = new cudnn.cudnnTensorStruct();
            this.cyDesc = new cudnn.cudnnTensorStruct();
            this.dhxDesc = new cudnn.cudnnTensorStruct();
            this.dcxDesc = new cudnn.cudnnTensorStruct();
            this.dhyDesc = new cudnn.cudnnTensorStruct();
            this.dcyDesc = new cudnn.cudnnTensorStruct();
            this.wDesc = new cudnn.cudnnFilterStruct();
            this.dwDesc = new cudnn.cudnnFilterStruct();
            this.linLayerMatDesc = new cudnn.cudnnFilterStruct();
            this.linLayerBiasDesc = new cudnn.cudnnFilterStruct();
            this.rnnDesc = new cudnn.cudnnRNNStruct();
            this.dropoutDesc = new cudnn.cudnnDropoutStruct();
            this.activationDesc = new cudnn.cudnnActivationStruct();
            this.hxDesc = new cudnn.cudnnTensorStruct(cudnnLSTMContext.hxDesc);
            this.cxDesc = new cudnn.cudnnTensorStruct(cudnnLSTMContext.cxDesc);
            this.hyDesc = new cudnn.cudnnTensorStruct(cudnnLSTMContext.hyDesc);
            this.cyDesc = new cudnn.cudnnTensorStruct(cudnnLSTMContext.cyDesc);
            this.dhxDesc = new cudnn.cudnnTensorStruct(cudnnLSTMContext.dhxDesc);
            this.dcxDesc = new cudnn.cudnnTensorStruct(cudnnLSTMContext.dcxDesc);
            this.dhyDesc = new cudnn.cudnnTensorStruct(cudnnLSTMContext.dhyDesc);
            this.dcyDesc = new cudnn.cudnnTensorStruct(cudnnLSTMContext.dcyDesc);
            this.wDesc = new cudnn.cudnnFilterStruct(cudnnLSTMContext.wDesc);
            this.dwDesc = new cudnn.cudnnFilterStruct(cudnnLSTMContext.dwDesc);
            this.linLayerMatDesc = new cudnn.cudnnFilterStruct(cudnnLSTMContext.linLayerMatDesc);
            this.linLayerBiasDesc = new cudnn.cudnnFilterStruct(cudnnLSTMContext.linLayerBiasDesc);
            this.rnnDesc = new cudnn.cudnnRNNStruct(cudnnLSTMContext.rnnDesc);
            this.dropoutDesc = new cudnn.cudnnDropoutStruct(cudnnLSTMContext.dropoutDesc);
            this.activationDesc = new cudnn.cudnnActivationStruct(cudnnLSTMContext.activationDesc);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.nn.layers.BaseCudnnHelper.CudnnContext
        public void createHandles() {
            super.createHandles();
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.hxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.cxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.hyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.cyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dhxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dcxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dhyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dcyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor(this.wDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor(this.dwDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor(this.linLayerMatDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor(this.linLayerBiasDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateRNNDescriptor(this.rnnDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateDropoutDescriptor(this.dropoutDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnCreateActivationDescriptor(this.activationDesc));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.nn.layers.BaseCudnnHelper.CudnnContext
        public void destroyHandles() {
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyActivationDescriptor(this.activationDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyDropoutDescriptor(this.dropoutDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyRNNDescriptor(this.rnnDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor(this.wDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor(this.dwDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor(this.linLayerMatDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor(this.linLayerBiasDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.hxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.cxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.hyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.cyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dhxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dcxDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dhyDesc));
            CudnnLSTMHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dcyDesc));
            super.destroyHandles();
        }
    }

    private static INDArray toCOrder(INDArray iNDArray) {
        if (iNDArray.isView() || iNDArray.ordering() != 'c' || !Shape.strideDescendingCAscendingF(iNDArray)) {
            iNDArray = iNDArray.dup('c');
        }
        return iNDArray;
    }

    public boolean checkSupported(IActivation iActivation, IActivation iActivation2, boolean z) {
        boolean checkSupported = checkSupported();
        if (!(iActivation instanceof ActivationSigmoid)) {
            checkSupported = BIDIRECTIONAL;
            log.warn("Not supported: Gate activation functions != ActivationSigmoid");
        }
        if (!(iActivation2 instanceof ActivationTanH)) {
            checkSupported = BIDIRECTIONAL;
            log.warn("Not supported: Layer activation functions != ActivationTanH");
        }
        if (z) {
            checkSupported = BIDIRECTIONAL;
            log.warn("Not supported: LSTM layers with peephole connections");
        }
        return checkSupported;
    }

    public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration neuralNetConfiguration, IActivation iActivation, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, int i, FwdPassReturn fwdPassReturn, boolean z2, String str, String str2, String str3, Map<String, INDArray> map, INDArray iNDArray5, boolean z3, LayerWorkspaceMgr layerWorkspaceMgr) {
        Pointer pointer;
        int i2;
        long j;
        long size = iNDArray2.size(BIDIRECTIONAL);
        long size2 = iNDArray3.size(BIDIRECTIONAL);
        long size3 = iNDArray.size(NUM_LAYERS);
        long size4 = iNDArray4.size(BIDIRECTIONAL);
        long size5 = iNDArray4.rank() < 3 ? 1L : iNDArray4.size(RNN_MODE);
        INDArray cOrder = toCOrder(iNDArray.permute(new int[]{RNN_MODE, BIDIRECTIONAL, NUM_LAYERS}));
        INDArray cOrder2 = toCOrder(iNDArray4.permute(new int[]{RNN_MODE, BIDIRECTIONAL, NUM_LAYERS}));
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, new long[]{size5, size4, size2}, 'c');
        INDArray iNDArray6 = map.get(str);
        INDArray iNDArray7 = map.get(str2);
        INDArray iNDArray8 = map.get(str3);
        INDArray cOrder3 = toCOrder(fwdPassReturn.fwdPassOutput.permute(new int[]{RNN_MODE, BIDIRECTIONAL, NUM_LAYERS}));
        INDArray cOrder4 = toCOrder(fwdPassReturn.prevMemCell);
        INDArray cOrder5 = toCOrder(fwdPassReturn.prevAct);
        Nd4j.getExecutioner().commit();
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareActionAllWrite = atomicAllocator.getFlowController().prepareActionAllWrite(new INDArray[]{cOrder, cOrder2, createUninitialized, cOrder3, cOrder4, cOrder5, iNDArray6, iNDArray7, iNDArray8});
        Pointer pointer2 = atomicAllocator.getPointer(cOrder, prepareActionAllWrite);
        Pointer pointer3 = atomicAllocator.getPointer(cOrder2, prepareActionAllWrite);
        Pointer pointer4 = atomicAllocator.getPointer(createUninitialized, prepareActionAllWrite);
        Pointer pointer5 = atomicAllocator.getPointer(cOrder3, prepareActionAllWrite);
        Pointer pointer6 = atomicAllocator.getPointer(cOrder4, prepareActionAllWrite);
        Pointer pointer7 = atomicAllocator.getPointer(cOrder5, prepareActionAllWrite);
        Pointer pointer8 = atomicAllocator.getPointer(iNDArray6, prepareActionAllWrite);
        Pointer pointer9 = atomicAllocator.getPointer(iNDArray7, prepareActionAllWrite);
        Pointer pointer10 = atomicAllocator.getPointer(iNDArray8, prepareActionAllWrite);
        cuda.CUstream_st cUstream_st = new cuda.CUstream_st(prepareActionAllWrite.getOldStream());
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, cUstream_st));
        if (z) {
            long max = Math.max(0L, size5 - i) * size4 * size;
            pointer2.position(max * this.dataTypeSize);
            pointer3.position(max * 1 * this.dataTypeSize);
            pointer5.position(max * 1 * this.dataTypeSize);
            size5 = (int) Math.min(size5, i);
        }
        cudnn.cudnnTensorStruct cudnntensorstruct = this.xDesc.get(cudnn.cudnnTensorStruct.class, 0L);
        BaseCudnnHelper.DataCache dataCache = (BaseCudnnHelper.DataCache) layerWorkspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
        checkCudnn(cudnn.cudnnRNNBackwardData(this.cudnnContext, this.cudnnContext.rnnDesc, (int) size5, this.yDesc, pointer5, this.dyDesc, pointer3, this.cudnnContext.dhyDesc, (Pointer) null, this.cudnnContext.dcyDesc, (Pointer) null, this.cudnnContext.wDesc, this.weightsSpace, this.cudnnContext.hxDesc, pointer7, this.cudnnContext.cxDesc, pointer6, this.dxDesc, pointer4, this.cudnnContext.dhxDesc, (Pointer) null, this.cudnnContext.dcxDesc, (Pointer) null, dataCache, dataCache.limit(), this.reserveSpace, this.reserveSpace.limit()));
        checkCuda(cuda.cudaMemsetAsync(this.weightsSpace, BIDIRECTIONAL, this.weightsSpace.limit(), cUstream_st));
        checkCudnn(cudnn.cudnnRNNBackwardWeights(this.cudnnContext, this.cudnnContext.rnnDesc, (int) size5, this.xDesc, pointer2, this.cudnnContext.hxDesc, pointer7, this.yDesc, pointer5, dataCache, dataCache.limit(), this.cudnnContext.dwDesc, this.weightsSpace, this.reserveSpace, this.reserveSpace.limit()));
        int[] iArr = new int[NUM_LAYERS];
        int[] iArr2 = new int[NUM_LAYERS];
        int[] iArr3 = new int[NUM_LAYERS];
        int[] iArr4 = new int[3];
        Pointer pointer11 = new Pointer();
        Pointer pointer12 = new Pointer();
        for (int i3 = BIDIRECTIONAL; i3 < NUM_LAYERS; i3 += NUM_LAYERS) {
            for (int i4 = BIDIRECTIONAL; i4 < NUM_LINEAR_LAYERS; i4 += NUM_LAYERS) {
                checkCudnn(cudnn.cudnnGetRNNLinLayerMatrixParams(this.cudnnContext, this.cudnnContext.rnnDesc, i3, cudnntensorstruct, this.cudnnContext.wDesc, this.weightsSpace, i4, this.cudnnContext.linLayerMatDesc, pointer11));
                checkCudnn(cudnn.cudnnGetFilterNdDescriptor(this.cudnnContext.linLayerMatDesc, 3, iArr, iArr2, iArr3, iArr4));
                checkCudnn(cudnn.cudnnGetRNNLinLayerBiasParams(this.cudnnContext, this.cudnnContext.rnnDesc, i3, cudnntensorstruct, this.cudnnContext.wDesc, this.weightsSpace, i4, this.cudnnContext.linLayerBiasDesc, pointer12));
                checkCudnn(cudnn.cudnnGetFilterNdDescriptor(this.cudnnContext.linLayerBiasDesc, 3, iArr, iArr2, iArr3, iArr4));
                switch (i4) {
                    case BIDIRECTIONAL /* 0 */:
                        pointer = pointer8;
                        i2 = 3;
                        j = size3;
                        break;
                    case NUM_LAYERS /* 1 */:
                        pointer = pointer8;
                        i2 = NUM_LAYERS;
                        j = size3;
                        break;
                    case RNN_MODE /* 2 */:
                        pointer = pointer8;
                        i2 = BIDIRECTIONAL;
                        j = size3;
                        break;
                    case 3:
                        pointer = pointer8;
                        i2 = RNN_MODE;
                        j = size3;
                        break;
                    case 4:
                        pointer = pointer9;
                        i2 = 3;
                        j = size;
                        break;
                    case 5:
                        pointer = pointer9;
                        i2 = NUM_LAYERS;
                        j = size;
                        break;
                    case 6:
                        pointer = pointer9;
                        i2 = BIDIRECTIONAL;
                        j = size;
                        break;
                    case 7:
                        pointer = pointer9;
                        i2 = RNN_MODE;
                        j = size;
                        break;
                    default:
                        throw new RuntimeException();
                }
                checkCuda(cuda.cudaMemcpyAsync(pointer.position(i2 * j * size * this.dataTypeSize), pointer11, j * size * this.dataTypeSize, 3, cUstream_st));
                if (i4 < 4) {
                    checkCuda(cuda.cudaMemcpyAsync(pointer10.position(i2 * size * this.dataTypeSize), pointer12, size * this.dataTypeSize, 3, cUstream_st));
                }
            }
        }
        atomicAllocator.getFlowController().registerActionAllWrite(prepareActionAllWrite, new INDArray[]{cOrder, cOrder2, createUninitialized, cOrder3, cOrder4, cOrder5, iNDArray6, iNDArray7, iNDArray8});
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.gradientForVariable().put(str, iNDArray6);
        defaultGradient.gradientForVariable().put(str2, iNDArray7);
        defaultGradient.gradientForVariable().put(str3, iNDArray8);
        return new Pair<>(defaultGradient, createUninitialized.permute(new int[]{NUM_LAYERS, RNN_MODE, BIDIRECTIONAL}));
    }

    public FwdPassReturn activate(Layer layer, NeuralNetConfiguration neuralNetConfiguration, IActivation iActivation, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z, INDArray iNDArray5, INDArray iNDArray6, boolean z2, boolean z3, String str, INDArray iNDArray7, boolean z4, LayerWorkspaceMgr layerWorkspaceMgr) {
        Pointer pointer;
        int i;
        long j;
        long size = iNDArray.rank() < 3 ? 1L : iNDArray.size(RNN_MODE);
        long size2 = iNDArray2.size(BIDIRECTIONAL);
        long size3 = iNDArray.size(BIDIRECTIONAL);
        long size4 = iNDArray.size(NUM_LAYERS);
        INDArray cOrder = toCOrder(iNDArray.permute(new int[]{RNN_MODE, BIDIRECTIONAL, NUM_LAYERS}));
        INDArray cOrder2 = toCOrder(iNDArray5);
        INDArray cOrder3 = toCOrder(iNDArray6);
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new long[]{size, size3, size2 * 1}, 'c');
        INDArray createUninitialized2 = Nd4j.createUninitialized(new long[]{size3, size2}, 'c');
        INDArray createUninitialized3 = Nd4j.createUninitialized(new long[]{size3, size2}, 'c');
        FwdPassReturn fwdPassReturn = new FwdPassReturn();
        fwdPassReturn.prevAct = cOrder2;
        fwdPassReturn.prevMemCell = cOrder3;
        Nd4j.getExecutioner().commit();
        if (size > this.xDesc.capacity()) {
            this.xDesc.deallocate();
            this.xDesc = new BaseCudnnHelper.TensorArray(size);
        }
        if (size > this.yDesc.capacity()) {
            this.yDesc.deallocate();
            this.yDesc = new BaseCudnnHelper.TensorArray(size);
        }
        if (size > this.dxDesc.capacity()) {
            this.dxDesc.deallocate();
            this.dxDesc = new BaseCudnnHelper.TensorArray(size);
        }
        if (size > this.dyDesc.capacity()) {
            this.dyDesc.deallocate();
            this.dyDesc = new BaseCudnnHelper.TensorArray(size);
        }
        for (int i2 = BIDIRECTIONAL; i2 < size; i2 += NUM_LAYERS) {
            int[] iArr = {(int) size3, (int) size4, NUM_LAYERS};
            int[] iArr2 = {iArr[RNN_MODE] * iArr[NUM_LAYERS], iArr[RNN_MODE], NUM_LAYERS};
            checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.xDesc.get(cudnn.cudnnTensorStruct.class, i2), this.dataType, 3, iArr, iArr2));
            checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.dxDesc.get(cudnn.cudnnTensorStruct.class, i2), this.dataType, 3, iArr, iArr2));
            int[] iArr3 = {(int) size3, ((int) size2) * NUM_LAYERS, NUM_LAYERS};
            int[] iArr4 = {iArr3[RNN_MODE] * iArr3[NUM_LAYERS], iArr3[RNN_MODE], NUM_LAYERS};
            checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.yDesc.get(cudnn.cudnnTensorStruct.class, i2), this.dataType, 3, iArr3, iArr4));
            checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.dyDesc.get(cudnn.cudnnTensorStruct.class, i2), this.dataType, 3, iArr3, iArr4));
        }
        int[] iArr5 = {NUM_LAYERS, (int) size3, (int) size2};
        int[] iArr6 = {iArr5[RNN_MODE] * iArr5[NUM_LAYERS], iArr5[RNN_MODE], NUM_LAYERS};
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.hxDesc, this.dataType, 3, iArr5, iArr6));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.cxDesc, this.dataType, 3, iArr5, iArr6));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.hyDesc, this.dataType, 3, iArr5, iArr6));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.cyDesc, this.dataType, 3, iArr5, iArr6));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.dhxDesc, this.dataType, 3, iArr5, iArr6));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.dcxDesc, this.dataType, 3, iArr5, iArr6));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.dhyDesc, this.dataType, 3, iArr5, iArr6));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.dcyDesc, this.dataType, 3, iArr5, iArr6));
        checkCudnn(cudnn.cudnnDropoutGetStatesSize(this.cudnnContext, this.sizeInBytes));
        long j2 = this.sizeInBytes.get(0L);
        if (j2 > this.stateSpace.capacity()) {
            this.stateSpace.deallocate();
            this.stateSpace = new BaseCudnnHelper.DataCache(j2);
        }
        this.stateSpace.limit(j2);
        if (!this.initializedDropoutDescriptor) {
            checkCudnn(cudnn.cudnnSetDropoutDescriptor(this.cudnnContext.dropoutDesc, this.cudnnContext, DROPOUT, this.stateSpace, j2, Nd4j.getRandom().getSeed()));
        }
        checkCudnn(cudnn.cudnnSetRNNDescriptor_v6(this.cudnnContext, this.cudnnContext.rnnDesc, (int) size2, NUM_LAYERS, this.cudnnContext.dropoutDesc, BIDIRECTIONAL, BIDIRECTIONAL, RNN_MODE, BIDIRECTIONAL, this.dataType));
        cudnn.cudnnTensorStruct cudnntensorstruct = this.xDesc.get(cudnn.cudnnTensorStruct.class, 0L);
        checkCudnn(cudnn.cudnnGetRNNParamsSize(this.cudnnContext, this.cudnnContext.rnnDesc, cudnntensorstruct, this.sizeInBytes, this.dataType));
        long j3 = this.sizeInBytes.get(0L);
        if (j3 > this.weightsSpace.capacity()) {
            this.weightsSpace.deallocate();
            this.weightsSpace = new BaseCudnnHelper.DataCache(j3);
        }
        this.weightsSpace.limit(j3);
        int[] iArr7 = {((int) j3) / this.dataTypeSize, NUM_LAYERS, NUM_LAYERS};
        checkCudnn(cudnn.cudnnSetFilterNdDescriptor(this.cudnnContext.wDesc, this.dataType, BIDIRECTIONAL, 3, iArr7));
        checkCudnn(cudnn.cudnnSetFilterNdDescriptor(this.cudnnContext.dwDesc, this.dataType, BIDIRECTIONAL, 3, iArr7));
        checkCudnn(cudnn.cudnnGetRNNWorkspaceSize(this.cudnnContext, this.cudnnContext.rnnDesc, (int) size, this.xDesc, this.sizeInBytes));
        long j4 = this.sizeInBytes.get(0L);
        BaseCudnnHelper.DataCache dataCache = (BaseCudnnHelper.DataCache) layerWorkspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
        if (dataCache == null || j4 > dataCache.capacity()) {
            if (log.isTraceEnabled()) {
                if (dataCache == null) {
                    log.trace("CudnnLSTMHelper activate: Allocating initial workspace of size {} ({})", Long.valueOf(j4), StringUtils.TraditionalBinaryPrefix.long2String(j4, "B", RNN_MODE));
                } else {
                    log.trace("CudnnLSTMHelper activate: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", new Object[]{Long.valueOf(dataCache.capacity()), StringUtils.TraditionalBinaryPrefix.long2String(dataCache.capacity(), "B", RNN_MODE), Long.valueOf(j4), StringUtils.TraditionalBinaryPrefix.long2String(j4, "B", RNN_MODE)});
                }
            }
            if (dataCache != null) {
                dataCache.deallocate();
            }
            dataCache = new BaseCudnnHelper.DataCache(j4);
            layerWorkspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, dataCache);
        }
        dataCache.limit(j4);
        checkCudnn(cudnn.cudnnGetRNNTrainingReserveSize(this.cudnnContext, this.cudnnContext.rnnDesc, (int) size, this.xDesc, this.sizeInBytes));
        long j5 = this.sizeInBytes.get(0L);
        if (j5 > this.reserveSpace.capacity()) {
            this.reserveSpace.deallocate();
            this.reserveSpace = new BaseCudnnHelper.DataCache(j5);
        }
        this.reserveSpace.limit(j5);
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareActionAllWrite = atomicAllocator.getFlowController().prepareActionAllWrite(new INDArray[]{cOrder, iNDArray3, iNDArray2, iNDArray4, cOrder2, cOrder3, createUninitialized, createUninitialized2, createUninitialized3});
        Pointer pointer2 = atomicAllocator.getPointer(cOrder, prepareActionAllWrite);
        Pointer pointer3 = atomicAllocator.getPointer(iNDArray3, prepareActionAllWrite);
        Pointer pointer4 = atomicAllocator.getPointer(iNDArray2, prepareActionAllWrite);
        Pointer pointer5 = atomicAllocator.getPointer(iNDArray4, prepareActionAllWrite);
        Pointer pointer6 = atomicAllocator.getPointer(cOrder2, prepareActionAllWrite);
        Pointer pointer7 = atomicAllocator.getPointer(cOrder3, prepareActionAllWrite);
        Pointer pointer8 = atomicAllocator.getPointer(createUninitialized, prepareActionAllWrite);
        Pointer pointer9 = atomicAllocator.getPointer(createUninitialized2, prepareActionAllWrite);
        Pointer pointer10 = atomicAllocator.getPointer(createUninitialized3, prepareActionAllWrite);
        cuda.CUstream_st cUstream_st = new cuda.CUstream_st(prepareActionAllWrite.getOldStream());
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, cUstream_st));
        checkCuda(cuda.cudaMemsetAsync(this.weightsSpace, BIDIRECTIONAL, this.weightsSpace.limit(), cUstream_st));
        int[] iArr8 = new int[NUM_LAYERS];
        int[] iArr9 = new int[NUM_LAYERS];
        int[] iArr10 = new int[NUM_LAYERS];
        int[] iArr11 = new int[3];
        Pointer pointer11 = new Pointer();
        Pointer pointer12 = new Pointer();
        for (int i3 = BIDIRECTIONAL; i3 < NUM_LAYERS; i3 += NUM_LAYERS) {
            for (int i4 = BIDIRECTIONAL; i4 < NUM_LINEAR_LAYERS; i4 += NUM_LAYERS) {
                checkCudnn(cudnn.cudnnGetRNNLinLayerMatrixParams(this.cudnnContext, this.cudnnContext.rnnDesc, i3, cudnntensorstruct, this.cudnnContext.wDesc, this.weightsSpace, i4, this.cudnnContext.linLayerMatDesc, pointer11));
                checkCudnn(cudnn.cudnnGetFilterNdDescriptor(this.cudnnContext.linLayerMatDesc, 3, iArr8, iArr9, iArr10, iArr11));
                checkCudnn(cudnn.cudnnGetRNNLinLayerBiasParams(this.cudnnContext, this.cudnnContext.rnnDesc, i3, cudnntensorstruct, this.cudnnContext.wDesc, this.weightsSpace, i4, this.cudnnContext.linLayerBiasDesc, pointer12));
                checkCudnn(cudnn.cudnnGetFilterNdDescriptor(this.cudnnContext.linLayerBiasDesc, 3, iArr8, iArr9, iArr10, iArr11));
                switch (i4) {
                    case BIDIRECTIONAL /* 0 */:
                        pointer = pointer3;
                        i = 3;
                        j = size4;
                        break;
                    case NUM_LAYERS /* 1 */:
                        pointer = pointer3;
                        i = NUM_LAYERS;
                        j = size4;
                        break;
                    case RNN_MODE /* 2 */:
                        pointer = pointer3;
                        i = BIDIRECTIONAL;
                        j = size4;
                        break;
                    case 3:
                        pointer = pointer3;
                        i = RNN_MODE;
                        j = size4;
                        break;
                    case 4:
                        pointer = pointer4;
                        i = 3;
                        j = size2;
                        break;
                    case 5:
                        pointer = pointer4;
                        i = NUM_LAYERS;
                        j = size2;
                        break;
                    case 6:
                        pointer = pointer4;
                        i = BIDIRECTIONAL;
                        j = size2;
                        break;
                    case 7:
                        pointer = pointer4;
                        i = RNN_MODE;
                        j = size2;
                        break;
                    default:
                        throw new RuntimeException();
                }
                checkCuda(cuda.cudaMemcpyAsync(pointer11, pointer.position(i * j * size2 * this.dataTypeSize), j * size2 * this.dataTypeSize, 3, cUstream_st));
                if (i4 < 4) {
                    checkCuda(cuda.cudaMemcpyAsync(pointer12, pointer5.position(i * size2 * this.dataTypeSize), size2 * this.dataTypeSize, 3, cUstream_st));
                }
            }
        }
        if (z) {
            checkCudnn(cudnn.cudnnRNNForwardTraining(this.cudnnContext, this.cudnnContext.rnnDesc, (int) size, this.xDesc, pointer2, this.cudnnContext.hxDesc, pointer6, this.cudnnContext.cxDesc, pointer7, this.cudnnContext.wDesc, this.weightsSpace, this.yDesc, pointer8, this.cudnnContext.hyDesc, pointer10, this.cudnnContext.cyDesc, pointer9, dataCache, dataCache.limit(), this.reserveSpace, this.reserveSpace.limit()));
        } else {
            checkCudnn(cudnn.cudnnRNNForwardInference(this.cudnnContext, this.cudnnContext.rnnDesc, (int) size, this.xDesc, pointer2, this.cudnnContext.hxDesc, pointer6, this.cudnnContext.cxDesc, pointer7, this.cudnnContext.wDesc, this.weightsSpace, this.yDesc, pointer8, this.cudnnContext.hyDesc, pointer10, this.cudnnContext.cyDesc, pointer9, dataCache, dataCache.limit()));
        }
        atomicAllocator.getFlowController().registerActionAllWrite(prepareActionAllWrite, new INDArray[]{cOrder, iNDArray3, iNDArray2, iNDArray4, cOrder2, cOrder3, createUninitialized, createUninitialized2, createUninitialized3});
        fwdPassReturn.fwdPassOutput = createUninitialized.permute(new int[]{NUM_LAYERS, RNN_MODE, BIDIRECTIONAL});
        fwdPassReturn.lastAct = createUninitialized3;
        fwdPassReturn.lastMemCell = createUninitialized2;
        fwdPassReturn.prevAct = cOrder2;
        fwdPassReturn.prevMemCell = cOrder3;
        return fwdPassReturn;
    }

    public Map<String, Long> helperMemoryUse() {
        HashMap hashMap = new HashMap();
        hashMap.put("stateStace", Long.valueOf(this.stateSpace.capacity()));
        hashMap.put("reserveSpace", Long.valueOf(this.reserveSpace.capacity()));
        hashMap.put("weightsSpace", Long.valueOf(this.weightsSpace.capacity()));
        return hashMap;
    }
}
