package org.deeplearning4j.iterativereduce.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.OneForOneStrategy;
import akka.actor.SupervisorStrategy;
import akka.actor.UntypedActor;
import akka.cluster.Cluster;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import akka.dispatch.OnComplete;
import akka.event.Logging;
import akka.event.LoggingAdapter;
import akka.japi.Function;
import com.google.common.collect.Lists;
import java.io.DataOutputStream;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.iterativereduce.actor.core.api.EpochDoneListener;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.ComputableMaster;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import org.jblas.DoubleMatrix;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/core/actor/MasterActor.class */
public abstract class MasterActor<E extends Updateable<?>> extends UntypedActor implements DeepLearningConfigurable, ComputableMaster<E> {
    protected Conf conf;
    protected LoggingAdapter log;
    protected E masterResults;
    protected List<E> updates;
    protected EpochDoneListener<E> listener;
    protected ActorRef batchActor;
    protected int epochsComplete;
    protected final ActorRef mediator;
    public static String BROADCAST = "broadcast";
    public static String MASTER = "result";
    public static String SHUTDOWN = "shutdown";
    public static String FINISH = "finish";
    Cluster cluster;
    protected Set<String> workerIds;
    protected Map<String, WorkerState> workers;
    protected int partition;
    protected boolean isDone;

    public MasterActor(Conf conf, ActorRef actorRef, Object[] objArr) {
        this.log = Logging.getLogger(getContext().system(), this);
        this.updates = new ArrayList();
        this.mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
        this.cluster = Cluster.get(getContext().system());
        this.workerIds = new HashSet();
        this.workers = new HashMap();
        this.partition = 1;
        this.isDone = false;
        this.conf = conf;
        this.batchActor = actorRef;
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MASTER, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(FINISH, getSelf()), getSelf());
        setup(conf);
    }

    public MasterActor(Conf conf, ActorRef actorRef) {
        this(conf, actorRef, null);
    }

    public void preStart() throws Exception {
        super.preStart();
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.log.info("Setup master with path " + self().path());
        this.log.info("Pre start on master " + self().path().toString());
    }

    public void postStop() throws Exception {
        super.postStop();
        this.log.info("Post stop on master");
        this.cluster.unsubscribe(getSelf());
    }

    public abstract E compute(Collection<E> collection, Collection<E> collection2);

    public abstract void setup(Conf conf);

    /* JADX INFO: Access modifiers changed from: protected */
    public void sendToWorkers(List<DataSet> list) {
        final List partition = Lists.partition(list, this.conf.getSplit());
        this.partition = partition.size();
        this.log.info("Found partition of size " + this.partition);
        for (int i = 0; i < partition.size(); i++) {
            final int i2 = i;
            Futures.future(new Callable<Void>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.1
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public Void call() throws Exception {
                    MasterActor.this.log.info("Sending off work for batch " + i2);
                    boolean z = false;
                    while (!z) {
                        Iterator<WorkerState> it = MasterActor.this.workers.values().iterator();
                        while (true) {
                            if (it.hasNext()) {
                                WorkerState next = it.next();
                                if (next.isAvailable()) {
                                    next.getRef().tell(new ArrayList((Collection) partition.get(i2)), MasterActor.this.getSelf());
                                    MasterActor.this.mediator.tell(new DistributedPubSubMediator.Publish(next.getWorkerId(), new ArrayList((Collection) partition.get(i2))), MasterActor.this.getSelf());
                                    MasterActor.this.log.info("Delegated work to worker " + next.getWorkerId());
                                    next.setAvailable(false);
                                    z = true;
                                    break;
                                }
                            }
                        }
                    }
                    return null;
                }
            }, context().dispatcher()).onComplete(new OnComplete<Void>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.2
                public void onComplete(Throwable th, Void r4) throws Throwable {
                    if (th != null) {
                        throw th;
                    }
                }
            }, context().dispatcher());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void splitListIntoRows(List<DataSet> list) {
        ArrayDeque arrayDeque = new ArrayDeque(list);
        list.clear();
        this.log.info("Splitting list in to rows...");
        while (!arrayDeque.isEmpty()) {
            DataSet dataSet = (DataSet) arrayDeque.poll();
            List rowsAsList = ((DoubleMatrix) dataSet.getFirst()).rowsAsList();
            List rowsAsList2 = ((DoubleMatrix) dataSet.getSecond()).rowsAsList();
            if (rowsAsList.isEmpty()) {
                throw new IllegalArgumentException("No input rows found");
            }
            if (rowsAsList.size() != rowsAsList2.size()) {
                throw new IllegalArgumentException("Label rows not equal to input rows");
            }
            for (int i = 0; i < rowsAsList.size(); i++) {
                list.add(new DataSet((DoubleMatrix) rowsAsList.get(i), (DoubleMatrix) rowsAsList2.get(i)));
            }
        }
    }

    public void addWorker(WorkerState workerState) {
        if (this.workers.containsKey(workerState.getWorkerId())) {
            return;
        }
        this.workers.put(workerState.getWorkerId(), workerState);
        this.log.info("Added worker with id " + workerState.getWorkerId());
    }

    public abstract void complete(DataOutputStream dataOutputStream);

    public synchronized E getResults() {
        return this.masterResults;
    }

    public SupervisorStrategy supervisorStrategy() {
        return new OneForOneStrategy(0, Duration.Zero(), new Function<Throwable, SupervisorStrategy.Directive>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.3
            public SupervisorStrategy.Directive apply(Throwable th) {
                MasterActor.this.log.error("Problem with processing", th);
                return SupervisorStrategy.stop();
            }
        });
    }

    public Conf getConf() {
        return this.conf;
    }

    public int getEpochsComplete() {
        return this.epochsComplete;
    }

    public int getPartition() {
        return this.partition;
    }

    public E getMasterResults() {
        return this.masterResults;
    }

    public boolean isDone() {
        return this.isDone;
    }

    public List<E> getUpdates() {
        return this.updates;
    }

    public EpochDoneListener<E> getListener() {
        return this.listener;
    }

    public ActorRef getBatchActor() {
        return this.batchActor;
    }

    public ActorRef getMediator() {
        return this.mediator;
    }
}
