package de.citec.scie.classifiers;

import de.citec.scie.classifiers.data.ClassificationResult;
import de.citec.scie.classifiers.data.LabeledDataPoint;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.Reader;
import java.io.StringReader;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;

/* loaded from: input_file:de/citec/scie/classifiers/ClassifierEvaluation.class */
public class ClassifierEvaluation {
    public static final String TRAIN_SEPARATOR = "--- Training ---";
    public static final String TEST_SEPARATOR = "--- Test ---";
    public static final int DEFAULT_ROC_STEPS = 51;
    private final SetEvaluation trainEvaluation;
    private final SetEvaluation testEvaluation;

    /* loaded from: input_file:de/citec/scie/classifiers/ClassifierEvaluation$ConfidenceComparator.class */
    private static class ConfidenceComparator implements Comparator<Tuple<LabeledDataPoint, ClassificationResult>> {
        private ConfidenceComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Tuple<LabeledDataPoint, ClassificationResult> tuple, Tuple<LabeledDataPoint, ClassificationResult> tuple2) {
            return -Double.compare(tuple.getSecond().getConfidence(), tuple2.getSecond().getConfidence());
        }
    }

    /* loaded from: input_file:de/citec/scie/classifiers/ClassifierEvaluation$SetEvaluation.class */
    public static class SetEvaluation {
        public static final String N_STR = "N";
        private final int N;
        public static final String PRECISION_STR = "Precision";
        private final double precision;
        public static final String RECALL_STR = "Recall";
        private final double recall;
        public static final String F1_STR = "F1";
        public static final String ACC_STR = "Accuracy";
        private final double accuracy;
        public static final String ROC_THR_STR = "ROC Confidence Thresholds";
        private final double[] roc_thresholds;
        public static final String ROC_SPEC_STR = "ROC Specificities";
        private final double[] roc_specificities;
        public static final String ROC_REC_STR = "ROC Recalls";
        private final double[] roc_recalls;
        public static final String AREA_ROC_STR = "Area under ROC";
        private final double area_under_roc;

        public SetEvaluation(Classifier classifier, ArrayList<LabeledDataPoint> arrayList, int i) {
            this(classifier, arrayList, i, false);
        }

        public SetEvaluation(Classifier classifier, ArrayList<LabeledDataPoint> arrayList, int i, boolean z) {
            this.N = arrayList.size();
            ArrayList arrayList2 = new ArrayList(this.N);
            if (classifier instanceof LibLinearClassifier) {
                ((LibLinearClassifier) classifier).setSilentCorrect(false);
            }
            Iterator<LabeledDataPoint> it = arrayList.iterator();
            while (it.hasNext()) {
                LabeledDataPoint next = it.next();
                ClassificationResult classify = classifier.classify(next.getData());
                if (!z && (classifier instanceof LibLinearClassifier)) {
                    ((LibLinearClassifier) classifier).setSilentCorrect(true);
                }
                arrayList2.add(new Tuple(next, classify));
            }
            Collections.sort(arrayList2, new ConfidenceComparator());
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double d4 = 0.0d;
            if (z) {
                System.out.println("TrueLabel;ClassifierLabel;Confidence");
            }
            Iterator it2 = arrayList2.iterator();
            while (it2.hasNext()) {
                Tuple tuple = (Tuple) it2.next();
                if (z) {
                    System.out.println(((LabeledDataPoint) tuple.getFirst()).getLabel() + ";" + ((ClassificationResult) tuple.getSecond()).getLabel() + ";" + ((ClassificationResult) tuple.getSecond()).getConfidence());
                }
                if (((LabeledDataPoint) tuple.getFirst()).getLabel()) {
                    if (((ClassificationResult) tuple.getSecond()).getLabel()) {
                        d += 1.0d;
                    } else {
                        d3 += 1.0d;
                    }
                } else if (((ClassificationResult) tuple.getSecond()).getLabel()) {
                    d4 += 1.0d;
                } else {
                    d2 += 1.0d;
                }
            }
            if (d > 0.0d) {
                this.precision = d / (d + d4);
                this.recall = d / (d + d3);
            } else {
                this.precision = 0.0d;
                this.recall = 0.0d;
            }
            this.accuracy = (d + d2) / this.N;
            double d5 = 1.0d / (i - 1);
            this.roc_thresholds = new double[i];
            this.roc_recalls = new double[i];
            this.roc_specificities = new double[i];
            this.roc_thresholds[0] = 1.0d;
            double d6 = 0.0d;
            double d7 = 0.0d;
            double d8 = 0.0d;
            double d9 = 0.0d;
            int i2 = 0;
            Iterator<LabeledDataPoint> it3 = arrayList.iterator();
            while (it3.hasNext()) {
                if (it3.next().getLabel()) {
                    d8 += 1.0d;
                } else {
                    d7 += 1.0d;
                }
            }
            this.roc_specificities[0] = 1.0d;
            this.roc_recalls[i - 1] = 1.0d;
            for (int i3 = 1; i3 < i - 1; i3++) {
                this.roc_thresholds[i3] = this.roc_thresholds[i3 - 1] - d5;
                int i4 = i2;
                while (i4 < this.N) {
                    Tuple tuple2 = (Tuple) arrayList2.get(i4);
                    if (((ClassificationResult) tuple2.getSecond()).getConfidence() < this.roc_thresholds[i3]) {
                        break;
                    }
                    if (((LabeledDataPoint) tuple2.getFirst()).getLabel()) {
                        d6 += 1.0d;
                        d8 -= 1.0d;
                    } else {
                        d7 -= 1.0d;
                        d9 += 1.0d;
                    }
                    i4++;
                }
                if (d6 > 0.0d) {
                    this.roc_recalls[i3] = d6 / (d6 + d8);
                }
                if (d7 > 0.0d) {
                    this.roc_specificities[i3] = d7 / (d7 + d9);
                }
                i2 = i4;
            }
            double d10 = 0.0d;
            for (int i5 = 1; i5 < i; i5++) {
                d10 += (this.roc_specificities[i5 - 1] - this.roc_specificities[i5]) * 0.5d * (this.roc_recalls[i5 - 1] + this.roc_recalls[i5]);
            }
            this.area_under_roc = d10;
        }

        public SetEvaluation(int i, double d, double d2, double d3, double[] dArr, double[] dArr2, double[] dArr3, double d4) {
            this.N = i;
            this.precision = d;
            this.recall = d2;
            this.accuracy = d3;
            this.roc_thresholds = dArr;
            this.roc_specificities = dArr2;
            this.roc_recalls = dArr3;
            this.area_under_roc = d4;
        }

        public SetEvaluation(Reader reader) throws IOException {
            BufferedReader bufferedReader = reader instanceof BufferedReader ? (BufferedReader) reader : new BufferedReader(reader);
            HashMap<String, String> hashMap = new HashMap<>();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                String[] split = readLine.split(" = ");
                if (split.length == 2) {
                    hashMap.put(split[0], split[1]);
                }
            }
            this.N = retrieveInt(hashMap, N_STR);
            this.precision = retrieveDouble(hashMap, PRECISION_STR);
            this.recall = retrieveDouble(hashMap, RECALL_STR);
            this.accuracy = retrieveDouble(hashMap, ACC_STR);
            this.roc_thresholds = retrieveDoubleArr(hashMap, ROC_THR_STR);
            this.roc_specificities = retrieveDoubleArr(hashMap, ROC_SPEC_STR);
            this.roc_recalls = retrieveDoubleArr(hashMap, ROC_REC_STR);
            if (this.roc_thresholds.length != this.roc_specificities.length) {
                throw new UnsupportedOperationException("The input data was inconsistent: The number of confidence thresholds and the number of specificities did not match!");
            }
            if (this.roc_thresholds.length != this.roc_recalls.length) {
                throw new UnsupportedOperationException("The input data was inconsistent: The number of confidence thresholds and the number of recalls did not match!");
            }
            this.area_under_roc = retrieveDouble(hashMap, AREA_ROC_STR);
        }

        private int retrieveInt(HashMap<String, String> hashMap, String str) {
            String str2 = hashMap.get(str);
            if (str2 == null) {
                throw new UnsupportedOperationException(str + " could not be retrieved from input!");
            }
            try {
                return Integer.parseInt(str2);
            } catch (NumberFormatException e) {
                throw new UnsupportedOperationException(str + " could not be retrieved from input!", e);
            }
        }

        private double retrieveDouble(HashMap<String, String> hashMap, String str) {
            String str2 = hashMap.get(str);
            if (str2 == null) {
                throw new UnsupportedOperationException(str + " could not be retrieved from input!");
            }
            try {
                return Double.parseDouble(str2);
            } catch (NumberFormatException e) {
                throw new UnsupportedOperationException(str + " could not be retrieved from input!", e);
            }
        }

        private double[] retrieveDoubleArr(HashMap<String, String> hashMap, String str) {
            String str2 = hashMap.get(str);
            if (str2 == null) {
                throw new UnsupportedOperationException(str + " could not be retrieved from input!");
            }
            try {
                String[] split = str2.replaceAll("[\\[\\]]", "").split(",");
                double[] dArr = new double[split.length];
                for (int i = 0; i < split.length; i++) {
                    dArr[i] = Double.parseDouble(split[i]);
                }
                return dArr;
            } catch (NumberFormatException e) {
                throw new UnsupportedOperationException(str + " could not be retrieved from input!", e);
            }
        }

        public int getN() {
            return this.N;
        }

        public double getAccuracy() {
            return this.accuracy;
        }

        public double getPrecision() {
            return this.precision;
        }

        public double getRecall() {
            return this.recall;
        }

        public double getF1() {
            if (this.precision == 0.0d && this.recall == 0.0d) {
                return 0.0d;
            }
            return ((2.0d * this.precision) * this.recall) / (this.precision + this.recall);
        }

        public double[] getRoc_thresholds() {
            return this.roc_thresholds;
        }

        public double[] getRoc_specificities() {
            return this.roc_specificities;
        }

        public double[] getRoc_recalls() {
            return this.roc_recalls;
        }

        public double getArea_under_roc() {
            return this.area_under_roc;
        }

        public String toString() {
            return "N = " + this.N + "\n" + ACC_STR + " = " + this.accuracy + "\n" + PRECISION_STR + " = " + this.precision + "\n" + RECALL_STR + " = " + this.recall + "\n" + F1_STR + " = " + getF1() + "\n" + AREA_ROC_STR + " = " + this.area_under_roc;
        }

        public void store(Writer writer) throws IOException {
            writer.write("N = " + this.N + "\n");
            writer.write("Accuracy = " + this.accuracy + "\n");
            writer.write("Precision = " + this.precision + "\n");
            writer.write("Recall = " + this.recall + "\n");
            writer.write("F1 = " + getF1() + "\n");
            writer.write("ROC Confidence Thresholds = " + Arrays.toString(this.roc_thresholds) + "\n");
            writer.write("ROC Specificities = " + Arrays.toString(this.roc_specificities) + "\n");
            writer.write("ROC Recalls = " + Arrays.toString(this.roc_recalls) + "\n");
            writer.write("Area under ROC = " + this.area_under_roc);
        }

        public int hashCode() {
            return (31 * ((31 * ((31 * ((31 * ((31 * ((31 * ((31 * ((31 * 5) + this.N)) + ((int) (Double.doubleToLongBits(this.precision) ^ (Double.doubleToLongBits(this.precision) >>> 32))))) + ((int) (Double.doubleToLongBits(this.recall) ^ (Double.doubleToLongBits(this.recall) >>> 32))))) + ((int) (Double.doubleToLongBits(this.accuracy) ^ (Double.doubleToLongBits(this.accuracy) >>> 32))))) + Arrays.hashCode(this.roc_thresholds))) + Arrays.hashCode(this.roc_specificities))) + Arrays.hashCode(this.roc_recalls))) + ((int) (Double.doubleToLongBits(this.area_under_roc) ^ (Double.doubleToLongBits(this.area_under_roc) >>> 32)));
        }

        public boolean equals(Object obj) {
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            SetEvaluation setEvaluation = (SetEvaluation) obj;
            return this.N == setEvaluation.N && Double.doubleToLongBits(this.precision) == Double.doubleToLongBits(setEvaluation.precision) && Double.doubleToLongBits(this.recall) == Double.doubleToLongBits(setEvaluation.recall) && Double.doubleToLongBits(this.accuracy) == Double.doubleToLongBits(setEvaluation.accuracy) && Arrays.equals(this.roc_thresholds, setEvaluation.roc_thresholds) && Arrays.equals(this.roc_specificities, setEvaluation.roc_specificities) && Arrays.equals(this.roc_recalls, setEvaluation.roc_recalls) && Double.doubleToLongBits(this.area_under_roc) == Double.doubleToLongBits(setEvaluation.area_under_roc);
        }
    }

    public ClassifierEvaluation(Classifier classifier, ArrayList<LabeledDataPoint> arrayList, ArrayList<LabeledDataPoint> arrayList2) {
        this(classifier, arrayList, arrayList2, 51, false);
    }

    public ClassifierEvaluation(Classifier classifier, ArrayList<LabeledDataPoint> arrayList, ArrayList<LabeledDataPoint> arrayList2, boolean z) {
        this(classifier, arrayList, arrayList2, 51, z);
    }

    public ClassifierEvaluation(Classifier classifier, ArrayList<LabeledDataPoint> arrayList, ArrayList<LabeledDataPoint> arrayList2, int i) {
        this(classifier, arrayList, arrayList2, i, false);
    }

    public ClassifierEvaluation(Classifier classifier, ArrayList<LabeledDataPoint> arrayList, ArrayList<LabeledDataPoint> arrayList2, int i, boolean z) {
        this.trainEvaluation = new SetEvaluation(classifier, arrayList, i, z);
        this.testEvaluation = new SetEvaluation(classifier, arrayList2, i, z);
    }

    public ClassifierEvaluation(Reader reader) throws IOException {
        BufferedReader bufferedReader = reader instanceof BufferedReader ? (BufferedReader) reader : new BufferedReader(reader);
        StringBuilder sb = new StringBuilder();
        boolean z = false;
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                break;
            }
            if (readLine.equals(TRAIN_SEPARATOR)) {
                z = true;
            } else if (readLine.equals(TEST_SEPARATOR)) {
                break;
            } else if (z) {
                sb.append(readLine);
            }
        }
        this.trainEvaluation = new SetEvaluation(new BufferedReader(new StringReader(sb.toString())));
        this.testEvaluation = new SetEvaluation(bufferedReader);
    }

    public SetEvaluation getTestEvaluation() {
        return this.testEvaluation;
    }

    public SetEvaluation getTrainEvaluation() {
        return this.trainEvaluation;
    }

    public String toString() {
        return "Training Results:\n" + this.trainEvaluation.toString() + "\nTest Results:\n" + this.testEvaluation.toString();
    }

    public void store(Writer writer) throws IOException {
        writer.write(TRAIN_SEPARATOR);
        writer.write("\n");
        this.trainEvaluation.store(writer);
        writer.write("\n");
        writer.write(TEST_SEPARATOR);
        writer.write("\n");
        this.testEvaluation.store(writer);
    }
}
