package org.nd4j.autodiff.listeners.impl;

import com.google.flatbuffers.Table;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.graph.UIGraphStructure;
import org.nd4j.graph.UIStaticInfoRecord;
import org.nd4j.graph.ui.LogFileWriter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/nd4j/autodiff/listeners/impl/UIListener.class */
public class UIListener extends BaseListener {
    private FileMode fileMode;
    private File logFile;
    private int lossPlotFreq;
    private int performanceStatsFrequency;
    private int updateRatioFrequency;
    private UpdateRatio updateRatioType;
    private int histogramFrequency;
    private HistogramType[] histogramTypes;
    private int opProfileFrequency;
    private Map<Pair<String, Integer>, List<Evaluation.Metric>> trainEvalMetrics;
    private int trainEvalFrequency;
    private TestEvaluation testEvaluation;
    private int learningRateFrequency;
    private MultiDataSet currentIterDataSet;
    private LogFileWriter writer;
    private boolean wroteLossNames;
    private boolean wroteLearningRateName;
    private Set<String> relevantOpsForEval;
    private Map<Pair<String, Integer>, Evaluation> epochTrainEval;
    private boolean wroteEvalNames;
    private boolean wroteEvalNamesIter;
    private int firstUpdateRatioIter;
    private boolean checkStructureForRestore;

    /* loaded from: input_file:org/nd4j/autodiff/listeners/impl/UIListener$Builder.class */
    public static class Builder {
        private File logFile;
        private HistogramType[] histogramTypes;
        private Map<Pair<String, Integer>, List<Evaluation.Metric>> trainEvalMetrics;
        private FileMode fileMode = FileMode.CREATE_OR_APPEND;
        private int lossPlotFreq = 1;
        private int performanceStatsFrequency = -1;
        private int updateRatioFrequency = -1;
        private UpdateRatio updateRatioType = UpdateRatio.MEAN_MAGNITUDE;
        private int histogramFrequency = -1;
        private int opProfileFrequency = -1;
        private int trainEvalFrequency = 10;
        private TestEvaluation testEvaluation = null;
        private int learningRateFrequency = 10;

        public Builder(@NonNull File file) {
            if (file == null) {
                throw new NullPointerException("logFile is marked @NonNull but is null");
            }
            this.logFile = file;
        }

        public Builder fileMode(FileMode fileMode) {
            this.fileMode = fileMode;
            return this;
        }

        public Builder plotLosses(int i) {
            this.lossPlotFreq = i;
            return this;
        }

        public Builder performanceStats(int i) {
            this.performanceStatsFrequency = i;
            return this;
        }

        public Builder trainEvaluationMetrics(String str, int i, Evaluation.Metric... metricArr) {
            if (this.trainEvalMetrics == null) {
                this.trainEvalMetrics = new LinkedHashMap();
            }
            Pair<String, Integer> pair = new Pair<>(str, Integer.valueOf(i));
            if (!this.trainEvalMetrics.containsKey(pair)) {
                this.trainEvalMetrics.put(pair, new ArrayList());
            }
            List<Evaluation.Metric> list = this.trainEvalMetrics.get(pair);
            for (Evaluation.Metric metric : metricArr) {
                if (!list.contains(metric)) {
                    list.add(metric);
                }
            }
            return this;
        }

        public Builder trainAccuracy(String str, int i) {
            return trainEvaluationMetrics(str, i, Evaluation.Metric.ACCURACY);
        }

        public Builder trainF1(String str, int i) {
            return trainEvaluationMetrics(str, i, Evaluation.Metric.F1);
        }

        public Builder trainEvalFrequency(int i) {
            this.trainEvalFrequency = i;
            return this;
        }

        public Builder updateRatios(int i) {
            return updateRatios(i, UpdateRatio.MEAN_MAGNITUDE);
        }

        public Builder updateRatios(int i, UpdateRatio updateRatio) {
            this.updateRatioFrequency = i;
            this.updateRatioType = updateRatio;
            return this;
        }

        public Builder histograms(int i, HistogramType... histogramTypeArr) {
            this.histogramFrequency = i;
            this.histogramTypes = histogramTypeArr;
            return this;
        }

