package org.nd4j.evaluation.classification;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Triple;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.EvaluationUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

/* loaded from: input_file:org/nd4j/evaluation/classification/EvaluationBinary.class */
public class EvaluationBinary extends BaseEvaluation<EvaluationBinary> {
    public static final int DEFAULT_PRECISION = 4;
    public static final double DEFAULT_EDGE_VALUE = 0.0d;
    protected int axis;
    private int[] countTruePositive;
    private int[] countFalsePositive;
    private int[] countTrueNegative;
    private int[] countFalseNegative;
    private ROCBinary rocBinary;
    private List<String> labels;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private INDArray decisionThreshold;

    /* loaded from: input_file:org/nd4j/evaluation/classification/EvaluationBinary$Metric.class */
    public enum Metric implements IMetric {
        ACCURACY,
        F1,
        PRECISION,
        RECALL,
        GMEASURE,
        MCC,
        FAR;

        @Override // org.nd4j.evaluation.IMetric
        public Class<? extends IEvaluation> getEvaluationClass() {
            return EvaluationBinary.class;
        }

        @Override // org.nd4j.evaluation.IMetric
        public boolean minimize() {
            return false;
        }
    }

    protected EvaluationBinary(int i, ROCBinary rOCBinary, List<String> list, INDArray iNDArray) {
        this.axis = 1;
        this.axis = i;
        this.rocBinary = rOCBinary;
        this.labels = list;
        this.decisionThreshold = iNDArray;
    }

    public EvaluationBinary(INDArray iNDArray) {
        this.axis = 1;
        if (iNDArray != null) {
            if (!iNDArray.isRowVectorOrScalar()) {
                throw new IllegalArgumentException("Decision threshold array must be a row vector; got array with shape " + Arrays.toString(iNDArray.shape()));
            }
            if (iNDArray.minNumber().doubleValue() < 0.0d) {
                throw new IllegalArgumentException("Invalid decision threshold array: minimum value is less than 0");
            }
            if (iNDArray.maxNumber().doubleValue() > 1.0d) {
                throw new IllegalArgumentException("invalid decision threshold array: maximum value is greater than 1.0");
            }
            this.decisionThreshold = iNDArray;
        }
    }

    public EvaluationBinary(int i, Integer num) {
        this.axis = 1;
        this.countTruePositive = new int[i];
        this.countFalsePositive = new int[i];
        this.countTrueNegative = new int[i];
        this.countFalseNegative = new int[i];
        if (num != null) {
            this.rocBinary = new ROCBinary(num.intValue());
        }
    }

    public void setAxis(int i) {
        this.axis = i;
    }

