package org.deeplearning4j.datasets.canova;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.NoSuchElementException;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.writable.Writable;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/datasets/canova/SequenceRecordReaderDataSetIterator.class */
public class SequenceRecordReaderDataSetIterator implements DataSetIterator {
    private SequenceRecordReader recordReader;
    private SequenceRecordReader labelsReader;
    private int miniBatchSize;
    private final boolean regression;
    private final int numPossibleLabels;
    private DataSetPreProcessor preProcessor;
    private int labelIndex = -1;
    private int cursor = 0;
    private int inputColumns = -1;
    private int totalOutcomes = -1;
    private boolean useStored = false;
    private DataSet stored = null;

    public SequenceRecordReaderDataSetIterator(SequenceRecordReader sequenceRecordReader, SequenceRecordReader sequenceRecordReader2, int i, int i2, boolean z) {
        this.miniBatchSize = 10;
        this.recordReader = sequenceRecordReader;
        this.labelsReader = sequenceRecordReader2;
        this.miniBatchSize = i;
        this.numPossibleLabels = i2;
        this.regression = z;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.recordReader.hasNext();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public DataSet next() {
        return next(this.miniBatchSize);
    }

    public DataSet next(int i) {
        if (this.useStored) {
            this.useStored = false;
            DataSet dataSet = this.stored;
            this.stored = null;
            if (this.preProcessor != null) {
                this.preProcessor.preProcess(dataSet);
            }
            return dataSet;
        }
        if (!hasNext()) {
            throw new NoSuchElementException();
        }
        ArrayList arrayList = new ArrayList(i);
        ArrayList arrayList2 = new ArrayList(i);
        for (int i2 = 0; i2 < i && hasNext(); i2++) {
            Collection<Collection<Writable>> sequenceRecord = this.recordReader.sequenceRecord();
            Collection<Collection<Writable>> sequenceRecord2 = this.labelsReader.sequenceRecord();
            INDArray features = getFeatures(sequenceRecord);
            INDArray labels = getLabels(sequenceRecord2);
            arrayList.add(features);
            arrayList2.add(labels);
        }
        int[] iArr = {arrayList.size(), ((INDArray) arrayList.get(0)).size(1), ((INDArray) arrayList.get(0)).size(0)};
        int[] iArr2 = {arrayList2.size(), ((INDArray) arrayList2.get(0)).size(1), ((INDArray) arrayList2.get(0)).size(0)};
        INDArray create = Nd4j.create(iArr);
        INDArray create2 = Nd4j.create(iArr2);
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            create.tensorAlongDimension(i3, new int[]{1, 2}).assign((INDArray) arrayList.get(i3));
            create2.tensorAlongDimension(i3, new int[]{1, 2}).assign((INDArray) arrayList2.get(i3));
        }
        this.cursor += arrayList.size();
        if (this.inputColumns == -1) {
            this.inputColumns = create.size(1);
        }
        if (this.totalOutcomes == -1) {
            this.totalOutcomes = create2.size(1);
        }
        DataSet dataSet2 = new DataSet(create, create2);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet2);
        }
        return dataSet2;
    }

    public int totalExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public int inputColumns() {
        if (this.inputColumns != -1) {
            return this.inputColumns;
        }
        preLoad();
        return this.inputColumns;
    }

    public int totalOutcomes() {
        if (this.totalOutcomes != -1) {
            return this.totalOutcomes;
        }
        preLoad();
        return this.totalOutcomes;
    }

    private void preLoad() {
        this.stored = next();
        this.useStored = true;
        this.inputColumns = this.stored.getFeatureMatrix().size(1);
        this.totalOutcomes = this.stored.getLabels().size(1);
    }

    public void reset() {
    }

    public int batch() {
        return this.miniBatchSize;
    }

    public int cursor() {
        return this.cursor;
    }

    public int numExamples() {
        throw new UnsupportedOperationException("Not supported");
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    @Override // java.util.Iterator
    public void remove() {
        throw new UnsupportedOperationException("Remove not supported for this iterator");
    }

    private INDArray getFeatures(Collection<Collection<Writable>> collection) {
        int[] iArr = new int[2];
        iArr[0] = collection.size();
        int i = 0;
        INDArray iNDArray = null;
        for (Collection<Writable> collection2 : collection) {
            if (i == 0) {
                iArr[1] = collection2.size();
                iNDArray = Nd4j.create(iArr);
            }
            Iterator<Writable> it = collection2.iterator();
            int i2 = 0;
            while (it.hasNext()) {
                int i3 = i2;
                i2++;
                iNDArray.put(i, i3, Double.valueOf(it.next().toDouble()));
            }
            i++;
        }
        return iNDArray;
    }

    private INDArray getLabels(Collection<Collection<Writable>> collection) {
        int[] iArr = new int[2];
        iArr[0] = collection.size();
        int i = 0;
        INDArray iNDArray = null;
        for (Collection<Writable> collection2 : collection) {
            if (i == 0) {
                if (this.regression) {
                    iArr[1] = collection2.size();
                } else {
                    iArr[1] = this.numPossibleLabels;
                }
                iNDArray = Nd4j.create(iArr);
            }
            Iterator<Writable> it = collection2.iterator();
            int i2 = 0;
            if (this.regression) {
                while (it.hasNext()) {
                    int i3 = i2;
                    i2++;
                    iNDArray.put(i3, i, Double.valueOf(it.next().toDouble()));
                }
            } else {
                iNDArray.getRow(i).assign(FeatureUtil.toOutcomeVector(it.next().toInt(), this.numPossibleLabels));
            }
            i++;
        }
        return iNDArray;
    }
}
