package org.nd4j.jita.constant;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.cache.ArrayDescriptor;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.CudaDoubleDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaFloatDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaHalfDataBuffer;
import org.nd4j.linalg.jcublas.buffer.CudaIntDataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/constant/ProtectedCudaConstantHandler.class */
public class ProtectedCudaConstantHandler implements ConstantHandler {
    protected FlowController flowController;
    private static final int MAX_CONSTANT_LENGTH = 49152;
    private static final int MAX_BUFFER_LENGTH = 272;
    private static ProtectedCudaConstantHandler ourInstance = new ProtectedCudaConstantHandler();
    protected static final ConstantProtector protector = ConstantProtector.getInstance();
    private static Logger logger = LoggerFactory.getLogger(ProtectedCudaConstantHandler.class);
    protected Map<Integer, AtomicLong> constantOffsets = new HashMap();
    protected Map<Integer, Semaphore> deviceLocks = new ConcurrentHashMap();
    protected Map<Integer, Map<ArrayDescriptor, DataBuffer>> buffersCache = new HashMap();
    protected Map<Integer, Pointer> deviceAddresses = new HashMap();
    private Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected Semaphore lock = new Semaphore(1);
    private boolean resetHappened = false;

    public static ProtectedCudaConstantHandler getInstance() {
        return ourInstance;
    }

    private ProtectedCudaConstantHandler() {
    }

    public void purgeConstants() {
        this.buffersCache = new HashMap();
        protector.purgeProtector();
        this.resetHappened = true;
        logger.info("Resetting Constants...");
        for (Integer num : this.constantOffsets.keySet()) {
            this.constantOffsets.get(num).set(0L);
            this.buffersCache.put(num, new ConcurrentHashMap());
        }
    }

    protected int amountOfEntries(int i) {
        ensureMaps(Integer.valueOf(i));
        return this.buffersCache.get(0).size();
    }

