package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.text.functions.CountCumSum;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.class */
public class Word2Vec extends WordVectorsImpl<VocabWord> implements Serializable {
    private INDArray trainedSyn1;
    private static Logger log = LoggerFactory.getLogger(Word2Vec.class);
    private int MAX_EXP = 6;
    private int vectorLength = 100;
    private boolean useAdaGrad = false;
    private int negative = 0;
    private int numWords = 1;
    private int window = 5;
    private double alpha = 0.025d;
    private double minAlpha = 1.0E-4d;
    private int iterations = 1;
    private int nGrams = 1;
    private String tokenizer = "org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory";
    private String tokenPreprocessor = "org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor";
    private boolean removeStop = false;
    private long seed = 42;
    private double[] expTable = initExpTable();

    public Word2Vec(INDArray iNDArray) {
        this.trainedSyn1 = iNDArray;
    }

    public Word2Vec() {
    }

    public double[] initExpTable() {
        double[] dArr = new double[1000];
        for (int i = 0; i < dArr.length; i++) {
            double exp = FastMath.exp((((i / dArr.length) * 2.0d) - 1.0d) * this.MAX_EXP);
            dArr[i] = exp / (exp + 1.0d);
        }
        return dArr;
    }

    public Map<String, Object> getTokenizerVarMap() {
        return new HashMap<String, Object>() { // from class: org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec.1
            {
                put("numWords", Integer.valueOf(Word2Vec.this.numWords));
                put("nGrams", Integer.valueOf(Word2Vec.this.nGrams));
                put("tokenizer", Word2Vec.this.tokenizer);
                put("tokenPreprocessor", Word2Vec.this.tokenPreprocessor);
                put("removeStop", Boolean.valueOf(Word2Vec.this.removeStop));
            }
        };
    }

    public Map<String, Object> getWord2vecVarMap() {
        return new HashMap<String, Object>() { // from class: org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec.2
            {
                put("vectorLength", Integer.valueOf(Word2Vec.this.vectorLength));
                put("useAdaGrad", Boolean.valueOf(Word2Vec.this.useAdaGrad));
                put("negative", Integer.valueOf(Word2Vec.this.negative));
                put("window", Integer.valueOf(Word2Vec.this.window));
                put("alpha", Double.valueOf(Word2Vec.this.alpha));
                put("minAlpha", Double.valueOf(Word2Vec.this.minAlpha));
                put("iterations", Integer.valueOf(Word2Vec.this.iterations));
                put("seed", Long.valueOf(Word2Vec.this.seed));
                put("maxExp", Integer.valueOf(Word2Vec.this.MAX_EXP));
            }
        };
    }

    public void train(JavaRDD<String> javaRDD) throws Exception {
        log.info("Start training ...");
        JavaSparkContext javaSparkContext = new JavaSparkContext(javaRDD.context());
        Map<String, Object> tokenizerVarMap = getTokenizerVarMap();
        Map<String, Object> word2vecVarMap = getWord2vecVarMap();
        log.info("Tokenization and building VocabCache ...");
        TextPipeline textPipeline = new TextPipeline(javaRDD, javaSparkContext.broadcast(tokenizerVarMap));
        textPipeline.buildVocabCache();
        textPipeline.buildVocabWordListRDD();
        word2vecVarMap.put("totalWordCount", textPipeline.getTotalWordCount());
        JavaRDD<AtomicLong> sentenceCountRDD = textPipeline.getSentenceCountRDD();
        JavaRDD<List<VocabWord>> vocabWordListRDD = textPipeline.getVocabWordListRDD();
        VocabCache vocabCache = (VocabCache) textPipeline.getBroadCastVocabCache().getValue();
        log.info("Building Huffman Tree ...");
        new Huffman(vocabCache.vocabWords()).build();
        log.info("Calculating cumulative sum of sentence counts ...");
        JavaRDD<Long> buildCumSum = new CountCumSum(sentenceCountRDD).buildCumSum();
        log.info("Mapping to RDD(vocabWordList, cumulative sentence count) ...");
        JavaPairRDD cache = vocabWordListRDD.zip(buildCumSum).setName("vocabWordListSentenceCumSumRDD").cache();
        log.info("Broadcasting word2vec variables to workers ...");
        Broadcast broadcast = javaSparkContext.broadcast(word2vecVarMap);
        Broadcast broadcast2 = javaSparkContext.broadcast(this.expTable);
        log.info("Training word2vec sentences ...");
        List<Pair> collect = cache.mapPartitions(new FirstIterationFunction(broadcast, broadcast2)).map(new MapToPairFunction()).collect();
        INDArray create = Nd4j.create(vocabCache.numWords(), this.vectorLength);
        for (Pair pair : collect) {
            create.getRow(((Integer) pair.getFirst()).intValue()).addi((INDArray) pair.getSecond());
        }
        this.vocab = vocabCache;
        InMemoryLookupTable inMemoryLookupTable = new InMemoryLookupTable();
        inMemoryLookupTable.setVocab(vocabCache);
        inMemoryLookupTable.setVectorLength(this.vectorLength);
        inMemoryLookupTable.setSyn0(create);
        this.lookupTable = inMemoryLookupTable;
    }

    public int getVectorLength() {
        return this.vectorLength;
    }

    public Word2Vec setVectorLength(int i) {
        this.vectorLength = i;
        return this;
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public Word2Vec setUseAdaGrad(boolean z) {
        this.useAdaGrad = z;
        return this;
    }

    public int getNegative() {
        return this.negative;
    }

    public Word2Vec setNegative(int i) {
        this.negative = i;
        return this;
    }

    public int getNumWords() {
        return this.numWords;
    }

    public Word2Vec setNumWords(int i) {
        this.numWords = i;
        return this;
    }

    public int getWindow() {
        return this.window;
    }

    public Word2Vec setWindow(int i) {
        this.window = i;
        return this;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public Word2Vec setAlpha(double d) {
        this.alpha = d;
        return this;
    }

    public double getMinAlpha() {
        return this.minAlpha;
    }

    public Word2Vec setMinAlpha(double d) {
        this.minAlpha = d;
        return this;
    }

    public int getIterations() {
        return this.iterations;
    }

    public Word2Vec setIterations(int i) {
        this.iterations = i;
        return this;
    }

    public int getnGrams() {
        return this.nGrams;
    }

    public Word2Vec setnGrams(int i) {
        this.nGrams = i;
        return this;
    }

    public String getTokenizer() {
        return this.tokenizer;
    }

    public Word2Vec setTokenizer(String str) {
        this.tokenizer = str;
        return this;
    }

    public String getTokenPreprocessor() {
        return this.tokenPreprocessor;
    }

    public Word2Vec setTokenPreprocessor(String str) {
        this.tokenPreprocessor = str;
        return this;
    }

    public boolean isRemoveStop() {
        return this.removeStop;
    }

    public Word2Vec setRemoveStop(boolean z) {
        this.removeStop = z;
        return this;
    }

    public long getSeed() {
        return this.seed;
    }

    public Word2Vec setSeed(long j) {
        this.seed = j;
        return this;
    }

    public double[] getExpTable() {
        return this.expTable;
    }
}
