package org.deeplearning4j.arbiter.server;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.MultiLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.saver.local.FileModelSaver;
import org.deeplearning4j.arbiter.scoring.RegressionValue;
import org.deeplearning4j.arbiter.scoring.ScoreFunctions;

/* loaded from: input_file:org/deeplearning4j/arbiter/server/ArbiterCliGenerator.class */
public class ArbiterCliGenerator {

    @Parameter(names = {"--searchSpacePath"})
    private String searchSpacePath = null;

    @Parameter(names = {"--candidateType"}, required = true)
    private String candidateType = null;

    @Parameter(names = {"--discretizationCount"})
    private int discretizationCount = 5;

    @Parameter(names = {"--gridSearchOrder"})
    private String gridSearchOrder = null;

    @Parameter(names = {"--neuralNetType"}, required = true)
    private String neuralNetType = null;

    @Parameter(names = {"--dataSetIteratorClass"}, required = true)
    private String dataSetIteratorClass = null;

    @Parameter(names = {"--modelOutputPath"}, required = true)
    private String modelOutputPath = null;

    @Parameter(names = {"--score"}, required = true)
    private String score = null;

    @Parameter(names = {"--problemType"}, required = true)
    private String problemType = "classification";

    @Parameter(names = {"--configSavePath"}, required = true)
    private String configSavePath = null;

    @Parameter(names = {"--duration"}, description = "The number of minutes to run for. Default is -1 which means run till convergence.")
    private long duration = -1;

    @Parameter(names = {"--numCandidates"}, description = "The number of candidates to generate. Default is 1.")
    private int numCandidates = 1;
    public static final String REGRESSION_MULTI = "regression";
    public static final String REGRESSION = "regression";
    public static final String CLASSIFICIATION = "classification";
    public static final String RANDOM_CANDIDATE = "random";
    public static final String GRID_SEARCH_CANDIDATE = "gridsearch";
    public static final String SEQUENTIAL_ORDER = "sequence";
    public static final String RANDOM_ORDER = "random";
    public static final String COMP_GRAPH = "compgraph";
    public static final String MULTI_LAYER = "multilayer";
    public static final String ACCURACY = "accuracy";
    public static final String F1 = "f1";
    public static final String ACCURACY_MULTI = "accuracy_multi";
    public static final String F1_MULTI = "f1_multi";
    public static final String REGRESSION_SCORE = "regression_score";
    public static final String REGRESSION_SCORE_MULTI = "regression_score_multi";

    public void runMain(String... strArr) throws Exception {
        JCommander jCommander = new JCommander(this);
        try {
            jCommander.parse(strArr);
        } catch (ParameterException e) {
            System.err.println(e.getMessage());
            jCommander.usage();
            try {
                Thread.sleep(500L);
            } catch (Exception e2) {
            }
            System.exit(1);
        }
        DataSetIteratorFactoryProvider dataSetIteratorFactoryProvider = new DataSetIteratorFactoryProvider();
        HashMap hashMap = new HashMap();
        hashMap.put("org.deeplearning4j.arbiter.data.data.factory", this.dataSetIteratorClass);
        if (this.neuralNetType.equals(MULTI_LAYER)) {
            MultiLayerSpace loadMultiLayer = loadMultiLayer();
            RandomSearchGenerator randomSearchGenerator = null;
            if (this.candidateType.equals(GRID_SEARCH_CANDIDATE)) {
                randomSearchGenerator = new RandomSearchGenerator(loadMultiLayer, hashMap);
            } else if (this.candidateType.equals("random")) {
                randomSearchGenerator = new RandomSearchGenerator(loadMultiLayer, hashMap);
            }
            if (this.problemType.equals("classification")) {
                FileUtils.writeStringToFile(new File(this.configSavePath), new OptimizationConfiguration.Builder().candidateGenerator(randomSearchGenerator).dataProvider(dataSetIteratorFactoryProvider).modelSaver(new FileModelSaver(this.modelOutputPath)).scoreFunction(scoreFunctionMultiLayerNetwork()).terminationConditions(getConditions()).build().toJson());
                return;
            } else {
                if (this.problemType.equals("regression")) {
                    FileUtils.writeStringToFile(new File(this.configSavePath), new OptimizationConfiguration.Builder().candidateGenerator(randomSearchGenerator).dataProvider(dataSetIteratorFactoryProvider).modelSaver(new FileModelSaver(this.modelOutputPath)).scoreFunction(scoreFunctionMultiLayerNetwork()).terminationConditions(getConditions()).build().toJson());
                    return;
                }
                return;
            }
        }
        if (this.neuralNetType.equals("compgraph")) {
            ComputationGraphSpace loadCompGraph = loadCompGraph();
            RandomSearchGenerator randomSearchGenerator2 = null;
            if (this.candidateType.equals(GRID_SEARCH_CANDIDATE)) {
                randomSearchGenerator2 = new RandomSearchGenerator(loadCompGraph, hashMap);
            } else if (this.candidateType.equals("random")) {
                randomSearchGenerator2 = new RandomSearchGenerator(loadCompGraph, hashMap);
            }
            if (this.problemType.equals("classification")) {
                FileUtils.writeStringToFile(new File(this.configSavePath), new OptimizationConfiguration.Builder().candidateGenerator(randomSearchGenerator2).dataProvider(dataSetIteratorFactoryProvider).modelSaver(new FileModelSaver(this.modelOutputPath)).scoreFunction(scoreFunctionCompGraph()).terminationConditions(getConditions()).build().toJson());
            } else {
                FileUtils.writeStringToFile(new File(this.configSavePath), new OptimizationConfiguration.Builder().candidateGenerator(randomSearchGenerator2).dataProvider(dataSetIteratorFactoryProvider).modelSaver(new FileModelSaver(this.modelOutputPath)).scoreFunction(scoreFunctionCompGraph()).terminationConditions(getConditions()).build().toJson());
            }
        }
    }

