package org.nd4j.linalg.api.ops.performance;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.performance.primitives.AveragingTransactionsHolder;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemcpyDirection;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/api/ops/performance/PerformanceTracker.class */
public class PerformanceTracker {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) PerformanceTracker.class);
    private static final PerformanceTracker INSTANCE = new PerformanceTracker();
    private Map<Integer, AveragingTransactionsHolder> bandwidth = new HashMap();
    private Map<Integer, AveragingTransactionsHolder> operations = new HashMap();

    private PerformanceTracker() {
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numberOfDevices; i++) {
            this.bandwidth.put(Integer.valueOf(i), new AveragingTransactionsHolder());
            this.operations.put(Integer.valueOf(i), new AveragingTransactionsHolder());
        }
    }

    public static PerformanceTracker getInstance() {
        return INSTANCE;
    }

    public long addMemoryTransaction(int i, long j, long j2) {
        return addMemoryTransaction(i, j, j2, MemcpyDirection.HOST_TO_HOST);
    }

    public long addMemoryTransaction(int i, long j, long j2, @NonNull MemcpyDirection memcpyDirection) {
        if (memcpyDirection == null) {
            throw new NullPointerException("direction");
        }
        long j3 = (long) (j2 / (j / 1000.0d));
        if (j3 > 0) {
            this.bandwidth.get(Integer.valueOf(i)).addValue(memcpyDirection, Long.valueOf(j3));
        }
        return j3;
    }

    public void clear() {
        Iterator<Integer> it = this.bandwidth.keySet().iterator();
        while (it.hasNext()) {
            this.bandwidth.get(it.next()).clear();
        }
    }

    public long helperStartTransaction() {
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.BANDWIDTH) {
            return System.nanoTime();
        }
        return 0L;
    }

    public void helperRegisterTransaction(int i, long j, long j2, @NonNull MemcpyDirection memcpyDirection) {
        if (memcpyDirection == null) {
            throw new NullPointerException("direction");
        }
        if (Nd4j.getExecutioner().getProfilingMode() == OpExecutioner.ProfilingMode.BANDWIDTH) {
            addMemoryTransaction(i, System.nanoTime() - j, j2, memcpyDirection);
        }
    }

    public Map<Integer, Map<MemcpyDirection, Long>> getCurrentBandwidth() {
        HashMap hashMap = new HashMap();
        for (Integer num : this.bandwidth.keySet()) {
            hashMap.put(num, new HashMap());
            for (MemcpyDirection memcpyDirection : MemcpyDirection.values()) {
                ((Map) hashMap.get(num)).put(memcpyDirection, this.bandwidth.get(num).getAverageValue(memcpyDirection));
            }
        }
        return hashMap;
    }
}
