package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.GridExecutioner;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper.class */
public class ParallelWrapper implements AutoCloseable {
    private static Logger logger = LoggerFactory.getLogger(ParallelWrapper.class);
    private Model model;
    private int workers;
    private int prefetchSize;
    private Trainer[] zoo;
    private int averagingFrequency = 1;
    private AtomicLong iterationsCounter = new AtomicLong(0);
    private boolean reportScore = false;
    private boolean averageUpdaters = true;
    private boolean legacyAveraging = false;

    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper$Builder.class */
    public static class Builder {
        private Model model;
        private int workers = 2;
        private int prefetchSize = 16;
        private int averagingFrequency = 1;
        private boolean reportScore = false;
        private boolean averageUpdaters = true;
        private boolean legacyAveraging = true;

        public Builder(@NonNull MultiLayerNetwork multiLayerNetwork) {
            if (multiLayerNetwork == null) {
                throw new NullPointerException("mln");
            }
            this.model = multiLayerNetwork;
        }

        public Builder(@NonNull ComputationGraph computationGraph) {
            if (computationGraph == null) {
                throw new NullPointerException("graph");
            }
            this.model = computationGraph;
        }

        public Builder workers(int i) {
            if (i < 2) {
                throw new RuntimeException("Number of workers can't be lower then 2!");
            }
            this.workers = i;
            return this;
        }

        public Builder averagingFrequency(int i) {
            this.averagingFrequency = i;
            return this;
        }

        public Builder averageUpdaters(boolean z) {
            this.averageUpdaters = z;
            return this;
        }

        public Builder prefetchBuffer(int i) {
            if (i < 0) {
                i = 0;
            }
            this.prefetchSize = i;
            return this;
        }

        public Builder useLegacyAveraging(boolean z) {
            this.legacyAveraging = z;
            return this;
        }

        public Builder reportScoreAfterAveraging(boolean z) {
            this.reportScore = z;
            return this;
        }