    public static void main(String... strArr) throws Exception {
        new ArbiterCliGenerator().runMain(strArr);
    }

    private List<TerminationCondition> getConditions() {
        ArrayList arrayList = new ArrayList();
        if (this.duration > 0) {
            arrayList.add(new MaxTimeCondition(this.duration, TimeUnit.MINUTES));
        }
        if (this.numCandidates > 0) {
            arrayList.add(new MaxCandidatesCondition(this.numCandidates));
        }
        if (arrayList.isEmpty()) {
            arrayList.add(new MaxCandidatesCondition(1));
        }
        return arrayList;
    }

    private GridSearchCandidateGenerator.Mode getMode() {
        if (this.gridSearchOrder.equals("random")) {
            return GridSearchCandidateGenerator.Mode.RandomOrder;
        }
        if (this.gridSearchOrder.equals(SEQUENTIAL_ORDER)) {
            return GridSearchCandidateGenerator.Mode.Sequential;
        }
        throw new IllegalArgumentException("Illegal mode " + this.gridSearchOrder);
    }

    private ScoreFunction scoreFunctionCompGraph() {
        if (!this.problemType.equals("classification")) {
            if (!this.problemType.equals("regression")) {
                throw new IllegalStateException("Illegal problem type " + this.problemType);
            }
            String str = this.score;
            boolean z = -1;
            switch (str.hashCode()) {
                case -466866924:
                    if (str.equals(REGRESSION_SCORE)) {
                        z = false;
                        break;
                    }
                    break;
                case 1441154638:
                    if (str.equals(REGRESSION_SCORE_MULTI)) {
                        z = true;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    return ScoreFunctions.testSetRegression(RegressionValue.valueOf(this.score));
                case true:
                    return ScoreFunctions.testSetRegression(RegressionValue.valueOf(this.score));
                default:
                    throw new IllegalArgumentException("Score " + this.score + " not valid for type " + this.problemType);
            }
        }
        String str2 = this.score;
        boolean z2 = -1;
        switch (str2.hashCode()) {
            case -2131707655:
                if (str2.equals(ACCURACY)) {
                    z2 = false;
                    break;
                }
                break;
            case -1547070861:
                if (str2.equals(ACCURACY_MULTI)) {
                    z2 = 3;
                    break;
                }
                break;
            case 3211:
                if (str2.equals(F1)) {
                    z2 = true;
                    break;
                }
                break;
            case 740061317:
                if (str2.equals(F1_MULTI)) {
                    z2 = 2;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                return ScoreFunctions.testSetAccuracy();
            case true:
                return ScoreFunctions.testSetF1();
            case true:
                return ScoreFunctions.testSetF1();
            case true:
                return ScoreFunctions.testSetAccuracy();
            default:
                throw new IllegalArgumentException("Score " + this.score + " not valid for type " + this.problemType);
        }
    }

    private ScoreFunction scoreFunctionMultiLayerNetwork() {
        if (!this.problemType.equals("classification")) {
            if (!this.problemType.equals("regression")) {
                throw new IllegalStateException("Illegal problem type " + this.problemType);
            }
            String str = this.score;
            boolean z = -1;
            switch (str.hashCode()) {
                case -466866924:
                    if (str.equals(REGRESSION_SCORE)) {
                        z = false;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    return ScoreFunctions.testSetRegression(RegressionValue.valueOf(this.score));
                default:
                    throw new IllegalArgumentException("Score " + this.score + " not valid for type " + this.problemType);
            }
        }
        String str2 = this.score;
        boolean z2 = -1;
        switch (str2.hashCode()) {
            case -2131707655:
                if (str2.equals(ACCURACY)) {
                    z2 = false;
                    break;
                }
                break;
            case 3211:
                if (str2.equals(F1)) {
                    z2 = true;
                    break;
                }
                break;
        }
        switch (z2) {
            case false:
                return ScoreFunctions.testSetAccuracy();
            case true:
                return ScoreFunctions.testSetF1();
            default:
                throw new IllegalArgumentException("Score " + this.score + " not valid for type " + this.problemType);
        }
    }

    private ComputationGraphSpace loadCompGraph() throws Exception {
        return ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(this.searchSpacePath)));
    }

    private MultiLayerSpace loadMultiLayer() throws Exception {
        return MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(this.searchSpacePath)));
    }
}
