package org.nd4j.jita.allocator.impl;

import java.lang.ref.ReferenceQueue;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.LockSupport;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.context.ContextPool;
import org.nd4j.jita.allocator.context.ExternalContext;
import org.nd4j.jita.allocator.enums.Aggressiveness;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.garbage.GarbageBufferReference;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.time.Ring;
import org.nd4j.jita.allocator.time.rings.LockedRing;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.constant.ConstantProtector;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.handler.impl.CudaZeroHandler;
import org.nd4j.jita.workspace.CudaWorkspace;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.rocksdb.HashLinkedListMemTableConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.PropertyAccessor;
import org.springframework.util.backoff.ExponentialBackOff;
import org.springframework.util.backoff.FixedBackOff;

/* loaded from: input_file:org/nd4j/jita/allocator/impl/AtomicAllocator.class */
public class AtomicAllocator implements Allocator {
    private transient MemoryHandler memoryHandler;
    protected static ConstantProtector protector;
    private static final AtomicAllocator INSTANCE = new AtomicAllocator();
    private static Logger log = LoggerFactory.getLogger((Class<?>) AtomicAllocator.class);
    private AtomicLong allocationsCounter = new AtomicLong(0);
    private AtomicLong objectsTracker = new AtomicLong(0);
    private Map<Long, AllocationPoint> allocationsMap = new ConcurrentHashMap();
    private ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock();
    private ReentrantReadWriteLock externalsLock = new ReentrantReadWriteLock();
    private Map<Integer, UnifiedGarbageCollectorThread> collectorsUnified = new ConcurrentHashMap();
    private final AtomicBoolean shouldStop = new AtomicBoolean(false);
    private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
    private final Ring deviceLong = new LockedRing(30);
    private final Ring deviceShort = new LockedRing(30);
    private final Ring zeroLong = new LockedRing(30);
    private final Ring zeroShort = new LockedRing(30);
    private final Map<Integer, ReferenceQueue<BaseDataBuffer>> queueMap = new ConcurrentHashMap();
    private ConstantHandler constantHandler = Nd4j.getConstantHandler();
    private AtomicLong useTracker = new AtomicLong(System.currentTimeMillis());
    private Configuration configuration = CudaEnvironment.getInstance().getConfiguration();

    /* loaded from: input_file:org/nd4j/jita/allocator/impl/AtomicAllocator$DeviceGarbageCollectorThread.class */
    private class DeviceGarbageCollectorThread extends Thread implements Runnable {
        private final Integer deviceId;
        private final AtomicBoolean terminate;

