package org.deeplearning4j.iterativereduce.actor.single;

import akka.actor.ActorRef;
import akka.actor.ActorSystem;
import akka.actor.Address;
import akka.actor.PoisonPill;
import akka.actor.Props;
import akka.cluster.Cluster;
import akka.contrib.pattern.ClusterClient;
import akka.contrib.pattern.ClusterSingletonManager;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import com.typesafe.config.ConfigFactory;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.iterativereduce.actor.core.actor.BatchActor;
import org.deeplearning4j.iterativereduce.actor.core.actor.ModelSavingActor;
import org.deeplearning4j.iterativereduce.actor.core.actor.SimpleClusterListener;
import org.deeplearning4j.iterativereduce.actor.core.api.EpochDoneListener;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.single.UpdateableSingleImpl;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/single/ActorNetworkRunner.class */
public class ActorNetworkRunner implements DeepLearningConfigurable, EpochDoneListener<UpdateableSingleImpl> {
    private static final long serialVersionUID = -4385335922485305364L;
    private transient ActorSystem system;
    private Integer currEpochs;
    private Integer epochs;
    private List<Pair<DoubleMatrix, DoubleMatrix>> samples;
    private UpdateableSingleImpl result;
    private ActorRef mediator;
    private static Logger log = LoggerFactory.getLogger(ActorNetworkRunner.class);
    private static String systemName = "ClusterSystem";
    private String type;
    private Address masterAddress;
    private DataSetIterator iter;

    public ActorNetworkRunner(String str, DataSetIterator dataSetIterator) {
        this.currEpochs = 0;
        this.type = "master";
        this.type = str;
        this.iter = dataSetIterator;
    }

    public ActorNetworkRunner(String str, String str2) {
        this.currEpochs = 0;
        this.type = "master";
        this.type = str;
        URI create = URI.create(str2);
        this.masterAddress = Address.apply(create.getScheme(), create.getUserInfo(), create.getHost(), create.getPort());
    }

    public ActorNetworkRunner() {
        this.currEpochs = 0;
        this.type = "master";
    }

    public static Address startBackend(Address address, String str, Conf conf, DataSetIterator dataSetIterator) {
        ActorSystem create = ActorSystem.create(systemName, ConfigFactory.parseString("akka.cluster.roles=[" + str + "]").withFallback(ConfigFactory.load()));
        ActorRef actorOf = create.actorOf(Props.create(new BatchActor.BatchActorFactory(dataSetIterator, conf.getNumPasses())));
        Address selfAddress = address == null ? Cluster.get(create).selfAddress() : address;
        Cluster.get(create).join(selfAddress);
        create.actorOf(ClusterSingletonManager.defaultProps(MasterActor.propsFor(conf, actorOf), "active", PoisonPill.getInstance(), "master"));
        return selfAddress;
    }

    public static void startWorker(Address address, Conf conf) {
        ActorSystem create = ActorSystem.create(systemName);
        HashSet hashSet = new HashSet();
        hashSet.add(create.actorSelection(address + "/user/receptionist"));
        create.actorOf(WorkerActor.propsFor(create.actorOf(ClusterClient.defaultProps(hashSet), "clusterClient"), conf), "worker");
        try {
            Thread.sleep(5000L);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        Cluster.get(create).join(address);
    }

    public void setup(Conf conf) {
        this.system = ActorSystem.create(systemName);
        this.system.actorOf(Props.create(SimpleClusterListener.class, new Object[0]), "clusterListener");
        this.mediator = DistributedPubSubExtension.get(this.system).mediator();
        this.epochs = Integer.valueOf(conf.getPretrainEpochs());
        if (this.type.equals("master")) {
            if (this.iter == null) {
                throw new IllegalStateException("Unable to initialize no dataset to train");
            }
            this.masterAddress = startBackend(null, "master", conf, this.iter);
            try {
                Thread.sleep(30000L);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, this), this.mediator);
            return;
        }
        startWorker(this.masterAddress, conf.copy());
        try {
            Thread.sleep(5000L);
        } catch (InterruptedException e2) {
            Thread.currentThread().interrupt();
        }
        Cluster.get(this.system).join(this.masterAddress);
        log.info("Setup worker nodes");
    }

    public void train(List<Pair<DoubleMatrix, DoubleMatrix>> list) {
        this.samples = list;
        log.info("Publishing to results for training");
        try {
            Thread.sleep(30000L);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, this), this.mediator);
        this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, list), this.mediator);
    }

    public void train(Pair<DoubleMatrix, DoubleMatrix> pair) {
        train(new ArrayList(Arrays.asList(pair)));
    }

    public void train(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        train(new Pair<>(doubleMatrix, doubleMatrix2));
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.api.EpochDoneListener
    public void epochComplete(UpdateableSingleImpl updateableSingleImpl) {
        Integer num = this.currEpochs;
        this.currEpochs = Integer.valueOf(this.currEpochs.intValue() + 1);
        if (this.currEpochs.intValue() < this.epochs.intValue()) {
            this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.BROADCAST, updateableSingleImpl), this.mediator);
            log.info("Updating result on epoch " + this.currEpochs);
            try {
                Thread.sleep(15000L);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            log.info("Starting next epoch");
            this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, this.samples), this.mediator);
        } else if (this.iter.hasNext()) {
            train((Pair<DoubleMatrix, DoubleMatrix>) this.iter.next());
        }
        this.result = updateableSingleImpl;
        this.mediator.tell(new DistributedPubSubMediator.Publish(ModelSavingActor.SAVE, updateableSingleImpl), this.mediator);
    }

    public UpdateableSingleImpl getResult() {
        return this.result;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.api.EpochDoneListener
    public void finish() {
    }

    public Address getMasterAddress() {
        return this.masterAddress;
    }
}
