package org.deeplearning4j.iterativereduce.actor.core.actor;

import akka.actor.ActorRef;
import akka.actor.OneForOneStrategy;
import akka.actor.SupervisorStrategy;
import akka.actor.UntypedActor;
import akka.contrib.pattern.DistributedPubSubExtension;
import akka.contrib.pattern.DistributedPubSubMediator;
import akka.dispatch.Futures;
import akka.dispatch.OnComplete;
import akka.japi.Creator;
import akka.japi.Function;
import java.util.List;
import java.util.concurrent.Callable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.iterativereduce.actor.core.UpdateMessage;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.ComputableWorker;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.concurrent.duration.Duration;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/core/actor/WorkerActor.class */
public abstract class WorkerActor<E extends Updateable<?>> extends UntypedActor implements DeepLearningConfigurable, ComputableWorker<E> {
    protected DoubleMatrix combinedInput;
    protected DoubleMatrix outcomes;
    protected ActorRef mediator = DistributedPubSubExtension.get(getContext().system()).mediator();
    protected E e;
    protected E results;
    private static Logger log = LoggerFactory.getLogger(WorkerActor.class);
    protected int fineTuneEpochs;
    protected int preTrainEpochs;
    protected int[] hiddenLayerSizes;
    protected int numHidden;
    protected int numVisible;
    protected int numHiddenNeurons;
    protected long seed;
    protected double learningRate;
    protected double corruptionLevel;
    protected Object[] extraParams;
    protected boolean useRegularization;
    public static final String SYSTEM_NAME = "Workers";

    /* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/core/actor/WorkerActor$WorkerActorFactory.class */
    public static abstract class WorkerActorFactory<E> implements Creator<WorkerActor<Updateable<E>>> {
        private static final long serialVersionUID = 381253681712601968L;
        private Conf conf;

        public WorkerActorFactory(Conf conf) {
            this.conf = conf;
        }

        /* renamed from: create, reason: merged with bridge method [inline-methods] */
        public abstract WorkerActor<Updateable<E>> m7create() throws Exception;
    }

