package org.deeplearning4j.cuda.dropout;

import com.jakewharton.byteunits.BinaryByteUnit;
import java.util.Collections;
import java.util.Map;
import org.bytedeco.cuda.cudart.CUstream_st;
import org.bytedeco.cuda.cudnn.cudnnDropoutStruct;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.deeplearning4j.cuda.BaseCudnnHelper;
import org.deeplearning4j.nn.conf.dropout.DropoutHelper;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/cuda/dropout/CudnnDropoutHelper.class */
public class CudnnDropoutHelper extends BaseCudnnHelper implements DropoutHelper {
    private static final Logger log = LoggerFactory.getLogger(CudnnDropoutHelper.class);
    private CudnnDropoutContext cudnnContext;
    private boolean initializedDescriptor;
    private BaseCudnnHelper.DataCache rngStates;
    private BaseCudnnHelper.DataCache mask;
    private SizeTPointer stateSizeBytesPtr;
    private SizeTPointer reserveSizeBytesPtr;
    private float lastInitializedP;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/cuda/dropout/CudnnDropoutHelper$CudnnDropoutContext.class */
    public static class CudnnDropoutContext extends BaseCudnnHelper.CudnnContext {
        private cudnnTensorStruct xTensorDesc;
        private cudnnTensorStruct dxTensorDesc;
        private cudnnTensorStruct yTensorDesc;
        private cudnnTensorStruct dyTensorDesc;
        private cudnnDropoutStruct dropoutDesc;

        /* loaded from: input_file:org/deeplearning4j/cuda/dropout/CudnnDropoutHelper$CudnnDropoutContext$Deallocator.class */
        private static class Deallocator extends CudnnDropoutContext implements Pointer.Deallocator {
            Deallocator(CudnnDropoutContext cudnnDropoutContext) {
                super(cudnnDropoutContext);
            }

            public void deallocate() {
                destroyHandles();
            }
        }

        public CudnnDropoutContext() {
            this.xTensorDesc = new cudnnTensorStruct();
            this.dxTensorDesc = new cudnnTensorStruct();
            this.yTensorDesc = new cudnnTensorStruct();
            this.dyTensorDesc = new cudnnTensorStruct();
            this.dropoutDesc = new cudnnDropoutStruct();
            createHandles();
            deallocator(new Deallocator(this));
        }

        public CudnnDropoutContext(CudnnDropoutContext cudnnDropoutContext) {
            super(cudnnDropoutContext);
            this.xTensorDesc = new cudnnTensorStruct();
            this.dxTensorDesc = new cudnnTensorStruct();
            this.yTensorDesc = new cudnnTensorStruct();
            this.dyTensorDesc = new cudnnTensorStruct();
            this.dropoutDesc = new cudnnDropoutStruct();
            this.xTensorDesc = new cudnnTensorStruct(cudnnDropoutContext.xTensorDesc);
            this.dxTensorDesc = new cudnnTensorStruct(cudnnDropoutContext.dxTensorDesc);
            this.yTensorDesc = new cudnnTensorStruct(cudnnDropoutContext.yTensorDesc);
            this.dyTensorDesc = new cudnnTensorStruct(cudnnDropoutContext.dyTensorDesc);
            this.dropoutDesc = new cudnnDropoutStruct(cudnnDropoutContext.dropoutDesc);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.cuda.BaseCudnnHelper.CudnnContext
        public void createHandles() {
            super.createHandles();
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.xTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dxTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.yTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(this.dyTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnCreateDropoutDescriptor(this.dropoutDesc));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // org.deeplearning4j.cuda.BaseCudnnHelper.CudnnContext
        public void destroyHandles() {
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.xTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dxTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.yTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(this.dyTensorDesc));
            CudnnDropoutHelper.checkCudnn(cudnn.cudnnDestroyDropoutDescriptor(this.dropoutDesc));
            super.destroyHandles();
        }
    }

