package org.deeplearning4j.plot;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.NeuralNetworkGradient;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/deeplearning4j/plot/NeuralNetPlotter.class */
public class NeuralNetPlotter {
    private static ClassPathResource r = new ClassPathResource("/scripts/plot.py");
    private static Logger log = LoggerFactory.getLogger(NeuralNetPlotter.class);

    public void renderFilter(DoubleMatrix doubleMatrix, int i, int i2, long j) {
        try {
            String writeMatrix = writeMatrix(doubleMatrix);
            Process exec = Runtime.getRuntime().exec("python /tmp/plot.py filter " + writeMatrix + " " + i + " " + i2 + " " + j);
            log.info("Std out " + IOUtils.readLines(exec.getInputStream()).toString());
            log.info("Rendering weights " + writeMatrix);
            log.error(IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void plotNetworkGradient(NeuralNetwork neuralNetwork, NeuralNetworkGradient neuralNetworkGradient) {
        plotMatrices(new String[]{"W", "hbias", "vbias", "w-gradient", "hbias-gradient", "vbias-gradient"}, new DoubleMatrix[]{neuralNetwork.getW(), neuralNetwork.gethBias(), neuralNetwork.getvBias(), neuralNetworkGradient.getwGradient(), neuralNetworkGradient.gethBiasGradient(), neuralNetworkGradient.getvBiasGradient()});
        plotActivations(neuralNetwork);
    }

    public void plotMatrices(String[] strArr, DoubleMatrix[] doubleMatrixArr) {
        String[] strArr2 = new String[doubleMatrixArr.length * 2];
        try {
            if (strArr.length != doubleMatrixArr.length) {
                throw new IllegalArgumentException("Titles and matrix lengths must be equal");
            }
            for (int i = 0; i < strArr2.length - 1; i += 2) {
                strArr2[i] = writeMatrix(MatrixUtil.unroll(doubleMatrixArr[i / 2]));
                strArr2[i + 1] = strArr[i / 2];
            }
            Process exec = Runtime.getRuntime().exec("python /tmp/plot.py multi " + StringUtils.join(strArr2, ","));
            log.info("Rendering multiple matrices... ");
            log.info("Std out " + IOUtils.readLines(exec.getInputStream()).toString());
            log.error(IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected String writeMatrix(DoubleMatrix doubleMatrix) throws IOException {
        String str = System.getProperty("java.io.tmpdir") + File.separator + UUID.randomUUID().toString();
        File file = new File(str);
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file, true));
        file.deleteOnExit();
        for (int i = 0; i < doubleMatrix.rows; i++) {
            DoubleMatrix row = doubleMatrix.getRow(i);
            StringBuffer stringBuffer = new StringBuffer();
            for (int i2 = 0; i2 < row.length; i2++) {
                stringBuffer.append(String.format("%.10f", Double.valueOf(row.get(i2))));
                if (i2 < row.length - 1) {
                    stringBuffer.append(",");
                }
            }
            stringBuffer.append("\n");
            bufferedOutputStream.write(stringBuffer.toString().getBytes());
            bufferedOutputStream.flush();
        }
        bufferedOutputStream.close();
        return str;
    }

    public void plotWeights(NeuralNetwork neuralNetwork) {
        try {
            String writeMatrix = writeMatrix(neuralNetwork.getW());
            Process exec = Runtime.getRuntime().exec("python /tmp/plot.py weights " + writeMatrix);
            log.info("Rendering weights " + writeMatrix);
            log.error(IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (Exception e) {
        }
    }

    public void plotActivations(NeuralNetwork neuralNetwork) {
        try {
            if (neuralNetwork.getInput() == null) {
                throw new IllegalStateException("Unable to plot; missing input");
            }
            String writeMatrix = writeMatrix(neuralNetwork.getInput().mmul(neuralNetwork.getW()).addRowVector(neuralNetwork.gethBias()));
            Process exec = Runtime.getRuntime().exec("python /tmp/plot.py hbias " + writeMatrix);
            Thread.sleep(10000L);
            exec.destroy();
            log.info("Rendering hbias " + writeMatrix);
            log.error(IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (Exception e) {
            log.warn("Image closed");
        }
    }

    private static void loadIntoTmp() {
        try {
            FileUtils.writeLines(new File("/tmp/plot.py"), IOUtils.readLines(r.getInputStream()));
        } catch (IOException e) {
            throw new IllegalStateException("Unable to load python file");
        }
    }

    static {
        loadIntoTmp();
    }
}
