package de.datexis.sector.exec;

import de.datexis.common.CommandLineParser;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.impl.BloomEncoder;
import de.datexis.encoder.impl.DummyEncoder;
import de.datexis.encoder.impl.FastTextEncoder;
import de.datexis.encoder.impl.StructureEncoder;
import de.datexis.encoder.impl.Word2VecEncoder;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.sector.SectorAnnotator;
import de.datexis.sector.encoder.ClassEncoder;
import de.datexis.sector.encoder.HeadingEncoder;
import de.datexis.sector.model.SectionAnnotation;
import de.datexis.sector.reader.WikiSectionReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.deeplearning4j.ui.api.UIServer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMultiLabel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/sector/exec/TrainSectorAnnotator.class */
public class TrainSectorAnnotator {
    protected static final Logger log = LoggerFactory.getLogger(TrainSectorAnnotator.class);

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:de/datexis/sector/exec/TrainSectorAnnotator$ExecParams.class */
    public static class ExecParams implements CommandLineParser.Options {
        protected String trainFile;
        protected String devFile = null;
        protected String testFile = null;
        protected String outputPath = null;
        protected String embeddingsFile = null;
        protected String language = null;
        protected boolean trainingUI = false;
        protected boolean testSegmentation = false;
        protected boolean isHeadingsModel = false;

        protected ExecParams() {
        }

        public void setParams(CommandLine commandLine) {
            this.trainFile = commandLine.getOptionValue("i");
            this.devFile = commandLine.getOptionValue("v");
            this.testFile = commandLine.getOptionValue("t");
            this.outputPath = commandLine.getOptionValue("o");
            this.embeddingsFile = commandLine.getOptionValue("e");
            this.language = commandLine.getOptionValue("l", "en");
            this.trainingUI = commandLine.hasOption("u");
            this.testSegmentation = commandLine.hasOption("s");
            this.isHeadingsModel = commandLine.hasOption("h");
        }

        public Options setUpCliOptions() {
            Options options = new Options();
            options.addRequiredOption("i", "input", true, "file name of WikiSection training dataset");
            options.addRequiredOption("o", "output", true, "path to create and store the model");
            options.addOption("h", "headings", false, "train multi-label model (SEC>H), otherwise single-label model (SEC>T) is used");
            options.addRequiredOption("o", "output", true, "path to create and store the model");
            options.addOption("v", "validation", true, "file name of WikiSection validation dataset (will use early stopping if given)");
            options.addOption("t", "test", true, "file name of WikiSection test dataset (will test after training if given)");
            options.addOption("s", "segment", false, "evaluate full segmentation model instead of faster sentence classification");
            options.addOption("e", "embedding", true, "path to word embedding model, will use bloom filters if not given");
            options.addOption("l", "language", true, "language to use for sentence splitting and stopwords (EN or DE)");
            options.addOption("u", "ui", false, "enable training UI (http://127.0.0.1:9000)");
            return options;
        }
    }

    public static void main(String[] strArr) throws IOException {
        ExecParams execParams = new ExecParams();
        try {
            new CommandLineParser(execParams).parse(strArr);
            new TrainSectorAnnotator().runTraining(execParams);
            System.exit(0);
        } catch (ParseException e) {
            new HelpFormatter().printHelp("texoo-train-sector", "TeXoo: train SectorAnnotator from WikiSection dataset", execParams.setUpCliOptions(), "", true);
            System.exit(1);
        } catch (Throwable th) {
            th.printStackTrace();
            System.exit(1);
        }
    }

