package de.datexis.ner;

import com.google.common.collect.Lists;
import de.datexis.annotator.Annotator;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.common.Resource;
import de.datexis.common.Timer;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.tag.BIO2Tag;
import de.datexis.model.tag.BIOESTag;
import de.datexis.model.tag.Tag;
import de.datexis.ner.MentionAnnotation;
import de.datexis.ner.eval.HTMLExport;
import de.datexis.ner.eval.MentionAnnotatorEvaluation;
import de.datexis.ner.tagger.MentionTagger;
import de.datexis.tagger.Tagger;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/ner/MentionAnnotator.class */
public class MentionAnnotator extends Annotator {
    protected static final Logger log = LoggerFactory.getLogger(MentionAnnotator.class);
    protected Resource bestModel;

    /* loaded from: input_file:de/datexis/ner/MentionAnnotator$Builder.class */
    public static class Builder {
        protected String types = "GENERIC";
        protected Class tagset = BIOESTag.class;
        protected List<Encoder> encoders = new ArrayList();
        private int trainingSize = -1;
        private int ffwLayerSize = 300;
        private int lstmLayerSize = 100;
        private double learningRate = 0.001d;
        private int iterations = 1;
        private int batchSize = 16;
        private int numEpochs = 1;
        private int workers = 1;
        private boolean enabletrainingUI = false;
        MentionTagger tagger = new MentionTagger();
        MentionAnnotator ann = new MentionAnnotator(this.tagger);

        public Builder withModelParams(int i, int i2) {
            this.ffwLayerSize = i;
            this.lstmLayerSize = i2;
            return this;
        }

        public Builder withTrainingParams(double d, int i, int i2) {
            this.learningRate = d;
            this.batchSize = i;
            this.numEpochs = i2;
            return this;
        }

        public Builder withWorkspaceParams(int i) {
            this.workers = i;
            return this;
        }

        public Builder withTypes(MentionAnnotation.Type type) {
            this.types = type.toString();
            return this;
        }

        public Builder withTypes(String str) {
            this.types = str;
            return this;
        }

        public Builder withEncoders(String str, Encoder... encoderArr) {
            this.ann.getProvenance().setFeatures(str);
            withEncoders(encoderArr);
            return this;
        }

        public Builder withEncoders(Encoder... encoderArr) {
            this.encoders = Lists.newArrayList(encoderArr);
            this.ann.getProvenance().setArchitecture(this.encoders.toString());
            return this;
        }

        public Builder enableTrainingUI(boolean z) {
            this.enabletrainingUI = z;
            return this;
        }

        public Builder pretrain(Dataset dataset) {
            Iterator<Encoder> it = this.encoders.iterator();
            while (it.hasNext()) {
                it.next().trainModel(dataset.streamDocuments());
            }
            return this;
        }

        public MentionAnnotator build() {
            Iterator<Encoder> it = this.encoders.iterator();
            while (it.hasNext()) {
                AnnotatorComponent annotatorComponent = (Encoder) it.next();
                if (!annotatorComponent.isModelAvailable()) {
                    throw new IllegalArgumentException("encoder " + annotatorComponent.getId() + " has no model available, please consider pretrain()");
                }
                this.ann.addComponent(annotatorComponent);
            }
            this.tagger.setTagset(this.tagset, this.types);
            this.tagger.setEncoders(this.encoders);
            this.tagger.setModelParams(this.ffwLayerSize, this.lstmLayerSize, this.iterations, this.learningRate * this.batchSize);
            if (this.enabletrainingUI) {
                this.tagger.enableTrainingUI();
            }
            this.tagger.setTrainingParams(this.batchSize, this.numEpochs, true);
            this.tagger.setWorkspaceParams(this.workers);
            this.ann.getProvenance().setTask("NER-" + this.types);
            this.tagger.setName(this.ann.getProvenance().toString());
            return this.ann;
        }
    }

    public MentionAnnotator() {
    }

    public MentionAnnotator(Tagger tagger) {
        super(tagger);
    }

    /* renamed from: getTagger, reason: merged with bridge method [inline-methods] */
    public MentionTagger m6getTagger() {
        return (MentionTagger) this.tagger;
    }

    public String toString() {
        return getProvenance().toString();
    }

    public void annotate(Collection<Document> collection) {
        m6getTagger().tag(collection);
        createAnnotations(collection, Annotation.Source.PRED);
    }

    public void trainModel(Dataset dataset, Dataset dataset2, WordHelpers.Language language) {
        this.provenance.setDataset(dataset.getName());
        this.provenance.setLanguage(language.toString().toLowerCase());
        m6getTagger().setName(this.provenance.toString());
        createTags(dataset.getDocuments(), Annotation.Source.GOLD);
        m6getTagger().trainModel(dataset, Annotation.Source.GOLD);
        createTags(dataset2.getDocuments(), Annotation.Source.GOLD);
        m6getTagger().testModel(dataset2, Annotation.Source.GOLD);
    }

