package org.deeplearning4j.iterativereduce.actor.multilayer;

import akka.actor.ActorRef;
import akka.actor.Cancellable;
import akka.actor.Props;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeUnit;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.iterativereduce.actor.core.Ack;
import org.deeplearning4j.iterativereduce.actor.core.ClearWorker;
import org.deeplearning4j.iterativereduce.actor.core.Job;
import org.deeplearning4j.iterativereduce.actor.core.NoJobFound;
import org.deeplearning4j.iterativereduce.actor.util.ActorRefUtils;
import org.deeplearning4j.iterativereduce.tracker.statetracker.StateTracker;
import org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import org.deeplearning4j.scaleout.iterativereduce.multi.UpdateableImpl;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/multilayer/WorkerActor.class */
public class WorkerActor extends org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor<UpdateableImpl> {
    protected BaseMultiLayerNetwork network;
    protected DoubleMatrix combinedInput;
    protected UpdateableImpl workerUpdateable;
    protected ActorRef mediator;
    protected Cancellable heartbeat;
    protected static Logger log = LoggerFactory.getLogger(WorkerActor.class);

    public WorkerActor(Conf conf, StateTracker<UpdateableImpl> stateTracker) {
        super(conf, stateTracker);
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
        setup(conf);
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.BROADCAST, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(this.id, getSelf()), getSelf());
        heartbeat();
        stateTracker.addWorker(this.id);
    }

    public WorkerActor(ActorRef actorRef, Conf conf, StateTracker<UpdateableImpl> stateTracker) {
        super(conf, actorRef, stateTracker);
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
        setup(conf);
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.BROADCAST, getSelf()), getSelf());
        stateTracker.addWorker(this.id);
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(this.id, getSelf()), getSelf());
        heartbeat();
    }

    public static Props propsFor(ActorRef actorRef, Conf conf, StateTracker<UpdateableImpl> stateTracker) {
        return Props.create(WorkerActor.class, new Object[]{actorRef, conf, stateTracker});
    }

    public static Props propsFor(Conf conf, HazelCastStateTracker hazelCastStateTracker) {
        return Props.create(WorkerActor.class, new Object[]{conf, hazelCastStateTracker});
    }

    protected void confirmWorking() {
        Job jobFor = this.tracker.jobFor(this.id);
        if (jobFor != null) {
            this.mediator.tell(new DistributedPubSubMediator.Publish(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.MASTER, jobFor), getSelf());
        } else {
            log.warn("Not confirming work when none to be found");
        }
    }

    protected void heartbeat() {
        this.heartbeat = context().system().scheduler().schedule(Duration.apply(10L, TimeUnit.SECONDS), Duration.apply(10L, TimeUnit.SECONDS), new Runnable() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.WorkerActor.1
            @Override // java.lang.Runnable
            public void run() {
                WorkerActor.log.info("Sending heartbeat to master");
                WorkerActor.this.mediator.tell(new DistributedPubSubMediator.Publish(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.MASTER, WorkerActor.this.register()), WorkerActor.this.getSelf());
                WorkerActor.this.tracker.addWorker(WorkerActor.this.id);
            }
        }, context().dispatcher());
    }

    public void onReceive(Object obj) throws Exception {
        if ((obj instanceof DistributedPubSubMediator.SubscribeAck) || (obj instanceof DistributedPubSubMediator.UnsubscribeAck)) {
            this.mediator.tell(new DistributedPubSubMediator.Publish("topics", obj), getSelf());
            log.info("Subscribed to " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
            return;
        }
        if (!(obj instanceof Job)) {
            if (obj instanceof BaseMultiLayerNetwork) {
                setNetwork((BaseMultiLayerNetwork) obj);
                log.info("Set network");
                return;
            } else if (obj instanceof Ack) {
                log.info("Ack from master on worker " + this.id);
                return;
            } else if (!(obj instanceof Updateable)) {
                unhandled(obj);
                return;
            } else {
                final UpdateableImpl updateableImpl = (UpdateableImpl) obj;
                ActorRefUtils.throwExceptionIfExists(Futures.future(new Callable<Void>() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.WorkerActor.2
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Void call() throws Exception {
                        if (updateableImpl.get() == null) {
                            WorkerActor.this.setNetwork(WorkerActor.this.tracker.getCurrent().get());
                            return null;
                        }
                        WorkerActor.this.setWorkerUpdateable(updateableImpl.clone());
                        WorkerActor.this.setNetwork(updateableImpl.get());
                        return null;
                    }
                }, context().dispatcher()), context().dispatcher());
                return;
            }
        }
        Job job = (Job) obj;
        if (this.tracker.jobFor(this.id) == null) {
            this.tracker.addJobToCurrent(job);
            log.info("Confirmation from " + job.getWorkerId() + " on work");
            List<DataSet> list = (List) job.getWork();
            confirmWorking();
            updateTraining(list);
            return;
        }
        boolean z = false;
        while (!z) {
            Iterator<String> it = this.tracker.workers().iterator();
            while (true) {
                if (it.hasNext()) {
                    String next = it.next();
                    if (this.tracker.jobFor(next) == null) {
                        job.setWorkerId(next);
                        this.mediator.tell(new DistributedPubSubMediator.Publish(next, job), getSelf());
                        log.info("Delegated work to worker " + next);
                        z = true;
                        break;
                    }
                }
            }
        }
    }

    protected void updateTraining(List<DataSet> list) {
        DoubleMatrix doubleMatrix = new DoubleMatrix(list.size(), ((DoubleMatrix) list.get(0).getFirst()).columns);
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(list.size(), ((DoubleMatrix) list.get(0).getSecond()).columns);
        for (int i = 0; i < list.size(); i++) {
            doubleMatrix.putRow(i, (DoubleMatrix) list.get(i).getFirst());
            doubleMatrix2.putRow(i, (DoubleMatrix) list.get(i).getSecond());
        }
        ActorRefUtils.throwExceptionIfExists(Futures.future(new Callable<UpdateableImpl>() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.WorkerActor.3
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public UpdateableImpl call() throws Exception {
                UpdateableImpl compute = WorkerActor.this.compute();
                if (compute != null) {
                    WorkerActor.log.info("Updating parent actor...");
                    WorkerActor.this.mediator.tell(new DistributedPubSubMediator.Publish(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.MASTER, compute), WorkerActor.this.getSelf());
                } else {
                    WorkerActor.this.mediator.tell(new DistributedPubSubMediator.Publish(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.MASTER, NoJobFound.getInstance()), WorkerActor.this.getSelf());
                    WorkerActor.log.info("No job found; unlocking worker " + WorkerActor.this.id);
                }
                return compute;
            }
        }, getContext().dispatcher()), context().dispatcher());
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public UpdateableImpl compute(List<UpdateableImpl> list) {
        return compute();
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public UpdateableImpl compute() {
        log.info("Training network");
        BaseMultiLayerNetwork network = getNetwork();
        while (network == null) {
            try {
                network = this.tracker.getCurrent().get();
                this.network = network;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        DataSet dataSet = null;
        Job jobFor = this.tracker.jobFor(this.id);
        if (jobFor != null) {
            log.info("Found job for worker " + this.id);
            dataSet = jobFor.getWork() instanceof List ? DataSet.merge((List) jobFor.getWork()) : (DataSet) jobFor.getWork();
        }
        if (jobFor == null) {
            return null;
        }
        if (dataSet == null) {
            throw new IllegalStateException("No job found for worker " + this.id);
        }
        if (this.tracker.isPretrain()) {
            log.info("Worker " + this.id + " pretraining");
            network.pretrain((DoubleMatrix) dataSet.getFirst(), this.extraParams);
        } else {
            network.setInput((DoubleMatrix) dataSet.getFirst());
            log.info("Worker " + this.id + " finetune");
            network.feedForward((DoubleMatrix) dataSet.getFirst());
            network.finetune((DoubleMatrix) dataSet.getSecond(), this.learningRate, this.fineTuneEpochs);
        }
        if (jobFor != null) {
            try {
                this.tracker.clearJob(jobFor);
            } catch (Exception e2) {
                throw new RuntimeException(e2);
            }
        }
        return new UpdateableImpl(network);
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public boolean incrementIteration() {
        return false;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public void setup(Conf conf) {
        super.setup(conf);
    }

    public void aroundPostStop() {
        super.aroundPostStop();
        this.mediator.tell(new DistributedPubSubMediator.Publish(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.MASTER, new ClearWorker(this.id)), getSelf());
        this.heartbeat.cancel();
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public UpdateableImpl getResults() {
        return this.workerUpdateable;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public void update(UpdateableImpl updateableImpl) {
        this.workerUpdateable = updateableImpl;
    }

    public synchronized BaseMultiLayerNetwork getNetwork() {
        return this.network;
    }

    public void setNetwork(BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this.network = baseMultiLayerNetwork;
    }

    public UpdateableImpl getWorkerUpdateable() {
        return this.workerUpdateable;
    }

    public void setWorkerUpdateable(UpdateableImpl updateableImpl) {
        this.workerUpdateable = updateableImpl;
    }
}
