package org.deeplearning4j.datasets.canova;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.canova.api.io.WritableConverter;
import org.canova.api.io.converters.SelfWritableConverter;
import org.canova.api.io.converters.WritableConverterException;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.records.reader.SequenceRecordReader;
import org.canova.api.writable.Writable;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

/* loaded from: input_file:org/deeplearning4j/datasets/canova/RecordReaderDataSetIterator.class */
public class RecordReaderDataSetIterator implements DataSetIterator {
    private RecordReader recordReader;
    private WritableConverter converter;
    private int batchSize;
    private int labelIndex;
    private int numPossibleLabels;
    private boolean overshot;
    private Iterator<Collection<Writable>> sequenceIter;
    private DataSet last;
    private boolean useCurrent;
    private boolean regression;

    public RecordReaderDataSetIterator(RecordReader recordReader, int i) {
        this(recordReader, new SelfWritableConverter(), i, -1, -1);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2, int i3) {
        this(recordReader, new SelfWritableConverter(), i, i2, i3);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader) {
        this(recordReader, (WritableConverter) new SelfWritableConverter());
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, int i, int i2) {
        this(recordReader, new SelfWritableConverter(), 10, i, i2);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2, int i3, boolean z) {
        this.batchSize = 10;
        this.labelIndex = -1;
        this.numPossibleLabels = -1;
        this.overshot = false;
        this.useCurrent = false;
        this.regression = false;
        this.recordReader = recordReader;
        this.converter = writableConverter;
        this.batchSize = i;
        this.labelIndex = i2;
        this.numPossibleLabels = i3;
        this.regression = z;
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2, int i3) {
        this(recordReader, writableConverter, i, i2, i3, false);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter) {
        this(recordReader, writableConverter, 10, -1, -1);
    }

    public RecordReaderDataSetIterator(RecordReader recordReader, WritableConverter writableConverter, int i, int i2) {
        this(recordReader, writableConverter, 10, i, i2);
    }

    public DataSet next(int i) {
        if (this.useCurrent) {
            this.useCurrent = false;
            return this.last;
        }
        ArrayList<DataSet> arrayList = new ArrayList();
        for (int i2 = 0; i2 < i && hasNext(); i2++) {
            if (this.recordReader instanceof SequenceRecordReader) {
                if (this.sequenceIter == null || !this.sequenceIter.hasNext()) {
                    this.sequenceIter = this.recordReader.sequenceRecord().iterator();
                }
                arrayList.add(getDataSet(this.sequenceIter.next()));
            } else {
                arrayList.add(getDataSet(this.recordReader.next()));
            }
        }
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (DataSet dataSet : arrayList) {
            arrayList2.add(dataSet.getFeatureMatrix());
            arrayList3.add(dataSet.getLabels());
        }
        if (arrayList2.isEmpty()) {
            this.overshot = true;
            return this.last;
        }
        DataSet dataSet2 = new DataSet(Nd4j.vstack((INDArray[]) arrayList2.toArray(new INDArray[0])), Nd4j.vstack((INDArray[]) arrayList3.toArray(new INDArray[0])));
        this.last = dataSet2;
        return dataSet2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v61, types: [java.util.List] */
    private DataSet getDataSet(Collection<Writable> collection) {
        ArrayList arrayList = collection instanceof List ? (List) collection : new ArrayList(collection);
        if (this.numPossibleLabels >= 1 && this.labelIndex < 0) {
            this.labelIndex = collection.size() - 1;
        }
        INDArray iNDArray = null;
        INDArray create = Nd4j.create(this.labelIndex >= 0 ? arrayList.size() - 1 : arrayList.size());
        for (int i = 0; i < arrayList.size(); i++) {
            if (this.labelIndex < 0 || i != this.labelIndex) {
                Writable writable = (Writable) arrayList.get(i);
                if (!writable.toString().isEmpty()) {
                    create.putScalar(i, Double.valueOf(writable.toString()).doubleValue());
                }
            } else {
                if (this.numPossibleLabels < 1) {
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
                }
                Writable writable2 = (Writable) arrayList.get(i);
                if (!writable2.toString().isEmpty()) {
                    if (this.converter != null) {
                        try {
                            writable2 = this.converter.convert(writable2);
                        } catch (WritableConverterException e) {
                            e.printStackTrace();
                        }
                    }
                    if (this.regression) {
                        iNDArray = Nd4j.scalar(Double.valueOf(writable2.toString()));
                    } else {
                        int intValue = Double.valueOf(writable2.toString()).intValue();
                        if (intValue >= this.numPossibleLabels) {
                            intValue--;
                        }
                        iNDArray = FeatureUtil.toOutcomeVector(intValue, this.numPossibleLabels);
                    }
                }
            }
        }
        return new DataSet(create, this.labelIndex >= 0 ? iNDArray : create);
    }

    public int totalExamples() {
        throw new UnsupportedOperationException();
    }

    public int inputColumns() {
        if (this.last != null) {
            return this.last.numInputs();
        }
        DataSet next = next();
        this.last = next;
        this.useCurrent = true;
        return next.numInputs();
    }

    public int totalOutcomes() {
        if (this.last != null) {
            return this.last.numOutcomes();
        }
        DataSet next = next();
        this.last = next;
        this.useCurrent = true;
        return next.numOutcomes();
    }

    public void reset() {
    }

    public int batch() {
        return this.batchSize;
    }

    public int cursor() {
        throw new UnsupportedOperationException();
    }

    public int numExamples() {
        throw new UnsupportedOperationException();
    }

    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.recordReader.hasNext() || this.overshot;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public DataSet next() {
        return next(this.batchSize);
    }

    @Override // java.util.Iterator
    public void remove() {
        throw new UnsupportedOperationException();
    }
}
