package de.datexis.cdv.train;

import de.datexis.annotator.Annotator;
import de.datexis.annotator.AnnotatorFactory;
import de.datexis.cdv.CDVAnnotator;
import de.datexis.cdv.index.AspectIndex;
import de.datexis.cdv.index.EntityIndex;
import de.datexis.cdv.loss.LossHuber;
import de.datexis.common.CommandLineParser;
import de.datexis.common.Configuration;
import de.datexis.common.ObjectSerializer;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.Encoder;
import de.datexis.encoder.EncoderAnnotator;
import de.datexis.encoder.IEncoder;
import de.datexis.encoder.impl.BagOfWordsEncoder;
import de.datexis.encoder.impl.FastTextEncoder;
import de.datexis.encoder.impl.StructureEncoder;
import de.datexis.encoder.impl.Word2VecEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.parvec.encoder.ParVecWordsEncoder;
import de.datexis.preprocess.IdentityPreprocessor;
import de.datexis.retrieval.encoder.LSTMSentenceAnnotator;
import de.datexis.sector.encoder.ParVecSentenceEncoder;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Iterator;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/cdv/train/TrainCDVAnnotator.class */
public class TrainCDVAnnotator {
    protected static final Logger log = LoggerFactory.getLogger(TrainCDVAnnotator.class);
    protected WordHelpers.Language lang;
    protected Encoder sentenceEmb;
    protected StructureEncoder positionalEmb;
    protected IEncoder entityEmb;
    protected IEncoder aspectEmb;
    protected EntityIndex entityIndex = null;
    protected AspectIndex aspectIndex = null;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:de/datexis/cdv/train/TrainCDVAnnotator$TrainingParams.class */
    public static class TrainingParams implements CommandLineParser.Options {
        protected String inputPath = null;
        protected String datasetName = "wd_disease";
        protected String inputEmbedding = null;
        protected String entityEmbedding = null;
        protected String aspectEmbedding = null;
        protected String searchPath = "models/common";
        protected String outputPath = "models";
        protected String modelName = null;
        protected boolean trainingUI = true;
        protected boolean entityModel = true;
        protected boolean aspectModel = true;
        protected boolean balancing = false;
        protected int epochs = 50;

        protected TrainingParams() {
        }

        public void setParams(CommandLine commandLine) {
            this.inputPath = commandLine.getOptionValue("i");
            this.datasetName = commandLine.getOptionValue("d");
            this.modelName = commandLine.getOptionValue("m");
            this.inputEmbedding = commandLine.getOptionValue("w");
            this.entityModel = commandLine.hasOption("e");
            this.entityEmbedding = commandLine.getOptionValue("e");
            this.aspectModel = commandLine.hasOption("a");
            this.aspectEmbedding = commandLine.getOptionValue("a");
            this.searchPath = commandLine.getOptionValue("s");
            this.balancing = commandLine.hasOption("b");
            this.trainingUI = commandLine.hasOption("u");
            this.outputPath = commandLine.getOptionValue("o", Configuration.getProperty("de.datexis.path.results"));
        }

        public Options setUpCliOptions() {
            Options options = new Options();
            options.addRequiredOption("i", "dataset", true, "path to the WikiSection training dataset");
            options.addRequiredOption("d", "datasetname", true, "name of the data set, e.g. en_disease");
            options.addRequiredOption("m", "modelname", true, "model name");
            options.addOption("w", "wordemb", true, "path to a pretrained embedding");
            options.addOption("e", "entity", true, "path to the entity embedding");
            options.addOption("a", "aspect", true, "path to the aspect embedding");
            options.addOption("o", "output", true, "path to create the output folder in");
            options.addOption("s", "search", true, "search path for pre-trained word embeddings");
            options.addOption("b", "balancing", false, "use class balancing during training");
            options.addOption("u", "ui", false, "enable training UI");
            return options;
        }
    }

