package org.nd4j.jita.memory.impl;

import java.util.concurrent.ConcurrentHashMap;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.memory.impl.CudaCachingZeroProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/memory/impl/CudaFullCachingProvider.class */
public class CudaFullCachingProvider extends CudaCachingZeroProvider {
    protected final long MAX_GPU_ALLOCATION = this.configuration.getMaximumSingleDeviceAllocation();
    protected final long MAX_GPU_CACHE = this.configuration.getMaximumDeviceCache();
    protected volatile ConcurrentHashMap<Integer, ConcurrentHashMap<AllocationShape, CudaCachingZeroProvider.CacheHolder>> deviceCache = new ConcurrentHashMap<>();
    private static Logger log = LoggerFactory.getLogger(CudaFullCachingProvider.class);

    @Override // org.nd4j.jita.memory.impl.CudaCachingZeroProvider, org.nd4j.jita.memory.impl.CudaDirectProvider, org.nd4j.jita.memory.MemoryProvider
    public PointersPair malloc(AllocationShape allocationShape, AllocationPoint allocationPoint, AllocationStatus allocationStatus) {
        Pointer poll;
        long requiredMemory = AllocationUtils.getRequiredMemory(allocationShape);
        if (allocationStatus != AllocationStatus.DEVICE || requiredMemory >= this.MAX_GPU_ALLOCATION) {
            return super.malloc(allocationShape, allocationPoint, allocationStatus);
        }
        int intValue = AtomicAllocator.getInstance().getDeviceId().intValue();
        ensureDeviceCacheHolder(Integer.valueOf(intValue), allocationShape);
        CudaCachingZeroProvider.CacheHolder cacheHolder = this.deviceCache.get(Integer.valueOf(intValue)).get(allocationShape);
        if (cacheHolder == null || (poll = cacheHolder.poll()) == null) {
            this.cacheDeviceMiss.incrementAndGet();
            return super.malloc(allocationShape, allocationPoint, allocationStatus);
        }
        this.cacheDeviceHit.incrementAndGet();
        this.deviceCachedAmount.addAndGet((-1) * requiredMemory);
        PointersPair pointersPair = new PointersPair();
        pointersPair.setDevicePointer(poll);
        allocationPoint.setAllocationStatus(AllocationStatus.DEVICE);
        allocationPoint.setDeviceId(intValue);
        return pointersPair;
    }

    @Override // org.nd4j.jita.memory.impl.CudaCachingZeroProvider, org.nd4j.jita.memory.impl.CudaDirectProvider, org.nd4j.jita.memory.MemoryProvider
    public void free(AllocationPoint allocationPoint) {
        if (allocationPoint.getAllocationStatus() != AllocationStatus.DEVICE) {
            super.free(allocationPoint);
            return;
        }
        AllocationShape shape = allocationPoint.getShape();
        int deviceId = allocationPoint.getDeviceId();
        long address = allocationPoint.getDevicePointer().address();
        long requiredMemory = AllocationUtils.getRequiredMemory(shape);
        if (requiredMemory > this.MAX_GPU_ALLOCATION || this.deviceCachedAmount.get() >= this.MAX_GPU_CACHE) {
            super.free(allocationPoint);
            return;
        }
        ensureDeviceCacheHolder(Integer.valueOf(deviceId), shape);
        CudaCachingZeroProvider.CacheHolder cacheHolder = this.deviceCache.get(Integer.valueOf(deviceId)).get(shape);
        if (allocationPoint.getDeviceId() != deviceId) {
            throw new RuntimeException("deviceId changed!");
        }
        if (this.validator.get(Long.valueOf(address)).intValue() != deviceId) {
            log.error("MISMATCH: {}", Long.valueOf(address));
            throw new RuntimeException("PEW");
        }
        if (requiredMemory <= 96) {
            cacheHolder.put(new CudaPointer(allocationPoint.getDevicePointer().address()));
            return;
        }
        long size = cacheHolder.size();
        this.deviceCache.get(Integer.valueOf(deviceId)).size();
        long j = size * requiredMemory;
        cacheHolder.put(new CudaPointer(allocationPoint.getDevicePointer().address()));
    }

    protected void ensureDeviceCacheHolder(Integer num, AllocationShape allocationShape) {
        try {
            if (!this.deviceCache.containsKey(num)) {
                try {
                    this.singleLock.acquire();
                    if (!this.deviceCache.containsKey(num)) {
                        this.deviceCache.put(num, new ConcurrentHashMap<>());
                    }
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            if (this.deviceCache.get(num).containsKey(allocationShape)) {
                return;
            }
            try {
                this.singleLock.acquire();
                if (!this.deviceCache.get(num).containsKey(allocationShape)) {
                    this.deviceCache.get(num).put(allocationShape, new CudaCachingZeroProvider.CacheHolder(this, allocationShape, this.deviceCachedAmount));
                }
                this.singleLock.release();
            } catch (Exception e2) {
                this.singleLock.release();
            } catch (Throwable th) {
                this.singleLock.release();
                throw th;
            }
        } finally {
            this.singleLock.release();
        }
    }
}
