package org.nd4j.jita.workspace;

import java.util.List;
import java.util.Queue;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.Deallocator;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.pointers.PointersPair;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/jita/workspace/CudaWorkspaceDeallocator.class */
public class CudaWorkspaceDeallocator implements Deallocator {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) CudaWorkspaceDeallocator.class);
    private PointersPair pointersPair;
    private Queue<PointersPair> pinnedPointers;
    private List<PointersPair> externalPointers;

    public CudaWorkspaceDeallocator(@NonNull CudaWorkspace cudaWorkspace) {
        if (cudaWorkspace == null) {
            throw new NullPointerException("workspace is marked @NonNull but is null");
        }
        this.pointersPair = cudaWorkspace.workspace();
        this.pinnedPointers = cudaWorkspace.pinnedPointers();
        this.externalPointers = cudaWorkspace.externalPointers();
    }

    @Override // org.nd4j.linalg.api.memory.Deallocator
    public void deallocate() {
        log.trace("Deallocating CUDA workspace");
        if (this.pointersPair != null) {
            if (this.pointersPair.getDevicePointer() != null) {
                Nd4j.getMemoryManager().release(this.pointersPair.getDevicePointer(), MemoryKind.DEVICE);
            }
            if (this.pointersPair.getHostPointer() != null) {
                Nd4j.getMemoryManager().release(this.pointersPair.getHostPointer(), MemoryKind.HOST);
            }
        }
        for (PointersPair pointersPair : this.externalPointers) {
            if (pointersPair != null) {
                if (pointersPair.getHostPointer() != null) {
                    Nd4j.getMemoryManager().release(pointersPair.getHostPointer(), MemoryKind.HOST);
                }
                if (pointersPair.getDevicePointer() != null) {
                    Nd4j.getMemoryManager().release(pointersPair.getDevicePointer(), MemoryKind.DEVICE);
                }
            }
        }
        for (PointersPair pointersPair2 : this.externalPointers) {
            if (pointersPair2 != null) {
                if (pointersPair2.getHostPointer() != null) {
                    Nd4j.getMemoryManager().release(pointersPair2.getHostPointer(), MemoryKind.HOST);
                }
                if (pointersPair2.getDevicePointer() != null) {
                    Nd4j.getMemoryManager().release(pointersPair2.getDevicePointer(), MemoryKind.DEVICE);
                }
            }
        }
        while (true) {
            PointersPair poll = this.pinnedPointers.poll();
            if (poll == null) {
                return;
            }
            if (poll.getHostPointer() != null) {
                Nd4j.getMemoryManager().release(poll.getHostPointer(), MemoryKind.HOST);
            }
            if (poll.getDevicePointer() != null) {
                Nd4j.getMemoryManager().release(poll.getDevicePointer(), MemoryKind.DEVICE);
            }
        }
    }
}
