package org.nd4j.jita.allocator.context.impl;

import java.lang.ref.ReferenceQueue;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.garbage.GarbageResourceReference;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;

/* loaded from: input_file:org/nd4j/jita/allocator/context/impl/LimitedContextPool.class */
public class LimitedContextPool extends BasicContextPool {
    protected Map<Integer, LinkedBlockingQueue<CudaContext>> pool = new HashMap();
    protected Map<Long, CudaContext> acquired = new ConcurrentHashMap();
    protected AtomicInteger currentPoolSize = new AtomicInteger(0);
    protected Map<Integer, ResourceGarbageCollectorThread> collectors = new HashMap();
    protected Map<Integer, ReferenceQueue<Thread>> queueMap = new HashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/nd4j/jita/allocator/context/impl/LimitedContextPool$ResourceGarbageCollectorThread.class */
    public class ResourceGarbageCollectorThread extends Thread implements Runnable {
        private final ReferenceQueue<Thread> queue;

        public ResourceGarbageCollectorThread(int i, @NonNull ReferenceQueue<Thread> referenceQueue) {
            if (referenceQueue == null) {
                throw new NullPointerException("queue");
            }
            this.queue = referenceQueue;
            setDaemon(true);
            setName("ResourceGC thread " + i);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            while (true) {
                GarbageResourceReference garbageResourceReference = (GarbageResourceReference) this.queue.poll();
                if (garbageResourceReference != null) {
                    CudaContext context = garbageResourceReference.getContext();
                    Long valueOf = Long.valueOf(garbageResourceReference.getThreadId());
                    LimitedContextPool.this.pool.get(Integer.valueOf(garbageResourceReference.getDeviceId())).add(context);
                    LimitedContextPool.this.acquired.remove(valueOf);
                } else {
                    try {
                        Thread.sleep(100L);
                    } catch (Exception e) {
                    }
                }
            }
        }
    }

    public LimitedContextPool() {
        int poolSize = CudaEnvironment.getInstance().getConfiguration().getPoolSize();
        for (int i = 0; i < 4; i++) {
            ReferenceQueue<Thread> referenceQueue = new ReferenceQueue<>();
            ResourceGarbageCollectorThread resourceGarbageCollectorThread = new ResourceGarbageCollectorThread(i, referenceQueue);
            resourceGarbageCollectorThread.start();
            this.collectors.put(Integer.valueOf(i), resourceGarbageCollectorThread);
            this.queueMap.put(Integer.valueOf(i), referenceQueue);
        }
        fillPoolWithResources(poolSize, false);
        this.currentPoolSize.set(poolSize);
    }

    protected void addResourcesToPool(int i) {
        int intValue = AtomicAllocator.getInstance().getDeviceId().intValue();
        cublasHandle_t createNewCublasHandle = createNewCublasHandle();
        for (int i2 = 0; i2 < i; i2++) {
            CudaContext createNewStream = createNewStream(Integer.valueOf(intValue));
            createNewStream.initOldStream();
            getDeviceBuffers(createNewStream, intValue);
            createNewStream.setHandle(createNewCublasHandle);
            createNewStream.syncOldStream();
            this.pool.get(Integer.valueOf(intValue)).add(createNewStream);
        }
    }

    protected synchronized void fillPoolWithResources(int i, boolean z) {
        List<Integer> availableDevices = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices();
        int intValue = z ? AtomicAllocator.getInstance().getDeviceId().intValue() : 0;
        NativeOps deviceNativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        for (Integer num : availableDevices) {
            deviceNativeOps.setDevice(new CudaPointer(num.intValue()));
            this.pool.put(num, new LinkedBlockingQueue<>());
            cublasHandle_t createNewCublasHandle = createNewCublasHandle();
            for (int i2 = 0; i2 < i; i2++) {
                CudaContext createNewStream = createNewStream(num);
                createNewStream.initOldStream();
                getDeviceBuffers(createNewStream, num.intValue());
                createNewStream.setHandle(createNewCublasHandle);
                createNewStream.syncOldStream();
                this.pool.get(num).add(createNewStream);
            }
        }
        if (z) {
            deviceNativeOps.setDevice(new CudaPointer(intValue));
        }
    }

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool, org.nd4j.jita.allocator.context.ContextPool
    public CudaContext acquireContextForDevice(Integer num) {
        CudaContext poll;
        long id = Thread.currentThread().getId();
        CudaContext cudaContext = this.acquired.get(Long.valueOf(id));
        if (cudaContext != null && num.intValue() == cudaContext.getDeviceId()) {
            return cudaContext;
        }
        this.nativeOps.setDevice(new CudaPointer(num.intValue()));
        CudaContext poll2 = this.pool.get(num).poll();
        if (poll2 != null) {
            int nextInt = RandomUtils.nextInt(0, this.collectors.size());
            this.collectors.get(Integer.valueOf(nextInt));
            poll2.attachReference(new GarbageResourceReference(Thread.currentThread(), this.queueMap.get(Integer.valueOf(nextInt)), poll2, num.intValue()));
            this.acquired.put(Long.valueOf(id), poll2);
            poll2.setDeviceId(num.intValue());
            return poll2;
        }
        do {
            try {
                System.gc();
                poll = this.pool.get(num).poll(1L, TimeUnit.SECONDS);
                if (poll != null) {
                    int nextInt2 = RandomUtils.nextInt(0, this.collectors.size());
                    this.collectors.get(Integer.valueOf(nextInt2));
                    poll.attachReference(new GarbageResourceReference(Thread.currentThread(), this.queueMap.get(Integer.valueOf(nextInt2)), poll, num.intValue()));
                    this.acquired.put(Long.valueOf(id), poll);
                    poll.setDeviceId(num.intValue());
                } else if (this.currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) {
                    addResourcesToPool(16);
                    this.currentPoolSize.addAndGet(16);
                } else {
                    logger.warn("Can't allocate new context, sleeping...");
                    System.gc();
                    try {
                        Thread.sleep(500L);
                    } catch (Exception e) {
                    }
                }
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        } while (poll == null);
        return poll;
    }

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool, org.nd4j.jita.allocator.context.ContextPool
    public ContextPack acquireContextPackForDevice(Integer num) {
        return new ContextPack(acquireContextForDevice(num));
    }

    @Override // org.nd4j.jita.allocator.context.impl.BasicContextPool
    public CudaContext getContextForDevice(Integer num) {
        return acquireContextForDevice(num);
    }
}
