package de.datexis.cdv;

import de.datexis.annotator.Annotator;
import de.datexis.annotator.AnnotatorComponent;
import de.datexis.cdv.index.AspectIndex;
import de.datexis.cdv.index.EntityIndex;
import de.datexis.cdv.index.QueryIndex;
import de.datexis.cdv.tagger.CDVModelBuilder;
import de.datexis.cdv.tagger.CDVSentenceIterator;
import de.datexis.cdv.tagger.CDVTagger;
import de.datexis.common.AnnotationHelpers;
import de.datexis.common.Timer;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.IEncoder;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.encoder.impl.FastTextEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.impl.PassageAnnotation;
import de.datexis.retrieval.index.IVocabulary;
import de.datexis.retrieval.index.InMemoryIndex;
import de.datexis.sector.encoder.ClassEncoder;
import de.datexis.sector.encoder.HeadingEncoder;
import de.datexis.sector.model.SectionAnnotation;
import de.datexis.sector.tagger.SectorEncoder;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.Tagger;
import java.util.AbstractMap;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/cdv/CDVAnnotator.class */
public class CDVAnnotator extends Annotator {
    protected static final Logger log = LoggerFactory.getLogger(CDVAnnotator.class);
    protected Timer timer;

    /* loaded from: input_file:de/datexis/cdv/CDVAnnotator$Builder.class */
    public static class Builder {
        IEncoder inputEncoder;
        IEncoder flagEncoder;
        IEncoder entityEncoder;
        IEncoder aspectEncoder;
        protected ILossFunction lossFunc = LossFunctions.LossFunction.MCXENT.getILossFunction();
        protected Activation activation = Activation.SOFTMAX;
        private int examplesPerEpoch = -1;
        private int maxSentencesPerDoc = -1;
        private int maxWordsPerSentence = -1;
        private int lstmWordLayerSize = 128;
        private int lstmSentenceLayerSize = 5121;
        private int embeddingLayerSize = 128;
        private double learningRate = 0.01d;
        private double weightDecay = 1.0E-4d;
        private double dropOut = 0.95d;
        private int batchSize = 16;
        private int numEpochs = 1;
        private boolean classBalancing = false;
        private boolean enabletrainingUI = false;
        CDVTagger tagger = new CDVTagger();
        CDVAnnotator ann = new CDVAnnotator(this.tagger);

        @Deprecated
        public Builder withHierarchicalModel(boolean z) {
            return this;
        }

        public Builder withClassBalancing(boolean z) {
            this.classBalancing = z;
            return this;
        }

        public Builder withId(String str) {
            this.tagger.setId(str);
            return this;
        }

        public Builder withDataset(String str, WordHelpers.Language language) {
            this.ann.getProvenance().setDataset(str);
            this.ann.getProvenance().setLanguage(language.toString().toLowerCase());
            return this;
        }

        public Builder withLossFunction(LossFunctions.LossFunction lossFunction, Activation activation) {
            this.lossFunc = lossFunction.getILossFunction();
            this.activation = activation;
            return this;
        }

        public Builder withLossFunction(ILossFunction iLossFunction, Activation activation) {
            this.lossFunc = iLossFunction;
            this.activation = activation;
            return this;
        }

        public Builder withModelParams(int i, int i2, int i3) {
            this.lstmWordLayerSize = i;
            this.lstmSentenceLayerSize = i2;
            this.embeddingLayerSize = i3;
            return this;
        }

        public Builder withTrainingParams(double d, double d2, double d3, int i, int i2, int i3) {
            this.learningRate = d;
            this.dropOut = d2;
            this.weightDecay = d3;
            this.examplesPerEpoch = i;
            this.batchSize = i2;
            this.numEpochs = i3;
            return this;
        }

        public Builder withTrainingParams(double d, double d2, double d3, int i, int i2) {
            this.learningRate = d;
            this.dropOut = d2;
            this.weightDecay = d3;
            this.batchSize = i;
            this.numEpochs = i2;
            return this;
        }

        public Builder withDatasetLimit(int i, int i2, int i3) {
            this.examplesPerEpoch = i;
            this.maxSentencesPerDoc = i2;
            this.maxWordsPerSentence = i3;
            return this;
        }

        public Builder withInputEncoders(String str, Encoder encoder, Encoder encoder2) {
            this.inputEncoder = encoder;
            this.flagEncoder = encoder2;
            this.tagger.setInputEncoders(encoder, encoder2);
            this.ann.getProvenance().setFeatures(str);
            this.ann.addComponent(encoder);
            this.ann.addComponent(encoder2);
            return this;
        }

        public Builder withEntityEncoder(QueryIndex queryIndex) {
            this.entityEncoder = queryIndex;
            this.tagger.setEntityEncoder(queryIndex);
            this.ann.addComponent(queryIndex);
            this.ann.getProvenance().setName("CDV-E");
            return this;
        }

