package org.deeplearning4j.iterativereduce.actor.multilayer;

import akka.actor.ActorRef;
import akka.actor.OneForOneStrategy;
import akka.actor.Props;
import akka.actor.SupervisorStrategy;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import akka.dispatch.OnComplete;
import akka.japi.Function;
import java.util.List;
import java.util.concurrent.Callable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.iterativereduce.actor.core.Ack;
import org.deeplearning4j.iterativereduce.actor.core.NeedsModelMessage;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import org.deeplearning4j.scaleout.iterativereduce.multi.UpdateableImpl;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/multilayer/WorkerActor.class */
public class WorkerActor extends org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor<UpdateableImpl> {
    protected BaseMultiLayerNetwork network;
    protected DoubleMatrix combinedInput;
    protected UpdateableImpl workerUpdateable;
    protected ActorRef mediator;
    protected static Logger log = LoggerFactory.getLogger(WorkerActor.class);
    public static final String SYSTEM_NAME = "Workers";

    public WorkerActor(Conf conf) {
        super(conf);
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
        setup(conf);
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.BROADCAST, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(this.id, getSelf()), getSelf());
    }

    public WorkerActor(ActorRef actorRef, Conf conf) {
        super(conf, actorRef);
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
        setup(conf);
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.BROADCAST, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(this.id, getSelf()), getSelf());
    }

    public static Props propsFor(ActorRef actorRef, Conf conf) {
        return Props.create(WorkerActor.class, new Object[]{actorRef, conf});
    }

    public static Props propsFor(Conf conf) {
        return Props.create(WorkerActor.class, new Object[]{conf});
    }

    public void onReceive(Object obj) throws Exception {
        if (obj instanceof DistributedPubSubMediator.SubscribeAck) {
            log.info("Subscribed to " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
            return;
        }
        if (obj instanceof List) {
            updateTraining((List) obj);
            return;
        }
        if (obj instanceof BaseMultiLayerNetwork) {
            setNetwork((BaseMultiLayerNetwork) obj);
            log.info("Set network");
            return;
        }
        if (obj instanceof Ack) {
            log.info("Ack from master on worker " + this.id);
            return;
        }
        if (!(obj instanceof Updateable)) {
            unhandled(obj);
            return;
        }
        UpdateableImpl updateableImpl = (UpdateableImpl) obj;
        setWorkerUpdateable(updateableImpl);
        log.info("Updated worker network");
        if (updateableImpl.get() == null) {
            log.warn("Unable to initialize network; network was null");
            throw new IllegalArgumentException("Network was null");
        }
        setNetwork(updateableImpl.get().clone());
    }

    private void updateTraining(List<Pair<DoubleMatrix, DoubleMatrix>> list) {
        DoubleMatrix doubleMatrix = new DoubleMatrix(list.size(), ((DoubleMatrix) list.get(0).getFirst()).columns);
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(list.size(), ((DoubleMatrix) list.get(0).getSecond()).columns);
        for (int i = 0; i < list.size(); i++) {
            doubleMatrix.putRow(i, (DoubleMatrix) list.get(i).getFirst());
            doubleMatrix2.putRow(i, (DoubleMatrix) list.get(i).getSecond());
        }
        setCombinedInput(doubleMatrix);
        setOutcomes(doubleMatrix2);
        Futures.future(new Callable<UpdateableImpl>() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.WorkerActor.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public UpdateableImpl call() throws Exception {
                UpdateableImpl compute = WorkerActor.this.compute();
                WorkerActor.log.info("Updating parent actor...");
                WorkerActor.this.mediator.tell(new DistributedPubSubMediator.Publish(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.MASTER, compute), WorkerActor.this.getSelf());
                return compute;
            }
        }, getContext().dispatcher()).onComplete(new OnComplete<UpdateableImpl>() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.WorkerActor.2
            public void onComplete(Throwable th, UpdateableImpl updateableImpl) throws Throwable {
                if (th != null) {
                    WorkerActor.log.error("Unable to process work ", th);
                    throw th;
                }
                WorkerActor.this.availableForWork();
            }
        }, context().dispatcher());
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public UpdateableImpl compute(List<UpdateableImpl> list) {
        return compute();
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public synchronized UpdateableImpl compute() {
        log.info("Training network");
        while (getNetwork() == null) {
            log.info("Network is null, this worker has recently joined the cluster. Asking master for a copy of the current network");
            this.mediator.tell(new DistributedPubSubMediator.Publish(org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.MASTER, new NeedsModelMessage(this.id)), getSelf());
            try {
                Thread.sleep(15000L);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        getNetwork().trainNetwork(getCombinedInput(), getOutcomes(), this.extraParams);
        return new UpdateableImpl(getNetwork());
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public boolean incrementIteration() {
        return false;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public void setup(Conf conf) {
        super.setup(conf);
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public SupervisorStrategy supervisorStrategy() {
        return new OneForOneStrategy(0, Duration.Zero(), new Function<Throwable, SupervisorStrategy.Directive>() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.WorkerActor.3
            public SupervisorStrategy.Directive apply(Throwable th) {
                WorkerActor.log.error("Problem with processing", th);
                return SupervisorStrategy.stop();
            }
        });
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public UpdateableImpl getResults() {
        return this.workerUpdateable;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public synchronized void update(UpdateableImpl updateableImpl) {
        this.workerUpdateable = updateableImpl;
    }

    public synchronized BaseMultiLayerNetwork getNetwork() {
        return this.network;
    }

    public synchronized void setNetwork(BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this.network = baseMultiLayerNetwork;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public synchronized DoubleMatrix getCombinedInput() {
        return this.combinedInput;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor
    public synchronized void setCombinedInput(DoubleMatrix doubleMatrix) {
        this.combinedInput = doubleMatrix;
    }

    public synchronized UpdateableImpl getWorkerUpdateable() {
        return this.workerUpdateable;
    }

    public synchronized void setWorkerUpdateable(UpdateableImpl updateableImpl) {
        this.workerUpdateable = updateableImpl;
    }
}
