package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecParam;
import org.deeplearning4j.spark.text.TextPipeline;
import org.deeplearning4j.spark.text.TokenizerFunction;
import org.deeplearning4j.spark.text.TokentoVocabWord;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.class */
public class Word2Vec implements Serializable {
    private Broadcast<VocabCache> vocabCacheBroadcast;
    private String tokenizerFactoryClazz;
    private InMemoryLookupTable table;
    private static Logger log = LoggerFactory.getLogger(Word2Vec.class);

    public Word2Vec(String str, InMemoryLookupTable inMemoryLookupTable) {
        this.tokenizerFactoryClazz = str;
        this.table = inMemoryLookupTable;
    }

    public Word2Vec(String str) {
        this.tokenizerFactoryClazz = str;
    }

    public Word2Vec() {
        this(DefaultTokenizerFactory.class.getName());
    }

    public Pair<VocabCache, WeightLookupTable> train(JavaRDD<String> javaRDD) {
        Pair<VocabCache, Long> process = new TextPipeline(javaRDD).process(this.tokenizerFactoryClazz);
        SparkConf conf = javaRDD.context().getConf();
        JavaSparkContext javaSparkContext = new JavaSparkContext(javaRDD.context());
        this.vocabCacheBroadcast = javaSparkContext.broadcast(process.getFirst());
        final InMemoryLookupTable build = this.table != null ? this.table : new InMemoryLookupTable.Builder().cache((VocabCache) process.getFirst()).lr(conf.getDouble("org.deeplearning4j.scaleout.perform.models.word2vec.alpha", 0.025d)).vectorLength(conf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.length", 100)).negative(conf.getDouble("org.deeplearning4j.scaleout.perform.models.word2vec.negative", 5.0d)).useAdaGrad(conf.getBoolean("org.deeplearning4j.scaleout.perform.models.word2vec.adagrad", false)).build();
        if (this.table == null) {
            build.resetWeights();
        }
        new Huffman(((VocabCache) process.getFirst()).vocabWords()).build();
        JavaRDD map = javaRDD.map(new TokenizerFunction(this.tokenizerFactoryClazz)).map(new TokentoVocabWord(this.vocabCacheBroadcast));
        Word2VecParam build2 = new Word2VecParam.Builder().negative(build.getNegative()).window(conf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.window", 5)).expTable(javaSparkContext.broadcast(build.getExpTable())).setAlpha(build.getLr().get()).setMinAlpha(0.01d).setVectorLength(build.getVectorLength()).useAdaGrad(build.isUseAdaGrad()).weights(build).build();
        build2.setTotalWords(((Long) process.getSecond()).intValue());
        final ArrayList arrayList = new ArrayList();
        long j = 0;
        List collect = map.map(new Function<Pair<List<VocabWord>, AtomicLong>, AtomicLong>() { // from class: org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec.1
            public AtomicLong call(Pair<List<VocabWord>, AtomicLong> pair) throws Exception {
                return (AtomicLong) pair.getSecond();
            }
        }).collect();
        for (int i = 0; i < collect.size(); i++) {
            arrayList.add(Long.valueOf(j + ((AtomicLong) collect.get(i)).get()));
            j += ((AtomicLong) collect.get(i)).get();
        }
        JavaPairRDD cache = map.map(new Function<Pair<List<VocabWord>, AtomicLong>, List<VocabWord>>() { // from class: org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec.2
            public List<VocabWord> call(Pair<List<VocabWord>, AtomicLong> pair) throws Exception {
                return (List) pair.getFirst();
            }
        }).zipWithIndex().mapToPair(new PairFunction<Tuple2<List<VocabWord>, Long>, List<VocabWord>, Long>() { // from class: org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec.3
            public Tuple2<List<VocabWord>, Long> call(Tuple2<List<VocabWord>, Long> tuple2) throws Exception {
                return new Tuple2<>(tuple2._1(), arrayList.get(((Long) tuple2._2()).intValue()));
            }
        }).cache();
        for (int i2 = 0; i2 < conf.getInt("org.deeplearning4j.scaleout.perform.models.word2vec.iterations", 5); i2++) {
            JavaRDD map2 = cache.map(new Word2VecSetup(javaSparkContext.broadcast(build2))).map(new SentenceBatch());
            map2.foreach(new VoidFunction<Word2VecChange>() { // from class: org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec.4
                public void call(Word2VecChange word2VecChange) throws Exception {
                    word2VecChange.apply(build);
                }
            });
            map2.unpersist();
            log.info("Iteration " + i2);
        }
        return new Pair<>(this.vocabCacheBroadcast.getValue(), build);
    }
}
