package org.nd4j.jita.flow.impl;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.jita.allocator.Allocator;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
import org.nd4j.jita.allocator.impl.AllocationPoint;
import org.nd4j.jita.allocator.pointers.cuda.cudaEvent_t;
import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t;
import org.nd4j.jita.allocator.time.TimeProvider;
import org.nd4j.jita.allocator.time.providers.OperativeProvider;
import org.nd4j.jita.allocator.utils.AllocationUtils;
import org.nd4j.jita.concurrency.EventsProvider;
import org.nd4j.jita.conf.Configuration;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.jita.flow.FlowController;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.PropertyAccessor;

@Deprecated
/* loaded from: input_file:org/nd4j/jita/flow/impl/AsynchronousFlowController.class */
public class AsynchronousFlowController implements FlowController {
    private volatile Allocator allocator;
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    protected EventsProvider eventsProvider = new EventsProvider();
    private transient TimeProvider timeProvider = new OperativeProvider();
    protected AtomicLong asyncHit = new AtomicLong(0);
    protected AtomicLong asyncMiss = new AtomicLong(0);
    protected Map<Integer, AtomicLong> lanesCounter = new ConcurrentHashMap();
    private AtomicLong totalHits = new AtomicLong(0);
    protected ArrayList<ArrayList<Queue<cudaEvent_t>>> eventsBarrier = new ArrayList<>();
    protected ArrayList<ArrayList<AtomicLong>> laneClocks = new ArrayList<>();
    protected ArrayList<AtomicLong> deviceClocks = new ArrayList<>();
    private static final Configuration configuration = CudaEnvironment.getInstance().getConfiguration();
    private static Logger log = LoggerFactory.getLogger((Class<?>) AsynchronousFlowController.class);
    protected static final int MAX_EXECUTION_QUEUE = configuration.getCommandQueueLength();
    protected static final AtomicLong eventCounts = new AtomicLong(0);