    protected void runTraining(ExecParams execParams) throws IOException {
        Resource fromDirectory = Resource.fromDirectory(execParams.trainFile);
        Resource fromDirectory2 = execParams.devFile != null ? Resource.fromDirectory(execParams.devFile) : null;
        Resource fromDirectory3 = execParams.testFile != null ? Resource.fromDirectory(execParams.testFile) : null;
        Resource fromDirectory4 = Resource.fromDirectory(execParams.outputPath);
        WordHelpers.Language language = WordHelpers.getLanguage(execParams.language);
        Dataset readDatasetFromJSON = fromDirectory.getFileName().endsWith(".json") ? WikiSectionReader.readDatasetFromJSON(fromDirectory) : WikiSectionReader.readDatasetFromJSON(fromDirectory);
        Dataset readDatasetFromJSON2 = fromDirectory2 == null ? null : fromDirectory2.getFileName().endsWith(".json") ? WikiSectionReader.readDatasetFromJSON(fromDirectory2) : WikiSectionReader.readDatasetFromJSON(fromDirectory2);
        Dataset readDatasetFromJSON3 = fromDirectory3 == null ? null : fromDirectory3.getFileName().endsWith(".json") ? WikiSectionReader.readDatasetFromJSON(fromDirectory3) : WikiSectionReader.readDatasetFromJSON(fromDirectory3);
        SectorAnnotator.Builder builder = new SectorAnnotator.Builder();
        if (execParams.embeddingsFile == null) {
            initializeInputEncodings_bloom(builder, readDatasetFromJSON, language);
        } else {
            initializeInputEncodings_wemb(builder, Resource.fromFile(execParams.embeddingsFile));
        }
        if (execParams.isHeadingsModel) {
            initializeHeadingsTarget(builder, readDatasetFromJSON, language);
        } else {
            initializeClassLabelsTarget(builder, readDatasetFromJSON);
        }
        SectorAnnotator build = builder.withDataset(readDatasetFromJSON.getName(), language).withModelParams(0, 256, 128).withTrainingParams(0.01d, 0.5d, 2048, 396, 16, 10).enableTrainingUI(execParams.trainingUI).build();
        boolean z = false;
        try {
            if (readDatasetFromJSON2 == null) {
                build.trainModel(readDatasetFromJSON);
            } else {
                build.trainModelEarlyStopping(readDatasetFromJSON, readDatasetFromJSON2, 10, 10, 100);
            }
            Resource resolve = fromDirectory4.resolve(build.m4getTagger().getName());
            resolve.toFile().mkdirs();
            build.writeModel(resolve);
            build.writeTrainLog(resolve);
            if (readDatasetFromJSON3 != null) {
                if (execParams.testSegmentation) {
                    log.info("Testing full BEMD segmentation model (might take longer)");
                    build.annotate(readDatasetFromJSON3.getDocuments(), SectorAnnotator.SegmentationMethod.BEMD);
                    build.evaluateModel(readDatasetFromJSON3, false, true, true);
                } else {
                    log.info("Testing sentence classification (fast, but no segmentation)");
                    build.annotate(readDatasetFromJSON3.getDocuments(), SectorAnnotator.SegmentationMethod.NONE);
                    build.evaluateModel(readDatasetFromJSON3, true, false, false);
                }
            }
            build.writeTestLog(resolve);
            z = true;
            try {
                if (execParams.trainingUI) {
                    UIServer.getInstance().stop();
                }
                System.exit(1 != 0 ? 0 : 1);
            } catch (Exception e) {
                System.exit(1 != 0 ? 0 : 1);
            } catch (NoClassDefFoundError e2) {
                System.exit(1 != 0 ? 0 : 1);
            } catch (Throwable th) {
                System.exit(1 != 0 ? 0 : 1);
                throw th;
            }
        } catch (Throwable th2) {
            try {
                if (execParams.trainingUI) {
                    UIServer.getInstance().stop();
                }
                System.exit(z ? 0 : 1);
            } catch (Exception e3) {
                System.exit(z ? 0 : 1);
            } catch (NoClassDefFoundError e4) {
                System.exit(z ? 0 : 1);
            } catch (Throwable th3) {
                System.exit(z ? 0 : 1);
                throw th3;
            }
            throw th2;
        }
    }

    protected SectorAnnotator.Builder initializeInputEncodings_bloom(SectorAnnotator.Builder builder, Dataset dataset, WordHelpers.Language language) {
        BloomEncoder bloomEncoder = new BloomEncoder(4096, 5);
        bloomEncoder.trainModel(dataset.getDocuments(), 5, language);
        return builder.withInputEncoders("bloom", bloomEncoder, new DummyEncoder(), new StructureEncoder());
    }

    protected SectorAnnotator.Builder initializeInputEncodings_wemb(SectorAnnotator.Builder builder, Resource resource) throws IOException {
        StructureEncoder structureEncoder = new StructureEncoder();
        if (resource.getFileName().endsWith(".bin") || resource.getFileName().endsWith(".bin.gz")) {
            FastTextEncoder fastTextEncoder = new FastTextEncoder();
            fastTextEncoder.loadModel(resource);
            return builder.withInputEncoders("ft", new DummyEncoder(), fastTextEncoder, structureEncoder);
        }
        Word2VecEncoder word2VecEncoder = new Word2VecEncoder();
        word2VecEncoder.loadModel(resource);
        return builder.withInputEncoders("w2v", new DummyEncoder(), word2VecEncoder, structureEncoder);
    }

    protected SectorAnnotator.Builder initializeClassLabelsTarget(SectorAnnotator.Builder builder, Dataset dataset) {
        ArrayList arrayList = new ArrayList();
        Iterator it = dataset.getDocuments().iterator();
        while (it.hasNext()) {
            Iterator it2 = ((Document) it.next()).getAnnotations(SectionAnnotation.class).iterator();
            while (it2.hasNext()) {
                arrayList.add(((SectionAnnotation) it2.next()).getSectionLabel());
            }
        }
        Encoder classEncoder = new ClassEncoder();
        classEncoder.trainModel(arrayList, 0);
        return builder.withId("SEC>T").withTargetEncoder(classEncoder).withLossFunction((ILossFunction) new LossMCXENT(), Activation.SOFTMAX, false);
    }

    protected SectorAnnotator.Builder initializeHeadingsTarget(SectorAnnotator.Builder builder, Dataset dataset, WordHelpers.Language language) {
        ArrayList arrayList = new ArrayList();
        Iterator it = dataset.getDocuments().iterator();
        while (it.hasNext()) {
            Iterator it2 = ((Document) it.next()).getAnnotations(SectionAnnotation.class).iterator();
            while (it2.hasNext()) {
                arrayList.add(((SectionAnnotation) it2.next()).getSectionHeading());
            }
        }
        Encoder headingEncoder = new HeadingEncoder();
        headingEncoder.trainModel(arrayList, 20, 3, language);
        return builder.withId("SEC>H").withTargetEncoder(headingEncoder).withLossFunction((ILossFunction) new LossMultiLabel(), Activation.SIGMOID, false);
    }
}
