package org.deeplearning4j.iterativereduce.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.AddressFromURIString;
import akka.actor.Cancellable;
import akka.actor.OneForOneStrategy;
import akka.actor.SupervisorStrategy;
import akka.actor.UntypedActor;
import akka.cluster.Cluster;
import akka.cluster.ClusterEvent;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import akka.japi.Function;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.deeplearning4j.iterativereduce.actor.core.ClearWorker;
import org.deeplearning4j.iterativereduce.actor.core.Job;
import org.deeplearning4j.iterativereduce.actor.util.ActorRefUtils;
import org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.ComputableWorker;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/core/actor/WorkerActor.class */
public abstract class WorkerActor<E extends Updateable<?>> extends UntypedActor implements DeepLearningConfigurable, ComputableWorker<E> {
    protected E results;
    protected Job currentJob;
    protected String id;
    Cluster cluster;
    protected ActorRef clusterClient;
    protected String masterPath;
    protected StateTracker<E> tracker;
    protected AtomicBoolean isWorking;
    protected Conf conf;
    protected ActorRef mediator;
    protected Cancellable heartbeat;
    protected static Logger log = LoggerFactory.getLogger(WorkerActor.class);

    public WorkerActor(Conf conf, StateTracker<E> stateTracker) throws Exception {
        this(conf, null, stateTracker);
    }

