package de.datexis.ner.eval;

import com.google.common.collect.Lists;
import de.datexis.annotator.AnnotatorEvaluation;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Token;
import de.datexis.model.tag.BIO2Tag;
import de.datexis.ner.MentionAnnotation;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.TreeMap;
import java.util.stream.Collectors;
import org.nd4j.linalg.primitives.Counter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/ner/eval/MentionAnnotatorEvaluation.class */
public class MentionAnnotatorEvaluation extends AnnotatorEvaluation {
    protected static Logger log = LoggerFactory.getLogger(MentionAnnotatorEvaluation.class);
    protected TreeMap<AnnotatorEvaluation.Measure, Counter<Integer>> counts;
    Annotation.Match matchingStrategy;

    public MentionAnnotatorEvaluation(String str, Annotation.Match match) {
        this(str, Annotation.Source.GOLD, Annotation.Source.PRED, match);
    }

    public MentionAnnotatorEvaluation(String str, Annotation.Source source, Annotation.Source source2, Annotation.Match match) {
        super(str, source, source2);
        log = LoggerFactory.getLogger(MentionAnnotatorEvaluation.class);
        this.matchingStrategy = match;
        clear();
    }

    protected void clear() {
        this.counts = new TreeMap<>();
        this.counts.put(AnnotatorEvaluation.Measure.TP, new Counter<>());
        this.counts.put(AnnotatorEvaluation.Measure.FP, new Counter<>());
        this.counts.put(AnnotatorEvaluation.Measure.TN, new Counter<>());
        this.counts.put(AnnotatorEvaluation.Measure.FN, new Counter<>());
        this.countExamples = 0;
        this.countAnnotations = 0;
        this.countDocs = 0;
        this.countSentences = 0;
        this.countTokens = 0;
    }

    protected double getCount(AnnotatorEvaluation.Measure measure, int i) {
        return this.counts.get(measure).getCount(Integer.valueOf(i));
    }

    public double getScore() {
        return getMicroF1();
    }

    public void calculateScores(Dataset dataset) {
        calculateScoresFromAnnotations(dataset.getDocuments(), MentionAnnotation.class);
    }

    public void calculateScores(Collection<Document> collection) {
        calculateScoresFromAnnotations(collection, MentionAnnotation.class);
    }

    public void calculateScoresFromAnnotations(Collection<Document> collection, Class<? extends Annotation> cls) {
        int i = 0;
        for (Document document : collection) {
            this.counts.get(AnnotatorEvaluation.Measure.TP).setCount(Integer.valueOf(i), getTP(document, cls));
            this.counts.get(AnnotatorEvaluation.Measure.FP).setCount(Integer.valueOf(i), getFP(document, cls));
            this.counts.get(AnnotatorEvaluation.Measure.TN).setCount(Integer.valueOf(i), getTN(document, cls));
            this.counts.get(AnnotatorEvaluation.Measure.FN).setCount(Integer.valueOf(i), getFN(document, cls));
            this.countTokens += document.countTokens();
            this.countSentences += document.countSentences();
            this.countAnnotations = (int) (this.countAnnotations + document.countAnnotations(this.expectedSource, cls));
            this.countDocs++;
            i++;
        }
        fixCounters();
    }

    protected void fixCounters() {
        this.counts.get(AnnotatorEvaluation.Measure.TP).removeKey(-1);
        this.counts.get(AnnotatorEvaluation.Measure.FP).removeKey(-1);
        this.counts.get(AnnotatorEvaluation.Measure.TN).removeKey(-1);
        this.counts.get(AnnotatorEvaluation.Measure.FN).removeKey(-1);
    }

