package org.nd4j.jita.memory;

import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.memory.BasicMemoryManager;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/memory/CudaMemoryManager.class */
public class CudaMemoryManager extends BasicMemoryManager {
    private static final Logger log = LoggerFactory.getLogger(CudaMemoryManager.class);

    public Pointer allocate(long j, MemoryKind memoryKind, boolean z) {
        AtomicAllocator.getInstance();
        if (memoryKind == MemoryKind.HOST) {
            Pointer mallocHost = NativeOpsHolder.getInstance().getDeviceNativeOps().mallocHost(j, 0);
            if (mallocHost == null) {
                throw new RuntimeException("Failed to allocate " + j + " bytes from HOST memory");
            }
            if (z) {
                Pointer.memset(mallocHost, 0, j);
            }
            return mallocHost;
        }
        if (memoryKind != MemoryKind.DEVICE) {
            throw new RuntimeException("Unknown MemoryKind requested: " + memoryKind);
        }
        Pointer mallocDevice = NativeOpsHolder.getInstance().getDeviceNativeOps().mallocDevice(j, (Pointer) null, 0);
        if (mallocDevice == null) {
            throw new RuntimeException("Failed to allocate " + j + " bytes from DEVICE [" + Nd4j.getAffinityManager().getDeviceForCurrentThread() + "] memory");
        }
        if (z) {
            CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
            if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(mallocDevice, 0, j, 0, cudaContext.getSpecialStream()) == 0) {
                throw new ND4JIllegalStateException("memset failed on device_" + Nd4j.getAffinityManager().getDeviceForCurrentThread());
            }
            cudaContext.getSpecialStream().synchronize();
        }
        return mallocDevice;
    }

    public void collect(INDArray... iNDArrayArr) {
        Nd4j.getExecutioner().commit();
        int i = -1;
        AtomicAllocator atomicAllocator = AtomicAllocator.getInstance();
        for (INDArray iNDArray : iNDArrayArr) {
            i++;
            if (iNDArray != null && !iNDArray.isView()) {
                AllocationPoint allocationPoint = atomicAllocator.getAllocationPoint(iNDArray);
                if (allocationPoint.getAllocationStatus() == AllocationStatus.HOST) {
                    atomicAllocator.getMemoryHandler().free(allocationPoint, AllocationStatus.HOST);
                } else if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
                    atomicAllocator.getMemoryHandler().free(allocationPoint, AllocationStatus.DEVICE);
                    atomicAllocator.getMemoryHandler().free(allocationPoint, AllocationStatus.HOST);
                } else if (allocationPoint.getAllocationStatus() != AllocationStatus.DEALLOCATED) {
                    throw new RuntimeException("Unknown AllocationStatus: " + allocationPoint.getAllocationStatus() + " for argument: " + i);
                }
                allocationPoint.setAllocationStatus(AllocationStatus.DEALLOCATED);
            }
        }
    }

    public synchronized void purgeCaches() {
        AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().purgeCache();
    }

    public void memcpy(DataBuffer dataBuffer, DataBuffer dataBuffer2) {
        if ((dataBuffer instanceof CompressedDataBuffer) && !(dataBuffer2 instanceof CompressedDataBuffer)) {
            AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer2);
            long elementSize = dataBuffer2.getElementSize() * dataBuffer2.length();
            if (!allocationPoint.isActualOnHostSide()) {
                AtomicAllocator.getInstance().synchronizeHostData(dataBuffer2);
            }
            Pointer.memcpy(dataBuffer.addressPointer(), AtomicAllocator.getInstance().getHostPointer(dataBuffer2), elementSize);
            return;
        }
        if (!(dataBuffer instanceof CompressedDataBuffer) && (dataBuffer2 instanceof CompressedDataBuffer)) {
            AllocationPoint allocationPoint2 = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
            Pointer.memcpy(dataBuffer.addressPointer(), dataBuffer2.addressPointer(), dataBuffer2.getElementSize() * dataBuffer2.length());
            allocationPoint2.tickHostWrite();
            return;
        }
        if ((dataBuffer instanceof CompressedDataBuffer) && (dataBuffer2 instanceof CompressedDataBuffer)) {
            Pointer.memcpy(dataBuffer.addressPointer(), dataBuffer2.addressPointer(), dataBuffer2.length() * dataBuffer2.getElementSize());
        } else {
            AtomicAllocator.getInstance().memcpy(dataBuffer, dataBuffer2);
        }
    }

    public void release(Pointer pointer, MemoryKind memoryKind) {
        if (memoryKind == MemoryKind.DEVICE) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pointer, (Pointer) null);
        } else if (memoryKind == MemoryKind.HOST) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pointer);
        }
    }

    public void setAutoGcWindow(int i) {
        super.setAutoGcWindow(i);
        CudaEnvironment.getInstance().getConfiguration().setNoGcWindowMs(i);
    }

    public void memset(INDArray iNDArray) {
        if (iNDArray.isView()) {
            iNDArray.assign(Double.valueOf(0.0d));
            Nd4j.getExecutioner().commit();
            return;
        }
        Nd4j.getExecutioner().push();
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(iNDArray);
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
            NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(AtomicAllocator.getInstance().getPointer(iNDArray, cudaContext), 0, iNDArray.data().length() * Nd4j.sizeOfDataType(iNDArray.data().dataType()), 0, cudaContext.getOldStream());
            cudaContext.getOldStream().synchronize();
            allocationPoint.tickDeviceWrite();
            return;
        }
        if (allocationPoint.getAllocationStatus() == AllocationStatus.HOST) {
            Nd4j.getExecutioner().commit();
            Pointer.memset(AtomicAllocator.getInstance().getHostPointer(iNDArray), 0, iNDArray.data().length() * Nd4j.sizeOfDataType(iNDArray.data().dataType()));
            allocationPoint.tickHostWrite();
        }
    }
}
