package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.io.Serializable;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.NonNull;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.class */
public class NegativeHolder implements Serializable {
    private static NegativeHolder ourInstance = new NegativeHolder();
    private volatile INDArray syn1Neg;
    private volatile INDArray table;
    private transient AtomicBoolean wasInit = new AtomicBoolean(false);
    private transient VocabCache<VocabWord> vocab;

    public static NegativeHolder getInstance() {
        return ourInstance;
    }

    private NegativeHolder() {
    }

    public synchronized void initHolder(@NonNull VocabCache<VocabWord> vocabCache, double[] dArr, int i) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache is marked @NonNull but is null");
        }
        if (this.wasInit.get()) {
            return;
        }
        this.vocab = vocabCache;
        this.syn1Neg = Nd4j.zeros(vocabCache.numWords(), i);
        makeTable(Math.max(dArr.length, 100000), 0.75d);
        this.wasInit.set(true);
    }

    protected void makeTable(int i, double d) {
        int numWords = this.vocab.numWords();
        this.table = Nd4j.create(new FloatBuffer(i));
        double d2 = 0.0d;
        Iterator it = this.vocab.words().iterator();
        while (it.hasNext()) {
            d2 += Math.pow(this.vocab.wordFrequency((String) it.next()), d);
        }
        int i2 = 0;
        double pow = Math.pow(this.vocab.wordFrequency(this.vocab.wordAtIndex(0)), d) / d2;
        for (int i3 = 0; i3 < i; i3++) {
            this.table.putScalar(i3, i2);
            if ((i3 * 1.0d) / i > pow) {
                if (i2 < numWords - 1) {
                    i2++;
                }
                String wordAtIndex = this.vocab.wordAtIndex(i2);
                String wordAtIndex2 = this.vocab.wordAtIndex(i2);
                if (wordAtIndex != null) {
                    pow += Math.pow(this.vocab.wordFrequency(wordAtIndex2), d) / d2;
                }
            }
        }
    }

    public INDArray getSyn1Neg() {
        return this.syn1Neg;
    }

    public INDArray getTable() {
        return this.table;
    }
}