    public double getTP(Document document, Class<? extends Annotation> cls) {
        int i = 0;
        ArrayList<Annotation> newArrayList = Lists.newArrayList(document.streamAnnotations(this.predictedSource, cls).iterator());
        ArrayList newArrayList2 = Lists.newArrayList(document.streamAnnotations(this.expectedSource, cls).iterator());
        for (Annotation annotation : newArrayList) {
            Iterator it = newArrayList2.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (annotation.matches((Annotation) it.next(), this.matchingStrategy)) {
                    i++;
                    break;
                }
            }
        }
        this.countExamples += i;
        return i;
    }

    public double getFP(Document document, Class<? extends Annotation> cls) {
        int i = 0;
        ArrayList<Annotation> newArrayList = Lists.newArrayList(document.streamAnnotations(this.predictedSource, cls).iterator());
        ArrayList newArrayList2 = Lists.newArrayList(document.streamAnnotations(this.expectedSource, cls).iterator());
        for (Annotation annotation : newArrayList) {
            boolean z = false;
            Iterator it = newArrayList2.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (((Annotation) it.next()).matches(annotation, this.matchingStrategy)) {
                    z = true;
                    break;
                }
            }
            if (!z) {
                i++;
            }
        }
        this.countExamples += i;
        return i;
    }

    public double getTN(Document document, Class<? extends Annotation> cls) {
        return 0.0d;
    }

    public double getFN(Document document, Class<? extends Annotation> cls) {
        int i = 0;
        ArrayList newArrayList = Lists.newArrayList(document.streamAnnotations(this.predictedSource, cls).iterator());
        for (Annotation annotation : Lists.newArrayList(document.streamAnnotations(this.expectedSource, cls).iterator())) {
            boolean z = false;
            Iterator it = newArrayList.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (((Annotation) it.next()).matches(annotation, Annotation.Match.STRONG)) {
                    z = true;
                    break;
                }
            }
            if (!z) {
                i++;
            }
        }
        return i;
    }

    public double getTP() {
        return this.counts.get(AnnotatorEvaluation.Measure.TP).totalCount();
    }

    public double getFP() {
        return this.counts.get(AnnotatorEvaluation.Measure.FP).totalCount();
    }

    public double getTN() {
        return this.counts.get(AnnotatorEvaluation.Measure.TN).totalCount();
    }

    public double getFN() {
        return this.counts.get(AnnotatorEvaluation.Measure.FN).totalCount();
    }

    protected double div(double d, double d2) {
        if (d2 == 0.0d) {
            return 0.0d;
        }
        return d / d2;
    }

    public double getTokenAccuracy(Dataset dataset) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (Token token : (List) dataset.streamTokens().collect(Collectors.toList())) {
            if (token.getTag(this.expectedSource, BIO2Tag.class).get().equals(token.getTag(this.predictedSource, BIO2Tag.class).get())) {
                d2 += 1.0d;
            }
            d += 1.0d;
        }
        return d2 / d;
    }

    public double getAccuracy() {
        double tp = getTP();
        double tp2 = getTP() + getFN();
        if (tp2 > 0.0d) {
            return tp / tp2;
        }
        return 0.0d;
    }

    protected double getAccuracy(int i) {
        return div(getCount(AnnotatorEvaluation.Measure.TP, i), getCount(AnnotatorEvaluation.Measure.TP, i) + getCount(AnnotatorEvaluation.Measure.FN, i));
    }

    public double getMicroPrecision() {
        double tp = getTP();
        double tp2 = getTP() + getFP();
        if (tp2 > 0.0d) {
            return tp / tp2;
        }
        return 0.0d;
    }

    public double getMacroPrecision() {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this.countDocs; i2++) {
            if (getCount(AnnotatorEvaluation.Measure.FP, i2) > 0.0d) {
                d += getPrecision(i2);
                i++;
            }
        }
        if (i > 0) {
            return d / i;
        }
        return 0.0d;
    }

    protected double getPrecision(int i) {
        if (getCount(AnnotatorEvaluation.Measure.TP, i) == 0.0d) {
            return 0.0d;
        }
        return div(getCount(AnnotatorEvaluation.Measure.TP, i), getCount(AnnotatorEvaluation.Measure.TP, i) + getCount(AnnotatorEvaluation.Measure.FP, i));
    }

    public double getMicroRecall() {
        double tp = getTP();
        double tp2 = getTP() + getFN();
        if (tp2 > 0.0d) {
            return tp / tp2;
        }
        return 0.0d;
    }

    public double getMacroRecall() {
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < this.countDocs; i2++) {
            if (getCount(AnnotatorEvaluation.Measure.FN, i2) > 0.0d) {
                d += getRecall(i2);
                i++;
            }
        }
        if (i > 0) {
            return d / i;
        }
        return 0.0d;
    }

    protected double getRecall(int i) {
        if (getCount(AnnotatorEvaluation.Measure.TP, i) == 0.0d) {
            return 0.0d;
        }
        return div(getCount(AnnotatorEvaluation.Measure.TP, i), getCount(AnnotatorEvaluation.Measure.TP, i) + getCount(AnnotatorEvaluation.Measure.FN, i));
    }

    public double getMicroF1() {
        return getF1(getMicroPrecision(), getMicroRecall());
    }

    public double getMacroF1() {
        return getF1(getMacroPrecision(), getMacroRecall());
    }

    protected double getF1(int i) {
        return getF1(getPrecision(i), getRecall(i));
    }

    private double getF1(double d, double d2) {
        if (d + d2 == 0.0d) {
            return 0.0d;
        }
        return ((2.0d * d) * d2) / (d + d2);
    }

    public String printAnnotationStats() {
        return printHeader() + printRow();
    }

    public static String printHeader() {
        StringBuilder sb = new StringBuilder();
        sb.append("ANNOTATION [micro-avg]\n").append("Experiment ----------------------------------------\t#Docs\t#Tokns\t#Anns\t#Pred\t#TP\t#FP\t#TN\t#FN\tPrec\tRec\tF1");
        sb.append("\n");
        System.out.print(sb.toString());
        return sb.toString();
    }

    public String printRow() {
        StringBuilder sb = new StringBuilder();
        sb.append(fStr(this.experimentName, 50)).append("\t");
        sb.append(fInt(countDocuments())).append("\t");
        sb.append(fInt(countTokens())).append("\t");
        sb.append(fInt(countAnnotations())).append("\t");
        sb.append(fInt(countExamples())).append("\t");
        sb.append(fInt(getTP())).append("\t");
        sb.append(fInt(getFP())).append("\t");
        sb.append(fInt(getTN())).append("\t");
        sb.append(fInt(getFN())).append("\t");
        sb.append(fDbl(getMicroPrecision())).append("\t");
        sb.append(fDbl(getMicroRecall())).append("\t");
        sb.append(fDbl(getMicroF1())).append("\t");
        sb.append("\n");
        System.out.print(sb.toString());
        return sb.toString();
    }
}
