package org.deeplearning4j.evaluation;

import java.awt.Color;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.eval.EvaluationCalibration;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.eval.curves.Histogram;
import org.deeplearning4j.eval.curves.PrecisionRecallCurve;
import org.deeplearning4j.eval.curves.ReliabilityDiagram;
import org.deeplearning4j.eval.curves.RocCurve;
import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.api.LengthUnit;
import org.deeplearning4j.ui.components.chart.ChartHistogram;
import org.deeplearning4j.ui.components.chart.ChartLine;
import org.deeplearning4j.ui.components.chart.style.StyleChart;
import org.deeplearning4j.ui.components.component.ComponentDiv;
import org.deeplearning4j.ui.components.component.style.StyleDiv;
import org.deeplearning4j.ui.components.table.ComponentTable;
import org.deeplearning4j.ui.components.table.style.StyleTable;
import org.deeplearning4j.ui.components.text.ComponentText;
import org.deeplearning4j.ui.components.text.style.StyleText;
import org.deeplearning4j.ui.standalone.StaticPageUtil;

/* loaded from: input_file:org/deeplearning4j/evaluation/EvaluationTools.class */
public class EvaluationTools {
    private static final String ROC_TITLE = "ROC: TPR/Recall (y) vs. FPR (x)";
    private static final String PR_TITLE = "Precision (y) vs. Recall (x)";
    private static final String PR_THRESHOLD_TITLE = "Precision and Recall (y) vs. Classifier Threshold (x)";
    private static final double CHART_WIDTH_PX = 600.0d;
    private static final double CHART_HEIGHT_PX = 400.0d;
    private static final StyleChart CHART_STYLE = new StyleChart.Builder().width(CHART_WIDTH_PX, LengthUnit.Px).height(CHART_HEIGHT_PX, LengthUnit.Px).margin(LengthUnit.Px, (Integer) 60, (Integer) 60, (Integer) 75, (Integer) 10).strokeWidth(2.0d).seriesColors(Color.BLUE, Color.LIGHT_GRAY).build();
    private static final StyleChart CHART_STYLE_PRECISION_RECALL = new StyleChart.Builder().width(CHART_WIDTH_PX, LengthUnit.Px).height(CHART_HEIGHT_PX, LengthUnit.Px).margin(LengthUnit.Px, (Integer) 60, (Integer) 60, (Integer) 40, (Integer) 10).strokeWidth(2.0d).seriesColors(Color.BLUE, Color.GREEN).build();
    private static final StyleTable TABLE_STYLE = new StyleTable.Builder().backgroundColor(Color.WHITE).headerColor(Color.LIGHT_GRAY).borderWidth(1).columnWidths(LengthUnit.Percent, 50.0d, 50.0d).width(CHART_HEIGHT_PX, LengthUnit.Px).height(200.0d, LengthUnit.Px).build();
    private static final StyleDiv OUTER_DIV_STYLE = new StyleDiv.Builder().width(1200.0d, LengthUnit.Px).height(CHART_HEIGHT_PX, LengthUnit.Px).build();
    private static final StyleDiv OUTER_DIV_STYLE_WIDTH_ONLY = new StyleDiv.Builder().width(1200.0d, LengthUnit.Px).build();
    private static final StyleDiv INNER_DIV_STYLE = new StyleDiv.Builder().width(CHART_WIDTH_PX, LengthUnit.Px).floatValue(StyleDiv.FloatValue.left).build();
    private static final StyleDiv PAD_DIV_STYLE = new StyleDiv.Builder().width(CHART_WIDTH_PX, LengthUnit.Px).height(100.0d, LengthUnit.Px).floatValue(StyleDiv.FloatValue.left).build();
    private static final ComponentDiv PAD_DIV = new ComponentDiv(PAD_DIV_STYLE, new Component[0]);
    private static final StyleText HEADER_TEXT_STYLE = new StyleText.Builder().color(Color.BLACK).fontSize(16.0d).underline(true).build();
    private static final StyleDiv HEADER_DIV_STYLE = new StyleDiv.Builder().width(1050.0d, LengthUnit.Px).height(30.0d, LengthUnit.Px).backgroundColor(Color.LIGHT_GRAY).margin(LengthUnit.Px, (Integer) 5, (Integer) 5, (Integer) 200, (Integer) 10).floatValue(StyleDiv.FloatValue.left).build();
    private static final StyleDiv HEADER_DIV_STYLE_1400 = new StyleDiv.Builder().width(1250.0d, LengthUnit.Px).height(30.0d, LengthUnit.Px).backgroundColor(Color.LIGHT_GRAY).margin(LengthUnit.Px, (Integer) 5, (Integer) 5, (Integer) 200, (Integer) 10).floatValue(StyleDiv.FloatValue.left).build();
    private static final StyleDiv HEADER_DIV_PAD_STYLE = new StyleDiv.Builder().width(1200.0d, LengthUnit.Px).height(150.0d, LengthUnit.Px).backgroundColor(Color.WHITE).build();
    private static final StyleDiv HEADER_DIV_TEXT_PAD_STYLE = new StyleDiv.Builder().width(120.0d, LengthUnit.Px).height(30.0d, LengthUnit.Px).backgroundColor(Color.LIGHT_GRAY).floatValue(StyleDiv.FloatValue.left).build();
    private static final ComponentTable INFO_TABLE = new ComponentTable.Builder(new StyleTable.Builder().backgroundColor(Color.WHITE).borderWidth(0).build()).content(new String[]{new String[]{"Precision", "(true positives) / (true positives + false positives)"}, new String[]{"True Positive Rate (Recall)", "(true positives) / (data positives)"}, new String[]{"False Positive Rate", "(false positives) / (data negatives)"}}).build();

