package org.deeplearning4j.cuda.convolution;

import com.jakewharton.byteunits.BinaryByteUnit;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cudnn.cudnnActivationStruct;
import org.bytedeco.cuda.cudnn.cudnnConvolutionStruct;
import org.bytedeco.cuda.cudnn.cudnnFilterStruct;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.OneTimeLogger;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper.class */
public class CudnnConvolutionHelper extends BaseCudnnHelper implements ConvolutionHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnConvolutionHelper.class);
    private CudnnConvolutionContext cudnnContext;

    /* renamed from: org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo;
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo;
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo = new int[ConvolutionLayer.FwdAlgo.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.IMPLICIT_GEMM.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.IMPLICIT_PRECOMP_GEMM.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.GEMM.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.DIRECT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.FFT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.FFT_TILING.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.WINOGRAD.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.WINOGRAD_NONFUSED.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[ConvolutionLayer.FwdAlgo.COUNT.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo = new int[ConvolutionLayer.BwdDataAlgo.values().length];
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo[ConvolutionLayer.BwdDataAlgo.ALGO_0.ordinal()] = 1;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo[ConvolutionLayer.BwdDataAlgo.ALGO_1.ordinal()] = 2;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo[ConvolutionLayer.BwdDataAlgo.FFT.ordinal()] = 3;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo[ConvolutionLayer.BwdDataAlgo.FFT_TILING.ordinal()] = 4;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo[ConvolutionLayer.BwdDataAlgo.WINOGRAD.ordinal()] = 5;
            } catch (NoSuchFieldError e14) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo[ConvolutionLayer.BwdDataAlgo.WINOGRAD_NONFUSED.ordinal()] = 6;
            } catch (NoSuchFieldError e15) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo[ConvolutionLayer.BwdDataAlgo.COUNT.ordinal()] = 7;
            } catch (NoSuchFieldError e16) {
            }
            $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo = new int[ConvolutionLayer.BwdFilterAlgo.values().length];
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[ConvolutionLayer.BwdFilterAlgo.ALGO_0.ordinal()] = 1;
            } catch (NoSuchFieldError e17) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[ConvolutionLayer.BwdFilterAlgo.ALGO_1.ordinal()] = 2;
            } catch (NoSuchFieldError e18) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[ConvolutionLayer.BwdFilterAlgo.FFT.ordinal()] = 3;
            } catch (NoSuchFieldError e19) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[ConvolutionLayer.BwdFilterAlgo.ALGO_3.ordinal()] = 4;
            } catch (NoSuchFieldError e20) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[ConvolutionLayer.BwdFilterAlgo.WINOGRAD.ordinal()] = 5;
            } catch (NoSuchFieldError e21) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[ConvolutionLayer.BwdFilterAlgo.WINOGRAD_NONFUSED.ordinal()] = 6;
            } catch (NoSuchFieldError e22) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[ConvolutionLayer.BwdFilterAlgo.FFT_TILING.ordinal()] = 7;
            } catch (NoSuchFieldError e23) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[ConvolutionLayer.BwdFilterAlgo.COUNT.ordinal()] = 8;
            } catch (NoSuchFieldError e24) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper$CudnnConvolutionContext.class */
    private static class CudnnConvolutionContext extends BaseCudnnHelper.CudnnContext {
        private cudnnTensorStruct srcTensorDesc;
        private cudnnTensorStruct dstTensorDesc;
        private cudnnTensorStruct biasTensorDesc;
        private cudnnTensorStruct deltaTensorDesc;
        private cudnnFilterStruct filterDesc;
        private cudnnConvolutionStruct convDesc;
        private cudnnActivationStruct activationDesc;

        /* loaded from: input_file:org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper$CudnnConvolutionContext$Deallocator.class */
        private static class Deallocator extends CudnnConvolutionContext implements Pointer.Deallocator {
            Deallocator(CudnnConvolutionContext cudnnConvolutionContext) {
                super(cudnnConvolutionContext);
            }

            public void deallocate() {
                destroyHandles();
            }
        }

        public CudnnConvolutionContext() {
            this.srcTensorDesc = new cudnnTensorStruct();
            this.dstTensorDesc = new cudnnTensorStruct();
            this.biasTensorDesc = new cudnnTensorStruct();
            this.deltaTensorDesc = new cudnnTensorStruct();
            this.filterDesc = new cudnnFilterStruct();
            this.convDesc = new cudnnConvolutionStruct();
            this.activationDesc = new cudnnActivationStruct();
            createHandles();
            deallocator(new Deallocator(this));
        }

        public CudnnConvolutionContext(CudnnConvolutionContext cudnnConvolutionContext) {
            super(cudnnConvolutionContext);
            this.srcTensorDesc = new cudnnTensorStruct();
            this.dstTensorDesc = new cudnnTensorStruct();
            this.biasTensorDesc = new cudnnTensorStruct();
            this.deltaTensorDesc = new cudnnTensorStruct();
            this.filterDesc = new cudnnFilterStruct();
            this.convDesc = new cudnnConvolutionStruct();
            this.activationDesc = new cudnnActivationStruct();
            this.srcTensorDesc = new cudnnTensorStruct(cudnnConvolutionContext.srcTensorDesc);
            this.dstTensorDesc = new cudnnTensorStruct(cudnnConvolutionContext.dstTensorDesc);
            this.biasTensorDesc = new cudnnTensorStruct(cudnnConvolutionContext.biasTensorDesc);
            this.deltaTensorDesc = new cudnnTensorStruct(cudnnConvolutionContext.deltaTensorDesc);
            this.filterDesc = new cudnnFilterStruct(cudnnConvolutionContext.filterDesc);
            this.convDesc = new cudnnConvolutionStruct(cudnnConvolutionContext.convDesc);
            this.activationDesc = new cudnnActivationStruct(cudnnConvolutionContext.activationDesc);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.cuda.BaseCudnnHelper.CudnnContext
        public void createHandles() {
            super.createHandles();
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.srcTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dstTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.biasTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.deltaTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateFilterDescriptor(this.filterDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateConvolutionDescriptor(this.convDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnCreateActivationDescriptor(this.activationDesc));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.cuda.BaseCudnnHelper.CudnnContext
        public void destroyHandles() {
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyActivationDescriptor(this.activationDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyConvolutionDescriptor(this.convDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyFilterDescriptor(this.filterDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.srcTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dstTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.biasTensorDesc));
            CudnnConvolutionHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.deltaTensorDesc));
            super.destroyHandles();
        }
    }

    /* loaded from: input_file:org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper$CudnnForwardArgs.class */
    public static class CudnnForwardArgs {
        private boolean manualPadBottom;
        private boolean manualPadRight;
        private INDArray input;
        private INDArray origInput;
        private int[] padding;
        private int[] outSize;

        public CudnnForwardArgs(boolean z, boolean z2, INDArray iNDArray, INDArray iNDArray2, int[] iArr, int[] iArr2) {
            this.manualPadBottom = z;
            this.manualPadRight = z2;
            this.input = iNDArray;
            this.origInput = iNDArray2;
            this.padding = iArr;
            this.outSize = iArr2;
        }

        public boolean isManualPadBottom() {
            return this.manualPadBottom;
        }

        public boolean isManualPadRight() {
            return this.manualPadRight;
        }

        public INDArray getInput() {
            return this.input;
        }

        public INDArray getOrigInput() {
            return this.origInput;
        }

        public int[] getPadding() {
            return this.padding;
        }

        public int[] getOutSize() {
            return this.outSize;
        }

        public void setManualPadBottom(boolean z) {
            this.manualPadBottom = z;
        }

        public void setManualPadRight(boolean z) {
            this.manualPadRight = z;
        }

        public void setInput(INDArray iNDArray) {
            this.input = iNDArray;
        }

        public void setOrigInput(INDArray iNDArray) {
            this.origInput = iNDArray;
        }

        public void setPadding(int[] iArr) {
            this.padding = iArr;
        }

        public void setOutSize(int[] iArr) {
            this.outSize = iArr;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof CudnnForwardArgs)) {
                return false;
            }
            CudnnForwardArgs cudnnForwardArgs = (CudnnForwardArgs) obj;
            if (!cudnnForwardArgs.canEqual(this) || isManualPadBottom() != cudnnForwardArgs.isManualPadBottom() || isManualPadRight() != cudnnForwardArgs.isManualPadRight()) {
                return false;
            }
            INDArray input = getInput();
            INDArray input2 = cudnnForwardArgs.getInput();
            if (input == null) {
                if (input2 != null) {
                    return false;
                }
            } else if (!input.equals(input2)) {
                return false;
            }
            INDArray origInput = getOrigInput();
            INDArray origInput2 = cudnnForwardArgs.getOrigInput();
            if (origInput == null) {
                if (origInput2 != null) {
                    return false;
                }
            } else if (!origInput.equals(origInput2)) {
                return false;
            }
            return Arrays.equals(getPadding(), cudnnForwardArgs.getPadding()) && Arrays.equals(getOutSize(), cudnnForwardArgs.getOutSize());
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof CudnnForwardArgs;
        }

        public int hashCode() {
            int i = (((1 * 59) + (isManualPadBottom() ? 79 : 97)) * 59) + (isManualPadRight() ? 79 : 97);
            INDArray input = getInput();
            int hashCode = (i * 59) + (input == null ? 43 : input.hashCode());
            INDArray origInput = getOrigInput();
            return (((((hashCode * 59) + (origInput == null ? 43 : origInput.hashCode())) * 59) + Arrays.hashCode(getPadding())) * 59) + Arrays.hashCode(getOutSize());
        }

        public String toString() {
            return "CudnnConvolutionHelper.CudnnForwardArgs(manualPadBottom=" + isManualPadBottom() + ", manualPadRight=" + isManualPadRight() + ", input=" + getInput() + ", origInput=" + getOrigInput() + ", padding=" + Arrays.toString(getPadding()) + ", outSize=" + Arrays.toString(getOutSize()) + ")";
        }
    }

    public CudnnConvolutionHelper(DataType dataType) {
        super(dataType);
        this.cudnnContext = new CudnnConvolutionContext();
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, int[] iArr, int[] iArr2, int[] iArr3, INDArray iNDArray5, INDArray iNDArray6, IActivation iActivation, ConvolutionLayer.AlgoMode algoMode, ConvolutionLayer.BwdFilterAlgo bwdFilterAlgo, ConvolutionLayer.BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] iArr4, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr) {
        boolean z = false;
        if (cNN2DFormat == CNN2DFormat.NHWC) {
            iNDArray = iNDArray.permute(new int[]{0, 3, 1, 2});
            iNDArray4 = iNDArray4.permute(new int[]{0, 3, 1, 2});
            z = true;
        }
        long size = iNDArray.size(0);
        long size2 = iNDArray2.size(0);
        long size3 = iNDArray2.size(1);
        long size4 = iNDArray2.size(2);
        long size5 = iNDArray2.size(3);
        CudnnForwardArgs cudnnForwardArgs = getCudnnForwardArgs(iNDArray, iArr, iArr2, iArr3, iArr4, convolutionMode, null, CNN2DFormat.NCHW);
        INDArray input = cudnnForwardArgs.getInput();
        long size6 = input.size(2);
        long size7 = input.size(3);
        long[] stride = input.stride();
        int[] outSize = cudnnForwardArgs.getOutSize();
        int i = outSize[0];
        int i2 = outSize[1];
        if (!Shape.strideDescendingCAscendingF(iNDArray4)) {
            iNDArray4 = iNDArray4.dup();
        }
        long[] stride2 = iNDArray4.stride();
        int[] iArr5 = new int[1];
        int[] iArr6 = new int[1];
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        checkCudnn(false, "cudnnSetTensor4dDescriptorEx", cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.srcTensorDesc, this.dataType, (int) size, (int) size3, (int) size6, (int) size7, (int) stride[0], (int) stride[1], (int) stride[2], (int) stride[3]), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        checkCudnn(false, "cudnnSetTensor4dDescriptorEx", cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.deltaTensorDesc, this.dataType, (int) size, (int) size2, i, i2, (int) stride2[0], (int) stride2[1], (int) stride2[2], (int) stride2[3]), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        checkCudnn(false, "cudnnSetConvolution2dDescriptor", cudnn.cudnnSetConvolution2dDescriptor(this.cudnnContext.convDesc, iArr3[0], iArr3[1], iArr2[0], iArr2[1], iArr4[0], iArr4[1], 1, this.dataType), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        checkCudnn(false, "cudnnSetFilter4dDescriptor", cudnn.cudnnSetFilter4dDescriptor(this.cudnnContext.filterDesc, this.dataType, 0, (int) size2, (int) size3, (int) size4, (int) size5), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        if (algoMode != ConvolutionLayer.AlgoMode.USER_SPECIFIED || bwdFilterAlgo == null || bwdDataAlgo == null) {
            checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", cudnn.cudnnGetConvolutionBackwardFilterAlgorithm(this.cudnnContext, this.cudnnContext.srcTensorDesc, this.cudnnContext.deltaTensorDesc, this.cudnnContext.convDesc, this.cudnnContext.filterDesc, algoMode == ConvolutionLayer.AlgoMode.NO_WORKSPACE ? 0 : 1, 0L, iArr5), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
            checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", cudnn.cudnnGetConvolutionBackwardDataAlgorithm(this.cudnnContext, this.cudnnContext.filterDesc, this.cudnnContext.deltaTensorDesc, this.cudnnContext.convDesc, this.cudnnContext.srcTensorDesc, algoMode == ConvolutionLayer.AlgoMode.NO_WORKSPACE ? 0 : 1, 0L, iArr6), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        } else {
            switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdFilterAlgo[bwdFilterAlgo.ordinal()]) {
                case 1:
                    iArr5[0] = 0;
                    break;
                case 2:
                    iArr5[0] = 1;
                    break;
                case 3:
                    iArr5[0] = 2;
                    break;
                case 4:
                    iArr5[0] = 3;
                    break;
                case 5:
                    iArr5[0] = 4;
                    break;
                case 6:
                    iArr5[0] = 5;
                    break;
                case 7:
                    iArr5[0] = 6;
                    break;
                case 8:
                    iArr5[0] = 7;
                    break;
                default:
                    throw new IllegalArgumentException("Unknown BwdFilterAlgo: " + bwdFilterAlgo);
            }
            switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$BwdDataAlgo[bwdDataAlgo.ordinal()]) {
                case 1:
                    iArr6[0] = 0;
                    break;
                case 2:
                    iArr6[0] = 1;
                    break;
                case 3:
                    iArr6[0] = 2;
                    break;
                case 4:
                    iArr6[0] = 3;
                    break;
                case 5:
                    iArr6[0] = 4;
                    break;
                case 6:
                    iArr6[0] = 5;
                    break;
                case 7:
                    iArr6[0] = 6;
                    break;
                default:
                    throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo);
            }
        }
        if (log.isTraceEnabled()) {
            log.trace("CudnnConvolutionHelper backward algorithm selection: mode {}, filter algorithm {}, data algorithm {}", new Object[]{algoMode, ConvolutionLayer.BwdFilterAlgo.values()[iArr5[0]], ConvolutionLayer.BwdDataAlgo.values()[iArr6[0]]});
        }
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, iNDArray2.dataType(), new long[]{(int) size, (int) size3, (int) size6, (int) size7}, 'c');
        long[] stride3 = createUninitialized.stride();
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareActionAllWrite = atomicAllocator.getFlowController().prepareActionAllWrite(new INDArray[]{input, iNDArray2, iNDArray6, iNDArray5, iNDArray4, createUninitialized});
        Pointer pointer = atomicAllocator.getPointer(input, prepareActionAllWrite);
        Pointer pointer2 = atomicAllocator.getPointer(iNDArray2, prepareActionAllWrite);
        Pointer pointer3 = atomicAllocator.getPointer(iNDArray6, prepareActionAllWrite);
        Pointer pointer4 = atomicAllocator.getPointer(iNDArray5, prepareActionAllWrite);
        Pointer pointer5 = atomicAllocator.getPointer(iNDArray4, prepareActionAllWrite);
        Pointer pointer6 = atomicAllocator.getPointer(createUninitialized, prepareActionAllWrite);
        checkCudnn(false, "cudnnSetStream", cudnn.cudnnSetStream(this.cudnnContext, new CUstream_st(prepareActionAllWrite.getCublasStream())), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        checkCudnn(false, "cudnnSetTensor4dDescriptorEx", cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.dstTensorDesc, this.dataType, (int) size, (int) size3, (int) size6, (int) size7, (int) stride3[0], (int) stride3[1], (int) stride3[2], (int) stride3[3]), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", cudnn.cudnnGetConvolutionBackwardFilterWorkspaceSize(this.cudnnContext, this.cudnnContext.srcTensorDesc, this.cudnnContext.deltaTensorDesc, this.cudnnContext.convDesc, this.cudnnContext.filterDesc, iArr5[0], this.sizeInBytes), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        long j = this.sizeInBytes.get(0L);
        checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", cudnn.cudnnGetConvolutionBackwardDataWorkspaceSize(this.cudnnContext, this.cudnnContext.filterDesc, this.cudnnContext.deltaTensorDesc, this.cudnnContext.convDesc, this.cudnnContext.dstTensorDesc, iArr6[0], this.sizeInBytes), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        BaseCudnnHelper.DataCache dataCache = (BaseCudnnHelper.DataCache) layerWorkspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
        long j2 = this.sizeInBytes.get(0L);
        if (dataCache == null || j > dataCache.capacity() || j2 > dataCache.capacity()) {
            long max = Math.max(j, j2);
            if (log.isTraceEnabled()) {
                if (dataCache == null) {
                    log.trace("CudnnConvolutionHelper backpropGradient: Allocating initial workspace of size {} ({})", Long.valueOf(max), BinaryByteUnit.format(max, "#.00"));
                } else {
                    log.trace("CudnnConvolutionHelper backpropGradient: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", new Object[]{Long.valueOf(dataCache.capacity()), BinaryByteUnit.format(dataCache.capacity(), "#.00"), Long.valueOf(max), BinaryByteUnit.format(max, "#.00")});
                }
            }
            if (dataCache != null) {
                dataCache.deallocate();
            }
            dataCache = new BaseCudnnHelper.DataCache(max);
            layerWorkspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, dataCache);
        }
        checkCudnn(false, "cudnnSetTensor4dDescriptor", cudnn.cudnnSetTensor4dDescriptor(this.cudnnContext.biasTensorDesc, 0, this.dataType, 1, (int) size2, 1, 1), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        checkCudnn(false, "cudnnConvolutionBackwardBias", cudnn.cudnnConvolutionBackwardBias(this.cudnnContext, this.alpha, this.cudnnContext.deltaTensorDesc, pointer5, this.beta, this.cudnnContext.biasTensorDesc, pointer4), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        checkCudnn(false, "cudnnConvolutionBackwardFilter", cudnn.cudnnConvolutionBackwardFilter(this.cudnnContext, this.alpha, this.cudnnContext.srcTensorDesc, pointer, this.cudnnContext.deltaTensorDesc, pointer5, this.cudnnContext.convDesc, iArr5[0], dataCache, dataCache.capacity(), this.beta, this.cudnnContext.filterDesc, pointer3), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        checkCudnn(false, "cudnnConvolutionBackwardData", cudnn.cudnnConvolutionBackwardData(this.cudnnContext, this.alpha, this.cudnnContext.filterDesc, pointer2, this.cudnnContext.deltaTensorDesc, pointer5, this.cudnnContext.convDesc, iArr6[0], dataCache, dataCache.capacity(), this.beta, this.cudnnContext.dstTensorDesc, pointer6), input, iNDArray2, null, iNDArray4, iArr, iArr2, iArr3, algoMode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, iArr4);
        atomicAllocator.getFlowController().registerActionAllWrite(prepareActionAllWrite, new INDArray[]{input, iNDArray2, iNDArray6, iNDArray5, iNDArray4, createUninitialized});
        DefaultGradient defaultGradient = new DefaultGradient();
        defaultGradient.setGradientFor("b", iNDArray5);
        defaultGradient.setGradientFor("W", iNDArray6, 'c');
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            prepareActionAllWrite.syncOldStream();
        }
        if (cudnnForwardArgs.isManualPadBottom() || cudnnForwardArgs.isManualPadRight()) {
            INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[4];
            iNDArrayIndexArr[0] = NDArrayIndex.all();
            iNDArrayIndexArr[1] = NDArrayIndex.all();
            iNDArrayIndexArr[2] = NDArrayIndex.interval(0L, createUninitialized.size(2) - (cudnnForwardArgs.isManualPadBottom() ? 1 : 0));
            iNDArrayIndexArr[3] = NDArrayIndex.interval(0L, createUninitialized.size(3) - (cudnnForwardArgs.isManualPadRight() ? 1 : 0));
            createUninitialized = createUninitialized.get(iNDArrayIndexArr);
        }
        if (z) {
            createUninitialized = createUninitialized.permute(new int[]{0, 2, 3, 1});
        }
        return new Pair<>(defaultGradient, createUninitialized);
    }

    public INDArray preOutput(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionLayer.AlgoMode algoMode, ConvolutionLayer.FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] iArr4, CNN2DFormat cNN2DFormat, LayerWorkspaceMgr layerWorkspaceMgr) {
        boolean z = false;
        if (cNN2DFormat == CNN2DFormat.NHWC) {
            iNDArray = iNDArray.permute(new int[]{0, 3, 1, 2});
            z = true;
        }
        long size = iNDArray.size(0);
        long size2 = iNDArray2.size(0);
        long size3 = iNDArray2.size(1);
        long size4 = iNDArray2.size(2);
        long size5 = iNDArray2.size(3);
        CudnnForwardArgs cudnnForwardArgs = getCudnnForwardArgs(iNDArray, iArr, iArr2, iArr3, iArr4, convolutionMode, null, CNN2DFormat.NCHW);
        INDArray input = cudnnForwardArgs.getInput();
        long size6 = input.size(2);
        long size7 = input.size(3);
        long[] stride = input.stride();
        int[] outSize = cudnnForwardArgs.getOutSize();
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, iNDArray2.dataType(), new long[]{(int) size, (int) size2, outSize[0], outSize[1]});
        checkCudnn(true, "cudnnSetTensor4dDescriptorEx", cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.srcTensorDesc, this.dataType, (int) size, (int) size3, (int) size6, (int) size7, (int) stride[0], (int) stride[1], (int) stride[2], (int) stride[3]), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        checkCudnn(true, "cudnnSetFilter4dDescriptor", cudnn.cudnnSetFilter4dDescriptor(this.cudnnContext.filterDesc, this.dataType, 0, (int) size2, (int) size3, (int) size4, (int) size5), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        checkCudnn(true, "cudnnSetConvolution2dDescriptor", cudnn.cudnnSetConvolution2dDescriptor(this.cudnnContext.convDesc, iArr3[0], iArr3[1], iArr2[0], iArr2[1], iArr4[0], iArr4[1], 1, this.dataType), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        int[] iArr5 = new int[1];
        long[] stride2 = createUninitialized.stride();
        checkCudnn(true, "cudnnSetTensor4dDescriptorEx", cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.dstTensorDesc, this.dataType, (int) size, (int) size2, outSize[0], outSize[1], (int) stride2[0], (int) stride2[1], (int) stride2[2], (int) stride2[3]), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        if (algoMode != ConvolutionLayer.AlgoMode.USER_SPECIFIED || fwdAlgo == null) {
            if (cudnn.cudnnGetConvolutionForwardAlgorithm(this.cudnnContext, this.cudnnContext.srcTensorDesc, this.cudnnContext.filterDesc, this.cudnnContext.convDesc, this.cudnnContext.dstTensorDesc, algoMode == ConvolutionLayer.AlgoMode.NO_WORKSPACE ? 0 : 1, 0L, iArr5) != 0) {
                OneTimeLogger.warn(log, "Error getting CuDNN forward algorithm - falling back on IMPLICIT_GEMM", new Object[0]);
                algoMode = ConvolutionLayer.AlgoMode.USER_SPECIFIED;
                fwdAlgo = ConvolutionLayer.FwdAlgo.IMPLICIT_GEMM;
                iArr5[0] = 0;
            }
        } else {
            switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$layers$ConvolutionLayer$FwdAlgo[fwdAlgo.ordinal()]) {
                case 1:
                    iArr5[0] = 0;
                    break;
                case 2:
                    iArr5[0] = 1;
                    break;
                case 3:
                    iArr5[0] = 2;
                    break;
                case 4:
                    iArr5[0] = 3;
                    break;
                case 5:
                    iArr5[0] = 4;
                    break;
                case 6:
                    iArr5[0] = 5;
                    break;
                case 7:
                    iArr5[0] = 6;
                    break;
                case 8:
                    iArr5[0] = 7;
                    break;
                case 9:
                    iArr5[0] = 8;
                    break;
                default:
                    throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo);
            }
        }
        if (log.isTraceEnabled()) {
            log.trace("CudnnConvolutionHelper forward algorithm selection: mode {}, algorithm {}", algoMode, ConvolutionLayer.FwdAlgo.values()[iArr5[0]]);
        }
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(createUninitialized, new INDArray[]{input, iNDArray2, iNDArray3});
        Pointer pointer = atomicAllocator.getPointer(input, prepareAction);
        Pointer pointer2 = atomicAllocator.getPointer(iNDArray2, prepareAction);
        Pointer pointer3 = atomicAllocator.getPointer(iNDArray3, prepareAction);
        Pointer pointer4 = atomicAllocator.getPointer(createUninitialized, prepareAction);
        checkCudnn(true, "cudnnSetStream", cudnn.cudnnSetStream(this.cudnnContext, new CUstream_st(prepareAction.getCublasStream())), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", cudnn.cudnnGetConvolutionForwardWorkspaceSize(this.cudnnContext, this.cudnnContext.srcTensorDesc, this.cudnnContext.filterDesc, this.cudnnContext.convDesc, this.cudnnContext.dstTensorDesc, iArr5[0], this.sizeInBytes), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        BaseCudnnHelper.DataCache dataCache = (BaseCudnnHelper.DataCache) layerWorkspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY);
        if (dataCache == null || this.sizeInBytes.get(0L) > dataCache.capacity()) {
            if (log.isTraceEnabled()) {
                if (dataCache == null) {
                    log.trace("CudnnConvolutionHelper preOutput: allocating initial workspace of size {} ({})", Long.valueOf(this.sizeInBytes.get()), BinaryByteUnit.format(this.sizeInBytes.get(), "#.00"));
                } else {
                    log.trace("CudnnConvolutionHelper preOutput: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", new Object[]{Long.valueOf(dataCache.capacity()), BinaryByteUnit.format(dataCache.capacity(), "#.00"), Long.valueOf(this.sizeInBytes.get()), BinaryByteUnit.format(this.sizeInBytes.get(), "#.00")});
                }
            }
            if (dataCache != null) {
                dataCache.deallocate();
            }
            dataCache = new BaseCudnnHelper.DataCache(this.sizeInBytes.get(0L));
            layerWorkspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, dataCache);
        }
        checkCudnn(true, "cudnnConvolutionForward", cudnn.cudnnConvolutionForward(this.cudnnContext, this.alpha, this.cudnnContext.srcTensorDesc, pointer, this.cudnnContext.filterDesc, pointer2, this.cudnnContext.convDesc, iArr5[0], dataCache, dataCache.capacity(), this.beta, this.cudnnContext.dstTensorDesc, pointer4), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        checkCudnn(true, "cudnnSetTensor4dDescriptor", cudnn.cudnnSetTensor4dDescriptor(this.cudnnContext.biasTensorDesc, 0, this.dataType, 1, (int) size2, 1, 1), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        checkCudnn(true, "cudnnAddTensor", cudnn.cudnnAddTensor(this.cudnnContext, this.alpha, this.cudnnContext.biasTensorDesc, pointer3, this.alpha, this.cudnnContext.dstTensorDesc, pointer4), input, iNDArray2, iNDArray3, null, iArr, iArr2, iArr3, algoMode, fwdAlgo, null, null, convolutionMode, iArr4);
        atomicAllocator.registerAction(prepareAction, createUninitialized, new INDArray[]{input, iNDArray2, iNDArray3});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            prepareAction.syncOldStream();
        }
        if (z) {
            createUninitialized = createUninitialized.permute(new int[]{0, 2, 3, 1});
        }
        return createUninitialized;
    }

    private void checkCudnn(boolean z, String str, int i, INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, int[] iArr, int[] iArr2, int[] iArr3, ConvolutionLayer.AlgoMode algoMode, ConvolutionLayer.FwdAlgo fwdAlgo, ConvolutionLayer.BwdFilterAlgo bwdFilterAlgo, ConvolutionLayer.BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] iArr4) {
        if (i != 0) {
            StringBuilder sb = new StringBuilder();
            sb.append("CuDNN error = ").append(i).append(": ").append(cudnn.cudnnGetErrorString(i).getString()).append(" during ").append(z ? "forward pass" : "backward pass").append(" - step ").append(str).append(": inputShape=").append(Arrays.toString(iNDArray.shape())).append(", weightsShape=").append(Arrays.toString(iNDArray2.shape())).append(", biasShape=").append(iNDArray3 == null ? null : Arrays.toString(iNDArray3.shape()));
            if (!z) {
                sb.append(", gradientShape=").append(Arrays.toString(iNDArray4.shape()));
            }
            sb.append(", kernel=").append(Arrays.toString(iArr)).append(", stride=").append(Arrays.toString(iArr2)).append(", padding=").append(Arrays.toString(iArr3)).append(", dilation=").append(Arrays.toString(iArr4)).append(", AlgoMode=").append(algoMode);
            if (z) {
                sb.append(", fwdAlgo=").append(fwdAlgo);
            } else {
                sb.append(", bwdFilterAlgo=").append(bwdFilterAlgo).append(", bwdDataAlgo=").append(bwdDataAlgo);
            }
            sb.append(", convolutionMode=").append(convolutionMode);
            throw new RuntimeException(sb.toString());
        }
    }

    public INDArray activate(INDArray iNDArray, IActivation iActivation, boolean z) {
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        INDArray iNDArray2 = iNDArray;
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray, new INDArray[0]);
        Pointer pointer = atomicAllocator.getPointer(iNDArray, prepareAction);
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, new CUstream_st(prepareAction.getCublasStream())));
        String obj = iActivation.toString();
        boolean z2 = -1;
        switch (obj.hashCode()) {
            case -2035660550:
                if (obj.equals("softmax")) {
                    z2 = 4;
                    break;
                }
                break;
            case -1427427018:
                if (obj.equals("logsoftmax")) {
                    z2 = 5;
                    break;
                }
                break;
            case -135761730:
                if (obj.equals("identity")) {
                    z2 = false;
                    break;
                }
                break;
            case 3496700:
                if (obj.equals("relu")) {
                    z2 = 2;
                    break;
                }
                break;
            case 3552487:
                if (obj.equals("tanh")) {
                    z2 = 3;
                    break;
                }
                break;
            case 2088248974:
                if (obj.equals("sigmoid")) {
                    z2 = true;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                break;
            case true:
                checkCudnn(cudnn.cudnnSetActivationDescriptor(this.cudnnContext.activationDesc, 0, 1, 0.0d));
                checkCudnn(cudnn.cudnnActivationForward(this.cudnnContext, this.cudnnContext.activationDesc, this.alpha, this.cudnnContext.dstTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer));
                break;
            case true:
                checkCudnn(cudnn.cudnnSetActivationDescriptor(this.cudnnContext.activationDesc, 1, 1, 0.0d));
                checkCudnn(cudnn.cudnnActivationForward(this.cudnnContext, this.cudnnContext.activationDesc, this.alpha, this.cudnnContext.dstTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer));
                break;
            case true:
                checkCudnn(cudnn.cudnnSetActivationDescriptor(this.cudnnContext.activationDesc, 2, 1, 0.0d));
                checkCudnn(cudnn.cudnnActivationForward(this.cudnnContext, this.cudnnContext.activationDesc, this.alpha, this.cudnnContext.dstTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer));
                break;
            case true:
                checkCudnn(cudnn.cudnnSoftmaxForward(this.cudnnContext, 1, 1, this.alpha, this.cudnnContext.dstTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer));
                break;
            case true:
                checkCudnn(cudnn.cudnnSoftmaxForward(this.cudnnContext, 2, 1, this.alpha, this.cudnnContext.dstTensorDesc, pointer, this.beta, this.cudnnContext.dstTensorDesc, pointer));
                break;
            default:
                iNDArray2 = null;
                break;
        }
        atomicAllocator.registerAction(prepareAction, iNDArray2, new INDArray[0]);
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            prepareAction.syncOldStream();
        }
        return iNDArray2;
    }

    public static CudnnForwardArgs getCudnnForwardArgs(INDArray iNDArray, int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat cNN2DFormat) {
        int[] outputSize;
        long[] jArr;
        if (iNDArray.isView() || !Shape.hasDefaultStridesForShape(iNDArray)) {
            iNDArray = iNDArray.dup('c');
        }
        boolean z = cNN2DFormat == CNN2DFormat.NCHW;
        int i = z ? 2 : 1;
        int i2 = z ? 3 : 2;
        long size = iNDArray.size(i);
        long size2 = iNDArray.size(i2);
        boolean z2 = false;
        boolean z3 = false;
        if (convolutionMode == ConvolutionMode.Same) {
            outputSize = ConvolutionUtils.getOutputSize(iNDArray, iArr, iArr2, (int[]) null, convolutionMode, iArr4, cNN2DFormat);
            iArr3 = ConvolutionUtils.getSameModeTopLeftPadding(outputSize, new int[]{(int) size, (int) size2}, iArr, iArr2, iArr4);
            int[] sameModeBottomRightPadding = ConvolutionUtils.getSameModeBottomRightPadding(outputSize, new int[]{(int) size, (int) size2}, iArr, iArr2, iArr4);
            if (!Arrays.equals(iArr3, sameModeBottomRightPadding)) {
                z2 = iArr3[0] != sameModeBottomRightPadding[0];
                z3 = iArr3[1] != sameModeBottomRightPadding[1];
                if (z) {
                    long[] jArr2 = new long[4];
                    jArr2[0] = iNDArray.size(0);
                    jArr2[1] = iNDArray.size(1);
                    jArr2[2] = iNDArray.size(2) + (z2 ? 1 : 0);
                    jArr2[3] = iNDArray.size(3) + (z3 ? 1 : 0);
                    jArr = jArr2;
                } else {
                    long[] jArr3 = new long[4];
                    jArr3[0] = iNDArray.size(0);
                    jArr3[1] = iNDArray.size(1) + (z2 ? 1 : 0);
                    jArr3[2] = iNDArray.size(2) + (z3 ? 1 : 0);
                    jArr3[3] = iNDArray.size(3);
                    jArr = jArr3;
                }
                INDArray create = (poolingType == null || poolingType != PoolingType.MAX) ? Nd4j.create(iNDArray.dataType(), jArr) : Nd4j.valueArrayOf(jArr, Double.NEGATIVE_INFINITY, iNDArray.dataType());
                if (z) {
                    create.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0L, iNDArray.size(2)), NDArrayIndex.interval(0L, iNDArray.size(3))}, iNDArray);
                } else {
                    create.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0L, iNDArray.size(1)), NDArrayIndex.interval(0L, iNDArray.size(2)), NDArrayIndex.all()}, iNDArray);
                }
                iNDArray = create;
            }
        } else {
            outputSize = ConvolutionUtils.getOutputSize(iNDArray, iArr, iArr2, iArr3, convolutionMode, iArr4, cNN2DFormat);
        }
        return new CudnnForwardArgs(z2, z3, iNDArray, iNDArray, iArr3, outputSize);
    }

    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }
}