    public static void main(String[] strArr) throws IOException, ParseException {
        TrainingParams trainingParams = new TrainingParams();
        try {
            new CommandLineParser(trainingParams).parse(strArr);
            new TrainCDVAnnotator().trainCompleteCDVModel(trainingParams);
            System.exit(0);
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(1);
        } catch (ParseException e2) {
            new HelpFormatter().printHelp("train-cdv", "TeXoo: train contextualized discourse vectors (CDV)", trainingParams.setUpCliOptions(), "", true);
            System.exit(1);
        }
    }

    public void trainCompleteCDVModel(TrainingParams trainingParams) throws IOException {
        CDVAnnotator build;
        Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
        Resource fromDirectory = Resource.fromDirectory(trainingParams.inputPath);
        Resource fromDirectory2 = Resource.fromDirectory(trainingParams.outputPath);
        this.lang = trainingParams.datasetName.startsWith("de_") ? WordHelpers.Language.DE : WordHelpers.Language.EN;
        Dataset readDatasetFromJSON = readDatasetFromJSON(fromDirectory);
        log.info("read {} articles", Integer.valueOf(readDatasetFromJSON.countDocuments()));
        loadSentenceEmbedding(trainingParams, readDatasetFromJSON);
        if (trainingParams.aspectModel) {
            loadAspectEmbedding(trainingParams, readDatasetFromJSON);
        }
        if (trainingParams.entityModel) {
            loadEntityEmbedding(trainingParams, readDatasetFromJSON);
        }
        if (trainingParams.entityModel && trainingParams.aspectModel) {
            build = new CDVAnnotator.Builder().withId("CDV-EA").withClassBalancing(trainingParams.balancing).withInputEncoders(trainingParams.modelName, this.sentenceEmb, this.positionalEmb).withEntityAspectEncoders(this.entityIndex, this.aspectIndex).withLossFunction(new LossHuber(), Activation.TANH).withModelParams(768, 1024, 512).withTrainingParams(0.001d, 0.0d, 1.0E-4d, 16, trainingParams.epochs).withDatasetLimit(-1, 396, 96).withDataset(trainingParams.datasetName, this.lang).enableTrainingUI(trainingParams.trainingUI).build();
        } else if (trainingParams.entityModel) {
            build = new CDVAnnotator.Builder().withId("CDV-E").withClassBalancing(trainingParams.balancing).withInputEncoders(trainingParams.modelName, this.sentenceEmb, this.positionalEmb).withEntityEncoder(this.entityIndex).withLossFunction(new LossHuber(), Activation.TANH).withModelParams(128, 512, 128).withTrainingParams(5.0E-4d, 0.0d, 1.0E-4d, 16, trainingParams.epochs).withDatasetLimit(-1, 396, 96).withDataset(trainingParams.datasetName, this.lang).enableTrainingUI(trainingParams.trainingUI).build();
        } else {
            if (!trainingParams.aspectModel) {
                throw new IllegalArgumentException("No entity or aspect index given.");
            }
            build = new CDVAnnotator.Builder().withId("CDV-A").withClassBalancing(trainingParams.balancing).withInputEncoders(trainingParams.modelName, this.sentenceEmb, this.positionalEmb).withAspectEncoder(this.aspectIndex).withLossFunction(new LossHuber(), Activation.TANH).withModelParams(128, 512, 128).withTrainingParams(5.0E-4d, 0.0d, 1.0E-4d, 16, trainingParams.epochs).withDatasetLimit(-1, 396, 96).withDataset(trainingParams.datasetName, this.lang).enableTrainingUI(trainingParams.trainingUI).build();
        }
        boolean z = false;
        try {
            try {
                this.positionalEmb.setCachingEnabled(true);
                this.sentenceEmb.setCachingEnabled(true);
                build.trainModel(readDatasetFromJSON);
                saveModel(build, fromDirectory2.resolve(new SimpleDateFormat("yyMMdd_HHmm_").format(new Date()) + build.m1getTagger().getName()));
                z = true;
                System.exit(1 != 0 ? 0 : 1);
            } catch (Throwable th) {
                th.printStackTrace();
                System.exit(z ? 0 : 1);
            }
        } catch (Throwable th2) {
            System.exit(z ? 0 : 1);
            throw th2;
        }
    }