    public WorkerActor(Conf conf) {
        setup(conf);
        this.mediator.tell(new DistributedPubSubMediator.Put(getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MasterActor.BROADCAST, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Subscribe(MasterActor.SHUTDOWN, getSelf()), getSelf());
        this.mediator.tell(new DistributedPubSubMediator.Publish(DoneReaper.REAPER, getSelf()), this.mediator);
    }

    public void onReceive(Object obj) throws Exception {
        if (obj instanceof DistributedPubSubMediator.SubscribeAck) {
            log.info("Subscribed to " + ((DistributedPubSubMediator.SubscribeAck) obj).toString());
            return;
        }
        if (obj instanceof List) {
            updateTraining((List) obj);
        } else if (obj instanceof UpdateMessage) {
            this.results = (E) ((UpdateMessage) obj).getUpdateable().get();
        } else {
            unhandled(obj);
        }
    }

    private void updateTraining(List<Pair<DoubleMatrix, DoubleMatrix>> list) {
        DoubleMatrix doubleMatrix = new DoubleMatrix(list.size(), ((DoubleMatrix) list.get(0).getFirst()).columns);
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(list.size(), ((DoubleMatrix) list.get(0).getSecond()).columns);
        for (int i = 0; i < list.size(); i++) {
            doubleMatrix.putRow(i, (DoubleMatrix) list.get(i).getFirst());
            doubleMatrix2.putRow(i, (DoubleMatrix) list.get(i).getSecond());
        }
        this.combinedInput = doubleMatrix;
        this.outcomes = doubleMatrix2;
        Futures.future(new Callable<E>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor.1
            @Override // java.util.concurrent.Callable
            public E call() throws Exception {
                return (E) WorkerActor.this.compute();
            }
        }, getContext().dispatcher()).onComplete(new OnComplete<E>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor.2
            public void onComplete(Throwable th, E e) throws Throwable {
                WorkerActor.log.info("Updating parent actor...");
                WorkerActor.this.mediator.tell(new DistributedPubSubMediator.Publish(MasterActor.MASTER, e), WorkerActor.this.getSelf());
            }
        }, context().dispatcher());
    }

    public E compute(List<E> list) {
        return compute();
    }

    public abstract E compute();

    public boolean incrementIteration() {
        return false;
    }

    public void setup(Conf conf) {
        this.hiddenLayerSizes = conf.getLayerSizes();
        this.numHidden = conf.getnOut();
        this.numVisible = conf.getnIn();
        this.numHiddenNeurons = this.hiddenLayerSizes.length;
        this.seed = conf.getSeed();
        this.useRegularization = conf.isUseRegularization();
        this.learningRate = conf.getPretrainLearningRate();
        this.preTrainEpochs = conf.getPretrainEpochs();
        this.fineTuneEpochs = conf.getFinetuneEpochs();
        this.corruptionLevel = conf.getCorruptionLevel();
        this.extraParams = conf.getDeepLearningParams();
    }

    public SupervisorStrategy supervisorStrategy() {
        return new OneForOneStrategy(0, Duration.Zero(), new Function<Throwable, SupervisorStrategy.Directive>() { // from class: org.deeplearning4j.iterativereduce.actor.core.actor.WorkerActor.3
            public SupervisorStrategy.Directive apply(Throwable th) {
                WorkerActor.log.error("Problem with processing", th);
                return SupervisorStrategy.stop();
            }
        });
    }

    public E getResults() {
        return this.results;
    }

    public void update(E e) {
        this.e = e;
    }

    public synchronized DoubleMatrix getCombinedInput() {
        return this.combinedInput;
    }

    public synchronized void setCombinedInput(DoubleMatrix doubleMatrix) {
        this.combinedInput = doubleMatrix;
    }

    public synchronized DoubleMatrix getOutcomes() {
        return this.outcomes;
    }

    public synchronized void setOutcomes(DoubleMatrix doubleMatrix) {
        this.outcomes = doubleMatrix;
    }

    public synchronized ActorRef getMediator() {
        return this.mediator;
    }

    public synchronized void setMediator(ActorRef actorRef) {
        this.mediator = actorRef;
    }

    public synchronized E getE() {
        return this.e;
    }

    public synchronized void setE(E e) {
        this.e = e;
    }

    public synchronized int getFineTuneEpochs() {
        return this.fineTuneEpochs;
    }

    public synchronized void setFineTuneEpochs(int i) {
        this.fineTuneEpochs = i;
    }

    public synchronized int getPreTrainEpochs() {
        return this.preTrainEpochs;
    }

    public synchronized void setPreTrainEpochs(int i) {
        this.preTrainEpochs = i;
    }

    public synchronized int[] getHiddenLayerSizes() {
        return this.hiddenLayerSizes;
    }

    public synchronized void setHiddenLayerSizes(int[] iArr) {
        this.hiddenLayerSizes = iArr;
    }

    public synchronized int getNumHidden() {
        return this.numHidden;
    }

    public synchronized void setNumHidden(int i) {
        this.numHidden = i;
    }

    public synchronized int getNumVisible() {
        return this.numVisible;
    }

    public synchronized void setNumVisible(int i) {
        this.numVisible = i;
    }

    public synchronized int getNumHiddenNeurons() {
        return this.numHiddenNeurons;
    }

    public synchronized void setNumHiddenNeurons(int i) {
        this.numHiddenNeurons = i;
    }

    public synchronized long getSeed() {
        return this.seed;
    }

    public synchronized void setSeed(long j) {
        this.seed = j;
    }

    public synchronized double getLearningRate() {
        return this.learningRate;
    }

    public synchronized void setLearningRate(double d) {
        this.learningRate = d;
    }

    public synchronized double getCorruptionLevel() {
        return this.corruptionLevel;
    }

    public synchronized void setCorruptionLevel(double d) {
        this.corruptionLevel = d;
    }

    public synchronized Object[] getExtraParams() {
        return this.extraParams;
    }

    public synchronized void setExtraParams(Object[] objArr) {
        this.extraParams = objArr;
    }

    public synchronized void setResults(E e) {
        this.results = e;
    }
}
