package org.deeplearning4j.iterativereduce.actor.multilayer;

import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Address;
import akka.actor.PoisonPill;
import akka.actor.Props;
import akka.cluster.Cluster;
import akka.contrib.pattern.ClusterClient;
import akka.contrib.pattern.ClusterSingletonManager;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import akka.dispatch.OnComplete;
import akka.routing.RoundRobinPool;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigFactory;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.Callable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.iterativereduce.actor.core.ClusterListener;
import org.deeplearning4j.iterativereduce.actor.core.actor.BatchActor;
import org.deeplearning4j.iterativereduce.actor.core.actor.ModelSavingActor;
import org.deeplearning4j.iterativereduce.actor.util.ActorRefUtils;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.multi.UpdateableImpl;
import org.deeplearning4j.scaleout.zookeeper.ZooKeeperConfigurationRegister;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/multilayer/ActorNetworkRunner.class */
public class ActorNetworkRunner implements DeepLearningConfigurable, Serializable {
    private static final long serialVersionUID = -4385335922485305364L;
    private transient ActorSystem system;
    private Integer epochs;
    private UpdateableImpl result;
    private ActorRef mediator;
    private BaseMultiLayerNetwork startingNetwork;
    private static Logger log = LoggerFactory.getLogger(ActorNetworkRunner.class);
    private static String systemName = "ClusterSystem";
    private String type;
    private Address masterAddress;
    private DataSetIterator iter;
    protected ActorRef masterActor;

    public ActorNetworkRunner(String str, DataSetIterator dataSetIterator, BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this.type = "master";
        this.type = str;
        this.iter = dataSetIterator;
        this.startingNetwork = baseMultiLayerNetwork;
    }

    public ActorNetworkRunner(String str, DataSetIterator dataSetIterator) {
        this(str, dataSetIterator, null);
    }

    public ActorNetworkRunner(DataSetIterator dataSetIterator) {
        this("master", dataSetIterator, null);
    }

    public ActorNetworkRunner(DataSetIterator dataSetIterator, BaseMultiLayerNetwork baseMultiLayerNetwork) {
        this("master", dataSetIterator, baseMultiLayerNetwork);
    }

    public ActorNetworkRunner(String str, String str2) {
        this.type = "master";
        this.type = str;
        URI create = URI.create(str2);
        this.masterAddress = Address.apply(create.getScheme(), create.getUserInfo(), create.getHost(), create.getPort());
    }

    public ActorNetworkRunner() {
        this.type = "master";
    }

    public Address startBackend(Address address, String str, Conf conf, DataSetIterator dataSetIterator) {
        ActorSystem create = ActorSystem.create(systemName, ConfigFactory.parseString("akka.cluster.roles=[master,worker]").withFallback(ConfigFactory.load()));
        ActorRef actorOf = create.actorOf(new RoundRobinPool(Runtime.getRuntime().availableProcessors()).props(Props.create(BatchActor.class, new Object[]{dataSetIterator})));
        Props propsFor = this.startingNetwork != null ? MasterActor.propsFor(conf, actorOf, this.startingNetwork) : MasterActor.propsFor(conf, actorOf);
        log.info("Started batch actor");
        Address selfAddress = address == null ? Cluster.get(create).selfAddress() : address;
        Cluster.get(create).join(selfAddress);
        this.masterActor = create.actorOf(ClusterSingletonManager.defaultProps(propsFor, "active", PoisonPill.getInstance(), "master"));
        return selfAddress;
    }

    public void setup(final Conf conf) {
        this.system = ActorSystem.create(systemName);
        this.mediator = DistributedPubSubExtension.get(this.system).mediator();
        this.epochs = Integer.valueOf(conf.getPretrainEpochs());
        if (!this.type.equals("master")) {
            Config withFallback = ConfigFactory.parseString(String.format("akka.cluster.seed-nodes = [%s]", conf.getMasterUrl())).withFallback(ConfigFactory.load());
            log.info("Starting workers");
            ActorSystem create = ActorSystem.create(systemName, withFallback);
            create.actorOf(Props.create(ClusterListener.class, new Object[0]));
            HashSet hashSet = new HashSet();
            hashSet.add(create.actorSelection(this.masterAddress + "/user/receptionist"));
            create.actorOf(WorkerActor.propsFor(create.actorOf(ClusterClient.defaultProps(hashSet), "clusterClient"), conf), "worker");
            try {
                Thread.sleep(30000L);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            Cluster.get(create).join(this.masterAddress);
            log.info("Setup worker nodes");
            return;
        }
        if (this.iter == null) {
            throw new IllegalStateException("Unable to initialize no dataset to train");
        }
        log.info("Starting master");
        this.masterAddress = startBackend(null, "master", conf, this.iter);
        try {
            Thread.sleep(60000L);
        } catch (InterruptedException e2) {
            Thread.currentThread().interrupt();
        }
        this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, this), this.mediator);
        log.info("Starting model saver");
        this.system.actorOf(Props.create(ModelSavingActor.class, new Object[]{",model-saver"}));
        Cluster.get(this.system).join(this.masterAddress);
        conf.setMasterUrl(getMasterAddress().toString());
        conf.setMasterAbsPath(ActorRefUtils.absPath(this.masterActor, this.system));
        log.info("Stored master path of " + conf.getMasterAbsPath());
        Futures.future(new Callable<Void>() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.ActorNetworkRunner.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Void call() throws Exception {
                ActorNetworkRunner.log.info("Registering with zookeeper; if the logging stops here, ensure zookeeper is started");
                ZooKeeperConfigurationRegister zooKeeperConfigurationRegister = new ZooKeeperConfigurationRegister(conf, "master", "localhost", 2181);
                zooKeeperConfigurationRegister.register();
                zooKeeperConfigurationRegister.close();
                return null;
            }
        }, this.system.dispatcher()).onComplete(new OnComplete<Void>() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.ActorNetworkRunner.2
            public void onComplete(Throwable th, Void r5) throws Throwable {
                if (th != null) {
                    throw th;
                }
                ActorNetworkRunner.log.info("Registered conf with zookeeper");
            }
        }, this.system.dispatcher());
        log.info("Setup master with epochs " + this.epochs);
    }

    public void train(List<Pair<DoubleMatrix, DoubleMatrix>> list) {
        log.info("Publishing to results for training");
        try {
            log.info("Waiting for cluster to go up...");
            Thread.sleep(30000L);
            log.info("Done waiting");
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, this), this.mediator);
        log.info("Started pipeline");
        this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, list), this.mediator);
        log.info("Published results");
    }

    public void train() {
        if (this.iter.hasNext()) {
            train((Pair<DoubleMatrix, DoubleMatrix>) this.iter.next());
        } else {
            log.warn("No data found");
        }
    }

    public void train(Pair<DoubleMatrix, DoubleMatrix> pair) {
        train(new ArrayList(Arrays.asList(pair)));
    }

    public void train(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        train(new Pair<>(doubleMatrix, doubleMatrix2));
    }

    public UpdateableImpl getResult() {
        return this.result;
    }

    public Address getMasterAddress() {
        return this.masterAddress;
    }
}
