package org.nd4j.jita.handler.impl;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.concurrency.DeviceAllocationsTracker;
import org.nd4j.jita.allocator.context.ContextPool;
import org.nd4j.jita.allocator.context.ExternalContext;
import org.nd4j.jita.allocator.context.impl.LimitedContextPool;
import org.nd4j.jita.allocator.context.impl.PackedContextPool;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.PointersPair;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.jita.flow.impl.AsynchronousFlowController;
import org.nd4j.jita.flow.impl.SynchronousFlowController;
import org.nd4j.jita.handler.MemoryHandler;
import org.nd4j.jita.memory.MemoryProvider;
import org.nd4j.jita.memory.impl.CudaCachingZeroProvider;
import org.nd4j.jita.memory.impl.CudaDirectProvider;
import org.nd4j.jita.memory.impl.CudaFullCachingProvider;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/handler/impl/CudaZeroHandler.class */
public class CudaZeroHandler implements MemoryHandler {
    private static Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    private static Logger log = LoggerFactory.getLogger(CudaZeroHandler.class);
    protected volatile DeviceAllocationsTracker deviceMemoryTracker;
    private final ContextPool contextPool;
    private final MemoryProvider memoryProvider;
    private final FlowController flowController;
    private final AllocationStatus INITIAL_LOCATION;
    protected final AtomicLong zeroUseCounter = new AtomicLong(0);
    protected Map<Long, Integer> devicesAffinity = new ConcurrentHashMap();
    private ReentrantReadWriteLock deviceLock = new ReentrantReadWriteLock();
    private AtomicInteger devPtr = new AtomicInteger(0);
    private final AtomicBoolean wasInitialised = new AtomicBoolean(false);
    private final AffinityManager affinityManager = Nd4j.getAffinityManager();
    private final Map<Integer, ConcurrentHashMap<Long, Long>> deviceAllocations = new ConcurrentHashMap();
    private final Map<Long, ConcurrentHashMap<Long, Long>> zeroAllocations = new ConcurrentHashMap();
    private AtomicLong zeroCounter = new AtomicLong(0);
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();

