package org.deeplearning4j.datasets.canova;

import java.beans.ConstructorProperties;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.writable.Writable;
import org.canova.common.data.NDArrayWritable;
import org.deeplearning4j.berkeley.Pair;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/datasets/canova/RecordReaderMultiDataSetIterator.class */
public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator {
    private int batchSize;
    private AlignmentMode alignmentMode;
    private Map<String, RecordReader> recordReaders;
    private Map<String, SequenceRecordReader> sequenceRecordReaders;
    private List<SubsetDetails> inputs;
    private List<SubsetDetails> outputs;
    private MultiDataSetPreProcessor preProcessor;

    /* loaded from: input_file:org/deeplearning4j/datasets/canova/RecordReaderMultiDataSetIterator$AlignmentMode.class */
    public enum AlignmentMode {
        EQUAL_LENGTH,
        ALIGN_START,
        ALIGN_END
    }

    /* loaded from: input_file:org/deeplearning4j/datasets/canova/RecordReaderMultiDataSetIterator$Builder.class */
    public static class Builder {
        private int batchSize;
        private AlignmentMode alignmentMode = AlignmentMode.EQUAL_LENGTH;
        private Map<String, RecordReader> recordReaders = new HashMap();
        private Map<String, SequenceRecordReader> sequenceRecordReaders = new HashMap();
        private List<SubsetDetails> inputs = new ArrayList();
        private List<SubsetDetails> outputs = new ArrayList();

        public Builder(int i) {
            this.batchSize = i;
        }

        public Builder addReader(String str, RecordReader recordReader) {
            this.recordReaders.put(str, recordReader);
            return this;
        }

        public Builder addSequenceReader(String str, SequenceRecordReader sequenceRecordReader) {
            this.sequenceRecordReaders.put(str, sequenceRecordReader);
            return this;
        }

        public Builder sequenceAlignmentMode(AlignmentMode alignmentMode) {
            this.alignmentMode = alignmentMode;
            return this;
        }

        public Builder addInput(String str) {
            this.inputs.add(new SubsetDetails(str, true, false, -1, -1, -1));
            return this;
        }

        public Builder addInput(String str, int i, int i2) {
            this.inputs.add(new SubsetDetails(str, false, false, -1, i, i2));
            return this;
        }

        public Builder addInputOneHot(String str, int i, int i2) {
            this.inputs.add(new SubsetDetails(str, false, true, i2, i, -1));
            return this;
        }

        public Builder addOutput(String str) {
            this.outputs.add(new SubsetDetails(str, true, false, -1, -1, -1));
            return this;
        }

        public Builder addOutput(String str, int i, int i2) {
            this.outputs.add(new SubsetDetails(str, false, false, -1, i, i2));
            return this;
        }

        public Builder addOutputOneHot(String str, int i, int i2) {
            this.outputs.add(new SubsetDetails(str, false, true, i2, i, -1));
            return this;
        }

        public RecordReaderMultiDataSetIterator build() {
            if (this.recordReaders.size() == 0 && this.sequenceRecordReaders.size() == 0) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no readers");
            }
            if (this.batchSize <= 0) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with batch size <= 0");
            }
            if (this.inputs.size() == 0 && this.outputs.size() == 0) {
                throw new IllegalStateException("Cannot construct RecordReaderMultiDataSetIterator with no inputs/outputs");
            }
            for (SubsetDetails subsetDetails : this.inputs) {
                if (!this.recordReaders.containsKey(subsetDetails.readerName) && !this.sequenceRecordReaders.containsKey(subsetDetails.readerName)) {
                    throw new IllegalStateException("Invalid input name: \"" + subsetDetails.readerName + "\" - no reader found with this name");
                }
            }
            for (SubsetDetails subsetDetails2 : this.outputs) {
                if (!this.recordReaders.containsKey(subsetDetails2.readerName) && !this.sequenceRecordReaders.containsKey(subsetDetails2.readerName)) {
                    throw new IllegalStateException("Invalid output name: \"" + subsetDetails2.readerName + "\" - no reader found with this name");
                }
            }
            return new RecordReaderMultiDataSetIterator(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/datasets/canova/RecordReaderMultiDataSetIterator$SubsetDetails.class */
    public static class SubsetDetails {
        private final String readerName;
        private final boolean entireReader;
        private final boolean oneHot;
        private final int oneHotNumClasses;
        private final int subsetStart;
        private final int subsetEndInclusive;

        @ConstructorProperties({"readerName", "entireReader", "oneHot", "oneHotNumClasses", "subsetStart", "subsetEndInclusive"})
        public SubsetDetails(String str, boolean z, boolean z2, int i, int i2, int i3) {
            this.readerName = str;
            this.entireReader = z;
            this.oneHot = z2;
            this.oneHotNumClasses = i;
            this.subsetStart = i2;
            this.subsetEndInclusive = i3;
        }
    }

    private RecordReaderMultiDataSetIterator(Builder builder) {
        this.recordReaders = new HashMap();
        this.sequenceRecordReaders = new HashMap();
        this.inputs = new ArrayList();
        this.outputs = new ArrayList();
        this.batchSize = builder.batchSize;
        this.alignmentMode = builder.alignmentMode;
        this.recordReaders = builder.recordReaders;
        this.sequenceRecordReaders = builder.sequenceRecordReaders;
        this.inputs.addAll(builder.inputs);
        this.outputs.addAll(builder.outputs);
    }

    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public MultiDataSet m11next() {
        return next(this.batchSize);
    }

    public void remove() {
        throw new UnsupportedOperationException("Remove not supported");
    }

    public MultiDataSet next(int i) {
        if (!hasNext()) {
            throw new NoSuchElementException("No next elements");
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        int i2 = Integer.MAX_VALUE;
        for (Map.Entry<String, RecordReader> entry : this.recordReaders.entrySet()) {
            RecordReader value = entry.getValue();
            ArrayList arrayList = new ArrayList(i);
            for (int i3 = 0; i3 < i && value.hasNext(); i3++) {
                arrayList.add(value.next());
            }
            i2 = Math.min(i2, arrayList.size());
            hashMap.put(entry.getKey(), arrayList);
        }
        for (Map.Entry<String, SequenceRecordReader> entry2 : this.sequenceRecordReaders.entrySet()) {
            SequenceRecordReader value2 = entry2.getValue();
            ArrayList arrayList2 = new ArrayList(i);
            for (int i4 = 0; i4 < i && value2.hasNext(); i4++) {
                arrayList2.add(value2.sequenceRecord());
            }
            i2 = Math.min(i2, arrayList2.size());
            hashMap2.put(entry2.getKey(), arrayList2);
        }
        if (i2 == Integer.MAX_VALUE) {
            throw new RuntimeException("Error occurred during data set generation: no readers?");
        }
        int[] iArr = null;
        if (this.alignmentMode == AlignmentMode.ALIGN_END) {
            iArr = new int[i2];
            Iterator it = hashMap2.entrySet().iterator();
            while (it.hasNext()) {
                List list = (List) ((Map.Entry) it.next()).getValue();
                for (int i5 = 0; i5 < list.size() && i5 < i2; i5++) {
                    iArr[i5] = Math.max(iArr[i5], ((Collection) list.get(i5)).size());
                }
            }
        }
        int i6 = -1;
        if (this.alignmentMode != AlignmentMode.EQUAL_LENGTH) {
            Iterator it2 = hashMap2.entrySet().iterator();
            while (it2.hasNext()) {
                Iterator it3 = ((List) ((Map.Entry) it2.next()).getValue()).iterator();
                while (it3.hasNext()) {
                    i6 = Math.max(i6, ((Collection) it3.next()).size());
                }
            }
        }
        INDArray[] iNDArrayArr = new INDArray[this.inputs.size()];
        INDArray[] iNDArrayArr2 = new INDArray[this.inputs.size()];
        boolean z = false;
        int i7 = 0;
        for (SubsetDetails subsetDetails : this.inputs) {
            if (hashMap.containsKey(subsetDetails.readerName)) {
                iNDArrayArr[i7] = convertWritables((List) hashMap.get(subsetDetails.readerName), i2, subsetDetails);
            } else {
                Pair<INDArray, INDArray> convertWritablesSequence = convertWritablesSequence((List) hashMap2.get(subsetDetails.readerName), i2, i6, subsetDetails, iArr);
                iNDArrayArr[i7] = convertWritablesSequence.getFirst();
                iNDArrayArr2[i7] = convertWritablesSequence.getSecond();
                if (iNDArrayArr2[i7] != null) {
                    z = true;
                }
            }
            i7++;
        }
        if (!z) {
            iNDArrayArr2 = null;
        }
        INDArray[] iNDArrayArr3 = new INDArray[this.outputs.size()];
        INDArray[] iNDArrayArr4 = new INDArray[this.outputs.size()];
        boolean z2 = false;
        int i8 = 0;
        for (SubsetDetails subsetDetails2 : this.outputs) {
            if (hashMap.containsKey(subsetDetails2.readerName)) {
                iNDArrayArr3[i8] = convertWritables((List) hashMap.get(subsetDetails2.readerName), i2, subsetDetails2);
            } else {
                Pair<INDArray, INDArray> convertWritablesSequence2 = convertWritablesSequence((List) hashMap2.get(subsetDetails2.readerName), i2, i6, subsetDetails2, iArr);
                iNDArrayArr3[i8] = convertWritablesSequence2.getFirst();
                iNDArrayArr4[i8] = convertWritablesSequence2.getSecond();
                if (iNDArrayArr4[i8] != null) {
                    z2 = true;
                }
            }
            i8++;
        }
        if (!z2) {
            iNDArrayArr4 = null;
        }
        org.nd4j.linalg.dataset.MultiDataSet multiDataSet = new org.nd4j.linalg.dataset.MultiDataSet(iNDArrayArr, iNDArrayArr3, iNDArrayArr2, iNDArrayArr4);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(multiDataSet);
        }
        return multiDataSet;
    }

    private INDArray convertWritables(List<Collection<Writable>> list, int i, SubsetDetails subsetDetails) {
        INDArray create = subsetDetails.entireReader ? Nd4j.create(i, list.get(0).size()) : subsetDetails.oneHot ? Nd4j.zeros(i, subsetDetails.oneHotNumClasses) : Nd4j.create(i, (subsetDetails.subsetEndInclusive - subsetDetails.subsetStart) + 1);
        for (int i2 = 0; i2 < i; i2++) {
            Collection<Writable> collection = list.get(i2);
            if (subsetDetails.entireReader) {
                int i3 = 0;
                Iterator<Writable> it = collection.iterator();
                while (it.hasNext()) {
                    NDArrayWritable nDArrayWritable = (Writable) it.next();
                    try {
                        create.putScalar(i2, i3, nDArrayWritable.toDouble());
                    } catch (UnsupportedOperationException e) {
                        if (!(nDArrayWritable instanceof NDArrayWritable)) {
                            throw e;
                        }
                        create.putRow(i2, nDArrayWritable.get());
                    }
                    i3++;
                }
            } else if (subsetDetails.oneHot) {
                Writable writable = null;
                if (collection instanceof List) {
                    writable = (Writable) ((List) collection).get(subsetDetails.subsetStart);
                } else {
                    Iterator<Writable> it2 = collection.iterator();
                    for (int i4 = 0; i4 <= subsetDetails.subsetStart; i4++) {
                        writable = it2.next();
                    }
                }
                create.putScalar(i2, writable.toInt(), 1.0d);
            } else {
                Iterator<Writable> it3 = collection.iterator();
                for (int i5 = 0; i5 < subsetDetails.subsetStart; i5++) {
                    it3.next();
                }
                int i6 = 0;
                for (int i7 = subsetDetails.subsetStart; i7 <= subsetDetails.subsetEndInclusive; i7++) {
                    NDArrayWritable nDArrayWritable2 = (Writable) it3.next();
                    try {
                        create.putScalar(i2, i6, nDArrayWritable2.toDouble());
                    } catch (UnsupportedOperationException e2) {
                        if (!(nDArrayWritable2 instanceof NDArrayWritable)) {
                            throw e2;
                        }
                        create.putRow(i2, nDArrayWritable2.get().get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(subsetDetails.subsetStart, subsetDetails.subsetEndInclusive + 1)}));
                    }
                    i6++;
                }
            }
        }
        return create;
    }

    private Pair<INDArray, INDArray> convertWritablesSequence(List<Collection<Collection<Writable>>> list, int i, int i2, SubsetDetails subsetDetails, int[] iArr) {
        if (i2 == -1) {
            i2 = list.get(0).size();
        }
        INDArray create = subsetDetails.entireReader ? Nd4j.create(new int[]{i, list.get(0).iterator().next().size(), i2}, 'f') : subsetDetails.oneHot ? Nd4j.create(new int[]{i, subsetDetails.oneHotNumClasses, i2}, 'f') : Nd4j.create(new int[]{i, (subsetDetails.subsetEndInclusive - subsetDetails.subsetStart) + 1, i2}, 'f');
        boolean z = false;
        Iterator<Collection<Collection<Writable>>> it = list.iterator();
        while (it.hasNext()) {
            if (it.next().size() < i2) {
                z = true;
            }
        }
        INDArray ones = z ? Nd4j.ones(i, i2) : null;
        for (int i3 = 0; i3 < i; i3++) {
            Collection<Collection<Writable>> collection = list.get(i3);
            int size = (this.alignmentMode == AlignmentMode.ALIGN_START || this.alignmentMode == AlignmentMode.EQUAL_LENGTH) ? 0 : iArr[i3] - collection.size();
            int i4 = 0;
            for (Collection<Writable> collection2 : collection) {
                int i5 = i4;
                i4++;
                int i6 = size + i5;
                if (subsetDetails.entireReader) {
                    Iterator<Writable> it2 = collection2.iterator();
                    int i7 = 0;
                    while (it2.hasNext()) {
                        NDArrayWritable nDArrayWritable = (Writable) it2.next();
                        try {
                            create.putScalar(i3, i7, i6, nDArrayWritable.toDouble());
                        } catch (UnsupportedOperationException e) {
                            if (!(nDArrayWritable instanceof NDArrayWritable)) {
                                throw e;
                            }
                            create.get(new INDArrayIndex[]{NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.point(i6)}).putRow(0, nDArrayWritable.get());
                        }
                        i7++;
                    }
                } else if (subsetDetails.oneHot) {
                    Writable writable = null;
                    if (collection2 instanceof List) {
                        writable = (Writable) ((List) collection2).get(subsetDetails.subsetStart);
                    } else {
                        Iterator<Writable> it3 = collection2.iterator();
                        for (int i8 = 0; i8 <= subsetDetails.subsetStart; i8++) {
                            writable = it3.next();
                        }
                    }
                    create.putScalar(i3, writable.toInt(), i6, 1.0d);
                } else {
                    Iterator<Writable> it4 = collection2.iterator();
                    for (int i9 = 0; i9 < subsetDetails.subsetStart; i9++) {
                        it4.next();
                    }
                    int i10 = 0;
                    for (int i11 = subsetDetails.subsetStart; i11 <= subsetDetails.subsetEndInclusive; i11++) {
                        NDArrayWritable nDArrayWritable2 = (Writable) it4.next();
                        try {
                            int i12 = i10;
                            i10++;
                            create.putScalar(i3, i12, i6, nDArrayWritable2.toDouble());
                        } catch (UnsupportedOperationException e2) {
                            if (!(nDArrayWritable2 instanceof NDArrayWritable)) {
                                throw e2;
                            }
                            create.get(new INDArrayIndex[]{NDArrayIndex.point(i3), NDArrayIndex.all(), NDArrayIndex.point(i6)}).putRow(0, nDArrayWritable2.get().get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(subsetDetails.subsetStart, subsetDetails.subsetEndInclusive + 1)}));
                        }
                    }
                }
            }
            if (z) {
                if (this.alignmentMode == AlignmentMode.ALIGN_END) {
                    for (int i13 = 0; i13 < size; i13++) {
                        ones.putScalar(i3, i13, 0.0d);
                    }
                }
                if (this.alignmentMode == AlignmentMode.ALIGN_START) {
                    for (int i14 = i4; i14 < i2; i14++) {
                        ones.putScalar(i3, i14, 0.0d);
                    }
                }
            }
        }
        return new Pair<>(create, ones);
    }

    public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
        this.preProcessor = multiDataSetPreProcessor;
    }

    public void reset() {
        Iterator<RecordReader> it = this.recordReaders.values().iterator();
        while (it.hasNext()) {
            it.next().reset();
        }
        Iterator<SequenceRecordReader> it2 = this.sequenceRecordReaders.values().iterator();
        while (it2.hasNext()) {
            it2.next().reset();
        }
    }

    public boolean hasNext() {
        Iterator<RecordReader> it = this.recordReaders.values().iterator();
        while (it.hasNext()) {
            if (!it.next().hasNext()) {
                return false;
            }
        }
        Iterator<SequenceRecordReader> it2 = this.sequenceRecordReaders.values().iterator();
        while (it2.hasNext()) {
            if (!it2.next().hasNext()) {
                return false;
            }
        }
        return true;
    }
}