        public Builder withAspectEncoder(QueryIndex queryIndex) {
            this.aspectEncoder = queryIndex;
            this.tagger.setAspectEncoder(queryIndex);
            this.ann.addComponent(queryIndex);
            this.ann.getProvenance().setName("CDV-A");
            return this;
        }

        public Builder withEntityAspectEncoders(QueryIndex queryIndex, QueryIndex queryIndex2) {
            this.entityEncoder = queryIndex;
            this.aspectEncoder = queryIndex2;
            this.tagger.setEntityEncoder(queryIndex);
            this.tagger.setAspectEncoder(queryIndex2);
            this.ann.addComponent(queryIndex);
            this.ann.addComponent(queryIndex2);
            this.ann.getProvenance().setName("CDV-EA");
            return this;
        }

        public Builder enableTrainingUI(boolean z) {
            this.enabletrainingUI = z;
            return this;
        }

        public CDVAnnotator build() {
            if (this.entityEncoder == null || this.aspectEncoder == null) {
                this.tagger.setIteratorClass(CDVSentenceIterator.class);
                this.tagger.initializeNetwork(CDVModelBuilder.buildSingleTaskCDV(this.inputEncoder.getEmbeddingVectorSize(), this.flagEncoder.getEmbeddingVectorSize(), this.lstmSentenceLayerSize, this.embeddingLayerSize, this.entityEncoder != null ? this.entityEncoder.getEmbeddingVectorSize() : this.aspectEncoder.getEmbeddingVectorSize(), this.learningRate, this.dropOut, this.weightDecay, this.lossFunc, this.activation));
            } else {
                this.tagger.setIteratorClass(CDVSentenceIterator.class);
                this.tagger.initializeNetwork(CDVModelBuilder.buildMultiTaskCDV(this.inputEncoder.getEmbeddingVectorSize(), this.flagEncoder.getEmbeddingVectorSize(), this.lstmSentenceLayerSize, this.embeddingLayerSize, this.entityEncoder.getEmbeddingVectorSize(), this.aspectEncoder.getEmbeddingVectorSize(), this.learningRate, this.dropOut, this.weightDecay, this.lossFunc, this.activation));
            }
            if (this.enabletrainingUI) {
                this.tagger.enableTrainingUI();
            }
            this.tagger.setTrainingParams(this.examplesPerEpoch, this.maxSentencesPerDoc, this.batchSize, this.numEpochs, true, this.classBalancing);
            this.tagger.setTrainingLimits(this.examplesPerEpoch, this.maxSentencesPerDoc, this.maxWordsPerSentence);
            this.tagger.setEmbeddingLayerSize(this.embeddingLayerSize);
            this.tagger.setName(this.ann.getProvenance().toString());
            this.tagger.appendTrainLog(printParams());
            if (this.entityEncoder != null && (this.entityEncoder instanceof InMemoryIndex) && (this.entityEncoder.getEncoder() instanceof Encoder)) {
                this.ann.addComponent(this.entityEncoder.getEncoder());
                Iterator it = this.entityEncoder.getEncoder().getEncoders().iterator();
                while (it.hasNext()) {
                    this.ann.addComponent((Encoder) it.next());
                }
            }
            if (this.aspectEncoder != null && (this.aspectEncoder instanceof InMemoryIndex) && (this.aspectEncoder.getEncoder() instanceof Encoder)) {
                this.ann.addComponent(this.aspectEncoder.getEncoder());
                Iterator it2 = this.aspectEncoder.getEncoder().getEncoders().iterator();
                while (it2.hasNext()) {
                    this.ann.addComponent((Encoder) it2.next());
                }
            }
            return this.ann;
        }

        private String printParams() {
            StringBuilder sb = new StringBuilder();
            sb.append("TRAINING PARAMS: ").append(this.tagger.getName()).append("\n");
            sb.append("\nDataset:\n");
            sb.append("File").append("\t").append(this.ann.getProvenance().getDataset()).append("\n");
            sb.append("Language").append("\t").append(this.ann.getProvenance().getLanguage()).append("\n");
            sb.append("\nInput Encoders:\n");
            sb.append("\nNetwork Params:\n");
            sb.append("LSTM").append("\t").append(this.lstmWordLayerSize).append("\n");
            sb.append("BLSTM").append("\t").append(this.lstmSentenceLayerSize).append("\n");
            sb.append("EMB").append("\t").append(this.embeddingLayerSize).append("\n");
            sb.append("\nTraining Params:\n");
            sb.append("examples per epoch").append("\t").append(this.examplesPerEpoch).append("\n");
            sb.append("max time series length").append("\t").append(this.maxSentencesPerDoc).append("\n");
            sb.append("epochs").append("\t").append(this.numEpochs).append("\n");
            sb.append("batch size").append("\t").append(this.batchSize).append("\n");
            sb.append("learning rate").append("\t").append(this.learningRate).append("\n");
            sb.append("dropout").append("\t").append(this.dropOut).append("\n");
            sb.append("weight decay").append("\t").append(this.weightDecay).append("\n");
            sb.append("loss").append("\t").append(this.lossFunc.toString()).append("\n");
            sb.append("\n");
            return sb.toString();
        }
    }

