package org.deeplearning4j.iterativereduce.actor.multilayer;

import akka.actor.ActorRef;
import akka.actor.PoisonPill;
import akka.contrib.pattern.ClusterSingletonManager;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.routing.RoundRobinPool;
import java.io.DataOutputStream;
import java.util.Collection;
import org.deeplearning4j.iterativereduce.actor.core.Ack;
import org.deeplearning4j.iterativereduce.actor.core.DoneMessage;
import org.deeplearning4j.iterativereduce.actor.core.Job;
import org.deeplearning4j.iterativereduce.actor.core.MoreWorkMessage;
import org.deeplearning4j.iterativereduce.actor.core.actor.BatchActor;
import org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast.DeepLearningAccumulatorIterateAndUpdate;
import org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.iterativereduce.multi.UpdateableImpl;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/multilayer/MasterActor.class */
public class MasterActor extends org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor<UpdateableImpl> {
    protected BaseMultiLayerNetwork network;

    public MasterActor(Conf conf, ActorRef actorRef, HazelCastStateTracker hazelCastStateTracker) {
        super(conf, actorRef, hazelCastStateTracker);
        setup(conf);
    }

    public MasterActor(Conf conf, ActorRef actorRef, BaseMultiLayerNetwork baseMultiLayerNetwork, HazelCastStateTracker hazelCastStateTracker) {
        super(conf, actorRef, hazelCastStateTracker);
        this.network = baseMultiLayerNetwork;
        setup(conf);
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public UpdateableImpl m16compute() {
        DeepLearningAccumulatorIterateAndUpdate deepLearningAccumulatorIterateAndUpdate = (DeepLearningAccumulatorIterateAndUpdate) this.stateTracker.updates();
        if (this.stateTracker.workerUpdates().isEmpty()) {
            return null;
        }
        try {
            deepLearningAccumulatorIterateAndUpdate.accumulate();
            UpdateableImpl results = getResults();
            if (results == null) {
                results = deepLearningAccumulatorIterateAndUpdate.accumulated();
            } else {
                results.set(deepLearningAccumulatorIterateAndUpdate.accumulated().get());
            }
            try {
                this.stateTracker.setCurrent(results);
                return results;
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        } catch (Exception e2) {
            this.log.debug("Unable to accumulate results", e2);
            return null;
        }
    }

    public void setup(Conf conf) {
        this.log.info("Starting workers");
        context().system().actorOf(ClusterSingletonManager.defaultProps(new RoundRobinPool(Runtime.getRuntime().availableProcessors()).props(WorkerActor.propsFor(conf, this.stateTracker)), "master", PoisonPill.getInstance(), "master"), "worker");
        this.log.info("Broadcasting initial master network");
        BaseMultiLayerNetwork init = this.network == null ? conf.init() : this.network;
        init.initializeLayers(Nd4j.zeros(1, conf.getConf().getnIn()));
        try {
            this.stateTracker.setCurrent(new UpdateableImpl(init));
            this.log.info("Stored " + this.stateTracker.getCurrent().get());
            this.stateTracker.setMiniBatchSize(conf.getSplit());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void onReceive(Object obj) throws Exception {
        if ((obj instanceof DistributedPubSubMediator.SubscribeAck) || (obj instanceof DistributedPubSubMediator.UnsubscribeAck)) {
            this.mediator.tell(new DistributedPubSubMediator.Publish("topics", obj), getSelf());
            this.log.info("Subscribed " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
            return;
        }
        if (obj instanceof DoneMessage) {
            this.log.info("Received done message");
            doDoneOrNextPhase();
            return;
        }
        if (obj instanceof String) {
            getSender().tell(Ack.getInstance(), getSelf());
            return;
        }
        if (obj instanceof MoreWorkMessage) {
            this.log.info("Prompted for more work, starting pipeline");
            this.mediator.tell(new DistributedPubSubMediator.Publish(BatchActor.BATCH, MoreWorkMessage.getInstance()), getSelf());
            return;
        }
        if (!(obj instanceof Collection)) {
            unhandled(obj);
            return;
        }
        for (String str : (Collection) obj) {
            DataSet loadForWorker = this.stateTracker.loadForWorker(str);
            int i = 0;
            while (loadForWorker == null && i < 3) {
                loadForWorker = this.stateTracker.loadForWorker(str);
                i++;
                if (loadForWorker == null) {
                    Thread.sleep(10000L);
                    this.log.info("Data still not found....sleeping for 10 seconds and trying again");
                }
            }
            if (loadForWorker == null && i >= 3) {
                this.log.info("No data found for worker..." + str + " returning");
                return;
            } else {
                this.stateTracker.addJobToCurrent(new Job(str, loadForWorker.copy()));
                this.log.info("Job delegated for " + str);
            }
        }
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor
    public void complete(DataOutputStream dataOutputStream) {
        getMasterResults().get().write(dataOutputStream);
    }
}
