package org.deeplearning4j.nlp.uima.corpora.treeparser;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.uima.analysis_engine.AnalysisEngine;
import org.apache.uima.analysis_engine.AnalysisEngineDescription;
import org.apache.uima.cas.CAS;
import org.apache.uima.fit.factory.AnalysisEngineFactory;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.util.CasPool;
import org.cleartk.opennlp.tools.ParserAnnotator;
import org.cleartk.opennlp.tools.parser.DefaultOutputTypesHelper;
import org.cleartk.syntax.constituent.type.TopTreebankNode;
import org.cleartk.syntax.constituent.type.TreebankNode;
import org.cleartk.token.type.Sentence;
import org.cleartk.token.type.Token;
import org.cleartk.util.ParamUtil;
import org.deeplearning4j.nlp.uima.annotator.PoStagger;
import org.deeplearning4j.nlp.uima.annotator.SentenceAnnotator;
import org.deeplearning4j.nlp.uima.annotator.StemmerAnnotator;
import org.deeplearning4j.nlp.uima.annotator.TokenizerAnnotator;
import org.deeplearning4j.nlp.uima.tokenization.tokenizerfactory.UimaTokenizerFactory;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive.Tree;
import org.deeplearning4j.text.movingwindow.ContextLabelRetriever;
import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor;
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.common.collection.MultiDimensionalMap;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.SetUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nlp/uima/corpora/treeparser/TreeParser.class */
public class TreeParser {
    private AnalysisEngine parser;
    private AnalysisEngine tokenizer;
    private CasPool pool;
    private static final Logger log = LoggerFactory.getLogger(TreeParser.class);
    private TokenizerFactory tf;

    public TreeParser(AnalysisEngine analysisEngine, AnalysisEngine analysisEngine2, CasPool casPool) {
        this.parser = analysisEngine;
        this.tokenizer = analysisEngine2;
        this.pool = casPool;
        this.tf = new UimaTokenizerFactory(analysisEngine2, true);
    }

    public TreeParser() throws Exception {
        if (this.parser == null) {
            this.parser = getParser();
        }
        if (this.tokenizer == null) {
            this.tokenizer = getTokenizer();
        }
        if (this.pool == null) {
            this.pool = new CasPool(Runtime.getRuntime().availableProcessors(), this.parser);
        }
        this.tf = new UimaTokenizerFactory(this.tokenizer, true);
    }

