package org.deeplearning4j.scaleout.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.UntypedActor;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedDeque;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.nn.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.actor.core.protocol.ResetMessage;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.api.workrouter.WorkRouter;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.job.JobIterator;
import org.deeplearning4j.scaleout.messages.DoneMessage;
import org.deeplearning4j.scaleout.messages.MoreWorkMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/scaleout/actor/core/actor/BatchActor.class */
public class BatchActor extends UntypedActor implements DeepLearningConfigurable {
    protected JobIterator iter;
    private static final Logger log = LoggerFactory.getLogger(BatchActor.class);
    public static final String BATCH = "batch";
    private transient StateTracker stateTracker;
    private transient Configuration conf;
    private WorkRouter workRouter;
    private final ActorRef mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
    private Queue<String> workers = new ConcurrentLinkedDeque();
    private int numDataSets = 0;

    public BatchActor(JobIterator jobIterator, StateTracker stateTracker, Configuration configuration, WorkRouter workRouter) {
        this.iter = jobIterator;
        this.stateTracker = stateTracker;
        this.conf = configuration;
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MasterActor.SHUTDOWN, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(BATCH, getSelf()), getSelf());
        this.workRouter = workRouter;
    }

    public void onReceive(Object obj) throws Exception {
        if ((obj instanceof DistributedPubSubMediator.SubscribeAck) || (obj instanceof DistributedPubSubMediator.UnsubscribeAck)) {
            log.info("Susbcribed batch actor");
            this.mediator.tell(new DistributedPubSubMediator.Publish("topics", obj), getSelf());
            return;
        }
        if (obj instanceof ResetMessage) {
            this.iter.reset();
            self().tell(MoreWorkMessage.getInstance(), self());
            return;
        }
        if (obj instanceof MoreWorkMessage) {
            log.info("Saving model");
            this.mediator.tell(new DistributedPubSubMediator.Publish(ModelSavingActor.SAVE, MoreWorkMessage.getInstance()), this.mediator);
            if (!this.iter.hasNext()) {
                if (this.iter.hasNext()) {
                    unhandled(obj);
                    return;
                } else {
                    this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, DoneMessage.getInstance()), this.mediator);
                    return;
                }
            }
            log.info("Propagating new work to master");
            this.numDataSets++;
            log.info("Iterating over next dataset " + this.numDataSets);
            List workers = this.stateTracker.workers();
            Iterator it = workers.iterator();
            while (it.hasNext()) {
                log.info("Worker " + ((String) it.next()));
            }
            Iterator it2 = this.stateTracker.workerData().iterator();
            while (it2.hasNext()) {
                this.stateTracker.removeWorkerData((String) it2.next());
            }
            int size = workers.size();
            int inputSplit = this.stateTracker.inputSplit();
            if (size == 0) {
                size = Runtime.getRuntime().availableProcessors();
            }
            log.info("Number of workers " + size + " and batch size is " + inputSplit);
            Iterator it3 = this.stateTracker.workers().iterator();
            while (it3.hasNext()) {
                this.stateTracker.enableWorker((String) it3.next());
            }
            log.info("Batch size for worker is " + (size * inputSplit));
            int i = 0;
            while (i < size && this.iter.hasNext()) {
                String nextWorker = nextWorker();
                log.info("Saving data for worker " + nextWorker);
                if (nextWorker == null) {
                    i--;
                } else {
                    Job next = this.iter.next(nextWorker);
                    if (next == null) {
                        break;
                    } else {
                        this.workRouter.routeJob(next);
                    }
                }
                i++;
            }
            this.stateTracker.incrementBatchesRan(workers.size());
            this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, this.stateTracker.workerData()), this.mediator);
        }
    }

    private String nextWorker() {
        while (this.workers.isEmpty()) {
            for (String str : this.stateTracker.workers()) {
                if (this.stateTracker.jobFor(str) == null && !this.workers.contains(str)) {
                    this.workers.add(str);
                }
            }
            log.info("Refilling queue with size of " + this.workers.size() + " out of " + this.stateTracker.numWorkers());
        }
        return this.workers.poll();
    }

    public JobIterator getIter() {
        return this.iter;
    }

    public void setup(Configuration configuration) {
    }
}
