package org.deeplearning4j.nn.layers.normalization;

import java.util.HashMap;
import java.util.Map;
import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.Pointer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseCudnnHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper.class */
public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements BatchNormalizationHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnBatchNormalizationHelper.class);
    protected final int batchNormMode = 1;
    private CudnnBatchNormalizationContext cudnnContext;
    private INDArray meanCache;
    private INDArray varCache;
    private double eps;

    /* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper$CudnnBatchNormalizationContext.class */
    private static class CudnnBatchNormalizationContext extends BaseCudnnHelper.CudnnContext {
        private cudnnTensorStruct srcTensorDesc;
        private cudnnTensorStruct dstTensorDesc;
        private cudnnTensorStruct deltaTensorDesc;
        private cudnnTensorStruct gammaBetaTensorDesc;

        /* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/CudnnBatchNormalizationHelper$CudnnBatchNormalizationContext$Deallocator.class */
        private static class Deallocator extends CudnnBatchNormalizationContext implements Pointer.Deallocator {
            Deallocator(CudnnBatchNormalizationContext cudnnBatchNormalizationContext) {
                super(cudnnBatchNormalizationContext);
            }

            public void deallocate() {
                destroyHandles();
            }
        }

        public CudnnBatchNormalizationContext() {
            this.srcTensorDesc = new cudnnTensorStruct();
            this.dstTensorDesc = new cudnnTensorStruct();
            this.deltaTensorDesc = new cudnnTensorStruct();
            this.gammaBetaTensorDesc = new cudnnTensorStruct();
            createHandles();
            deallocator(new Deallocator(this));
        }

        public CudnnBatchNormalizationContext(CudnnBatchNormalizationContext cudnnBatchNormalizationContext) {
            super(cudnnBatchNormalizationContext);
            this.srcTensorDesc = new cudnnTensorStruct();
            this.dstTensorDesc = new cudnnTensorStruct();
            this.deltaTensorDesc = new cudnnTensorStruct();
            this.gammaBetaTensorDesc = new cudnnTensorStruct();
            this.srcTensorDesc = new cudnnTensorStruct(cudnnBatchNormalizationContext.srcTensorDesc);
            this.dstTensorDesc = new cudnnTensorStruct(cudnnBatchNormalizationContext.dstTensorDesc);
            this.deltaTensorDesc = new cudnnTensorStruct(cudnnBatchNormalizationContext.deltaTensorDesc);
            this.gammaBetaTensorDesc = new cudnnTensorStruct(cudnnBatchNormalizationContext.gammaBetaTensorDesc);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.nn.layers.BaseCudnnHelper.CudnnContext
        public void createHandles() {
            super.createHandles();
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.srcTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dstTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.deltaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.gammaBetaTensorDesc));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.nn.layers.BaseCudnnHelper.CudnnContext
        public void destroyHandles() {
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.srcTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dstTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.deltaTensorDesc));
            CudnnBatchNormalizationHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.gammaBetaTensorDesc));
            super.destroyHandles();
        }
    }

    public CudnnBatchNormalizationHelper(DataType dataType) {
        super(dataType);
        this.batchNormMode = 1;
        this.cudnnContext = new CudnnBatchNormalizationContext();
    }

    public boolean checkSupported(double d, boolean z) {
        boolean checkSupported = checkSupported();
        if (d < 1.0E-5d) {
            checkSupported = false;
            log.warn("Not supported: eps < CUDNN_BN_MIN_EPSILON (" + d + " < 1.0E-5)");
        }
        return checkSupported;
    }

    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray, INDArray iNDArray2, int[] iArr, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, double d, LayerWorkspaceMgr layerWorkspaceMgr) {
        this.eps = d;
        int size = (int) iNDArray.size(0);
        int size2 = (int) iNDArray.size(1);
        int size3 = (int) iNDArray.size(2);
        int size4 = (int) iNDArray.size(3);
        boolean z = iNDArray.dataType() == DataType.HALF;
        INDArray iNDArray6 = null;
        INDArray iNDArray7 = null;
        INDArray iNDArray8 = null;
        if (z) {
            iNDArray6 = iNDArray3;
            iNDArray7 = iNDArray4;
            iNDArray8 = iNDArray5;
            iNDArray3 = iNDArray3.castTo(DataType.FLOAT);
            iNDArray4 = iNDArray4.castTo(DataType.FLOAT);
            iNDArray5 = iNDArray5.castTo(DataType.FLOAT);
        }
        DefaultGradient defaultGradient = new DefaultGradient();
        if (!Shape.hasDefaultStridesForShape(iNDArray2)) {
            iNDArray2 = iNDArray2.dup('c');
        }
        int[] ints = ArrayUtil.toInts(iNDArray.stride());
        int[] ints2 = ArrayUtil.toInts(iNDArray2.stride());
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.srcTensorDesc, this.dataType, size, size2, size3, size4, ints[0], ints[1], ints[2], ints[3]));
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.deltaTensorDesc, this.dataType, size, size2, size3, size4, ints2[0], ints2[1], ints2[2], ints2[3]));
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, iNDArray.dataType(), new long[]{size, size2, size3, size4}, 'c');
        int[] ints3 = ArrayUtil.toInts(createUninitialized.stride());
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.dstTensorDesc, this.dataType, size, size2, size3, size4, ints3[0], ints3[1], ints3[2], ints3[3]));
        checkCudnn(cudnn.cudnnSetTensor4dDescriptor(this.cudnnContext.gammaBetaTensorDesc, 0, toCudnnDataType(iNDArray3.data().dataType()), iArr[0], iArr[1], iArr.length > 2 ? iArr[2] : 1, iArr.length > 3 ? iArr[3] : 1));
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareActionAllWrite = atomicAllocator.getFlowController().prepareActionAllWrite(new INDArray[]{iNDArray, iNDArray2, createUninitialized, iNDArray3, iNDArray4, iNDArray5});
        Pointer pointer = atomicAllocator.getPointer(iNDArray, prepareActionAllWrite);
        Pointer pointer2 = atomicAllocator.getPointer(iNDArray2, prepareActionAllWrite);
        Pointer pointer3 = atomicAllocator.getPointer(createUninitialized, prepareActionAllWrite);
        Pointer pointer4 = atomicAllocator.getPointer(iNDArray3, prepareActionAllWrite);
        Pointer pointer5 = atomicAllocator.getPointer(iNDArray4, prepareActionAllWrite);
        Pointer pointer6 = atomicAllocator.getPointer(iNDArray5, prepareActionAllWrite);
        Pointer pointer7 = atomicAllocator.getPointer(this.meanCache, prepareActionAllWrite);
        Pointer pointer8 = atomicAllocator.getPointer(this.varCache, prepareActionAllWrite);
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, new CUstream_st(prepareActionAllWrite.getCublasStream())));
        checkCudnn(cudnn.cudnnBatchNormalizationBackward(this.cudnnContext, 1, this.alpha, this.beta, this.alpha, this.alpha, this.cudnnContext.srcTensorDesc, pointer, this.cudnnContext.deltaTensorDesc, pointer2, this.cudnnContext.dstTensorDesc, pointer3, this.cudnnContext.gammaBetaTensorDesc, pointer4, pointer5, pointer6, d, pointer7, pointer8));
        atomicAllocator.getFlowController().registerActionAllWrite(prepareActionAllWrite, new INDArray[]{iNDArray, iNDArray2, createUninitialized, iNDArray3, iNDArray4, iNDArray5});
        defaultGradient.setGradientFor("gamma", iNDArray4);
        defaultGradient.setGradientFor("beta", iNDArray5);
        prepareActionAllWrite.syncOldStream();
        if (z) {
            iNDArray6.assign(iNDArray3.castTo(DataType.HALF));
            iNDArray7.assign(iNDArray4.castTo(DataType.HALF));
            iNDArray8.assign(iNDArray5.castTo(DataType.HALF));
        }
        return new Pair<>(defaultGradient, createUninitialized);
    }

    public INDArray preOutput(INDArray iNDArray, boolean z, int[] iArr, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, double d, double d2, LayerWorkspaceMgr layerWorkspaceMgr) {
        MemoryWorkspace scopeOutOfWorkspaces;
        this.eps = d2;
        boolean z2 = iNDArray.dataType() == DataType.HALF;
        if (z2) {
            iNDArray2 = iNDArray2.castTo(DataType.FLOAT);
            iNDArray3 = iNDArray3.castTo(DataType.FLOAT);
            iNDArray4 = iNDArray4.castTo(DataType.FLOAT);
            iNDArray5 = iNDArray5.castTo(DataType.FLOAT);
        }
        int size = (int) iNDArray.size(0);
        int size2 = (int) iNDArray.size(1);
        int size3 = (int) iNDArray.size(2);
        int size4 = (int) iNDArray.size(3);
        int[] ints = ArrayUtil.toInts(iNDArray.stride());
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.srcTensorDesc, this.dataType, size, size2, size3, size4, ints[0], ints[1], ints[2], ints[3]));
        INDArray createUninitialized = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, iNDArray.dataType(), new long[]{size, size2, size3, size4}, 'c');
        int[] ints2 = ArrayUtil.toInts(createUninitialized.stride());
        checkCudnn(cudnn.cudnnSetTensor4dDescriptorEx(this.cudnnContext.dstTensorDesc, this.dataType, size, size2, size3, size4, ints2[0], ints2[1], ints2[2], ints2[3]));
        checkCudnn(cudnn.cudnnSetTensor4dDescriptor(this.cudnnContext.gammaBetaTensorDesc, 0, toCudnnDataType(iNDArray4.data().dataType()), iArr[0], iArr[1], iArr.length > 2 ? iArr[2] : 1, iArr.length > 3 ? iArr[3] : 1));
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareActionAllWrite = atomicAllocator.getFlowController().prepareActionAllWrite(new INDArray[]{iNDArray, createUninitialized, iNDArray2, iNDArray3, iNDArray4, iNDArray5});
        Pointer pointer = atomicAllocator.getPointer(iNDArray, prepareActionAllWrite);
        Pointer pointer2 = atomicAllocator.getPointer(createUninitialized, prepareActionAllWrite);
        Pointer pointer3 = atomicAllocator.getPointer(iNDArray2, prepareActionAllWrite);
        Pointer pointer4 = atomicAllocator.getPointer(iNDArray3, prepareActionAllWrite);
        Pointer pointer5 = atomicAllocator.getPointer(iNDArray4, prepareActionAllWrite);
        Pointer pointer6 = atomicAllocator.getPointer(iNDArray5, prepareActionAllWrite);
        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
            Nd4j.getExecutioner().flushQueue();
        }
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, new CUstream_st(prepareActionAllWrite.getCublasStream())));
        if (z) {
            if (this.meanCache == null || this.meanCache.length() < iNDArray4.length()) {
                MemoryWorkspace scopeOutOfWorkspaces2 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    this.meanCache = Nd4j.createUninitialized(iNDArray.dataType(), new long[]{iNDArray4.length()});
                    if (scopeOutOfWorkspaces2 != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces2.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces2.close();
                        }
                    }
                    if (iNDArray.dataType() == DataType.HALF) {
                        scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                        Throwable th3 = null;
                        try {
                            try {
                                this.meanCache = this.meanCache.castTo(DataType.FLOAT);
                                if (scopeOutOfWorkspaces != null) {
                                    if (0 != 0) {
                                        try {
                                            scopeOutOfWorkspaces.close();
                                        } catch (Throwable th4) {
                                            th3.addSuppressed(th4);
                                        }
                                    } else {
                                        scopeOutOfWorkspaces.close();
                                    }
                                }
                            } finally {
                            }
                        } finally {
                        }
                    }
                } catch (Throwable th5) {
                    if (scopeOutOfWorkspaces2 != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces2.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        } else {
                            scopeOutOfWorkspaces2.close();
                        }
                    }
                    throw th5;
                }
            }
            if (this.varCache == null || this.varCache.length() < iNDArray4.length()) {
                MemoryWorkspace scopeOutOfWorkspaces3 = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                Throwable th7 = null;
                try {
                    try {
                        this.varCache = Nd4j.createUninitialized(iNDArray.dataType(), new long[]{iNDArray4.length()});
                        if (scopeOutOfWorkspaces3 != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces3.close();
                                } catch (Throwable th8) {
                                    th7.addSuppressed(th8);
                                }
                            } else {
                                scopeOutOfWorkspaces3.close();
                            }
                        }
                        if (this.nd4jDataType == DataType.HALF) {
                            scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                            Throwable th9 = null;
                            try {
                                try {
                                    this.varCache = this.varCache.castTo(DataType.FLOAT);
                                    if (scopeOutOfWorkspaces != null) {
                                        if (0 != 0) {
                                            try {
                                                scopeOutOfWorkspaces.close();
                                            } catch (Throwable th10) {
                                                th9.addSuppressed(th10);
                                            }
                                        } else {
                                            scopeOutOfWorkspaces.close();
                                        }
                                    }
                                } finally {
                                }
                            } finally {
                                if (scopeOutOfWorkspaces != null) {
                                    if (th9 != null) {
                                        try {
                                            scopeOutOfWorkspaces.close();
                                        } catch (Throwable th11) {
                                            th9.addSuppressed(th11);
                                        }
                                    } else {
                                        scopeOutOfWorkspaces.close();
                                    }
                                }
                            }
                        }
                    } finally {
                    }
                } finally {
                }
            }
            checkCudnn(cudnn.cudnnBatchNormalizationForwardTraining(this.cudnnContext, 1, this.alpha, this.beta, this.cudnnContext.srcTensorDesc, pointer, this.cudnnContext.dstTensorDesc, pointer2, this.cudnnContext.gammaBetaTensorDesc, pointer3, pointer4, 0.0d, pointer5, pointer6, d2, atomicAllocator.getPointer(this.meanCache, prepareActionAllWrite), atomicAllocator.getPointer(this.varCache, prepareActionAllWrite)));
        } else {
            checkCudnn(cudnn.cudnnBatchNormalizationForwardInference(this.cudnnContext, 1, this.alpha, this.beta, this.cudnnContext.srcTensorDesc, pointer, this.cudnnContext.dstTensorDesc, pointer2, this.cudnnContext.gammaBetaTensorDesc, pointer3, pointer4, pointer5, pointer6, d2));
        }
        atomicAllocator.getFlowController().registerActionAllWrite(prepareActionAllWrite, new INDArray[]{iNDArray, createUninitialized, iNDArray2, iNDArray3, iNDArray4, iNDArray5});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            prepareActionAllWrite.syncOldStream();
        }
        prepareActionAllWrite.syncOldStream();
        if (z) {
            AtomicAllocator.getInstance().getAllocationPoint(this.meanCache).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(this.varCache).tickDeviceWrite();
        }
        if (z && z2) {
            iNDArray4.assign(iNDArray4.castTo(DataType.HALF));
            iNDArray5.assign(iNDArray5.castTo(DataType.HALF));
            iNDArray2.assign(iNDArray2.castTo(DataType.HALF));
            iNDArray3.assign(iNDArray3.castTo(DataType.HALF));
        }
        return createUninitialized;
    }

    public INDArray getMeanCache(DataType dataType) {
        return dataType == DataType.HALF ? this.meanCache.castTo(DataType.HALF) : this.meanCache;
    }

    public INDArray getVarCache(DataType dataType) {
        INDArray subi;
        if (dataType == DataType.HALF) {
            INDArray castTo = this.varCache.castTo(DataType.HALF);
            subi = castTo.mul(castTo).rdivi(Double.valueOf(1.0d)).subi(Double.valueOf(this.eps));
        } else {
            subi = this.varCache.mul(this.varCache).rdivi(Double.valueOf(1.0d)).subi(Double.valueOf(this.eps));
        }
        return dataType == DataType.HALF ? subi.castTo(DataType.HALF) : subi;
    }

    public Map<String, Long> helperMemoryUse() {
        HashMap hashMap = new HashMap();
        hashMap.put("meanCache", Long.valueOf(this.meanCache == null ? 0L : this.meanCache.length() * this.meanCache.data().getElementSize()));
        hashMap.put("varCache", Long.valueOf(this.varCache == null ? 0L : this.varCache.length() * this.varCache.data().getElementSize()));
        return hashMap;
    }
}
