package de.jungblut.classification.eval;

import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import de.jungblut.classification.Classifier;
import de.jungblut.classification.ClassifierFactory;
import de.jungblut.classification.Predictor;
import de.jungblut.datastructure.ArrayUtils;
import de.jungblut.math.DoubleVector;
import de.jungblut.math.MathUtils;
import de.jungblut.partition.BlockPartitioner;
import de.jungblut.partition.Boundaries;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorCompletionService;
import java.util.concurrent.Executors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/classification/eval/Evaluator.class */
public final class Evaluator {
    private static final Logger LOG = LogManager.getLogger(Evaluator.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:de/jungblut/classification/eval/Evaluator$CallableEvaluation.class */
    public static class CallableEvaluation<A extends Classifier> implements Callable<EvaluationResult> {
        private final int fold;
        private final int[] splitRanges;
        private final int m;
        private final DoubleVector[] features;
        private final DoubleVector[] outcome;
        private final ClassifierFactory<A> classifierFactory;
        private final Double threshold;

        public CallableEvaluation(int i, int[] iArr, int i2, ClassifierFactory<A> classifierFactory, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, int i3, Double d) {
            this.fold = i;
            this.splitRanges = iArr;
            this.m = i2;
            this.classifierFactory = classifierFactory;
            this.features = doubleVectorArr;
            this.outcome = doubleVectorArr2;
            this.threshold = d;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public EvaluationResult call() throws Exception {
            DoubleVector[] doubleVectorArr = (DoubleVector[]) ArrayUtils.subArray(this.features, this.splitRanges[this.fold], this.splitRanges[this.fold + 1]);
            DoubleVector[] doubleVectorArr2 = (DoubleVector[]) ArrayUtils.subArray(this.outcome, this.splitRanges[this.fold], this.splitRanges[this.fold + 1]);
            DoubleVector[] doubleVectorArr3 = new DoubleVector[this.m - doubleVectorArr.length];
            DoubleVector[] doubleVectorArr4 = new DoubleVector[this.m - doubleVectorArr.length];
            int i = 0;
            for (int i2 = 0; i2 < this.m; i2++) {
                if (i2 < this.splitRanges[this.fold] || i2 > this.splitRanges[this.fold + 1]) {
                    doubleVectorArr3[i] = this.features[i2];
                    doubleVectorArr4[i] = this.outcome[i2];
                    i++;
                }
            }
            return Evaluator.evaluateSplit(this.classifierFactory.newInstance(), doubleVectorArr3, doubleVectorArr4, doubleVectorArr, doubleVectorArr2, this.threshold);
        }
    }

    /* loaded from: input_file:de/jungblut/classification/eval/Evaluator$EvaluationResult.class */
    public static class EvaluationResult {
        int numLabels;
        int correct;
        int testSize;
        int truePositive;
        int falsePositive;
        int trueNegative;
        int falseNegative;
        int[][] confusionMatrix;
        double auc;
        double logLoss;

        public double getAUC() {
            return this.auc;
        }

        public double getLogLoss() {
            return (-this.logLoss) / this.testSize;
        }

        public double getPrecision() {
            return this.truePositive / (this.truePositive + this.falsePositive);
        }

        public double getRecall() {
            return this.truePositive / (this.truePositive + this.falseNegative);
        }

        public double getFalsePositiveRate() {
            return this.falsePositive / (this.falsePositive + this.trueNegative);
        }

        public double getAccuracy() {
            return isBinary() ? (this.truePositive + this.trueNegative) / (((this.truePositive + this.trueNegative) + this.falsePositive) + this.falseNegative) : this.correct / this.testSize;
        }

        public double getF1Score() {
            return (2.0d * (getPrecision() * getRecall())) / (getPrecision() + getRecall());
        }

        public int getCorrect() {
            return !isBinary() ? this.correct : this.truePositive + this.trueNegative;
        }

        public int getNumLabels() {
            return this.numLabels;
        }

        public int getTestSize() {
            return this.testSize;
        }

        public int[][] getConfusionMatrix() {
            return this.confusionMatrix;
        }

        public boolean isBinary() {
            return this.numLabels == 2;
        }

        public void add(EvaluationResult evaluationResult) {
            this.correct += evaluationResult.correct;
            this.testSize += evaluationResult.testSize;
            this.truePositive += evaluationResult.truePositive;
            this.falsePositive += evaluationResult.falsePositive;
            this.trueNegative += evaluationResult.trueNegative;
            this.falseNegative += evaluationResult.falseNegative;
            this.auc += evaluationResult.auc;
            this.logLoss += evaluationResult.logLoss;
            if (this.confusionMatrix == null && evaluationResult.confusionMatrix != null) {
                this.confusionMatrix = evaluationResult.confusionMatrix;
                return;
            }
            if (this.confusionMatrix == null || evaluationResult.confusionMatrix == null) {
                return;
            }
            for (int i = 0; i < this.numLabels; i++) {
                for (int i2 = 0; i2 < this.numLabels; i2++) {
                    int[] iArr = this.confusionMatrix[i];
                    int i3 = i2;
                    iArr[i3] = iArr[i3] + evaluationResult.confusionMatrix[i][i2];
                }
            }
        }

        public void average(int i) {
            this.correct /= i;
            this.testSize /= i;
            this.truePositive /= i;
            this.falsePositive /= i;
            this.trueNegative /= i;
            this.falseNegative /= i;
            this.auc /= i;
            this.logLoss /= i;
            if (this.confusionMatrix != null) {
                for (int i2 = 0; i2 < this.numLabels; i2++) {
                    for (int i3 = 0; i3 < this.numLabels; i3++) {
                        int[] iArr = this.confusionMatrix[i2];
                        int i4 = i3;
                        iArr[i4] = iArr[i4] / i;
                    }
                }
            }
        }

        public int getTruePositive() {
            return this.truePositive;
        }

        public int getFalsePositive() {
            return this.falsePositive;
        }

        public int getTrueNegative() {
            return this.trueNegative;
        }

        public int getFalseNegative() {
            return this.falseNegative;
        }

        public void print() {
            print(Evaluator.LOG);
        }

        public void print(Logger logger) {
            logger.info("Number of labels: " + getNumLabels());
            logger.info("Testset size: " + getTestSize());
            logger.info("Correctly classified: " + getCorrect());
            logger.info("Accuracy: " + getAccuracy());
            logger.info("Log loss: " + getLogLoss());
            if (!isBinary()) {
                printConfusionMatrix();
                return;
            }
            logger.info("TP: " + this.truePositive);
            logger.info("FP: " + this.falsePositive);
            logger.info("TN: " + this.trueNegative);
            logger.info("FN: " + this.falseNegative);
            logger.info("Precision: " + getPrecision());
            logger.info("Recall: " + getRecall());
            logger.info("F1 Score: " + getF1Score());
            logger.info("AUC: " + getAUC());
        }

        public void printConfusionMatrix() {
            printConfusionMatrix(null);
        }

        public void printConfusionMatrix(String[] strArr) {
            Preconditions.checkNotNull(this.confusionMatrix, "No confusion matrix found.");
            if (strArr != null) {
                Preconditions.checkArgument(strArr.length == getNumLabels(), "Passed class names doesn't match with number of labels! Expected " + getNumLabels() + " but was " + strArr.length);
            }
            System.out.println("\nConfusion matrix (real outcome on rows, prediction in columns)\n");
            for (int i = 0; i < getNumLabels(); i++) {
                System.out.format("%5d", Integer.valueOf(i));
            }
            System.out.format(" <- %5s %5s\t%s\n", "sum", "perc", "class");
            for (int i2 = 0; i2 < getNumLabels(); i2++) {
                int i3 = 0;
                for (int i4 = 0; i4 < getNumLabels(); i4++) {
                    if (i2 != i4) {
                        i3 += this.confusionMatrix[i2][i4];
                    }
                    System.out.format("%5d", Integer.valueOf(this.confusionMatrix[i2][i4]));
                }
                System.out.format(" <- %5s %5s\t%s\n", Integer.valueOf(i3), NumberFormat.getPercentInstance().format(i3 / (i3 + this.confusionMatrix[i2][i2])), strArr != null ? " " + i2 + " (" + strArr[i2] + ")" : " " + i2);
            }
        }
    }

    private Evaluator() {
        throw new IllegalAccessError();
    }

    public static EvaluationResult evaluateClassifier(Classifier classifier, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, float f, boolean z) {
        return evaluateClassifier(classifier, doubleVectorArr, doubleVectorArr2, f, z, null);
    }

    public static EvaluationResult evaluateClassifier(Classifier classifier, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, float f, boolean z, Double d) {
        return evaluateSplit(classifier, EvaluationSplit.create(doubleVectorArr, doubleVectorArr2, f, z), d);
    }

    public static EvaluationResult evaluateSplit(Classifier classifier, EvaluationSplit evaluationSplit) {
        return evaluateSplit(classifier, evaluationSplit.getTrainFeatures(), evaluationSplit.getTrainOutcome(), evaluationSplit.getTestFeatures(), evaluationSplit.getTestOutcome(), null);
    }

    public static EvaluationResult evaluateSplit(Classifier classifier, EvaluationSplit evaluationSplit, Double d) {
        return evaluateSplit(classifier, evaluationSplit.getTrainFeatures(), evaluationSplit.getTrainOutcome(), evaluationSplit.getTestFeatures(), evaluationSplit.getTestOutcome(), d);
    }

    public static EvaluationResult evaluateSplit(Classifier classifier, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, DoubleVector[] doubleVectorArr3, DoubleVector[] doubleVectorArr4, Double d) {
        classifier.train(doubleVectorArr, doubleVectorArr2);
        return testClassifier(classifier, doubleVectorArr3, doubleVectorArr4, d);
    }

    public static EvaluationResult testClassifier(Predictor predictor, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2) {
        return testClassifier(predictor, doubleVectorArr, doubleVectorArr2, null);
    }

    public static EvaluationResult testClassifier(Predictor predictor, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, Double d) {
        EvaluationResult evaluationResult = new EvaluationResult();
        evaluationResult.numLabels = Math.max(2, doubleVectorArr2[0].getDimension());
        evaluationResult.testSize = doubleVectorArr2.length;
        if (evaluationResult.isBinary()) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < doubleVectorArr.length; i++) {
                DoubleVector doubleVector = doubleVectorArr2[i];
                DoubleVector predict = predictor.predict(doubleVectorArr[i]);
                arrayList.add(MathUtils.PredictionOutcomePair.from(observeBinaryClassificationElement(predictor, d, evaluationResult, doubleVector, predict), predict.get(0)));
            }
            evaluationResult.auc = MathUtils.computeAUC(arrayList);
        } else {
            int[][] iArr = new int[evaluationResult.numLabels][evaluationResult.numLabels];
            for (int i2 = 0; i2 < doubleVectorArr.length; i2++) {
                DoubleVector predict2 = predictor.predict(doubleVectorArr[i2]);
                DoubleVector doubleVector2 = doubleVectorArr2[i2];
                evaluationResult.logLoss += doubleVector2.multiply(MathUtils.logVector(predict2)).sum();
                int maxIndex = doubleVector2.maxIndex();
                int extractPredictedClass = predictor.extractPredictedClass(predict2);
                int[] iArr2 = iArr[maxIndex];
                iArr2[extractPredictedClass] = iArr2[extractPredictedClass] + 1;
                if (maxIndex == extractPredictedClass) {
                    evaluationResult.correct++;
                }
            }
            evaluationResult.confusionMatrix = iArr;
        }
        return evaluationResult;
    }

