package org.nd4j.autodiff.listeners.debugging;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/autodiff/listeners/debugging/OpBenchmarkListener.class */
public class OpBenchmarkListener extends BaseListener {
    private final Operation operation;
    private final Mode mode;
    private final long minRuntime;
    private Map<String, OpExec> aggregateModeMap;
    private long start;
    private boolean printActive;
    private boolean printDone;

    /* loaded from: input_file:org/nd4j/autodiff/listeners/debugging/OpBenchmarkListener$Mode.class */
    public enum Mode {
        SINGLE_ITER_PRINT,
        AGGREGATE
    }

    /* loaded from: input_file:org/nd4j/autodiff/listeners/debugging/OpBenchmarkListener$OpExec.class */
    public static class OpExec {
        private final String opOwnName;
        private final String opName;
        private final Class<?> opClass;
        private List<Long> runtimeMs;
        private String firstIter;

        public String toString() {
            DecimalFormat decimalFormat = new DecimalFormat("0.000");
            return this.opOwnName + " - op class: " + this.opClass.getSimpleName() + " (op name: " + this.opName + ")\ncount: " + this.runtimeMs.size() + ", mean: " + decimalFormat.format(avgMs()) + "ms, std: " + decimalFormat.format(stdMs()) + "ms, min: " + minMs() + "ms, max: " + maxMs() + "ms\n" + this.firstIter;
        }

        public double avgMs() {
            long j = 0;
            Iterator<Long> it2 = this.runtimeMs.iterator();
            while (it2.hasNext()) {
                j += it2.next().longValue();
            }
            return j / this.runtimeMs.size();
        }

        public double stdMs() {
            return Nd4j.createFromArray(ArrayUtil.toArrayLong(this.runtimeMs)).stdNumber().doubleValue();
        }

        public long minMs() {
            return Nd4j.createFromArray(ArrayUtil.toArrayLong(this.runtimeMs)).minNumber().longValue();
        }

        public long maxMs() {
            return Nd4j.createFromArray(ArrayUtil.toArrayLong(this.runtimeMs)).maxNumber().longValue();
        }

        public OpExec(String str, String str2, Class<?> cls, List<Long> list, String str3) {
            this.opOwnName = str;
            this.opName = str2;
            this.opClass = cls;
            this.runtimeMs = list;
            this.firstIter = str3;
        }

        public String getOpOwnName() {
            return this.opOwnName;
        }

        public String getOpName() {
            return this.opName;
        }

        public Class<?> getOpClass() {
            return this.opClass;
        }

        public List<Long> getRuntimeMs() {
            return this.runtimeMs;
        }

        public String getFirstIter() {
            return this.firstIter;
        }

        public void setRuntimeMs(List<Long> list) {
            this.runtimeMs = list;
        }

        public void setFirstIter(String str) {
            this.firstIter = str;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof OpExec)) {
                return false;
            }
            OpExec opExec = (OpExec) obj;
            if (!opExec.canEqual(this)) {
                return false;
            }
            String opOwnName = getOpOwnName();
            String opOwnName2 = opExec.getOpOwnName();
            if (opOwnName == null) {
                if (opOwnName2 != null) {
                    return false;
                }
            } else if (!opOwnName.equals(opOwnName2)) {
                return false;
            }
            String opName = getOpName();
            String opName2 = opExec.getOpName();
            if (opName == null) {
                if (opName2 != null) {
                    return false;
                }
            } else if (!opName.equals(opName2)) {
                return false;
            }
            Class<?> opClass = getOpClass();
            Class<?> opClass2 = opExec.getOpClass();
            if (opClass == null) {
                if (opClass2 != null) {
                    return false;
                }
            } else if (!opClass.equals(opClass2)) {
                return false;
            }
            List<Long> runtimeMs = getRuntimeMs();
            List<Long> runtimeMs2 = opExec.getRuntimeMs();
            if (runtimeMs == null) {
                if (runtimeMs2 != null) {
                    return false;
                }
            } else if (!runtimeMs.equals(runtimeMs2)) {
                return false;
            }
            String firstIter = getFirstIter();
            String firstIter2 = opExec.getFirstIter();
            return firstIter == null ? firstIter2 == null : firstIter.equals(firstIter2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof OpExec;
        }

