package org.deeplearning4j.plot.iterationlistener;

import java.util.ArrayList;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;

/* loaded from: input_file:org/deeplearning4j/plot/iterationlistener/AccuracyPlotterIterationListener.class */
public class AccuracyPlotterIterationListener implements IterationListener {
    private int epochs;
    private INDArray input;
    private MultiLayerNetwork network;
    private INDArray labels;
    private NeuralNetPlotter plotter;
    private boolean renderFirst;
    private ArrayList<Double> accuracy;
    private boolean invoked;

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public boolean invoked() {
        return this.invoked;
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public void invoke() {
        this.invoked = true;
    }

    public AccuracyPlotterIterationListener(int i, boolean z) {
        this.epochs = 1;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.accuracy = new ArrayList<>();
        this.invoked = false;
        this.epochs = i;
        this.renderFirst = z;
    }

    public AccuracyPlotterIterationListener(int i, NeuralNetPlotter neuralNetPlotter) {
        this.epochs = 1;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.accuracy = new ArrayList<>();
        this.invoked = false;
        this.epochs = i;
        this.plotter = neuralNetPlotter;
    }

    public AccuracyPlotterIterationListener(int i, NeuralNetPlotter neuralNetPlotter, boolean z) {
        this.epochs = 1;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.accuracy = new ArrayList<>();
        this.invoked = false;
        this.epochs = i;
        this.plotter = neuralNetPlotter;
        this.renderFirst = z;
    }

    public AccuracyPlotterIterationListener(int i, MultiLayerNetwork multiLayerNetwork, DataSet dataSet) {
        this.epochs = 1;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.accuracy = new ArrayList<>();
        this.invoked = false;
        this.epochs = i;
        this.network = multiLayerNetwork;
        this.input = dataSet.getFeatures();
        this.labels = dataSet.getLabels();
    }

    public AccuracyPlotterIterationListener(int i, MultiLayerNetwork multiLayerNetwork, DataSet dataSet, boolean z) {
        this.epochs = 1;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.accuracy = new ArrayList<>();
        this.invoked = false;
        this.epochs = i;
        this.network = multiLayerNetwork;
        this.input = dataSet.getFeatures();
        this.labels = dataSet.getLabels();
        this.renderFirst = z;
    }

    public AccuracyPlotterIterationListener(int i, MultiLayerNetwork multiLayerNetwork, INDArray iNDArray, INDArray iNDArray2) {
        this.epochs = 1;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.accuracy = new ArrayList<>();
        this.invoked = false;
        this.epochs = i;
        this.network = multiLayerNetwork;
        this.input = iNDArray;
        this.labels = iNDArray2;
    }

    public AccuracyPlotterIterationListener(int i) {
        this.epochs = 1;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.accuracy = new ArrayList<>();
        this.invoked = false;
        this.epochs = this.epochs;
    }

    private double calculateAccuracy() {
        Evaluation evaluation = new Evaluation();
        evaluation.eval(this.labels, this.network.output(this.input));
        return evaluation.accuracy();
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public void iterationDone(Model model, int i) {
        this.accuracy.add(Double.valueOf(calculateAccuracy()));
        if (!(i == 0 && this.renderFirst) && (i <= 0 || i % this.epochs != 0)) {
            return;
        }
        invoke();
        this.plotter.renderGraph("accuracy", this.plotter.writeArray(this.accuracy), this.plotter.getLayerGraphFilePath() + "accuracy.png");
    }
}
