package org.deeplearning4j.iterativereduce.actor.deepautoencoder;

import java.io.File;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.iterativereduce.tracker.statetracker.hazelcast.deepautoencoder.DeepAutoEncoderHazelCastStateTracker;
import org.deeplearning4j.scaleout.conf.Conf;
import org.deeplearning4j.scaleout.core.conf.DeepLearningConfigurableDistributed;
import org.deeplearning4j.scaleout.zookeeper.ZookeeperConfigurationRetriever;
import org.deeplearning4j.util.SerializationUtils;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.kohsuke.args4j.spi.DoubleOptionHandler;
import org.kohsuke.args4j.spi.IntOptionHandler;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/iterativereduce/actor/deepautoencoder/DistributedDeepLearningTrainerApp.class */
public class DistributedDeepLearningTrainerApp implements DeepLearningConfigurableDistributed {
    protected static Logger log = LoggerFactory.getLogger(DistributedDeepLearningTrainerApp.class);

    @Option(name = "-a", usage = "algorithm to use: sda (stacked denoising autoencoders),dbn (deep belief networks),cdbn (continuous deep belief networks)")
    protected String algorithm;

    @Option(name = "-i", usage = "number of inputs (columns in the input matrix)", handler = IntOptionHandler.class)
    protected int inputs;

    @Option(name = "-o", usage = "number of outputs for the network", handler = IntOptionHandler.class)
    protected int outputs;

    @Option(name = "-hl", usage = "hidden layer sizes (comma separated list)")
    protected String hiddenLayerSizesOption;
    protected int[] hiddenLayerSizes;

    @Option(name = "-ad", usage = "address of master worker")
    protected String address;

    @Option(name = "-data", usage = "class to instantiate")
    protected String dataSet;

    @Option(name = "-datasetpath", usage = "dataset path; eg if you save a dataset you can just point the network runner at a path rather than worrying about a class to instantiate")
    protected String dataPath;
    protected DeepAutoEncoderDistributedTrainer runner;
    protected DataSetIterator iter;

    @Option(name = "-fte", usage = "number of fine tune epochs to iterate on (default: 100)", handler = IntOptionHandler.class)
    protected int finetuneEpochs = 100;

    @Option(name = "-pte", usage = "number of epochs for pretraining (default: 100)", handler = IntOptionHandler.class)
    protected int pretrainEpochs = 100;

    @Option(name = "-r", usage = "seed value for the random number generator (default: 123)", handler = IntOptionHandler.class)
    protected long rngSeed = 123;

    @Option(name = "-k", usage = "the k for rbms (default: 1)", handler = IntOptionHandler.class)
    protected int k = 1;

    @Option(name = "-c", usage = "corruption level (for denoising autoencoders) (default: 0.3)", handler = DoubleOptionHandler.class)
    protected float corruptionLevel = 0.3f;

    @Option(name = "-h", usage = "the host to connect to as a master (default: 127.0.0.1)")
    protected String host = "localhost";

    @Option(name = "-ftl", usage = "the starter fine tune learning rate (default: 0.1)", handler = DoubleOptionHandler.class)
    protected float finetuneLearningRate = 0.1f;

    @Option(name = "-ptl", usage = "the starter pretrain learning rate (default: 0.1)", handler = DoubleOptionHandler.class)
    protected float pretrainLearningRate = 0.1f;

    @Option(name = "-t", usage = "type of worker")
    protected String type = "master";

    @Option(name = "-sp", usage = "number of inputs to split by default: 10")
    protected int split = 10;

    @Option(name = "-e", usage = "number of examples to iterate on: if unspecified will just iterate on everything found")
    protected int numExamples = -1;

    @Option(name = "-adg", usage = "use adagrad; default true")
    protected boolean useAdaGrad = true;

    @Option(name = "-stp", usage = "state tracker port")
    protected int stateTrackerPort = -1;

    public DistributedDeepLearningTrainerApp(String[] strArr) {
        this.hiddenLayerSizes = new int[]{300, 300, 300};
        CmdLineParser cmdLineParser = new CmdLineParser(this);
        try {
            cmdLineParser.parseArgument(strArr);
        } catch (CmdLineException e) {
            cmdLineParser.printUsage(System.err);
            log.error("Unable to parse args", e);
        }
        if (this.hiddenLayerSizesOption != null) {
            String[] split = this.hiddenLayerSizesOption.split(",");
            this.hiddenLayerSizes = new int[split.length];
            for (int i = 0; i < split.length; i++) {
                this.hiddenLayerSizes[i] = Integer.parseInt(split[i]);
            }
        }
    }

