package org.deeplearning4j.iterativereduce.actor.single;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.LFWDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.core.conf.DeepLearningConfigurableDistributed;
import org.deeplearning4j.scaleout.zookeeper.ZooKeeperConfigurationRegister;
import org.deeplearning4j.scaleout.zookeeper.ZookeeperConfigurationRetriever;
import org.jblas.DoubleMatrix;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.spi.DoubleOptionHandler;
import org.kohsuke.args4j.spi.IntOptionHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/single/ActorNetworkRunnerApp.class */
public class ActorNetworkRunnerApp implements DeepLearningConfigurableDistributed {
    protected static Logger log = LoggerFactory.getLogger(ActorNetworkRunnerApp.class);

    @Option(name = "-a", usage = "algorithm to use: sda (stacked denoising autoencoders),dbn (deep belief networks),cdbn (continuous deep belief networks)")
    protected String algorithm;

    @Option(name = "-i", usage = "number of inputs (columns in the input matrix)", handler = IntOptionHandler.class)
    protected int inputs;

    @Option(name = "-o", usage = "number hidden units for the network", handler = IntOptionHandler.class)
    protected int outputs;

    @Option(name = "-ad", usage = "address of master worker")
    protected String address;

    @Option(name = "-data", usage = "dataset to train on: options: mnist,text (text files with <label>text</label>, image (images where the parent directory is the label)")
    protected String dataSet;
    protected ActorNetworkRunner runner;
    protected DataSetIterator iter;

    @Option(name = "-pte", usage = "number of epochs for pretraining (default: 100)", handler = IntOptionHandler.class)
    protected int pretrainEpochs = 1;

    @Option(name = "-r", usage = "seed value for the random number generator (default: 123)", handler = IntOptionHandler.class)
    protected long rngSeed = 123;

    @Option(name = "-k", usage = "the k for rbms (default: 1)", handler = IntOptionHandler.class)
    protected int k = 1;

    @Option(name = "-c", usage = "corruption level (for denoising autoencoders) (default: 0.3)", handler = DoubleOptionHandler.class)
    protected double corruptionLevel = 0.3d;

    @Option(name = "-h", usage = "the host to connect to as a master (default: 127.0.0.1)")
    protected String host = "localhost";

    @Option(name = "-ptl", usage = "the starter pretrain learning rate (default: 0.1)", handler = DoubleOptionHandler.class)
    protected double pretrainLearningRate = 0.1d;

    @Option(name = "-t", usage = "type of worker")
    protected String type = "master";

    @Option(name = "-sp", usage = "number of inputs to split by default: 10")
    protected int split = 10;

    @Option(name = "-e", usage = "number of examples to train on: if unspecified will just train on everything found")
    protected int numExamples = -1;

    @Option(name = "-l2", usage = "l2 regularization constant")
    protected double l2 = 0.1d;

    @Option(name = "-m", usage = "momentum")
    protected double momentum = 0.1d;

    public ActorNetworkRunnerApp(String[] strArr) {
        CmdLineParser cmdLineParser = new CmdLineParser(this);
        try {
            cmdLineParser.parseArgument(strArr);
        } catch (CmdLineException e) {
            cmdLineParser.printUsage(System.err);
            log.error("Unable to parse args", e);
        }
    }

