package de.datexis.sector.eval;

import de.datexis.annotator.AnnotatorEvaluation;
import de.datexis.common.AnnotationHelpers;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Document;
import de.datexis.model.Span;
import de.datexis.model.tag.Tag;
import java.io.Serializable;
import java.util.Collection;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.nd4j.evaluation.EvaluationAveraging;
import org.nd4j.evaluation.EvaluationUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/sector/eval/ClassificationEvaluation.class */
public class ClassificationEvaluation extends AnnotatorEvaluation implements IEvaluation<ClassificationEvaluation> {
    protected LookupCacheEncoder encoder;
    protected int numClasses;
    protected int K;
    protected Evaluation eval;
    protected double mrrsum;
    protected double mapsum;
    protected double p1sum;
    protected double r1sum;
    protected double pksum;
    protected double rksum;
    static final /* synthetic */ boolean $assertionsDisabled;

    public ClassificationEvaluation(String str, LookupCacheEncoder lookupCacheEncoder) {
        this(str, Annotation.Source.GOLD, Annotation.Source.PRED, lookupCacheEncoder, 3);
    }

    public ClassificationEvaluation(String str, Annotation.Source source, Annotation.Source source2, LookupCacheEncoder lookupCacheEncoder, int i) {
        super(str, source, source2);
        this.mrrsum = 0.0d;
        this.mapsum = 0.0d;
        this.p1sum = 0.0d;
        this.r1sum = 0.0d;
        this.pksum = 0.0d;
        this.rksum = 0.0d;
        this.K = i;
        this.encoder = lookupCacheEncoder;
        this.numClasses = (int) lookupCacheEncoder.getEmbeddingVectorSize();
        this.log = LoggerFactory.getLogger(ClassificationEvaluation.class);
        clear();
    }

    protected void clear() {
        this.eval = new Evaluation(this.encoder.getWords(), this.K);
        this.countDocs = 0;
        this.countExamples = 0;
        this.mrrsum = 0.0d;
        this.mapsum = 0.0d;
        this.p1sum = 0.0d;
        this.r1sum = 0.0d;
        this.pksum = 0.0d;
        this.rksum = 0.0d;
    }

    public double getScore() {
        return getMAP();
    }

    public void calculateScores(Collection<Document> collection) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public void calculateScoresFromAnnotations(Collection<Document> collection, Class<? extends Annotation> cls, boolean z) {
        IdentityHashMap identityHashMap = new IdentityHashMap();
        this.countDocs += collection.size();
        for (Document document : collection) {
            for (Annotation annotation : document.getAnnotations(this.expectedSource, cls)) {
                Optional annotationMaxOverlap = AnnotationHelpers.getAnnotationMaxOverlap(document, this.predictedSource, cls, annotation);
                if (annotationMaxOverlap.isPresent()) {
                    identityHashMap.put(annotationMaxOverlap.get(), true);
                    evalExample(annotation.getVector(this.encoder.getClass()).transpose(), ((Annotation) annotationMaxOverlap.get()).getVector(this.encoder.getClass()).transpose());
                } else {
                    this.log.warn("Could not match predicted Annotation for expected Annotation {}-{}", Integer.valueOf(annotation.getBegin()), Integer.valueOf(annotation.getEnd()));
                }
            }
            if (z) {
                for (Annotation annotation2 : document.getAnnotations(this.predictedSource, cls)) {
                    if (!identityHashMap.containsKey(annotation2)) {
                        Optional annotationMaxOverlap2 = AnnotationHelpers.getAnnotationMaxOverlap(document, this.expectedSource, cls, annotation2);
                        if (annotationMaxOverlap2.isPresent()) {
                            evalExample(((Annotation) annotationMaxOverlap2.get()).getVector(this.encoder.getClass()).transpose(), annotation2.getVector(this.encoder.getClass()).transpose());
                        }
                    }
                }
            }
        }
    }

    public <T extends Tag> void calculateScoresFromTags(Collection<Document> collection, Class<? extends Span> cls, Class<T> cls2) {
        this.countDocs += collection.size();
        for (Document document : collection) {
            for (Span span : (List) document.getStream(cls).collect(Collectors.toList())) {
                Tag tag = span.getTag(this.expectedSource, cls2);
                Tag tag2 = span.getTag(this.predictedSource, cls2);
                if (tag == null || tag2 == null) {
                    this.log.warn("Skipped sentence without label: docId={} {}-{}", new Object[]{document.getId(), Integer.valueOf(span.getBegin()), Integer.valueOf(span.getEnd())});
                } else {
                    evalExample(tag.getVector().transpose(), tag2.getVector().transpose());
                }
            }
        }
    }

