package de.datexis.ner.tagger;

import com.google.common.collect.Lists;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncoderSet;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Snippet;
import de.datexis.model.tag.BIO2Tag;
import de.datexis.model.tag.BIOESTag;
import de.datexis.model.tag.Tag;
import de.datexis.ner.MentionAnnotation;
import de.datexis.ner.eval.MentionAnnotatorEval;
import de.datexis.ner.eval.MentionTaggerEval;
import de.datexis.tagger.AbstractIterator;
import de.datexis.tagger.Tagger;
import java.util.Collection;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.BackpropType;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.parallelism.ParallelWrapper;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/ner/tagger/MentionTagger.class */
public class MentionTagger extends Tagger {
    protected static final Logger log = LoggerFactory.getLogger(MentionTagger.class);
    protected int workers;
    protected Class<? extends Tag> tagset;
    protected String type;

    public MentionTagger() {
        this("BLSTM");
        setTagset(BIOESTag.class, "GENERIC");
    }

    public MentionTagger(String str) {
        super(str);
        this.workers = 1;
        this.tagset = BIOESTag.class;
        this.type = "GENERIC";
        setTagset(BIOESTag.class, "GENERIC");
    }

    public MentionTagger(AbstractIterator abstractIterator, int i, int i2, int i3, double d) {
        super(abstractIterator.getInputSize(), abstractIterator.getLabelSize());
        this.workers = 1;
        this.tagset = BIOESTag.class;
        this.type = "GENERIC";
        this.net = createBLSTM(this.inputVectorSize, i, i2, this.outputVectorSize, i3, d);
    }

    public MentionTagger setModelParams(int i, int i2, int i3, double d) {
        this.net = createBLSTM(this.inputVectorSize, i, i2, this.outputVectorSize, i3, d);
        return this;
    }

    public Class<? extends Tag> getTagset() {
        return this.tagset;
    }