    public void trainModel(Dataset dataset, Annotation.Source source, WordHelpers.Language language) {
        trainModel(dataset, source, language, -1, true, true);
    }

    public void trainModel(Dataset dataset, Annotation.Source source, WordHelpers.Language language, int i, boolean z, boolean z2) {
        this.provenance.setDataset(dataset.getName());
        this.provenance.setLanguage(language.toString().toLowerCase());
        m6getTagger().setName(this.provenance.toString());
        createTags(dataset.getDocuments(), source);
        m6getTagger().trainModel(dataset, source, i, z2);
    }

    public void trainModelEarlyStopping(Dataset dataset, Dataset dataset2, Annotation.Source source, WordHelpers.Language language, int i, int i2, int i3, int i4) {
        this.provenance.setDataset(dataset.getName());
        this.provenance.setLanguage(language.toString().toLowerCase());
        m6getTagger().setName(this.provenance.toString());
        createTags(dataset.getDocuments(), source);
        Timer timer = new Timer();
        int i5 = 1;
        double d = 0.0d;
        int i6 = i4;
        timer.start();
        do {
            m6getTagger().appendTrainLog("\n");
            m6getTagger().appendTrainLog("EPOCH " + i5 + ": training " + this.tagger.getName());
            m6getTagger().trainModel(dataset, source, i, true);
            m6getTagger().appendTestLog("Testing epoch " + i5);
            annotate(dataset2.getDocuments());
            MentionAnnotatorEvaluation mentionAnnotatorEvaluation = new MentionAnnotatorEvaluation("TraiNER epoch " + i5, source, Annotation.Source.PRED, Annotation.Match.STRONG);
            mentionAnnotatorEvaluation.calculateScores(dataset2.getDocuments());
            mentionAnnotatorEvaluation.printAnnotationStats();
            double score = mentionAnnotatorEvaluation.getScore();
            timer.setSplit("epoch");
            m6getTagger().appendTrainLog("EPOCH " + i5 + " complete: score " + score, timer.getLong("epoch"));
            if (score >= d) {
                this.bestModel = Resource.createTempDirectory();
                try {
                    writeModel(this.bestModel, m6getTagger().getName());
                    FileUtils.writeStringToFile(this.bestModel.resolve("test_" + i5 + ".html").toFile(), new HTMLExport(dataset2.getDocuments(), BIOESTag.class, source, Annotation.Source.PRED).getHTML());
                } catch (IOException e) {
                    log.error("Could not write output: " + e.toString());
                }
                d = score;
                i6 = i4;
            } else {
                i6--;
            }
            i5++;
            if (i5 > i2 && i6 < 0) {
                break;
            }
        } while (i5 <= i3);
        timer.stop();
        m6getTagger().appendTrainLog("Training complete: " + this.tagger.getName() + " with score " + d, timer.getLong());
        m6getTagger().appendTrainLog("\n");
    }

    public void writeBestModel(Resource resource, String str) throws IOException {
        FileUtils.copyDirectory(this.bestModel.toFile(), resource.toFile());
    }

    public void trainModel(Collection<Sentence> collection, Annotation.Source source, WordHelpers.Language language) {
        this.provenance.setLanguage(language.toString().toLowerCase());
        m6getTagger().setName(this.provenance.toString());
        m6getTagger().trainModel(collection, source, true);
    }

    protected void createTags(Iterable<Document> iterable, Annotation.Source source) {
        for (Document document : iterable) {
            if (!document.isTagAvaliable(source, BIOESTag.class) && document.isTagAvaliable(source, BIO2Tag.class)) {
                BIO2Tag.convertToBIOES(document, source);
                document.setTagAvailable(source, BIOESTag.class, true);
            } else if (!document.isTagAvaliable(source, BIOESTag.class)) {
                MentionAnnotation.createTagsFromAnnotations(document, source, BIOESTag.class);
                document.setTagAvailable(source, BIOESTag.class, true);
            }
        }
    }

    protected void createAnnotations(Iterable<Document> iterable, Annotation.Source source) {
        for (Document document : iterable) {
            document.clearAnnotations(source, MentionAnnotation.class);
            if (document.isTagAvaliable(source, BIO2Tag.class)) {
                MentionAnnotation.annotateFromTags(document, Annotation.Source.PRED, (Class<? extends Tag>) BIO2Tag.class);
            } else {
                log.error("BIO2Tag not set");
            }
        }
    }
}
