package org.deeplearning4j.iterativereduce.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.Cancellable;
import akka.actor.OneForOneStrategy;
import akka.actor.SupervisorStrategy;
import akka.actor.UntypedActor;
import akka.cluster.Cluster;
import akka.contrib.pattern.ClusterReceptionistExtension;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import akka.event.Logging;
import akka.event.LoggingAdapter;
import akka.japi.Function;
import java.io.DataOutputStream;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.iterativereduce.actor.core.Job;
import org.deeplearning4j.iterativereduce.actor.core.MoreWorkMessage;
import org.deeplearning4j.iterativereduce.actor.util.ActorRefUtils;
import org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.ComputableMaster;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import scala.Option;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/core/actor/MasterActor.class */
public abstract class MasterActor<E extends Updateable<?>> extends UntypedActor implements DeepLearningConfigurable, ComputableMaster<E> {
    protected Conf conf;
    protected LoggingAdapter log;
    protected ActorRef batchActor;
    protected StateTracker<E> stateTracker;
    protected int epochsComplete;
    protected AtomicLong oneDown;
    protected final ActorRef mediator;
    public static String BROADCAST = "broadcast";
    public static String MASTER = "result";
    public static String SHUTDOWN = "shutdown";
    public static String FINISH = "finish";
    Cluster cluster;
    ClusterReceptionistExtension receptionist;
    protected boolean isDone;
    protected Cancellable forceNextPhase;
    protected Cancellable clearStateWorkers;
    private boolean began;

