package org.nd4j.jita.concurrency;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/concurrency/CudaAffinityManager.class */
public class CudaAffinityManager extends BasicAffinityManager {
    private static Logger logger = LoggerFactory.getLogger((Class<?>) CudaAffinityManager.class);
    private Map<Long, Integer> affinityMap = new ConcurrentHashMap();
    private AtomicInteger devPtr = new AtomicInteger(0);
    private ThreadLocal<AtomicBoolean> affiliated = new ThreadLocal<>();
    private AtomicInteger numberOfDevices = new AtomicInteger(-1);

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public Integer getDeviceForCurrentThread() {
        return Integer.valueOf(NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice());
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public Integer getDeviceForThread(long j) {
        Integer num = this.affinityMap.get(Long.valueOf(j));
        if (num == null) {
            if (j != Thread.currentThread().getId()) {
                throw new RuntimeException("Affinity for thread [" + j + "] wasn't defined yet");
            }
            num = Integer.valueOf(NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice());
            this.affinityMap.put(Long.valueOf(j), num);
        }
        return num;
    }

    protected Integer getNextDevice(long j) {
        Integer num;
        if (CudaEnvironment.getInstance().getConfiguration().isForcedSingleGPU() || getNumberOfDevices() <= 0) {
            num = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().get(0);
            logger.debug("Single device is forced, mapping to device [{}]", num);
        } else {
            synchronized (this) {
                num = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().get(this.devPtr.getAndIncrement());
                if (this.devPtr.get() >= CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().size()) {
                    this.devPtr.set(0);
                }
                Thread currentThread = Thread.currentThread();
                logger.debug("Mapping thread [{} - {}] to device [{}], out of [{}] devices...", Long.valueOf(j), currentThread.getId() == j ? currentThread.getName() : "N/A", num, Integer.valueOf(CudaEnvironment.getInstance().getConfiguration().getAvailableDevices().size()));
            }
        }
        return num;
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public int getNumberOfDevices() {
        if (this.numberOfDevices.get() < 0) {
            synchronized (this) {
                if (this.numberOfDevices.get() < 1) {
                    this.numberOfDevices.set(NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices());
                }
            }
        }
        return this.numberOfDevices.get();
    }

    @Override // org.nd4j.linalg.api.concurrency.AffinityManager
    public void touch(INDArray iNDArray) {
        if (iNDArray == null) {
            return;
        }
        touch(iNDArray.data());
        touch(iNDArray.shapeInfoDataBuffer());
    }

    @Override // org.nd4j.linalg.api.concurrency.AffinityManager
    public void touch(DataBuffer dataBuffer) {
        if (dataBuffer == null) {
            return;
        }
        if (AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).isConstant()) {
            Nd4j.getConstantHandler().relocateConstantSpace(dataBuffer);
        } else {
            AtomicAllocator.getInstance().getMemoryHandler().relocateObject(dataBuffer);
        }
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public synchronized INDArray replicateToDevice(Integer num, INDArray iNDArray) {
        if (iNDArray == null) {
            return null;
        }
        if (iNDArray.isS()) {
            return iNDArray.dup(iNDArray.ordering());
        }
        if (iNDArray.isView()) {
            throw new UnsupportedOperationException("It's impossible to replicate View");
        }
        long[] shape = iNDArray.shape();
        long[] stride = iNDArray.stride();
        int elementWiseStride = iNDArray.elementWiseStride();
        char ordering = iNDArray.ordering();
        iNDArray.length();
        DataType dataType = iNDArray.dataType();
        boolean isEmpty = iNDArray.isEmpty();
        AtomicAllocator.getInstance().getPointer(iNDArray, AtomicAllocator.getInstance().getDeviceContext());
        int intValue = getDeviceForCurrentThread().intValue();
        if (intValue != num.intValue()) {
            unsafeSetDevice(num);
        }
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(replicateToDevice(num, iNDArray.data()), Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, elementWiseStride, ordering, dataType, isEmpty).getFirst());
        if (intValue != num.intValue()) {
            unsafeSetDevice(Integer.valueOf(intValue));
        }
        return createArrayFromShapeBuffer;
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public DataBuffer replicateToDevice(Integer num, DataBuffer dataBuffer) {
        if (dataBuffer == null) {
            return null;
        }
        int intValue = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        if (intValue != num.intValue()) {
            Nd4j.getAffinityManager().unsafeSetDevice(num);
        }
        DataBuffer createBuffer = Nd4j.createBuffer(dataBuffer.dataType(), dataBuffer.length(), false);
        AtomicAllocator.getInstance().memcpy(createBuffer, dataBuffer);
        if (intValue != num.intValue()) {
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(intValue));
        }
        return createBuffer;
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public void tagLocation(INDArray iNDArray, AffinityManager.Location location) {
        if (iNDArray.isEmpty()) {
            return;
        }
        if (location == AffinityManager.Location.HOST) {
            AtomicAllocator.getInstance().getAllocationPoint(iNDArray).tickHostWrite();
            return;
        }
        if (location == AffinityManager.Location.DEVICE) {
            AtomicAllocator.getInstance().getAllocationPoint(iNDArray).tickDeviceWrite();
        } else if (location == AffinityManager.Location.EVERYWHERE) {
            AtomicAllocator.getInstance().getAllocationPoint(iNDArray).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(iNDArray).tickHostRead();
        }
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public void tagLocation(DataBuffer dataBuffer, AffinityManager.Location location) {
        if (location == AffinityManager.Location.HOST) {
            AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).tickHostWrite();
            return;
        }
        if (location == AffinityManager.Location.DEVICE) {
            AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).tickDeviceWrite();
        } else if (location == AffinityManager.Location.EVERYWHERE) {
            AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).tickDeviceWrite();
            AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).tickHostRead();
        }
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public Integer getDeviceForArray(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        return AtomicAllocator.getInstance().getDeviceId(iNDArray);
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public void unsafeSetDevice(Integer num) {
        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(num.intValue());
        AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext();
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public void ensureLocation(INDArray iNDArray, AffinityManager.Location location) {
        if (iNDArray.isEmpty() || iNDArray.isS()) {
            return;
        }
        ((BaseCudaDataBuffer) iNDArray.data()).lazyAllocateHostPointer();
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(iNDArray);
        switch (location) {
            case HOST:
                AtomicAllocator.getInstance().synchronizeHostData(iNDArray);
                return;
            case DEVICE:
                AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(allocationPoint);
                return;
            case EVERYWHERE:
            default:
                AtomicAllocator.getInstance().synchronizeHostData(iNDArray);
                AtomicAllocator.getInstance().getFlowController().synchronizeToDevice(allocationPoint);
                return;
        }
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public AffinityManager.Location getActiveLocation(INDArray iNDArray) {
        if (iNDArray.isEmpty()) {
            return AffinityManager.Location.EVERYWHERE;
        }
        AllocationPoint allocationPoint = AtomicAllocator.getInstance().getAllocationPoint(iNDArray);
        return (allocationPoint.isActualOnDeviceSide() && allocationPoint.isActualOnHostSide()) ? AffinityManager.Location.EVERYWHERE : allocationPoint.isActualOnDeviceSide() ? AffinityManager.Location.DEVICE : AffinityManager.Location.HOST;
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public boolean isCrossDeviceAccessSupported() {
        return NativeOpsHolder.getInstance().getDeviceNativeOps().isP2PAvailable() && CudaEnvironment.getInstance().getConfiguration().isCrossDeviceAccessAllowed();
    }

    @Override // org.nd4j.linalg.api.concurrency.BasicAffinityManager, org.nd4j.linalg.api.concurrency.AffinityManager
    public void allowCrossDeviceAccess(boolean z) {
        CudaEnvironment.getInstance().getConfiguration().allowCrossDeviceAccess(z);
    }
}