    private void loadSentenceEmbedding(TrainingParams trainingParams, Dataset dataset) throws IOException {
        Resource fromFile = trainingParams.inputEmbedding == null ? null : Resource.fromFile(trainingParams.inputEmbedding);
        if (fromFile == null) {
            this.sentenceEmb = new BagOfWordsEncoder();
            this.sentenceEmb.trainModel(dataset.getDocuments());
        } else if (trainingParams.inputEmbedding.endsWith("pv.zip")) {
            this.sentenceEmb = loadParVecEmbedding(fromFile);
        } else if (trainingParams.inputEmbedding.endsWith(".bin") || trainingParams.inputEmbedding.endsWith(".bin.gz")) {
            this.sentenceEmb = loadFastTextEmbedding(fromFile);
        } else {
            if (trainingParams.inputEmbedding.equalsIgnoreCase("ELMo")) {
                throw new UnsupportedOperationException("REST embedding is not possible outside Beuth infrastructure");
            }
            if (trainingParams.inputEmbedding.equalsIgnoreCase("BERT-base")) {
                throw new UnsupportedOperationException("REST embedding is not possible outside Beuth infrastructure");
            }
            if (trainingParams.inputEmbedding.equalsIgnoreCase("BioBERT")) {
                throw new UnsupportedOperationException("REST embedding is not possible outside Beuth infrastructure");
            }
            if (trainingParams.inputEmbedding.equalsIgnoreCase("BERT-large")) {
                throw new UnsupportedOperationException("REST embedding is not possible outside Beuth infrastructure");
            }
            this.sentenceEmb = loadWord2VecEmbedding(fromFile);
        }
        this.positionalEmb = new StructureEncoder();
    }

    private void loadAspectEmbedding(TrainingParams trainingParams, Dataset dataset) throws IOException {
        Resource fromFile = Resource.fromFile(trainingParams.aspectEmbedding);
        Resource fromDirectory = Resource.fromDirectory(trainingParams.searchPath);
        Resource resolve = fromFile.resolve("aspect.index.bin");
        if (fromFile.isDirectory()) {
            EncoderAnnotator loadAnnotator = AnnotatorFactory.loadAnnotator(fromFile, new Resource[]{fromDirectory});
            if (loadAnnotator instanceof EncoderAnnotator) {
                EncoderAnnotator encoderAnnotator = loadAnnotator;
                this.aspectEmb = encoderAnnotator.getEncoder();
                if (encoderAnnotator.getEncoder() instanceof FastTextEncoder) {
                    encoderAnnotator.getEncoder().setModelAsReference();
                }
            } else if (loadAnnotator instanceof LSTMSentenceAnnotator) {
                LSTMSentenceAnnotator lSTMSentenceAnnotator = (LSTMSentenceAnnotator) loadAnnotator;
                if (lSTMSentenceAnnotator.getTagger().getInputEncoder() instanceof FastTextEncoder) {
                    lSTMSentenceAnnotator.getTagger().getInputEncoder().setModelAsReference();
                }
                this.aspectEmb = lSTMSentenceAnnotator.asEncoder();
            }
        } else if (trainingParams.aspectEmbedding.endsWith(".bin") || trainingParams.aspectEmbedding.endsWith(".bin.gz")) {
            this.aspectEmb = loadFastTextEmbedding(fromFile);
        } else {
            this.aspectEmb = loadWord2VecEmbedding(fromFile);
        }
        this.aspectIndex = new AspectIndex(this.aspectEmb);
        if (!resolve.exists()) {
            throw new IllegalArgumentException("Aspect encoder needs to provide a knowledge base called aspect.index.bin");
        }
        this.aspectIndex.loadModel(resolve);
    }

