package org.deeplearning4j.iterativereduce.actor.multilayer;

import akka.actor.ActorRef;
import akka.actor.PoisonPill;
import akka.contrib.pattern.ClusterSingletonManager;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.routing.RoundRobinPool;
import java.io.DataOutputStream;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.iterativereduce.actor.core.Ack;
import org.deeplearning4j.iterativereduce.actor.core.DoneMessage;
import org.deeplearning4j.iterativereduce.actor.core.Job;
import org.deeplearning4j.iterativereduce.actor.core.MoreWorkMessage;
import org.deeplearning4j.iterativereduce.actor.core.NoJobFound;
import org.deeplearning4j.iterativereduce.actor.core.ResetMessage;
import org.deeplearning4j.iterativereduce.akka.DeepLearningAccumulator;
import org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.iterativereduce.multi.UpdateableImpl;
import org.deeplearning4j.util.SerializationUtils;
import org.jblas.DoubleMatrix;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/multilayer/MasterActor.class */
public class MasterActor extends org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor<UpdateableImpl> {
    protected BaseMultiLayerNetwork network;
    protected AtomicLong lastUpdated;

    public MasterActor(Conf conf, ActorRef actorRef, HazelCastStateTracker hazelCastStateTracker) {
        super(conf, actorRef, hazelCastStateTracker);
        this.lastUpdated = new AtomicLong(System.currentTimeMillis());
        setup(conf);
        this.forceNextPhase = context().system().scheduler().schedule(Duration.create(1L, TimeUnit.MINUTES), Duration.create(1L, TimeUnit.MINUTES), new Runnable() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.MasterActor.1
            @Override // java.lang.Runnable
            public void run() {
                try {
                    List<Job> currentJobs = MasterActor.this.stateTracker.currentJobs();
                    MasterActor.this.log.info("Status check on next iteration");
                    if (MasterActor.this.updates.size() >= MasterActor.this.partition || currentJobs.isEmpty()) {
                        MasterActor.this.nextIteration();
                    }
                    MasterActor.this.log.info("Current jobs left " + currentJobs);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }, context().dispatcher());
        this.clearStateWorkers = context().system().scheduler().schedule(Duration.create(1L, TimeUnit.MINUTES), Duration.create(1L, TimeUnit.MINUTES), new Runnable() { // from class: org.deeplearning4j.iterativereduce.actor.multilayer.MasterActor.2
            @Override // java.lang.Runnable
            public void run() {
                try {
                    long currentTimeMillis = System.currentTimeMillis();
                    Map<String, Long> heartBeats = MasterActor.this.stateTracker.getHeartBeats();
                    for (String str : heartBeats.keySet()) {
                        if (TimeUnit.MILLISECONDS.toSeconds(currentTimeMillis - heartBeats.get(str).longValue()) >= 30) {
                            MasterActor.this.log.info("Removing stale worker " + str);
                            MasterActor.this.stateTracker.removeWorker(str);
                            MasterActor.access$810(MasterActor.this);
                        }
                    }
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
        }, context().dispatcher());
    }

    public MasterActor(Conf conf, ActorRef actorRef, BaseMultiLayerNetwork baseMultiLayerNetwork, HazelCastStateTracker hazelCastStateTracker) {
        super(conf, actorRef, hazelCastStateTracker);
        this.lastUpdated = new AtomicLong(System.currentTimeMillis());
        this.network = baseMultiLayerNetwork;
        setup(conf);
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor
    public UpdateableImpl compute(Collection<UpdateableImpl> collection, Collection<UpdateableImpl> collection2) {
        DeepLearningAccumulator deepLearningAccumulator = new DeepLearningAccumulator();
        Iterator<UpdateableImpl> it = collection.iterator();
        while (it.hasNext()) {
            deepLearningAccumulator.accumulate(it.next().get());
        }
        UpdateableImpl results = getResults();
        if (results == null) {
            results = new UpdateableImpl(deepLearningAccumulator.averaged());
        } else {
            results.set(deepLearningAccumulator.averaged());
        }
        try {
            this.stateTracker.setCurrent(results);
            return results;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor
    public void setup(Conf conf) {
        this.log.info("Starting workers");
        context().system().actorOf(ClusterSingletonManager.defaultProps(new RoundRobinPool(Runtime.getRuntime().availableProcessors()).props(WorkerActor.propsFor(conf, this.stateTracker)), "master", PoisonPill.getInstance(), "master"), "worker");
        try {
            Thread.sleep(30000L);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        this.log.info("Broadcasting initial master network");
        BaseMultiLayerNetwork build = this.network == null ? new BaseMultiLayerNetwork.Builder().numberOfInputs(conf.getnIn()).numberOfOutPuts(conf.getnOut()).withClazz(conf.getMultiLayerClazz()).hiddenLayerSizes(conf.getLayerSizes()).renderWeights(conf.getRenderWeightEpochs()).useRegularization(conf.isUseRegularization()).withDropOut(conf.getDropOut()).withLossFunction(conf.getLossFunction()).withSparsity(conf.getSparsity()).useAdaGrad(conf.isUseAdaGrad()).withOptimizationAlgorithm(conf.getOptimizationAlgorithm()).withMultiLayerGradientListeners(conf.getMultiLayerGradientListeners()).withGradientListeners(conf.getGradientListeners()).build() : this.network;
        build.synchonrizeRng();
        if (conf.getColumnMeans() != null) {
            build.setColumnMeans(conf.getColumnMeans());
        }
        if (conf.getColumnStds() != null) {
            build.setColumnStds(conf.getColumnStds());
        }
        UpdateableImpl updateableImpl = new UpdateableImpl(build);
        try {
            this.stateTracker.setCurrent(updateableImpl);
            this.mediator.tell(new DistributedPubSubMediator.Publish(BROADCAST, updateableImpl), getSelf());
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    protected void nextIteration() throws Exception {
        if (this.updates.isEmpty()) {
            return;
        }
        UpdateableImpl compute = compute(this.updates, this.updates);
        this.epochsComplete++;
        if (!isDone()) {
            this.batchActor.tell(new MoreWorkMessage(compute), getSelf());
        }
        this.updates.clear();
        this.log.info("Broadcasting weights");
        this.mediator.tell(new DistributedPubSubMediator.Publish(BROADCAST, compute), getSelf());
        this.stateTracker.setCurrent(compute);
    }

    protected void checkDone() throws Exception {
        UpdateableImpl masterResults;
        if (this.updates.isEmpty()) {
            masterResults = getMasterResults();
        } else {
            masterResults = compute(this.updates, this.updates);
            this.stateTracker.setCurrent(masterResults);
            this.epochsComplete++;
            this.updates.clear();
        }
        if (!this.stateTracker.isPretrain() || !this.stateTracker.currentJobs().isEmpty()) {
            if (this.stateTracker.currentJobs().isEmpty()) {
                this.isDone = true;
                this.stateTracker.finish();
                this.log.info("Done training!");
                return;
            }
            return;
        }
        this.log.info("Switching to finetune mode");
        this.pretrain = false;
        this.stateTracker.moveToFinetune();
        SerializationUtils.saveObject(masterResults.get(), new File("pretrain-model.bin"));
        this.batchActor.tell(ResetMessage.getInstance(), getSelf());
        this.batchActor.tell(new MoreWorkMessage(masterResults), getSelf());
    }

    public void onReceive(Object obj) throws Exception {
        if ((obj instanceof DistributedPubSubMediator.SubscribeAck) || (obj instanceof DistributedPubSubMediator.UnsubscribeAck)) {
            this.mediator.tell(new DistributedPubSubMediator.Publish("topics", obj), getSelf());
            this.log.info("Subscribed " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
            return;
        }
        if (obj instanceof NoJobFound) {
            this.partition--;
            if (this.updates.size() >= this.partition) {
                nextIteration();
                return;
            }
            return;
        }
        if (obj instanceof DoneMessage) {
            this.log.info("Received done message");
            checkDone();
            return;
        }
        if (obj instanceof String) {
            getSender().tell(Ack.getInstance(), getSelf());
            return;
        }
        if (obj instanceof UpdateableImpl) {
            this.updates.add((UpdateableImpl) obj);
            this.log.info("Num updates so far " + this.updates.size() + " and partition size is " + this.partition);
            if (this.updates.size() >= this.partition) {
                nextIteration();
                return;
            }
            return;
        }
        if (!(obj instanceof List) && !(obj instanceof DataSet)) {
            unhandled(obj);
            return;
        }
        if (obj instanceof List) {
            List<DataSet> list = (List) obj;
            splitListIntoRows(list);
            sendToWorkers(list);
        } else if (obj instanceof DataSet) {
            DataSet dataSet = (DataSet) obj;
            List rowsAsList = ((DoubleMatrix) dataSet.getFirst()).rowsAsList();
            List rowsAsList2 = ((DoubleMatrix) dataSet.getSecond()).rowsAsList();
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < rowsAsList.size(); i++) {
                arrayList.add(new DataSet((DoubleMatrix) rowsAsList.get(i), (DoubleMatrix) rowsAsList2.get(i)));
            }
            sendToWorkers(arrayList);
        }
    }

    @Override // org.deeplearning4j.iterativereduce.actor.core.actor.MasterActor
    public void complete(DataOutputStream dataOutputStream) {
        getMasterResults().get().write(dataOutputStream);
    }

    static /* synthetic */ int access$810(MasterActor masterActor) {
        int i = masterActor.partition;
        masterActor.partition = i - 1;
        return i;
    }
}
