package org.nd4j.linalg.dataset;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.nd4j.common.util.ND4JFileUtils;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.class */
public class MiniBatchFileDataSetIterator implements DataSetIterator {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) MiniBatchFileDataSetIterator.class);
    private int batchSize;
    private List<String[]> paths;
    private int currIdx;
    private File rootDir;
    private int totalExamples;
    private int totalLabels;
    private int totalBatches;
    private DataSetPreProcessor dataSetPreProcessor;

    public MiniBatchFileDataSetIterator(DataSet dataSet, int i) throws IOException {
        this(dataSet, i, true);
    }

    public MiniBatchFileDataSetIterator(DataSet dataSet, int i, boolean z, File file) throws IOException {
        this.totalBatches = -1;
        if (dataSet.numExamples() < i) {
            throw new IllegalAccessError("Number of examples smaller than batch size");
        }
        this.batchSize = i;
        this.rootDir = new File(file, UUID.randomUUID().toString());
        this.rootDir.mkdirs();
        if (z) {
            Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { // from class: org.nd4j.linalg.dataset.MiniBatchFileDataSetIterator.1
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        FileUtils.deleteDirectory(MiniBatchFileDataSetIterator.this.rootDir);
                    } catch (IOException e) {
                        MiniBatchFileDataSetIterator.log.error("", (Throwable) e);
                    }
                }
            }));
        }
        this.currIdx = 0;
        this.paths = new ArrayList();
        this.totalExamples = dataSet.numExamples();
        this.totalLabels = dataSet.numOutcomes();
        int i2 = 0;
        this.totalBatches = dataSet.numExamples() / i;
        for (int i3 = 0; i3 < dataSet.numExamples() / i; i3++) {
            this.paths.add(writeData(new DataSet(dataSet.getFeatures().get(NDArrayIndex.interval(i2, i2 + i)), dataSet.getLabels().get(NDArrayIndex.interval(i2, i2 + i)))));
            i2 += i;
            if (i2 >= this.totalExamples) {
                return;
            }
        }
    }

    public MiniBatchFileDataSetIterator(DataSet dataSet, int i, boolean z) throws IOException {
        this(dataSet, i, z, ND4JFileUtils.getTempDir());
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSet next(int i) {
        throw new UnsupportedOperationException("Unable to load custom number of examples");
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int inputColumns() {
        return 0;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int totalOutcomes() {
        return this.totalLabels;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean resetSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean asyncSupported() {
        return true;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void reset() {
        this.currIdx = 0;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int batch() {
        return this.batchSize;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.dataSetPreProcessor = dataSetPreProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSetPreProcessor getPreProcessor() {
        return this.dataSetPreProcessor;
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public List<String> getLabels() {
        return null;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        return this.currIdx < this.totalBatches;
    }

    @Override // java.util.Iterator
    public void remove() {
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public DataSet next() {
        try {
            DataSet read = read(this.currIdx);
            if (this.dataSetPreProcessor != null) {
                this.dataSetPreProcessor.preProcess(read);
            }
            this.currIdx++;
            return read;
        } catch (IOException e) {
            throw new IllegalStateException("Unable to read dataset");
        }
    }

    private DataSet read(int i) throws IOException {
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(new FileInputStream(this.paths.get(i)[0])));
        DataInputStream dataInputStream2 = new DataInputStream(new BufferedInputStream(new FileInputStream(this.paths.get(i)[1])));
        DataSet dataSet = new DataSet(Nd4j.read(dataInputStream), Nd4j.read(dataInputStream2));
        dataInputStream.close();
        dataInputStream2.close();
        return dataSet;
    }

    private String[] writeData(DataSet dataSet) throws IOException {
        String uuid = UUID.randomUUID().toString();
        DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(new File(this.rootDir, uuid + ".bin"))));
        Nd4j.write(dataSet.getFeatures(), dataOutputStream);
        dataOutputStream.flush();
        dataOutputStream.close();
        DataOutputStream dataOutputStream2 = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(new File(this.rootDir, uuid + ".labels.bin"))));
        Nd4j.write(dataSet.getLabels(), dataOutputStream2);
        dataOutputStream2.flush();
        dataOutputStream2.close();
        return new String[]{new File(this.rootDir, uuid + ".bin").getAbsolutePath(), new File(this.rootDir, uuid + ".labels.bin").getAbsolutePath()};
    }

    public File getRootDir() {
        return this.rootDir;
    }

    public void setRootDir(File file) {
        this.rootDir = file;
    }
}