        public ParallelWrapper build() {
            ParallelWrapper parallelWrapper = new ParallelWrapper(this.model, this.workers, this.prefetchSize);
            parallelWrapper.averagingFrequency = this.averagingFrequency;
            parallelWrapper.reportScore = this.reportScore;
            parallelWrapper.averageUpdaters = this.averageUpdaters;
            parallelWrapper.legacyAveraging = this.legacyAveraging;
            return parallelWrapper;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/parallelism/ParallelWrapper$Trainer.class */
    public static class Trainer extends Thread implements Runnable {
        private Model originalModel;
        private Model replicatedModel;
        private LinkedBlockingQueue<DataSet> queue;
        private LinkedBlockingQueue<MultiDataSet> queueMDS;
        private AtomicInteger running;
        private int threadId;
        private AtomicBoolean shouldUpdate;
        private AtomicBoolean shouldStop;
        private Exception thrownException;
        private boolean useMDS;

        public Trainer(int i, Model model, boolean z) {
            this(i, model);
            this.useMDS = z;
        }

        public Trainer(int i, Model model) {
            this.queue = new LinkedBlockingQueue<>();
            this.queueMDS = new LinkedBlockingQueue<>();
            this.running = new AtomicInteger(0);
            this.shouldUpdate = new AtomicBoolean(false);
            this.shouldStop = new AtomicBoolean(false);
            this.useMDS = false;
            this.threadId = i;
            setDaemon(true);
            setName("ParallelWrapper trainer " + i);
            this.originalModel = model;
            if (!(model instanceof MultiLayerNetwork) && (model instanceof ComputationGraph)) {
                this.replicatedModel = ((ComputationGraph) model).clone();
                if (i != 0) {
                    this.replicatedModel.setListeners(new ArrayList());
                }
            }
        }

        public void feedMultiDataSet(@NonNull MultiDataSet multiDataSet) {
            if (multiDataSet == null) {
                throw new NullPointerException("dataSet");
            }
            this.running.incrementAndGet();
            this.queueMDS.add(multiDataSet);
        }

        public void feedDataSet(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            this.running.incrementAndGet();
            this.queue.add(dataSet);
        }

        public Model getModel() {
            return this.replicatedModel;
        }

        public void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model");
            }
            this.shouldUpdate.set(true);
            if (this.replicatedModel instanceof MultiLayerNetwork) {
                this.replicatedModel.setParams(model.params().dup());
                INDArray stateViewArray = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
                if (stateViewArray != null) {
                    Updater updater = this.replicatedModel.getUpdater();
                    INDArray dup = stateViewArray.dup();
                    if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                        Nd4j.getExecutioner().flushQueueBlocking();
                    }
                    updater.setStateViewArray(this.replicatedModel, dup, false);
                }
            } else if (this.replicatedModel instanceof ComputationGraph) {
                this.replicatedModel.setParams(model.params().dup());
                INDArray stateViewArray2 = ((ComputationGraph) model).getUpdater().getStateViewArray();
                if (stateViewArray2 != null) {
                    INDArray dup2 = stateViewArray2.dup();
                    if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                        Nd4j.getExecutioner().flushQueueBlocking();
                    }
                    this.replicatedModel.getUpdater().setStateViewArray(dup2);
                }
            }
            if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                Nd4j.getExecutioner().flushQueueBlocking();
            }
        }

        public boolean isRunning() {
            if (this.thrownException != null) {
                throw new RuntimeException(this.thrownException);
            }
            return this.running.get() == 0;
        }

        public void shutdown() {
            this.shouldStop.set(true);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            try {
                if (this.originalModel instanceof MultiLayerNetwork) {
                    this.replicatedModel = new MultiLayerNetwork(this.originalModel.getLayerWiseConfigurations().clone());
                    this.replicatedModel.init();
                } else if (this.originalModel instanceof ComputationGraph) {
                    this.replicatedModel = new ComputationGraph(this.originalModel.getConfiguration().clone());
                    this.replicatedModel.init();
                }
                if (this.useMDS) {
                    MultiDataSet poll = this.queueMDS.poll(100L, TimeUnit.MILLISECONDS);
                    if (poll != null) {
                        if (!(this.replicatedModel instanceof ComputationGraph)) {
                            throw new RuntimeException("MultiDataSet can be fit into ComputationGraph only");
                        }
                        this.replicatedModel.fit(poll);
                        if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                            Nd4j.getExecutioner().flushQueueBlocking();
                        }
                        this.running.decrementAndGet();
                    }
                } else {
                    while (!this.shouldStop.get()) {
                        DataSet poll2 = this.queue.poll(100L, TimeUnit.MILLISECONDS);
                        if (poll2 != null) {
                            if (this.replicatedModel instanceof MultiLayerNetwork) {
                                this.replicatedModel.fit(poll2);
                            } else if (this.replicatedModel instanceof ComputationGraph) {
                                this.replicatedModel.fit(poll2);
                            }
                            if (Nd4j.getExecutioner() instanceof GridExecutioner) {
                                Nd4j.getExecutioner().flushQueueBlocking();
                            }
                            this.running.decrementAndGet();
                        }
                    }
                }
            } catch (Exception e) {
                this.thrownException = e;
            }
        }

        public void waitTillRunning() {
            while (this.running.get() != 0) {
                if (this.thrownException != null) {
                    throw new RuntimeException(this.thrownException);
                }
                try {
                    Thread.sleep(10L);
                } catch (Exception e) {
                }
            }
        }
    }

    protected ParallelWrapper(Model model, int i, int i2) {
        this.workers = 2;
        this.prefetchSize = 2;
        this.model = model;
        this.workers = i;
        this.prefetchSize = i2;
        if (this.model instanceof MultiLayerNetwork) {
            this.model.getUpdater();
        } else if (this.model instanceof ComputationGraph) {
            this.model.getUpdater();
        }
        this.zoo = new Trainer[i];
        for (int i3 = 0; i3 < i; i3++) {
            this.zoo[i3] = new Trainer(i3, model);
            this.zoo[i3].start();
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        if (this.zoo != null) {
            for (int i = 0; i < this.zoo.length; i++) {
                if (this.zoo[i] != null) {
                    this.zoo[i].shutdown();
                }
            }
            this.zoo = null;
        }
    }

    public synchronized void shutdown() {
        try {
            close();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public synchronized void fit(@NonNull MultiDataSetIterator multiDataSetIterator) {
        ComputationGraphUpdater updater;
        if (multiDataSetIterator == null) {
            throw new NullPointerException("source");
        }
        if (this.zoo == null) {
            this.zoo = new Trainer[this.workers];
            for (int i = 0; i < this.workers; i++) {
                this.zoo[i] = new Trainer(i, this.model, true);
                this.zoo[i].start();
            }
        }
        multiDataSetIterator.reset();
        MultiDataSetIterator asyncMultiDataSetIterator = (this.prefetchSize <= 0 || !multiDataSetIterator.asyncSupported()) ? multiDataSetIterator : new AsyncMultiDataSetIterator(multiDataSetIterator, this.prefetchSize);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        while (asyncMultiDataSetIterator.hasNext()) {
            MultiDataSet multiDataSet = (MultiDataSet) asyncMultiDataSetIterator.next();
            int andIncrement = atomicInteger.getAndIncrement();
            this.zoo[andIncrement].feedMultiDataSet(multiDataSet);
            if (andIncrement + 1 == this.workers || !asyncMultiDataSetIterator.hasNext()) {
                this.iterationsCounter.incrementAndGet();
                for (int i2 = 0; i2 < this.workers && i2 < atomicInteger.get(); i2++) {
                    try {
                        this.zoo[i2].waitTillRunning();
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
                if (this.iterationsCounter.get() % this.averagingFrequency == 0 && andIncrement + 1 == this.workers) {
                    double d = 0.0d;
                    if (this.legacyAveraging) {
                        INDArray zeros = Nd4j.zeros(this.model.params().shape());
                        int i3 = 0;
                        while (i3 < this.workers && i3 < atomicInteger.get()) {
                            zeros.addi(this.zoo[i3].getModel().params());
                            d += this.zoo[i3].getModel().score();
                            i3++;
                        }
                        zeros.divi(Integer.valueOf(i3));
                        this.model.setParams(zeros);
                    } else {
                        ArrayList arrayList = new ArrayList();
                        for (int i4 = 0; i4 < this.workers && i4 < atomicInteger.get(); i4++) {
                            arrayList.add(this.zoo[i4].getModel().params());
                            d += this.zoo[i4].getModel().score();
                        }
                        Nd4j.averageAndPropagate(this.model.params(), arrayList);
                    }
                    double min = d / Math.min(this.workers, atomicInteger.get());
                    if (this.reportScore) {
                        logger.info("Averaged score: " + min);
                    }
                    if (!(this.model instanceof ComputationGraph)) {
                        throw new RuntimeException("MultiDataSet might be used only with ComputationGraph model");
                    }
                    if (this.averageUpdaters && (updater = this.model.getUpdater()) != null && updater.getStateViewArray() != null) {
                        if (this.legacyAveraging) {
                            INDArray zeros2 = Nd4j.zeros(updater.getStateViewArray().shape());
                            int i5 = 0;
                            while (i5 < this.workers && i5 < atomicInteger.get()) {
                                zeros2.addi(this.zoo[i5].getModel().getUpdater().getStateViewArray());
                                i5++;
                            }
                            zeros2.divi(Integer.valueOf(i5));
                            updater.setStateViewArray(zeros2);
                        } else {
                            ArrayList arrayList2 = new ArrayList();
                            for (int i6 = 0; i6 < this.workers && i6 < atomicInteger.get(); i6++) {
                                arrayList2.add(this.zoo[i6].getModel().getUpdater().getStateViewArray());
                            }
                            Nd4j.averageAndPropagate(updater.getStateViewArray(), arrayList2);
                        }
                    }
                    this.model.setScore(min);
                    if (this.legacyAveraging) {
                        for (int i7 = 0; i7 < this.workers; i7++) {
                            this.zoo[i7].updateModel(this.model);
                        }
                    }
                }
                atomicInteger.set(0);
            }
        }
        logger.debug("Iterations passed: {}", Long.valueOf(this.iterationsCounter.get()));
        this.iterationsCounter.set(0L);
    }

    public synchronized void fit(@NonNull DataSetIterator dataSetIterator) {
        ComputationGraphUpdater updater;
        Updater updater2;
        if (dataSetIterator == null) {
            throw new NullPointerException("source");
        }
        if (this.zoo == null) {
            this.zoo = new Trainer[this.workers];
            for (int i = 0; i < this.workers; i++) {
                this.zoo[i] = new Trainer(i, this.model);
                this.zoo[i].start();
            }
        }
        dataSetIterator.reset();
        DataSetIterator asyncDataSetIterator = (this.prefetchSize <= 0 || !dataSetIterator.asyncSupported()) ? dataSetIterator : new AsyncDataSetIterator(dataSetIterator, this.prefetchSize);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        while (asyncDataSetIterator.hasNext()) {
            DataSet dataSet = (DataSet) asyncDataSetIterator.next();
            int andIncrement = atomicInteger.getAndIncrement();
            this.zoo[andIncrement].feedDataSet(dataSet);
            if (andIncrement + 1 == this.workers || !asyncDataSetIterator.hasNext()) {
                this.iterationsCounter.incrementAndGet();
                for (int i2 = 0; i2 < this.workers && i2 < atomicInteger.get(); i2++) {
                    try {
                        this.zoo[i2].waitTillRunning();
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
                Nd4j.getExecutioner();
                if (this.iterationsCounter.get() % this.averagingFrequency == 0 && andIncrement + 1 == this.workers) {
                    double d = 0.0d;
                    if (this.legacyAveraging) {
                        INDArray zeros = Nd4j.zeros(this.model.params().shape());
                        int i3 = 0;
                        while (i3 < this.workers && i3 < atomicInteger.get()) {
                            zeros.addi(this.zoo[i3].getModel().params());
                            d += this.zoo[i3].getModel().score();
                            i3++;
                        }
                        zeros.divi(Integer.valueOf(i3));
                        this.model.setParams(zeros);
                    } else {
                        ArrayList arrayList = new ArrayList();
                        for (int i4 = 0; i4 < this.workers && i4 < atomicInteger.get(); i4++) {
                            arrayList.add(this.zoo[i4].getModel().params());
                            d += this.zoo[i4].getModel().score();
                        }
                        Nd4j.averageAndPropagate(this.model.params(), arrayList);
                    }
                    double min = d / Math.min(this.workers, atomicInteger.get());
                    if (this.reportScore) {
                        logger.info("Averaged score: " + min);
                    }
                    if (this.model instanceof MultiLayerNetwork) {
                        if (this.averageUpdaters && (updater2 = this.model.getUpdater()) != null && updater2.getStateViewArray() != null) {
                            if (this.legacyAveraging) {
                                INDArray zeros2 = Nd4j.zeros(updater2.getStateViewArray().shape());
                                int i5 = 0;
                                while (i5 < this.workers && i5 < atomicInteger.get()) {
                                    zeros2.addi(this.zoo[i5].getModel().getUpdater().getStateViewArray().dup());
                                    i5++;
                                }
                                zeros2.divi(Integer.valueOf(i5));
                                updater2.setStateViewArray(this.model, zeros2, false);
                            } else {
                                ArrayList arrayList2 = new ArrayList();
                                for (int i6 = 0; i6 < this.workers && i6 < atomicInteger.get(); i6++) {
                                    arrayList2.add(this.zoo[i6].getModel().getUpdater().getStateViewArray());
                                }
                                Nd4j.averageAndPropagate(updater2.getStateViewArray(), arrayList2);
                            }
                        }
                        this.model.setScore(min);
                    } else if (this.model instanceof ComputationGraph) {
                        if (this.averageUpdaters && (updater = this.model.getUpdater()) != null && updater.getStateViewArray() != null) {
                            if (this.legacyAveraging) {
                                INDArray zeros3 = Nd4j.zeros(updater.getStateViewArray().shape());
                                int i7 = 0;
                                while (i7 < this.workers && i7 < atomicInteger.get()) {
                                    zeros3.addi(this.zoo[i7].getModel().getUpdater().getStateViewArray());
                                    i7++;
                                }
                                zeros3.divi(Integer.valueOf(i7));
                                updater.setStateViewArray(zeros3);
                            } else {
                                ArrayList arrayList3 = new ArrayList();
                                for (int i8 = 0; i8 < this.workers && i8 < atomicInteger.get(); i8++) {
                                    arrayList3.add(this.zoo[i8].getModel().getUpdater().getStateViewArray());
                                }
                                Nd4j.averageAndPropagate(updater.getStateViewArray(), arrayList3);
                            }
                        }
                        this.model.setScore(min);
                    }
                    if (this.legacyAveraging) {
                        for (int i9 = 0; i9 < this.workers; i9++) {
                            this.zoo[i9].updateModel(this.model);
                        }
                    }
                }
                atomicInteger.set(0);
            }
        }
        logger.debug("Iterations passed: {}", Long.valueOf(this.iterationsCounter.get()));
        this.iterationsCounter.set(0L);
    }
}
