package org.nd4j.linalg.dataset.api.iterator;

import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/iterator/MultipleEpochsIterator.class */
public class MultipleEpochsIterator implements DataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(MultipleEpochsIterator.class);
    private int numPasses;
    private DataSetIterator iter;
    private DataSetPreProcessor preProcessor;
    private int batch = 0;
    private int passes = 0;

    public MultipleEpochsIterator(int i, DataSetIterator dataSetIterator) {
        this.numPasses = i;
        this.iter = dataSetIterator;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSet next(int i) {
        if (!this.iter.hasNext() && this.passes < this.numPasses) {
            this.passes++;
            this.batch = 0;
            log.info("Epoch " + this.passes + " batch " + this.batch);
            this.iter.reset();
        }
        this.batch++;
        DataSet next = this.iter.next(i);
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(next);
        }
        return next;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int totalExamples() {
        return this.iter.totalExamples();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int inputColumns() {
        return this.iter.inputColumns();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int totalOutcomes() {
        return this.iter.totalOutcomes();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void reset() {
        this.passes = 0;
        this.batch = 0;
        this.iter.reset();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int batch() {
        return this.iter.batch();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int cursor() {
        return this.iter.cursor();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int numExamples() {
        return this.iter.numExamples();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.preProcessor = dataSetPreProcessor;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.iter.hasNext() || this.passes < this.numPasses;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public DataSet next() {
        if (!this.iter.hasNext() && this.passes < this.numPasses) {
            this.passes++;
            this.batch = 0;
            log.info("Epoch " + this.passes + " batch " + this.batch);
            this.iter.reset();
        }
        this.batch++;
        DataSet next = this.iter.next();
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(next);
        }
        return next;
    }

    @Override // java.util.Iterator
    public void remove() {
        this.iter.remove();
    }
}