    private EvaluationTools() {
    }

    public static void exportRocChartsToHtmlFile(ROC roc, File file) throws IOException {
        FileUtils.writeStringToFile(file, rocChartToHtml(roc));
    }

    public static void exportRocChartsToHtmlFile(ROCMultiClass rOCMultiClass, File file) throws Exception {
        FileUtils.writeStringToFile(file, rocChartToHtml(rOCMultiClass));
    }

    public static String rocChartToHtml(ROC roc) {
        return StaticPageUtil.renderHTML(getRocFromPoints(ROC_TITLE, roc.getRocCurve(), roc.getCountActualPositive(), roc.getCountActualNegative(), roc.calculateAUC(), roc.calculateAUCPR()), getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, roc.getPrecisionRecallCurve()));
    }

    public static String rocChartToHtml(ROCMultiClass rOCMultiClass) {
        return rocChartToHtml(rOCMultiClass, null);
    }

    public static String rocChartToHtml(ROCMultiClass rOCMultiClass, List<String> list) {
        int numClasses = rOCMultiClass.getNumClasses();
        ArrayList arrayList = new ArrayList(numClasses);
        for (int i = 0; i < numClasses; i++) {
            RocCurve rocCurve = rOCMultiClass.getRocCurve(i);
            String str = "Class " + i;
            if (list != null && list.size() > i) {
                str = str + " (" + list.get(i) + ")";
            }
            arrayList.add(new ComponentDiv(HEADER_DIV_PAD_STYLE, new Component[0]));
            ComponentDiv componentDiv = new ComponentDiv(HEADER_DIV_TEXT_PAD_STYLE, new Component[0]);
            ComponentDiv componentDiv2 = new ComponentDiv(HEADER_DIV_STYLE, new ComponentText(str + " vs. All", HEADER_TEXT_STYLE));
            Component rocFromPoints = getRocFromPoints(ROC_TITLE, rocCurve, rOCMultiClass.getCountActualPositive(i), rOCMultiClass.getCountActualNegative(i), rOCMultiClass.calculateAUC(i), rOCMultiClass.calculateAUCPR(i));
            Component pRCharts = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, rOCMultiClass.getPrecisionRecallCurve(i));
            arrayList.add(componentDiv);
            arrayList.add(componentDiv2);
            arrayList.add(rocFromPoints);
            arrayList.add(pRCharts);
        }
        return StaticPageUtil.renderHTML(arrayList);
    }

    public static void exportevaluationCalibrationToHtmlFile(EvaluationCalibration evaluationCalibration, File file) throws IOException {
        FileUtils.writeStringToFile(file, evaluationCalibrationToHtml(evaluationCalibration));
    }

    public static String evaluationCalibrationToHtml(EvaluationCalibration evaluationCalibration) {
        ArrayList arrayList = new ArrayList();
        int numClasses = evaluationCalibration.numClasses();
        arrayList.add(new ComponentDiv(HEADER_DIV_STYLE_1400, new ComponentText("Labels and Network Prediction Class Distributions (X: Class Index. Y: Count)", HEADER_TEXT_STYLE)));
        int[] labelCountsEachClass = evaluationCalibration.getLabelCountsEachClass();
        int[] predictionCountsEachClass = evaluationCalibration.getPredictionCountsEachClass();
        ChartHistogram.Builder builder = new ChartHistogram.Builder("Label Class Distribution", CHART_STYLE);
        ChartHistogram.Builder builder2 = new ChartHistogram.Builder("Predicted Class Distribution", CHART_STYLE);
        for (int i = 0; i < numClasses; i++) {
            double d = i - 0.5d;
            double d2 = i + 0.5d;
            builder.addBin(d, d2, labelCountsEachClass[i]);
            builder2.addBin(d, d2, predictionCountsEachClass[i]);
        }
        arrayList.add(new ComponentDiv(OUTER_DIV_STYLE_WIDTH_ONLY, builder.build(), builder2.build()));
        arrayList.add(new ComponentDiv(HEADER_DIV_STYLE_1400, new ComponentText("Reliability Diagrams (X: Mean Predicted Value. Y: Fraction Positives)", HEADER_TEXT_STYLE)));
        ArrayList arrayList2 = new ArrayList();
        double[] dArr = {0.0d, 1.0d};
        for (int i2 = 0; i2 < numClasses; i2++) {
            ReliabilityDiagram reliabilityDiagram = evaluationCalibration.getReliabilityDiagram(i2);
            arrayList2.add(new ChartLine.Builder(reliabilityDiagram.getTitle(), CHART_STYLE).addSeries("Classifier", reliabilityDiagram.getMeanPredictedValueX(), reliabilityDiagram.getFractionPositivesY()).addSeries("Ideal Classifier", dArr, dArr).build());
        }
        arrayList.add(new ComponentDiv(OUTER_DIV_STYLE_WIDTH_ONLY, arrayList2));
        arrayList.add(new ComponentDiv(HEADER_DIV_STYLE_1400, new ComponentText("Network Predictions - Residual Plots - |Label(i) - P(class(i))|", HEADER_TEXT_STYLE)));
        ArrayList arrayList3 = new ArrayList();
        arrayList3.add(getHistogram(evaluationCalibration.getResidualPlotAllClasses()));
        for (int i3 = 0; i3 < numClasses; i3++) {
            arrayList3.add(getHistogram(evaluationCalibration.getResidualPlot(i3)));
        }
        arrayList.add(new ComponentDiv(OUTER_DIV_STYLE_WIDTH_ONLY, arrayList3));
        arrayList.add(new ComponentDiv(HEADER_DIV_STYLE_1400, new ComponentText("Network Prediction Probabilities (X: P(class). Y: Count)", HEADER_TEXT_STYLE)));
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(getHistogram(evaluationCalibration.getProbabilityHistogramAllClasses()));
        for (int i4 = 0; i4 < numClasses; i4++) {
            arrayList4.add(getHistogram(evaluationCalibration.getProbabilityHistogram(i4)));
        }
        arrayList.add(new ComponentDiv(OUTER_DIV_STYLE_WIDTH_ONLY, arrayList4));
        return StaticPageUtil.renderHTML(arrayList);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v17, types: [java.lang.String[], java.lang.String[][]] */
    private static Component getRocFromPoints(String str, RocCurve rocCurve, long j, long j2, double d, double d2) {
        double[] dArr = {0.0d, 1.0d};
        ChartLine build = new ChartLine.Builder(str, CHART_STYLE).setXMin(Double.valueOf(0.0d)).setXMax(Double.valueOf(1.0d)).setYMin(Double.valueOf(0.0d)).setYMax(Double.valueOf(1.0d)).addSeries("ROC", rocCurve.getX(), rocCurve.getY()).addSeries("", dArr, dArr).build();
        return new ComponentDiv(OUTER_DIV_STYLE, new ComponentDiv(INNER_DIV_STYLE, PAD_DIV, new ComponentTable.Builder(TABLE_STYLE).header("Field", "Value").content(new String[]{new String[]{"AUROC: Area under ROC:", String.format("%.5f", Double.valueOf(d))}, new String[]{"AUPRC: Area under P/R:", String.format("%.5f", Double.valueOf(d2))}, new String[]{"Total Data Positive Count", String.valueOf(j)}, new String[]{"Total Data Negative Count", String.valueOf(j2)}}).build(), PAD_DIV, INFO_TABLE), new ComponentDiv(INNER_DIV_STYLE, build));
    }

    private static Component getPRCharts(String str, String str2, PrecisionRecallCurve precisionRecallCurve) {
        return new ComponentDiv(OUTER_DIV_STYLE, new ComponentDiv(INNER_DIV_STYLE, getPrecisionRecallCurve(str, precisionRecallCurve)), new ComponentDiv(INNER_DIV_STYLE, getPrecisionRecallVsThreshold(str2, precisionRecallCurve)));
    }

    private static Component getPrecisionRecallCurve(String str, PrecisionRecallCurve precisionRecallCurve) {
        return new ChartLine.Builder(str, CHART_STYLE).setXMin(Double.valueOf(0.0d)).setXMax(Double.valueOf(1.0d)).setYMin(Double.valueOf(0.0d)).setYMax(Double.valueOf(1.0d)).addSeries("P vs R", precisionRecallCurve.getRecall(), precisionRecallCurve.getPrecision()).build();
    }

    private static Component getPrecisionRecallVsThreshold(String str, PrecisionRecallCurve precisionRecallCurve) {
        double[] recall = precisionRecallCurve.getRecall();
        double[] precision = precisionRecallCurve.getPrecision();
        double[] threshold = precisionRecallCurve.getThreshold();
        return new ChartLine.Builder(str, CHART_STYLE_PRECISION_RECALL).setXMin(Double.valueOf(0.0d)).setXMax(Double.valueOf(1.0d)).setYMin(Double.valueOf(0.0d)).setYMax(Double.valueOf(1.0d)).addSeries("Precision", threshold, precision).addSeries("Recall", threshold, recall).showLegend(true).build();
    }

    private static Component getHistogram(Histogram histogram) {
        ChartHistogram.Builder builder = new ChartHistogram.Builder(histogram.getTitle(), CHART_STYLE);
        double[] binLowerBounds = histogram.getBinLowerBounds();
        double[] binUpperBounds = histogram.getBinUpperBounds();
        int[] binCounts = histogram.getBinCounts();
        for (int i = 0; i < binCounts.length; i++) {
            builder.addBin(binLowerBounds[i], binUpperBounds[i], binCounts[i]);
        }
        return builder.build();
    }
}
