package org.nd4j.linalg.dataset;

import java.util.List;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.callbacks.DataSetCallback;
import org.nd4j.linalg.dataset.callbacks.DefaultCallback;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/AsyncDataSetIterator.class */
public class AsyncDataSetIterator implements DataSetIterator {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) AsyncDataSetIterator.class);
    protected DataSetIterator backedIterator;
    protected DataSet terminator;
    protected DataSet nextElement;
    protected BlockingQueue<DataSet> buffer;
    protected AsyncPrefetchThread thread;
    protected AtomicBoolean shouldWork;
    protected volatile RuntimeException throwable;
    protected boolean useWorkspace;
    protected int prefetchSize;
    protected String workspaceId;
    protected Integer deviceId;
    protected AtomicBoolean hasDepleted;
    protected DataSetCallback callback;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/linalg/dataset/AsyncDataSetIterator$AsyncPrefetchThread.class */
    public class AsyncPrefetchThread extends Thread implements Runnable {
        private BlockingQueue<DataSet> queue;
        private DataSetIterator iterator;
        private DataSet terminator;
        private boolean isShutdown = false;
        private WorkspaceConfiguration configuration;
        private MemoryWorkspace workspace;
        private final int deviceId;

        protected AsyncPrefetchThread(@NonNull BlockingQueue<DataSet> blockingQueue, @NonNull DataSetIterator dataSetIterator, @NonNull DataSet dataSet, MemoryWorkspace memoryWorkspace, int i) {
            this.configuration = WorkspaceConfiguration.builder().minSize(10485760L).overallocationLimit(AsyncDataSetIterator.this.prefetchSize + 2).policyReset(ResetPolicy.ENDOFBUFFER_REACHED).policyLearning(LearningPolicy.FIRST_LOOP).policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE).build();
            if (blockingQueue == null) {
                throw new NullPointerException("queue is marked @NonNull but is null");
            }
            if (dataSetIterator == null) {
                throw new NullPointerException("iterator is marked @NonNull but is null");
            }
            if (dataSet == null) {
                throw new NullPointerException("terminator is marked @NonNull but is null");
            }
            this.queue = blockingQueue;
            this.iterator = dataSetIterator;
            this.terminator = dataSet;
            this.deviceId = i;
            setDaemon(true);
            setName("ADSI prefetch thread");
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            DataSet next;
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(this.deviceId));
            AsyncDataSetIterator.this.externalCall();
            try {
                try {
                    try {
                        try {
                            if (AsyncDataSetIterator.this.useWorkspace) {
                                this.workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(this.configuration, AsyncDataSetIterator.this.workspaceId);
                            }
                            while (this.iterator.hasNext() && AsyncDataSetIterator.this.shouldWork.get()) {
                                if (AsyncDataSetIterator.this.useWorkspace) {
                                    MemoryWorkspace notifyScopeEntered = this.workspace.notifyScopeEntered();
                                    Throwable th = null;
                                    try {
                                        try {
                                            next = this.iterator.next();
                                            if (AsyncDataSetIterator.this.callback != null) {
                                                AsyncDataSetIterator.this.callback.call(next);
                                            }
                                            if (notifyScopeEntered != null) {
                                                if (0 != 0) {
                                                    try {
                                                        notifyScopeEntered.close();
                                                    } catch (Throwable th2) {
                                                        th.addSuppressed(th2);
                                                    }
                                                } else {
                                                    notifyScopeEntered.close();
                                                }
                                            }
                                        } catch (Throwable th3) {
                                            th = th3;
                                            throw th3;
                                        }
                                    } catch (Throwable th4) {
                                        if (notifyScopeEntered != null) {
                                            if (th != null) {
                                                try {
                                                    notifyScopeEntered.close();
                                                } catch (Throwable th5) {
                                                    th.addSuppressed(th5);
                                                }
                                            } else {
                                                notifyScopeEntered.close();
                                            }
                                        }
                                        throw th4;
                                    }
                                } else {
                                    next = this.iterator.next();
                                    if (AsyncDataSetIterator.this.callback != null) {
                                        AsyncDataSetIterator.this.callback.call(next);
                                    }
                                }
                                Nd4j.getExecutioner().commit();
                                if (next != null) {
                                    this.queue.put(next);
                                }
                            }
                            this.queue.put(this.terminator);
                            synchronized (this) {
                                this.isShutdown = true;
                                notifyAll();
                            }
                        } catch (Throwable th6) {
                            synchronized (this) {
                                this.isShutdown = true;
                                notifyAll();
                                throw th6;
                            }
                        }
                    } catch (RuntimeException e) {
                        AsyncDataSetIterator.this.throwable = e;
                        throw new RuntimeException(e);
                    }
                } catch (Exception e2) {
                    AsyncDataSetIterator.this.throwable = new RuntimeException(e2);
                    throw new RuntimeException(e2);
                }
            } catch (InterruptedException e3) {
                Thread.currentThread().interrupt();
                AsyncDataSetIterator.this.shouldWork.set(false);
                synchronized (this) {
                    this.isShutdown = true;
                    notifyAll();
                }
            }
        }

        public void shutdown() {
            synchronized (this) {
                while (!this.isShutdown) {
                    try {
                        wait();
                    } catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                        throw new RuntimeException(e);
                    }
                }
            }
            if (this.workspace != null) {
                AsyncDataSetIterator.log.debug("Manually destroying ADSI workspace");
                this.workspace.destroyWorkspace(true);
            }
        }
    }

    protected AsyncDataSetIterator() {
        this.terminator = new DataSet();
        this.nextElement = null;
        this.shouldWork = new AtomicBoolean(true);
        this.throwable = null;
        this.useWorkspace = true;
        this.hasDepleted = new AtomicBoolean(false);
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator) {
        this(dataSetIterator, 8);
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator, int i, BlockingQueue<DataSet> blockingQueue) {
        this(dataSetIterator, i, blockingQueue, true);
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator, int i) {
        this(dataSetIterator, i, new LinkedBlockingQueue(i));
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator, int i, boolean z) {
        this(dataSetIterator, i, new LinkedBlockingQueue(i), z);
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator, int i, boolean z, Integer num) {
        this(dataSetIterator, i, new LinkedBlockingQueue(i), z, new DefaultCallback(), num);
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator, int i, boolean z, DataSetCallback dataSetCallback) {
        this(dataSetIterator, i, new LinkedBlockingQueue(i), z, dataSetCallback);
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator, int i, BlockingQueue<DataSet> blockingQueue, boolean z) {
        this(dataSetIterator, i, blockingQueue, z, new DefaultCallback());
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator, int i, BlockingQueue<DataSet> blockingQueue, boolean z, DataSetCallback dataSetCallback) {
        this(dataSetIterator, i, blockingQueue, z, dataSetCallback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
    }

    public AsyncDataSetIterator(DataSetIterator dataSetIterator, int i, BlockingQueue<DataSet> blockingQueue, boolean z, DataSetCallback dataSetCallback, Integer num) {
        this.terminator = new DataSet();
        this.nextElement = null;
        this.shouldWork = new AtomicBoolean(true);
        this.throwable = null;
        this.useWorkspace = true;
        this.hasDepleted = new AtomicBoolean(false);
        i = i < 2 ? 2 : i;
        this.deviceId = num;
        this.callback = dataSetCallback;
        this.useWorkspace = z;
        this.buffer = blockingQueue;
        this.prefetchSize = i;
        this.backedIterator = dataSetIterator;
        this.workspaceId = "ADSI_ITER-" + UUID.randomUUID().toString();
        if (dataSetIterator.resetSupported() && !dataSetIterator.hasNext()) {
            this.backedIterator.reset();
        }
        this.thread = new AsyncPrefetchThread(this.buffer, dataSetIterator, this.terminator, null, num.intValue());
        this.thread.setDaemon(true);
        this.thread.start();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSet next(int i) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int inputColumns() {
        return this.backedIterator.inputColumns();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int totalOutcomes() {
        return this.backedIterator.totalOutcomes();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean resetSupported() {
        return this.backedIterator.resetSupported();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public boolean asyncSupported() {
        return false;
    }

    protected void externalCall() {
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void reset() {
        this.buffer.clear();
        if (this.thread != null) {
            this.thread.interrupt();
        }
        try {
            if (this.thread != null) {
                this.thread.join();
            }
            this.thread.shutdown();
            this.buffer.clear();
            this.backedIterator.reset();
            this.shouldWork.set(true);
            this.thread = new AsyncPrefetchThread(this.buffer, this.backedIterator, this.terminator, null, this.deviceId.intValue());
            this.thread.setDaemon(true);
            this.thread.start();
            this.hasDepleted.set(false);
            this.nextElement = null;
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    public void shutdown() {
        this.buffer.clear();
        if (this.thread != null) {
            this.thread.interrupt();
        }
        try {
            if (this.thread != null) {
                this.thread.join();
            }
            this.thread.shutdown();
            this.buffer.clear();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public int batch() {
        return this.backedIterator.batch();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public void setPreProcessor(DataSetPreProcessor dataSetPreProcessor) {
        this.backedIterator.setPreProcessor(dataSetPreProcessor);
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public DataSetPreProcessor getPreProcessor() {
        return this.backedIterator.getPreProcessor();
    }

    @Override // org.nd4j.linalg.dataset.api.iterator.DataSetIterator
    public List<String> getLabels() {
        return this.backedIterator.getLabels();
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        if (this.throwable != null) {
            throw this.throwable;
        }
        try {
            if (this.hasDepleted.get()) {
                return false;
            }
            if (this.nextElement != null && this.nextElement != this.terminator) {
                return true;
            }
            if (this.nextElement == this.terminator) {
                return false;
            }
            this.nextElement = this.buffer.take();
            if (this.nextElement != this.terminator) {
                return true;
            }
            this.hasDepleted.set(true);
            return false;
        } catch (Exception e) {
            log.error("Premature end of loop!");
            throw new RuntimeException(e);
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public DataSet next() {
        if (this.throwable != null) {
            throw this.throwable;
        }
        if (this.hasDepleted.get()) {
            return null;
        }
        DataSet dataSet = this.nextElement;
        this.nextElement = null;
        return dataSet;
    }

    @Override // java.util.Iterator
    public void remove() {
    }
}
