package org.deeplearning4j.iterator;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.iterator.bert.BertMaskedLMMasker;
import org.deeplearning4j.iterator.bert.BertSequenceMasker;
import org.deeplearning4j.text.tokenization.tokenizer.Tokenizer;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;

/* loaded from: input_file:org/deeplearning4j/iterator/BertIterator.class */
public class BertIterator implements MultiDataSetIterator {
    protected Task task;
    protected TokenizerFactory tokenizerFactory;
    protected int maxTokens;
    protected int minibatchSize;
    protected boolean padMinibatches;
    protected MultiDataSetPreProcessor preProcessor;
    protected LabeledSentenceProvider sentenceProvider;
    protected LabeledPairSentenceProvider sentencePairProvider;
    protected LengthHandling lengthHandling;
    protected FeatureArrays featureArrays;
    protected Map<String, Integer> vocabMap;
    protected BertSequenceMasker masker;
    protected UnsupervisedLabelFormat unsupervisedLabelFormat;
    protected String maskToken;
    protected String prependToken;
    protected String appendToken;
    protected List<String> vocabKeysAsList;

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$Builder.class */
    public static class Builder {
        protected Task task;
        protected TokenizerFactory tokenizerFactory;
        protected MultiDataSetPreProcessor preProcessor;
        protected Map<String, Integer> vocabMap;
        protected UnsupervisedLabelFormat unsupervisedLabelFormat;
        protected String maskToken;
        protected String prependToken;
        protected String appendToken;
        protected LengthHandling lengthHandling = LengthHandling.FIXED_LENGTH;
        protected int maxTokens = -1;
        protected int minibatchSize = 32;
        protected boolean padMinibatches = false;
        protected LabeledSentenceProvider sentenceProvider = null;
        protected LabeledPairSentenceProvider sentencePairProvider = null;
        protected FeatureArrays featureArrays = FeatureArrays.INDICES_MASK_SEGMENTID;
        protected BertSequenceMasker masker = new BertMaskedLMMasker();

        public Builder task(Task task) {
            this.task = task;
            return this;
        }

        public Builder tokenizer(TokenizerFactory tokenizerFactory) {
            this.tokenizerFactory = tokenizerFactory;
            return this;
        }

        public Builder lengthHandling(@NonNull LengthHandling lengthHandling, int i) {
            if (lengthHandling == null) {
                throw new NullPointerException("lengthHandling is marked @NonNull but is null");
            }
            this.lengthHandling = lengthHandling;
            this.maxTokens = i;
            return this;
        }

        public Builder minibatchSize(int i) {
            this.minibatchSize = i;
            return this;
        }

        public Builder padMinibatches(boolean z) {
            this.padMinibatches = z;
            return this;
        }

        public Builder preProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
            this.preProcessor = multiDataSetPreProcessor;
            return this;
        }

        public Builder sentenceProvider(LabeledSentenceProvider labeledSentenceProvider) {
            this.sentenceProvider = labeledSentenceProvider;
            return this;
        }

        public Builder sentencePairProvider(LabeledPairSentenceProvider labeledPairSentenceProvider) {
            this.sentencePairProvider = labeledPairSentenceProvider;
            return this;
        }

        public Builder featureArrays(FeatureArrays featureArrays) {
            this.featureArrays = featureArrays;
            return this;
        }

        public Builder vocabMap(Map<String, Integer> map) {
            this.vocabMap = map;
            return this;
        }

        public Builder masker(BertSequenceMasker bertSequenceMasker) {
            this.masker = bertSequenceMasker;
            return this;
        }

        public Builder unsupervisedLabelFormat(UnsupervisedLabelFormat unsupervisedLabelFormat) {
            this.unsupervisedLabelFormat = unsupervisedLabelFormat;
            return this;
        }

        public Builder maskToken(String str) {
            this.maskToken = str;
            return this;
        }

        public Builder prependToken(String str) {
            this.prependToken = str;
            return this;
        }

        public Builder appendToken(String str) {
            this.appendToken = str;
            return this;
        }