    public List<Tree> getTrees(String str, SentencePreProcessor sentencePreProcessor) throws Exception {
        if (str.isEmpty()) {
            return new ArrayList();
        }
        CAS cas = this.pool.getCas();
        if (sentencePreProcessor != null) {
            str = sentencePreProcessor.preProcess(str);
        }
        cas.setDocumentText(str);
        this.tokenizer.process(cas);
        ArrayList arrayList = new ArrayList();
        CAS cas2 = this.pool.getCas();
        ArrayList arrayList2 = new ArrayList();
        for (Sentence sentence : JCasUtil.select(cas.getJCas(), Sentence.class)) {
            ArrayList arrayList3 = new ArrayList();
            Iterator it = JCasUtil.selectCovered(Token.class, sentence).iterator();
            while (it.hasNext()) {
                arrayList3.add(((Token) it.next()).getCoveredText());
            }
            Pair stringWithLabels = ContextLabelRetriever.stringWithLabels(sentence.getCoveredText(), this.tf);
            cas2.setDocumentText((String) stringWithLabels.getFirst());
            arrayList2.add(stringWithLabels);
            this.tokenizer.process(cas2);
            this.parser.process(cas2);
            arrayList.add(TreeFactory.buildTree(JCasUtil.selectSingle(cas.getJCas(), TopTreebankNode.class)));
        }
        this.pool.releaseCas(cas2);
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            addPreTerminal((Tree) it2.next());
        }
        return arrayList;
    }

    private void addPreTerminal(Tree tree) {
        if (!tree.isLeaf()) {
            Iterator it = tree.children().iterator();
            while (it.hasNext()) {
                addPreTerminal((Tree) it.next());
            }
        } else {
            Tree tree2 = new Tree(tree);
            tree2.setLabel(tree.value());
            tree.children().add(tree2);
            tree2.setParent(tree);
        }
    }

    public List<TreebankNode> getTreebankTrees(String str) throws Exception {
        if (str.isEmpty()) {
            return new ArrayList();
        }
        CAS cas = this.pool.getCas();
        cas.setDocumentText(str);
        this.tokenizer.process(cas);
        ArrayList arrayList = new ArrayList();
        for (Sentence sentence : JCasUtil.select(cas.getJCas(), Sentence.class)) {
            ArrayList arrayList2 = new ArrayList();
            CAS newCAS = this.tokenizer.newCAS();
            Iterator it = JCasUtil.selectCovered(Token.class, sentence).iterator();
            while (it.hasNext()) {
                arrayList2.add(((Token) it.next()).getCoveredText());
            }
            newCAS.setDocumentText(sentence.getCoveredText());
            this.tokenizer.process(newCAS);
            this.parser.process(newCAS);
            arrayList.add(JCasUtil.selectSingle(newCAS.getJCas(), TopTreebankNode.class));
        }
        this.pool.releaseCas(cas);
        return arrayList;
    }

    public List<Tree> getTreesWithLabels(String str, String str2, List<String> list) throws Exception {
        if (str.isEmpty()) {
            return new ArrayList();
        }
        CAS cas = this.pool.getCas();
        cas.setDocumentText("<" + str2 + "> " + str + " </" + str2 + ">");
        this.tokenizer.process(cas);
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().toLowerCase());
        }
        ArrayList arrayList2 = new ArrayList();
        CAS cas2 = this.pool.getCas();
        for (Sentence sentence : JCasUtil.select(cas.getJCas(), Sentence.class)) {
            if (!sentence.getCoveredText().isEmpty()) {
                ArrayList arrayList3 = new ArrayList();
                Iterator it2 = JCasUtil.selectCovered(Token.class, sentence).iterator();
                while (it2.hasNext()) {
                    arrayList3.add(((Token) it2.next()).getCoveredText());
                }
                try {
                    Pair stringWithLabels = ContextLabelRetriever.stringWithLabels(sentence.getCoveredText(), this.tf);
                    cas2.setDocumentText((String) stringWithLabels.getFirst());
                    this.tokenizer.process(cas2);
                    this.parser.process(cas2);
                    ArrayList arrayList4 = new ArrayList(JCasUtil.select(cas2.getJCas(), TopTreebankNode.class));
                    if (arrayList4.size() > 1) {
                        log.warn("More than one top level node for a treebank parse. Only accepting first input node.");
                    } else if (arrayList4.isEmpty()) {
                        cas2.reset();
                    }
                    arrayList2.add(TreeFactory.buildTree((TopTreebankNode) arrayList4.get(0), stringWithLabels, arrayList));
                    cas2.reset();
                } catch (Exception e) {
                    log.warn("Unable to parse " + sentence.getCoveredText());
                    cas2.reset();
                }
            }
        }
        this.pool.releaseCas(cas);
        this.pool.releaseCas(cas2);
        return arrayList2;
    }

    public List<Tree> getTreesWithLabels(String str, List<String> list) throws Exception {
        CAS cas = this.pool.getCas();
        cas.setDocumentText(str);
        this.tokenizer.process(cas);
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().toLowerCase());
        }
        ArrayList arrayList2 = new ArrayList();
        CAS cas2 = this.pool.getCas();
        for (Sentence sentence : JCasUtil.select(cas.getJCas(), Sentence.class)) {
            ArrayList arrayList3 = new ArrayList();
            Iterator it2 = JCasUtil.selectCovered(Token.class, sentence).iterator();
            while (it2.hasNext()) {
                arrayList3.add(((Token) it2.next()).getCoveredText());
            }
            Pair stringWithLabels = ContextLabelRetriever.stringWithLabels(sentence.getCoveredText(), this.tf);
            cas2.setDocumentText((String) stringWithLabels.getFirst());
            this.tokenizer.process(cas2);
            this.parser.process(cas2);
            ArrayList arrayList4 = new ArrayList(JCasUtil.select(cas2.getJCas(), TopTreebankNode.class));
            if (arrayList4.size() > 1) {
                log.warn("More than one top level node for a treebank parse. Only accepting first input node.");
            } else if (arrayList4.isEmpty()) {
                cas2.reset();
            }
            if (SetUtils.difference(((MultiDimensionalMap) stringWithLabels.getSecond()).values(), arrayList).isEmpty()) {
                arrayList2.add(TreeFactory.buildTree((TopTreebankNode) arrayList4.get(0), stringWithLabels, arrayList));
                cas2.reset();
            } else {
                log.warn("Found invalid sentence. Skipping");
                cas2.reset();
            }
        }
        this.pool.releaseCas(cas);
        this.pool.releaseCas(cas2);
        return arrayList2;
    }

    public List<Tree> getTrees(String str) throws Exception {
        CAS cas = this.pool.getCas();
        cas.setDocumentText(str);
        this.tokenizer.process(cas);
        ArrayList arrayList = new ArrayList();
        CAS cas2 = this.pool.getCas();
        for (Sentence sentence : JCasUtil.select(cas.getJCas(), Sentence.class)) {
            ArrayList arrayList2 = new ArrayList();
            Iterator it = JCasUtil.selectCovered(Token.class, sentence).iterator();
            while (it.hasNext()) {
                arrayList2.add(((Token) it.next()).getCoveredText());
            }
            cas2.setDocumentText(sentence.getCoveredText());
            this.tokenizer.process(cas2);
            this.parser.process(cas2);
            TopTreebankNode selectSingle = JCasUtil.selectSingle(cas2.getJCas(), TopTreebankNode.class);
            log.info("Tree bank parse " + selectSingle.getTreebankParse());
            for (TreebankNode treebankNode : JCasUtil.select(cas2.getJCas(), TreebankNode.class)) {
                log.info("Node val " + treebankNode.getNodeValue() + " and label " + treebankNode.getNodeType() + " and tags was " + treebankNode.getNodeTags());
            }
            arrayList.add(TreeFactory.buildTree(selectSingle));
            cas2.reset();
        }
        this.pool.releaseCas(cas);
        this.pool.releaseCas(cas2);
        return arrayList;
    }

    public static AnalysisEngine getTokenizer() throws Exception {
        return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription(new AnalysisEngineDescription[]{SentenceAnnotator.getDescription(), TokenizerAnnotator.getDescription(), PoStagger.getDescription("en"), StemmerAnnotator.getDescription("English")}), new Object[0]);
    }

    public static AnalysisEngine getParser() throws Exception {
        return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription(new AnalysisEngineDescription[]{AnalysisEngineFactory.createEngineDescription(ParserAnnotator.class, new Object[]{"useTagsFromCas", true, "parserModelPath", ParamUtil.getParameterValue("parserModelPath", "/models/en-parser-chunking.bin"), "outputTypesHelperClassName", DefaultOutputTypesHelper.class.getName()})}), new Object[0]);
    }
}