    public void exec() throws Exception {
        if (this.type != null && this.type.equals("worker")) {
            log.info("Initializing conf from zookeeper at " + this.host);
            ZookeeperConfigurationRetriever zookeeperConfigurationRetriever = new ZookeeperConfigurationRetriever(this.host, 2181, "master");
            Conf retrieve = zookeeperConfigurationRetriever.retrieve();
            String masterUrl = retrieve.getMasterUrl();
            log.info("Creating hazel cast state tracker... " + retrieve.getStateTrackerConnectionString());
            DeepAutoEncoderHazelCastStateTracker deepAutoEncoderHazelCastStateTracker = new DeepAutoEncoderHazelCastStateTracker(retrieve.getStateTrackerConnectionString());
            log.info("Creating hazel cast via worker " + deepAutoEncoderHazelCastStateTracker.connectionString());
            this.runner = new DeepAutoEncoderDistributedTrainer(this.type, masterUrl);
            this.runner.setStateTracker(deepAutoEncoderHazelCastStateTracker);
            this.runner.setup(retrieve);
            zookeeperConfigurationRetriever.close();
            return;
        }
        Conf conf = new Conf();
        getDataSet();
        conf.setMultiLayerClazz(Class.forName(getClassForAlgorithm()));
        conf.setLayerSizes(this.hiddenLayerSizes);
        conf.setSplit(10);
        conf.getConf().setnIn(this.iter.inputColumns());
        conf.getConf().setnOut(this.iter.totalOutcomes());
        conf.getConf().setPretrainEpochs(this.pretrainEpochs);
        conf.getConf().setFinetuneEpochs(this.finetuneEpochs);
        conf.getConf().setSeed(this.rngSeed);
        conf.getConf().setPretrainLearningRate(this.pretrainLearningRate);
        conf.getConf().setUseAdaGrad(this.useAdaGrad);
        conf.getConf().setCorruptionLevel(this.corruptionLevel);
        conf.setSplit(this.split);
        conf.getConf().setK(this.k);
        conf.getConf().setFinetuneLearningRate(this.finetuneLearningRate);
        conf.getConf().setPretrainEpochs(this.pretrainEpochs);
        conf.getConf().setPretrainLearningRate(this.pretrainLearningRate);
        this.runner = new DeepAutoEncoderDistributedTrainer("master", this.iter);
        this.runner.setStateTrackerPort(this.stateTrackerPort);
        this.runner.setup(conf);
    }

    public void train() {
        this.runner.train();
    }

    public void shutdown() {
    }

    public String getData() {
        return this.dataSet;
    }

    protected void getDataSet() {
        if (this.type.equals("worker")) {
            return;
        }
        try {
            if (this.dataPath != null && this.dataSet != null) {
                throw new IllegalStateException("Can't have both a data applyTransformToDestination and a dataset path defined");
            }
            if (this.dataPath != null) {
                this.iter = new ListDataSetIterator(((DataSet) SerializationUtils.readObject(new File(this.dataPath))).asList(), this.split);
            } else if (this.numExamples < 0 || this.split == 0) {
                this.iter = (DataSetIterator) Class.forName(this.dataSet).newInstance();
            } else {
                this.iter = (DataSetIterator) Class.forName(this.dataSet).getConstructor(Integer.class, Integer.class).newInstance(Integer.valueOf(this.split), Integer.valueOf(this.numExamples));
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected String getClassForAlgorithm() {
        String str = this.algorithm;
        boolean z = -1;
        switch (str.hashCode()) {
            case 99248:
                if (str.equals("dbn")) {
                    z = true;
                    break;
                }
                break;
            case 113712:
                if (str.equals("sda")) {
                    z = false;
                    break;
                }
                break;
            case 3048557:
                if (str.equals("cdbn")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return "org.deeplearning4j.models.classifiers.sda.StackedDenoisingAutoEncoder";
            case true:
                return "org.deeplearning4j.models.classifiers.dbn.DBN";
            case true:
                return "org.deeplearning4j.models.classifiers.dbn.CDBN";
            default:
                return null;
        }
    }

    public boolean isDone() {
        return this.iter.hasNext();
    }

    public static void main(String[] strArr) throws Exception {
        DistributedDeepLearningTrainerApp distributedDeepLearningTrainerApp = new DistributedDeepLearningTrainerApp(strArr);
        distributedDeepLearningTrainerApp.exec();
        if (distributedDeepLearningTrainerApp.type.equals("master")) {
            distributedDeepLearningTrainerApp.train();
        }
    }

    public void setup(Conf conf) {
    }

    public String getAlgorithm() {
        return this.algorithm;
    }

    public int getInputs() {
        return this.inputs;
    }

    public int getOutputs() {
        return this.outputs;
    }

    public int getFinetuneEpochs() {
        return this.finetuneEpochs;
    }

    public int getPretrainEpochs() {
        return this.pretrainEpochs;
    }

    public long getRngSeed() {
        return this.rngSeed;
    }

    public int getK() {
        return this.k;
    }

    public float getCorruptionLevel() {
        return this.corruptionLevel;
    }

    public String getHost() {
        return this.host;
    }

    public float getFinetineLearningRate() {
        return this.finetuneLearningRate;
    }

    public float getPretrainLearningRate() {
        return this.pretrainLearningRate;
    }

    public String getHiddenLayerSizesOption() {
        return this.hiddenLayerSizesOption;
    }

    public int[] getHiddenLayerSizes() {
        return this.hiddenLayerSizes;
    }

    public String getType() {
        return this.type;
    }

    public String getAddress() {
        return this.address;
    }

    public int getSplit() {
        return this.split;
    }

    public int getNumExamples() {
        return this.numExamples;
    }
}