        public BertIterator build() {
            Preconditions.checkState(this.task != null, "No task has been set. Use .task(BertIterator.Task.X) to set the task to be performed");
            Preconditions.checkState(this.tokenizerFactory != null, "No tokenizer factory has been set. A tokenizer factory (such as BertWordPieceTokenizerFactory) is required");
            Preconditions.checkState(this.vocabMap != null, "Cannot create iterator: No vocabMap has been set. Use Builder.vocabMap(Map<String,Integer>) to set");
            Preconditions.checkState((this.task == Task.UNSUPERVISED && this.masker == null) ? false : true, "If task is UNSUPERVISED training, a masker must be set via masker(BertSequenceMasker) method");
            Preconditions.checkState((this.task == Task.UNSUPERVISED && this.unsupervisedLabelFormat == null) ? false : true, "If task is UNSUPERVISED training, a label format must be set via masker(BertSequenceMasker) method");
            Preconditions.checkState((this.task == Task.UNSUPERVISED && this.maskToken == null) ? false : true, "If task is UNSUPERVISED training, the mask token in the vocab (such as \"[MASK]\" must be specified");
            if (this.sentencePairProvider != null) {
                Preconditions.checkState(this.task == Task.SEQ_CLASSIFICATION, "Currently only supervised sequence classification is set up with sentence pairs. \".task(BertIterator.Task.SEQ_CLASSIFICATION)\" is required with a sentence pair provider");
                Preconditions.checkState(this.featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID, "Currently only supervised sequence classification is set up with sentence pairs. \".featureArrays(FeatureArrays.INDICES_MASK_SEGMENTID)\" is required with a sentence pair provider");
                Preconditions.checkState(this.lengthHandling == LengthHandling.FIXED_LENGTH, "Currently only fixed length is supported for sentence pairs. \".lengthHandling(BertIterator.LengthHandling.FIXED_LENGTH, maxLength)\" is required with a sentence pair provider");
                Preconditions.checkState(this.sentencePairProvider != null, "Provide either a sentence provider or a sentence pair provider. Both cannot be non null");
            }
            if (this.appendToken != null) {
                Preconditions.checkState(this.sentencePairProvider != null, "Tokens are only appended with sentence pairs. Sentence pair provider is not set. Set sentence pair provider.");
            }
            return new BertIterator(this);
        }
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$FeatureArrays.class */
    public enum FeatureArrays {
        INDICES_MASK,
        INDICES_MASK_SEGMENTID
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$LengthHandling.class */
    public enum LengthHandling {
        FIXED_LENGTH,
        ANY_LENGTH,
        CLIP_ONLY
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$SentenceListProcessed.class */
    public static class SentenceListProcessed {
        private int listLength;
        private int maxL;
        private List<Pair<List<String>, String>> tokensAndLabelList;

