package org.deeplearning4j.models.glove;

import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectorsImpl;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.legacy.AdaGrad;

@Deprecated
/* loaded from: input_file:org/deeplearning4j/models/glove/GloveWeightLookupTable.class */
public class GloveWeightLookupTable<T extends SequenceElement> extends InMemoryLookupTable<T> {
    private AdaGrad weightAdaGrad;
    private AdaGrad biasAdaGrad;
    private INDArray bias;
    private double xMax;
    private double maxCount;

    /* loaded from: input_file:org/deeplearning4j/models/glove/GloveWeightLookupTable$Builder.class */
    public static class Builder<T extends SequenceElement> extends InMemoryLookupTable.Builder<T> {
        private double xMax = 0.75d;
        private double maxCount = 100.0d;

        public Builder<T> maxCount(double d) {
            this.maxCount = d;
            return this;
        }

        public Builder<T> xMax(double d) {
            this.xMax = d;
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder<T> cache(VocabCache<T> vocabCache) {
            super.cache((VocabCache) vocabCache);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder<T> negative(double d) {
            super.negative(d);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder<T> vectorLength(int i) {
            super.vectorLength(i);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder<T> useAdaGrad(boolean z) {
            super.useAdaGrad(z);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder<T> lr(double d) {
            super.lr(d);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder<T> gen(Random random) {
            super.gen(random);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public Builder<T> seed(long j) {
            super.seed(j);
            return this;
        }

        @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable.Builder
        public GloveWeightLookupTable<T> build() {
            return new GloveWeightLookupTable<>(this.vocabCache, this.vectorLength, this.useAdaGrad, this.lr, this.gen, this.negative, this.xMax, this.maxCount);
        }
    }

    public GloveWeightLookupTable(VocabCache<T> vocabCache, int i, boolean z, double d, Random random, double d2, double d3, double d4) {
        super(vocabCache, i, z, d, random, d2);
        this.xMax = 0.75d;
        this.maxCount = 100.0d;
        this.xMax = d3;
        this.maxCount = d4;
    }

    @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable, org.deeplearning4j.models.embeddings.WeightLookupTable
    public void resetWeights(boolean z) {
        if (this.rng == null) {
            this.rng = Nd4j.getRandom();
        }
        if (this.syn0 == null || z) {
            this.syn0 = Nd4j.rand(new int[]{this.vocab.numWords() + 1, this.vectorLength}, this.rng).subi(Double.valueOf(0.5d)).divi(Double.valueOf(this.vectorLength));
            putVector(WordVectorsImpl.DEFAULT_UNK, Nd4j.rand(1, this.vectorLength, this.rng).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(this.vectorLength)));
        }
        if (this.weightAdaGrad == null || z) {
            this.weightAdaGrad = new AdaGrad(new long[]{this.vocab.numWords() + 1, this.vectorLength}, this.lr.get());
        }
        if (this.bias == null || z) {
            this.bias = Nd4j.create(this.syn0.rows());
        }
        if (this.biasAdaGrad == null || z) {
            this.biasAdaGrad = new AdaGrad(this.bias.shape(), this.lr.get());
        }
    }

    @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable, org.deeplearning4j.models.embeddings.WeightLookupTable
    public void resetWeights() {
        resetWeights(true);
    }

    public double iterateSample(T t, T t2, double d) {
        INDArray slice = this.syn0.slice(t.getIndex());
        INDArray slice2 = this.syn0.slice(t2.getIndex());
        if (t.getIndex() < 0 || t.getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + t.getLabel());
        }
        if (t2.getIndex() < 0 || t2.getIndex() >= this.syn0.rows()) {
            throw new IllegalArgumentException("Illegal index for word " + t2.getLabel());
        }
        double dot = Nd4j.getBlasWrapper().dot(slice, slice2) + this.bias.getDouble(t.getIndex()) + this.bias.getDouble(t2.getIndex());
        double pow = d > this.xMax ? dot : Math.pow(Math.min(1.0d, d / this.maxCount), this.xMax) * (dot - Math.log(d));
        if (Double.isNaN(pow)) {
            pow = Nd4j.EPS_THRESHOLD;
        }
        double d2 = pow;
        update(t, slice, slice2, d2);
        update(t2, slice2, slice, d2);
        return pow;
    }

    private void update(T t, INDArray iNDArray, INDArray iNDArray2, double d) {
        iNDArray.subi(this.weightAdaGrad.getGradient(iNDArray2.mul(Double.valueOf(d)), t.getIndex(), this.syn0.shape()));
        this.bias.putScalar(t.getIndex(), this.bias.getDouble(t.getIndex()) - this.biasAdaGrad.getGradient(d, t.getIndex(), this.bias.shape()));
    }

    public AdaGrad getWeightAdaGrad() {
        return this.weightAdaGrad;
    }

    public AdaGrad getBiasAdaGrad() {
        return this.biasAdaGrad;
    }

    public static GloveWeightLookupTable load(InputStream inputStream, VocabCache<? extends SequenceElement> vocabCache) throws IOException {
        LineIterator lineIterator = IOUtils.lineIterator(inputStream, "UTF-8");
        GloveWeightLookupTable<T> gloveWeightLookupTable = null;
        HashMap hashMap = new HashMap();
        while (lineIterator.hasNext()) {
            String trim = lineIterator.nextLine().trim();
            if (!trim.isEmpty()) {
                String[] split = trim.split(" ");
                String str = split[0];
                if (gloveWeightLookupTable == null) {
                    gloveWeightLookupTable = new Builder().cache((VocabCache) vocabCache).vectorLength(split.length - 1).build();
                }
                if (!str.isEmpty()) {
                    float[] read = read(split, gloveWeightLookupTable.layerSize());
                    if (read.length >= 1) {
                        hashMap.put(str, read);
                    }
                }
            }
        }
        gloveWeightLookupTable.setSyn0(weights(gloveWeightLookupTable, hashMap, vocabCache));
        gloveWeightLookupTable.resetWeights(false);
        lineIterator.close();
        return gloveWeightLookupTable;
    }

    private static INDArray weights(GloveWeightLookupTable gloveWeightLookupTable, Map<String, float[]> map, VocabCache vocabCache) {
        INDArray create = Nd4j.create(map.size(), gloveWeightLookupTable.layerSize());
        for (Map.Entry<String, float[]> entry : map.entrySet()) {
            String key = entry.getKey();
            INDArray create2 = Nd4j.create(Nd4j.createBuffer(entry.getValue()));
            if (create2.length() == gloveWeightLookupTable.layerSize() && vocabCache.indexOf(key) < map.size() && vocabCache.indexOf(key) >= 0) {
                create.putRow(vocabCache.indexOf(key), create2);
            }
        }
        return create;
    }

    private static float[] read(String[] strArr, int i) {
        float[] fArr = new float[i];
        for (int i2 = 1; i2 < strArr.length; i2++) {
            fArr[i2 - 1] = Float.parseFloat(strArr[i2]);
        }
        return fArr;
    }

    @Override // org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable, org.deeplearning4j.models.embeddings.WeightLookupTable
    public void iterateSample(T t, T t2, AtomicLong atomicLong, double d) {
        throw new UnsupportedOperationException();
    }

    public double getxMax() {
        return this.xMax;
    }

    public void setxMax(double d) {
        this.xMax = d;
    }

    public double getMaxCount() {
        return this.maxCount;
    }

    public void setMaxCount(double d) {
        this.maxCount = d;
    }

    public INDArray getBias() {
        return this.bias;
    }

    public void setBias(INDArray iNDArray) {
        this.bias = iNDArray;
    }
}
