package org.nd4j.jita.allocator.context.impl;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.cublas;
import org.bytedeco.javacpp.cusolver;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.context.ContextPool;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.CUcontext;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.PropertyAccessor;

/* loaded from: input_file:org/nd4j/jita/allocator/context/impl/BasicContextPool.class */
public class BasicContextPool implements ContextPool {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BasicContextPool.class);
    protected static final int MAX_STREAMS_PER_DEVICE = 2147483646;
    protected volatile Map<Integer, CUcontext> cuPool = new ConcurrentHashMap();
    protected volatile Map<Integer, cublasHandle_t> cublasPool = new ConcurrentHashMap();
    protected volatile Map<Integer, cusolverDnHandle_t> solverPool = new ConcurrentHashMap();
    protected volatile Map<Long, CudaContext> contextsPool = new ConcurrentHashMap();
    protected volatile Map<Integer, Map<Integer, CudaContext>> contextsForDevices = new ConcurrentHashMap();
    protected Semaphore lock = new Semaphore(1);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

    public boolean containsContextForThread(long j) {
        return this.contextsPool.containsKey(Long.valueOf(j));
    }

    public CudaContext getContextForDevice(Integer num) {
        return acquireContextForDevice(num);
    }

    @Override // org.nd4j.jita.allocator.context.ContextPool
    public CudaContext acquireContextForDevice(Integer num) {
        Long valueOf = Long.valueOf(Thread.currentThread().getId());
        try {
            if (this.contextsPool.containsKey(valueOf)) {
                return this.contextsPool.get(valueOf);
            }
            try {
                this.lock.acquire();
                if (!this.contextsForDevices.containsKey(num)) {
                    this.contextsForDevices.put(num, new ConcurrentHashMap());
                }
                if (this.contextsForDevices.get(num).size() >= MAX_STREAMS_PER_DEVICE) {
                    Integer valueOf2 = Integer.valueOf(RandomUtils.nextInt(0, MAX_STREAMS_PER_DEVICE));
                    log.debug("Reusing context: " + valueOf2);
                    this.nativeOps.setDevice(new CudaPointer(num.intValue()));
                    CudaContext cudaContext = this.contextsForDevices.get(num).get(valueOf2);
                    this.contextsPool.put(valueOf, cudaContext);
                    this.lock.release();
                    return cudaContext;
                }
                log.debug("Creating new context...");
                CudaContext createNewStream = createNewStream(num);
                getDeviceBuffers(createNewStream, num.intValue());
                if (this.contextsForDevices.get(num).size() == 0) {
                    log.debug("Creating new cuBLAS handle for device [{}]...", num);
                    cudaStream_t oldStream = createNewStream(num).getOldStream();
                    cublasHandle_t createNewCublasHandle = createNewCublasHandle(oldStream);
                    createNewStream.setHandle(createNewCublasHandle);
                    createNewStream.setCublasStream(oldStream);
                    this.cublasPool.put(num, createNewCublasHandle);
                    log.debug("Creating new cuSolver handle for device [{}]...", num);
                    cudaStream_t oldStream2 = createNewStream(num).getOldStream();
                    cusolverDnHandle_t createNewSolverHandle = createNewSolverHandle(oldStream2);
                    createNewStream.setSolverHandle(createNewSolverHandle);
                    createNewStream.setSolverStream(oldStream2);
                    this.solverPool.put(num, createNewSolverHandle);
                } else {
                    log.debug("Reusing blas here...");
                    createNewStream.setHandle(this.cublasPool.get(num));
                    log.debug("Reusing solver here...");
                    createNewStream.setSolverHandle(this.solverPool.get(num));
                }
                createNewStream.syncOldStream();
                this.contextsPool.put(valueOf, createNewStream);
                this.contextsForDevices.get(num).put(Integer.valueOf(this.contextsForDevices.get(num).size()), createNewStream);
                this.lock.release();
                return createNewStream;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } catch (Throwable th) {
            this.lock.release();
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CudaContext createNewStream(Integer num) {
        log.trace("Creating new stream for thread: [{}], device: [{}]...", Long.valueOf(Thread.currentThread().getId()), num);
        this.nativeOps.setDevice(new CudaPointer(num.intValue()));
        CudaContext cudaContext = new CudaContext();
        cudaContext.initOldStream();
        return cudaContext;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public cublasHandle_t createNewCublasHandle() {
        cublas.cublasContext cublascontext = new cublas.cublasContext();
        int cublasCreate_v2 = cublas.cublasCreate_v2(cublascontext);
        if (cublasCreate_v2 != 0) {
            throw new IllegalStateException("Can't create new cuBLAS handle! cuBLAS errorCode: [" + cublasCreate_v2 + PropertyAccessor.PROPERTY_KEY_SUFFIX);
        }
        return new cublasHandle_t(cublascontext);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public cublasHandle_t createNewCublasHandle(cudaStream_t cudastream_t) {
        return createNewCublasHandle();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public cusolverDnHandle_t createNewSolverHandle() {
        cusolver.cusolverDnContext cusolverdncontext = new cusolver.cusolverDnContext();
        int cusolverDnCreate = cusolver.cusolverDnCreate(cusolverdncontext);
        if (cusolverDnCreate != 0) {
            throw new IllegalStateException("Can't create new cuBLAS handle! cusolverDn errorCode: [" + cusolverDnCreate + "] from cusolverDnCreate()");
        }
        return new cusolverDnHandle_t(cusolverdncontext);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public cusolverDnHandle_t createNewSolverHandle(cudaStream_t cudastream_t) {
        return createNewSolverHandle();
    }

    protected CUcontext createNewContext(Integer num) {
        return null;
    }

    public synchronized void resetPool(int i) {
    }

    public CUcontext getCuContextForDevice(Integer num) {
        return this.cuPool.get(num);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void getDeviceBuffers(CudaContext cudaContext, int i) {
        NativeOps deviceNativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        Pointer mallocDevice = deviceNativeOps.mallocDevice(16385 * 8 * 2, new CudaPointer(i), 0);
        if (mallocDevice == null) {
            throw new IllegalStateException("Can't allocate [DEVICE] reduction buffer memory!");
        }
        deviceNativeOps.memsetAsync(mallocDevice, 0, 16385 * 8 * 2, 0, cudaContext.getOldStream());
        cudaContext.syncOldStream();
        Pointer mallocDevice2 = deviceNativeOps.mallocDevice(1048576L, new CudaPointer(i), 0);
        if (mallocDevice2 == null) {
            throw new IllegalStateException("Can't allocate [DEVICE] allocation buffer memory!");
        }
        Pointer mallocHost = deviceNativeOps.mallocHost(1 * 8, 0);
        if (mallocHost == null) {
            throw new IllegalStateException("Can't allocate [HOST] scalar buffer memory!");
        }
        cudaContext.setBufferScalar(mallocHost);
        cudaContext.setBufferAllocation(mallocDevice2);
        cudaContext.setBufferReduction(mallocDevice);
        Pointer mallocDevice3 = deviceNativeOps.mallocDevice(1048576 * 8, new CudaPointer(i), 0);
        if (mallocDevice3 == null) {
            throw new IllegalStateException("Can't allocate [DEVICE] special buffer memory!");
        }
        deviceNativeOps.memsetAsync(mallocDevice3, 0, 65536 * 8, 0, cudaContext.getOldStream());
        cudaContext.setBufferSpecial(mallocDevice3);
    }

    @Override // org.nd4j.jita.allocator.context.ContextPool
    public ContextPack acquireContextPackForDevice(Integer num) {
        return new ContextPack(acquireContextForDevice(num));
    }
}
