package org.deeplearning4j.spark.text.functions;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.Accumulator;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.spark.text.accumulators.WordFreqAccumulator;
import org.nd4j.common.primitives.AtomicDouble;
import org.nd4j.common.primitives.Counter;
import org.nd4j.common.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/spark/text/functions/TextPipeline.class */
public class TextPipeline {
    private JavaRDD<String> corpusRDD;
    private int numWords;
    private int nGrams;
    private String tokenizer;
    private String tokenizerPreprocessor;
    private JavaSparkContext sc;
    private Accumulator<Counter<String>> wordFreqAcc;
    private Broadcast<List<String>> stopWordBroadCast;
    private JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD;
    private Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast;
    private JavaRDD<List<VocabWord>> vocabWordListRDD;
    private JavaRDD<AtomicLong> sentenceCountRDD;
    private long totalWordCount;
    private boolean useUnk;
    private VectorsConfiguration configuration;
    private List<String> stopWords = new ArrayList();
    private VocabCache<VocabWord> vocabCache = new AbstractCache();

    public TextPipeline() {
    }

    public TextPipeline(JavaRDD<String> javaRDD, Broadcast<Map<String, Object>> broadcast) throws Exception {
        setRDDVarMap(javaRDD, broadcast);
        setup();
    }

    public void setRDDVarMap(JavaRDD<String> javaRDD, Broadcast<Map<String, Object>> broadcast) {
        Map map = (Map) broadcast.getValue();
        this.corpusRDD = javaRDD;
        this.numWords = ((Integer) map.get("numWords")).intValue();
        this.nGrams = ((Integer) map.get("nGrams")).intValue();
        this.tokenizer = (String) map.get("tokenizer");
        this.tokenizerPreprocessor = (String) map.get("tokenPreprocessor");
        this.useUnk = ((Boolean) map.get("useUnk")).booleanValue();
        this.configuration = (VectorsConfiguration) map.get("vectorsConfiguration");
        this.stopWords = (List) map.get("stopWords");
    }

    private void setup() {
        this.sc = new JavaSparkContext(this.corpusRDD.context());
        this.wordFreqAcc = this.sc.accumulator(new Counter(), new WordFreqAccumulator());
        this.stopWordBroadCast = this.sc.broadcast(this.stopWords);
    }

    public JavaRDD<List<String>> tokenize() {
        if (this.corpusRDD == null) {
            throw new IllegalStateException("corpusRDD not assigned. Define TextPipeline with corpusRDD assigned.");
        }
        return this.corpusRDD.map(new TokenizerFunction(this.tokenizer, this.tokenizerPreprocessor, this.nGrams));
    }

    public JavaRDD<Pair<List<String>, AtomicLong>> updateAndReturnAccumulatorVal(JavaRDD<List<String>> javaRDD) {
        JavaRDD<Pair<List<String>, AtomicLong>> map = javaRDD.map(new UpdateWordFreqAccumulatorFunction(this.stopWordBroadCast, this.wordFreqAcc));
        map.count();
        return map;
    }

    private String filterMinWord(String str, double d) {
        return d < ((double) this.numWords) ? this.configuration.getUNK() : str;
    }

    private void addTokenToVocabCache(String str, Float f) {
        VocabWord vocabWord;
        if (this.vocabCache.hasToken(str)) {
            vocabWord = (VocabWord) this.vocabCache.tokenFor(str);
            vocabWord.increaseElementFrequency(f.intValue());
        } else {
            vocabWord = new VocabWord(f.floatValue(), str);
        }
        if (this.vocabCache.containsWord(str)) {
            return;
        }
        int numWords = this.vocabCache.numWords();
        this.vocabCache.addToken(vocabWord);
        vocabWord.setIndex(numWords);
        this.vocabCache.putVocabWord(str);
    }

    public void filterMinWordAddVocab(Counter<String> counter) {
        if (counter.isEmpty()) {
            throw new IllegalStateException("IllegalStateException: wordFreqCounter has nothing. Check accumulator updating");
        }
        for (Map.Entry entry : counter.entrySet()) {
            String str = (String) entry.getKey();
            Double valueOf = Double.valueOf(((AtomicDouble) entry.getValue()).doubleValue());
            String filterMinWord = filterMinWord(str, valueOf.doubleValue());
            if (this.useUnk || !filterMinWord.equals("UNK")) {
                addTokenToVocabCache(filterMinWord, Float.valueOf(valueOf.floatValue()));
            }
        }
    }

    public void buildVocabCache() {
        this.sentenceWordsCountRDD = updateAndReturnAccumulatorVal(tokenize()).cache();
        filterMinWordAddVocab((Counter) this.wordFreqAcc.value());
        Huffman huffman = new Huffman(this.vocabCache.vocabWords());
        huffman.build();
        huffman.applyIndexes(this.vocabCache);
        this.vocabCacheBroadcast = this.sc.broadcast(this.vocabCache);
    }

    public void buildVocabWordListRDD() {
        if (this.sentenceWordsCountRDD == null) {
            throw new IllegalStateException("SentenceWordCountRDD must be defined first. Run buildLookupCache first.");
        }
        this.vocabWordListRDD = this.sentenceWordsCountRDD.map(new WordsListToVocabWordsFunction(this.vocabCacheBroadcast)).setName("vocabWordListRDD").cache();
        this.sentenceCountRDD = this.sentenceWordsCountRDD.map(new GetSentenceCountFunction()).setName("sentenceCountRDD").cache();
        this.vocabWordListRDD.count();
        this.totalWordCount = ((AtomicLong) this.sentenceCountRDD.reduce(new ReduceSentenceCount())).get();
        this.sentenceWordsCountRDD.unpersist();
    }

    public Accumulator<Counter<String>> getWordFreqAcc() {
        if (this.wordFreqAcc != null) {
            return this.wordFreqAcc;
        }
        throw new IllegalStateException("IllegalStateException: wordFreqAcc not set at TextPipline.");
    }

    public Broadcast<VocabCache<VocabWord>> getBroadCastVocabCache() throws IllegalStateException {
        if (this.vocabCache.numWords() > 0) {
            return this.vocabCacheBroadcast;
        }
        throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
    }

    public VocabCache<VocabWord> getVocabCache() throws IllegalStateException {
        if (this.vocabCache == null || this.vocabCache.numWords() <= 0) {
            throw new IllegalStateException("IllegalStateException: VocabCache not set at TextPipline.");
        }
        return this.vocabCache;
    }

    public JavaRDD<Pair<List<String>, AtomicLong>> getSentenceWordsCountRDD() {
        if (this.sentenceWordsCountRDD != null) {
            return this.sentenceWordsCountRDD;
        }
        throw new IllegalStateException("IllegalStateException: sentenceWordsCountRDD not set at TextPipline.");
    }

    public JavaRDD<List<VocabWord>> getVocabWordListRDD() throws IllegalStateException {
        if (this.vocabWordListRDD != null) {
            return this.vocabWordListRDD;
        }
        throw new IllegalStateException("IllegalStateException: vocabWordListRDD not set at TextPipline.");
    }

    public JavaRDD<AtomicLong> getSentenceCountRDD() throws IllegalStateException {
        if (this.sentenceCountRDD != null) {
            return this.sentenceCountRDD;
        }
        throw new IllegalStateException("IllegalStateException: sentenceCountRDD not set at TextPipline.");
    }

    public Long getTotalWordCount() {
        if (this.totalWordCount != 0) {
            return Long.valueOf(this.totalWordCount);
        }
        throw new IllegalStateException("IllegalStateException: totalWordCount not set at TextPipline.");
    }
}
