package org.deeplearning4j.apps;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.util.List;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.dbn.CDBN;
import org.deeplearning4j.nn.activation.HardTanh;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;

/* loaded from: input_file:org/deeplearning4j/apps/DataSetTrainer.class */
public class DataSetTrainer {
    public static void main(String[] strArr) throws FileNotFoundException {
        String str = strArr[0];
        int parseInt = Integer.parseInt(strArr[1]);
        DataSet empty = DataSet.empty();
        empty.load(new BufferedInputStream(new FileInputStream(new File(str))));
        List dataSetBatches = empty.dataSetBatches(parseInt);
        int i = ((DoubleMatrix) ((DataSet) dataSetBatches.get(0)).getFirst()).columns;
        CDBN build = new CDBN.Builder().useRegularization(false).numberOfInputs(i).hiddenLayerSizes(new int[]{i / 2, i / 4, i / 8}).withActivation(new HardTanh()).numberOfOutPuts(((DataSet) dataSetBatches.get(0)).numOutcomes()).build();
        for (int i2 = 0; i2 < dataSetBatches.size(); i2++) {
            DataSet dataSet = (DataSet) dataSetBatches.get(i2);
            DoubleMatrix normalizeByColumnSums = MatrixUtil.normalizeByColumnSums((DoubleMatrix) dataSet.getFirst());
            DoubleMatrix doubleMatrix = (DoubleMatrix) dataSet.getSecond();
            build.pretrain(normalizeByColumnSums, 1, 0.01d, 1000);
            build.finetune(doubleMatrix, 0.01d, 1000);
        }
        build.write(new BufferedOutputStream(new FileOutputStream(new File("nn-model.bin"))));
    }
}
