package org.deeplearning4j.plot;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.TreeSet;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.nn.params.PretrainParamInitializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/deeplearning4j/plot/NeuralNetPlotter.class */
public class NeuralNetPlotter implements Serializable {
    private static ClassPathResource r = new ClassPathResource("/scripts/plot.py");
    private static final Logger log = LoggerFactory.getLogger(NeuralNetPlotter.class);
    private static FilterRenderer render = new FilterRenderer();

    public void renderFilter(INDArray iNDArray) {
        try {
            render.renderFilters(iNDArray.dup(), "currimg.png", (int) Math.sqrt(r0.rows()), (int) Math.sqrt(r0.columns()), 10);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void plotNetworkGradient(Layer layer, INDArray iNDArray, int i) {
        histogram(new String[]{DefaultParamInitializer.WEIGHT_KEY, "w-gradient"}, new INDArray[]{layer.getParam(DefaultParamInitializer.WEIGHT_KEY), iNDArray});
        plotActivations(layer);
        try {
            if (layer.getParam(DefaultParamInitializer.WEIGHT_KEY).shape().length > 2) {
                INDArray dup = layer.getParam(DefaultParamInitializer.WEIGHT_KEY).dup();
                render.renderFilters(dup.transpose(), "currimg.png", dup.columns(), dup.rows(), dup.slices());
            } else {
                render.renderFilters(layer.getParam(DefaultParamInitializer.WEIGHT_KEY).dup(), "currimg.png", (int) Math.sqrt(layer.getParam(DefaultParamInitializer.WEIGHT_KEY).rows()), (int) Math.sqrt(layer.getParam(DefaultParamInitializer.WEIGHT_KEY).rows()), i);
            }
        } catch (Exception e) {
            log.error("Unable to plot filter, continuing...", e);
        }
    }

    public void hist(Layer layer, Gradient gradient) {
        TreeSet treeSet = new TreeSet(gradient.gradientForVariable().keySet());
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Iterator it = treeSet.iterator();
        while (it.hasNext()) {
            linkedHashSet.add(((String) it.next()) + "-gradient");
        }
        treeSet.addAll(linkedHashSet);
        histogram((String[]) treeSet.toArray(new String[treeSet.size()]), new INDArray[]{layer.getParam(DefaultParamInitializer.WEIGHT_KEY), layer.getParam("b"), layer.getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY), gradient.gradientForVariable().get(DefaultParamInitializer.WEIGHT_KEY), gradient.gradientForVariable().get("b"), gradient.gradientForVariable().get(PretrainParamInitializer.VISIBLE_BIAS_KEY)});
    }

    public void hist(Layer layer) {
        hist(layer, layer.gradient());
    }

    public void plotNetworkGradient(Layer layer, Gradient gradient, int i) {
        TreeSet treeSet = new TreeSet(gradient.gradientForVariable().keySet());
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Iterator it = treeSet.iterator();
        while (it.hasNext()) {
            linkedHashSet.add(((String) it.next()) + "-gradient");
        }
        treeSet.addAll(linkedHashSet);
        histogram((String[]) treeSet.toArray(new String[treeSet.size()]), new INDArray[]{layer.getParam(DefaultParamInitializer.WEIGHT_KEY), layer.getParam("b"), layer.getParam(PretrainParamInitializer.VISIBLE_BIAS_KEY), gradient.gradientForVariable().get(DefaultParamInitializer.WEIGHT_KEY), gradient.gradientForVariable().get("b"), gradient.gradientForVariable().get(PretrainParamInitializer.VISIBLE_BIAS_KEY)});
        plotActivations(layer);
        FilterRenderer filterRenderer = new FilterRenderer();
        try {
            filterRenderer.renderFilters(layer.getParam(DefaultParamInitializer.WEIGHT_KEY).dup(), "currimg.png", (int) Math.sqrt(r0.rows()), (int) Math.sqrt(r0.rows()), i);
        } catch (Exception e) {
            log.error("Unable to plot filter, continuing...", e);
        }
    }

    public void scatter(String[] strArr, INDArray[] iNDArrayArr) {
        String[] strArr2 = new String[iNDArrayArr.length * 2];
        try {
            if (strArr.length != iNDArrayArr.length) {
                throw new IllegalArgumentException("Titles and matrix lengths must be equal");
            }
            for (int i = 0; i < strArr2.length - 1; i += 2) {
                strArr2[i] = writeMatrix(iNDArrayArr[i / 2].ravel());
                strArr2[i + 1] = strArr[i / 2];
            }
            Process exec = Runtime.getRuntime().exec("python /tmp/plot.py scatter " + StringUtils.join(strArr2, ","));
            log.info("Rendering Matrix histograms... ");
            log.info("Std out " + IOUtils.readLines(exec.getInputStream()).toString());
            log.error(IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void histogram(String[] strArr, INDArray[] iNDArrayArr) {
        String[] strArr2 = new String[iNDArrayArr.length * 2];
        try {
            if (strArr.length != iNDArrayArr.length) {
                throw new IllegalArgumentException("Titles and matrix lengths must be equal");
            }
            for (int i = 0; i < strArr2.length - 1; i += 2) {
                strArr2[i] = writeMatrix(iNDArrayArr[i / 2].ravel());
                strArr2[i + 1] = strArr[i / 2];
            }
            Process exec = Runtime.getRuntime().exec("python /tmp/plot.py multi " + StringUtils.join(strArr2, ","));
            log.info("Rendering Matrix histograms... ");
            log.info("Std out " + IOUtils.readLines(exec.getInputStream()).toString());
            log.error(IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected String writeMatrix(INDArray iNDArray) throws IOException {
        String str = System.getProperty("java.io.tmpdir") + File.separator + UUID.randomUUID().toString();
        File file = new File(str);
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file, true));
        file.deleteOnExit();
        for (int i = 0; i < iNDArray.rows(); i++) {
            INDArray row = iNDArray.getRow(i);
            StringBuilder sb = new StringBuilder();
            for (int i2 = 0; i2 < row.length(); i2++) {
                sb.append(String.format("%.10f", Double.valueOf(row.getDouble(i2))));
                if (i2 < row.length() - 1) {
                    sb.append(",");
                }
            }
            sb.append("\n");
            bufferedOutputStream.write(sb.toString().getBytes());
            bufferedOutputStream.flush();
        }
        bufferedOutputStream.close();
        return str;
    }

    public void plotActivations(Layer layer) {
        try {
            if (layer.input() == null) {
                throw new IllegalStateException("Unable to plot; missing input");
            }
            String writeMatrix = writeMatrix(layer.activationMean());
            Process exec = Runtime.getRuntime().exec("python /tmp/plot.py hbias " + writeMatrix);
            Thread.sleep(10000L);
            exec.destroy();
            log.info("Rendering hbias " + writeMatrix);
            log.error(IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (Exception e) {
            log.warn("Image closed");
        }
    }

    private static void loadIntoTmp() {
        try {
            FileUtils.writeLines(new File("/tmp/plot.py"), IOUtils.readLines(r.getInputStream()));
        } catch (IOException e) {
            throw new IllegalStateException("Unable to load python file");
        }
    }

    static {
        loadIntoTmp();
    }
}
