package org.deeplearning4j.datasets.iterator.impl;

import java.util.concurrent.atomic.AtomicLong;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/datasets/iterator/impl/BenchmarkMultiDataSetIterator.class */
public class BenchmarkMultiDataSetIterator implements MultiDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) BenchmarkMultiDataSetIterator.class);
    private INDArray[] baseFeatures;
    private INDArray[] baseLabels;
    private long limit;
    private AtomicLong counter = new AtomicLong(0);

    public BenchmarkMultiDataSetIterator(int[][] iArr, int[] iArr2, int i) {
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException("Number of input features must match length of input labels.");
        }
        this.baseFeatures = new INDArray[iArr.length];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            this.baseFeatures[i2] = Nd4j.rand(iArr[i2]);
        }
        this.baseLabels = new INDArray[iArr.length];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            this.baseLabels[i3] = Nd4j.create(iArr[i3][0], iArr2[i3]);
            this.baseLabels[i3].getColumn(1L).assign(Double.valueOf(1.0d));
        }
        Nd4j.getExecutioner().commit();
        this.limit = i;
    }

    public BenchmarkMultiDataSetIterator(MultiDataSet multiDataSet, int i) {
        this.baseFeatures = new INDArray[multiDataSet.getFeatures().length];
        for (int i2 = 0; i2 < multiDataSet.getFeatures().length; i2++) {
            this.baseFeatures[i2] = multiDataSet.getFeatures()[i2].dup();
        }
        this.baseLabels = new INDArray[multiDataSet.getLabels().length];
        for (int i3 = 0; i3 < multiDataSet.getLabels().length; i3++) {
            this.baseFeatures[i3] = multiDataSet.getLabels()[i3].dup();
        }
        Nd4j.getExecutioner().commit();
        this.limit = i;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public MultiDataSet next(int i) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean resetSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public boolean asyncSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void reset() {
        this.counter.set(0L);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public void setPreProcessor(MultiDataSetPreProcessor multiDataSetPreProcessor) {
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator
    public MultiDataSetPreProcessor getPreProcessor() {
        return null;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.counter.get() < this.limit;
    }

    @Override // java.util.Iterator
    /* renamed from: next, reason: merged with bridge method [inline-methods] */
    public org.nd4j.linalg.dataset.api.MultiDataSet next2() {
        this.counter.incrementAndGet();
        INDArray[] iNDArrayArr = new INDArray[this.baseFeatures.length];
        for (int i = 0; i < this.baseFeatures.length; i++) {
            iNDArrayArr[i] = this.baseFeatures[i];
        }
        INDArray[] iNDArrayArr2 = new INDArray[this.baseLabels.length];
        for (int i2 = 0; i2 < this.baseLabels.length; i2++) {
            iNDArrayArr2[i2] = this.baseLabels[i2];
        }
        return new MultiDataSet(iNDArrayArr, iNDArrayArr2);
    }

    @Override // java.util.Iterator
    public void remove() {
    }
}
