package de.datexis.cdv.tagger;

import com.google.common.collect.Lists;
import de.datexis.cdv.index.AspectIndex;
import de.datexis.cdv.index.EntityIndex;
import de.datexis.cdv.model.AspectAnnotation;
import de.datexis.cdv.model.EntityAnnotation;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncodingHelpers;
import de.datexis.encoder.IEncoder;
import de.datexis.encoder.impl.DummyEncoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import de.datexis.tagger.DocumentSentenceIterator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.TimeUnit;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:de/datexis/cdv/tagger/CDVWordIterator.class */
public class CDVWordIterator extends CDVSentenceIterator {
    protected int maxWordsPerSentence;

    public CDVWordIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> collection, CDVTagger cDVTagger, int i, int i2, int i3, int i4, boolean z, boolean z2) {
        super(stage, collection, cDVTagger, i, i2, i4, z, z2);
        this.maxWordsPerSentence = i3;
        reset();
    }

    protected CDVWordIterator(AbstractMultiDataSetIterator.Stage stage, Collection<Document> collection, IEncoder iEncoder, int i) {
        super(stage, collection, new CDVTagger(), -1, -1, i, false, false);
        this.tagger.setInputEncoders(iEncoder, new DummyEncoder());
        this.tagger.setAspectEncoder(new DummyEncoder());
        this.maxWordsPerSentence = -1;
        reset();
    }

    public DocumentSentenceIterator.DocumentBatch nextDocumentBatch(int i) {
        DocumentSentenceIterator.DocumentBatch nextBatch = nextBatch(i);
        if (this.maxWordsPerSentence > 0 && nextBatch.maxSentenceLength > this.maxWordsPerSentence) {
            nextBatch.maxSentenceLength = this.maxWordsPerSentence;
        }
        nextBatch.dataset = generateDataSet(nextBatch);
        reportProgress(nextBatch.maxDocLength, nextBatch.maxSentenceLength);
        return nextBatch;
    }

    @Override // de.datexis.cdv.tagger.CDVSentenceIterator
    public MultiDataSet generateDataSet(DocumentSentenceIterator.DocumentBatch documentBatch) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{documentBatch.size * documentBatch.maxDocLength, 1, this.maxWordsPerSentence, this.tagger.inputEncoder.getEmbeddingVectorSize()});
        INDArray zeros2 = Nd4j.zeros(DataType.FLOAT, new long[]{documentBatch.size * documentBatch.maxDocLength, 1, this.maxWordsPerSentence, 1});
        INDArray createMask = createMask(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray createMask2 = createMask(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray createMask3 = createMask(documentBatch.docs, documentBatch.maxDocLength, Sentence.class);
        INDArray encodeMatrix = this.tagger.flagEncoder instanceof Encoder ? this.tagger.flagEncoder.encodeMatrix(documentBatch.docs, documentBatch.maxDocLength, Sentence.class) : EncodingHelpers.encodeTimeStepMatrix(documentBatch.docs, this.tagger.flagEncoder, documentBatch.maxDocLength, Sentence.class);
        INDArrayIndex[] iNDArrayIndexArr = new INDArrayIndex[4];
        for (int i = 0; i < documentBatch.size; i++) {
            iNDArrayIndexArr[1] = NDArrayIndex.point(0L);
            ArrayList newArrayList = Lists.newArrayList(((Document) documentBatch.docs.get(i)).getSentences());
            for (int i2 = 0; i2 < newArrayList.size() && i2 < documentBatch.maxDocLength; i2++) {
                iNDArrayIndexArr[0] = NDArrayIndex.point((i * this.batchSize) + i2);
                ArrayList newArrayList2 = Lists.newArrayList(((Sentence) newArrayList.get(i2)).getTokens());
                for (int i3 = 0; i3 < newArrayList2.size() && i3 < documentBatch.maxSentenceLength; i3++) {
                    INDArray encode = this.tagger.inputEncoder.encode((Span) newArrayList2.get(i3));
                    iNDArrayIndexArr[2] = NDArrayIndex.point(i3);
                    iNDArrayIndexArr[3] = NDArrayIndex.all();
                    zeros.put(iNDArrayIndexArr, encode);
                }
                if (newArrayList2.size() >= documentBatch.maxSentenceLength) {
                    zeros2.slice((i * documentBatch.size) + i2).assign(Double.valueOf(1.0d));
                } else {
                    zeros2.get(new INDArrayIndex[]{NDArrayIndex.point((i * documentBatch.size) + i2), NDArrayIndex.point(0L), NDArrayIndex.interval(0, newArrayList2.size()), NDArrayIndex.point(0L)}).assign(Double.valueOf(1.0d));
                }
            }
        }
        INDArray iNDArray = null;
        if (this.stage.equals(AbstractMultiDataSetIterator.Stage.TRAIN) || this.stage.equals(AbstractMultiDataSetIterator.Stage.TEST)) {
            if (this.tagger.entityEncoder != null && (this.tagger.entityEncoder instanceof EntityIndex)) {
                r20 = encodeTarget(createMask2, documentBatch.docs, documentBatch.maxDocLength, this.tagger.entityEncoder, Sentence.class, EntityAnnotation.class);
            }
            if (this.tagger.aspectEncoder != null && (this.tagger.aspectEncoder instanceof AspectIndex)) {
                iNDArray = encodeTarget(createMask3, documentBatch.docs, documentBatch.maxDocLength, this.tagger.aspectEncoder, Sentence.class, AspectAnnotation.class);
            }
        } else {
            r20 = this.tagger.entityEncoder != null ? EncodingHelpers.createTimeStepMatrix(documentBatch.size, this.tagger.entityEncoder.getEmbeddingVectorSize(), documentBatch.maxDocLength) : null;
            if (this.tagger.aspectEncoder != null) {
                iNDArray = EncodingHelpers.createTimeStepMatrix(documentBatch.size, this.tagger.aspectEncoder.getEmbeddingVectorSize(), documentBatch.maxDocLength);
            }
        }
        if (r20 != null && iNDArray != null) {
            return new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{zeros, encodeMatrix}, new INDArray[]{r20, iNDArray}, new INDArray[]{zeros2, createMask}, new INDArray[]{createMask2, createMask3});
        }
        INDArray[] iNDArrayArr = {zeros, encodeMatrix};
        INDArray[] iNDArrayArr2 = new INDArray[1];
        iNDArrayArr2[0] = r20 != null ? r20 : iNDArray;
        INDArray[] iNDArrayArr3 = {zeros2, createMask};
        INDArray[] iNDArrayArr4 = new INDArray[1];
        iNDArrayArr4[0] = r20 != null ? createMask2 : createMask3;
        return new org.nd4j.linalg.dataset.MultiDataSet(iNDArrayArr, iNDArrayArr2, iNDArrayArr3, iNDArrayArr4);
    }

    protected void reportProgress(int i, int i2) {
        String str = "??";
        try {
            long currentTimeMillis = System.currentTimeMillis() - this.startTime;
            long j = ((currentTimeMillis * this.numExamples) / this.cursor) - currentTimeMillis;
            str = String.format("%02d:%02d:%02d", Long.valueOf(TimeUnit.MILLISECONDS.toHours(j)), Long.valueOf(TimeUnit.MILLISECONDS.toMinutes(j) - TimeUnit.HOURS.toMinutes(TimeUnit.MILLISECONDS.toHours(j))), Long.valueOf(TimeUnit.MILLISECONDS.toSeconds(j) - TimeUnit.MINUTES.toSeconds(TimeUnit.MILLISECONDS.toMinutes(j))));
        } catch (Exception e) {
        }
        this.log.debug("{}: returning {}/{} examples in [{}%, {} remaining] [{} x {}]", new Object[]{this.stage.toString(), Integer.valueOf(this.cursor), Long.valueOf(this.numExamples), Integer.valueOf((int) ((this.cursor * 100.0f) / ((float) this.numExamples))), str, Integer.valueOf(i), Integer.valueOf(i2)});
    }
}