    public CudnnDropoutHelper(DataType dataType) {
        super(dataType);
        this.cudnnContext = new CudnnDropoutContext();
        this.initializedDescriptor = false;
    }

    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }

    @Override // org.deeplearning4j.cuda.BaseCudnnHelper
    public boolean checkSupported() {
        return true;
    }

    public void applyDropout(INDArray iNDArray, INDArray iNDArray2, double d) {
        float f = (float) (1.0d - d);
        int[] adaptForTensorDescr = adaptForTensorDescr(ArrayUtil.toInts(iNDArray.shape()));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.xTensorDesc, this.dataType, adaptForTensorDescr.length, adaptForTensorDescr, adaptForTensorDescr(ArrayUtil.toInts(iNDArray.stride()))));
        int[] adaptForTensorDescr2 = adaptForTensorDescr(ArrayUtil.toInts(iNDArray2.shape()));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.yTensorDesc, this.dataType, adaptForTensorDescr2.length, adaptForTensorDescr2, adaptForTensorDescr(ArrayUtil.toInts(iNDArray2.stride()))));
        if (this.stateSizeBytesPtr == null) {
            this.stateSizeBytesPtr = new SizeTPointer(1L);
            this.reserveSizeBytesPtr = new SizeTPointer(1L);
        }
        checkCudnn(cudnn.cudnnDropoutGetStatesSize(this.cudnnContext, this.stateSizeBytesPtr));
        long j = this.stateSizeBytesPtr.get();
        checkCudnn(cudnn.cudnnDropoutGetReserveSpaceSize(this.cudnnContext.xTensorDesc, this.reserveSizeBytesPtr));
        long j2 = this.reserveSizeBytesPtr.get();
        if (this.rngStates == null || this.rngStates.capacity() < j) {
            if (log.isTraceEnabled()) {
                if (this.rngStates == null) {
                    log.trace("CudnnDropoutHelper: Allocating intial RNG states workspace of size {} ({})", Long.valueOf(j), BinaryByteUnit.format(j, "#.00"));
                } else {
                    log.trace("CudnnDropoutHelper: Deallocating RNG states of size {} ({}), allocating new workspace of size {} ({})", new Object[]{Long.valueOf(this.rngStates.capacity()), BinaryByteUnit.format(this.rngStates.capacity(), "#.00"), Long.valueOf(j), BinaryByteUnit.format(j, "#.00")});
                }
            }
            if (this.rngStates != null) {
                this.rngStates.deallocate();
            }
            this.rngStates = new BaseCudnnHelper.DataCache(j);
            this.initializedDescriptor = false;
        }
        if (this.mask == null || this.mask.capacity() < j2) {
            if (log.isTraceEnabled()) {
                if (this.mask == null) {
                    log.trace("CudnnDropoutHelper: Allocating intial mask array of size {} ({})", Long.valueOf(j2), BinaryByteUnit.format(j2, "#.00"));
                } else {
                    log.trace("CudnnDropoutHelper: Deallocating mask array of size {} ({}), allocating new mask array of size {} ({})", new Object[]{Long.valueOf(this.mask.capacity()), BinaryByteUnit.format(this.mask.capacity(), "#.00"), Long.valueOf(j2), BinaryByteUnit.format(j2, "#.00")});
                }
            }
            if (this.mask != null) {
                this.mask.deallocate();
            }
            this.mask = new BaseCudnnHelper.DataCache(j2);
        }
        if (!this.initializedDescriptor || f != this.lastInitializedP) {
            if (log.isTraceEnabled()) {
                log.trace("CudnnDropoutHelper: (re)initializing dropout descriptor");
            }
            long nextLong = Nd4j.getRandom().nextLong();
            this.lastInitializedP = f;
            checkCudnn(cudnn.cudnnSetDropoutDescriptor(this.cudnnContext.dropoutDesc, this.cudnnContext, f, this.rngStates, this.rngStates.capacity(), nextLong));
            this.initializedDescriptor = true;
        }
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray, new INDArray[]{iNDArray2});
        Pointer pointer = atomicAllocator.getPointer(iNDArray, prepareAction);
        Pointer pointer2 = atomicAllocator.getPointer(iNDArray2, prepareAction);
        checkCudnn(cudnn.cudnnSetStream(this.cudnnContext, new CUstream_st(prepareAction.getCublasStream())));
        checkCudnn(cudnn.cudnnDropoutForward(this.cudnnContext, this.cudnnContext.dropoutDesc, this.cudnnContext.xTensorDesc, pointer, this.cudnnContext.yTensorDesc, pointer2, this.mask, this.mask.capacity()));
        atomicAllocator.registerAction(prepareAction, iNDArray, new INDArray[]{iNDArray2});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            prepareAction.syncOldStream();
        }
    }

    public void backprop(INDArray iNDArray, INDArray iNDArray2) {
        int[] adaptForTensorDescr = adaptForTensorDescr(ArrayUtil.toInts(iNDArray.shape()));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.dyTensorDesc, this.dataType, adaptForTensorDescr.length, adaptForTensorDescr, adaptForTensorDescr(ArrayUtil.toInts(iNDArray.stride()))));
        int[] adaptForTensorDescr2 = adaptForTensorDescr(ArrayUtil.toInts(iNDArray2.shape()));
        checkCudnn(cudnn.cudnnSetTensorNdDescriptor(this.cudnnContext.dxTensorDesc, this.dataType, adaptForTensorDescr2.length, adaptForTensorDescr2, adaptForTensorDescr(ArrayUtil.toInts(iNDArray2.stride()))));
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        CudaContext prepareAction = atomicAllocator.getFlowController().prepareAction(iNDArray, new INDArray[]{iNDArray2});
        checkCudnn(cudnn.cudnnDropoutBackward(this.cudnnContext, this.cudnnContext.dropoutDesc, this.cudnnContext.dyTensorDesc, atomicAllocator.getPointer(iNDArray, prepareAction), this.cudnnContext.dxTensorDesc, atomicAllocator.getPointer(iNDArray2, prepareAction), this.mask, this.mask.capacity()));
        atomicAllocator.registerAction(prepareAction, iNDArray, new INDArray[]{iNDArray2});
        if (CudaEnvironment.getInstance().getConfiguration().isDebug()) {
            prepareAction.syncOldStream();
        }
    }

    public CudnnDropoutContext getCudnnContext() {
        return this.cudnnContext;
    }

    public boolean isInitializedDescriptor() {
        return this.initializedDescriptor;
    }

    public BaseCudnnHelper.DataCache getRngStates() {
        return this.rngStates;
    }

    public BaseCudnnHelper.DataCache getMask() {
        return this.mask;
    }

    public SizeTPointer getStateSizeBytesPtr() {
        return this.stateSizeBytesPtr;
    }

    public SizeTPointer getReserveSizeBytesPtr() {
        return this.reserveSizeBytesPtr;
    }

    public float getLastInitializedP() {
        return this.lastInitializedP;
    }

    public void setCudnnContext(CudnnDropoutContext cudnnDropoutContext) {
        this.cudnnContext = cudnnDropoutContext;
    }

    public void setInitializedDescriptor(boolean z) {
        this.initializedDescriptor = z;
    }

    public void setRngStates(BaseCudnnHelper.DataCache dataCache) {
        this.rngStates = dataCache;
    }

    public void setMask(BaseCudnnHelper.DataCache dataCache) {
        this.mask = dataCache;
    }

    public void setStateSizeBytesPtr(SizeTPointer sizeTPointer) {
        this.stateSizeBytesPtr = sizeTPointer;
    }

    public void setReserveSizeBytesPtr(SizeTPointer sizeTPointer) {
        this.reserveSizeBytesPtr = sizeTPointer;
    }

    public void setLastInitializedP(float f) {
        this.lastInitializedP = f;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CudnnDropoutHelper)) {
            return false;
        }
        CudnnDropoutHelper cudnnDropoutHelper = (CudnnDropoutHelper) obj;
        if (!cudnnDropoutHelper.canEqual(this) || isInitializedDescriptor() != cudnnDropoutHelper.isInitializedDescriptor() || Float.compare(getLastInitializedP(), cudnnDropoutHelper.getLastInitializedP()) != 0) {
            return false;
        }
        CudnnDropoutContext cudnnContext = getCudnnContext();
        CudnnDropoutContext cudnnContext2 = cudnnDropoutHelper.getCudnnContext();
        if (cudnnContext == null) {
            if (cudnnContext2 != null) {
                return false;
            }
        } else if (!cudnnContext.equals(cudnnContext2)) {
            return false;
        }
        BaseCudnnHelper.DataCache rngStates = getRngStates();
        BaseCudnnHelper.DataCache rngStates2 = cudnnDropoutHelper.getRngStates();
        if (rngStates == null) {
            if (rngStates2 != null) {
                return false;
            }
        } else if (!rngStates.equals(rngStates2)) {
            return false;
        }
        BaseCudnnHelper.DataCache mask = getMask();
        BaseCudnnHelper.DataCache mask2 = cudnnDropoutHelper.getMask();
        if (mask == null) {
            if (mask2 != null) {
                return false;
            }
        } else if (!mask.equals(mask2)) {
            return false;
        }
        SizeTPointer stateSizeBytesPtr = getStateSizeBytesPtr();
        SizeTPointer stateSizeBytesPtr2 = cudnnDropoutHelper.getStateSizeBytesPtr();
        if (stateSizeBytesPtr == null) {
            if (stateSizeBytesPtr2 != null) {
                return false;
            }
        } else if (!stateSizeBytesPtr.equals(stateSizeBytesPtr2)) {
            return false;
        }
        SizeTPointer reserveSizeBytesPtr = getReserveSizeBytesPtr();
        SizeTPointer reserveSizeBytesPtr2 = cudnnDropoutHelper.getReserveSizeBytesPtr();
        return reserveSizeBytesPtr == null ? reserveSizeBytesPtr2 == null : reserveSizeBytesPtr.equals(reserveSizeBytesPtr2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CudnnDropoutHelper;
    }

    public int hashCode() {
        int floatToIntBits = (((1 * 59) + (isInitializedDescriptor() ? 79 : 97)) * 59) + Float.floatToIntBits(getLastInitializedP());
        CudnnDropoutContext cudnnContext = getCudnnContext();
        int hashCode = (floatToIntBits * 59) + (cudnnContext == null ? 43 : cudnnContext.hashCode());
        BaseCudnnHelper.DataCache rngStates = getRngStates();
        int hashCode2 = (hashCode * 59) + (rngStates == null ? 43 : rngStates.hashCode());
        BaseCudnnHelper.DataCache mask = getMask();
        int hashCode3 = (hashCode2 * 59) + (mask == null ? 43 : mask.hashCode());
        SizeTPointer stateSizeBytesPtr = getStateSizeBytesPtr();
        int hashCode4 = (hashCode3 * 59) + (stateSizeBytesPtr == null ? 43 : stateSizeBytesPtr.hashCode());
        SizeTPointer reserveSizeBytesPtr = getReserveSizeBytesPtr();
        return (hashCode4 * 59) + (reserveSizeBytesPtr == null ? 43 : reserveSizeBytesPtr.hashCode());
    }

    public String toString() {
        return "CudnnDropoutHelper(cudnnContext=" + getCudnnContext() + ", initializedDescriptor=" + isInitializedDescriptor() + ", rngStates=" + getRngStates() + ", mask=" + getMask() + ", stateSizeBytesPtr=" + getStateSizeBytesPtr() + ", reserveSizeBytesPtr=" + getReserveSizeBytesPtr() + ", lastInitializedP=" + getLastInitializedP() + ")";
    }
}
