package org.deeplearning4j.spark.impl.evaluation;

import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import java.util.Queue;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator;
import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.common.base.Preconditions;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.class */
public class EvaluationRunner {
    private static final Logger log = LoggerFactory.getLogger(EvaluationRunner.class);
    private static final EvaluationRunner INSTANCE = new EvaluationRunner();
    private final AtomicInteger workerCount = new AtomicInteger(0);
    private Queue<Eval> queue = new ConcurrentLinkedQueue();
    private Map<byte[], DeviceLocalNDArray> paramsMap = new WeakHashMap();

    /* loaded from: input_file:org/deeplearning4j/spark/impl/evaluation/EvaluationRunner$Eval.class */
    private static class Eval {
        private Iterator<DataSet> ds;
        private Iterator<MultiDataSet> mds;
        private IEvaluation[] evaluations;
        private EvaluationFuture future;

        public Eval(Iterator<DataSet> it, Iterator<MultiDataSet> it2, IEvaluation[] iEvaluationArr, EvaluationFuture evaluationFuture) {
            this.ds = it;
            this.mds = it2;
            this.evaluations = iEvaluationArr;
            this.future = evaluationFuture;
        }

        public Iterator<DataSet> getDs() {
            return this.ds;
        }

        public Iterator<MultiDataSet> getMds() {
            return this.mds;
        }

        public IEvaluation[] getEvaluations() {
            return this.evaluations;
        }

        public EvaluationFuture getFuture() {
            return this.future;
        }

        public void setDs(Iterator<DataSet> it) {
            this.ds = it;
        }

        public void setMds(Iterator<MultiDataSet> it) {
            this.mds = it;
        }

        public void setEvaluations(IEvaluation[] iEvaluationArr) {
            this.evaluations = iEvaluationArr;
        }

        public void setFuture(EvaluationFuture evaluationFuture) {
            this.future = evaluationFuture;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Eval)) {
                return false;
            }
            Eval eval = (Eval) obj;
            if (!eval.canEqual(this)) {
                return false;
            }
            Iterator<DataSet> ds = getDs();
            Iterator<DataSet> ds2 = eval.getDs();
            if (ds == null) {
                if (ds2 != null) {
                    return false;
                }
            } else if (!ds.equals(ds2)) {
                return false;
            }
            Iterator<MultiDataSet> mds = getMds();
            Iterator<MultiDataSet> mds2 = eval.getMds();
            if (mds == null) {
                if (mds2 != null) {
                    return false;
                }
            } else if (!mds.equals(mds2)) {
                return false;
            }
            if (!Arrays.deepEquals(getEvaluations(), eval.getEvaluations())) {
                return false;
            }
            EvaluationFuture future = getFuture();
            EvaluationFuture future2 = eval.getFuture();
            return future == null ? future2 == null : future.equals(future2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Eval;
        }

        public int hashCode() {
            Iterator<DataSet> ds = getDs();
            int hashCode = (1 * 59) + (ds == null ? 43 : ds.hashCode());
            Iterator<MultiDataSet> mds = getMds();
            int hashCode2 = (((hashCode * 59) + (mds == null ? 43 : mds.hashCode())) * 59) + Arrays.deepHashCode(getEvaluations());
            EvaluationFuture future = getFuture();
            return (hashCode2 * 59) + (future == null ? 43 : future.hashCode());
        }