        public DeviceGarbageCollectorThread(Integer num, AtomicBoolean atomicBoolean) {
            this.deviceId = num;
            this.terminate = atomicBoolean;
            setName("device gc thread [" + num + PropertyAccessor.PROPERTY_KEY_SUFFIX);
            setDaemon(true);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            AtomicAllocator.log.info("Starting device GC for device: " + this.deviceId);
            long currentTimeMillis = System.currentTimeMillis();
            while (!this.terminate.get()) {
                try {
                    Thread.sleep(Math.max(AtomicAllocator.this.configuration.getMinimumTTLMilliseconds(), FixedBackOff.DEFAULT_INTERVAL));
                } catch (Exception e) {
                }
                Aggressiveness gpuDeallocAggressiveness = AtomicAllocator.this.configuration.getGpuDeallocAggressiveness();
                if ((AtomicAllocator.this.memoryHandler.getAllocatedDeviceObjects(this.deviceId) > 100000 || AtomicAllocator.this.memoryHandler.getAllocatedDeviceMemory(this.deviceId) > AtomicAllocator.this.configuration.getMaximumDeviceAllocation() * 0.75d) && gpuDeallocAggressiveness.ordinal() < Aggressiveness.URGENT.ordinal()) {
                    gpuDeallocAggressiveness = Aggressiveness.URGENT;
                }
                if (AtomicAllocator.this.memoryHandler.getAllocatedDeviceMemory(this.deviceId) > AtomicAllocator.this.configuration.getMaximumDeviceAllocation() * 0.85d) {
                    gpuDeallocAggressiveness = Aggressiveness.IMMEDIATE;
                }
                if (AtomicAllocator.this.memoryHandler.getAllocatedDeviceMemory(this.deviceId) >= AtomicAllocator.this.configuration.getMaximumDeviceAllocation() * 0.25d || AtomicAllocator.this.memoryHandler.getAllocatedDeviceObjects(this.deviceId) >= 500 || currentTimeMillis <= System.currentTimeMillis() - ExponentialBackOff.DEFAULT_MAX_INTERVAL) {
                    AtomicAllocator.this.seekUnusedDevice(0L, this.deviceId, gpuDeallocAggressiveness);
                    currentTimeMillis = System.currentTimeMillis();
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/nd4j/jita/allocator/impl/AtomicAllocator$UnifiedGarbageCollectorThread.class */
    public class UnifiedGarbageCollectorThread extends Thread implements Runnable {
        private final ReferenceQueue<BaseDataBuffer> queue;
        private int threadId;
        private int deviceId;
        private AtomicLong stopper = new AtomicLong(System.currentTimeMillis());

        public UnifiedGarbageCollectorThread(Integer num, @NonNull ReferenceQueue<BaseDataBuffer> referenceQueue) {
            if (referenceQueue == null) {
                throw new NullPointerException("queue is marked @NonNull but is null");
            }
            this.queue = referenceQueue;
            setDaemon(true);
            setName("UniGC thread " + num);
            this.threadId = num.intValue();
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            GarbageBufferReference garbageBufferReference;
            while (true) {
                try {
                    garbageBufferReference = this.threadId == 0 ? (GarbageBufferReference) this.queue.poll() : (GarbageBufferReference) this.queue.remove();
                } catch (InterruptedException e) {
                }
                if (garbageBufferReference != null) {
                    AllocationPoint point = garbageBufferReference.getPoint();
                    if (!point.isAttached()) {
                        if (this.threadId == 0) {
                            this.stopper.set(System.currentTimeMillis());
                        }
                        if (point.getAllocationStatus() == AllocationStatus.HOST) {
                            AtomicAllocator.this.purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false);
                        } else if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
                            AtomicAllocator.this.purgeDeviceObject(0L, Integer.valueOf(point.getDeviceId()), point.getObjectId(), point, false);
                            AtomicAllocator.this.purgeZeroObject(point.getBucketId(), point.getObjectId(), point, false);
                        }
                    } else {
                        if (!AtomicAllocator.this.allocationsMap.containsKey(point.getObjectId())) {
                            throw new RuntimeException();
                        }
                        AtomicAllocator.this.getFlowController().waitTillReleased(point);
                        AtomicAllocator.this.getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent());
                        AtomicAllocator.this.getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent());
                        AtomicAllocator.this.allocationsMap.remove(point.getObjectId());
                    }
                } else {
                    try {
                        if (this.threadId == 0) {
                            if (Nd4j.getMemoryManager().isPeriodicGcActive()) {
                                long currentTimeMillis = System.currentTimeMillis();
                                if (AtomicAllocator.this.useTracker.get() <= currentTimeMillis - 3000 || currentTimeMillis <= Nd4j.getMemoryManager().getLastGcTime() + Nd4j.getMemoryManager().getAutoGcWindow()) {
                                    LockSupport.parkNanos(HashLinkedListMemTableConfig.DEFAULT_BUCKET_COUNT);
                                } else {
                                    Nd4j.getMemoryManager().invokeGc();
                                }
                            } else {
                                LockSupport.parkNanos(HashLinkedListMemTableConfig.DEFAULT_BUCKET_COUNT);
                            }
                        }
                    } catch (Exception e2) {
                    }
                }
            }
        }
    }

    /* loaded from: input_file:org/nd4j/jita/allocator/impl/AtomicAllocator$ZeroGarbageCollectorThread.class */
    private class ZeroGarbageCollectorThread extends Thread implements Runnable {
        private final Long bucketId;
        private final AtomicBoolean terminate;

