package org.deeplearning4j.ui.flow;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.UiUtils;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.flow.beans.Description;
import org.deeplearning4j.ui.flow.beans.LayerInfo;
import org.deeplearning4j.ui.flow.beans.LayerParams;
import org.deeplearning4j.ui.flow.beans.ModelInfo;
import org.deeplearning4j.ui.flow.beans.ModelState;
import org.deeplearning4j.ui.flow.data.FlowStaticPersistable;
import org.deeplearning4j.ui.flow.data.FlowUpdatePersistable;
import org.deeplearning4j.ui.storage.mapdb.MapDBStatsStorage;
import org.deeplearning4j.ui.weights.HistogramBin;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
/* loaded from: input_file:org/deeplearning4j/ui/flow/FlowIterationListener.class */
public class FlowIterationListener implements IterationListener {
    private static final String FORMAT = "%02d:%02d:%02d";
    public static final String INPUT = "INPUT";
    private int frequency;
    private boolean firstIteration;
    private ModelState modelState;
    private AtomicLong iterationCount;
    private long lastTime;
    private long currTime;
    private long initTime;
    private final StatsStorageRouter ssr;
    private final String sessionID;
    private final String workerID;
    private boolean openBrowser;
    private static final Logger log = LoggerFactory.getLogger(FlowIterationListener.class);
    private static final List<String> colors = Collections.unmodifiableList(Arrays.asList("#9966ff", "#ff9933", "#ffff99", "#3366ff", "#0099cc", "#669999", "#66ffff"));

    protected FlowIterationListener() {
        this(1);
    }

    public FlowIterationListener(int i) {
        this(new MapDBStatsStorage(), i, null, null, true);
    }

    @Deprecated
    public FlowIterationListener(@NonNull String str, int i, int i2) {
        this(i2);
        if (str == null) {
            throw new NullPointerException("address");
        }
    }

    public FlowIterationListener(StatsStorageRouter statsStorageRouter, int i, String str, String str2, boolean z) {
        this.frequency = 1;
        this.firstIteration = true;
        this.modelState = new ModelState();
        this.iterationCount = new AtomicLong(0L);
        this.lastTime = System.currentTimeMillis();
        this.initTime = System.currentTimeMillis();
        this.frequency = i;
        this.ssr = statsStorageRouter;
        if (str == null) {
            this.sessionID = UUID.randomUUID().toString();
        } else {
            this.sessionID = str;
        }
        if (str2 == null) {
            this.workerID = UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId();
        } else {
            this.workerID = str2;
        }
        this.openBrowser = z;
        if ((statsStorageRouter instanceof StatsStorage) && z) {
            UIServer.getInstance().attach((StatsStorage) statsStorageRouter);
        }
        System.out.println("FlowIterationListener path: http://localhost:" + UIServer.getInstance().getPort() + "/flow");
    }

    @Deprecated
    public FlowIterationListener(@NonNull UiConnectionInfo uiConnectionInfo, int i) {
        this(i);
        if (uiConnectionInfo == null) {
            throw new NullPointerException("connectionInfo");
        }
    }

    public boolean invoked() {
        return false;
    }

    public void invoke() {
    }

    public synchronized void iterationDone(Model model, int i) {
        if (this.iterationCount.incrementAndGet() % this.frequency == 0) {
            this.currTime = System.currentTimeMillis();
            if (this.firstIteration) {
                this.ssr.putStaticInfo(new FlowStaticPersistable(this.sessionID, this.workerID, System.currentTimeMillis(), buildModelInfo(model)));
            }
            buildModelState(model);
            this.ssr.putUpdate(new FlowUpdatePersistable(this.sessionID, this.workerID, System.currentTimeMillis(), this.modelState));
            if (this.firstIteration && this.openBrowser) {
                try {
                    UiUtils.tryOpenBrowser("http://localhost:" + UIServer.getInstance().getPort() + "/flow?sid=" + this.sessionID, log);
                } catch (Exception e) {
                }
                this.firstIteration = false;
            }
        }
        this.lastTime = System.currentTimeMillis();
    }