    public CudaZeroHandler() {
        configuration.setInitialized();
        this.INITIAL_LOCATION = configuration.getFirstMemory();
        switch (configuration.getExecutionModel()) {
            case OPTIMIZED:
            case ASYNCHRONOUS:
                this.flowController = new AsynchronousFlowController();
                this.contextPool = new PackedContextPool();
                break;
            case SEQUENTIAL:
                this.flowController = new SynchronousFlowController();
                this.contextPool = new LimitedContextPool();
                break;
            default:
                throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]");
        }
        switch (configuration.getAllocationModel()) {
            case CACHE_ALL:
                this.memoryProvider = new CudaFullCachingProvider();
                return;
            case CACHE_HOST:
                this.memoryProvider = new CudaCachingZeroProvider();
                return;
            case DIRECT:
                this.memoryProvider = new CudaDirectProvider();
                return;
            default:
                throw new RuntimeException("Unknown AllocationModel: [" + configuration.getAllocationModel() + "]");
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void init(@NonNull Configuration configuration2, @NonNull Allocator allocator) {
        if (configuration2 == null) {
            throw new NullPointerException("configuration");
        }
        if (allocator == null) {
            throw new NullPointerException("allocator");
        }
        configuration = configuration2;
        this.deviceMemoryTracker = new DeviceAllocationsTracker(configuration);
        this.flowController.init(allocator);
    }

    private void pickupHostAllocation(AllocationPoint allocationPoint) {
        long nextInt = RandomUtils.nextInt(0, configuration.getNumberOfGcThreads());
        this.zeroUseCounter.addAndGet(AllocationUtils.getRequiredMemory(allocationPoint.getShape()));
        allocationPoint.setBucketId(Long.valueOf(nextInt));
        if (!this.zeroAllocations.containsKey(Long.valueOf(nextInt))) {
            log.debug("Creating bucketID: " + nextInt);
            synchronized (this) {
                if (!this.zeroAllocations.containsKey(Long.valueOf(nextInt))) {
                    this.zeroAllocations.put(Long.valueOf(nextInt), new ConcurrentHashMap<>());
                }
            }
        }
        this.zeroAllocations.get(Long.valueOf(nextInt)).put(allocationPoint.getObjectId(), allocationPoint.getObjectId());
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public PointersPair alloc(AllocationStatus allocationStatus, AllocationPoint allocationPoint, AllocationShape allocationShape, boolean z) {
        long requiredMemory = AllocationUtils.getRequiredMemory(allocationShape);
        getCudaContext();
        switch (allocationStatus) {
            case HOST:
                if (this.zeroUseCounter.get() + requiredMemory >= configuration.getMaximumZeroAllocation()) {
                    if (requiredMemory > configuration.getMaximumZeroAllocation()) {
                        throw new IllegalStateException("You can't allocate more memory, then allowed with configured value: [" + configuration.getMaximumZeroAllocation() + "]");
                    }
                    while (this.zeroUseCounter.get() + requiredMemory >= configuration.getMaximumZeroAllocation()) {
                        try {
                            log.warn("No available [HOST] memory, sleeping for a while...");
                            log.debug("Currently used: [" + this.zeroUseCounter.get() + "], allocated objects: [" + this.zeroAllocations.get(0) + "]");
                            System.gc();
                            Thread.sleep(1000L);
                        } catch (Exception e) {
                            throw new RuntimeException(e);
                        }
                    }
                }
                PointersPair malloc = this.memoryProvider.malloc(allocationShape, allocationPoint, allocationStatus);
                if (z) {
                    Pointer.memset(malloc.getHostPointer(), 0, requiredMemory);
                    allocationPoint.tickHostWrite();
                }
                pickupHostAllocation(allocationPoint);
                return malloc;
            case DEVICE:
                int intValue = getDeviceId().intValue();
                PointersPair pointersPair = new PointersPair();
                PointersPair pointersPair2 = new PointersPair();
                if (allocationPoint.getPointers() == null || allocationPoint.getPointers().getHostPointer() == null) {
                    pointersPair2 = alloc(AllocationStatus.HOST, allocationPoint, allocationPoint.getShape(), z);
                    pointersPair.setDevicePointer(pointersPair2.getHostPointer());
                    pointersPair.setHostPointer(pointersPair2.getHostPointer());
                    allocationPoint.setAllocationStatus(AllocationStatus.HOST);
                    allocationPoint.setPointers(pointersPair2);
                }
                if (requiredMemory >= configuration.getMaximumSingleHostAllocation() || this.deviceMemoryTracker.getAllocatedSize(Integer.valueOf(intValue)) + requiredMemory >= configuration.getMaximumDeviceAllocation()) {
                    log.warn("Soft limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]", Integer.valueOf(intValue));
                    System.gc();
                    try {
                        Thread.sleep(100L);
                    } catch (Exception e2) {
                    }
                } else if (this.deviceMemoryTracker.reserveAllocationIfPossible(Long.valueOf(Thread.currentThread().getId()), Integer.valueOf(intValue), requiredMemory)) {
                    allocationPoint.setDeviceId(Integer.valueOf(intValue));
                    PointersPair malloc2 = this.memoryProvider.malloc(allocationShape, allocationPoint, allocationStatus);
                    if (malloc2 != null) {
                        pointersPair.setDevicePointer(malloc2.getDevicePointer());
                        allocationPoint.setAllocationStatus(AllocationStatus.DEVICE);
                        if (allocationPoint.getPointers() == null) {
                            throw new RuntimeException("WTF?");
                        }
                        allocationPoint.getPointers().setDevicePointer(malloc2.getDevicePointer());
                        this.deviceAllocations.get(Integer.valueOf(intValue)).put(allocationPoint.getObjectId(), allocationPoint.getObjectId());
                        this.zeroAllocations.get(allocationPoint.getBucketId()).remove(allocationPoint.getObjectId());
                        this.deviceMemoryTracker.addToAllocation(Long.valueOf(Thread.currentThread().getId()), Integer.valueOf(intValue), requiredMemory);
                        allocationPoint.tickHostWrite();
                    } else {
                        log.warn("Out of [DEVICE] memory, host memory will be used instead: deviceId: [{}], requested bytes: [{}]", Integer.valueOf(intValue), Long.valueOf(requiredMemory));
                        pointersPair.setDevicePointer(pointersPair2.getHostPointer());
                        allocationPoint.setAllocationStatus(AllocationStatus.HOST);
                        System.gc();
                        try {
                            Thread.sleep(100L);
                        } catch (Exception e3) {
                        }
                    }
                } else {
                    log.warn("Hard limit on [DEVICE] memory hit, please consider tuning memory parameters, deviceId [{}]", Integer.valueOf(intValue));
                    System.gc();
                    try {
                        Thread.sleep(100L);
                    } catch (Exception e4) {
                    }
                }
                return pointersPair;
            default:
                throw new IllegalStateException("Can't allocate memory on target [" + allocationStatus + "]");
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public boolean pingDeviceForFreeMemory(Integer num, long j) {
        return this.memoryProvider.pingDeviceForFreeMemory(num, j);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void relocate(AllocationStatus allocationStatus, AllocationStatus allocationStatus2, AllocationPoint allocationPoint, AllocationShape allocationShape, CudaContext cudaContext) {
        if (allocationStatus == AllocationStatus.DEVICE && allocationStatus2 == AllocationStatus.HOST) {
            if (allocationPoint.getBuffer() == null) {
                throw new IllegalStateException("Target buffer is NULL!");
            }
            new CudaPointer(allocationPoint.getPointers().getDevicePointer().address());
        } else {
            if (allocationStatus != AllocationStatus.HOST || allocationStatus2 != AllocationStatus.DEVICE) {
                throw new UnsupportedOperationException("Can't relocate data in requested direction: [" + allocationStatus + "] -> [" + allocationStatus2 + "]");
            }
            if (allocationPoint.isConstant()) {
                return;
            }
            if (allocationPoint.getPointers().getDevicePointer() == null) {
                throw new IllegalStateException("devicePointer is NULL!");
            }
            if (this.nativeOps.memcpyAsync(allocationPoint.getPointers().getDevicePointer(), allocationPoint.getPointers().getHostPointer(), AllocationUtils.getRequiredMemory(allocationShape), CudaConstants.cudaMemcpyHostToDevice, cudaContext.getSpecialStream()) == 0) {
                throw new IllegalStateException("MemcpyAsync relocate H2D failed: [" + allocationPoint.getHostPointer().address() + "] -> [" + allocationPoint.getDevicePointer().address() + "]");
            }
            this.flowController.commitTransfer(cudaContext.getSpecialStream());
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    @Deprecated
    public void copyback(AllocationPoint allocationPoint, AllocationShape allocationShape) {
        throw new UnsupportedOperationException("Deprecated call");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    @Deprecated
    public void copyforward(AllocationPoint allocationPoint, AllocationShape allocationShape) {
        log.info("copyforward() called on tp[" + allocationPoint.getObjectId() + "], shape: " + allocationPoint.getShape());
        throw new UnsupportedOperationException("Deprecated call");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    @Deprecated
    public void fallback(AllocationPoint allocationPoint, AllocationShape allocationShape) {
        throw new IllegalStateException("Can't fallback from [" + allocationPoint.getAllocationStatus() + "]");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void free(AllocationPoint allocationPoint, AllocationStatus allocationStatus) {
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            this.deviceMemoryTracker.subFromAllocation(Long.valueOf(Thread.currentThread().getId()), allocationPoint.getDeviceId(), AllocationUtils.getRequiredMemory(allocationPoint.getShape()));
        }
        this.memoryProvider.free(allocationPoint);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public AllocationStatus getInitialLocation() {
        return this.INITIAL_LOCATION;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void initializeDevice(Long l, Integer num) {
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpyAsync(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        CudaPointer cudaPointer = new CudaPointer(allocationPoint.getPointers().getHostPointer().address() + j2);
        CudaContext cudaContext = null;
        if (dataBuffer.isConstant()) {
            Pointer.memcpy(new CudaPointer(allocationPoint.getPointers().getHostPointer().address() + j2, 0L), new CudaPointer(pointer, j), j);
            allocationPoint.tickHostRead();
        } else {
            CudaContext prepareAction = this.flowController.prepareAction(allocationPoint, new AllocationPoint[0]);
            cudaContext = prepareAction;
            if (this.nativeOps.memcpyAsync(cudaPointer, pointer, j, CudaConstants.cudaMemcpyHostToHost, prepareAction.getSpecialStream()) == 0) {
                throw new IllegalStateException("MemcpyAsync H2H failed: [" + pointer.address() + "] -> [" + cudaPointer.address() + "]");
            }
            this.flowController.commitTransfer(cudaContext.getSpecialStream());
            if (allocationPoint.getAllocationStatus() == AllocationStatus.HOST) {
                this.flowController.registerAction(prepareAction, allocationPoint, new AllocationPoint[0]);
            }
        }
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            CudaPointer cudaPointer2 = new CudaPointer(allocationPoint.getPointers().getDevicePointer().address() + j2);
            if (cudaContext == null) {
                cudaContext = this.flowController.prepareAction(allocationPoint, new AllocationPoint[0]);
            }
            if (this.nativeOps.memcpyAsync(cudaPointer2, cudaPointer, j, CudaConstants.cudaMemcpyHostToDevice, cudaContext.getSpecialStream()) == 0) {
                throw new IllegalStateException("MemcpyAsync H2D failed: [" + cudaPointer.address() + "] -> [" + cudaPointer2.address() + "]");
            }
            this.flowController.commitTransfer(cudaContext.getSpecialStream());
            this.flowController.registerAction(cudaContext, allocationPoint, new AllocationPoint[0]);
        }
        allocationPoint.tickDeviceWrite();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpyDevice(DataBuffer dataBuffer, Pointer pointer, long j, long j2, CudaContext cudaContext) {
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        this.nativeOps.memcpyAsync(new CudaPointer(allocationPoint.getPointers().getDevicePointer().address() + j2), pointer, j, CudaConstants.cudaMemcpyDeviceToDevice, cudaContext.getOldStream());
        allocationPoint.tickDeviceWrite();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpySpecial(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        CudaContext cudaContext = getCudaContext();
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        CudaPointer cudaPointer = new CudaPointer(allocationPoint.getPointers().getHostPointer().address() + j2);
        this.nativeOps.memcpyAsync(cudaPointer, pointer, j, CudaConstants.cudaMemcpyHostToHost, cudaContext.getOldStream());
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            this.nativeOps.memcpyAsync(new CudaPointer(allocationPoint.getPointers().getDevicePointer().address() + j2), cudaPointer, j, CudaConstants.cudaMemcpyHostToDevice, cudaContext.getOldStream());
            cudaContext.syncOldStream();
        }
        cudaContext.syncOldStream();
        allocationPoint.tickDeviceWrite();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpyBlocking(DataBuffer dataBuffer, Pointer pointer, long j, long j2) {
        CudaContext cudaContext = getCudaContext();
        memcpyAsync(dataBuffer, pointer, j, j2);
        cudaContext.syncOldStream();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void memcpy(DataBuffer dataBuffer, DataBuffer dataBuffer2) {
        log.info("Memcpy buffer: {} bytes ", Long.valueOf(dataBuffer.length() * dataBuffer.getElementSize()));
        CudaContext cudaContext = getCudaContext();
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        AllocationPoint allocationPoint2 = ((BaseCudaDataBuffer) dataBuffer2).getAllocationPoint();
        CudaPointer cudaPointer = new CudaPointer(allocationPoint.getPointers().getHostPointer().address());
        if (allocationPoint2.getAllocationStatus() == AllocationStatus.DEVICE) {
            this.nativeOps.memcpyAsync(cudaPointer, new CudaPointer(allocationPoint2.getPointers().getDevicePointer().address()), dataBuffer2.length() * dataBuffer2.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, cudaContext.getOldStream());
        } else {
            this.nativeOps.memcpyAsync(cudaPointer, new CudaPointer(allocationPoint2.getPointers().getHostPointer().address()), dataBuffer2.length() * dataBuffer2.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, cudaContext.getOldStream());
        }
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
            this.nativeOps.memcpyAsync(new CudaPointer(allocationPoint.getPointers().getDevicePointer().address()), cudaPointer, dataBuffer2.length() * dataBuffer2.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, cudaContext.getOldStream());
        }
        allocationPoint.tickDeviceWrite();
        cudaContext.syncOldStream();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Pointer getDevicePointer(DataBuffer dataBuffer, CudaContext cudaContext) {
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        if (allocationPoint.getAllocationStatus() != AllocationStatus.HOST || dataBuffer.offset() == 0) {
        }
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE && !allocationPoint.isActualOnDeviceSide()) {
            relocate(AllocationStatus.HOST, AllocationStatus.DEVICE, allocationPoint, allocationPoint.getShape(), cudaContext);
        }
        allocationPoint.tickDeviceRead();
        return new CudaPointer(allocationPoint.getPointers().getDevicePointer(), dataBuffer.length(), dataBuffer.offset() * dataBuffer.getElementSize());
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Pointer getHostPointer(DataBuffer dataBuffer) {
        AllocationPoint allocationPoint = ((BaseCudaDataBuffer) dataBuffer).getAllocationPoint();
        if (allocationPoint.getPointers().getHostPointer() != null) {
            synchronizeThreadDevice(Long.valueOf(Thread.currentThread().getId()), allocationPoint.getDeviceId(), allocationPoint);
            return new CudaPointer(allocationPoint.getPointers().getHostPointer(), dataBuffer.length(), dataBuffer.offset() * dataBuffer.getElementSize());
        }
        log.info("DevicePointer: " + allocationPoint.getPointers().getDevicePointer());
        log.info("HostPointer: " + allocationPoint.getPointers().getHostPointer());
        log.info("AllocStatus: " + allocationPoint.getAllocationStatus());
        throw new RuntimeException("pointer is null");
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void relocateObject(DataBuffer dataBuffer) {
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
        if (allocationPoint.getAllocationStatus() != AllocationStatus.DEVICE) {
            return;
        }
        int intValue = getDeviceId().intValue();
        if (allocationPoint.getDeviceId().intValue() < 0 || allocationPoint.getDeviceId().intValue() != intValue) {
            if (!allocationPoint.isActualOnHostSide()) {
                AtomicAllocator.getInstance().synchronizeHostData(dataBuffer);
            }
            if (!allocationPoint.isActualOnHostSide()) {
                throw new RuntimeException("Buffer synchronization failed");
            }
            if (dataBuffer.isConstant()) {
                throw new RuntimeException("Can't relocateObject() for constant buffer");
            }
            this.memoryProvider.free(allocationPoint);
            this.deviceMemoryTracker.subFromAllocation(Long.valueOf(Thread.currentThread().getId()), allocationPoint.getDeviceId(), AllocationUtils.getRequiredMemory(allocationPoint.getShape()));
            alloc(AllocationStatus.DEVICE, allocationPoint, allocationPoint.getShape(), false);
            CudaContext cudaContext = getCudaContext();
            this.nativeOps.memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), dataBuffer.length() * dataBuffer.getElementSize(), 1, cudaContext.getSpecialStream());
            cudaContext.syncSpecialStream();
            allocationPoint.tickDeviceRead();
            allocationPoint.tickHostRead();
        }
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public boolean promoteObject(DataBuffer dataBuffer) {
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
        if (allocationPoint.getAllocationStatus() != AllocationStatus.HOST) {
            return false;
        }
        if (configuration.getMemoryModel() != Configuration.MemoryModel.DELAYED || allocationPoint.getAllocationStatus() != AllocationStatus.HOST) {
            return true;
        }
        if (dataBuffer.isConstant()) {
            Nd4j.getConstantHandler().moveToConstantSpace(dataBuffer);
            return true;
        }
        PointersPair malloc = this.memoryProvider.malloc(allocationPoint.getShape(), allocationPoint, AllocationStatus.DEVICE);
        if (malloc == null) {
            throw new RuntimeException("PewPew");
        }
        Integer deviceId = getDeviceId();
        allocationPoint.getPointers().setDevicePointer(malloc.getDevicePointer());
        allocationPoint.setAllocationStatus(AllocationStatus.DEVICE);
        this.deviceAllocations.get(deviceId).put(allocationPoint.getObjectId(), allocationPoint.getObjectId());
        this.zeroAllocations.get(allocationPoint.getBucketId()).remove(allocationPoint.getObjectId());
        this.deviceMemoryTracker.addToAllocation(Long.valueOf(Thread.currentThread().getId()), deviceId, AllocationUtils.getRequiredMemory(allocationPoint.getShape()));
        allocationPoint.tickHostWrite();
        return true;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Table<AllocationStatus, Integer, Long> getAllocationStatistics() {
        HashBasedTable create = HashBasedTable.create();
        create.put(AllocationStatus.HOST, 0, Long.valueOf(this.zeroUseCounter.get()));
        for (Integer num : configuration.getAvailableDevices()) {
            create.put(AllocationStatus.DEVICE, num, Long.valueOf(getAllocatedDeviceMemory(num)));
        }
        return create;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedDeviceMemory(Integer num) {
        return this.deviceMemoryTracker.getAllocatedSize(num);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedHostMemory() {
        return this.zeroUseCounter.get();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedDeviceObjects(Integer num) {
        if (this.deviceAllocations.containsKey(num)) {
            return this.deviceAllocations.get(num).size();
        }
        return 0L;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedHostObjects(Long l) {
        if (this.zeroAllocations.containsKey(l)) {
            return this.zeroAllocations.get(l).size();
        }
        return 0L;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public long getAllocatedHostObjects() {
        AtomicLong atomicLong = new AtomicLong(0L);
        Iterator<Long> it = this.zeroAllocations.keySet().iterator();
        while (it.hasNext()) {
            atomicLong.addAndGet(this.zeroAllocations.get(it.next()).size());
        }
        return atomicLong.get();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Set<Long> getDeviceTrackingPoints(Integer num) {
        return !this.deviceAllocations.containsKey(num) ? new HashSet() : this.deviceAllocations.get(num).keySet();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Set<Long> getHostTrackingPoints(Long l) {
        return !this.zeroAllocations.containsKey(l) ? new HashSet() : this.zeroAllocations.get(l).keySet();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void purgeDeviceObject(Long l, Integer num, Long l2, AllocationPoint allocationPoint, boolean z) {
        if (allocationPoint.getAllocationStatus() != AllocationStatus.DEVICE) {
            return;
        }
        this.flowController.waitTillReleased(allocationPoint);
        free(allocationPoint, AllocationStatus.DEVICE);
        if (!this.deviceAllocations.get(num).containsKey(l2)) {
            throw new IllegalStateException("Can't happen ever");
        }
        this.deviceAllocations.get(num).remove(l2);
        if (this.deviceAllocations.get(num).containsKey(l2)) {
            throw new IllegalStateException("Can't happen ever");
        }
        this.deviceMemoryTracker.subFromAllocation(l, num, AllocationUtils.getRequiredMemory(allocationPoint.getShape()));
        allocationPoint.setAllocationStatus(AllocationStatus.HOST);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void purgeZeroObject(Long l, Long l2, AllocationPoint allocationPoint, boolean z) {
        this.zeroAllocations.get(l).remove(l2);
        this.flowController.waitTillReleased(allocationPoint);
        free(allocationPoint, AllocationStatus.HOST);
        allocationPoint.setAllocationStatus(AllocationStatus.DEALLOCATED);
        this.zeroUseCounter.addAndGet(AllocationUtils.getRequiredMemory(allocationPoint.getShape()) * (-1));
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Integer getDeviceId() {
        int intValue = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        if (!this.deviceAllocations.containsKey(Integer.valueOf(intValue))) {
            this.deviceAllocations.put(Integer.valueOf(intValue), new ConcurrentHashMap<>());
        }
        return Integer.valueOf(intValue);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Pointer getDeviceIdPointer() {
        return new CudaPointer(getDeviceId().intValue());
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public Set<Integer> getAvailableDevices() {
        return new HashSet(configuration.getAvailableDevices());
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public ExternalContext getDeviceContext() {
        return new ExternalContext(getCudaContext());
    }

    public CudaContext getCudaContext() {
        return this.contextPool.acquireContextForDevice(getDeviceId());
    }

    protected void initCudaContextForThread(Long l) {
        this.nativeOps.setDevice(getDeviceIdPointer());
        CudaContext cudaContext = new CudaContext();
        cudaContext.initHandle();
        cudaContext.initOldStream();
        cudaContext.initStream();
        cudaContext.associateHandle();
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public boolean isDeviceDependant() {
        return true;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void synchronizeThreadDevice(Long l, Integer num, AllocationPoint allocationPoint) {
        this.flowController.synchronizeToHost(allocationPoint);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public void registerAction(CudaContext cudaContext, INDArray iNDArray, INDArray... iNDArrayArr) {
        this.flowController.registerAction(cudaContext, iNDArray, iNDArrayArr);
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public FlowController getFlowController() {
        return this.flowController;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public ContextPool getContextPool() {
        return this.contextPool;
    }

    @Override // org.nd4j.jita.handler.MemoryHandler
    public MemoryProvider getMemoryProvider() {
        return this.memoryProvider;
    }
}