    public CDVAnnotator() {
        this.timer = new Timer();
    }

    public CDVAnnotator(Tagger tagger) {
        super(tagger);
        this.timer = new Timer();
    }

    protected CDVAnnotator(AnnotatorComponent annotatorComponent) {
        super(annotatorComponent);
        this.timer = new Timer();
    }

    /* renamed from: getTagger, reason: merged with bridge method [inline-methods] */
    public CDVTagger m1getTagger() {
        return (CDVTagger) super.getTagger();
    }

    @JsonIgnore
    public IEncoder getEntityEncoder() {
        return m1getTagger().getEntityEncoder();
    }

    @JsonIgnore
    public IEncoder getAspectEncoder() {
        return m1getTagger().getAspectEncoder();
    }

    public void annotate(Collection<Document> collection) {
        annotateSentences(collection);
    }

    public void annotateSentences(Collection<Document> collection) {
        log.info("Running CDV neural net encoding...");
        this.timer.start();
        m1getTagger().attachCDVSentenceVectors(collection, AbstractMultiDataSetIterator.Stage.ENCODE);
        this.timer.stop();
        m1getTagger().appendTestLog("Encoding complete", this.timer.getLong());
    }

    public void annotateDocuments(Collection<Document> collection) {
        log.info("Running CDV neural net encoding...");
        this.timer.start();
        m1getTagger().attachCDVDocumentMatrix(collection);
        this.timer.stop();
        m1getTagger().appendTestLog("Encoding complete", this.timer.getLong());
    }

    public void annotateDocumentsBaseline(Collection<Document> collection) {
        log.info("Running CDV baseline encoding...");
        this.timer.start();
        m1getTagger().attachMatrixBaseline(collection);
        this.timer.stop();
        m1getTagger().appendTestLog("Encoding complete", this.timer.getLong());
    }

    @Deprecated
    public void printPredictions(Dataset dataset, AspectIndex aspectIndex, SectionAnnotation.Field field) {
        if (getEntityEncoder().getClass() == FastTextEncoder.class) {
            Iterator it = dataset.getDocuments().iterator();
            while (it.hasNext()) {
                for (Map.Entry entry : AnnotationHelpers.getSpanAnnotationsMap((Document) it.next(), Sentence.class, SectionAnnotation.class)) {
                    System.out.println(((SectionAnnotation) entry.getValue()).getSectionHeading());
                    System.out.println(getEntityEncoder().getNearestNeighbours(encodeAnnotation(getEntityEncoder(), (PassageAnnotation) entry.getValue()), 3).toString() + " -> " + getEntityEncoder().getNearestNeighbours(((Sentence) entry.getKey()).getVector(CDVTagger.class), 3).toString());
                }
            }
            return;
        }
        if (aspectIndex.getClass() != AspectIndex.class) {
            throw new IllegalArgumentException("Target encoder has no evaluation: " + getEntityEncoder().getClass().toString());
        }
        for (Document document : dataset.getDocuments()) {
            System.out.println();
            System.out.println(document.getId());
            for (Map.Entry entry2 : AnnotationHelpers.getSpanAnnotationsMap(document, Sentence.class, SectionAnnotation.class)) {
                System.out.println(aspectIndex.getKeyPreprocessor().preProcess(((SectionAnnotation) entry2.getValue()).getAnnotation(field)) + "\t -> " + aspectIndex.find(((Sentence) entry2.getKey()).getVector(getAspectEncoder().getClass()), 3).toString());
            }
        }
    }