        public ZeroGarbageCollectorThread(Long l, AtomicBoolean atomicBoolean) {
            this.bucketId = l;
            this.terminate = atomicBoolean;
            setName("zero gc thread " + l);
            setDaemon(true);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            AtomicAllocator.log.debug("Starting zero GC for thread: " + this.bucketId);
            long currentTimeMillis = System.currentTimeMillis();
            while (!this.terminate.get()) {
                try {
                    Thread.sleep(Math.max(AtomicAllocator.this.configuration.getMinimumTTLMilliseconds(), 10000L));
                } catch (Exception e) {
                }
                Aggressiveness hostDeallocAggressiveness = AtomicAllocator.this.configuration.getHostDeallocAggressiveness();
                if ((AtomicAllocator.this.memoryHandler.getAllocatedHostObjects(this.bucketId) > 500000 || AtomicAllocator.this.memoryHandler.getAllocatedHostMemory() > AtomicAllocator.this.configuration.getMaximumZeroAllocation() * 0.75d) && hostDeallocAggressiveness.ordinal() < Aggressiveness.URGENT.ordinal()) {
                    hostDeallocAggressiveness = Aggressiveness.URGENT;
                }
                if (AtomicAllocator.this.memoryHandler.getAllocatedHostMemory() > AtomicAllocator.this.configuration.getMaximumZeroAllocation() * 0.85d) {
                    hostDeallocAggressiveness = Aggressiveness.IMMEDIATE;
                }
                if (AtomicAllocator.this.memoryHandler.getAllocatedHostMemory() >= AtomicAllocator.this.configuration.getMaximumZeroAllocation() * 0.25d || AtomicAllocator.this.memoryHandler.getAllocatedHostObjects(this.bucketId) >= FixedBackOff.DEFAULT_INTERVAL || currentTimeMillis <= System.currentTimeMillis() - ExponentialBackOff.DEFAULT_MAX_INTERVAL) {
                    AtomicAllocator.this.seekUnusedZero(this.bucketId, hostDeallocAggressiveness);
                    currentTimeMillis = System.currentTimeMillis();
                }
            }
        }
    }

    public static AtomicAllocator getInstance() {
        if (INSTANCE == null) {
            throw new RuntimeException("AtomicAllocator is NULL");
        }
        return INSTANCE;
    }

    private AtomicAllocator() {
        applyConfiguration();
        this.memoryHandler = new CudaZeroHandler();
        this.memoryHandler.init(this.configuration, this);
        initDeviceCollectors();
        initHostCollectors();
        protector = ConstantProtector.getInstance();
    }

    public void applyConfiguration() {
        CudaEnvironment.getInstance().notifyConfigurationApplied();
        NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(this.configuration.isDebug());
        NativeOpsHolder.getInstance().getDeviceNativeOps().enableVerboseMode(this.configuration.isVerbose());
        NativeOpsHolder.getInstance().getDeviceNativeOps().enableP2P(this.configuration.isCrossDeviceAccessAllowed());
        NativeOpsHolder.getInstance().getDeviceNativeOps().setGridLimit(this.configuration.getMaximumGridSize());
        NativeOpsHolder.getInstance().getDeviceNativeOps().setOmpNumThreads(this.configuration.getMaximumBlockSize());
        NativeOpsHolder.getInstance().getDeviceNativeOps().setOmpMinThreads(this.configuration.getMinimumBlockSize());
    }

    protected void initHostCollectors() {
        for (int i = 0; i < this.configuration.getNumberOfGcThreads(); i++) {
            ReferenceQueue<BaseDataBuffer> referenceQueue = new ReferenceQueue<>();
            UnifiedGarbageCollectorThread unifiedGarbageCollectorThread = new UnifiedGarbageCollectorThread(Integer.valueOf(i), referenceQueue);
            Nd4j.getAffinityManager().attachThreadToDevice(unifiedGarbageCollectorThread, getDeviceId());
            this.queueMap.put(Integer.valueOf(i), referenceQueue);
            unifiedGarbageCollectorThread.start();
            this.collectorsUnified.put(Integer.valueOf(i), unifiedGarbageCollectorThread);
        }
    }

