package org.deeplearning4j.iterativereduce.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.OneForOneStrategy;
import akka.actor.SupervisorStrategy;
import akka.actor.UntypedActor;
import akka.cluster.Cluster;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import akka.dispatch.OnComplete;
import akka.event.Logging;
import akka.event.LoggingAdapter;
import akka.japi.Creator;
import akka.japi.Function;
import com.google.common.collect.Lists;
import java.io.DataOutputStream;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.iterativereduce.actor.core.FinishMessage;
import org.deeplearning4j.iterativereduce.actor.core.ResetMessage;
import org.deeplearning4j.iterativereduce.actor.core.UpdateMessage;
import org.deeplearning4j.iterativereduce.actor.core.api.EpochDoneListener;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.ComputableMaster;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import org.jblas.DoubleMatrix;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/core/actor/MasterActor.class */
public abstract class MasterActor<E extends Updateable<?>> extends UntypedActor implements DeepLearningConfigurable, ComputableMaster<E> {
    protected Conf conf;
    protected E masterResults;
    protected EpochDoneListener<E> listener;
    protected ActorRef batchActor;
    protected int epochsComplete;
    public static String BROADCAST = "broadcast";
    public static String MASTER = "result";
    public static String SHUTDOWN = "shutdown";
    public static String FINISH = "finish";
    protected LoggingAdapter log = Logging.getLogger(getContext().system(), this);
    protected List<E> updates = new ArrayList();
    protected final ActorRef mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
    Cluster cluster = Cluster.get(getContext().system());
    protected int partition = 1;
    protected boolean isDone = false;

    /* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/core/actor/MasterActor$MasterActorFactory.class */
    public static abstract class MasterActorFactory<E> implements Creator<MasterActor<Updateable<E>>> {
        protected Conf conf;
        protected ActorRef batchActor;
        private static final long serialVersionUID = 1932205634961409897L;

        public MasterActorFactory(Conf conf, ActorRef actorRef) {
            this.conf = conf;
            this.batchActor = actorRef;
        }

        /* renamed from: create, reason: merged with bridge method [inline-methods] */
        public abstract MasterActor<Updateable<E>> m4create() throws Exception;
    }

    public MasterActor(Conf conf, ActorRef actorRef) {
        this.conf = conf;
        this.batchActor = actorRef;
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MASTER, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(FINISH, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Publish(DoneReaper.REAPER, getSelf()), this.mediator);
        setup(conf);
    }

    public void preStart() throws Exception {
        super.preStart();
        this.log.info("Pre start on master");
    }

    public void postStop() throws Exception {
        super.postStop();
        this.log.info("Post stop on master");
        this.cluster.unsubscribe(getSelf());
    }

    public abstract E compute(Collection<E> collection, Collection<E> collection2);

    public abstract void setup(Conf conf);

