package org.deeplearning4j.iterativereduce.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.AddressFromURIString;
import akka.actor.OneForOneStrategy;
import akka.actor.SupervisorStrategy;
import akka.actor.UntypedActor;
import akka.cluster.Cluster;
import akka.cluster.ClusterEvent;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.event.Logging;
import akka.event.LoggingAdapter;
import akka.japi.Function;
import java.util.List;
import java.util.UUID;
import org.deeplearning4j.iterativereduce.actor.core.ClearWorker;
import org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.ComputableWorker;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/core/actor/WorkerActor.class */
public abstract class WorkerActor<E extends Updateable<?>> extends UntypedActor implements DeepLearningConfigurable, ComputableWorker<E> {
    protected ActorRef mediator;
    protected E e;
    protected E results;
    protected LoggingAdapter log;
    protected int fineTuneEpochs;
    protected int preTrainEpochs;
    protected int[] hiddenLayerSizes;
    protected int numHidden;
    protected int numVisible;
    protected int numHiddenNeurons;
    protected int renderWeightEpochs;
    protected long seed;
    protected double learningRate;
    protected double corruptionLevel;
    protected Object[] extraParams;
    protected String id;
    protected boolean useRegularization;
    Cluster cluster;
    protected ActorRef clusterClient;
    protected String masterPath;
    protected StateTracker<E> tracker;

    public WorkerActor(Conf conf, StateTracker<E> stateTracker) {
        this(conf, null, stateTracker);
    }

    public WorkerActor(Conf conf, ActorRef actorRef, StateTracker<E> stateTracker) {
        this.log = Logging.getLogger(getContext().system(), this);
        this.cluster = Cluster.get(getContext().system());
        setup(conf);
        this.tracker = stateTracker;
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MasterActor.BROADCAST, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MasterActor.SHUTDOWN, getSelf()), getSelf());
        this.id = generateId();
        this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, register()), getSelf());
        this.clusterClient = actorRef;
        stateTracker.availableForWork(this.id);
        this.masterPath = conf.getMasterAbsPath();
        this.log.info("Registered with master " + this.id + " at master " + conf.getMasterAbsPath());
    }

    public WorkerState register() {
        return new WorkerState(this.id);
    }

    public String generateId() {
        return UUID.randomUUID().toString();
    }

    public void postStop() throws Exception {
        super.postStop();
        try {
            this.tracker.removeWorker(this.id);
        } catch (Exception e) {
            this.log.info("Tracker already shut down");
        }
        this.log.info("Post stop on worker actor");
        this.cluster.unsubscribe(getSelf());
    }

    public void preStart() throws Exception {
        super.preStart();
        this.cluster.subscribe(getSelf(), ClusterEvent.MemberEvent.class);
        this.log.info("Pre start on worker");
    }

    public E compute(List<E> list) {
        return compute();
    }

    public abstract E compute();

    public boolean incrementIteration() {
        return false;
    }

    public void setup(Conf conf) {
        this.hiddenLayerSizes = conf.getLayerSizes();
        this.numHidden = conf.getnOut();
        this.numVisible = conf.getnIn();
        this.numHiddenNeurons = this.hiddenLayerSizes.length;
        this.seed = conf.getSeed();
        this.renderWeightEpochs = conf.getRenderWeightEpochs();
        this.useRegularization = conf.isUseRegularization();
        this.learningRate = conf.getPretrainLearningRate();
        this.preTrainEpochs = conf.getPretrainEpochs();
        this.fineTuneEpochs = conf.getFinetuneEpochs();
        this.corruptionLevel = conf.getCorruptionLevel();
        this.extraParams = conf.getDeepLearningParams();
        String masterUrl = conf.getMasterUrl();
        this.masterPath = conf.getMasterAbsPath();
        Cluster.get(context().system()).join(AddressFromURIString.apply(masterUrl));
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
    }

    public SupervisorStrategy supervisorStrategy() {
        return new OneForOneStrategy(0, Duration.Zero(), new Function<Throwable, SupervisorStrategy.Directive>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor.1
            public SupervisorStrategy.Directive apply(Throwable th) {
                WorkerActor.this.log.error("Problem with processing", th);
                WorkerActor.this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, new ClearWorker(WorkerActor.this.id)), WorkerActor.this.getSelf());
                return SupervisorStrategy.restart();
            }
        });
    }

    public E getResults() {
        return this.results;
    }

    public void update(E e) {
        this.e = e;
    }

    public E getE() {
        return this.e;
    }

    public void setE(E e) {
        this.e = e;
    }
}