    protected void initDeviceCollectors() {
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public ExternalContext getDeviceContext() {
        return this.memoryHandler.getDeviceContext();
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void setMemoryHandler(@NonNull MemoryHandler memoryHandler) {
        if (memoryHandler == null) {
            throw new NullPointerException("memoryHandler is marked @NonNull but is null");
        }
        this.globalLock.writeLock().lock();
        this.memoryHandler = memoryHandler;
        this.memoryHandler.init(this.configuration, this);
        this.globalLock.writeLock().unlock();
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void applyConfiguration(@NonNull Configuration configuration) {
        if (configuration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        if (this.wasInitialised.get()) {
            return;
        }
        this.globalLock.writeLock().lock();
        this.configuration = configuration;
        this.globalLock.writeLock().unlock();
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public Configuration getConfiguration() {
        try {
            this.globalLock.readLock().lock();
            return this.configuration;
        } finally {
            this.globalLock.readLock().unlock();
        }
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public Pointer getPointer(DataBuffer dataBuffer, CudaContext cudaContext) {
        return this.memoryHandler.getDevicePointer(dataBuffer, cudaContext);
    }

    public Pointer getPointer(DataBuffer dataBuffer) {
        return this.memoryHandler.getDevicePointer(dataBuffer, (CudaContext) getDeviceContext().getContext());
    }

    @Override // org.nd4j.jita.allocator.Allocator
    @Deprecated
    public Pointer getPointer(DataBuffer dataBuffer, AllocationShape allocationShape, boolean z, CudaContext cudaContext) {
        return this.memoryHandler.getDevicePointer(dataBuffer, cudaContext);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public Pointer getPointer(INDArray iNDArray, CudaContext cudaContext) {
        return this.memoryHandler.getDevicePointer(iNDArray.data(), cudaContext);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public Pointer getHostPointer(INDArray iNDArray) {
        synchronizeHostData(iNDArray);
        return this.memoryHandler.getHostPointer(iNDArray.data());
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public Pointer getHostPointer(DataBuffer dataBuffer) {
        return this.memoryHandler.getHostPointer(dataBuffer);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void synchronizeHostData(INDArray iNDArray) {
        synchronizeHostData(iNDArray.data().originalDataBuffer() == null ? iNDArray.data() : iNDArray.data().originalDataBuffer());
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void synchronizeHostData(DataBuffer dataBuffer) {
        if (!dataBuffer.isConstant() && this.memoryHandler.isDeviceDependant()) {
            AllocationPoint allocationPoint = getAllocationPoint(dataBuffer.getTrackingPoint());
            if (allocationPoint == null) {
                throw new RuntimeException("AllocationPoint is NULL");
            }
            this.memoryHandler.synchronizeThreadDevice(Long.valueOf(Thread.currentThread().getId()), this.memoryHandler.getDeviceId(), allocationPoint);
        }
    }

    public Integer getDeviceId(INDArray iNDArray) {
        return Integer.valueOf(getAllocationPoint(iNDArray).getDeviceId());
    }

    public void freeMemory(AllocationPoint allocationPoint) {
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            getMemoryHandler().getMemoryProvider().free(allocationPoint);
            allocationPoint.setAllocationStatus(AllocationStatus.HOST);
            getMemoryHandler().getMemoryProvider().free(allocationPoint);
            getMemoryHandler().forget(allocationPoint, AllocationStatus.DEVICE);
        } else {
            getMemoryHandler().getMemoryProvider().free(allocationPoint);
            getMemoryHandler().forget(allocationPoint, AllocationStatus.HOST);
        }
        this.allocationsMap.remove(allocationPoint.getObjectId());
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public AllocationPoint allocateMemory(DataBuffer dataBuffer, AllocationShape allocationShape, boolean z) {
        AllocationPoint allocationPoint = null;
        if (this.configuration.getMemoryModel() == Configuration.MemoryModel.IMMEDIATE) {
            allocationPoint = allocateMemory(dataBuffer, allocationShape, this.memoryHandler.getInitialLocation(), z);
        } else if (this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
            allocationPoint = allocateMemory(dataBuffer, allocationShape, AllocationStatus.HOST, z);
        }
        return allocationPoint;
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public AllocationPoint allocateMemory(DataBuffer dataBuffer, AllocationShape allocationShape, AllocationStatus allocationStatus, boolean z) {
        AllocationPoint allocationPoint = new AllocationPoint();
        this.useTracker.set(System.currentTimeMillis());
        Long valueOf = Long.valueOf(this.objectsTracker.getAndIncrement());
        allocationPoint.setObjectId(valueOf);
        allocationPoint.setShape(allocationShape);
        allocationPoint.attachReference(new GarbageBufferReference((BaseDataBuffer) dataBuffer, this.queueMap.get(Integer.valueOf(RandomUtils.nextInt(0, this.configuration.getNumberOfGcThreads()))), allocationPoint));
        allocationPoint.setDeviceId(-1);
        if (dataBuffer.isAttached()) {
            long requiredMemory = AllocationUtils.getRequiredMemory(allocationShape);
            getMemoryHandler().getCudaContext();
            allocationPoint.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue());
            CudaWorkspace cudaWorkspace = (CudaWorkspace) Nd4j.getMemoryManager().getCurrentWorkspace();
            PointersPair pointersPair = new PointersPair();
            PagedPointer alloc = cudaWorkspace.alloc(requiredMemory, MemoryKind.DEVICE, allocationShape.getDataType(), z);
            PagedPointer alloc2 = cudaWorkspace.alloc(requiredMemory, MemoryKind.HOST, allocationShape.getDataType(), z);
            pointersPair.setHostPointer(alloc2);
            if (alloc != null) {
                pointersPair.setDevicePointer(alloc);
                allocationPoint.setAllocationStatus(AllocationStatus.DEVICE);
            } else {
                pointersPair.setDevicePointer(alloc2);
                allocationPoint.setAllocationStatus(AllocationStatus.HOST);
            }
            allocationPoint.setAttached(true);
            allocationPoint.setPointers(pointersPair);
        } else {
            allocationPoint.setPointers(this.memoryHandler.alloc(allocationStatus, allocationPoint, allocationShape, z));
        }
        this.allocationsMap.put(valueOf, allocationPoint);
        allocationPoint.tickDeviceWrite();
        return allocationPoint;
    }

    protected AllocationPoint getAllocationPoint(Long l) {
        return this.allocationsMap.get(l);
    }

    protected void purgeZeroObject(Long l, Long l2, AllocationPoint allocationPoint, boolean z) {
        this.allocationsMap.remove(l2);
        this.memoryHandler.purgeZeroObject(l, l2, allocationPoint, z);
        getFlowController().getEventsProvider().storeEvent(allocationPoint.getLastWriteEvent());
        getFlowController().getEventsProvider().storeEvent(allocationPoint.getLastReadEvent());
    }

    protected void purgeDeviceObject(Long l, Integer num, Long l2, AllocationPoint allocationPoint, boolean z) {
        this.memoryHandler.purgeDeviceObject(l, num, l2, allocationPoint, z);
    }

    protected synchronized long seekUnusedZero(Long l, Aggressiveness aggressiveness) {
        AtomicLong atomicLong = new AtomicLong(0L);
        int allocatedHostObjects = (int) this.memoryHandler.getAllocatedHostObjects(l);
        float average = this.zeroShort.getAverage();
        float length = average / (Aggressiveness.values().length - aggressiveness.ordinal());
        float average2 = this.zeroLong.getAverage() / (Aggressiveness.values().length - aggressiveness.ordinal());
        AtomicInteger atomicInteger = new AtomicInteger(0);
        AtomicInteger atomicInteger2 = new AtomicInteger(0);
        for (Long l2 : this.memoryHandler.getHostTrackingPoints(l)) {
            AllocationPoint allocationPoint = getAllocationPoint(l2);
            if (allocationPoint != null && allocationPoint.getAllocationStatus() == AllocationStatus.HOST) {
                if (allocationPoint.getBuffer() == null) {
                    purgeZeroObject(l, l2, allocationPoint, false);
                    atomicLong.addAndGet(AllocationUtils.getRequiredMemory(allocationPoint.getShape()));
                    atomicInteger.incrementAndGet();
                } else {
                    atomicInteger2.incrementAndGet();
                }
            }
        }
        log.debug("Zero {} elements checked: [{}], deleted: {}, survived: {}", l, Integer.valueOf(allocatedHostObjects), Integer.valueOf(atomicInteger.get()), Integer.valueOf(atomicInteger2.get()));
        return atomicLong.get();
    }

    protected long seekUnusedDevice(Long l, Integer num, Aggressiveness aggressiveness) {
        AtomicLong atomicLong = new AtomicLong(0L);
        float average = this.deviceShort.getAverage();
        float length = average / (Aggressiveness.values().length - aggressiveness.ordinal());
        float average2 = this.deviceLong.getAverage() / (Aggressiveness.values().length - aggressiveness.ordinal());
        AtomicInteger atomicInteger = new AtomicInteger(0);
        AtomicInteger atomicInteger2 = new AtomicInteger(0);
        AtomicInteger atomicInteger3 = new AtomicInteger(0);
        for (Long l2 : this.memoryHandler.getDeviceTrackingPoints(num)) {
            AllocationPoint allocationPoint = getAllocationPoint(l2);
            if (allocationPoint.getBuffer() != null) {
                atomicInteger3.incrementAndGet();
            } else if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
                purgeDeviceObject(l, num, l2, allocationPoint, false);
                atomicLong.addAndGet(AllocationUtils.getRequiredMemory(allocationPoint.getShape()));
                purgeZeroObject(allocationPoint.getBucketId(), l2, allocationPoint, false);
                atomicInteger.incrementAndGet();
            }
        }
        log.debug("Thread/Device [" + l + "/" + num + "] elements purged: [" + atomicInteger.get() + "]; Relocated: [" + atomicInteger2.get() + "]; Survivors: [" + atomicInteger3.get() + PropertyAccessor.PROPERTY_KEY_SUFFIX);
        return atomicLong.get();
    }

    public long getTotalAllocatedHostMemory() {
        return 0L;
    }

    protected int getTotalTrackingPoints() {
        return this.allocationsMap.size();
    }

    public long getTotalAllocatedDeviceMemory(Integer num) {
        return 0L;
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void memcpyAsync(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        this.memoryHandler.memcpyAsync(dataBuffer, pointer, j, j2);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void memcpySpecial(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        this.memoryHandler.memcpySpecial(dataBuffer, pointer, j, j2);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void memcpyDevice(DataBuffer dataBuffer, Pointer pointer, long j, long j2, CudaContext cudaContext) {
        this.memoryHandler.memcpyDevice(dataBuffer, pointer, j, j2, cudaContext);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void memcpyBlocking(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        this.memoryHandler.memcpyBlocking(dataBuffer, pointer, j, j2);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void memcpy(DataBuffer dataBuffer, DataBuffer dataBuffer2) {
        this.memoryHandler.memcpy(dataBuffer, dataBuffer2);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public Integer getDeviceId() {
        return this.memoryHandler.getDeviceId();
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public Pointer getDeviceIdPointer() {
        return new CudaPointer(getDeviceId().intValue());
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void tickHostWrite(DataBuffer dataBuffer) {
        getAllocationPoint(dataBuffer.getTrackingPoint()).tickHostWrite();
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void tickHostWrite(INDArray iNDArray) {
        tickHostWrite(iNDArray.data().originalDataBuffer() == null ? iNDArray.data() : iNDArray.data().originalDataBuffer());
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void tickDeviceWrite(INDArray iNDArray) {
        getAllocationPoint((iNDArray.data().originalDataBuffer() == null ? iNDArray.data() : iNDArray.data().originalDataBuffer()).getTrackingPoint()).tickDeviceWrite();
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public AllocationPoint getAllocationPoint(INDArray iNDArray) {
        return getAllocationPoint(iNDArray.data().originalDataBuffer() == null ? iNDArray.data() : iNDArray.data().originalDataBuffer());
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public AllocationPoint getAllocationPoint(DataBuffer dataBuffer) {
        if (!(dataBuffer instanceof CompressedDataBuffer)) {
            return getAllocationPoint(dataBuffer.getTrackingPoint());
        }
        log.warn("Trying to get AllocationPoint from CompressedDataBuffer");
        throw new RuntimeException("AP CDB");
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public void registerAction(CudaContext cudaContext, INDArray iNDArray, INDArray... iNDArrayArr) {
        this.memoryHandler.registerAction(cudaContext, iNDArray, iNDArrayArr);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public FlowController getFlowController() {
        return this.memoryHandler.getFlowController();
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public ContextPool getContextPool() {
        return this.memoryHandler.getContextPool();
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public DataBuffer getConstantBuffer(int[] iArr) {
        return Nd4j.getConstantHandler().getConstantBuffer(iArr);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public DataBuffer getConstantBuffer(float[] fArr) {
        return Nd4j.getConstantHandler().getConstantBuffer(fArr);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public DataBuffer getConstantBuffer(double[] dArr) {
        return Nd4j.getConstantHandler().getConstantBuffer(dArr);
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public DataBuffer moveToConstant(DataBuffer dataBuffer) {
        Nd4j.getConstantHandler().moveToConstantSpace(dataBuffer);
        return dataBuffer;
    }

    @Override // org.nd4j.jita.allocator.Allocator
    public MemoryHandler getMemoryHandler() {
        return this.memoryHandler;
    }
}
