package org.nd4j.linalg.jcublas.buffer.allocation;

import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.allocation.MemoryStrategy;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;

/* loaded from: input_file:org/nd4j/linalg/jcublas/buffer/allocation/PinnedMemoryStrategy.class */
public class PinnedMemoryStrategy implements MemoryStrategy {
    public Object copyToHost(DataBuffer dataBuffer, int i) {
        JCudaBuffer jCudaBuffer = (JCudaBuffer) dataBuffer;
        BaseCudaDataBuffer.DevicePointerInfo devicePointerInfo = (BaseCudaDataBuffer.DevicePointerInfo) jCudaBuffer.getPointersToContexts().get(Thread.currentThread().getName(), Integer.valueOf(i));
        JCuda.cudaMemcpyAsync(jCudaBuffer.getHostPointer(), devicePointerInfo.getPointer(), devicePointerInfo.getLength(), 2, ContextHolder.getInstance().getCudaStream());
        return jCudaBuffer.getHostPointer();
    }

    public Object alloc(DataBuffer dataBuffer, int i, int i2, int i3) {
        Pointer pointer = new Pointer();
        BaseCudaDataBuffer.DevicePointerInfo devicePointerInfo = new BaseCudaDataBuffer.DevicePointerInfo(pointer, i3, i, i2);
        JCuda.cudaHostAlloc(pointer, dataBuffer.getElementSize() * i3, 0);
        return devicePointerInfo;
    }

    public void free(DataBuffer dataBuffer, int i) {
        BaseCudaDataBuffer.DevicePointerInfo devicePointerInfo = (BaseCudaDataBuffer.DevicePointerInfo) ((JCudaBuffer) dataBuffer).getPointersToContexts().get(Thread.currentThread().getName(), Integer.valueOf(i));
        if (devicePointerInfo.isFreed()) {
            return;
        }
        JCuda.cudaFreeHost(devicePointerInfo.getPointer());
        devicePointerInfo.setFreed(true);
    }
}
