package org.datavec.cli.transforms.text.nlp;

import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.NoSuchElementException;
import org.apache.commons.math3.util.Pair;
import org.datavec.api.berkeley.Counter;
import org.datavec.api.conf.Configuration;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Writable;
import org.datavec.cli.transforms.Transform;
import org.datavec.nlp.metadata.DefaultVocabCache;
import org.datavec.nlp.metadata.VocabCache;
import org.datavec.nlp.stopwords.StopWords;
import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess;
import org.datavec.nlp.tokenization.tokenizer.Tokenizer;
import org.datavec.nlp.tokenization.tokenizer.preprocessor.EndingPreProcessor;
import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/datavec/cli/transforms/text/nlp/TfidfTextVectorizerTransform.class */
public class TfidfTextVectorizerTransform implements Transform {
    protected TokenizerFactory tokenizerFactory;
    public static final String MIN_WORD_FREQUENCY = "org.nd4j.nlp.minwordfrequency";
    public static final String STOP_WORDS = "org.nd4j.nlp.stopwords";
    public static final String TOKENIZER = "org.datavec.nlp.tokenizerfactory";
    protected Collection<String> stopWords;
    protected VocabCache cache;
    protected int minWordFrequency = 0;
    public Map<String, Pair<Integer, Integer>> recordLabels = new LinkedHashMap();
    final EndingPreProcessor preProcessor = new EndingPreProcessor();

    public int getVocabularySize() {
        return this.cache.vocabWords().size();
    }

    public void debugPrintVocabList() {
        System.out.println("Vocabulary Words: ");
        for (int i = 0; i < this.cache.vocabWords().size(); i++) {
            System.out.println(i + ". " + this.cache.wordAt(i));
        }
    }

    public void doWithTokens(Tokenizer tokenizer) {
        HashSet hashSet = new HashSet();
        while (tokenizer.hasMoreTokens()) {
            String nextToken = tokenizer.nextToken();
            this.cache.incrementCount(nextToken);
            if (!hashSet.contains(nextToken)) {
                this.cache.incrementDocCount(nextToken);
            }
        }
    }

    public TokenizerFactory createTokenizerFactory(Configuration configuration) {
        try {
            return (TokenizerFactory) Class.forName(configuration.get(TOKENIZER, DefaultTokenizerFactory.class.getName())).newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void initialize(Configuration configuration) {
        this.tokenizerFactory = createTokenizerFactory(configuration);
        this.tokenizerFactory.setTokenPreProcessor(new TokenPreProcess() { // from class: org.datavec.cli.transforms.text.nlp.TfidfTextVectorizerTransform.1
            public String preProcess(String str) {
                if (!str.startsWith("http://")) {
                    str = str.replaceAll("[^a-zA-Z ]", "").toLowerCase();
                }
                String replaceAll = TfidfTextVectorizerTransform.this.preProcessor.preProcess(str).replaceAll("\\d", "d");
                if (replaceAll.endsWith("ly") || replaceAll.endsWith("ing")) {
                    System.out.println();
                }
                return replaceAll;
            }
        });
        this.minWordFrequency = configuration.getInt(MIN_WORD_FREQUENCY, 5);
        this.stopWords = configuration.getStringCollection(STOP_WORDS);
        if (this.stopWords == null || this.stopWords.isEmpty()) {
            this.stopWords = StopWords.getStopWords();
        }
        this.cache = new DefaultVocabCache(this.minWordFrequency);
    }

    protected Counter<String> wordFrequenciesForSentence(String str) {
        Tokenizer create = this.tokenizerFactory.create(str);
        Counter<String> counter = new Counter<>();
        while (create.hasMoreTokens()) {
            try {
                counter.incrementCount(create.nextToken(), 1.0d);
            } catch (NoSuchElementException e) {
                System.out.println("Bad Token");
            }
        }
        return counter;
    }

    public INDArray convertTextRecordToTFIDFVector(String str) {
        Counter<String> wordFrequenciesForSentence = wordFrequenciesForSentence(str);
        INDArray create = Nd4j.create(this.cache.vocabWords().size());
        int numDocs = (int) this.cache.numDocs();
        for (int i = 0; i < this.cache.vocabWords().size(); i++) {
            String wordAt = this.cache.wordAt(i);
            create.putScalar(i, NLPUtils.tfidf(NLPUtils.tf((int) wordFrequenciesForSentence.getCount(this.cache.wordAt(i))), NLPUtils.idf(numDocs, (int) this.cache.idf(wordAt))));
        }
        return create;
    }

    @Override // org.datavec.cli.transforms.Transform
    public void collectStatistics(Collection<Writable> collection) {
        String obj = collection.toArray()[1].toString();
        String obj2 = collection.toArray()[0].toString();
        trackLabel(obj);
        Tokenizer create = this.tokenizerFactory.create(obj2);
        this.cache.incrementNumDocs(1.0d);
        doWithTokens(create);
    }

    private void trackLabel(String str) {
        String trim = str.trim();
        if (!this.recordLabels.containsKey(trim)) {
            this.recordLabels.put(trim, new Pair<>(Integer.valueOf(this.recordLabels.size()), 1));
        } else {
            this.recordLabels.put(trim, new Pair<>((Integer) this.recordLabels.get(trim).getFirst(), Integer.valueOf(((Integer) this.recordLabels.get(trim).getSecond()).intValue() + 1)));
        }
    }

    public int getNumberOfLabelsSeen() {
        return this.recordLabels.keySet().size();
    }

    public Integer getLabelID(String str) {
        if (this.recordLabels.containsKey(str)) {
            return (Integer) this.recordLabels.get(str).getFirst();
        }
        return null;
    }

    @Override // org.datavec.cli.transforms.Transform
    public void transform(Collection<Writable> collection) {
        if (collection.size() != 2) {
            return;
        }
        String obj = collection.toArray()[0].toString();
        Integer labelID = getLabelID(collection.toArray()[1].toString());
        INDArray convertTextRecordToTFIDFVector = convertTextRecordToTFIDFVector(obj);
        collection.clear();
        for (int i = 0; i < convertTextRecordToTFIDFVector.columns(); i++) {
            collection.add(new DoubleWritable(convertTextRecordToTFIDFVector.getDouble(0, i)));
        }
        collection.add(new DoubleWritable(labelID.intValue()));
    }

    @Override // org.datavec.cli.transforms.Transform
    public void evaluateStatistics() {
    }
}
