package org.nd4j.jita.memory.impl;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.MemoryTracker;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.memory.MemoryProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/memory/impl/CudaCachingZeroProvider.class */
public class CudaCachingZeroProvider extends CudaDirectProvider implements MemoryProvider {
    private static Logger log = LoggerFactory.getLogger((Class<?>) CudaCachingZeroProvider.class);
    protected volatile ConcurrentHashMap<AllocationShape, CacheHolder> zeroCache = new ConcurrentHashMap<>();
    protected final AtomicLong cacheZeroHit = new AtomicLong(0);
    protected final AtomicLong cacheZeroMiss = new AtomicLong(0);
    protected final AtomicLong cacheDeviceHit = new AtomicLong(0);
    protected final AtomicLong cacheDeviceMiss = new AtomicLong(0);
    private final AtomicLong allocRequests = new AtomicLong(0);
    protected final AtomicLong zeroCachedAmount = new AtomicLong(0);
    protected List<AtomicLong> deviceCachedAmount = new ArrayList();
    protected final Semaphore singleLock = new Semaphore(1);
    protected final long FORCED_CACHE_THRESHOLD = 96;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/jita/memory/impl/CudaCachingZeroProvider$CacheHolder.class */
    public class CacheHolder {
        private Queue<Pointer> queue = new ConcurrentLinkedQueue();
        private volatile int counter = 0;
        private long reqMem;
        private final AtomicLong allocCounter;

        public CacheHolder(AllocationShape allocationShape, AtomicLong atomicLong) {
            this.reqMem = 0L;
            this.reqMem = AllocationUtils.getRequiredMemory(allocationShape);
            this.allocCounter = atomicLong;
        }

        public synchronized int size() {
            return this.counter;
        }

        public synchronized Pointer poll() {
            Pointer poll = this.queue.poll();
            if (poll != null) {
                this.counter--;
            }
            return poll;
        }

        public synchronized void put(Pointer pointer) {
            this.allocCounter.addAndGet(this.reqMem);
            this.counter++;
            this.queue.add(pointer);
        }
    }

    /* loaded from: input_file:org/nd4j/jita/memory/impl/CudaCachingZeroProvider$CachePreallocator.class */
    protected class CachePreallocator extends Thread implements Runnable {
        private AllocationShape shape;
        private AllocationStatus location;
        private int target;

        public CachePreallocator(AllocationShape allocationShape, AllocationStatus allocationStatus, int i) {
            this.shape = allocationShape;
            this.target = i;
            this.location = allocationStatus;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            CudaCachingZeroProvider.this.ensureCacheHolder(this.shape);
            for (int i = 0; i < this.target; i++) {
                PointersPair malloc = CudaCachingZeroProvider.super.malloc(this.shape, new AllocationPoint(), this.location);
                if (this.location == AllocationStatus.HOST) {
                    CudaCachingZeroProvider.this.zeroCache.get(this.shape).put(new CudaPointer(malloc.getHostPointer().address()));
                }
            }
        }
    }

    @Override // org.nd4j.jita.memory.impl.CudaDirectProvider, org.nd4j.jita.memory.MemoryProvider
    public PointersPair malloc(AllocationShape allocationShape, AllocationPoint allocationPoint, AllocationStatus allocationStatus) {
        Pointer poll;
        long requiredMemory = AllocationUtils.getRequiredMemory(allocationShape);
        if (allocationStatus != AllocationStatus.HOST || requiredMemory >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength()) {
            return super.malloc(allocationShape, allocationPoint, allocationStatus);
        }
        CacheHolder cacheHolder = this.zeroCache.get(allocationShape);
        if (cacheHolder == null || (poll = cacheHolder.poll()) == null) {
            this.cacheZeroMiss.incrementAndGet();
            if (CudaEnvironment.getInstance().getConfiguration().isUsePreallocation() && this.zeroCachedAmount.get() < CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache() / 10 && requiredMemory < 16777216) {
                new CachePreallocator(allocationShape, allocationStatus, CudaEnvironment.getInstance().getConfiguration().getPreallocationCalls()).start();
            }
            this.cacheZeroMiss.incrementAndGet();
            return super.malloc(allocationShape, allocationPoint, allocationStatus);
        }
        this.cacheZeroHit.incrementAndGet();
        this.zeroCachedAmount.addAndGet((-1) * requiredMemory);
        PointersPair pointersPair = new PointersPair();
        pointersPair.setDevicePointer(new CudaPointer(poll.address()));
        pointersPair.setHostPointer(new CudaPointer(poll.address()));
        allocationPoint.setAllocationStatus(AllocationStatus.HOST);
        MemoryTracker.getInstance().incrementAllocatedHostAmount(requiredMemory);
        MemoryTracker.getInstance().decrementCachedHostAmount(requiredMemory);
        return pointersPair;
    }