    public synchronized long moveToConstantSpace(DataBuffer dataBuffer) {
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(dataBuffer);
        long requiredMemory = AllocationUtils.getRequiredMemory(allocationPoint.getShape());
        long j = this.constantOffsets.get(deviceId).get();
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        if (j + requiredMemory >= 49152 || requiredMemory > 272) {
            if (allocationPoint.getAllocationStatus() == AllocationStatus.HOST && this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
                AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, allocationPoint, allocationPoint.getShape(), false);
            }
            this.nativeOps.memcpyAsync(allocationPoint.getPointers().getDevicePointer(), allocationPoint.getPointers().getHostPointer(), requiredMemory, 1, cudaContext.getSpecialStream());
            this.flowController.commitTransfer(cudaContext.getSpecialStream());
            allocationPoint.setConstant(true);
            allocationPoint.tickDeviceWrite();
            allocationPoint.tickHostRead();
            allocationPoint.setDeviceId(deviceId.intValue());
            protector.persistDataBuffer(dataBuffer);
            return 0L;
        }
        long j2 = requiredMemory;
        if (dataBuffer.dataType() == DataBuffer.Type.HALF) {
            if (j2 % 4 != 0) {
                j2 += 2;
            }
        } else if (Nd4j.dataType() == DataBuffer.Type.DOUBLE) {
            if ((j2 / 4) % 2 != 0) {
                j2 += 4;
            }
            long j3 = j / 4;
            while (j3 % 2 != 0) {
                long addAndGet = this.constantOffsets.get(deviceId).addAndGet(4L);
                j3 = addAndGet / 4;
                if (addAndGet > 49152) {
                    break;
                }
            }
        }
        long andAdd = this.constantOffsets.get(deviceId).getAndAdd(j2);
        if (andAdd >= 49152) {
            if (allocationPoint.getAllocationStatus() == AllocationStatus.HOST && this.configuration.getMemoryModel() == Configuration.MemoryModel.DELAYED) {
                AtomicAllocator.getInstance().getMemoryHandler().alloc(AllocationStatus.DEVICE, allocationPoint, allocationPoint.getShape(), false);
            }
            this.nativeOps.memcpyAsync(allocationPoint.getPointers().getDevicePointer(), allocationPoint.getPointers().getHostPointer(), requiredMemory, 1, cudaContext.getSpecialStream());
            this.flowController.commitTransfer(cudaContext.getSpecialStream());
            allocationPoint.setConstant(true);
            allocationPoint.tickDeviceWrite();
            allocationPoint.tickHostRead();
            allocationPoint.setDeviceId(deviceId.intValue());
            protector.persistDataBuffer(dataBuffer);
            return 0L;
        }
        this.nativeOps.memcpyConstantAsync(andAdd, allocationPoint.getPointers().getHostPointer(), requiredMemory, 1, cudaContext.getSpecialStream());
        this.flowController.commitTransfer(cudaContext.getSpecialStream());
        long address = this.deviceAddresses.get(deviceId).address() + andAdd;
        allocationPoint.setAllocationStatus(AllocationStatus.CONSTANT);
        allocationPoint.getPointers().setDevicePointer(new CudaPointer(address));
        allocationPoint.setConstant(true);
        allocationPoint.tickDeviceWrite();
        allocationPoint.setDeviceId(deviceId.intValue());
        allocationPoint.tickHostRead();
        protector.persistDataBuffer(dataBuffer);
        return address;
    }

    public DataBuffer relocateConstantSpace(DataBuffer dataBuffer) {
        ensureMaps(AtomicAllocator.getInstance().getDeviceId());
        if (dataBuffer instanceof CudaIntDataBuffer) {
            return getConstantBuffer(dataBuffer.asInt());
        }
        if (dataBuffer instanceof CudaFloatDataBuffer) {
            return getConstantBuffer(dataBuffer.asFloat());
        }
        if (dataBuffer instanceof CudaDoubleDataBuffer) {
            return getConstantBuffer(dataBuffer.asDouble());
        }
        if (dataBuffer instanceof CudaHalfDataBuffer) {
            return getConstantBuffer(dataBuffer.asFloat());
        }
        throw new IllegalStateException("Unknown CudaDataBuffer type");
    }

    private void ensureMaps(Integer num) {
        if (this.buffersCache.containsKey(num)) {
            return;
        }
        if (this.flowController == null) {
            this.flowController = AtomicAllocator.getInstance().getFlowController();
        }
        try {
            try {
                this.lock.acquire();
                if (!this.buffersCache.containsKey(num)) {
                    this.buffersCache.put(num, new ConcurrentHashMap());
                    this.constantOffsets.put(num, new AtomicLong(0L));
                    this.deviceLocks.put(num, new Semaphore(1));
                    this.deviceAddresses.put(num, this.nativeOps.getConstantSpace());
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } finally {
            this.lock.release();
        }
    }

    public DataBuffer getConstantBuffer(int[] iArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(iArr);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);
        if (this.buffersCache.get(deviceId).containsKey(arrayDescriptor)) {
            return this.buffersCache.get(deviceId).get(arrayDescriptor);
        }
        DataBuffer createBuffer = Nd4j.createBuffer(iArr);
        createBuffer.setConstant(true);
        moveToConstantSpace(createBuffer);
        this.buffersCache.get(deviceId).put(arrayDescriptor, createBuffer);
        return createBuffer;
    }

    public DataBuffer getConstantBuffer(float[] fArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(fArr);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);
        if (this.buffersCache.get(deviceId).containsKey(arrayDescriptor)) {
            return this.buffersCache.get(deviceId).get(arrayDescriptor);
        }
        DataBuffer createBuffer = Nd4j.createBuffer(fArr);
        createBuffer.setConstant(true);
        moveToConstantSpace(createBuffer);
        this.buffersCache.get(deviceId).put(arrayDescriptor, createBuffer);
        return createBuffer;
    }

    public DataBuffer getConstantBuffer(double[] dArr) {
        ArrayDescriptor arrayDescriptor = new ArrayDescriptor(dArr);
        Integer deviceId = AtomicAllocator.getInstance().getDeviceId();
        ensureMaps(deviceId);
        if (this.buffersCache.get(deviceId).containsKey(arrayDescriptor)) {
            return this.buffersCache.get(deviceId).get(arrayDescriptor);
        }
        DataBuffer createBuffer = Nd4j.createBuffer(dArr);
        createBuffer.setConstant(true);
        moveToConstantSpace(createBuffer);
        this.buffersCache.get(deviceId).put(arrayDescriptor, createBuffer);
        return createBuffer;
    }
}
