package org.deeplearning4j.plot.iterationlistener;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.plot.NeuralNetPlotter;

/* loaded from: input_file:org/deeplearning4j/plot/iterationlistener/GradientPlotterIterationListener.class */
public class GradientPlotterIterationListener implements IterationListener {
    private int iterations;
    private NeuralNetPlotter plotter;
    private boolean renderFirst;
    private boolean invoked;

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public boolean invoked() {
        return this.invoked;
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public void invoke() {
        this.invoked = true;
    }

    public GradientPlotterIterationListener(int i, boolean z) {
        this.iterations = 10;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.invoked = false;
        this.iterations = i;
        this.renderFirst = z;
    }

    public GradientPlotterIterationListener(int i, NeuralNetPlotter neuralNetPlotter) {
        this.iterations = 10;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.invoked = false;
        this.iterations = i;
        this.plotter = neuralNetPlotter;
    }

    public GradientPlotterIterationListener(int i, NeuralNetPlotter neuralNetPlotter, boolean z) {
        this.iterations = 10;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.invoked = false;
        this.iterations = i;
        this.plotter = neuralNetPlotter;
        this.renderFirst = z;
    }

    public GradientPlotterIterationListener(int i) {
        this.iterations = 10;
        this.plotter = new NeuralNetPlotter();
        this.renderFirst = false;
        this.invoked = false;
        this.iterations = i;
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public void iterationDone(Model model, int i) {
        if (!(i == 0 && this.renderFirst) && (i <= 0 || i % this.iterations != 0)) {
            return;
        }
        invoke();
        Layer layer = (Layer) model;
        this.plotter.updateGraphDirectory(layer);
        this.plotter.plotNetworkGradient(layer, layer.gradient());
    }
}