        public int hashCode() {
            String opOwnName = getOpOwnName();
            int hashCode = (1 * 59) + (opOwnName == null ? 43 : opOwnName.hashCode());
            String opName = getOpName();
            int hashCode2 = (hashCode * 59) + (opName == null ? 43 : opName.hashCode());
            Class<?> opClass = getOpClass();
            int hashCode3 = (hashCode2 * 59) + (opClass == null ? 43 : opClass.hashCode());
            List<Long> runtimeMs = getRuntimeMs();
            int hashCode4 = (hashCode3 * 59) + (runtimeMs == null ? 43 : runtimeMs.hashCode());
            String firstIter = getFirstIter();
            return (hashCode4 * 59) + (firstIter == null ? 43 : firstIter.hashCode());
        }
    }

    public OpBenchmarkListener(Operation operation, @NonNull Mode mode) {
        this(operation, mode, 0L);
        if (mode == null) {
            throw new NullPointerException("mode is marked @NonNull but is null");
        }
    }

    public OpBenchmarkListener(Operation operation, @NonNull Mode mode, long j) {
        if (mode == null) {
            throw new NullPointerException("mode is marked @NonNull but is null");
        }
        this.operation = operation;
        this.mode = mode;
        this.minRuntime = j;
    }

    @Override // org.nd4j.autodiff.listeners.Listener
    public boolean isActive(Operation operation) {
        return this.operation == null || this.operation == operation;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void operationStart(SameDiff sameDiff, Operation operation) {
        if (this.printDone) {
            return;
        }
        if (this.operation == null || this.operation == operation) {
            this.printActive = true;
        }
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void operationEnd(SameDiff sameDiff, Operation operation) {
        if (this.printDone) {
            return;
        }
        if (this.operation == null || this.operation == operation) {
            this.printActive = false;
            this.printDone = true;
        }
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void preOpExecution(SameDiff sameDiff, At at, SameDiffOp sameDiffOp) {
        this.start = System.currentTimeMillis();
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void opExecution(SameDiff sameDiff, At at, MultiDataSet multiDataSet, SameDiffOp sameDiffOp, INDArray[] iNDArrayArr) {
        long currentTimeMillis = System.currentTimeMillis();
        if (this.mode == Mode.SINGLE_ITER_PRINT && this.printActive && currentTimeMillis - this.start > this.minRuntime) {
            System.out.println(getOpString(sameDiffOp, Long.valueOf(currentTimeMillis)));
            return;
        }
        if (this.mode == Mode.AGGREGATE) {
            if (this.aggregateModeMap == null) {
                this.aggregateModeMap = new LinkedHashMap();
            }
            if (!this.aggregateModeMap.containsKey(sameDiffOp.getName())) {
                this.aggregateModeMap.put(sameDiffOp.getName(), new OpExec(sameDiffOp.getName(), sameDiffOp.getOp().opName(), sameDiffOp.getOp().getClass(), new ArrayList(), getOpString(sameDiffOp, null)));
            }
            this.aggregateModeMap.get(sameDiffOp.getName()).getRuntimeMs().add(Long.valueOf(currentTimeMillis - this.start));
        }
    }

    private String getOpString(SameDiffOp sameDiffOp, Long l) {
        StringBuilder sb = new StringBuilder();
        sb.append(sameDiffOp.getName()).append(" - ").append(sameDiffOp.getOp().getClass().getSimpleName()).append("(").append(sameDiffOp.getOp().opName()).append(") - ");
        if (l != null) {
            sb.append(l.longValue() - this.start).append(" ms\n");
        }
        if (sameDiffOp.getOp() instanceof DynamicCustomOp) {
            DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) sameDiffOp.getOp();
            int i = 0;
            for (INDArray iNDArray : dynamicCustomOp.inputArguments()) {
                int i2 = i;
                i++;
                sb.append("  in ").append(i2).append(": ").append(iNDArray.shapeInfoToString()).append("\n");
            }
            int i3 = 0;
            for (INDArray iNDArray2 : dynamicCustomOp.outputArguments()) {
                int i4 = i3;
                i3++;
                sb.append("  out ").append(i4).append(": ").append(iNDArray2.shapeInfoToString()).append("\n");
            }
            long[] iArgs = dynamicCustomOp.iArgs();
            boolean[] bArgs = dynamicCustomOp.bArgs();
            double[] tArgs = dynamicCustomOp.tArgs();
            if (iArgs != null && iArgs.length > 0) {
                sb.append("  iargs: ").append(Arrays.toString(iArgs)).append("\n");
            }
            if (bArgs != null && bArgs.length > 0) {
                sb.append("  bargs: ").append(Arrays.toString(bArgs)).append("\n");
            }
            if (tArgs != null && tArgs.length > 0) {
                sb.append("  targs: ").append(Arrays.toString(tArgs)).append("\n");
            }
        } else {
            Op op = (Op) sameDiffOp.getOp();
            if (op.x() != null) {
                sb.append("  x: ").append(op.x().shapeInfoToString());
            }
            if (op.y() != null) {
                sb.append("  y: ").append(op.y().shapeInfoToString());
            }
            if (op.z() != null) {
                sb.append("  z: ").append(op.z().shapeInfoToString());
            }
        }
        return sb.toString();
    }

    public Operation getOperation() {
        return this.operation;
    }

    public Mode getMode() {
        return this.mode;
    }

    public long getMinRuntime() {
        return this.minRuntime;
    }

    public Map<String, OpExec> getAggregateModeMap() {
        return this.aggregateModeMap;
    }

    public boolean isPrintDone() {
        return this.printDone;
    }

    private long getStart() {
        return this.start;
    }

    private boolean isPrintActive() {
        return this.printActive;
    }
}