    public static ComputationGraph createBLSTM(long j, long j2, long j3, long j4, int i, double d) {
        log.info("initializing BLSTM network " + j + ":" + j2 + ":" + j2 + ":" + j3 + ":" + j4);
        ComputationGraphConfiguration.GraphBuilder addInputs = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Adam(d, 0.9d, 0.999d, 1.0E-8d)).l2(1.0E-4d).trainingWorkspaceMode(WorkspaceMode.ENABLED).inferenceWorkspaceMode(WorkspaceMode.ENABLED).graphBuilder().addInputs(new String[]{"input"});
        if (j2 > 0) {
            addInputs.addLayer("FF1", new DenseLayer.Builder().nIn(j).nOut(j2).activation(Activation.RELU).weightInit(WeightInit.RELU).build(), new String[]{"input"}).addLayer("FF2", new DenseLayer.Builder().nIn(j2).nOut(j2).activation(Activation.RELU).weightInit(WeightInit.RELU).build(), new String[]{"FF1"}).addLayer("BLSTM", new Bidirectional(Bidirectional.Mode.AVERAGE, new LSTM.Builder().nIn(j2).nOut(j3).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()), new String[]{"FF2"});
        } else {
            addInputs.addLayer("BLSTM", new Bidirectional(Bidirectional.Mode.AVERAGE, new LSTM.Builder().nIn(j).nOut(j3).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()), new String[]{"input"});
        }
        addInputs.addLayer("output", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(j3).nOut(j4).activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build(), new String[]{"BLSTM"}).setOutputs(new String[]{"output"}).setInputTypes(new InputType[]{InputType.recurrent(j)}).backpropType(BackpropType.Standard).build();
        ComputationGraph computationGraph = new ComputationGraph(addInputs.build());
        computationGraph.init();
        return computationGraph;
    }

    public String getType() {
        return this.type;
    }

    public void setType(String str) {
        this.type = str;
    }

    public MentionTagger setTagset(Class<? extends Tag> cls) {
        this.tagset = cls;
        try {
            this.outputVectorSize = cls.newInstance().getVectorSize();
        } catch (Exception e) {
            log.error("Could not set output vector size");
        }
        return this;
    }

    public MentionTagger setTagset(Class<? extends Tag> cls, String str) {
        setTagset(cls);
        this.type = str;
        return this;
    }

    public MentionTagger setTrainingParams(int i, int i2, boolean z) {
        this.batchSize = i;
        this.numEpochs = i2;
        this.randomize = z;
        return this;
    }

    public MentionTagger setWorkspaceParams(int i) {
        this.workers = i;
        return this;
    }

    @JsonIgnore
    @Deprecated
    public EncoderSet getEncoderSet() {
        return new EncoderSet((Encoder[]) getEncoders().toArray(new Encoder[0]));
    }

    public void trainModel(Dataset dataset) {
        trainModel(dataset, Annotation.Source.GOLD);
    }

    public void trainModel(Dataset dataset, Annotation.Source source) {
        trainModel(dataset, source, -1, this.randomize);
    }

    public void trainModel(Dataset dataset, Annotation.Source source, int i, boolean z) {
        trainModel(new MentionTaggerIterator(dataset.getDocuments(), dataset.getName(), getEncoderSet(), this.tagset, source, i, this.batchSize, z));
    }

    public void trainModel(Collection<Sentence> collection, Annotation.Source source, boolean z) {
        trainModel(new MentionTaggerIterator(Lists.newArrayList(new Document[]{new Snippet(collection, z)}), "training", getEncoderSet(), this.tagset, source, -1, this.batchSize, z));
    }

    protected void trainModel(MentionTaggerIterator mentionTaggerIterator) {
        int i = 0;
        appendTrainLog("Training " + getName() + " with " + mentionTaggerIterator.numExamples() + " examples in " + (mentionTaggerIterator.numExamples() / mentionTaggerIterator.batch()) + " batches for " + this.numEpochs + " epochs.");
        ParallelWrapper build = this.workers > 1 ? new ParallelWrapper.Builder(this.net).prefetchBuffer(this.workers * 4).workers(this.workers).trainingMode(ParallelWrapper.TrainingMode.AVERAGING).workspaceMode(WorkspaceMode.ENABLED).build() : null;
        this.timer.start();
        for (int i2 = 1; i2 <= this.numEpochs; i2++) {
            this.timer.setSplit("epoch");
            if (build != null) {
                build.fit(mentionTaggerIterator);
            } else if (this.net instanceof ComputationGraph) {
                this.net.fit(mentionTaggerIterator);
            } else if (this.net instanceof MultiLayerNetwork) {
                this.net.fit(mentionTaggerIterator);
            }
            i += mentionTaggerIterator.numExamples();
            appendTrainLog("Completed epoch " + i2 + " of " + this.numEpochs + "\t" + i, this.timer.getLong("epoch"));
            mentionTaggerIterator.reset();
        }
        this.timer.stop();
        appendTrainLog("Training complete", this.timer.getLong());
        setModelAvailable(true);
    }

    public synchronized void tag(Collection<Document> collection) {
        log.debug("Labeling Documents...");
        MentionTaggerIterator mentionTaggerIterator = new MentionTaggerIterator(collection, "train", getEncoderSet(), this.tagset, -1, this.batchSize, false);
        mentionTaggerIterator.reset();
        while (mentionTaggerIterator.hasNext()) {
            Pair nextDataSet = mentionTaggerIterator.nextDataSet();
            INDArray features = ((DataSet) nextDataSet.getKey()).getFeatures();
            INDArray featuresMaskArray = ((DataSet) nextDataSet.getKey()).getFeaturesMaskArray();
            INDArray labelsMaskArray = ((DataSet) nextDataSet.getKey()).getLabelsMaskArray();
            INDArray iNDArray = null;
            if (this.net instanceof MultiLayerNetwork) {
                iNDArray = this.net.output(features, false, featuresMaskArray, labelsMaskArray);
            } else if (this.net instanceof ComputationGraph) {
                this.net.setLayerMaskArrays(new INDArray[]{featuresMaskArray}, new INDArray[]{labelsMaskArray});
                iNDArray = this.net.outputSingle(new INDArray[]{features});
            }
            createTags((Iterable) nextDataSet.getValue(), iNDArray, mentionTaggerIterator.getTagset(), Annotation.Source.PRED, this.type, false, true);
        }
        for (Document document : mentionTaggerIterator.getDocuments()) {
            document.setTagAvailable(Annotation.Source.PRED, mentionTaggerIterator.getTagset(), true);
            if (!this.tagset.equals(BIO2Tag.class)) {
                document.setTagAvailable(Annotation.Source.PRED, BIO2Tag.class, true);
            }
        }
    }

    public void tagSentences(Collection<Sentence> collection) {
        tag(Lists.newArrayList(new Document[]{new Snippet(collection, false)}));
    }

    public void testModel(Dataset dataset, Annotation.Source source) {
        test(new MentionTaggerIterator(dataset.getDocuments(), dataset.getName(), getEncoderSet(), this.tagset, -1, this.batchSize, false));
        MentionTaggerEval mentionTaggerEval = new MentionTaggerEval(getName(), this.tagset);
        mentionTaggerEval.calculateMeasures(dataset);
        appendTestLog(mentionTaggerEval.printExperimentStats());
        appendTestLog(mentionTaggerEval.printDatasetStats());
        appendTestLog(mentionTaggerEval.printTrainingCurve());
        appendTestLog(mentionTaggerEval.printSequenceClassStats(false));
        MentionAnnotatorEval mentionAnnotatorEval = new MentionAnnotatorEval(getName());
        for (Document document : dataset.getDocuments()) {
            if (document.countAnnotations(source) == 0) {
                MentionAnnotation.annotateFromTags(document, source, (Class<? extends Tag>) BIO2Tag.class, this.type);
            }
            document.clearAnnotations(Annotation.Source.PRED, MentionAnnotation.class);
            MentionAnnotation.annotateFromTags(document, Annotation.Source.PRED, (Class<? extends Tag>) BIO2Tag.class, this.type);
        }
        mentionAnnotatorEval.setTestDataset(dataset, 0L, 0L);
        mentionAnnotatorEval.evaluateAnnotations();
        appendTestLog(mentionAnnotatorEval.printAnnotationStats());
    }

    public Evaluation test(MentionTaggerIterator mentionTaggerIterator) {
        this.timer.start();
        appendTrainLog("Evaluating " + getName() + " with " + mentionTaggerIterator.numExamples() + " examples...");
        Evaluation evaluation = new Evaluation(mentionTaggerIterator.getLabelSize());
        mentionTaggerIterator.reset();
        while (mentionTaggerIterator.hasNext()) {
            Pair nextDataSet = mentionTaggerIterator.nextDataSet();
            INDArray features = ((DataSet) nextDataSet.getKey()).getFeatures();
            INDArray labels = ((DataSet) nextDataSet.getKey()).getLabels();
            INDArray featuresMaskArray = ((DataSet) nextDataSet.getKey()).getFeaturesMaskArray();
            INDArray labelsMaskArray = ((DataSet) nextDataSet.getKey()).getLabelsMaskArray();
            INDArray iNDArray = null;
            if (this.net instanceof MultiLayerNetwork) {
                iNDArray = this.net.output(features, false, featuresMaskArray, labelsMaskArray);
            } else if (this.net instanceof ComputationGraph) {
                this.net.setLayerMaskArrays(new INDArray[]{featuresMaskArray}, new INDArray[]{labelsMaskArray});
                iNDArray = this.net.outputSingle(new INDArray[]{features});
            }
            try {
                evaluation.evalTimeSeries(labels, iNDArray, labelsMaskArray);
            } catch (IllegalStateException e) {
                log.warn(e.toString());
            }
            createTags((Iterable) nextDataSet.getValue(), iNDArray, mentionTaggerIterator.getTagset(), Annotation.Source.PRED, this.type, true, true);
        }
        for (Document document : mentionTaggerIterator.getDocuments()) {
            document.setTagAvailable(Annotation.Source.PRED, mentionTaggerIterator.getTagset(), true);
            if (!this.tagset.equals(BIO2Tag.class)) {
                document.setTagAvailable(Annotation.Source.PRED, BIO2Tag.class, true);
            }
        }
        this.timer.stop();
        appendTrainLog("Evaluation complete", this.timer.getLong());
        return evaluation;
    }

    public void enableTrainingUI() {
        InMemoryStatsStorage inMemoryStatsStorage = new InMemoryStatsStorage();
        this.net.addListeners(new TrainingListener[]{new StatsListener(inMemoryStatsStorage, 1)});
        UIServer.getInstance().attach(inMemoryStatsStorage);
    }

    /* JADX WARN: Removed duplicated region for block: B:4:0x0018  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static void createTags(java.lang.Iterable<de.datexis.model.Sentence> r8, org.nd4j.linalg.api.ndarray.INDArray r9, java.lang.Class r10, de.datexis.model.Annotation.Source r11, java.lang.String r12, boolean r13, boolean r14) {
        /*
            r0 = 0
            r15 = r0
            r0 = 0
            r16 = r0
            r0 = r8
            java.util.Iterator r0 = r0.iterator()
            r17 = r0
        Le:
            r0 = r17
            boolean r0 = r0.hasNext()
            if (r0 == 0) goto Lb8
            r0 = r17
            java.lang.Object r0 = r0.next()
            de.datexis.model.Sentence r0 = (de.datexis.model.Sentence) r0
            r18 = r0
            r0 = r18
            java.util.List r0 = r0.getTokens()
            java.util.Iterator r0 = r0.iterator()
            r19 = r0
        L30:
            r0 = r19
            boolean r0 = r0.hasNext()
            if (r0 == 0) goto L90
            r0 = r19
            java.lang.Object r0 = r0.next()
            de.datexis.model.Token r0 = (de.datexis.model.Token) r0
            r20 = r0
            r0 = r9
            r1 = r15
            long r1 = (long) r1
            r2 = r16
            int r16 = r16 + 1
            long r2 = (long) r2
            org.nd4j.linalg.api.ndarray.INDArray r0 = de.datexis.encoder.EncodingHelpers.getTimeStep(r0, r1, r2)
            r21 = r0
            r0 = r10
            java.lang.Class<de.datexis.model.tag.BIO2Tag> r1 = de.datexis.model.tag.BIO2Tag.class
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L71
            r0 = r20
            r1 = r11
            de.datexis.model.tag.BIO2Tag r2 = new de.datexis.model.tag.BIO2Tag
            r3 = r2
            r4 = r21
            r5 = r12
            r6 = 1
            r3.<init>(r4, r5, r6)
            de.datexis.model.Token r0 = r0.putTag(r1, r2)
        L71:
            r0 = r10
            java.lang.Class<de.datexis.model.tag.BIOESTag> r1 = de.datexis.model.tag.BIOESTag.class
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto L8d
            r0 = r20
            r1 = r11
            de.datexis.model.tag.BIOESTag r2 = new de.datexis.model.tag.BIOESTag
            r3 = r2
            r4 = r21
            r5 = r12
            r6 = 1
            r3.<init>(r4, r5, r6)
            de.datexis.model.Token r0 = r0.putTag(r1, r2)
        L8d:
            goto L30
        L90:
            r0 = 0
            r16 = r0
            int r15 = r15 + 1
            r0 = r10
            java.lang.Class<de.datexis.model.tag.BIOESTag> r1 = de.datexis.model.tag.BIOESTag.class
            boolean r0 = r0.equals(r1)
            if (r0 == 0) goto Lb0
            r0 = r18
            r1 = r11
            de.datexis.model.tag.BIOESTag.correctCRF(r0, r1)
            r0 = r14
            if (r0 == 0) goto Lb0
            r0 = r18
            r1 = r11
            de.datexis.model.tag.BIOESTag.convertToBIO2(r0, r1)
        Lb0:
            r0 = r13
            if (r0 != 0) goto Lb5
        Lb5:
            goto Le
        Lb8:
            return
        */
        throw new UnsupportedOperationException("Method not decompiled: de.datexis.ner.tagger.MentionTagger.createTags(java.lang.Iterable, org.nd4j.linalg.api.ndarray.INDArray, java.lang.Class, de.datexis.model.Annotation$Source, java.lang.String, boolean, boolean):void");
    }
}
