package org.deeplearning4j.iterativereduce.akka.gradient;

import java.util.List;
import org.apache.commons.math3.random.MersenneTwister;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.activation.ActivationFunction;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.conf.DeepLearningConfigurable;
import org.deeplearning4j.scaleout.iterativereduce.Updateable;
import org.deeplearning4j.scaleout.iterativereduce.multi.gradient.ComputableWorkerImpl;
import org.deeplearning4j.scaleout.iterativereduce.multi.gradient.UpdateableGradientImpl;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/akka/gradient/ComputableWorkerAkka.class */
public class ComputableWorkerAkka extends ComputableWorkerImpl implements DeepLearningConfigurable {
    private BaseMultiLayerNetwork network;
    private DoubleMatrix combinedInput;
    int fineTuneEpochs;
    int preTrainEpochs;
    boolean useRegularization;
    int[] hiddenLayerSizes;
    int numOuts;
    int numIns;
    double momentum = 0.0d;
    int numHiddenNeurons;
    long seed;
    double learningRate;
    double corruptionLevel;
    ActivationFunction activation;
    int[] rows;
    private boolean iterationComplete;
    private int currEpoch;
    private DoubleMatrix outcomes;
    Object[] extraParams;

    public ComputableWorkerAkka(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, int[] iArr) {
        this.combinedInput = doubleMatrix.getRows(iArr);
        this.rows = iArr;
        this.outcomes = doubleMatrix2.getRows(iArr);
    }

    public UpdateableGradientImpl compute(List<UpdateableGradientImpl> list) {
        return m9compute();
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public UpdateableGradientImpl m9compute() {
        this.network.trainNetwork(this.combinedInput, this.outcomes, this.extraParams);
        return new UpdateableGradientImpl(this.network.getGradient(this.extraParams, this.learningRate));
    }

    public boolean incrementIteration() {
        this.currEpoch++;
        return false;
    }

    public void setup(Conf conf) {
        this.hiddenLayerSizes = conf.getLayerSizes();
        this.numOuts = conf.getnOut();
        this.numIns = conf.getnIn();
        this.numHiddenNeurons = this.hiddenLayerSizes.length;
        this.seed = conf.getSeed();
        this.useRegularization = conf.isUseRegularization();
        this.momentum = conf.getMomentum();
        this.activation = conf.getFunction();
        this.network = new BaseMultiLayerNetwork.Builder().numberOfInputs(this.numIns).numberOfOutPuts(this.numOuts).withActivation(this.activation).hiddenLayerSizes(this.hiddenLayerSizes).withRng(new MersenneTwister(conf.getSeed())).useRegularization(this.useRegularization).withMomentum(this.momentum).withClazz(conf.getMultiLayerClazz()).build();
        this.learningRate = conf.getPretrainLearningRate();
        this.preTrainEpochs = conf.getPretrainEpochs();
        this.fineTuneEpochs = conf.getFinetuneEpochs();
        this.corruptionLevel = conf.getCorruptionLevel();
        this.extraParams = conf.getDeepLearningParams();
    }

    /* renamed from: compute, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Updateable m10compute(List list) {
        return compute((List<UpdateableGradientImpl>) list);
    }
}