        public Builder profileOps(int i) {
            this.opProfileFrequency = i;
            return this;
        }

        public Builder testEvaluation(TestEvaluation testEvaluation) {
            this.testEvaluation = testEvaluation;
            return this;
        }

        public Builder learningRate(int i) {
            this.learningRateFrequency = i;
            return this;
        }

        public UIListener build() {
            return new UIListener(this);
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/listeners/impl/UIListener$FileMode.class */
    public enum FileMode {
        CREATE,
        APPEND,
        CREATE_OR_APPEND,
        CREATE_APPEND_NOCHECK
    }

    /* loaded from: input_file:org/nd4j/autodiff/listeners/impl/UIListener$HistogramType.class */
    public enum HistogramType {
        PARAMETERS,
        PARAMETER_GRADIENTS,
        PARAMETER_UPDATES,
        ACTIVATIONS,
        ACTIVATION_GRADIENTS
    }

    /* loaded from: input_file:org/nd4j/autodiff/listeners/impl/UIListener$TestEvaluation.class */
    public static class TestEvaluation {
    }

    /* loaded from: input_file:org/nd4j/autodiff/listeners/impl/UIListener$UpdateRatio.class */
    public enum UpdateRatio {
        L2,
        MEAN_MAGNITUDE
    }

    private UIListener(Builder builder) {
        this.firstUpdateRatioIter = -1;
        this.fileMode = builder.fileMode;
        this.logFile = builder.logFile;
        this.lossPlotFreq = builder.lossPlotFreq;
        this.performanceStatsFrequency = builder.performanceStatsFrequency;
        this.updateRatioFrequency = builder.updateRatioFrequency;
        this.updateRatioType = builder.updateRatioType;
        this.histogramFrequency = builder.histogramFrequency;
        this.histogramTypes = builder.histogramTypes;
        this.opProfileFrequency = builder.opProfileFrequency;
        this.trainEvalMetrics = builder.trainEvalMetrics;
        this.trainEvalFrequency = builder.trainEvalFrequency;
        this.testEvaluation = builder.testEvaluation;
        this.learningRateFrequency = builder.learningRateFrequency;
        switch (this.fileMode) {
            case CREATE:
                Preconditions.checkState(!this.logFile.exists(), "Log file already exists and fileMode is set to CREATE: %s\nEither delete the existing file, specify a path that doesn't exist, or set the UIListener to another mode such as CREATE_OR_APPEND", this.logFile);
                break;
            case APPEND:
                Preconditions.checkState(this.logFile.exists(), "Log file does not exist and fileMode is set to APPEND: %s\nEither specify a path to an existing log file for this model, or set the UIListener to another mode such as CREATE_OR_APPEND", this.logFile);
                break;
        }
        if (this.logFile.exists()) {
            restoreLogFile();
        }
    }

    protected void restoreLogFile() {
        if ((this.logFile.length() == 0 && this.fileMode == FileMode.CREATE_OR_APPEND) || this.fileMode == FileMode.APPEND) {
            this.logFile.delete();
            return;
        }
        try {
            this.writer = new LogFileWriter(this.logFile);
            if (this.fileMode == FileMode.APPEND || this.fileMode == FileMode.CREATE_OR_APPEND) {
                try {
                    LogFileWriter.StaticInfo readStatic = this.writer.readStatic();
                    List<Pair<UIStaticInfoRecord, Table>> data = readStatic.getData();
                    if (readStatic != null) {
                        for (int i = 0; i < data.size(); i++) {
                            if (data.get(i).getFirst().infoType() == 0) {
                                this.checkStructureForRestore = true;
                            }
                        }
                    }
                } catch (IOException e) {
                    throw new RuntimeException("Error restoring existing log file, static info at path: " + this.logFile.getAbsolutePath(), e);
                }
            }
        } catch (IOException e2) {
            throw new RuntimeException("Error restoring existing log file at path: " + this.logFile.getAbsolutePath(), e2);
        }
    }

    protected void checkStructureForRestore(SameDiff sameDiff) {
        try {
            LogFileWriter.StaticInfo readStatic = this.writer.readStatic();
            List<Pair<UIStaticInfoRecord, Table>> data = readStatic.getData();
            if (readStatic != null) {
                UIGraphStructure uIGraphStructure = null;
                int i = 0;
                while (true) {
                    if (i >= data.size()) {
                        break;
                    }
                    if (data.get(i).getFirst().infoType() == 0) {
                        uIGraphStructure = (UIGraphStructure) data.get(i).getSecond();
                        break;
                    }
                    i++;
                }
                if (uIGraphStructure != null) {
                    int inputsLength = uIGraphStructure.inputsLength();
                    ArrayList arrayList = new ArrayList(inputsLength);
                    for (int i2 = 0; i2 < inputsLength; i2++) {
                        arrayList.add(uIGraphStructure.inputs(i2));
                    }
                    List<String> inputs = sameDiff.inputs();
                    if (inputs.size() != arrayList.size() || !inputs.containsAll(arrayList)) {
                        throw new IllegalStateException("Error continuing collection of UI stats in existing model file " + this.logFile.getAbsolutePath() + ": Model structure differs. Existing (file) model placeholders: " + arrayList + " vs. current model placeholders: " + inputs + ". To disable this check, use FileMode.CREATE_APPEND_NOCHECK though this may result issues when rendering data via UI");
                    }
                    int variablesLength = uIGraphStructure.variablesLength();
                    ArrayList arrayList2 = new ArrayList(variablesLength);
                    for (int i3 = 0; i3 < variablesLength; i3++) {
                        arrayList2.add(uIGraphStructure.variables(i3).name());
                    }
                    List<SDVariable> variables = sameDiff.variables();
                    ArrayList<String> arrayList3 = new ArrayList(variables.size());
                    Iterator<SDVariable> it = variables.iterator();
                    while (it.hasNext()) {
                        arrayList3.add(it.next().getVarName());
                    }
                    if (arrayList3.size() != arrayList2.size() || !arrayList3.containsAll(arrayList2)) {
                        int i4 = 0;
                        ArrayList arrayList4 = new ArrayList();
                        for (String str : arrayList3) {
                            if (!arrayList2.contains(str)) {
                                i4++;
                                if (arrayList4.size() < 10) {
                                    arrayList4.add(str);
                                }
                            }
                        }
                        StringBuilder sb = new StringBuilder();
                        sb.append("Error continuing collection of UI stats in existing model file ").append(this.logFile.getAbsolutePath()).append(": Current model structure differs vs. model structure in file - ").append(i4).append(" variable names differ.");
                        if (arrayList4.size() == i4) {
                            sb.append("\nVariables in new model not present in existing (file) model: ").append(arrayList4);
                        } else {
                            sb.append("\nFirst 10 variables in new model not present in existing (file) model: ").append(arrayList4);
                        }
                        sb.append("\nTo disable this check, use FileMode.CREATE_APPEND_NOCHECK though this may result issues when rendering data via UI");
                        throw new IllegalStateException(sb.toString());
                    }
                }
            }
            this.checkStructureForRestore = false;
        } catch (IOException e) {
            throw new RuntimeException("Error restoring existing log file, static info at path: " + this.logFile.getAbsolutePath(), e);
        }
    }

    protected void initalizeWriter(SameDiff sameDiff) {
        try {
            initializeHelper(sameDiff);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected void initializeHelper(SameDiff sameDiff) throws IOException {
        this.writer = new LogFileWriter(this.logFile);
        this.writer.writeGraphStructure(sameDiff);
        this.writer.writeFinishStaticMarker();
    }

    @Override // org.nd4j.autodiff.listeners.Listener
    public boolean isActive(Operation operation) {
        return operation == Operation.TRAINING;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void epochStart(SameDiff sameDiff, At at) {
        this.epochTrainEval = null;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public ListenerResponse epochEnd(SameDiff sameDiff, At at, LossCurve lossCurve, long j) {
        if (this.epochTrainEval != null) {
            long currentTimeMillis = System.currentTimeMillis();
            for (Map.Entry<Pair<String, Integer>, Evaluation> entry : this.epochTrainEval.entrySet()) {
                String str = "evaluation/" + entry.getKey().getFirst();
                for (Evaluation.Metric metric : this.trainEvalMetrics.get(entry.getKey())) {
                    String str2 = str + "/train/" + metric.toString().toLowerCase();
                    if (!this.wroteEvalNames && !this.writer.registeredEventName(str2)) {
                        this.writer.registerEventNameQuiet(str2);
                    }
                    try {
                        this.writer.writeScalarEvent(str2, LogFileWriter.EventSubtype.EVALUATION, currentTimeMillis, at.iteration(), at.epoch(), Double.valueOf(entry.getValue().scoreForMetric(metric)));
                    } catch (IOException e) {
                        throw new RuntimeException("Error writing to log file", e);
                    }
                }
                this.wroteEvalNames = true;
            }
        }
        this.epochTrainEval = null;
        return ListenerResponse.CONTINUE;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void iterationStart(SameDiff sameDiff, At at, MultiDataSet multiDataSet, long j) {
        if (this.writer == null) {
            initalizeWriter(sameDiff);
        }
        if (this.checkStructureForRestore) {
            checkStructureForRestore(sameDiff);
        }
        this.currentIterDataSet = multiDataSet;
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void iterationDone(SameDiff sameDiff, At at, MultiDataSet multiDataSet, Loss loss) {
        long currentTimeMillis = System.currentTimeMillis();
        if (!this.wroteLossNames) {
            Iterator<String> it = loss.getLossNames().iterator();
            while (it.hasNext()) {
                String str = "losses/" + it.next();
                if (!this.writer.registeredEventName(str)) {
                    this.writer.registerEventNameQuiet(str);
                }
            }
            if (loss.numLosses() > 1 && !this.writer.registeredEventName("losses/totalLoss")) {
                this.writer.registerEventNameQuiet("losses/totalLoss");
            }
            this.wroteLossNames = true;
        }
        List<String> lossNames = loss.getLossNames();
        double[] losses = loss.getLosses();
        for (int i = 0; i < losses.length; i++) {
            try {
                this.writer.writeScalarEvent("losses/" + lossNames.get(i), LogFileWriter.EventSubtype.LOSS, currentTimeMillis, at.iteration(), at.epoch(), Double.valueOf(losses[i]));
            } catch (IOException e) {
                throw new RuntimeException("Error writing to log file", e);
            }
        }
        if (losses.length > 1) {
            try {
                this.writer.writeScalarEvent("losses/totalLoss", LogFileWriter.EventSubtype.LOSS, currentTimeMillis, at.iteration(), at.epoch(), Double.valueOf(loss.totalLoss()));
            } catch (IOException e2) {
                throw new RuntimeException("Error writing to log file", e2);
            }
        }
        this.currentIterDataSet = null;
        if (this.learningRateFrequency > 0) {
            if (!this.wroteLearningRateName) {
                if (!this.writer.registeredEventName("learningRate")) {
                    this.writer.registerEventNameQuiet("learningRate");
                }
                this.wroteLearningRateName = true;
            }
            if (at.iteration() % this.learningRateFrequency == 0) {
                IUpdater updater = sameDiff.getTrainingConfig().getUpdater();
                if (updater.hasLearningRate()) {
                    try {
                        this.writer.writeScalarEvent("learningRate", LogFileWriter.EventSubtype.LEARNING_RATE, currentTimeMillis, at.iteration(), at.epoch(), Double.valueOf(updater.getLearningRate(at.iteration(), at.epoch())));
                    } catch (IOException e3) {
                        throw new RuntimeException("Error writing to log file");
                    }
                }
            }
        }
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void opExecution(SameDiff sameDiff, At at, MultiDataSet multiDataSet, SameDiffOp sameDiffOp, INDArray[] iNDArrayArr) {
        if (at.operation() != Operation.TRAINING || this.trainEvalMetrics == null || this.trainEvalMetrics.size() <= 0) {
            return;
        }
        long currentTimeMillis = System.currentTimeMillis();
        if (this.relevantOpsForEval == null) {
            this.relevantOpsForEval = new HashSet();
            Iterator<Pair<String, Integer>> it = this.trainEvalMetrics.keySet().iterator();
            while (it.hasNext()) {
                Variable variable = sameDiff.getVariables().get(it.next().getFirst());
                String outputOfOp = variable.getOutputOfOp();
                Preconditions.checkState(outputOfOp != null, "Cannot evaluate on variable of type %s - variable name: \"%s\"", variable.getVariable().getVariableType(), outputOfOp);
                this.relevantOpsForEval.add(variable.getOutputOfOp());
            }
        }
        if (this.relevantOpsForEval.contains(sameDiffOp.getName())) {
            if (this.epochTrainEval == null) {
                this.epochTrainEval = new HashMap();
                Iterator<Pair<String, Integer>> it2 = this.trainEvalMetrics.keySet().iterator();
                while (it2.hasNext()) {
                    this.epochTrainEval.put(it2.next(), new Evaluation());
                }
            }
            boolean z = false;
            for (Pair<String, Integer> pair : this.trainEvalMetrics.keySet()) {
                this.epochTrainEval.get(pair).eval(this.currentIterDataSet.getLabels(pair.getSecond().intValue()), iNDArrayArr[sameDiffOp.getOutputsOfOp().indexOf(pair.getFirst())], this.currentIterDataSet.getLabelsMaskArray(pair.getSecond().intValue()));
                if (this.trainEvalFrequency > 0 && at.iteration() > 0 && at.iteration() % this.trainEvalFrequency == 0) {
                    for (Evaluation.Metric metric : this.trainEvalMetrics.get(pair)) {
                        String str = "evaluation/train_iter/" + pair.getKey() + "/" + metric.toString().toLowerCase();
                        if (!this.wroteEvalNamesIter) {
                            if (!this.writer.registeredEventName(str)) {
                                this.writer.registerEventNameQuiet(str);
                            }
                            z = true;
                        }
                        try {
                            this.writer.writeScalarEvent(str, LogFileWriter.EventSubtype.EVALUATION, currentTimeMillis, at.iteration(), at.epoch(), Double.valueOf(this.epochTrainEval.get(pair).scoreForMetric(metric)));
                        } catch (IOException e) {
                            throw new RuntimeException("Error writing to log file");
                        }
                    }
                }
            }
            this.wroteEvalNamesIter = z;
        }
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void preUpdate(SameDiff sameDiff, At at, Variable variable, INDArray iNDArray) {
        double doubleValue;
        double doubleValue2;
        if (this.writer == null) {
            initalizeWriter(sameDiff);
        }
        if (this.updateRatioFrequency <= 0 || at.iteration() % this.updateRatioFrequency != 0) {
            return;
        }
        if (this.firstUpdateRatioIter < 0) {
            this.firstUpdateRatioIter = at.iteration();
        }
        if (this.firstUpdateRatioIter == at.iteration()) {
            String str = "logUpdateRatio/" + variable.getName();
            if (!this.writer.registeredEventName(str)) {
                this.writer.registerEventNameQuiet(str);
            }
        }
        if (this.updateRatioType == UpdateRatio.L2) {
            doubleValue = variable.getVariable().getArr().norm2Number().doubleValue();
            doubleValue2 = iNDArray.norm2Number().doubleValue();
        } else {
            doubleValue = variable.getVariable().getArr().norm1Number().doubleValue();
            doubleValue2 = iNDArray.norm1Number().doubleValue();
        }
        try {
            this.writer.writeScalarEvent("logUpdateRatio/" + variable.getName(), LogFileWriter.EventSubtype.LOSS, System.currentTimeMillis(), at.iteration(), at.epoch(), Double.valueOf(doubleValue == 0.0d ? 0.0d : Math.max(-10.0d, Math.log10(doubleValue2 / doubleValue))));
        } catch (IOException e) {
            throw new RuntimeException("Error writing to log file", e);
        }
    }

    public static Builder builder(File file) {
        return new Builder(file);
    }
}
