package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.class */
public class SecondIterationFunction implements FlatMapFunction<Iterator<Tuple2<List<VocabWord>, Long>>, Map.Entry<VocabWord, INDArray>> {
    private int vectorLength;
    private boolean useAdaGrad;
    private int batchSize;
    private double negative;
    private int window;
    private double alpha;
    private double minAlpha;
    private long totalWordCount;
    private long seed;
    private int maxExp;
    private double[] expTable;
    private int iterations;
    private volatile VocabCache<VocabWord> vocab;
    private volatile transient NegativeHolder negativeHolder;
    private volatile transient VocabHolder vocabHolder;
    private int ithIteration = 1;
    private AtomicLong nextRandom = new AtomicLong(5);
    private AtomicLong cid = new AtomicLong(0);
    private AtomicLong aff = new AtomicLong(0);

    public SecondIterationFunction(Broadcast<Map<String, Object>> broadcast, Broadcast<double[]> broadcast2, Broadcast<VocabCache<VocabWord>> broadcast3) {
        this.batchSize = 0;
        Map map = (Map) broadcast.getValue();
        this.expTable = (double[]) broadcast2.getValue();
        this.vectorLength = ((Integer) map.get("vectorLength")).intValue();
        this.useAdaGrad = ((Boolean) map.get("useAdaGrad")).booleanValue();
        this.negative = ((Double) map.get("negative")).doubleValue();
        this.window = ((Integer) map.get("window")).intValue();
        this.alpha = ((Double) map.get("alpha")).doubleValue();
        this.minAlpha = ((Double) map.get("minAlpha")).doubleValue();
        this.totalWordCount = ((Long) map.get("totalWordCount")).longValue();
        this.seed = ((Long) map.get("seed")).longValue();
        this.maxExp = ((Integer) map.get("maxExp")).intValue();
        this.iterations = ((Integer) map.get("iterations")).intValue();
        this.batchSize = ((Integer) map.get("batchSize")).intValue();
        this.vocab = (VocabCache) broadcast3.getValue();
        if (this.vocab == null) {
            throw new RuntimeException("VocabCache is null");
        }
    }