    public static int observeBinaryClassificationElement(Predictor predictor, Double d, EvaluationResult evaluationResult, DoubleVector doubleVector, DoubleVector doubleVector2) {
        int i = (int) doubleVector.get(0);
        evaluationResult.logLoss += doubleVector.multiply(MathUtils.logVector(doubleVector2)).sum();
        int extractPredictedClass = d == null ? predictor.extractPredictedClass(doubleVector2) : predictor.extractPredictedClass(doubleVector2, d.doubleValue());
        if (i == 1) {
            if (extractPredictedClass == 1) {
                evaluationResult.truePositive++;
            } else {
                evaluationResult.falseNegative++;
            }
        } else {
            if (i != 0) {
                throw new IllegalArgumentException("Outcome class was neither 0 or 1. Was: " + i + "; the supplied outcome value was: " + doubleVector.get(0));
            }
            if (extractPredictedClass == 0) {
                evaluationResult.trueNegative++;
            } else {
                evaluationResult.falsePositive++;
            }
        }
        return i;
    }

    public static <A extends Classifier> EvaluationResult crossValidateClassifier(ClassifierFactory<A> classifierFactory, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, int i, int i2, Double d, boolean z) {
        return crossValidateClassifier(classifierFactory, doubleVectorArr, doubleVectorArr2, i, i2, d, 1, z);
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [de.jungblut.math.DoubleVector[], java.lang.Object[][]] */
    public static <A extends Classifier> EvaluationResult crossValidateClassifier(ClassifierFactory<A> classifierFactory, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, int i, int i2, Double d, int i3, boolean z) {
        int i4 = i2 + 1;
        ArrayUtils.multiShuffle(doubleVectorArr, new DoubleVector[]{doubleVectorArr2});
        EvaluationResult evaluationResult = new EvaluationResult();
        evaluationResult.numLabels = i;
        int length = doubleVectorArr.length;
        ArrayList arrayList = new ArrayList(new BlockPartitioner().partition(i4, length).getBoundaries());
        int[] iArr = new int[i4];
        for (int i5 = 1; i5 < i4; i5++) {
            iArr[i5] = ((Boundaries.Range) arrayList.get(i5)).getEnd();
        }
        iArr[i4 - 1] = iArr[i4 - 1] - 1;
        if (z) {
            LOG.info("Computed split ranges: " + Arrays.toString(iArr) + "\n");
        }
        ExecutorCompletionService executorCompletionService = new ExecutorCompletionService(Executors.newFixedThreadPool(i3, new ThreadFactoryBuilder().setDaemon(true).build()));
        for (int i6 = 0; i6 < i2; i6++) {
            executorCompletionService.submit(new CallableEvaluation(i6, iArr, length, classifierFactory, doubleVectorArr, doubleVectorArr2, i2, d));
        }
        for (int i7 = 0; i7 < i2; i7++) {
            try {
                EvaluationResult evaluationResult2 = (EvaluationResult) executorCompletionService.take().get();
                if (z) {
                    LOG.info("Fold: " + (i7 + 1));
                    evaluationResult2.print();
                    LOG.info("");
                }
                evaluationResult.add(evaluationResult2);
            } catch (InterruptedException e) {
                e.printStackTrace();
            } catch (ExecutionException e2) {
                e2.printStackTrace();
            }
        }
        evaluationResult.average(i2);
        return evaluationResult;
    }

    public static <A extends Classifier> EvaluationResult tenFoldCrossValidation(ClassifierFactory<A> classifierFactory, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, int i, Double d, boolean z) {
        return crossValidateClassifier(classifierFactory, doubleVectorArr, doubleVectorArr2, i, 10, d, z);
    }

    public static <A extends Classifier> EvaluationResult tenFoldCrossValidation(ClassifierFactory<A> classifierFactory, DoubleVector[] doubleVectorArr, DoubleVector[] doubleVectorArr2, int i, Double d, int i2, boolean z) {
        return crossValidateClassifier(classifierFactory, doubleVectorArr, doubleVectorArr2, i, 10, d, i2, z);
    }
}