    public void exec() throws Exception {
        if (this.type != null && this.type.equals("worker")) {
            log.info("Initializing conf from zookeeper at " + this.host);
            ZookeeperConfigurationRetriever zookeeperConfigurationRetriever = new ZookeeperConfigurationRetriever(this.host, 2181, "master");
            Conf retreive = zookeeperConfigurationRetriever.retreive();
            this.runner = new ActorNetworkRunner(this.type, retreive.getMasterUrl());
            this.runner.setup(retreive);
            zookeeperConfigurationRetriever.close();
            return;
        }
        Conf conf = new Conf();
        getDataSet();
        conf.setMultiLayerClazz(Class.forName(getClassForAlgorithm()));
        conf.setSplit(10);
        if (this.inputs < 1) {
            conf.setnIn(this.iter.inputColumns());
        } else {
            conf.setnIn(this.inputs);
        }
        if (this.outputs < 1) {
            conf.setnOut(this.iter.totalOutcomes());
        } else {
            conf.setnOut(this.outputs);
        }
        conf.setPretrainEpochs(this.pretrainEpochs);
        conf.setSeed(this.rngSeed);
        conf.setPretrainLearningRate(this.pretrainLearningRate);
        conf.setL2(this.l2);
        conf.setMomentum(this.momentum);
        conf.setCorruptionLevel(this.corruptionLevel);
        conf.setSplit(this.split);
        conf.setK(this.k);
        this.runner = new ActorNetworkRunner("master", this.iter);
        this.runner.setup(conf);
        conf.setMasterUrl(this.runner.getMasterAddress().toString());
        ZooKeeperConfigurationRegister zooKeeperConfigurationRegister = new ZooKeeperConfigurationRegister(conf, "master", this.host, 2181);
        zooKeeperConfigurationRegister.register();
        zooKeeperConfigurationRegister.close();
    }

    public void train() {
        if (this.iter.hasNext()) {
            log.info("Training next batch " + (0 + 1));
            this.runner.train((Pair<DoubleMatrix, DoubleMatrix>) this.iter.next());
            int i = 0 + 1;
        }
    }

    public void shutdown() {
    }

    public String getData() {
        return this.dataSet;
    }

    protected void getDataSet() {
        if (this.type.equals("worker")) {
            return;
        }
        try {
            if (this.dataSet.equals("mnist")) {
                this.iter = new MnistDataSetIterator(this.split, this.numExamples);
            } else if (this.dataSet.equals("iris")) {
                this.iter = new IrisDataSetIterator(this.split, this.numExamples);
            } else if (this.dataSet.equals("lfw")) {
                this.iter = new LFWDataSetIterator(this.split, this.numExamples);
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected String getClassForAlgorithm() {
        String str = this.algorithm;
        boolean z = -1;
        switch (str.hashCode()) {
            case 3197:
                if (str.equals("da")) {
                    z = false;
                    break;
                }
                break;
            case 112701:
                if (str.equals("rbm")) {
                    z = true;
                    break;
                }
                break;
            case 3062010:
                if (str.equals("crbm")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return "org.deeplearning4j.sda.DenoisingAutoEncoder";
            case true:
                return "org.deeplearning4j.rbm.RBM";
            case true:
                return "org.deeplearning4j.rbm.CRBM";
            default:
                return null;
        }
    }

    public boolean isDone() {
        return this.iter.hasNext();
    }

    public static void main(String[] strArr) throws Exception {
        ActorNetworkRunnerApp actorNetworkRunnerApp = new ActorNetworkRunnerApp(strArr);
        actorNetworkRunnerApp.exec();
        if (actorNetworkRunnerApp.type.equals("master")) {
            actorNetworkRunnerApp.train();
        }
    }

    public void setup(Conf conf) {
    }

    public String getAlgorithm() {
        return this.algorithm;
    }

    public int getInputs() {
        return this.inputs;
    }

    public int getOutputs() {
        return this.outputs;
    }

    public int getPretrainEpochs() {
        return this.pretrainEpochs;
    }

    public long getRngSeed() {
        return this.rngSeed;
    }

    public int getK() {
        return this.k;
    }

    public double getCorruptionLevel() {
        return this.corruptionLevel;
    }

    public String getHost() {
        return this.host;
    }

    public double getPretrainLearningRate() {
        return this.pretrainLearningRate;
    }

    public String getType() {
        return this.type;
    }

    public String getAddress() {
        return this.address;
    }

    public int getSplit() {
        return this.split;
    }

    public int getNumExamples() {
        return this.numExamples;
    }
}