    public void evalExample(INDArray iNDArray, INDArray iNDArray2) {
        INDArray[] sortWithIndices = Nd4j.sortWithIndices(Nd4j.toFlattened(new INDArray[]{iNDArray2}).dup(), 1, false);
        if (sortWithIndices[0].sumNumber().doubleValue() == 0.0d) {
            this.log.warn("Sort on zero vector - please check vector dimensions!");
        }
        INDArray iNDArray3 = sortWithIndices[0];
        this.eval.eval(iNDArray, iNDArray2);
        this.mapsum += AP(iNDArray, iNDArray2, iNDArray3);
        this.mrrsum += RR(iNDArray, iNDArray2, iNDArray3);
        this.p1sum += Prec(iNDArray, iNDArray2, iNDArray3, 1);
        this.r1sum += Rec(iNDArray, iNDArray2, iNDArray3, 1);
        this.pksum += Prec(iNDArray, iNDArray2, iNDArray3, this.K);
        this.rksum += Rec(iNDArray, iNDArray2, iNDArray3, this.K);
        this.countExamples++;
    }

    protected double div(double d, double d2) {
        if (d2 == 0.0d) {
            return 0.0d;
        }
        return d / d2;
    }

    protected static int rank(int i, INDArray iNDArray) {
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            if (iNDArray.getInt(new int[]{i2}) == i) {
                return i2 + 1;
            }
        }
        throw new IllegalArgumentException("index does not exist in labels");
    }

    private double RR(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (maxIndex(iNDArray) >= 0) {
            return 1.0d / rank(r0, iNDArray3);
        }
        return 0.0d;
    }

    private double AP(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            if (iNDArray.getDouble(iNDArray3.getInt(new int[]{i2})) > 0.0d) {
                d += Prec(iNDArray, iNDArray2, iNDArray3, i2 + 1);
                i++;
            }
        }
        if (!$assertionsDisabled && i != 1) {
            throw new AssertionError();
        }
        if (i > 0) {
            return d / i;
        }
        return 0.0d;
    }

    private double Prec(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            if (iNDArray.getDouble(iNDArray3.getInt(new int[]{i2})) > 0.0d) {
                d += 1.0d;
            }
        }
        return d / i;
    }

    private double Rec(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        if (iNDArray.sumNumber().doubleValue() == 0.0d) {
            return 0.0d;
        }
        double d = 0.0d;
        for (int i2 = 0; i2 < i; i2++) {
            if (iNDArray.getDouble(iNDArray3.getInt(new int[]{i2})) > 0.0d) {
                d += 1.0d;
            }
        }
        return d / iNDArray.sumNumber().doubleValue();
    }

    protected static int maxIndex(INDArray iNDArray) {
        int i = -1;
        double d = Double.MIN_VALUE;
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            if (iNDArray.getDouble(i2) > d) {
                d = iNDArray.getDouble(i2);
                i = i2;
            }
        }
        return i;
    }

    public double getAccuracy() {
        return this.eval.accuracy();
    }

    public double getAccuracyK() {
        return this.eval.topNAccuracy();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getAccuracy(int i) {
        return div(((Integer) this.eval.truePositives().get(Integer.valueOf(i))).intValue(), ((Integer) this.eval.positive().get(Integer.valueOf(i))).intValue());
    }

    public double getMicroPrecision() {
        return this.eval.precision(EvaluationAveraging.Micro);
    }

    public double getMacroPrecision() {
        double d = 0.0d;
        for (int i = 0; i < this.numClasses; i++) {
            d += getPrecision(i);
        }
        return d / this.numClasses;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getPrecision(int i) {
        return this.eval.precision(Integer.valueOf(i));
    }

    public double getMicroRecall() {
        return this.eval.recall(EvaluationAveraging.Micro);
    }

    public double getMacroRecall() {
        double d = 0.0d;
        for (int i = 0; i < this.numClasses; i++) {
            d += getRecall(i);
        }
        return d / this.numClasses;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getRecall(int i) {
        return this.eval.recall(i);
    }

    public double getMicroF1() {
        return this.eval.f1(EvaluationAveraging.Micro);
    }

    public double getMacroF1() {
        double d = 0.0d;
        for (int i = 0; i < this.numClasses; i++) {
            d += getF1(i);
        }
        return d / this.numClasses;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getF1(int i) {
        return this.eval.f1(i);
    }

    protected double getMRR() {
        return this.mrrsum / this.countExamples;
    }

    public double getMAP() {
        return this.mapsum / this.countExamples;
    }

    public double getPrecisionK() {
        return this.pksum / this.countExamples;
    }

    public double getRecallK() {
        return this.rksum / this.countExamples;
    }

    public double getPrecision1() {
        return this.p1sum / this.countExamples;
    }

    public double getRecall1() {
        return this.r1sum / this.countExamples;
    }

    public void eval(INDArray iNDArray, INDArray iNDArray2) {
        for (int i = 0; i < iNDArray.rows(); i++) {
            evalExample(iNDArray.getRow(i), iNDArray2.getRow(i));
        }
    }

    public void eval(INDArray iNDArray, INDArray iNDArray2, List<? extends Serializable> list) {
        eval(iNDArray, iNDArray2);
    }

    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, List<? extends Serializable> list) {
        eval(iNDArray, iNDArray2, iNDArray3);
    }

    public void eval(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (iNDArray3 == null) {
            if (iNDArray.rank() == 3) {
                evalTimeSeries(iNDArray, iNDArray2, iNDArray3);
                return;
            } else {
                eval(iNDArray, iNDArray2);
                return;
            }
        }
        if (iNDArray.rank() != 3 || iNDArray3.rank() != 2) {
            throw new UnsupportedOperationException(getClass().getSimpleName() + " does not support per-output masking");
        }
        evalTimeSeries(iNDArray, iNDArray2, iNDArray3);
    }

    public void evalTimeSeries(INDArray iNDArray, INDArray iNDArray2) {
        evalTimeSeries(iNDArray, iNDArray2, null);
    }

    public void evalTimeSeries(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        Pair extractNonMaskedTimeSteps = EvaluationUtils.extractNonMaskedTimeSteps(iNDArray, iNDArray2, iNDArray3);
        if (extractNonMaskedTimeSteps == null) {
            return;
        }
        eval((INDArray) extractNonMaskedTimeSteps.getFirst(), (INDArray) extractNonMaskedTimeSteps.getSecond());
    }

    public void merge(ClassificationEvaluation classificationEvaluation) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public void reset() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String stats() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String toJson() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public String toYaml() {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public double getValue(IMetric iMetric) {
        return getScore();
    }

    /* renamed from: newInstance, reason: merged with bridge method [inline-methods] */
    public ClassificationEvaluation m13newInstance() {
        return new ClassificationEvaluation(null, null);
    }

    public String printClassificationAtKStats() {
        StringBuilder sb = new StringBuilder();
        sb.append(" Acc@1\t Acc@").append(this.K).append("\t P@1\t P@").append(this.K).append("\t R@1\t R@").append(this.K).append("\t MAP\n");
        sb.append(fDbl(getAccuracy())).append("\t");
        sb.append(fDbl(getAccuracyK())).append("\t");
        sb.append(fDbl(getPrecision1())).append("\t");
        sb.append(fDbl(getPrecisionK())).append("\t");
        sb.append(fDbl(getRecall1())).append("\t");
        sb.append(fDbl(getRecallK())).append("\t");
        sb.append(fDbl(getMAP())).append("\t");
        sb.append("\n");
        return sb.toString();
    }

    public String printClassificationStats() {
        StringBuilder sb = new StringBuilder();
        sb.append(" count\t TP\t FP\t MRR\t P@1\t MAP\t mPrec\t mRec\t mF1\n");
        sb.append(fInt(countExamples())).append("\t");
        sb.append(fInt(this.eval.getTruePositives().totalCount())).append("\t");
        sb.append(fInt(this.eval.getFalsePositives().totalCount())).append("\t");
        sb.append(fDbl(getMRR() / 100.0d)).append("\t");
        sb.append(fDbl(getAccuracy())).append("\t");
        sb.append(fDbl(getMAP())).append("\t");
        sb.append(fDbl(getMacroPrecision())).append("\t");
        sb.append(fDbl(getMacroRecall())).append("\t");
        sb.append(fDbl(getMacroF1())).append("\t");
        sb.append("\n");
        return sb.toString();
    }

    static {
        $assertionsDisabled = !ClassificationEvaluation.class.desiredAssertionStatus();
    }
}
