package org.deeplearning4j.models.embeddings.learning.impl.elements;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.impl.AggregateCBOW;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.DeviceLocalNDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.class */
public class CBOW<T extends SequenceElement> implements ElementsLearningAlgorithm<T> {
    private VocabCache<T> vocabCache;
    private WeightLookupTable<T> lookupTable;
    private VectorsConfiguration configuration;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) CBOW.class);
    protected static double MAX_EXP = 6.0d;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected double sampling;
    protected int[] variableWindows;
    protected DeviceLocalNDArray syn0;
    protected DeviceLocalNDArray syn1;
    protected DeviceLocalNDArray syn1Neg;
    protected DeviceLocalNDArray expTable;
    protected DeviceLocalNDArray table;
    protected ThreadLocal<List<Aggregate>> batches = new ThreadLocal<>();

    public List<Aggregate> getBatch() {
        return this.batches.get();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public String getCodeName() {
        return "CBOW";
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> weightLookupTable, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache is marked @NonNull but is null");
        }
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable is marked @NonNull but is null");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = weightLookupTable;
        this.configuration = vectorsConfiguration;
        this.window = vectorsConfiguration.getWindow();
        this.useAdaGrad = vectorsConfiguration.isUseAdaGrad();
        this.negative = vectorsConfiguration.getNegative();
        this.sampling = vectorsConfiguration.getSampling();
        if (vectorsConfiguration.getNegative() > 0.0d && ((InMemoryLookupTable) weightLookupTable).getSyn1Neg() == null) {
            logger.info("Initializing syn1Neg...");
            ((InMemoryLookupTable) weightLookupTable).setUseHS(vectorsConfiguration.isUseHierarchicSoftmax());
            ((InMemoryLookupTable) weightLookupTable).setNegative(vectorsConfiguration.getNegative());
            ((InMemoryLookupTable) weightLookupTable).resetWeights(false);
        }
        this.syn0 = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn0());
        this.syn1 = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn1());
        this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getSyn1Neg());
        this.expTable = new DeviceLocalNDArray(Nd4j.create(((InMemoryLookupTable) weightLookupTable).getExpTable()));
        this.table = new DeviceLocalNDArray(((InMemoryLookupTable) weightLookupTable).getTable());
        this.variableWindows = vectorsConfiguration.getVariableWindows();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void pretrain(SequenceIterator<T> sequenceIterator) {
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public void finish() {
        if (this.batches == null || this.batches.get() == null || this.batches.get().isEmpty()) {
            return;
        }
        Nd4j.getExecutioner().exec(this.batches.get());
        this.batches.get().clear();
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public double learnSequence(Sequence<T> sequence, AtomicLong atomicLong, double d) {
        Sequence<T> sequence2 = sequence;
        if (this.sampling > 0.0d) {
            sequence2 = applySubsampling(sequence, atomicLong);
        }
        int i = this.window;
        if (this.variableWindows != null && this.variableWindows.length != 0) {
            i = this.variableWindows[RandomUtils.nextInt(this.variableWindows.length)];
        }
        for (int i2 = 0; i2 < sequence2.getElements().size(); i2++) {
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            cbow(i2, sequence2.getElements(), ((int) atomicLong.get()) % i, atomicLong, d, i);
        }
        return 0.0d;
    }

    @Override // org.deeplearning4j.models.embeddings.learning.ElementsLearningAlgorithm
    public boolean isEarlyTerminationHit() {
        return false;
    }

    public void iterateSample(T t, int[] iArr, AtomicLong atomicLong, double d, boolean z, int i, boolean z2, INDArray iNDArray) {
        int[] iArr2;
        int[] iArr3;
        if (this.configuration.isUseHierarchicSoftmax()) {
            iArr2 = new int[t.getCodeLength()];
            iArr3 = new int[t.getCodeLength()];
            for (int i2 = 0; i2 < t.getCodeLength(); i2++) {
                if (t.getPoints().get(i2).intValue() >= 0) {
                    iArr3[i2] = t.getCodes().get(i2).byteValue();
                    iArr2[i2] = t.getPoints().get(i2).intValue();
                }
            }
        } else {
            iArr2 = new int[0];
            iArr3 = new int[0];
        }
        if (this.negative > 0.0d && this.syn1Neg == null) {
            ((InMemoryLookupTable) this.lookupTable).initNegative();
            this.syn1Neg = new DeviceLocalNDArray(((InMemoryLookupTable) this.lookupTable).getSyn1Neg());
        }
        if (this.batches.get() == null) {
            this.batches.set(new ArrayList());
        }
        AggregateCBOW aggregateCBOW = new AggregateCBOW(this.syn0.get(), this.syn1.get(), this.syn1Neg.get(), this.expTable.get(), this.table.get(), t.getIndex(), iArr, iArr2, iArr3, (int) this.negative, t.getIndex(), this.lookupTable.layerSize(), d, atomicLong.get(), this.vocabCache.numWords(), i, z2, iNDArray);
        atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
        if (z) {
            Nd4j.getExecutioner().exec(aggregateCBOW);
            return;
        }
        this.batches.get().add(aggregateCBOW);
        if (this.batches.get().size() > 4096) {
            Nd4j.getExecutioner().exec(this.batches.get());
            this.batches.get().clear();
        }
    }

    public void cbow(int i, List<T> list, int i2, AtomicLong atomicLong, double d, int i3) {
        int i4;
        int i5 = ((this.window * 2) + 1) - i2;
        T t = list.get(i);
        ArrayList arrayList = new ArrayList();
        for (int i6 = i2; i6 < i5; i6++) {
            if (i6 != i3 && (i4 = (i - i3) + i6) >= 0 && i4 < list.size()) {
                arrayList.add(Integer.valueOf(list.get(i4).getIndex()));
            }
        }
        int[] iArr = new int[arrayList.size()];
        for (int i7 = 0; i7 < iArr.length; i7++) {
            iArr[i7] = ((Integer) arrayList.get(i7)).intValue();
        }
        iterateSample(t, iArr, atomicLong, d, false, 0, true, null);
        if (this.batches == null || this.batches.get() == null || this.batches.get().size() < this.configuration.getBatchSize()) {
            return;
        }
        Nd4j.getExecutioner().exec(this.batches.get());
        this.batches.get().clear();
    }

    public Sequence<T> applySubsampling(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong) {
        if (sequence == null) {
            throw new NullPointerException("sequence is marked @NonNull but is null");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom is marked @NonNull but is null");
        }
        Sequence<T> sequence2 = new Sequence<>();
        if (this.sampling <= 0.0d) {
            return sequence;
        }
        sequence2.setSequenceId(sequence.getSequenceId());
        if (sequence.getSequenceLabels() != null) {
            sequence2.setSequenceLabels(sequence.getSequenceLabels());
        }
        if (sequence.getSequenceLabel() != null) {
            sequence2.setSequenceLabel(sequence.getSequenceLabel());
        }
        for (T t : sequence.getElements()) {
            double d = this.vocabCache.totalWordOccurrences();
            double sqrt = ((Math.sqrt(t.getElementFrequency() / (this.sampling * d)) + 1.0d) * (this.sampling * d)) / t.getElementFrequency();
            atomicLong.set(Math.abs((atomicLong.get() * 25214903917L) + 11));
            if (sqrt >= (atomicLong.get() & 65535) / 65536.0d) {
                sequence2.addElement(t);
            }
        }
        return sequence2;
    }

    public DeviceLocalNDArray getSyn0() {
        return this.syn0;
    }

    public DeviceLocalNDArray getSyn1() {
        return this.syn1;
    }

    public DeviceLocalNDArray getSyn1Neg() {
        return this.syn1Neg;
    }

    public DeviceLocalNDArray getExpTable() {
        return this.expTable;
    }

    public DeviceLocalNDArray getTable() {
        return this.table;
    }

    public void setSyn0(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn0 = deviceLocalNDArray;
    }

    public void setSyn1(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn1 = deviceLocalNDArray;
    }

    public void setSyn1Neg(DeviceLocalNDArray deviceLocalNDArray) {
        this.syn1Neg = deviceLocalNDArray;
    }

    public void setExpTable(DeviceLocalNDArray deviceLocalNDArray) {
        this.expTable = deviceLocalNDArray;
    }

    public void setTable(DeviceLocalNDArray deviceLocalNDArray) {
        this.table = deviceLocalNDArray;
    }
}
