package org.deeplearning4j.nn.layers;

import lombok.NonNull;
import org.bytedeco.cuda.cudnn.cudnnContext;
import org.bytedeco.cuda.cudnn.cudnnTensorStruct;
import org.bytedeco.cuda.global.cudart;
import org.bytedeco.cuda.global.cudnn;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper.class */
public abstract class BaseCudnnHelper {
    private static final Logger log = LoggerFactory.getLogger(BaseCudnnHelper.class);
    protected static final int TENSOR_FORMAT = 0;
    protected final DataType nd4jDataType;
    protected final int dataType;
    protected final int dataTypeSize;
    protected final Pointer alpha;
    protected final Pointer beta;
    protected SizeTPointer sizeInBytes = new SizeTPointer(1);

    /* renamed from: org.deeplearning4j.nn.layers.BaseCudnnHelper$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.INT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.HALF.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper$CudnnContext.class */
    protected static class CudnnContext extends cudnnContext {

        /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper$CudnnContext$Deallocator.class */
        protected static class Deallocator extends CudnnContext implements Pointer.Deallocator {
            Deallocator(CudnnContext cudnnContext) {
                super(cudnnContext);
            }

            public void deallocate() {
                destroyHandles();
            }
        }

        public CudnnContext() {
            Nd4j.create(1);
            AtomicAllocator.getInstance();
        }

        public CudnnContext(CudnnContext cudnnContext) {
            super(cudnnContext);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void createHandles() {
            BaseCudnnHelper.checkCudnn(cudnn.cudnnCreate(this));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void destroyHandles() {
            BaseCudnnHelper.checkCudnn(cudnn.cudnnDestroy(this));
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper$DataCache.class */
    protected static class DataCache extends Pointer {

        /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper$DataCache$Deallocator.class */
        static class Deallocator extends DataCache implements Pointer.Deallocator {
            Deallocator(DataCache dataCache) {
                super(dataCache);
            }

            public void deallocate() {
                BaseCudnnHelper.checkCuda(cudart.cudaFree(this));
                setNull();
            }
        }

        /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper$DataCache$HostDeallocator.class */
        static class HostDeallocator extends DataCache implements Pointer.Deallocator {
            HostDeallocator(DataCache dataCache) {
                super(dataCache);
            }

            public void deallocate() {
                BaseCudnnHelper.checkCuda(cudart.cudaFreeHost(this));
                setNull();
            }
        }

        public DataCache() {
        }

        public DataCache(long j) {
            this.position = 0L;
            this.capacity = j;
            this.limit = j;
            int cudaMalloc = cudart.cudaMalloc(this, j);
            if (cudaMalloc == 0) {
                deallocator(new Deallocator(this));
                return;
            }
            BaseCudnnHelper.log.warn("Cannot allocate " + j + " bytes of device memory (CUDA error = " + cudaMalloc + "), proceeding with host memory");
            BaseCudnnHelper.checkCuda(cudart.cudaMallocHost(this, j));
            deallocator(new HostDeallocator(this));
        }

        public DataCache(DataCache dataCache) {
            super(dataCache);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper$TensorArray.class */
    protected static class TensorArray extends PointerPointer<cudnnTensorStruct> {

        /* loaded from: input_file:org/deeplearning4j/nn/layers/BaseCudnnHelper$TensorArray$Deallocator.class */
        static class Deallocator extends TensorArray implements Pointer.Deallocator {
            Pointer owner;

            Deallocator(TensorArray tensorArray, Pointer pointer) {
                this.address = tensorArray.address;
                this.capacity = tensorArray.capacity;
                this.owner = pointer;
            }

            public void deallocate() {
                for (int i = BaseCudnnHelper.TENSOR_FORMAT; !isNull() && i < this.capacity; i++) {
                    BaseCudnnHelper.checkCudnn(cudnn.cudnnDestroyTensorDescriptor(get(cudnnTensorStruct.class, i)));
                }
                if (this.owner != null) {
                    this.owner.deallocate();
                    this.owner = null;
                }
                setNull();
            }
        }

        public TensorArray() {
        }

        public TensorArray(long j) {
            PointerPointer pointerPointer = new PointerPointer(j);
            pointerPointer.deallocate(false);
            this.address = pointerPointer.address();
            this.limit = pointerPointer.limit();
            this.capacity = pointerPointer.capacity();
            cudnnTensorStruct cudnntensorstruct = new cudnnTensorStruct();
            for (int i = BaseCudnnHelper.TENSOR_FORMAT; i < this.capacity; i++) {
                BaseCudnnHelper.checkCudnn(cudnn.cudnnCreateTensorDescriptor(cudnntensorstruct));
                put(i, cudnntensorstruct);
            }
            deallocator(new Deallocator(this, pointerPointer));
        }

        public TensorArray(TensorArray tensorArray) {
            super(tensorArray);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void checkCuda(int i) {
        if (i != 0) {
            throw new RuntimeException("CUDA error = " + i + ": " + cudart.cudaGetErrorString(i).getString());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void checkCudnn(int i) {
        if (i != 0) {
            throw new RuntimeException("cuDNN status = " + i + ": " + cudnn.cudnnGetErrorString(i).getString());
        }
    }

    public BaseCudnnHelper(@NonNull DataType dataType) {
        if (dataType == null) {
            throw new NullPointerException("dataType is marked @NonNull but is null");
        }
        this.nd4jDataType = dataType;
        this.dataType = dataType == DataType.DOUBLE ? 1 : dataType == DataType.FLOAT ? TENSOR_FORMAT : 2;
        this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : dataType == DataType.FLOAT ? 4 : 2;
        this.alpha = this.dataType == 1 ? new DoublePointer(new double[]{1.0d}) : new FloatPointer(new float[]{1.0f});
        this.beta = this.dataType == 1 ? new DoublePointer(new double[]{0.0d}) : new FloatPointer(new float[]{TENSOR_FORMAT});
    }

    public static int toCudnnDataType(DataType dataType) {
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType.ordinal()]) {
            case 1:
                return 1;
            case 2:
                return TENSOR_FORMAT;
            case 3:
                return 4;
            case 4:
                return 2;
            default:
                throw new RuntimeException("Cannot convert type: " + dataType);
        }
    }

    public boolean checkSupported() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int[] adaptForTensorDescr(int[] iArr) {
        if (iArr.length >= 4) {
            return iArr;
        }
        int[] iArr2 = new int[4];
        int i = TENSOR_FORMAT;
        while (i < iArr.length) {
            iArr2[i] = iArr[i];
            i++;
        }
        while (i < 4) {
            iArr2[i] = 1;
            i++;
        }
        return iArr2;
    }
}
