package org.nd4j.jita.memory.impl;

import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/memory/impl/CudaDirectProvider.class */
public class CudaDirectProvider implements MemoryProvider {
    protected static final long DEVICE_RESERVED_SPACE = 52428800;
    private static Logger log = LoggerFactory.getLogger(CudaDirectProvider.class);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private AtomicLong emergencyCounter = new AtomicLong(0);

    @Override // org.nd4j.jita.memory.MemoryProvider
    public synchronized PointersPair malloc(AllocationShape allocationShape, AllocationPoint allocationPoint, AllocationStatus allocationStatus) {
        switch (allocationStatus) {
            case HOST:
                new Pointer();
                long requiredMemory = AllocationUtils.getRequiredMemory(allocationShape);
                if (requiredMemory < 1) {
                    requiredMemory = 1;
                }
                Pointer mallocHost = this.nativeOps.mallocHost(requiredMemory, 0);
                if (mallocHost == null) {
                    throw new RuntimeException("Can't allocate [HOST] memory: " + requiredMemory + "; threadId: " + Thread.currentThread().getId());
                }
                CudaPointer cudaPointer = new CudaPointer(mallocHost);
                PointersPair pointersPair = new PointersPair();
                pointersPair.setDevicePointer(new CudaPointer(cudaPointer, requiredMemory));
                pointersPair.setHostPointer(new CudaPointer(cudaPointer, requiredMemory));
                allocationPoint.setPointers(pointersPair);
                allocationPoint.setAllocationStatus(AllocationStatus.HOST);
                return pointersPair;
            case DEVICE:
                long requiredMemory2 = AllocationUtils.getRequiredMemory(allocationShape);
                if (requiredMemory2 < 1) {
                    requiredMemory2 = 1;
                }
                Pointer mallocDevice = this.nativeOps.mallocDevice(requiredMemory2, (Pointer) null, 0);
                if (mallocDevice == null) {
                    return null;
                }
                CudaPointer cudaPointer2 = new CudaPointer(mallocDevice);
                PointersPair pointers = allocationPoint.getPointers();
                if (pointers == null) {
                    pointers = new PointersPair();
                }
                pointers.setDevicePointer(new CudaPointer(cudaPointer2, requiredMemory2));
                allocationPoint.setAllocationStatus(AllocationStatus.DEVICE);
                allocationPoint.setDeviceId(AtomicAllocator.getInstance().getDeviceId());
                return pointers;
            default:
                throw new IllegalStateException("Unsupported location for malloc: [" + allocationStatus + "]");
        }
    }

    @Override // org.nd4j.jita.memory.MemoryProvider
    public synchronized void free(AllocationPoint allocationPoint) {
        switch (allocationPoint.getAllocationStatus()) {
            case HOST:
                AllocationUtils.getRequiredMemory(allocationPoint.getShape());
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(allocationPoint.getPointers().getHostPointer()) == 0) {
                    throw new RuntimeException("Can't deallocate [HOST] memory...");
                }
                return;
            case DEVICE:
                AllocationUtils.getRequiredMemory(allocationPoint.getShape());
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(allocationPoint.getPointers().getDevicePointer(), new CudaPointer(0L)) == 0) {
                    throw new RuntimeException("Can't deallocate [DEVICE] memory...");
                }
                return;
            default:
                throw new IllegalStateException("Can't free memory on target [" + allocationPoint.getAllocationStatus() + "]");
        }
    }

    @Override // org.nd4j.jita.memory.MemoryProvider
    public boolean pingDeviceForFreeMemory(Integer num, long j) {
        return this.nativeOps.getDeviceFreeMemory(new CudaPointer(-1L)) - j >= DEVICE_RESERVED_SPACE;
    }
}
