package de.datexis.sector.tagger;

import com.google.common.collect.Lists;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncoderSet;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.sector.encoder.ClassEncoder;
import de.datexis.sector.encoder.HeadingEncoder;
import de.datexis.sector.model.SectionAnnotation;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:de/datexis/sector/tagger/SectorTaggerIterator.class */
public class SectorTaggerIterator extends DocumentSentenceIterator {
    protected EncoderSet inputEncoders;
    protected EncoderSet targetEncoders;
    protected SectorTagger tagger;
    protected boolean requireSubsampling;

    public SectorTaggerIterator(AbstractMultiDataSetIterator.Stage stage, Dataset dataset, SectorTagger sectorTagger, int i, boolean z, boolean z2) {
        this(stage, (Collection<Document>) dataset.getDocuments(), sectorTagger, i, z, z2);
    }

    public SectorTaggerIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> collection, SectorTagger sectorTagger, int i, boolean z, boolean z2) {
        this(stage, collection, sectorTagger, -1, i, z, z2);
    }

    public SectorTaggerIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> collection, SectorTagger sectorTagger, int i, int i2, boolean z, boolean z2) {
        this(stage, collection, sectorTagger, i, -1, i2, z, z2);
    }

    public SectorTaggerIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> collection, SectorTagger sectorTagger, int i, int i2, int i3, boolean z, boolean z2) {
        super(stage, collection, i, i2, i3, z);
        this.tagger = sectorTagger;
        this.inputEncoders = new EncoderSet(new Encoder[]{sectorTagger.bagEncoder, sectorTagger.embEncoder, sectorTagger.flagEncoder});
        this.targetEncoders = new EncoderSet(new Encoder[]{sectorTagger.targetEncoder});
        this.requireSubsampling = z2;
        reset();
    }

    public boolean asyncSupported() {
        return true;
    }

    public MultiDataSet generateDataSet(DocumentSentenceIterator.DocumentBatch documentBatch) {
        INDArray createMask = createMask(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray encodeMatrix = this.tagger.bagEncoder.encodeMatrix(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray encodeMatrix2 = this.tagger.embEncoder.encodeMatrix(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray encodeMatrix3 = this.tagger.flagEncoder.encodeMatrix(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray encodeTarget = (this.stage.equals(AbstractMultiDataSetIterator.Stage.TRAIN) || this.stage.equals(AbstractMultiDataSetIterator.Stage.TEST)) ? encodeTarget(documentBatch.docs, documentBatch.maxDocLength, Sentence.class) : Nd4j.zeros(new long[]{documentBatch.size, this.tagger.targetEncoder.getEmbeddingVectorSize(), documentBatch.maxDocLength});
        return new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{encodeMatrix, encodeMatrix2, encodeMatrix3}, new INDArray[]{encodeTarget, encodeTarget}, new INDArray[]{createMask, createMask, createMask}, new INDArray[]{createMask, createMask});
    }

    public INDArray createMask(List<Document> list, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), i});
        for (int i2 = 0; i2 < list.size(); i2++) {
            Document document = list.get(i2);
            int i3 = 0;
            if (cls == Token.class) {
                i3 = document.countTokens();
            } else if (cls == Sentence.class) {
                i3 = document.countSentences();
            }
            for (int i4 = 0; i4 < i3 && i4 < i; i4++) {
                zeros.putScalar(new int[]{i2, i4}, 1);
            }
        }
        return zeros;
    }

    public INDArray encodeTarget(List<Document> list, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(new long[]{list.size(), this.tagger.targetEncoder.getEmbeddingVectorSize(), i});
        for (int i2 = 0; i2 < list.size(); i2++) {
            Document document = list.get(i2);
            List list2 = Collections.EMPTY_LIST;
            if (cls == Token.class) {
                list2 = Lists.newArrayList(document.getTokens());
            } else if (cls == Sentence.class) {
                list2 = Lists.newArrayList(document.getSentences());
            }
            Iterator it = ((List) document.streamAnnotations(Annotation.Source.GOLD, SectionAnnotation.class).sorted().collect(Collectors.toList())).iterator();
            if (!it.hasNext()) {
                return zeros;
            }
            SectionAnnotation sectionAnnotation = (SectionAnnotation) it.next();
            INDArray encodeAnnotation = encodeAnnotation(this.tagger.targetEncoder, sectionAnnotation);
            for (int i3 = 0; i3 < list2.size() && i3 < i; i3++) {
                if (((Span) list2.get(i3)).getBegin() >= sectionAnnotation.getEnd() && it.hasNext()) {
                    sectionAnnotation = (SectionAnnotation) it.next();
                    encodeAnnotation = encodeAnnotation(this.tagger.targetEncoder, sectionAnnotation);
                }
                EncodingHelpers.putTimeStep(zeros, i2, i3, encodeAnnotation.dup());
            }
        }
        return zeros;
    }

    protected INDArray encodeAnnotation(Encoder encoder, SectionAnnotation sectionAnnotation) {
        return encoder instanceof HeadingEncoder ? this.requireSubsampling ? ((HeadingEncoder) encoder).encodeSubsampled(sectionAnnotation.getSectionHeading()) : ((HeadingEncoder) encoder).encode(sectionAnnotation.getSectionHeading()) : encoder instanceof ClassEncoder ? ((ClassEncoder) encoder).encode(sectionAnnotation.getSectionLabel()) : Nd4j.create(1);
    }
}