    public void onReceive(Object obj) throws Exception {
        if (obj instanceof DistributedPubSubMediator.SubscribeAck) {
            this.log.info("Subscribed " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
            return;
        }
        if (obj instanceof EpochDoneListener) {
            this.listener = (EpochDoneListener) obj;
            this.log.info("Set listener");
            return;
        }
        if (obj instanceof Updateable) {
            Updateable updateable = (Updateable) obj;
            this.updates.add(updateable);
            if (this.updates.size() >= this.partition) {
                this.masterResults = compute(this.updates, this.updates);
                if (this.listener != null) {
                    this.listener.epochComplete(this.masterResults);
                }
                this.batchActor.tell(new ResetMessage(), getSelf());
                this.epochsComplete++;
                this.batchActor.tell(updateable, getSelf());
                this.updates.clear();
                if (this.epochsComplete == this.conf.getNumPasses()) {
                    this.isDone = true;
                    return;
                }
                return;
            }
            return;
        }
        if (obj instanceof FinishMessage) {
            if (!this.updates.isEmpty()) {
                this.masterResults = compute(this.updates, this.updates);
                if (this.listener != null) {
                    this.listener.epochComplete(this.masterResults);
                }
            }
            this.isDone = true;
            this.log.info("All done; shutting down");
            return;
        }
        if (obj instanceof UpdateMessage) {
            this.mediator.tell(new DistributedPubSubMediator.Publish(BROADCAST, obj), getSelf());
            return;
        }
        if (!(obj instanceof List) && !(obj instanceof Pair)) {
            unhandled(obj);
            return;
        }
        if (obj instanceof List) {
            List<Pair<DoubleMatrix, DoubleMatrix>> list = (List) obj;
            splitListIntoRows(list);
            sendToWorkers(list);
        } else if (obj instanceof Pair) {
            Pair pair = (Pair) obj;
            List rowsAsList = ((DoubleMatrix) pair.getFirst()).rowsAsList();
            List rowsAsList2 = ((DoubleMatrix) pair.getSecond()).rowsAsList();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < rowsAsList.size(); i++) {
                arrayList.add(new Pair<>(rowsAsList.get(i), rowsAsList2.get(i)));
            }
            sendToWorkers(arrayList);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void sendToWorkers(List<Pair<DoubleMatrix, DoubleMatrix>> list) {
        final List partition = Lists.partition(list, this.conf.getSplit());
        this.partition = partition.size();
        this.log.info("Found partition of size " + this.partition);
        for (int i = 0; i < partition.size(); i++) {
            final int i2 = i;
            Futures.future(new Callable<Void>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.1
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public Void call() throws Exception {
                    MasterActor.this.log.info("Sending off work for batch " + i2);
                    MasterActor.this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.BROADCAST, new ArrayList((Collection) partition.get(i2))), MasterActor.this.getSelf());
                    return null;
                }
            }, context().dispatcher()).onComplete(new OnComplete<Void>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.2
                public void onComplete(Throwable th, Void r4) throws Throwable {
                    if (th != null) {
                        throw th;
                    }
                }
            }, context().dispatcher());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void splitListIntoRows(List<Pair<DoubleMatrix, DoubleMatrix>> list) {
        ArrayDeque arrayDeque = new ArrayDeque(list);
        list.clear();
        this.log.info("Splitting list in to rows...");
        while (!arrayDeque.isEmpty()) {
            Pair pair = (Pair) arrayDeque.poll();
            List rowsAsList = ((DoubleMatrix) pair.getFirst()).rowsAsList();
            List rowsAsList2 = ((DoubleMatrix) pair.getSecond()).rowsAsList();
            if (rowsAsList.isEmpty()) {
                throw new IllegalArgumentException("No input rows found");
            }
            if (rowsAsList.size() != rowsAsList2.size()) {
                throw new IllegalArgumentException("Label rows not equal to input rows");
            }
            for (int i = 0; i < rowsAsList.size(); i++) {
                list.add(new Pair<>(rowsAsList.get(i), rowsAsList2.get(i)));
            }
        }
    }

    public abstract void complete(DataOutputStream dataOutputStream);

    public E getResults() {
        return this.masterResults;
    }

    public SupervisorStrategy supervisorStrategy() {
        return new OneForOneStrategy(0, Duration.Zero(), new Function<Throwable, SupervisorStrategy.Directive>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor.3
            public SupervisorStrategy.Directive apply(Throwable th) {
                MasterActor.this.log.error("Problem with processing", th);
                return SupervisorStrategy.stop();
            }
        });
    }

    public Conf getConf() {
        return this.conf;
    }

    public int getEpochsComplete() {
        return this.epochsComplete;
    }

    public int getPartition() {
        return this.partition;
    }

    public E getMasterResults() {
        return this.masterResults;
    }

    public boolean isDone() {
        return this.isDone;
    }

    public List<E> getUpdates() {
        return this.updates;
    }

    public EpochDoneListener<E> getListener() {
        return this.listener;
    }

    public ActorRef getBatchActor() {
        return this.batchActor;
    }

    public ActorRef getMediator() {
        return this.mediator;
    }

    public static String getBROADCAST() {
        return BROADCAST;
    }

    public static String getRESULT() {
        return MASTER;
    }
}
