package org.nd4j.linalg.profiler.data;

import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.profiler.data.primitives.ComparableAtomicLong;
import org.nd4j.linalg.profiler.data.primitives.TimeSet;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/profiler/data/StringAggregator.class */
public class StringAggregator {
    private Map<String, TimeSet> times = new ConcurrentHashMap();
    private Map<String, ComparableAtomicLong> longCalls = new ConcurrentHashMap();
    private static final long THRESHOLD = 100000;

    public void reset() {
        Iterator<String> it = this.times.keySet().iterator();
        while (it.hasNext()) {
            this.times.put(it.next(), new TimeSet());
        }
        Iterator<String> it2 = this.longCalls.keySet().iterator();
        while (it2.hasNext()) {
            this.longCalls.put(it2.next(), new ComparableAtomicLong(0L));
        }
    }

    public void putTime(String str, Op op, long j) {
        if (!this.times.containsKey(str)) {
            this.times.put(str, new TimeSet());
        }
        this.times.get(str).addTime(j);
        if (j > THRESHOLD) {
            String str2 = str + StringUtils.SPACE + op.opName() + " (" + op.opNum() + ")";
            if (!this.longCalls.containsKey(str2)) {
                this.longCalls.put(str2, new ComparableAtomicLong(0L));
            }
            this.longCalls.get(str2).incrementAndGet();
        }
    }

    public void putTime(String str, CustomOp customOp, long j) {
        if (!this.times.containsKey(str)) {
            this.times.put(str, new TimeSet());
        }
        this.times.get(str).addTime(j);
        if (j > THRESHOLD) {
            String str2 = str + StringUtils.SPACE + customOp.opName() + " (" + customOp.opHash() + ")";
            if (!this.longCalls.containsKey(str2)) {
                this.longCalls.put(str2, new ComparableAtomicLong(0L));
            }
            this.longCalls.get(str2).incrementAndGet();
        }
    }

    public void putTime(String str, long j) {
        if (!this.times.containsKey(str)) {
            this.times.put(str, new TimeSet());
        }
        this.times.get(str).addTime(j);
    }

    protected long getMedian(String str) {
        return this.times.get(str).getMedian();
    }

    protected long getAverage(String str) {
        return this.times.get(str).getAverage();
    }

    protected long getMaximum(String str) {
        return this.times.get(str).getMaximum();
    }

    protected long getMinimum(String str) {
        return this.times.get(str).getMinimum();
    }

    protected long getSum(String str) {
        return this.times.get(str).getSum();
    }

    public String asPercentageString() {
        StringBuilder sb = new StringBuilder();
        Map sortMapByValue = ArrayUtil.sortMapByValue(this.times);
        AtomicLong atomicLong = new AtomicLong(0L);
        Iterator it = sortMapByValue.keySet().iterator();
        while (it.hasNext()) {
            atomicLong.addAndGet(getSum((String) it.next()));
        }
        long j = atomicLong.get();
        sb.append("Total time spent: ").append(j / 1000000).append(" ms.").append("\n");
        for (String str : sortMapByValue.keySet()) {
            long sum = getSum(str);
            sb.append(str).append("  >>> ").append(" perc: ").append(j == 0 ? 0.0f : (((float) sum) * 100.0f) / ((float) atomicLong.get())).append(StringUtils.SPACE).append("Time spent: ").append(sum / 1000000).append(" ms");
            sb.append("\n");
        }
        return sb.toString();
    }

    public String asString() {
        StringBuilder sb = new StringBuilder();
        Map sortMapByValue = ArrayUtil.sortMapByValue(this.times);
        for (String str : sortMapByValue.keySet()) {
            long maximum = getMaximum(str);
            long minimum = getMinimum(str);
            long average = getAverage(str);
            long median = getMedian(str);
            sb.append(str).append("  >>> ");
            if (this.longCalls.size() == 0) {
                sb.append(StringUtils.SPACE).append(((TimeSet) sortMapByValue.get(str)).size()).append(" calls; ");
            }
            sb.append("Min: ").append(minimum).append(" ns; ").append("Max: ").append(maximum).append(" ns; ").append("Average: ").append(average).append(" ns; ").append("Median: ").append(median).append(" ns; ");
            sb.append("\n");
        }
        sb.append("\n");
        Map sortMapByValue2 = ArrayUtil.sortMapByValue(this.longCalls);
        for (String str2 : sortMapByValue2.keySet()) {
            sb.append(str2).append("  >>> ").append(((ComparableAtomicLong) sortMapByValue2.get(str2)).get());
            sb.append("\n");
        }
        sb.append("\n");
        return sb.toString();
    }
}
