package de.datexis.cdv.tagger;

import de.datexis.cdv.CDVAnnotator;
import de.datexis.cdv.index.AspectIndex;
import de.datexis.cdv.index.EntityIndex;
import de.datexis.cdv.index.QueryIndex;
import de.datexis.cdv.model.AspectAnnotation;
import de.datexis.cdv.model.EntityAnnotation;
import de.datexis.common.AnnotationHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.encoder.IEncoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.model.impl.PassageAnnotation;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:de/datexis/cdv/tagger/CDVSentenceIterator.class */
public class CDVSentenceIterator extends DocumentSentenceIterator {
    protected CDVTagger tagger;
    protected boolean balancing;

    public CDVSentenceIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> collection, CDVTagger cDVTagger, int i, int i2, int i3, boolean z, boolean z2) {
        super(stage, collection, i, i2, i3, z);
        this.balancing = true;
        this.tagger = cDVTagger;
        this.balancing = z2;
        reset();
    }

    public MultiDataSet generateDataSet(DocumentSentenceIterator.DocumentBatch documentBatch) {
        INDArray createMask = createMask(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray createMask2 = createMask(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray createMask3 = createMask(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray encodeMatrix = this.tagger.inputEncoder instanceof Encoder ? this.tagger.inputEncoder.encodeMatrix(documentBatch.docs, documentBatch.maxDocLength, Sentence.class) : EncodingHelpers.encodeTimeStepMatrix(documentBatch.docs, this.tagger.inputEncoder, documentBatch.maxDocLength, Sentence.class);
        INDArray encodeMatrix2 = this.tagger.flagEncoder instanceof Encoder ? this.tagger.flagEncoder.encodeMatrix(documentBatch.docs, documentBatch.maxDocLength, Sentence.class) : EncodingHelpers.encodeTimeStepMatrix(documentBatch.docs, this.tagger.flagEncoder, documentBatch.maxDocLength, Sentence.class);
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        if (this.stage.equals(AbstractMultiDataSetIterator.Stage.TRAIN) || this.stage.equals(AbstractMultiDataSetIterator.Stage.TEST)) {
            if (this.tagger.entityEncoder != null && (this.tagger.entityEncoder instanceof EntityIndex)) {
                iNDArray = encodeTarget(createMask2, documentBatch.docs, documentBatch.maxDocLength, this.tagger.entityEncoder, Sentence.class, EntityAnnotation.class);
            }
            if (this.tagger.aspectEncoder != null && (this.tagger.aspectEncoder instanceof AspectIndex)) {
                iNDArray2 = encodeTarget(createMask3, documentBatch.docs, documentBatch.maxDocLength, this.tagger.aspectEncoder, Sentence.class, AspectAnnotation.class);
            }
        } else {
            if (this.tagger.entityEncoder != null) {
                iNDArray = EncodingHelpers.createTimeStepMatrix(documentBatch.size, this.tagger.entityEncoder.getEmbeddingVectorSize(), documentBatch.maxDocLength);
            }
            if (this.tagger.aspectEncoder != null) {
                iNDArray2 = EncodingHelpers.createTimeStepMatrix(documentBatch.size, this.tagger.aspectEncoder.getEmbeddingVectorSize(), documentBatch.maxDocLength);
            }
        }
        if (iNDArray != null && iNDArray2 != null) {
            return new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{encodeMatrix, encodeMatrix2}, new INDArray[]{iNDArray, iNDArray2}, new INDArray[]{createMask, createMask}, new INDArray[]{createMask2, createMask3});
        }
        INDArray[] iNDArrayArr = {encodeMatrix, encodeMatrix2};
        INDArray[] iNDArrayArr2 = new INDArray[1];
        iNDArrayArr2[0] = iNDArray != null ? iNDArray : iNDArray2;
        INDArray[] iNDArrayArr3 = {createMask, createMask};
        INDArray[] iNDArrayArr4 = new INDArray[1];
        iNDArrayArr4[0] = iNDArray != null ? createMask2 : createMask3;
        return new org.nd4j.linalg.dataset.MultiDataSet(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4);
    }

    public INDArray createMask(List<Document> list, int i, Class<? extends Span> cls) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{list.size(), i});
        for (int i2 = 0; i2 < list.size(); i2++) {
            Document document = list.get(i2);
            int i3 = 0;
            if (cls == Token.class) {
                i3 = document.countTokens();
            } else if (cls == Sentence.class) {
                i3 = document.countSentences();
            }
            for (int i4 = 0; i4 < i3 && i4 < i; i4++) {
                zeros.putScalar(new int[]{i2, i4}, 1);
            }
        }
        return zeros;
    }

    public <S extends Span, A extends PassageAnnotation> INDArray encodeTarget(INDArray iNDArray, List<Document> list, int i, IEncoder iEncoder, Class<S> cls, Class<A> cls2) {
        INDArray createTimeStepMatrix = EncodingHelpers.createTimeStepMatrix(list.size(), iEncoder.getEmbeddingVectorSize(), i);
        for (int i2 = 0; i2 < list.size(); i2++) {
            Document document = list.get(i2);
            int i3 = 0;
            INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{iEncoder.getEmbeddingVectorSize(), 1});
            double d = 0.0d;
            Object obj = null;
            for (Map.Entry entry : AnnotationHelpers.getSpanAnnotationsMultiMap(document, cls, cls2)) {
                if (i3 >= i) {
                    break;
                }
                Collection collection = (Collection) entry.getValue();
                if (!collection.equals(obj)) {
                    Map.Entry<INDArray, Double> lookupAnnotations = CDVAnnotator.lookupAnnotations((QueryIndex) iEncoder, collection, this.balancing);
                    zeros = lookupAnnotations.getKey();
                    d = lookupAnnotations.getValue().doubleValue();
                    obj = collection;
                }
                if (zeros == null) {
                    iNDArray.putScalar(new int[]{i2, i3}, 0);
                } else {
                    if (this.balancing && d < 1.0d) {
                        iNDArray.putScalar(new int[]{i2, i3}, d);
                    }
                    EncodingHelpers.putTimeStep(createTimeStepMatrix, i2, i3, zeros.dup());
                }
                i3++;
            }
        }
        return createTimeStepMatrix;
    }
}
