package org.nd4j.jita.workspace;

import java.util.List;
import java.util.Queue;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.impl.AllocationShape;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.impl.MemoryTracker;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.AllocationsTracker;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.nd4j.linalg.api.memory.enums.DebugMode;
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.enums.MirroringPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.memory.pointers.PointersPair;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.linalg.memory.abstracts.Nd4jWorkspace;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/workspace/CudaWorkspace.class */
public class CudaWorkspace extends Nd4jWorkspace {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) CudaWorkspace.class);

    public CudaWorkspace(@NonNull WorkspaceConfiguration workspaceConfiguration) {
        super(workspaceConfiguration);
        if (workspaceConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
    }

    public CudaWorkspace(@NonNull WorkspaceConfiguration workspaceConfiguration, @NonNull String str) {
        super(workspaceConfiguration, str);
        if (workspaceConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        if (str == null) {
            throw new NullPointerException("workspaceId is marked @NonNull but is null");
        }
    }

    public CudaWorkspace(@NonNull WorkspaceConfiguration workspaceConfiguration, @NonNull String str, Integer num) {
        super(workspaceConfiguration, str);
        if (workspaceConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        if (str == null) {
            throw new NullPointerException("workspaceId is marked @NonNull but is null");
        }
        this.deviceId = num.intValue();
    }

    @Override // org.nd4j.linalg.memory.abstracts.Nd4jWorkspace
    protected void init() {
        if (this.workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) {
            throw new ND4JIllegalStateException("CUDA do not support MMAP workspaces yet");
        }
        super.init();
        if (this.currentSize.get() > 0) {
            this.isInit.set(true);
            long j = this.currentSize.get();
            if (this.isDebug.get()) {
                log.info("Allocating [{}] workspace on device_{}, {} bytes...", this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Long.valueOf(j));
            }
            if (this.isDebug.get()) {
                Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread();
            }
            Pointer allocate = this.memoryManager.allocate(j + 1024, MemoryKind.HOST, false);
            if (allocate == null) {
                throw new ND4JIllegalStateException("Can't allocate memory for workspace");
            }
            this.workspace.setHostPointer(new PagedPointer(allocate));
            if (this.workspaceConfiguration.getPolicyMirroring() != MirroringPolicy.HOST_ONLY) {
                this.workspace.setDevicePointer(new PagedPointer(this.memoryManager.allocate(j + 1024, MemoryKind.DEVICE, false)));
                AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), j + 1024);
                MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), j + 1024);
            }
        }
    }

    @Override // org.nd4j.linalg.memory.abstracts.Nd4jWorkspace, org.nd4j.linalg.api.memory.MemoryWorkspace
    public PagedPointer alloc(long j, DataType dataType, boolean z) {
        return alloc(j, MemoryKind.DEVICE, dataType, z);
    }

    @Override // org.nd4j.linalg.memory.abstracts.Nd4jWorkspace, org.nd4j.linalg.api.memory.MemoryWorkspace
    public synchronized void destroyWorkspace(boolean z) {
        long andSet = this.currentSize.getAndSet(0L);
        reset();
        if (z) {
            clearExternalAllocations();
        }
        clearPinnedAllocations(z);
        if (this.workspace.getHostPointer() != null) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(this.workspace.getHostPointer());
        }
        if (this.workspace.getDevicePointer() != null) {
            NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(this.workspace.getDevicePointer(), 0);
            AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), andSet + 1024);
            MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), andSet + 1024);
        }
        this.workspace.setDevicePointer(null);
        this.workspace.setHostPointer(null);
    }

    @Override // org.nd4j.linalg.memory.abstracts.Nd4jWorkspace, org.nd4j.linalg.api.memory.MemoryWorkspace
    public PagedPointer alloc(long j, MemoryKind memoryKind, DataType dataType, boolean z) {
        long sizeOfDataType = j / Nd4j.sizeOfDataType(dataType);
        if (j % 8 != 0) {
            j += 8 - (j % 8);
        }
        if (!this.isUsed.get()) {
            if (this.disabledCounter.incrementAndGet() % 10 == 0) {
                log.warn("Worskpace was turned off, and wasn't enabled after {} allocations", Long.valueOf(this.disabledCounter.get()));
            }
            if (memoryKind != MemoryKind.DEVICE) {
                PagedPointer pagedPointer = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.HOST, z), sizeOfDataType);
                this.externalAllocations.add(new PointersPair(pagedPointer, null));
                return pagedPointer;
            }
            PagedPointer pagedPointer2 = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.DEVICE, z), sizeOfDataType);
            this.externalAllocations.add(new PointersPair(null, pagedPointer2));
            MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), j);
            return pagedPointer2;
        }
        boolean z2 = (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && j + this.cycleAllocations.get() > this.initialBlockSize.get() && this.initialBlockSize.get() > 0 && memoryKind == MemoryKind.DEVICE) || this.trimmedMode.get();
        if (z2 && this.workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && !this.trimmedMode.get()) {
            this.trimmedMode.set(true);
            this.trimmedStep.set(this.stepsCount.get());
        }
        if (memoryKind != MemoryKind.DEVICE) {
            if (memoryKind != MemoryKind.HOST) {
                throw new ND4JIllegalStateException("Unknown MemoryKind was passed in: " + memoryKind);
            }
            if (this.hostOffset.get() + j <= this.currentSize.get() && !z2 && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
                PagedPointer withOffset = this.workspace.getHostPointer().withOffset(this.hostOffset.getAndAdd(j), sizeOfDataType);
                if (z) {
                    Pointer.memset(withOffset, 0, j);
                }
                return withOffset;
            }
            new AllocationShape(j / Nd4j.sizeOfDataType(dataType), Nd4j.sizeOfDataType(dataType), dataType);
            switch (this.workspaceConfiguration.getPolicySpill()) {
                case REALLOCATE:
                case EXTERNAL:
                    if (!z2) {
                        PagedPointer pagedPointer3 = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.HOST, z), sizeOfDataType);
                        this.externalAllocations.add(new PointersPair(pagedPointer3, null));
                        return pagedPointer3;
                    }
                    PagedPointer pagedPointer4 = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.HOST, z), sizeOfDataType);
                    pagedPointer4.isLeaked();
                    this.pinnedAllocations.add(new PointersPair(Long.valueOf(this.stepsCount.get()), 0L, pagedPointer4, null));
                    return pagedPointer4;
                case FAIL:
                default:
                    throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
            }
        }
        if (this.deviceOffset.get() + j <= this.currentSize.get() && !z2 && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
            this.cycleAllocations.addAndGet(j);
            long andAdd = this.deviceOffset.getAndAdd(j);
            if (this.workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) {
                return null;
            }
            PagedPointer withOffset2 = this.workspace.getDevicePointer().withOffset(andAdd, sizeOfDataType);
            if (this.isDebug.get()) {
                log.info("Workspace [{}] device_{}: alloc array of {} bytes, capacity of {} elements; prevOffset: {}; newOffset: {}; size: {}; address: {}", this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Long.valueOf(j), Long.valueOf(sizeOfDataType), Long.valueOf(andAdd), Long.valueOf(this.deviceOffset.get()), Long.valueOf(this.currentSize.get()), Long.valueOf(withOffset2.address()));
            }
            if (z) {
                CudaContext deviceContext = AtomicAllocator.getInstance().getDeviceContext();
                if (NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(withOffset2, 0, j, 0, deviceContext.getSpecialStream()) == 0) {
                    throw new ND4JIllegalStateException("memset failed device_" + Nd4j.getAffinityManager().getDeviceForCurrentThread());
                }
                deviceContext.syncSpecialStream();
            }
            return withOffset2;
        }
        if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && this.currentSize.get() > 0 && !z2 && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) {
            reset();
            this.resetPlanned.set(true);
            return alloc(j, memoryKind, dataType, z);
        }
        if (z2) {
            this.pinnedAllocationsSize.addAndGet(j);
        } else {
            this.spilledAllocationsSize.addAndGet(j);
        }
        if (this.isDebug.get()) {
            log.info("Workspace [{}] device_{}: spilled DEVICE array of {} bytes, capacity of {} elements", this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Long.valueOf(j), Long.valueOf(sizeOfDataType));
        }
        new AllocationShape(j / Nd4j.sizeOfDataType(dataType), Nd4j.sizeOfDataType(dataType), dataType);
        this.cycleAllocations.addAndGet(j);
        if (this.workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) {
            return null;
        }
        switch (this.workspaceConfiguration.getPolicySpill()) {
            case REALLOCATE:
            case EXTERNAL:
                if (z2) {
                    this.pinnedCount.incrementAndGet();
                    PagedPointer pagedPointer5 = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.DEVICE, z), sizeOfDataType);
                    pagedPointer5.isLeaked();
                    this.pinnedAllocations.add(new PointersPair(Long.valueOf(this.stepsCount.get()), Long.valueOf(j), null, pagedPointer5));
                    MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), j);
                    return pagedPointer5;
                }
                this.externalCount.incrementAndGet();
                PagedPointer pagedPointer6 = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.DEVICE, z), sizeOfDataType);
                pagedPointer6.isLeaked();
                PointersPair pointersPair = new PointersPair(null, pagedPointer6);
                pointersPair.setRequiredMemory(Long.valueOf(j));
                this.externalAllocations.add(pointersPair);
                MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), j);
                return pagedPointer6;
            case FAIL:
            default:
                throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
        }
    }

    @Override // org.nd4j.linalg.memory.abstracts.Nd4jWorkspace
    protected void clearPinnedAllocations(boolean z) {
        if (this.isDebug.get()) {
            log.info("Workspace [{}] device_{} threadId {} cycle {}: clearing pinned allocations...", this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Long.valueOf(Thread.currentThread().getId()), Long.valueOf(this.cyclesCount.get()));
        }
        while (!this.pinnedAllocations.isEmpty()) {
            PointersPair peek = this.pinnedAllocations.peek();
            if (peek == null) {
                throw new RuntimeException();
            }
            long longValue = peek.getAllocationCycle().longValue();
            long j = this.stepsCount.get();
            if (this.isDebug.get()) {
                log.info("Allocation step: {}; Current step: {}", Long.valueOf(longValue), Long.valueOf(j));
            }
            if (longValue + 2 >= j && !z) {
                return;
            }
            this.pinnedAllocations.remove();
            if (peek.getDevicePointer() != null) {
                NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(peek.getDevicePointer(), 0);
                MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), peek.getRequiredMemory().longValue());
                this.pinnedCount.decrementAndGet();
                if (this.isDebug.get()) {
                    log.info("deleting external device allocation ");
                }
            }
            if (peek.getHostPointer() != null) {
                NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(peek.getHostPointer());
                if (this.isDebug.get()) {
                    log.info("deleting external host allocation ");
                }
            }
            this.pinnedAllocationsSize.addAndGet(peek.getRequiredMemory().longValue() * (-1));
        }
    }

    @Override // org.nd4j.linalg.memory.abstracts.Nd4jWorkspace
    protected void clearExternalAllocations() {
        if (this.isDebug.get()) {
            log.info("Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Long.valueOf(Thread.currentThread().getId()), this.guid);
        }
        Nd4j.getExecutioner().commit();
        try {
            for (PointersPair pointersPair : this.externalAllocations) {
                if (pointersPair.getHostPointer() != null) {
                    NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pointersPair.getHostPointer());
                    if (this.isDebug.get()) {
                        log.info("deleting external host allocation... ");
                    }
                }
                if (pointersPair.getDevicePointer() != null) {
                    NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pointersPair.getDevicePointer(), 0);
                    if (this.isDebug.get()) {
                        log.info("deleting external device allocation... ");
                    }
                    Long requiredMemory = pointersPair.getRequiredMemory();
                    if (requiredMemory != null) {
                        AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory.longValue());
                        MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue(), requiredMemory.longValue());
                    }
                }
            }
            this.spilledAllocationsSize.set(0L);
            this.externalCount.set(0);
            this.externalAllocations.clear();
        } catch (Exception e) {
            log.error("RC: Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Long.valueOf(Thread.currentThread().getId()), this.guid);
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.memory.abstracts.Nd4jWorkspace
    protected void resetWorkspace() {
        if (this.currentSize.get() < 1) {
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public PointersPair workspace() {
        return this.workspace;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Queue<PointersPair> pinnedPointers() {
        return this.pinnedAllocations;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<PointersPair> externalPointers() {
        return this.externalAllocations;
    }

    @Override // org.nd4j.linalg.api.memory.Deallocatable
    public Deallocator deallocator() {
        return new CudaWorkspaceDeallocator(this);
    }

    @Override // org.nd4j.linalg.api.memory.Deallocatable
    public String getUniqueId() {
        return "Workspace_" + getId();
    }

    @Override // org.nd4j.linalg.api.memory.Deallocatable
    public int targetDevice() {
        return this.deviceId;
    }

    @Override // org.nd4j.linalg.api.memory.MemoryWorkspace
    public long getPrimaryOffset() {
        return getDeviceOffset();
    }
}