    protected void ensureCacheHolder(AllocationShape allocationShape) {
        if (this.zeroCache.containsKey(allocationShape)) {
            return;
        }
        try {
            try {
                this.singleLock.acquire();
                if (!this.zeroCache.containsKey(allocationShape)) {
                    this.zeroCache.put(allocationShape, new CacheHolder(allocationShape, this.zeroCachedAmount));
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } finally {
            this.singleLock.release();
        }
    }

    @Override // org.nd4j.jita.memory.impl.CudaDirectProvider, org.nd4j.jita.memory.MemoryProvider
    public void free(AllocationPoint allocationPoint) {
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            super.free(allocationPoint);
            return;
        }
        if (allocationPoint.getHostPointer() == null) {
            return;
        }
        AllocationShape shape = allocationPoint.getShape();
        long requiredMemory = AllocationUtils.getRequiredMemory(shape);
        if (requiredMemory > CudaEnvironment.getInstance().getConfiguration().getMaximumHostCacheableLength() || this.zeroCachedAmount.get() >= CudaEnvironment.getInstance().getConfiguration().getMaximumHostCache()) {
            super.free(allocationPoint);
            return;
        }
        ensureCacheHolder(shape);
        CacheHolder cacheHolder = this.zeroCache.get(shape);
        if (requiredMemory <= 96) {
            Pointer.memset(allocationPoint.getHostPointer(), 0, requiredMemory);
            cacheHolder.put(new CudaPointer(allocationPoint.getHostPointer().address()));
        } else {
            long size = cacheHolder.size();
            this.zeroCache.size();
            long j = size * requiredMemory;
            Pointer.memset(allocationPoint.getHostPointer(), 0, requiredMemory);
            cacheHolder.put(new CudaPointer(allocationPoint.getHostPointer().address()));
        }
        MemoryTracker.getInstance().decrementAllocatedHostAmount(requiredMemory);
        MemoryTracker.getInstance().incrementCachedHostAmount(requiredMemory);
    }

    private float getZeroCacheHitRatio() {
        return ((float) (this.cacheZeroHit.get() * 100)) / ((float) (this.cacheZeroHit.get() + this.cacheZeroMiss.get()));
    }

    private float getDeviceCacheHitRatio() {
        return ((float) (this.cacheDeviceHit.get() * 100)) / ((float) (this.cacheDeviceHit.get() + this.cacheDeviceMiss.get()));
    }

    @Deprecated
    public void printCacheStats() {
        log.debug("Cached host amount: " + this.zeroCachedAmount.get());
        log.debug("Cached device amount: " + this.deviceCachedAmount.get(0).get());
        log.debug("Total shapes in cache: " + this.zeroCache.size());
        log.debug("Current host hit ratio: " + getZeroCacheHitRatio());
        log.debug("Current device hit ratio: " + getDeviceCacheHitRatio());
    }

    @Override // org.nd4j.jita.memory.impl.CudaDirectProvider, org.nd4j.jita.memory.MemoryProvider
    public void purgeCache() {
        Iterator it = this.zeroCache.keySet().iterator();
        while (it.hasNext()) {
            AllocationShape allocationShape = (AllocationShape) it.next();
            while (true) {
                Pointer poll = this.zeroCache.get(allocationShape).poll();
                if (poll != null) {
                    freeHost(poll);
                    MemoryTracker.getInstance().decrementCachedHostAmount(allocationShape.getNumberOfBytes());
                }
            }
        }
        this.zeroCachedAmount.set(0L);
    }
}