    @Deprecated
    protected static void attachVectorsToAnnotations(Document document, LookupCacheEncoder lookupCacheEncoder) {
        for (SectionAnnotation sectionAnnotation : document.getAnnotations(Annotation.Source.GOLD, SectionAnnotation.class)) {
            if (lookupCacheEncoder.getClass() == ClassEncoder.class) {
                sectionAnnotation.putVector(ClassEncoder.class, lookupCacheEncoder.encode(sectionAnnotation.getSectionLabel()));
            } else if (lookupCacheEncoder.getClass() == HeadingEncoder.class) {
                sectionAnnotation.putVector(HeadingEncoder.class, lookupCacheEncoder.encode(sectionAnnotation.getSectionHeading()));
            }
        }
        for (SectionAnnotation sectionAnnotation2 : document.getAnnotations(Annotation.Source.PRED, SectionAnnotation.class)) {
            int i = 0;
            INDArray zeros = Nd4j.zeros(new long[]{lookupCacheEncoder.getEmbeddingVectorSize(), 1});
            Iterator it = ((List) document.streamSentencesInRange(sectionAnnotation2.getBegin(), sectionAnnotation2.getEnd(), false).collect(Collectors.toList())).iterator();
            while (it.hasNext()) {
                zeros.addi(((Sentence) it.next()).getVector(lookupCacheEncoder.getClass()));
                i++;
            }
            if (i > 1) {
                zeros.divi(Integer.valueOf(i));
            }
            if (lookupCacheEncoder.getClass() == ClassEncoder.class) {
                sectionAnnotation2.putVector(ClassEncoder.class, zeros);
                sectionAnnotation2.setSectionLabel(lookupCacheEncoder.getNearestNeighbour(zeros));
                sectionAnnotation2.setConfidence(lookupCacheEncoder.getMaxConfidence(zeros));
            } else if (lookupCacheEncoder.getClass() == HeadingEncoder.class) {
                sectionAnnotation2.putVector(HeadingEncoder.class, zeros);
                sectionAnnotation2.setSectionHeading(StringUtils.join(lookupCacheEncoder.getNearestNeighbours(zeros, 2), "/"));
                sectionAnnotation2.setConfidence(lookupCacheEncoder.getMaxConfidence(zeros));
            }
        }
    }

    public static INDArray encodeAnnotation(IEncoder iEncoder, PassageAnnotation passageAnnotation) {
        return iEncoder instanceof FastTextEncoder ? iEncoder.encode(passageAnnotation.getLabel().replaceAll(" | ", " ")) : iEncoder.encode(passageAnnotation.getLabel());
    }

    @Deprecated
    public static INDArray encodeAnnotations(IEncoder iEncoder, Collection<PassageAnnotation> collection) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{iEncoder.getEmbeddingVectorSize(), 1});
        int i = 0;
        for (PassageAnnotation passageAnnotation : collection) {
            INDArray lookup = iEncoder instanceof IVocabulary ? ((IVocabulary) iEncoder).lookup(passageAnnotation.getLabel()) : iEncoder.encode(passageAnnotation.getLabel());
            if (lookup == null || lookup.maxNumber().doubleValue() == 0.0d) {
                log.warn("could not encode/lookup '{}'", passageAnnotation.getLabel());
            } else {
                zeros.addi(lookup);
                i++;
            }
        }
        return i > 1 ? zeros.divi(Integer.valueOf(i)) : zeros;
    }

    public static Map.Entry<INDArray, Double> lookupAnnotations(QueryIndex queryIndex, Collection<? extends PassageAnnotation> collection, boolean z) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{queryIndex.getEmbeddingVectorSize(), 1});
        int i = 0;
        double d = z ? 0.0d : 1.0d;
        for (PassageAnnotation passageAnnotation : collection) {
            double d2 = 1.0d;
            String label = passageAnnotation.getLabel();
            if (label.equals("Abstract")) {
                label = "description";
            }
            INDArray lookup = queryIndex.lookup(label);
            if (queryIndex instanceof EntityIndex) {
                if (lookup == null) {
                    log.trace("missing encoding for entity '{}'", label);
                }
                if (z) {
                    d2 = passageAnnotation.getBegin() == 0 ? 1.0d : 0.1d;
                    d = Math.max(d, d2);
                }
            } else if (queryIndex instanceof AspectIndex) {
                if (lookup == null) {
                    lookup = queryIndex.encode(label);
                }
                if (z) {
                    double weightFactor = queryIndex.weightFactor(label);
                    d = Math.max(d, weightFactor);
                    d2 = 1.0d + weightFactor;
                }
            }
            if (lookup != null && lookup.maxNumber().doubleValue() != 0.0d) {
                zeros.addi(lookup.muli(Double.valueOf(d2)));
                i++;
            }
        }
        if (i == 0) {
            return new AbstractMap.SimpleEntry(null, Double.valueOf(0.0d));
        }
        return new AbstractMap.SimpleEntry(Transforms.unitVec(i > 1 ? zeros.divi(Integer.valueOf(i)) : zeros), Double.valueOf(d));
    }

    protected static INDArray getLayerMatrix(Document document, String str) {
        INDArray zeros = Nd4j.zeros(new long[]{document.countSentences(), document.getSentence(0).getVector(str).length()});
        int i = 0;
        Iterator it = document.getSentences().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            zeros.getRow(i2).assign(((Sentence) it.next()).getVector(str));
        }
        return zeros;
    }

    protected static INDArray getLayerMatrix(Document document, Class cls) {
        return getLayerMatrix(document, cls.getCanonicalName());
    }

    protected static INDArray getEmbeddingMatrix(Document document) {
        return getLayerMatrix(document, SectorEncoder.class);
    }
}
