package de.datexis.tagger;

import de.datexis.encoder.Encoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Token;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.stream.StreamSupport;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:de/datexis/tagger/CachedSentenceIterator.class */
public abstract class CachedSentenceIterator extends AbstractIterator {
    protected Iterator<Document> docIt;
    protected Iterator<Sentence> sentIt;

    public CachedSentenceIterator(Collection<Document> collection, String str, int i, int i2, boolean z) {
        super(collection, str, i, i2, z);
        this.docIt = null;
        this.sentIt = null;
        this.totalExamples = StreamSupport.stream(collection.spliterator(), false).mapToInt(document -> {
            return document.countSentences();
        }).sum();
        this.numExamples = i < 0 ? this.totalExamples : i;
    }

    @Override // de.datexis.tagger.AbstractIterator
    public boolean asyncSupported() {
        return false;
    }

    @Override // de.datexis.tagger.AbstractIterator
    public final void reset() {
        this.cursor = 0;
        if (this.randomize) {
            this.documents = randomizeDocuments(this.documents);
        }
        this.docIt = this.documents.iterator();
        this.sentIt = null;
        this.startTime = System.currentTimeMillis();
    }

    private boolean reachedEnd() {
        return this.cursor >= this.numExamples;
    }

    @Override // de.datexis.tagger.AbstractIterator
    public boolean hasNext() {
        if (hasNextSentence()) {
            return !reachedEnd();
        }
        if (!hasNextDocument()) {
            return false;
        }
        this.sentIt = nextDocument().getSentences().iterator();
        return hasNext();
    }

    public boolean hasNextDocument() {
        return this.docIt != null && this.docIt.hasNext();
    }

    public Document nextDocument() {
        this.currDocument = this.docIt.next();
        encodeDocument(this.currDocument);
        return this.currDocument;
    }

    public boolean hasNextSentence() {
        return this.sentIt != null && this.sentIt.hasNext();
    }

    public Sentence nextSentence() {
        this.cursor++;
        return this.sentIt.next();
    }

    protected void encodeDocument(Document document) {
        this.docsInUse.add(document);
        Iterator<Encoder> it = this.encoders.iterator();
        while (it.hasNext()) {
            it.next().encodeEach(document, Token.class);
        }
    }

    protected boolean clearCache() {
        boolean z = false;
        for (Document document : this.docsInUse) {
            if (document != this.currDocument) {
                clearCachedDocument(document);
                z = true;
            }
        }
        if (z) {
            this.docsInUse.clear();
            this.docsInUse.add(this.currDocument);
        }
        return z;
    }

    protected void clearCachedDocument(Document document) {
        Iterator<Token> it = document.getTokens().iterator();
        while (it.hasNext()) {
            it.next().clearVectors();
        }
        document.getAnnotations().forEach(annotation -> {
            annotation.clearVectors();
        });
        Iterator<Sentence> it2 = document.getSentences().iterator();
        while (it2.hasNext()) {
            it2.next().clearVectors();
        }
    }

    public DataSet next(int i) {
        Pair<ArrayList<Sentence>, Integer> nextBatchOfSentences = nextBatchOfSentences(i);
        reportProgress();
        DataSet generateDataSet = generateDataSet((ArrayList) nextBatchOfSentences.getKey(), i, ((Integer) nextBatchOfSentences.getValue()).intValue());
        this.log.trace("Iterate: example size " + i + " Sentences x " + nextBatchOfSentences.getValue() + " Tokens");
        return generateDataSet;
    }

    public Pair<DataSet, ArrayList<Sentence>> nextDataSet() {
        return nextDataSet(this.batchSize);
    }

    public Pair<DataSet, ArrayList<Sentence>> nextDataSet(int i) {
        Pair<ArrayList<Sentence>, Integer> nextBatchOfSentences = nextBatchOfSentences(i);
        reportProgress();
        DataSet generateDataSet = generateDataSet((ArrayList) nextBatchOfSentences.getKey(), i, ((Integer) nextBatchOfSentences.getValue()).intValue());
        this.log.trace("Iterate: example size " + i + " Sentences x " + nextBatchOfSentences.getValue() + " Tokens");
        return new ImmutablePair(generateDataSet, nextBatchOfSentences.getKey());
    }

    public Pair<ArrayList<Sentence>, Integer> nextBatchOfSentences(int i) {
        ArrayList arrayList = new ArrayList(i);
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            Sentence nextSentence = hasNext() ? nextSentence() : new Sentence();
            arrayList.add(nextSentence);
            i2 = Math.max(i2, nextSentence.countTokens());
        }
        return new ImmutablePair(arrayList, Integer.valueOf(i2));
    }

    public abstract DataSet generateDataSet(ArrayList<Sentence> arrayList, int i, int i2);
}