    public Iterator<Map.Entry<VocabWord, INDArray>> call(Iterator<Tuple2<List<VocabWord>, Long>> it) {
        this.vocabHolder = VocabHolder.getInstance();
        this.vocabHolder.setSeed(this.seed, this.vectorLength);
        if (this.negative > 0.0d) {
            this.negativeHolder = NegativeHolder.getInstance();
            this.negativeHolder.initHolder(this.vocab, this.expTable, this.vectorLength);
        }
        while (it.hasNext()) {
            ArrayList arrayList = new ArrayList();
            while (it.hasNext() && arrayList.size() < this.batchSize) {
                Tuple2<List<VocabWord>, Long> next = it.next();
                arrayList.add(Pair.of((List) next._1(), (Long) next._2()));
            }
            for (int i = 0; i < this.iterations; i++) {
                Iterator it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    trainSentence((List) ((Pair) it2.next()).getKey(), Math.max(this.minAlpha, this.alpha - ((this.alpha - this.minAlpha) * (((Long) r0.getValue()).longValue() / this.totalWordCount))));
                }
            }
        }
        return this.vocabHolder.getSplit(this.vocab).iterator();
    }

    public void trainSentence(List<VocabWord> list, double d) {
        if (list == null || list.isEmpty()) {
            return;
        }
        for (int i = 0; i < list.size(); i++) {
            this.nextRandom.set(Math.abs((this.nextRandom.get() * 25214903917L) + 11));
            int i2 = ((int) this.nextRandom.get()) % this.window;
            if (list.get(i) != null) {
                skipGram(i, list, i2, d);
            }
        }
    }

    public void skipGram(int i, List<VocabWord> list, int i2, double d) {
        int i3;
        VocabWord vocabWord = list.get(i);
        if (vocabWord == null || list.isEmpty()) {
            return;
        }
        int i4 = ((this.window * 2) + 1) - i2;
        for (int i5 = i2; i5 < i4; i5++) {
            if (i5 != this.window && (i3 = (i - this.window) + i5) >= 0 && i3 < list.size()) {
                iterateSample(vocabWord, list.get(i3), d);
            }
        }
    }

    public void iterateSample(VocabWord vocabWord, VocabWord vocabWord2, double d) {
        int i;
        double gradient;
        int length;
        if (vocabWord == null || vocabWord2 == null || vocabWord2.getIndex() < 0 || vocabWord2.getIndex() == vocabWord.getIndex()) {
            return;
        }
        int index = vocabWord2.getIndex();
        INDArray create = Nd4j.create(this.vectorLength);
        INDArray syn0Vector = this.vocabHolder.getSyn0Vector(Integer.valueOf(index), this.vocab);
        for (int i2 = 0; i2 < vocabWord.getCodeLength(); i2++) {
            byte byteValue = ((Byte) vocabWord.getCodes().get(i2)).byteValue();
            int intValue = ((Integer) vocabWord.getPoints().get(i2)).intValue();
            if (intValue < 0) {
                throw new IllegalStateException("Illegal point " + intValue);
            }
            INDArray syn1Vector = this.vocabHolder.getSyn1Vector(Integer.valueOf(intValue));
            double dot = Nd4j.getBlasWrapper().level1().dot(this.vectorLength, 1.0d, syn0Vector, syn1Vector);
            if (dot >= (-this.maxExp) && dot < this.maxExp && (length = (int) ((dot + this.maxExp) * ((this.expTable.length / this.maxExp) / 2.0d))) < this.expTable.length) {
                double gradient2 = ((1 - byteValue) - this.expTable[length]) * (this.useAdaGrad ? vocabWord.getGradient(i2, d, d) : d);
                Nd4j.getBlasWrapper().level1().axpy(this.vectorLength, gradient2, syn1Vector, create);
                Nd4j.getBlasWrapper().level1().axpy(this.vectorLength, gradient2, syn0Vector, syn1Vector);
            }
        }
        int index2 = vocabWord.getIndex();
        if (this.negative > 0.0d) {
            for (int i3 = 0; i3 < this.negative + 1.0d; i3++) {
                if (i3 == 0) {
                    i = 1;
                } else {
                    this.nextRandom.set(Math.abs((this.nextRandom.get() * 25214903917L) + 11));
                    index2 = this.negativeHolder.getTable().getInt(new int[]{(int) Math.abs(((int) (this.nextRandom.get() >> 16)) % this.negativeHolder.getTable().length())});
                    if (index2 <= 0) {
                        index2 = (((int) this.nextRandom.get()) % (this.vocab.numWords() - 1)) + 1;
                    }
                    if (index2 != vocabWord.getIndex()) {
                        i = 0;
                    }
                }
                if (index2 < this.negativeHolder.getSyn1Neg().rows() && index2 >= 0) {
                    double dot2 = Nd4j.getBlasWrapper().dot(syn0Vector, this.negativeHolder.getSyn1Neg().slice(index2));
                    if (dot2 > this.maxExp) {
                        gradient = this.useAdaGrad ? vocabWord.getGradient(index2, i - 1, this.alpha) : (i - 1) * this.alpha;
                    } else if (dot2 < (-this.maxExp)) {
                        gradient = i * (this.useAdaGrad ? vocabWord.getGradient(index2, this.alpha, this.alpha) : this.alpha);
                    } else {
                        int length2 = (int) ((dot2 + this.maxExp) * ((this.expTable.length / this.maxExp) / 2));
                        if (length2 < this.expTable.length) {
                            gradient = this.useAdaGrad ? vocabWord.getGradient(index2, i - this.expTable[length2], this.alpha) : (i - this.expTable[length2]) * this.alpha;
                        }
                    }
                    Nd4j.getBlasWrapper().level1().axpy(this.vectorLength, gradient, this.negativeHolder.getSyn1Neg().slice(index2), create);
                    Nd4j.getBlasWrapper().level1().axpy(this.vectorLength, gradient, syn0Vector, this.negativeHolder.getSyn1Neg().slice(index2));
                }
            }
        }
        Nd4j.getBlasWrapper().level1().axpy(this.vectorLength, 1.0d, create, syn0Vector);
    }

    private INDArray getRandomSyn0Vec(int i, long j) {
        return Nd4j.rand(new int[]{1, i}, j * this.seed).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(i));
    }
}