    public int getAxis() {
        return this.axis;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation, org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        eval(iNDArray, iNDArray2, (INDArray) null);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, List<? extends Serializable> list) {
        eval(iNDArray, iNDArray2, iNDArray3);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation, org.nd4j.evaluation.IEvaluation
    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray gt;
        long longValue = Nd4j.getExecutioner().execAndReturn((ReduceOp) new MatchCondition(iNDArray2, Conditions.isNan(), new int[0])).getFinalResult().longValue();
        Preconditions.checkState(longValue == 0, "Cannot perform evaluation with NaNs present in predictions: %s NaNs present in predictions INDArray", longValue);
        if (this.countTruePositive != null && this.countTruePositive.length != iNDArray.size(this.axis)) {
            throw new IllegalStateException("Labels array does not match stored state size. Expected labels array with size " + this.countTruePositive.length + ", got labels array with size " + iNDArray.size(this.axis) + " for axis " + this.axis);
        }
        Triple<INDArray, INDArray, INDArray> reshapeAndExtractNotMasked = BaseEvaluation.reshapeAndExtractNotMasked(iNDArray, iNDArray2, iNDArray3, this.axis);
        INDArray first = reshapeAndExtractNotMasked.getFirst();
        INDArray second = reshapeAndExtractNotMasked.getSecond();
        INDArray third = reshapeAndExtractNotMasked.getThird();
        if (first.dataType() != second.dataType()) {
            first = first.castTo(second.dataType());
        }
        if (this.decisionThreshold != null && this.decisionThreshold.dataType() != second.dataType()) {
            this.decisionThreshold = this.decisionThreshold.castTo(second.dataType());
        }
        if (this.decisionThreshold != null) {
            gt = Nd4j.createUninitialized(DataType.BOOL, second.shape());
            Nd4j.getExecutioner().exec((BroadcastOp) new BroadcastGreaterThan(second, this.decisionThreshold, gt, 1));
        } else {
            gt = second.gt(Double.valueOf(0.5d));
        }
        INDArray castTo = gt.castTo(second.dataType());
        INDArray rsub = first.rsub(Double.valueOf(1.0d));
        INDArray rsub2 = castTo.rsub(Double.valueOf(1.0d));
        INDArray mul = castTo.mul(first);
        INDArray mul2 = rsub2.mul(rsub);
        INDArray mul3 = castTo.mul(rsub);
        INDArray mul4 = rsub2.mul(first);
        if (third != null) {
            third = third.castTo(mul.dataType());
            mul.muli(third);
            mul2.muli(third);
            mul3.muli(third);
            mul4.muli(third);
        }
        int[] asInt = mul.sum(0).data().asInt();
        int[] asInt2 = mul2.sum(0).data().asInt();
        int[] asInt3 = mul3.sum(0).data().asInt();
        int[] asInt4 = mul4.sum(0).data().asInt();
        if (this.countTruePositive == null) {
            int length = asInt.length;
            this.countTruePositive = new int[length];
            this.countFalsePositive = new int[length];
            this.countTrueNegative = new int[length];
            this.countFalseNegative = new int[length];
        }
        addInPlace(this.countTruePositive, asInt);
        addInPlace(this.countFalsePositive, asInt3);
        addInPlace(this.countTrueNegative, asInt2);
        addInPlace(this.countFalseNegative, asInt4);
        if (this.rocBinary != null) {
            this.rocBinary.eval(first, second, third);
        }
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void merge(EvaluationBinary evaluationBinary) {
        if (evaluationBinary.countTruePositive == null) {
            return;
        }
        if (this.countTruePositive == null) {
            this.countTruePositive = evaluationBinary.countTruePositive;
            this.countFalsePositive = evaluationBinary.countFalsePositive;
            this.countTrueNegative = evaluationBinary.countTrueNegative;
            this.countFalseNegative = evaluationBinary.countFalseNegative;
            this.rocBinary = evaluationBinary.rocBinary;
            return;
        }
        if (this.countTruePositive.length != evaluationBinary.countTruePositive.length) {
            throw new IllegalStateException("Cannot merge EvaluationBinary instances with different sizes. This size: " + this.countTruePositive.length + ", other size: " + evaluationBinary.countTruePositive.length);
        }
        addInPlace(this.countTruePositive, evaluationBinary.countTruePositive);
        addInPlace(this.countTrueNegative, evaluationBinary.countTrueNegative);
        addInPlace(this.countFalsePositive, evaluationBinary.countFalsePositive);
        addInPlace(this.countFalseNegative, evaluationBinary.countFalseNegative);
        if (this.rocBinary != null) {
            this.rocBinary.merge(evaluationBinary.rocBinary);
        }
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public void reset() {
        this.countTruePositive = null;
    }

    private static void addInPlace(int[] iArr, int[] iArr2) {
        for (int i = 0; i < iArr.length; i++) {
            int i2 = i;
            iArr[i2] = iArr[i2] + iArr2[i];
        }
    }

    public int numLabels() {
        if (this.countTruePositive == null) {
            return -1;
        }
        return this.countTruePositive.length;
    }

    public void setLabelNames(List<String> list) {
        if (list == null) {
            this.labels = null;
        } else {
            this.labels = new ArrayList(list);
        }
    }

    public int totalCount(int i) {
        assertIndex(i);
        return this.countTruePositive[i] + this.countTrueNegative[i] + this.countFalseNegative[i] + this.countFalsePositive[i];
    }

    public int truePositives(int i) {
        assertIndex(i);
        return this.countTruePositive[i];
    }

    public int trueNegatives(int i) {
        assertIndex(i);
        return this.countTrueNegative[i];
    }

    public int falsePositives(int i) {
        assertIndex(i);
        return this.countFalsePositive[i];
    }

    public int falseNegatives(int i) {
        assertIndex(i);
        return this.countFalseNegative[i];
    }

    public double averageAccuracy() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += accuracy(i);
        }
        return d / numLabels();
    }

