package org.deeplearning4j.iterativereduce.actor.multilayer;

import akka.actor.ActorRef;
import akka.actor.Props;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.routing.RoundRobinPool;
import java.io.DataOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.iterativereduce.actor.core.Ack;
import org.deeplearning4j.iterativereduce.actor.core.MoreWorkMessage;
import org.deeplearning4j.iterativereduce.actor.core.NeedsModelMessage;
import org.deeplearning4j.iterativereduce.actor.core.actor.WorkerState;
import org.deeplearning4j.iterativereduce.actor.core.api.EpochDoneListener;
import org.deeplearning4j.iterativereduce.akka.DeepLearningAccumulator;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.iterativereduce.multi.UpdateableImpl;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/multilayer/MasterActor.class */
public class MasterActor extends org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor<UpdateableImpl> {
    protected BaseMultiLayerNetwork network;

    public MasterActor(Conf conf, ActorRef actorRef) {
        super(conf, actorRef);
    }

    public static Props propsFor(Conf conf, ActorRef actorRef) {
        return Props.create(MasterActor.class, new Object[]{conf, actorRef});
    }

    public MasterActor(Conf conf, ActorRef actorRef, BaseMultiLayerNetwork baseMultiLayerNetwork) {
        super(conf, actorRef, new Object[]{baseMultiLayerNetwork});
    }

    public static Props propsFor(Conf conf, ActorRef actorRef, BaseMultiLayerNetwork baseMultiLayerNetwork) {
        return Props.create(MasterActor.class, new Object[]{conf, actorRef, baseMultiLayerNetwork});
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor
    public synchronized UpdateableImpl compute(Collection<UpdateableImpl> collection, Collection<UpdateableImpl> collection2) {
        DeepLearningAccumulator deepLearningAccumulator = new DeepLearningAccumulator();
        Iterator<UpdateableImpl> it = collection.iterator();
        while (it.hasNext()) {
            deepLearningAccumulator.accumulate(it.next().get());
        }
        if (this.masterResults == 0) {
            this.masterResults = new UpdateableImpl(deepLearningAccumulator.averaged());
        } else {
            this.masterResults.set(deepLearningAccumulator.averaged());
        }
        return this.masterResults;
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor
    public void setup(Conf conf) {
        this.log.info("Starting workers");
        context().system().actorOf(new RoundRobinPool(Runtime.getRuntime().availableProcessors()).props(WorkerActor.propsFor(conf)), "worker");
        try {
            Thread.sleep(30000L);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        this.log.info("Broadcasting initial master network");
        this.masterResults = new UpdateableImpl(this.network == null ? new BaseMultiLayerNetwork.Builder().numberOfInputs(conf.getnIn()).numberOfOutPuts(conf.getnOut()).withClazz(conf.getMultiLayerClazz()).hiddenLayerSizes(conf.getLayerSizes()).renderWeights(conf.getRenderWeightEpochs()).useRegularization(conf.isUseRegularization()).withSparsity(conf.getSparsity()).build() : this.network);
        this.mediator.tell(new DistributedPubSubMediator.Publish(BROADCAST, this.masterResults), getSelf());
    }

    public void onReceive(Object obj) throws Exception {
        if (obj instanceof DistributedPubSubMediator.SubscribeAck) {
            this.log.info("Subscribed " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
            return;
        }
        if (obj instanceof EpochDoneListener) {
            this.listener = (EpochDoneListener) obj;
            this.log.info("Set listener");
            return;
        }
        if (obj instanceof WorkerState) {
            addWorker((WorkerState) obj);
            return;
        }
        if (obj instanceof NeedsModelMessage) {
            this.log.info("Sending networks over");
            getSender().tell(this.masterResults.get(), getSelf());
            return;
        }
        if (obj instanceof String) {
            WorkerState workerState = this.workers.get(obj.toString());
            if (workerState == null) {
                WorkerState workerState2 = new WorkerState(obj.toString(), getSender());
                workerState2.setAvailable(true);
                this.log.info("Worker " + workerState2.getWorkerId() + " available for work");
                getSender().tell("", getSelf());
            } else {
                workerState.setAvailable(true);
                this.log.info("Worker " + workerState.getWorkerId() + " available for work");
            }
            getSender().tell(new Ack(), getSelf());
            return;
        }
        if (obj instanceof UpdateableImpl) {
            this.updates.add((UpdateableImpl) obj);
            this.log.info("Num updates so far " + this.updates.size() + " and partition size is " + this.partition);
            if (this.updates.size() >= this.partition) {
                this.masterResults = compute(this.updates, this.updates);
                if (this.listener != null) {
                    this.listener.epochComplete(this.masterResults);
                }
                this.epochsComplete++;
                this.batchActor.tell(new MoreWorkMessage(this.masterResults), getSelf());
                this.updates.clear();
                this.log.info("Broadcasting weights");
                this.mediator.tell(new DistributedPubSubMediator.Publish(BROADCAST, this.masterResults), getSelf());
                return;
            }
            return;
        }
        if (!(obj instanceof List) && !(obj instanceof Pair)) {
            unhandled(obj);
            return;
        }
        if (obj instanceof List) {
            List<DataSet> list = (List) obj;
            splitListIntoRows(list);
            sendToWorkers(list);
        } else if (obj instanceof Pair) {
            DataSet dataSet = (DataSet) obj;
            List rowsAsList = ((DoubleMatrix) dataSet.getFirst()).rowsAsList();
            List rowsAsList2 = ((DoubleMatrix) dataSet.getSecond()).rowsAsList();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < rowsAsList.size(); i++) {
                arrayList.add(new DataSet((DoubleMatrix) rowsAsList.get(i), (DoubleMatrix) rowsAsList2.get(i)));
            }
            sendToWorkers(arrayList);
        }
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor
    public void complete(DataOutputStream dataOutputStream) {
        this.masterResults.get().write(dataOutputStream);
    }
}
