package org.nd4j.jita.concurrency;

import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.concurrency.BasicAffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/concurrency/CudaAffinityManager.class */
public class CudaAffinityManager extends BasicAffinityManager {
    private static final Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    private static Logger logger = LoggerFactory.getLogger(CudaAffinityManager.class);
    private Map<Long, Integer> affinityMap = new ConcurrentHashMap();
    private AtomicInteger devPtr = new AtomicInteger(0);
    private ThreadLocal<AtomicBoolean> affiliated = new ThreadLocal<>();

    public Integer getDeviceForCurrentThread() {
        return getDeviceForThread(Thread.currentThread().getId());
    }

    public Integer getDeviceForThread(Thread thread) {
        return getDeviceForThread(thread.getId());
    }

    public Integer getDeviceForThread(long j) {
        if (!this.affinityMap.containsKey(Long.valueOf(j))) {
            Integer nextDevice = getNextDevice(j);
            this.affinityMap.put(Long.valueOf(j), nextDevice);
            this.affiliated.set(new AtomicBoolean(false));
            if (j == Thread.currentThread().getId()) {
                NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(nextDevice.intValue()));
                this.affiliated.get().set(true);
            }
            return nextDevice;
        }
        if (j == Thread.currentThread().getId()) {
            if (this.affiliated.get() == null) {
                this.affiliated.set(new AtomicBoolean(false));
            }
            if (!this.affiliated.get().get()) {
                int intValue = this.affinityMap.get(Long.valueOf(j)).intValue();
                NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(intValue));
                this.affiliated.get().set(true);
                return Integer.valueOf(intValue);
            }
        }
        return this.affinityMap.get(Long.valueOf(j));
    }

    public void attachThreadToDevice(Thread thread, Integer num) {
        attachThreadToDevice(thread.getId(), num);
    }

    public void attachThreadToDevice(long j, Integer num) {
        logger.debug("Manually mapping thread [{}] to device [{}], out of [{}] devices...", new Object[]{Long.valueOf(j), num, Integer.valueOf(new ArrayList(configuration.getAvailableDevices()).size())});
        this.affinityMap.put(Long.valueOf(j), num);
    }

    protected Integer getNextDevice(long j) {
        Integer num;
        ArrayList arrayList = new ArrayList(configuration.getAvailableDevices());
        if (configuration.isForcedSingleGPU()) {
            num = configuration.getAvailableDevices().get(0);
            logger.debug("Single device is forced, mapping to device [{}]", num);
        } else {
            synchronized (this) {
                num = (Integer) arrayList.get(this.devPtr.getAndIncrement());
                if (this.devPtr.get() >= arrayList.size()) {
                    this.devPtr.set(0);
                }
                logger.debug("Mapping thread [{}] to device [{}], out of [{}] devices...", new Object[]{Long.valueOf(j), num, Integer.valueOf(arrayList.size())});
            }
        }
        return num;
    }

    public int getNumberOfDevices() {
        return new ArrayList(configuration.getAvailableDevices()).size();
    }

    public void touch(INDArray iNDArray) {
        if (iNDArray == null) {
            return;
        }
        touch(iNDArray.data());
        touch(iNDArray.shapeInfoDataBuffer());
    }

    public void touch(DataBuffer dataBuffer) {
        if (dataBuffer == null) {
            return;
        }
        if (AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).isConstant()) {
            Nd4j.getConstantHandler().relocateConstantSpace(dataBuffer);
        } else {
            AtomicAllocator.getInstance().getMemoryHandler().relocateObject(dataBuffer);
        }
    }

    public synchronized INDArray replicateToDevice(Integer num, INDArray iNDArray) {
        if (iNDArray == null) {
            return null;
        }
        if (iNDArray.isView()) {
            throw new UnsupportedOperationException("It's impossible to replicate View");
        }
        int[] shape = iNDArray.shape();
        int[] stride = iNDArray.stride();
        int elementWiseStride = iNDArray.elementWiseStride();
        char ordering = iNDArray.ordering();
        iNDArray.length();
        AtomicAllocator.getInstance().getPointer(iNDArray, (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext());
        int intValue = getDeviceForCurrentThread().intValue();
        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(num.intValue()));
        attachThreadToDevice(Thread.currentThread().getId(), num);
        INDArray createArrayFromShapeBuffer = Nd4j.createArrayFromShapeBuffer(replicateToDevice(num, iNDArray.data()), Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 0, elementWiseStride, ordering));
        attachThreadToDevice(Thread.currentThread().getId(), Integer.valueOf(intValue));
        NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(intValue));
        return createArrayFromShapeBuffer;
    }

    public DataBuffer replicateToDevice(Integer num, DataBuffer dataBuffer) {
        if (dataBuffer == null) {
            return null;
        }
        int intValue = AtomicAllocator.getInstance().getDeviceId().intValue();
        if (intValue != num.intValue()) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(num.intValue()));
            Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), num);
        }
        DataBuffer createBuffer = Nd4j.createBuffer(dataBuffer.length(), false);
        AtomicAllocator.getInstance().memcpy(createBuffer, dataBuffer);
        if (intValue != num.intValue()) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(new CudaPointer(intValue));
            Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread().getId(), Integer.valueOf(intValue));
        }
        return createBuffer;
    }

    public void tagLocation(INDArray iNDArray, AffinityManager.Location location) {
        if (location == AffinityManager.Location.HOST) {
            AtomicAllocator.getInstance().getAllocationPoint(iNDArray).tickHostWrite();
        } else if (location == AffinityManager.Location.DEVICE) {
            AtomicAllocator.getInstance().getAllocationPoint(iNDArray).tickDeviceWrite();
        }
    }

    public void tagLocation(DataBuffer dataBuffer, AffinityManager.Location location) {
        if (location == AffinityManager.Location.HOST) {
            AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).tickHostWrite();
        } else if (location == AffinityManager.Location.DEVICE) {
            AtomicAllocator.getInstance().getAllocationPoint(dataBuffer).tickDeviceWrite();
        }
    }
}