    public double accuracy(int i) {
        assertIndex(i);
        return (this.countTruePositive[i] + this.countTrueNegative[i]) / totalCount(i);
    }

    public double averagePrecision() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += precision(i);
        }
        return d / numLabels();
    }

    public double precision(int i) {
        assertIndex(i);
        return this.countTruePositive[i] / (this.countTruePositive[i] + this.countFalsePositive[i]);
    }

    public double averageRecall() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += recall(i);
        }
        return d / numLabels();
    }

    public double recall(int i) {
        assertIndex(i);
        return this.countTruePositive[i] / (this.countTruePositive[i] + this.countFalseNegative[i]);
    }

    public double averageF1() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += f1(i);
        }
        return d / numLabels();
    }

    public double fBeta(double d, int i) {
        assertIndex(i);
        return EvaluationUtils.fBeta(d, precision(i), recall(i));
    }

    public double f1(int i) {
        return fBeta(1.0d, i);
    }

    public double matthewsCorrelation(int i) {
        assertIndex(i);
        return EvaluationUtils.matthewsCorrelation(truePositives(i), falsePositives(i), falseNegatives(i), trueNegatives(i));
    }

    public double averageMatthewsCorrelation() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += matthewsCorrelation(i);
        }
        return d / numLabels();
    }

    public double gMeasure(int i) {
        return EvaluationUtils.gMeasure(precision(i), recall(i));
    }

    public double averageGMeasure() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += gMeasure(i);
        }
        return d / numLabels();
    }

    public double falsePositiveRate(int i) {
        assertIndex(i);
        return falsePositiveRate(i, 0.0d);
    }

    public double falsePositiveRate(int i, double d) {
        return EvaluationUtils.falsePositiveRate(falsePositives(i), trueNegatives(i), d);
    }

    public double falseNegativeRate(Integer num) {
        return falseNegativeRate(num, 0.0d);
    }

    public double falseNegativeRate(Integer num, double d) {
        return EvaluationUtils.falseNegativeRate(falseNegatives(num.intValue()), truePositives(num.intValue()), d);
    }

    public ROCBinary getROCBinary() {
        return this.rocBinary;
    }

    private void assertIndex(int i) {
        if (this.countTruePositive == null) {
            throw new UnsupportedOperationException("EvaluationBinary does not have any stats: eval must be called first");
        }
        if (i < 0 || i >= this.countTruePositive.length) {
            throw new IllegalArgumentException("Invalid input: output number must be between 0 and " + (i - 1) + ". Got index: " + i);
        }
    }

    public double averageFalseAlarmRate() {
        double d = 0.0d;
        for (int i = 0; i < numLabels(); i++) {
            d += falseAlarmRate(i);
        }
        return d / numLabels();
    }

    public double falseAlarmRate(int i) {
        assertIndex(i);
        return (falsePositiveRate(i) + falseNegativeRate(Integer.valueOf(i))) / 2.0d;
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public String stats() {
        return stats(4);
    }

    public String stats(int i) {
        StringBuilder sb = new StringBuilder();
        int i2 = 15;
        if (this.labels != null) {
            Iterator<String> it = this.labels.iterator();
            while (it.hasNext()) {
                i2 = Math.max(it.next().length(), i2);
            }
        }
        String str = "%-12." + i + "f";
        String str2 = "%-" + (i2 + 5) + "s" + str + str + str + str + "%-8d%-7d%-7d%-7d%-7d";
        String str3 = "%-" + (i2 + 5) + "s%-12s%-12s%-12s%-12s%-8s%-7s%-7s%-7s%-7s";
        List asList = Arrays.asList("Label", "Accuracy", "F1", "Precision", "Recall", "Total", "TP", "TN", "FP", "FN");
        if (this.rocBinary != null) {
            str3 = str3 + "%-12s";
            str2 = str2 + str;
            asList = new ArrayList(asList);
            asList.add("AUC");
        }
        sb.append(String.format(str3, asList.toArray()));
        if (this.countTrueNegative != null) {
            for (int i3 = 0; i3 < this.countTrueNegative.length; i3++) {
                List asList2 = Arrays.asList(this.labels == null ? String.valueOf(i3) : this.labels.get(i3), Double.valueOf(accuracy(i3)), Double.valueOf(f1(i3)), Double.valueOf(precision(i3)), Double.valueOf(recall(i3)), Integer.valueOf(totalCount(i3)), Integer.valueOf(truePositives(i3)), Integer.valueOf(trueNegatives(i3)), Integer.valueOf(falsePositives(i3)), Integer.valueOf(falseNegatives(i3)));
                if (this.rocBinary != null) {
                    asList2 = new ArrayList(asList2);
                    asList2.add(Double.valueOf(this.rocBinary.calculateAUC(i3)));
                }
                sb.append("\n").append(String.format(str2, asList2.toArray()));
            }
            if (this.decisionThreshold != null) {
                sb.append("\nPer-output decision thresholds: ").append(Arrays.toString(this.decisionThreshold.dup().data().asFloat()));
            }
        } else {
            sb.append("\n-- No Data --\n");
        }
        return sb.toString();
    }

    public double scoreForMetric(Metric metric, int i) {
        switch (metric) {
            case ACCURACY:
                return accuracy(i);
            case F1:
                return f1(i);
            case PRECISION:
                return precision(i);
            case RECALL:
                return recall(i);
            case GMEASURE:
                return gMeasure(i);
            case MCC:
                return matthewsCorrelation(i);
            case FAR:
                return falseAlarmRate(i);
            default:
                throw new IllegalStateException("Unknown metric: " + metric);
        }
    }

    public static EvaluationBinary fromJson(String str) {
        return (EvaluationBinary) fromJson(str, EvaluationBinary.class);
    }

    public static EvaluationBinary fromYaml(String str) {
        return (EvaluationBinary) fromYaml(str, EvaluationBinary.class);
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public double getValue(IMetric iMetric) {
        if (!(iMetric instanceof Metric)) {
            throw new IllegalStateException("Can't get value for non-binary evaluation Metric " + iMetric);
        }
        switch ((Metric) iMetric) {
            case ACCURACY:
                return averageAccuracy();
            case F1:
                return averageF1();
            case PRECISION:
                return averagePrecision();
            case RECALL:
                return averageRecall();
            case GMEASURE:
                return averageGMeasure();
            case MCC:
                return averageMatthewsCorrelation();
            case FAR:
                return averageFalseAlarmRate();
            default:
                throw new IllegalStateException("Can't get value for non-binary evaluation Metric " + iMetric);
        }
    }

    @Override // org.nd4j.evaluation.IEvaluation
    public EvaluationBinary newInstance() {
        return this.rocBinary != null ? new EvaluationBinary(this.axis, this.rocBinary.newInstance(), this.labels, this.decisionThreshold) : new EvaluationBinary(this.axis, null, this.labels, this.decisionThreshold);
    }

    public EvaluationBinary() {
        this.axis = 1;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof EvaluationBinary)) {
            return false;
        }
        EvaluationBinary evaluationBinary = (EvaluationBinary) obj;
        if (!evaluationBinary.canEqual(this) || !super.equals(obj) || !Arrays.equals(getCountTruePositive(), evaluationBinary.getCountTruePositive()) || !Arrays.equals(getCountFalsePositive(), evaluationBinary.getCountFalsePositive()) || !Arrays.equals(getCountTrueNegative(), evaluationBinary.getCountTrueNegative()) || !Arrays.equals(getCountFalseNegative(), evaluationBinary.getCountFalseNegative())) {
            return false;
        }
        ROCBinary rOCBinary = getROCBinary();
        ROCBinary rOCBinary2 = evaluationBinary.getROCBinary();
        if (rOCBinary == null) {
            if (rOCBinary2 != null) {
                return false;
            }
        } else if (!rOCBinary.equals(rOCBinary2)) {
            return false;
        }
        List<String> labels = getLabels();
        List<String> labels2 = evaluationBinary.getLabels();
        if (labels == null) {
            if (labels2 != null) {
                return false;
            }
        } else if (!labels.equals(labels2)) {
            return false;
        }
        INDArray decisionThreshold = getDecisionThreshold();
        INDArray decisionThreshold2 = evaluationBinary.getDecisionThreshold();
        return decisionThreshold == null ? decisionThreshold2 == null : decisionThreshold.equals(decisionThreshold2);
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    protected boolean canEqual(Object obj) {
        return obj instanceof EvaluationBinary;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public int hashCode() {
        int hashCode = (((((((super.hashCode() * 59) + Arrays.hashCode(getCountTruePositive())) * 59) + Arrays.hashCode(getCountFalsePositive())) * 59) + Arrays.hashCode(getCountTrueNegative())) * 59) + Arrays.hashCode(getCountFalseNegative());
        ROCBinary rOCBinary = getROCBinary();
        int hashCode2 = (hashCode * 59) + (rOCBinary == null ? 43 : rOCBinary.hashCode());
        List<String> labels = getLabels();
        int hashCode3 = (hashCode2 * 59) + (labels == null ? 43 : labels.hashCode());
        INDArray decisionThreshold = getDecisionThreshold();
        return (hashCode3 * 59) + (decisionThreshold == null ? 43 : decisionThreshold.hashCode());
    }

    public int[] getCountTruePositive() {
        return this.countTruePositive;
    }

    public int[] getCountFalsePositive() {
        return this.countFalsePositive;
    }

    public int[] getCountTrueNegative() {
        return this.countTrueNegative;
    }

    public int[] getCountFalseNegative() {
        return this.countFalseNegative;
    }

    public List<String> getLabels() {
        return this.labels;
    }

    public INDArray getDecisionThreshold() {
        return this.decisionThreshold;
    }

    public void setCountTruePositive(int[] iArr) {
        this.countTruePositive = iArr;
    }

    public void setCountFalsePositive(int[] iArr) {
        this.countFalsePositive = iArr;
    }

    public void setCountTrueNegative(int[] iArr) {
        this.countTrueNegative = iArr;
    }

    public void setCountFalseNegative(int[] iArr) {
        this.countFalseNegative = iArr;
    }

    public void setRocBinary(ROCBinary rOCBinary) {
        this.rocBinary = rOCBinary;
    }

    public void setLabels(List<String> list) {
        this.labels = list;
    }

    public void setDecisionThreshold(INDArray iNDArray) {
        this.decisionThreshold = iNDArray;
    }

    @Override // org.nd4j.evaluation.BaseEvaluation
    public String toString() {
        return "EvaluationBinary(axis=" + getAxis() + ", countTruePositive=" + Arrays.toString(getCountTruePositive()) + ", countFalsePositive=" + Arrays.toString(getCountFalsePositive()) + ", countTrueNegative=" + Arrays.toString(getCountTrueNegative()) + ", countFalseNegative=" + Arrays.toString(getCountFalseNegative()) + ", rocBinary=" + getROCBinary() + ", labels=" + getLabels() + ", decisionThreshold=" + getDecisionThreshold() + ")";
    }
}