    protected List<LayerInfo> flattenToY(ModelInfo modelInfo, GraphVertex[] graphVertexArr, List<String> list, int i) {
        ArrayList arrayList = new ArrayList();
        int i2 = 0;
        for (GraphVertex graphVertex : graphVertexArr) {
            VertexIndices[] inputVertices = graphVertex.getInputVertices();
            if (inputVertices != null) {
                for (VertexIndices vertexIndices : inputVertices) {
                    String vertexName = graphVertexArr[vertexIndices.getVertexIndex()].getVertexName();
                    for (String str : list) {
                        if (vertexName.equals(str)) {
                            try {
                                LayerInfo layerInfoByName = modelInfo.getLayerInfoByName(graphVertex.getVertexName());
                                if (layerInfoByName == null) {
                                    layerInfoByName = getLayerInfo(graphVertex.getLayer(), i2, i, 121);
                                }
                                layerInfoByName.setName(graphVertex.getVertexName());
                                if (graphVertex.getLayer() == null) {
                                    layerInfoByName.setLayerType(graphVertex.getClass().getSimpleName());
                                }
                                if (layerInfoByName.getName().endsWith("-merge")) {
                                    layerInfoByName.setLayerType("MERGE");
                                }
                                if (modelInfo.getLayerInfoByName(graphVertex.getVertexName()) == null) {
                                    i2++;
                                    modelInfo.addLayer(layerInfoByName);
                                    arrayList.add(layerInfoByName);
                                }
                                LayerInfo layerInfoByName2 = modelInfo.getLayerInfoByName(str);
                                if (layerInfoByName2 != null) {
                                    layerInfoByName2.addConnection(layerInfoByName);
                                }
                            } catch (Exception e) {
                                e.printStackTrace();
                            }
                        }
                    }
                }
            }
        }
        return arrayList;
    }