        private SentenceListProcessed(int i) {
            this.listLength = i;
            this.tokensAndLabelList = new ArrayList(i);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addProcessedToList(Pair<List<String>, String> pair) {
            this.tokensAndLabelList.add(pair);
        }

        public int getMaxL() {
            return this.maxL;
        }

        public void setMaxL(int i) {
            this.maxL = i;
        }

        public List<Pair<List<String>, String>> getTokensAndLabelList() {
            return this.tokensAndLabelList;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$SentencePairListProcessed.class */
    public static class SentencePairListProcessed {
        private int listLength;
        private long[] segIdOnesFrom;
        private int cursor;
        private SentenceListProcessed sentenceListProcessed;

        private SentencePairListProcessed(int i) {
            this.listLength = 0;
            this.cursor = 0;
            this.listLength = i;
            this.segIdOnesFrom = new long[i];
            this.sentenceListProcessed = new SentenceListProcessed(i);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addProcessedToList(long j, Pair<List<String>, String> pair) {
            this.segIdOnesFrom[this.cursor] = j;
            this.sentenceListProcessed.addProcessedToList(pair);
            this.cursor++;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setMaxL(int i) {
            this.sentenceListProcessed.setMaxL(i);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public int getMaxL() {
            return this.sentenceListProcessed.getMaxL();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<Pair<List<String>, String>> getTokensAndLabelList() {
            return this.sentenceListProcessed.getTokensAndLabelList();
        }

        public long[] getSegIdOnesFrom() {
            return this.segIdOnesFrom;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$Task.class */
    public enum Task {
        UNSUPERVISED,
        SEQ_CLASSIFICATION
    }

    /* loaded from: input_file:org/deeplearning4j/iterator/BertIterator$UnsupervisedLabelFormat.class */
    public enum UnsupervisedLabelFormat {
        RANK2_IDX,
        RANK3_NCL,
        RANK3_LNC
    }

    protected BertIterator(Builder builder) {
        this.maxTokens = -1;
        this.minibatchSize = 32;
        this.padMinibatches = false;
        this.sentenceProvider = null;
        this.sentencePairProvider = null;
        this.masker = null;
        this.unsupervisedLabelFormat = null;
        this.task = builder.task;
        this.tokenizerFactory = builder.tokenizerFactory;
        this.maxTokens = builder.maxTokens;
        this.minibatchSize = builder.minibatchSize;
        this.padMinibatches = builder.padMinibatches;
        this.preProcessor = builder.preProcessor;
        this.sentenceProvider = builder.sentenceProvider;
        this.sentencePairProvider = builder.sentencePairProvider;
        this.lengthHandling = builder.lengthHandling;
        this.featureArrays = builder.featureArrays;
        this.vocabMap = builder.vocabMap;
        this.masker = builder.masker;
        this.unsupervisedLabelFormat = builder.unsupervisedLabelFormat;
        this.maskToken = builder.maskToken;
        this.prependToken = builder.prependToken;
        this.appendToken = builder.appendToken;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.sentenceProvider != null ? this.sentenceProvider.hasNext() : this.sentencePairProvider.hasNext();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public MultiDataSet next() {
        return next(this.minibatchSize);
    }

    @Override // java.util.Iterator
    public void remove() {
        throw new UnsupportedOperationException("Not supported");
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public MultiDataSet next(int i) {
        List<Pair<List<String>, String>> tokensAndLabelList;
        int maxL;
        Preconditions.checkState(hasNext(), "No next element available");
        int i2 = 0;
        long[] jArr = null;
        if (this.sentenceProvider != null) {
            ArrayList arrayList = new ArrayList(i);
            while (this.sentenceProvider.hasNext()) {
                int i3 = i2;
                i2++;
                if (i3 >= i) {
                    break;
                }
                arrayList.add(this.sentenceProvider.nextSentence());
            }
            SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(arrayList);
            tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
            maxL = sentenceListProcessed.getMaxL();
        } else {
            if (this.sentencePairProvider == null) {
                throw new UnsupportedOperationException("Labelled sentence provider is null and no other iterator types have yet been implemented");
            }
            ArrayList arrayList2 = new ArrayList(i);
            while (this.sentencePairProvider.hasNext()) {
                int i4 = i2;
                i2++;
                if (i4 >= i) {
                    break;
                }
                arrayList2.add(this.sentencePairProvider.nextSentencePair());
            }
            SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(arrayList2);
            tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
            maxL = sentencePairListProcessed.getMaxL();
            jArr = sentencePairListProcessed.getSegIdOnesFrom();
        }
        Pair<INDArray[], INDArray[]> convertMiniBatchFeatures = convertMiniBatchFeatures(tokensAndLabelList, maxL, jArr);
        INDArray[] first = convertMiniBatchFeatures.getFirst();
        INDArray[] second = convertMiniBatchFeatures.getSecond();
        Pair<INDArray[], INDArray[]> convertMiniBatchLabels = convertMiniBatchLabels(tokensAndLabelList, first, maxL);
        org.nd4j.linalg.dataset.MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet(first, convertMiniBatchLabels.getFirst(), second, convertMiniBatchLabels.getSecond());
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(multiDataSet);
        }
        return multiDataSet;
    }

    public Pair<INDArray[], INDArray[]> featurizeSentences(List<String> list) {
        SentenceListProcessed sentenceListProcessed = tokenizeMiniBatch(addDummyLabel(list));
        List<Pair<List<String>, String>> tokensAndLabelList = sentenceListProcessed.getTokensAndLabelList();
        int maxL = sentenceListProcessed.getMaxL();
        if (this.preProcessor == null) {
            return convertMiniBatchFeatures(tokensAndLabelList, maxL, null);
        }
        Pair<INDArray[], INDArray[]> convertMiniBatchFeatures = convertMiniBatchFeatures(tokensAndLabelList, maxL, null);
        org.nd4j.linalg.dataset.MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet(convertMiniBatchFeatures.getFirst(), (INDArray[]) null, convertMiniBatchFeatures.getSecond(), (INDArray[]) null);
        this.preProcessor.preProcess(multiDataSet);
        return new Pair<>(multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays());
    }

    public Pair<INDArray[], INDArray[]> featurizeSentencePairs(List<Pair<String, String>> list) {
        Preconditions.checkState(this.sentencePairProvider != null, "The featurizeSentencePairs method is meant for inference with sentence pairs. Use only when the sentence pair provider is set (i.e not null).");
        SentencePairListProcessed sentencePairListProcessed = tokenizePairsMiniBatch(addDummyLabelForPairs(list));
        List<Pair<List<String>, String>> tokensAndLabelList = sentencePairListProcessed.getTokensAndLabelList();
        int maxL = sentencePairListProcessed.getMaxL();
        long[] segIdOnesFrom = sentencePairListProcessed.getSegIdOnesFrom();
        if (this.preProcessor == null) {
            return convertMiniBatchFeatures(tokensAndLabelList, maxL, segIdOnesFrom);
        }
        Pair<INDArray[], INDArray[]> convertMiniBatchFeatures = convertMiniBatchFeatures(tokensAndLabelList, maxL, segIdOnesFrom);
        org.nd4j.linalg.dataset.MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet(convertMiniBatchFeatures.getFirst(), (INDArray[]) null, convertMiniBatchFeatures.getSecond(), (INDArray[]) null);
        this.preProcessor.preProcess(multiDataSet);
        return new Pair<>(multiDataSet.getFeatures(), multiDataSet.getFeaturesMaskArrays());
    }

    private Pair<INDArray[], INDArray[]> convertMiniBatchFeatures(List<Pair<List<String>, String>> list, int i, long[] jArr) {
        INDArray[] iNDArrayArr;
        INDArray[] iNDArrayArr2;
        int size = this.padMinibatches ? this.minibatchSize : list.size();
        int[][] iArr = new int[size][i];
        int[][] iArr2 = new int[size][i];
        int[][] iArr3 = (int[][]) null;
        if (this.featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
            iArr3 = new int[size][i];
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            List<String> first = list.get(i2).getFirst();
            for (int i3 = 0; i3 < i && i3 < first.size(); i3++) {
                Preconditions.checkState(this.vocabMap.containsKey(first.get(i3)), "Unknown token encountered: token \"%s\" is not in vocabulary", first.get(i3));
                iArr[i2][i3] = this.vocabMap.get(first.get(i3)).intValue();
                iArr2[i2][i3] = 1;
                if (jArr != null && i3 >= jArr[i2]) {
                    iArr3[i2][i3] = 1;
                }
            }
        }
        INDArray createFromArray = Nd4j.createFromArray(iArr);
        INDArray createFromArray2 = Nd4j.createFromArray(iArr2);
        if (this.featureArrays == FeatureArrays.INDICES_MASK_SEGMENTID) {
            iNDArrayArr = new INDArray[]{createFromArray, Nd4j.createFromArray(iArr3)};
            iNDArrayArr2 = new INDArray[]{createFromArray2, null};
        } else {
            iNDArrayArr = new INDArray[]{createFromArray};
            iNDArrayArr2 = new INDArray[]{createFromArray2};
        }
        return new Pair<>(iNDArrayArr, iNDArrayArr2);
    }

    private SentenceListProcessed tokenizeMiniBatch(List<Pair<String, String>> list) {
        int min;
        SentenceListProcessed sentenceListProcessed = new SentenceListProcessed(list.size());
        int i = -1;
        for (Pair<String, String> pair : list) {
            List<String> list2 = tokenizeSentence(pair.getFirst());
            sentenceListProcessed.addProcessedToList(new Pair(list2, pair.getSecond()));
            i = Math.max(i, list2.size());
        }
        switch (this.lengthHandling) {
            case FIXED_LENGTH:
                min = this.maxTokens;
                break;
            case ANY_LENGTH:
                min = i;
                break;
            case CLIP_ONLY:
                min = Math.min(this.maxTokens, i);
                break;
            default:
                throw new RuntimeException("Not implemented length handling mode: " + this.lengthHandling);
        }
        sentenceListProcessed.setMaxL(min);
        return sentenceListProcessed;
    }

    private SentencePairListProcessed tokenizePairsMiniBatch(List<Triple<String, String, String>> list) {
        SentencePairListProcessed sentencePairListProcessed = new SentencePairListProcessed(list.size());
        for (Triple<String, String, String> triple : list) {
            List<String> list2 = tokenizeSentence(triple.getFirst(), true);
            List<String> list3 = tokenizeSentence(triple.getSecond(), true);
            ArrayList arrayList = new ArrayList(this.maxTokens);
            int i = this.maxTokens;
            if (this.prependToken != null) {
                i--;
            }
            if (this.appendToken != null) {
                i -= 2;
            }
            if (list2.size() + list3.size() > i) {
                boolean z = list2.size() < list3.size();
                if (Math.min(list2.size(), list3.size()) > i / 2) {
                    list2.subList(i / 2, list2.size()).clear();
                    list3.subList(i - (i / 2), list3.size()).clear();
                } else if (z) {
                    list3.subList(i - list2.size(), list3.size()).clear();
                } else {
                    list2.subList(i - list3.size(), list2.size()).clear();
                }
            }
            if (this.prependToken != null) {
                arrayList.add(this.prependToken);
            }
            arrayList.addAll(list2);
            if (this.appendToken != null) {
                arrayList.add(this.appendToken);
            }
            int size = arrayList.size();
            arrayList.addAll(list3);
            if (this.appendToken != null) {
                arrayList.add(this.appendToken);
            }
            sentencePairListProcessed.addProcessedToList(size, new Pair(arrayList, triple.getThird()));
        }
        sentencePairListProcessed.setMaxL(this.maxTokens);
        return sentencePairListProcessed;
    }

    private Pair<INDArray[], INDArray[]> convertMiniBatchLabels(List<Pair<List<String>, String>> list, INDArray[] iNDArrayArr, int i) {
        INDArray create;
        INDArray[] iNDArrayArr2;
        int numLabelClasses;
        INDArray[] iNDArrayArr3 = new INDArray[1];
        int size = list.size();
        int size2 = this.padMinibatches ? this.minibatchSize : list.size();
        if (this.task == Task.SEQ_CLASSIFICATION) {
            int[] iArr = new int[size2];
            if (this.sentenceProvider != null) {
                numLabelClasses = this.sentenceProvider.numLabelClasses();
                List<String> allLabels = this.sentenceProvider.allLabels();
                for (int i2 = 0; i2 < size; i2++) {
                    String right = list.get(i2).getRight();
                    iArr[i2] = allLabels.indexOf(right);
                    Preconditions.checkState(iArr[i2] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", right);
                }
            } else {
                if (this.sentencePairProvider == null) {
                    throw new RuntimeException();
                }
                numLabelClasses = this.sentencePairProvider.numLabelClasses();
                List<String> allLabels2 = this.sentencePairProvider.allLabels();
                for (int i3 = 0; i3 < size; i3++) {
                    String right2 = list.get(i3).getRight();
                    iArr[i3] = allLabels2.indexOf(right2);
                    Preconditions.checkState(iArr[i3] >= 0, "Provided label \"%s\" for sentence does not exist in set of classes/categories", right2);
                }
            }
            iNDArrayArr3[0] = Nd4j.create(DataType.FLOAT, size2, numLabelClasses);
            for (int i4 = 0; i4 < size; i4++) {
                iNDArrayArr3[0].putScalar(i4, iArr[i4], 1.0d);
            }
            iNDArrayArr2 = null;
            if (this.padMinibatches && size != size2) {
                INDArray zeros = Nd4j.zeros(DataType.FLOAT, size2, 1);
                iNDArrayArr2 = new INDArray[]{zeros};
                zeros.get(NDArrayIndex.interval(0, size), NDArrayIndex.all()).assign((Number) 1);
            }
        } else {
            if (this.task != Task.UNSUPERVISED) {
                throw new IllegalStateException("Task not yet implemented: " + this.task);
            }
            if (this.vocabKeysAsList == null) {
                String[] strArr = new String[this.vocabMap.size()];
                for (Map.Entry<String, Integer> entry : this.vocabMap.entrySet()) {
                    strArr[entry.getValue().intValue()] = entry.getKey();
                }
                this.vocabKeysAsList = Arrays.asList(strArr);
            }
            int size3 = this.vocabMap.size();
            INDArray zeros2 = Nd4j.zeros(DataType.INT, size2, i);
            if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
                create = Nd4j.create(DataType.INT, size2, i);
            } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
                create = Nd4j.create(DataType.FLOAT, size2, size3, i);
            } else {
                if (this.unsupervisedLabelFormat != UnsupervisedLabelFormat.RANK3_LNC) {
                    throw new IllegalStateException("Unknown unsupervised label format: " + this.unsupervisedLabelFormat);
                }
                create = Nd4j.create(DataType.FLOAT, i, size2, size3);
            }
            for (int i5 = 0; i5 < size; i5++) {
                Pair<List<String>, boolean[]> maskSequence = this.masker.maskSequence(list.get(i5).getFirst(), this.maskToken, this.vocabKeysAsList);
                List<String> first = maskSequence.getFirst();
                boolean[] second = maskSequence.getSecond();
                int min = Math.min(second.length, i);
                for (int i6 = 0; i6 < min; i6++) {
                    if (second[i6]) {
                        int intValue = this.vocabMap.get(list.get(i5).getFirst().get(i6)).intValue();
                        if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK2_IDX) {
                            create.putScalar(i5, i6, intValue);
                        } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_NCL) {
                            create.putScalar(i5, i6, intValue, 1.0d);
                        } else if (this.unsupervisedLabelFormat == UnsupervisedLabelFormat.RANK3_LNC) {
                            create.putScalar(i6, i5, intValue, 1.0d);
                        }
                        zeros2.putScalar(i5, i6, 1.0d);
                        iNDArrayArr[0].putScalar(i5, i6, this.vocabMap.get(first.get(i6)).intValue());
                    }
                }
            }
            iNDArrayArr3[0] = create;
            iNDArrayArr2 = new INDArray[]{zeros2};
        }
        return new Pair<>(iNDArrayArr3, iNDArrayArr2);
    }

    private List<String> tokenizeSentence(String str) {
        return tokenizeSentence(str, false);
    }

    private List<String> tokenizeSentence(String str, boolean z) {
        Tokenizer create = this.tokenizerFactory.create(str);
        ArrayList arrayList = new ArrayList();
        if (this.prependToken != null && !z) {
            arrayList.add(this.prependToken);
        }
        while (create.hasMoreTokens()) {
            arrayList.add(create.nextToken());
        }
        if (this.appendToken != null && !z) {
            arrayList.add(this.appendToken);
        }
        return arrayList;
    }

    private List<Pair<String, String>> addDummyLabel(List<String> list) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<String> it2 = list.iterator();
        while (it2.hasNext()) {
            arrayList.add(new Pair(it2.next(), null));
        }
        return arrayList;
    }

    private List<Triple<String, String, String>> addDummyLabelForPairs(List<Pair<String, String>> list) {
        ArrayList arrayList = new ArrayList(list.size());
        for (Pair<String, String> pair : list) {
            arrayList.add(new Triple(pair.getFirst(), pair.getSecond(), null));
        }
        return arrayList;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean resetSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean asyncSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void reset() {
        if (this.sentenceProvider != null) {
            this.sentenceProvider.reset();
        }
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public MultiDataSetPreProcessor getPreProcessor() {
        return this.preProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
        this.preProcessor = multiDataSetPreProcessor;
    }
}
