package org.nd4j.linalg.api.memory;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.nd4j.linalg.api.memory.enums.AllocationKind;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/memory/AllocationsTracker.class */
public class AllocationsTracker {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) AllocationsTracker.class);
    private static final AllocationsTracker INSTANCE = new AllocationsTracker();
    private Map<Integer, DeviceAllocationsTracker> devices = new ConcurrentHashMap();

    protected AllocationsTracker() {
    }

    public static AllocationsTracker getInstance() {
        return INSTANCE;
    }

    protected DeviceAllocationsTracker trackerForDevice(Integer num) {
        DeviceAllocationsTracker deviceAllocationsTracker = this.devices.get(num);
        if (deviceAllocationsTracker == null) {
            synchronized (this) {
                deviceAllocationsTracker = this.devices.get(num);
                if (deviceAllocationsTracker == null) {
                    deviceAllocationsTracker = new DeviceAllocationsTracker();
                    this.devices.put(num, deviceAllocationsTracker);
                }
            }
        }
        return deviceAllocationsTracker;
    }

    public void markAllocated(AllocationKind allocationKind, Integer num, long j) {
        trackerForDevice(num).updateState(allocationKind, j);
    }

    public void markReleased(AllocationKind allocationKind, Integer num, long j) {
        trackerForDevice(num).updateState(allocationKind, -j);
    }

    public long bytesOnDevice(Integer num) {
        return bytesOnDevice(AllocationKind.GENERAL, num);
    }

    public long bytesOnDevice(AllocationKind allocationKind, Integer num) {
        return trackerForDevice(num).getState(allocationKind);
    }
}