    protected void buildModelState(Model model) {
        long j = this.currTime - this.lastTime;
        float f = ((float) j) / 1000.0f;
        INDArray input = model.input();
        this.modelState.addPerformanceSamples(((float) (input.lengthLong() / Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank())))) / f);
        this.modelState.addPerformanceBatches(1.0f / f);
        this.modelState.setIterationTime(j);
        this.modelState.addScore((float) model.score());
        this.modelState.setScore((float) model.score());
        this.modelState.setTrainingTime(parseTime(System.currentTimeMillis() - this.initTime));
        new LinkedHashMap();
        new LinkedHashMap();
        Map paramTable = model.paramTable();
        Layer[] layerArr = null;
        if (model instanceof MultiLayerNetwork) {
            layerArr = ((MultiLayerNetwork) model).getLayers();
        } else if (model instanceof ComputationGraph) {
            layerArr = ((ComputationGraph) model).getLayers();
        }
        ArrayList arrayList = new ArrayList();
        if (layerArr != null) {
            for (Layer layer : layerArr) {
                arrayList.add(Double.valueOf(layer.conf().getLayer().getLearningRate()));
            }
            this.modelState.setLearningRates(arrayList);
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry entry : paramTable.entrySet()) {
            String str = (String) entry.getKey();
            if (Character.isDigit(str.charAt(0))) {
                int parseInt = Integer.parseInt(str.replaceAll("\\_.*$", ""));
                String lowerCase = str.replaceAll("^.*?_", "").toLowerCase();
                if (!linkedHashMap.containsKey(Integer.valueOf(parseInt))) {
                    linkedHashMap.put(Integer.valueOf(parseInt), new LayerParams());
                }
                HistogramBin build = new HistogramBin.Builder(((INDArray) entry.getValue()).dup()).setBinCount(14).setRounding(6).build();
                if (lowerCase.equalsIgnoreCase("w")) {
                    ((LayerParams) linkedHashMap.get(Integer.valueOf(parseInt))).setW(build.getData());
                } else if (lowerCase.equalsIgnoreCase("rw")) {
                    ((LayerParams) linkedHashMap.get(Integer.valueOf(parseInt))).setRW(build.getData());
                } else if (lowerCase.equalsIgnoreCase("rwf")) {
                    ((LayerParams) linkedHashMap.get(Integer.valueOf(parseInt))).setRWF(build.getData());
                } else if (lowerCase.equalsIgnoreCase("b")) {
                    ((LayerParams) linkedHashMap.get(Integer.valueOf(parseInt))).setB(build.getData());
                }
            }
        }
        this.modelState.setLayerParams(linkedHashMap);
    }

    protected ModelInfo buildModelInfo(Model model) {
        ModelInfo modelInfo = new ModelInfo();
        if (model instanceof ComputationGraph) {
            ComputationGraph computationGraph = (ComputationGraph) model;
            List<String> networkInputs = computationGraph.getConfiguration().getNetworkInputs();
            int i = 0;
            for (String str : networkInputs) {
                INDArray iNDArray = computationGraph.getVertex(str).getInputs()[0];
                long tADLength = Shape.getTADLength(iNDArray.shape(), ArrayUtil.range(1, iNDArray.rank()));
                long lengthLong = iNDArray.lengthLong() / tADLength;
                StringBuilder sb = new StringBuilder();
                sb.append("Vertex name: ").append(str).append("<br/>");
                sb.append("Model input").append("<br/>");
                sb.append("Input size: ").append(tADLength).append("<br/>");
                sb.append("Batch size: ").append(lengthLong).append("<br/>");
                LayerInfo layerInfo = new LayerInfo();
                layerInfo.setId(0L);
                layerInfo.setName(str);
                layerInfo.setY(0);
                layerInfo.setX(i);
                layerInfo.setLayerType(INPUT);
                layerInfo.setDescription(new Description());
                layerInfo.getDescription().setMainLine("Model input");
                layerInfo.getDescription().setText(sb.toString());
                modelInfo.addLayer(layerInfo);
                i++;
            }
            GraphVertex[] vertices = computationGraph.getVertices();
            List<String> arrayList = new ArrayList<>();
            for (int i2 = 1; i2 < vertices.length; i2++) {
                if (arrayList.isEmpty()) {
                    arrayList.addAll(networkInputs);
                }
                List<LayerInfo> flattenToY = flattenToY(modelInfo, vertices, arrayList, i2);
                arrayList.clear();
                Iterator<LayerInfo> it = flattenToY.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next().getName());
                }
                if (arrayList.isEmpty()) {
                    break;
                }
            }
        } else if (model instanceof MultiLayerNetwork) {
            INDArray input = model.input();
            long tADLength2 = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
            long lengthLong2 = input.lengthLong() / tADLength2;
            StringBuilder sb2 = new StringBuilder();
            sb2.append("Model input").append("<br/>");
            sb2.append("Input size: ").append(tADLength2).append("<br/>");
            sb2.append("Batch size: ").append(lengthLong2).append("<br/>");
            LayerInfo layerInfo2 = new LayerInfo();
            layerInfo2.setId(0L);
            layerInfo2.setName("Input");
            layerInfo2.setY(0);
            layerInfo2.setX(0);
            layerInfo2.setLayerType(INPUT);
            layerInfo2.setDescription(new Description());
            layerInfo2.getDescription().setMainLine("Model input");
            layerInfo2.getDescription().setText(sb2.toString());
            layerInfo2.addConnection(0, 1);
            modelInfo.addLayer(layerInfo2);
            int i3 = 1;
            for (Layer layer : ((MultiLayerNetwork) model).getLayers()) {
                LayerInfo layerInfo3 = getLayerInfo(layer, 0, i3, i3);
                layerInfo3.addConnection(0, i3 + 1);
                modelInfo.addLayer(layerInfo3);
                i3++;
            }
            modelInfo.getLayerInfoByCoords(0, i3 - 1).dropConnections();
        }
        for (LayerInfo layerInfo4 : modelInfo.getLayers()) {
            if (layerInfo4.getConnections().size() == 0) {
                layerInfo4.setLayerType("OUTPUT");
            }
        }
        AtomicInteger atomicInteger = new AtomicInteger(0);
        for (String str2 : modelInfo.getLayerTypes()) {
            String str3 = colors.get(atomicInteger.getAndIncrement());
            if (atomicInteger.get() >= colors.size()) {
                atomicInteger.set(0);
            }
            for (LayerInfo layerInfo5 : modelInfo.getLayersByType(str2)) {
                if (str2.equals(INPUT)) {
                    layerInfo5.setColor("#99ff66");
                } else if (str2.equals("OUTPUT")) {
                    layerInfo5.setColor("#e6e6e6");
                } else {
                    layerInfo5.setColor(str3);
                }
            }
        }
        return modelInfo;
    }

    private LayerInfo getLayerInfo(Layer layer, int i, int i2, int i3) {
        LayerInfo layerInfo = new LayerInfo();
        layerInfo.setX(i);
        layerInfo.setY(i2);
        try {
            layerInfo.setName(layer.conf().getLayer().getLayerName());
        } catch (Exception e) {
        }
        if (layerInfo.getName() == null || layerInfo.getName().isEmpty()) {
            layerInfo.setName("unnamed");
        }
        layerInfo.setId(i3);
        Description description = new Description();
        layerInfo.setDescription(description);
        try {
            layerInfo.setLayerType(layer.getClass().getSimpleName().replaceAll("Layer$", ""));
            StringBuilder sb = new StringBuilder();
            StringBuilder sb2 = new StringBuilder();
            StringBuilder sb3 = new StringBuilder();
            if (layer.type().equals(Layer.Type.CONVOLUTIONAL)) {
                ConvolutionLayer layer2 = layer.conf().getLayer();
                sb.append("K: " + Arrays.toString(layer2.getKernelSize()) + " S: " + Arrays.toString(layer2.getStride()) + " P: " + Arrays.toString(layer2.getPadding()));
                sb2.append("nIn/nOut: [" + layer2.getNIn() + "/" + layer2.getNOut() + "]");
                sb3.append("Kernel size: ").append(Arrays.toString(layer2.getKernelSize())).append("<br/>");
                sb3.append("Stride: ").append(Arrays.toString(layer2.getStride())).append("<br/>");
                sb3.append("Padding: ").append(Arrays.toString(layer2.getPadding())).append("<br/>");
                sb3.append("Inputs number: ").append(layer2.getNIn()).append("<br/>");
                sb3.append("Outputs number: ").append(layer2.getNOut()).append("<br/>");
            } else if (layer.conf().getLayer() instanceof SubsamplingLayer) {
                SubsamplingLayer layer3 = layer.conf().getLayer();
                sb3.append("Kernel size: ").append(Arrays.toString(layer3.getKernelSize())).append("<br/>");
                sb3.append("Stride: ").append(Arrays.toString(layer3.getStride())).append("<br/>");
                sb3.append("Padding: ").append(Arrays.toString(layer3.getPadding())).append("<br/>");
                sb3.append("Pooling type: ").append(layer3.getPoolingType().toString()).append("<br/>");
            } else if (layer.conf().getLayer() instanceof FeedForwardLayer) {
                FeedForwardLayer layer4 = layer.conf().getLayer();
                sb.append("nIn/nOut: [" + layer4.getNIn() + "/" + layer4.getNOut() + "]");
                sb2.append(layerInfo.getLayerType());
                sb3.append("Inputs number: ").append(layer4.getNIn()).append("<br/>");
                sb3.append("Outputs number: ").append(layer4.getNOut()).append("<br/>");
            } else if (layer instanceof BaseOutputLayer) {
                sb.append("Outputs: [" + layer.conf().getLayer().getNOut() + "]");
                sb3.append("Outputs number: ").append(layer.conf().getLayer().getNOut()).append("<br/>");
            }
            sb2.append(" A: [").append(layer.conf().getLayer().getActivationFunction()).append("]");
            sb3.append("Activation function: ").append("<b>").append(layer.conf().getLayer().getActivationFunction()).append("</b>").append("<br/>");
            description.setMainLine(sb.toString());
            description.setSubLine(sb2.toString());
            description.setText(sb3.toString());
            return layerInfo;
        } catch (Exception e2) {
            layerInfo.setLayerType("n/a");
            return layerInfo;
        }
    }

    protected String parseTime(long j) {
        return String.format(FORMAT, Long.valueOf(TimeUnit.MILLISECONDS.toHours(j)), Long.valueOf(TimeUnit.MILLISECONDS.toMinutes(j) - TimeUnit.HOURS.toMinutes(TimeUnit.MILLISECONDS.toHours(j))), Long.valueOf(TimeUnit.MILLISECONDS.toSeconds(j) - TimeUnit.MINUTES.toSeconds(TimeUnit.MILLISECONDS.toMinutes(j))));
    }
}