    private void loadEntityEmbedding(TrainingParams trainingParams, Dataset dataset) throws IOException {
        Resource fromFile = Resource.fromFile(trainingParams.entityEmbedding);
        Resource fromDirectory = Resource.fromDirectory(trainingParams.searchPath);
        Resource resolve = fromFile.resolve("entity.index.bin");
        if (fromFile.isDirectory()) {
            EncoderAnnotator loadAnnotator = AnnotatorFactory.loadAnnotator(fromFile, new Resource[]{fromDirectory});
            if (loadAnnotator instanceof EncoderAnnotator) {
                EncoderAnnotator encoderAnnotator = loadAnnotator;
                this.entityEmb = encoderAnnotator.getEncoder();
                if (encoderAnnotator.getEncoder() instanceof FastTextEncoder) {
                    encoderAnnotator.getEncoder().setModelAsReference();
                }
            } else if (loadAnnotator instanceof LSTMSentenceAnnotator) {
                LSTMSentenceAnnotator loadAnnotator2 = AnnotatorFactory.loadAnnotator(fromFile, new Resource[]{fromDirectory});
                if (loadAnnotator2.getTagger().getInputEncoder() instanceof FastTextEncoder) {
                    loadAnnotator2.getTagger().getInputEncoder().setModelAsReference();
                }
                this.entityEmb = loadAnnotator2.asEncoder();
            }
        } else if (trainingParams.entityEmbedding.endsWith(".bin") || trainingParams.entityEmbedding.endsWith(".bin.gz")) {
            this.entityEmb = loadFastTextEmbedding(fromFile);
        } else {
            this.entityEmb = loadWord2VecEmbedding(fromFile);
        }
        this.entityIndex = new EntityIndex(this.entityEmb);
        if (!resolve.exists()) {
            throw new IllegalArgumentException("Entity encoder needs to provide a knowledge base called entity.index.bin");
        }
        this.entityIndex.loadModel(resolve);
    }

    private ParVecWordsEncoder loadParVecEmbedding(Resource resource) throws IOException {
        new ParVecSentenceEncoder().loadModel(resource);
        ParVecWordsEncoder parVecWordsEncoder = new ParVecWordsEncoder();
        parVecWordsEncoder.loadModel(resource);
        return parVecWordsEncoder;
    }

    private Word2VecEncoder loadWord2VecEmbedding(Resource resource) throws IOException {
        Word2VecEncoder word2VecEncoder = new Word2VecEncoder();
        word2VecEncoder.loadModelAsReference(resource);
        word2VecEncoder.setPreprocessor(new IdentityPreprocessor());
        return word2VecEncoder;
    }

    private FastTextEncoder loadFastTextEmbedding(Resource resource) throws IOException {
        FastTextEncoder fastTextEncoder = new FastTextEncoder();
        fastTextEncoder.loadModelAsReference(resource);
        return fastTextEncoder;
    }

    private void saveModel(Annotator annotator, Resource resource) throws IOException {
        resource.toFile().mkdirs();
        annotator.writeModel(resource);
        annotator.writeTrainLog(resource);
        annotator.writeTestLog(resource);
        log.info("model written to {}", resource.toString());
    }

    public static Dataset readDatasetFromJSON(Resource resource) throws IOException {
        log.info("Reading Wiki Articles from {}", resource.toString());
        Dataset dataset = new Dataset(resource.getFileName().replace(".json", ""));
        Iterator readJSONDocumentIterable = ObjectSerializer.readJSONDocumentIterable(resource);
        while (readJSONDocumentIterable.hasNext()) {
            Document document = (Document) readJSONDocumentIterable.next();
            for (Annotation annotation : document.getAnnotations()) {
                annotation.setSource(Annotation.Source.GOLD);
                annotation.setConfidence(1.0d);
            }
            dataset.addDocument(document);
        }
        return dataset;
    }
}