    public MasterActor(Conf conf, ActorRef actorRef, StateTracker<E> stateTracker) {
        this.log = Logging.getLogger(getContext().system(), this);
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
        this.cluster = Cluster.get(getContext().system());
        this.receptionist = ClusterReceptionistExtension.get(getContext().system());
        this.isDone = false;
        this.began = false;
        this.conf = conf;
        this.batchActor = actorRef;
        try {
            this.stateTracker = stateTracker;
            this.stateTracker.runPreTrainIterations(conf.getNumPasses());
            this.mediator.tell(new DistributedPubSubMediator.Subscribe(MASTER, getSelf()), getSelf());
            this.mediator.tell(new DistributedPubSubMediator.Subscribe(FINISH, getSelf()), getSelf());
            this.forceNextPhase = context().system().scheduler().schedule(Duration.create(10L, TimeUnit.SECONDS), Duration.create(10L, TimeUnit.SECONDS), new Runnable() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.1
                @Override // java.lang.Runnable
                public void run() {
                    if (MasterActor.this.stateTracker.isDone()) {
                        return;
                    }
                    try {
                        List<Job> currentJobs = MasterActor.this.stateTracker.currentJobs();
                        MasterActor.this.log.info("Status check on next iteration");
                        Collection<String> workerUpdates = MasterActor.this.stateTracker.workerUpdates();
                        if (currentJobs.size() != 1 || MasterActor.this.oneDown == null) {
                            if (currentJobs.size() == 1) {
                                MasterActor.this.log.info("Marking start of stale jobs");
                                MasterActor.this.oneDown = new AtomicLong(System.currentTimeMillis());
                            }
                        } else if (TimeUnit.MILLISECONDS.toMinutes(System.currentTimeMillis() - MasterActor.this.oneDown.get()) >= 5) {
                            MasterActor.this.stateTracker.currentJobs().clear();
                            MasterActor.this.oneDown = null;
                            MasterActor.this.log.info("Clearing out stale jobs");
                        }
                        if (workerUpdates.size() >= MasterActor.this.stateTracker.workers().size() || currentJobs.isEmpty()) {
                            MasterActor.this.nextBatch();
                        } else {
                            MasterActor.this.log.info("Still waiting on next batch, so far we have updates of size: " + workerUpdates.size() + " out of " + MasterActor.this.stateTracker.workers().size());
                        }
                        MasterActor.this.log.info("Current jobs left " + currentJobs);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }, context().dispatcher());
            this.clearStateWorkers = context().system().scheduler().schedule(Duration.create(1L, TimeUnit.MINUTES), Duration.create(1L, TimeUnit.MINUTES), new Runnable() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.2
                @Override // java.lang.Runnable
                public void run() {
                    if (MasterActor.this.stateTracker.isDone()) {
                        return;
                    }
                    try {
                        long currentTimeMillis = System.currentTimeMillis();
                        Map<String, Long> heartBeats = MasterActor.this.stateTracker.getHeartBeats();
                        for (String str : heartBeats.keySet()) {
                            if (TimeUnit.MILLISECONDS.toSeconds(currentTimeMillis - heartBeats.get(str).longValue()) >= 120) {
                                MasterActor.this.log.info("Removing stale worker " + str);
                                MasterActor.this.stateTracker.removeWorker(str);
                            }
                        }
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            }, context().dispatcher());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v10, types: [org.deeplearning4j.scaleout.iterativereduce.Updateable] */
    protected void nextBatch() throws Exception {
        if (this.stateTracker.workerUpdates().isEmpty() || !this.stateTracker.currentJobs().isEmpty()) {
            return;
        }
        E compute = compute();
        this.log.info("Updating next batch");
        this.stateTracker.setCurrent(compute);
        for (String str : this.stateTracker.workers()) {
            this.log.info("Replicating new network to " + str);
            this.stateTracker.addReplicate(str);
            this.stateTracker.enableWorker(str);
        }
        this.epochsComplete++;
        this.stateTracker.workerUpdates().clear();
        while (compute == null) {
            this.log.info("On next batch master results was null, attempting to grab results again");
            compute = getResults();
        }
        ActorRefUtils.throwExceptionIfExists(Futures.future(new Callable<Void>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.3
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Void call() throws Exception {
                MasterActor.this.mediator.tell(new DistributedPubSubMediator.Publish(BatchActor.BATCH, MoreWorkMessage.getInstance()), MasterActor.this.getSelf());
                MasterActor.this.log.info("Requesting more work...");
                return null;
            }
        }, context().dispatcher()), context().dispatcher());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public void doDoneOrNextPhase() throws Exception {
        if (this.stateTracker.workerUpdates().isEmpty()) {
            getMasterResults();
        } else {
            this.stateTracker.setCurrent(compute());
            this.epochsComplete++;
            this.stateTracker.workerUpdates().clear();
        }
        while (!this.stateTracker.currentJobs().isEmpty()) {
            this.log.info("Waiting fo jobs to finish up before next phase...");
            Thread.sleep(30000L);
        }
        if (this.stateTracker.currentJobs().isEmpty()) {
            this.isDone = true;
            this.stateTracker.finish();
            this.log.info("Done training!");
        }
    }

    public MasterActor(Conf conf, ActorRef actorRef) {
        this(conf, actorRef, null);
    }

    public void aroundPostRestart(Throwable th) {
        super.aroundPostRestart(th);
        this.log.info("Restarted because of ", th);
    }

    public void aroundPreRestart(Throwable th, Option<Object> option) {
        super.aroundPreRestart(th, option);
        this.log.info("Restarted because of ", th + " with message " + option.toString());
    }

    public void preStart() throws Exception {
        super.preStart();
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.log.info("Setup master with path " + self().path());
        this.log.info("Pre start on master " + self().path().toString());
    }

    public void postStop() throws Exception {
        super.postStop();
        this.log.info("Post stop on master");
        this.cluster.unsubscribe(getSelf());
        if (this.clearStateWorkers != null) {
            this.clearStateWorkers.cancel();
        }
        if (this.forceNextPhase != null) {
            this.forceNextPhase.cancel();
        }
    }

    public abstract void complete(DataOutputStream dataOutputStream);

    public E getResults() {
        try {
            return this.stateTracker.getCurrent();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public SupervisorStrategy supervisorStrategy() {
        return new OneForOneStrategy(0, Duration.Zero(), new Function<Throwable, SupervisorStrategy.Directive>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.4
            public SupervisorStrategy.Directive apply(Throwable th) {
                MasterActor.this.log.error("Problem with processing", th);
                return SupervisorStrategy.resume();
            }
        });
    }

    public Conf getConf() {
        return this.conf;
    }

    public E getMasterResults() {
        return getResults();
    }

    public boolean isDone() {
        return this.isDone;
    }
}