        public String toString() {
            return "EvaluationRunner.Eval(ds=" + getDs() + ", mds=" + getMds() + ", evaluations=" + Arrays.deepToString(getEvaluations()) + ", future=" + getFuture() + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/spark/impl/evaluation/EvaluationRunner$EvaluationFuture.class */
    public static class EvaluationFuture implements Future<IEvaluation[]> {
        private Semaphore semaphore;
        private IEvaluation[] result;
        private Throwable exception;

        private EvaluationFuture() {
            this.semaphore = new Semaphore(0);
        }

        @Override // java.util.concurrent.Future
        public boolean cancel(boolean z) {
            throw new UnsupportedOperationException("Not supported");
        }

        @Override // java.util.concurrent.Future
        public boolean isCancelled() {
            return false;
        }

        @Override // java.util.concurrent.Future
        public boolean isDone() {
            return this.semaphore.availablePermits() > 0;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Future
        public IEvaluation[] get() throws InterruptedException, ExecutionException {
            if (this.result == null && this.exception == null) {
                this.semaphore.acquire();
            }
            if (this.exception != null) {
                throw new ExecutionException(this.exception);
            }
            return this.result;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Future
        public IEvaluation[] get(long j, @NonNull TimeUnit timeUnit) {
            if (timeUnit == null) {
                throw new NullPointerException("unit is marked non-null but is null");
            }
            throw new UnsupportedOperationException();
        }

        public void setSemaphore(Semaphore semaphore) {
            this.semaphore = semaphore;
        }

        public void setResult(IEvaluation[] iEvaluationArr) {
            this.result = iEvaluationArr;
        }

        public void setException(Throwable th) {
            this.exception = th;
        }

        public Semaphore getSemaphore() {
            return this.semaphore;
        }

        public IEvaluation[] getResult() {
            return this.result;
        }

        public Throwable getException() {
            return this.exception;
        }
    }

    public static EvaluationRunner getInstance() {
        return INSTANCE;
    }

    private EvaluationRunner() {
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Future<IEvaluation[]> execute(IEvaluation[] iEvaluationArr, int i, int i2, Iterator<DataSet> it, Iterator<MultiDataSet> it2, boolean z, Broadcast<String> broadcast, Broadcast<byte[]> broadcast2) {
        DeviceLocalNDArray deviceLocalNDArray;
        int i3;
        ComputationGraph computationGraph;
        Preconditions.checkArgument(i > 0, "Invalid number of evaluation workers: must be > 0. Got: %s", i);
        Preconditions.checkState((it == null && it2 == null) ? false : true, "No data provided - both DataSet and MultiDataSet iterators were null");
        if (Nd4j.getAffinityManager().getNumberOfDevices() <= 0) {
        }
        synchronized (this) {
            if (!this.paramsMap.containsKey(broadcast2.getValue())) {
                try {
                    this.paramsMap.put(broadcast2.getValue(), new DeviceLocalNDArray(Nd4j.read(new ByteArrayInputStream((byte[]) broadcast2.getValue()))));
                } catch (RuntimeException e) {
                    throw new RuntimeException(e);
                }
            }
            deviceLocalNDArray = this.paramsMap.get(broadcast2.getValue());
        }
        do {
            i3 = this.workerCount.get();
            if (i3 >= i) {
                log.debug("Submitting evaluation from thread {} for processing in evaluation thread", Long.valueOf(Thread.currentThread().getId()));
                EvaluationFuture evaluationFuture = new EvaluationFuture();
                this.queue.add(new Eval(it, it2, iEvaluationArr, evaluationFuture));
                return evaluationFuture;
            }
        } while (!this.workerCount.compareAndSet(i3, i3 + 1));
        log.debug("Starting evaluation in thread {}", Long.valueOf(Thread.currentThread().getId()));
        EvaluationFuture evaluationFuture2 = new EvaluationFuture();
        evaluationFuture2.setResult(iEvaluationArr);
        try {
            if (z) {
                ComputationGraph computationGraph2 = new ComputationGraph(ComputationGraphConfiguration.fromJson((String) broadcast.getValue()));
                computationGraph2.init(deviceLocalNDArray.get(), false);
                computationGraph = computationGraph2;
            } else {
                ComputationGraph multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) broadcast.getValue()));
                multiLayerNetwork.init(deviceLocalNDArray.get(), false);
                computationGraph = multiLayerNetwork;
            }
            try {
                try {
                    doEval(computationGraph, iEvaluationArr, it, it2, i2);
                    evaluationFuture2.getSemaphore().release(1);
                } catch (Throwable th) {
                    evaluationFuture2.getSemaphore().release(1);
                    throw th;
                }
            } catch (Throwable th2) {
                evaluationFuture2.setException(th2);
                evaluationFuture2.getSemaphore().release(1);
            }
            while (!this.queue.isEmpty()) {
                Eval poll = this.queue.poll();
                if (poll != null) {
                    try {
                        try {
                            doEval(computationGraph, iEvaluationArr, poll.getDs(), poll.getMds(), i2);
                            poll.getFuture().getSemaphore().release(1);
                        } catch (Throwable th3) {
                            poll.getFuture().setException(th3);
                            poll.getFuture().getSemaphore().release(1);
                        }
                    } catch (Throwable th4) {
                        poll.getFuture().getSemaphore().release(1);
                        throw th4;
                    }
                }
            }
            Nd4j.getExecutioner().commit();
            return evaluationFuture2;
        } finally {
            this.workerCount.decrementAndGet();
            log.debug("Finished evaluation in thread {}", Long.valueOf(Thread.currentThread().getId()));
        }
    }

    private static void doEval(Model model, IEvaluation[] iEvaluationArr, Iterator<DataSet> it, Iterator<MultiDataSet> it2, int i) {
        if (model instanceof MultiLayerNetwork) {
            MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) model;
            if (it != null) {
                multiLayerNetwork.doEvaluation(new IteratorDataSetIterator(it, i), iEvaluationArr);
                return;
            } else {
                multiLayerNetwork.doEvaluation(new IteratorMultiDataSetIterator(it2, i), iEvaluationArr);
                return;
            }
        }
        ComputationGraph computationGraph = (ComputationGraph) model;
        if (it != null) {
            computationGraph.doEvaluation(new IteratorDataSetIterator(it, i), iEvaluationArr);
        } else {
            computationGraph.doEvaluation(new IteratorMultiDataSetIterator(it2, i), iEvaluationArr);
        }
    }
}