    public WorkerActor(Conf conf, ActorRef actorRef, StateTracker<E> stateTracker) throws Exception {
        this.cluster = Cluster.get(getContext().system());
        this.isWorking = new AtomicBoolean(false);
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
        setup(conf);
        this.tracker = stateTracker;
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MasterActor.BROADCAST, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MasterActor.SHUTDOWN, getSelf()), getSelf());
        this.id = generateId();
        this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, register()), getSelf());
        this.clusterClient = actorRef;
        stateTracker.availableForWork(this.id);
        this.masterPath = conf.getMasterAbsPath();
        log.info("Registered with master " + this.id + " at master " + conf.getMasterAbsPath());
        heartbeat();
        stateTracker.addWorker(this.id);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void heartbeat() throws Exception {
        this.heartbeat = context().system().scheduler().schedule(Duration.apply(30L, TimeUnit.SECONDS), Duration.apply(30L, TimeUnit.SECONDS), new Runnable() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor.1
            @Override // java.lang.Runnable
            public void run() {
                if (!WorkerActor.this.tracker.isDone()) {
                    WorkerActor.this.tracker.addWorker(WorkerActor.this.id);
                }
                if (!WorkerActor.this.tracker.isDone() && WorkerActor.this.tracker.needsReplicate(WorkerActor.this.id)) {
                    try {
                        WorkerActor.log.info("Updating worker " + WorkerActor.this.id);
                        E current = WorkerActor.this.tracker.getCurrent();
                        if (current == null || current.get() == null) {
                            return;
                        }
                        WorkerActor.this.results = current;
                        WorkerActor.this.tracker.doneReplicating(WorkerActor.this.id);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
                try {
                    WorkerActor.this.checkJobAvailable();
                    if (WorkerActor.this.currentJob != null && !WorkerActor.this.isWorking.get() && WorkerActor.this.tracker.jobFor(WorkerActor.this.id) != null) {
                        WorkerActor.log.info("Confirmation from " + WorkerActor.this.currentJob.getWorkerId() + " on work");
                        if (WorkerActor.this.currentJob.getWork() == null) {
                            throw new IllegalStateException("Work for worker " + WorkerActor.this.id + " was null");
                        }
                        WorkerActor.this.processDataSet(WorkerActor.this.currentJob.getWork().asList());
                    } else if (WorkerActor.this.currentJob == null || (!WorkerActor.this.isWorking.get() && WorkerActor.this.tracker.jobFor(WorkerActor.this.id) != null)) {
                        if (WorkerActor.this.tracker.jobFor(WorkerActor.this.id) != null) {
                            WorkerActor.this.tracker.clearJob(WorkerActor.this.id);
                        }
                        WorkerActor.log.info("Clearing stale job... " + WorkerActor.this.id);
                    }
                } catch (Exception e2) {
                    throw new RuntimeException(e2);
                }
            }
        }, context().dispatcher());
    }

    protected void processDataSet(final List<DataSet> list) {
        if (list == null || list.isEmpty()) {
            log.warn("Worker " + this.id + " was passed an empty or null list");
        } else {
            ActorRefUtils.throwExceptionIfExists(Futures.future(new Callable<E>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor.2
                @Override // java.util.concurrent.Callable
                public E call() throws Exception {
                    INDArray create = Nd4j.create(list.size(), ((DataSet) list.get(0)).getFeatureMatrix().columns());
                    INDArray create2 = Nd4j.create(list.size(), ((DataSet) list.get(0)).getLabels().columns());
                    for (int i = 0; i < list.size(); i++) {
                        create.putRow(i, ((DataSet) list.get(i)).getFeatureMatrix());
                        create2.putRow(i, ((DataSet) list.get(i)).getLabels());
                    }
                    WorkerActor.this.tracker.beginTraining();
                    if (WorkerActor.this.tracker.needsReplicate(WorkerActor.this.id)) {
                        WorkerActor.log.info("Updating network for worker " + WorkerActor.this.id);
                        WorkerActor.this.results = WorkerActor.this.tracker.getCurrent();
                        WorkerActor.this.tracker.doneReplicating(WorkerActor.this.id);
                    }
                    E e = (E) WorkerActor.this.compute();
                    if (e != null) {
                        WorkerActor.log.info("Done working; adding update to mini batch on worker " + WorkerActor.this.id);
                        WorkerActor.this.tracker.addUpdate(WorkerActor.this.id, e);
                        WorkerActor.this.tracker.disableWorker(WorkerActor.this.id);
                        WorkerActor.log.info("Number of updates so far " + WorkerActor.this.tracker.workerUpdates().size());
                    }
                    return e;
                }
            }, getContext().dispatcher()), context().dispatcher());
        }
    }

    public WorkerState register() {
        return new WorkerState(this.id);
    }

    public String generateId() {
        return System.getProperty("akka.remote.netty.tcp.hostname", "localhost") + "-" + UUID.randomUUID().toString();
    }

    public void postStop() throws Exception {
        super.postStop();
        try {
            this.tracker.removeWorker(this.id);
        } catch (Exception e) {
            log.info("Tracker already shut down");
        }
        log.info("Post stop on worker actor");
        this.cluster.unsubscribe(getSelf());
    }

    public void preStart() throws Exception {
        super.preStart();
        this.cluster.subscribe(getSelf(), new Class[]{ClusterEvent.MemberEvent.class});
        log.info("Pre start on worker");
    }

    protected void checkJobAvailable() throws Exception {
        Job jobFor = this.tracker.jobFor(this.id);
        if (jobFor == null || !this.tracker.workerEnabled(this.id)) {
            if (this.isWorking.get() || jobFor == null) {
                return;
            }
            this.tracker.clearJob(this.id);
            log.info("Clearing stale job " + this.id);
            return;
        }
        if (this.tracker.needsReplicate(this.id)) {
            try {
                log.info("Updating worker " + this.id);
                this.results = this.tracker.getCurrent();
                this.tracker.doneReplicating(this.id);
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (jobFor == null || this.currentJob != null) {
            return;
        }
        log.info("Assigning job for worker " + this.id);
        this.currentJob = jobFor;
        this.tracker.updateJob(new Job(this.id, null));
    }

    public E compute(List<E> list) {
        return compute();
    }

    public abstract E compute();

    public boolean incrementIteration() {
        return false;
    }

    public void setup(Conf conf) {
        this.conf = conf;
        String masterUrl = conf.getMasterUrl();
        this.masterPath = conf.getMasterAbsPath();
        Cluster.get(context().system()).join(AddressFromURIString.apply(masterUrl));
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
    }

    public SupervisorStrategy supervisorStrategy() {
        return new OneForOneStrategy(0, Duration.Zero(), new Function<Throwable, SupervisorStrategy.Directive>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor.3
            public SupervisorStrategy.Directive apply(Throwable th) {
                WorkerActor.log.error("Problem with processing", th);
                WorkerActor.this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, new ClearWorker(WorkerActor.this.id)), WorkerActor.this.getSelf());
                return SupervisorStrategy.restart();
            }
        });
    }

    public E getResults() {
        return this.results;
    }

    public void update(E e) {
        this.results = e;
    }
}