    public AsynchronousFlowController() {
        int commandLanesNumber = configuration.getCommandLanesNumber();
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numberOfDevices; i++) {
            this.eventsBarrier.add(i, new ArrayList<>());
            this.laneClocks.add(i, new ArrayList<>());
            this.deviceClocks.add(i, new AtomicLong(0L));
            for (int i2 = 0; i2 < commandLanesNumber; i2++) {
                this.eventsBarrier.get(i).add(i2, new ConcurrentLinkedQueue());
                this.laneClocks.get(i).add(i2, new AtomicLong(0L));
            }
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void synchronizeToDevice(AllocationPoint allocationPoint) {
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void init(Allocator allocator) {
        this.allocator = allocator;
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void synchronizeToHost(AllocationPoint allocationPoint) {
        if (allocationPoint.isActualOnHostSide()) {
            return;
        }
        if (!allocationPoint.isConstant()) {
            waitTillFinished(allocationPoint);
        }
        if (allocationPoint.getAllocationStatus() == AllocationStatus.DEVICE && !allocationPoint.isActualOnHostSide()) {
            CudaContext cudaContext = (CudaContext) this.allocator.getDeviceContext().getContext();
            if (this.nativeOps.memcpyAsync(allocationPoint.getHostPointer(), allocationPoint.getDevicePointer(), AllocationUtils.getRequiredMemory(allocationPoint.getShape()), CudaConstants.cudaMemcpyDeviceToHost, cudaContext.getSpecialStream()) == 0) {
                throw new IllegalStateException("MemcpyAsync D2H failed: [" + allocationPoint.getDevicePointer().address() + "] -> [" + allocationPoint.getHostPointer().address() + PropertyAccessor.PROPERTY_KEY_SUFFIX);
            }
            commitTransfer(cudaContext.getSpecialStream());
        }
        allocationPoint.tickHostRead();
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void waitTillFinished(AllocationPoint allocationPoint) {
        cudaEvent_t writeLane = allocationPoint.getWriteLane();
        if (writeLane != null) {
            writeLane.synchronize();
            writeLane.destroy();
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void waitTillReleased(AllocationPoint allocationPoint) {
        waitTillFinished(allocationPoint);
        while (true) {
            cudaEvent_t poll = allocationPoint.getReadLane().poll();
            if (poll == null) {
                return;
            }
            poll.synchronize();
            poll.destroy();
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void registerAction(CudaContext cudaContext, INDArray iNDArray, INDArray... iNDArrayArr) {
        if (this.totalHits.incrementAndGet() % 25000 == 0) {
            log.debug("AsyncHit ratio: [{}]", Float.valueOf(getAsyncHitRatio()));
        }
        cudaEvent_t cudaevent_t = new cudaEvent_t(this.nativeOps.createEvent());
        cudaevent_t.setLaneId(cudaContext.getLaneId());
        this.nativeOps.registerEvent(cudaevent_t, cudaContext.getOldStream());
        if (iNDArray != null) {
            setWriteLane(iNDArray, cudaevent_t);
            this.allocator.tickDeviceWrite(iNDArray);
        }
        for (INDArray iNDArray2 : iNDArrayArr) {
            if (iNDArray2 != null) {
                setReadLane(iNDArray2, cudaevent_t);
            }
        }
        fillTail(this.allocator.getDeviceId().intValue(), cudaevent_t.getLaneId(), cudaevent_t);
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void registerActionAllWrite(CudaContext cudaContext, INDArray... iNDArrayArr) {
    }

    protected void setWriteLane(INDArray iNDArray, cudaEvent_t cudaevent_t) {
        this.allocator.getAllocationPoint(iNDArray).setWriteLane(cudaevent_t);
    }

    protected void setReadLane(INDArray iNDArray, cudaEvent_t cudaevent_t) {
        this.allocator.getAllocationPoint(iNDArray).addReadLane(cudaevent_t);
    }

    protected Queue<cudaEvent_t> getReadLanes(INDArray iNDArray) {
        return this.allocator.getAllocationPoint(iNDArray).getReadLane();
    }

    protected cudaEvent_t getWriteLane(INDArray iNDArray) {
        return this.allocator.getAllocationPoint(iNDArray).getWriteLane();
    }

    protected int hasActiveWrite(INDArray iNDArray) {
        cudaEvent_t writeLane;
        if (iNDArray == null || (writeLane = getWriteLane(iNDArray)) == null || writeLane.isDestroyed()) {
            return -1;
        }
        return writeLane.getLaneId();
    }

    protected int hasActiveWrite(AllocationPoint allocationPoint) {
        cudaEvent_t writeLane = allocationPoint.getWriteLane();
        if (writeLane == null || writeLane.isDestroyed()) {
            return -1;
        }
        return writeLane.getLaneId();
    }

    protected boolean hasActiveReads(AllocationPoint allocationPoint) {
        Queue<cudaEvent_t> readLane = allocationPoint.getReadLane();
        if (readLane.size() == 0) {
            return false;
        }
        AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        for (cudaEvent_t cudaevent_t : new ArrayList(readLane)) {
            if (cudaevent_t != null) {
                atomicBoolean.compareAndSet(false, !cudaevent_t.isDestroyed());
            }
        }
        return atomicBoolean.get();
    }

    protected boolean hasActiveReads(INDArray iNDArray) {
        if (iNDArray == null) {
            return false;
        }
        return hasActiveReads(this.allocator.getAllocationPoint(iNDArray));
    }

    protected boolean isMatchingLanes(int[] iArr) {
        return iArr[0] == iArr[1] || iArr[1] == -1 || iArr[0] == -1;
    }

    protected boolean isMatchingLanes(int i, int[] iArr) {
        return (i == iArr[0] || i == iArr[1]) && isMatchingLanes(iArr);
    }

    protected void synchronizeReadLanes(AllocationPoint allocationPoint) {
        int i = 0;
        while (true) {
            cudaEvent_t poll = allocationPoint.getReadLane().poll();
            if (poll == null) {
                return;
            }
            poll.synchronize();
            poll.destroy();
            i++;
        }
    }

    protected void synchronizeReadLanes(INDArray iNDArray) {
        if (iNDArray == null) {
            return;
        }
        synchronizeReadLanes(this.allocator.getAllocationPoint(iNDArray));
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void registerAction(CudaContext cudaContext, AllocationPoint allocationPoint, AllocationPoint... allocationPointArr) {
        cudaEvent_t cudaevent_t = new cudaEvent_t(this.nativeOps.createEvent());
        cudaevent_t.setLaneId(cudaContext.getLaneId());
        this.nativeOps.registerEvent(cudaevent_t, cudaContext.getOldStream());
        allocationPoint.setWriteLane(cudaevent_t);
        fillTail(this.allocator.getDeviceId().intValue(), cudaevent_t.getLaneId(), cudaevent_t);
    }

    @Override // org.nd4j.jita.flow.FlowController
    public CudaContext prepareAction(AllocationPoint allocationPoint, AllocationPoint... allocationPointArr) {
        if (hasActiveReads(allocationPoint)) {
            synchronizeReadLanes(allocationPoint);
        }
        ContextPack acquireContextPackForDevice = this.allocator.getContextPool().acquireContextPackForDevice(this.allocator.getDeviceId());
        return acquireContextPackForDevice.getContextForLane(Integer.valueOf(acquireContextPackForDevice.nextRandomLane()));
    }

    protected int pickFirstLane(int[] iArr) {
        if (iArr[0] >= 0) {
            return iArr[0];
        }
        if (iArr[1] >= 0) {
            return iArr[1];
        }
        return 0;
    }

    @Override // org.nd4j.jita.flow.FlowController
    public CudaContext prepareAction(INDArray iNDArray, INDArray... iNDArrayArr) {
        ContextPack acquireContextPackForDevice = this.allocator.getContextPool().acquireContextPackForDevice(this.allocator.getDeviceId());
        int i = 0;
        int hasActiveWrite = hasActiveWrite(iNDArray);
        boolean hasActiveReads = hasActiveReads(iNDArray);
        if (iNDArray == null || (!hasActiveReads && hasActiveWrite < 0)) {
            AtomicInteger atomicInteger = new AtomicInteger(0);
            AtomicInteger atomicInteger2 = new AtomicInteger(0);
            int i2 = -1;
            int[] iArr = new int[iNDArrayArr.length + 1];
            Arrays.fill(iArr, -1);
            for (INDArray iNDArray2 : iNDArrayArr) {
                if (iNDArray2 != null) {
                    int hasActiveWrite2 = hasActiveWrite(iNDArray2);
                    if (hasActiveWrite2 >= 0) {
                        iArr[atomicInteger.get()] = hasActiveWrite2;
                        atomicInteger2.incrementAndGet();
                        i2 = hasActiveWrite2;
                    }
                    atomicInteger.incrementAndGet();
                }
            }
            if (atomicInteger2.get() > 0) {
                this.asyncMiss.incrementAndGet();
                if (isMatchingLanes(iArr)) {
                    i = i2;
                } else if (iArr[0] >= 0) {
                    waitTillFinished(this.allocator.getAllocationPoint(iNDArrayArr[0]));
                    i = iArr[1];
                } else if (iArr[1] >= 0) {
                    waitTillFinished(this.allocator.getAllocationPoint(iNDArrayArr[1]));
                    i = iArr[0];
                }
            } else {
                this.asyncHit.incrementAndGet();
                i = acquireContextPackForDevice.nextRandomLane();
            }
        } else {
            AtomicInteger atomicInteger3 = new AtomicInteger(0);
            AtomicInteger atomicInteger4 = new AtomicInteger(0);
            int[] iArr2 = new int[iNDArrayArr.length + 1];
            Arrays.fill(iArr2, -1);
            for (INDArray iNDArray3 : iNDArrayArr) {
                if (iNDArray3 != null) {
                    int hasActiveWrite3 = hasActiveWrite(iNDArray3);
                    if (hasActiveWrite3 >= 0) {
                        iArr2[atomicInteger3.get()] = hasActiveWrite3;
                        atomicInteger4.incrementAndGet();
                    }
                    atomicInteger3.incrementAndGet();
                }
            }
            if (hasActiveReads) {
                synchronizeReadLanes(iNDArray);
            }
            if (atomicInteger4.get() > 0) {
                this.asyncMiss.incrementAndGet();
                if (isMatchingLanes(hasActiveWrite, iArr2)) {
                    i = hasActiveWrite >= 0 ? hasActiveWrite : pickFirstLane(iArr2);
                } else {
                    i = hasActiveWrite >= 0 ? hasActiveWrite : pickFirstLane(iArr2);
                    for (INDArray iNDArray4 : iNDArrayArr) {
                        if (iNDArray4 != null) {
                            waitTillFinished(this.allocator.getAllocationPoint(iNDArray4));
                        }
                    }
                }
            } else {
                this.asyncHit.incrementAndGet();
                if (hasActiveWrite < 0) {
                    hasActiveWrite = acquireContextPackForDevice.nextRandomLane();
                }
                i = hasActiveWrite;
            }
        }
        CudaContext contextForLane = acquireContextPackForDevice.getContextForLane(Integer.valueOf(i));
        if (iNDArray != null) {
            this.allocator.getAllocationPoint(iNDArray).setCurrentContext(contextForLane);
        }
        for (INDArray iNDArray5 : iNDArrayArr) {
            if (iNDArray5 != null) {
                this.allocator.getAllocationPoint(iNDArray5).setCurrentContext(contextForLane);
            }
        }
        if (!this.lanesCounter.containsKey(Integer.valueOf(i))) {
            this.lanesCounter.put(Integer.valueOf(i), new AtomicLong(0L));
        }
        this.lanesCounter.get(Integer.valueOf(i)).incrementAndGet();
        if (contextForLane == null) {
            throw new IllegalStateException("Context shouldn't be null: " + i);
        }
        return contextForLane;
    }

    @Override // org.nd4j.jita.flow.FlowController
    public CudaContext prepareActionAllWrite(INDArray... iNDArrayArr) {
        return null;
    }

    private float getAsyncHitRatio() {
        return ((float) (this.asyncHit.get() * 100)) / ((float) (this.asyncHit.get() + this.asyncMiss.get()));
    }

    protected void fillTail(int i, int i2, cudaEvent_t cudaevent_t) {
        this.eventsBarrier.get(i).get(i2).add(cudaevent_t);
        this.laneClocks.get(i).get(i2).set(this.deviceClocks.get(i).incrementAndGet());
    }

    protected void sweepTail() {
        cudaEvent_t poll;
        Integer deviceId = this.allocator.getDeviceId();
        int i = 0;
        long j = this.deviceClocks.get(deviceId.intValue()).get();
        for (int i2 = 0; i2 < configuration.getCommandLanesNumber(); i2++) {
            Queue<cudaEvent_t> queue = this.eventsBarrier.get(deviceId.intValue()).get(i2);
            if ((queue.size() >= MAX_EXECUTION_QUEUE || this.laneClocks.get(deviceId.intValue()).get(i2).get() < j - MAX_EXECUTION_QUEUE) && (poll = queue.poll()) != null && !poll.isDestroyed()) {
                poll.synchronize();
                poll.destroy();
                i++;
            }
        }
        this.deviceClocks.get(deviceId.intValue()).incrementAndGet();
    }

    protected void cutTail() {
        Integer deviceId = this.allocator.getDeviceId();
        for (int i = 0; i < configuration.getCommandLanesNumber(); i++) {
            Queue<cudaEvent_t> queue = this.eventsBarrier.get(deviceId.intValue()).get(i);
            while (true) {
                cudaEvent_t poll = queue.poll();
                if (poll != null) {
                    poll.synchronize();
                    poll.destroy();
                }
            }
        }
    }

    @Override // org.nd4j.jita.flow.FlowController
    public void commitTransfer(cudaStream_t cudastream_t) {
        sweepTail();
        cudastream_t.synchronize();
    }

    @Override // org.nd4j.jita.flow.FlowController
    public EventsProvider getEventsProvider() {
        return this.eventsProvider;
    }
}
