package org.deeplearning4j.models.embeddings.reader.impl;

import java.util.Collection;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.class */
public class FlatModelUtils<T extends SequenceElement> extends BasicModelUtils<T> {
    private static final Logger log = LoggerFactory.getLogger(FlatModelUtils.class);

    @Override // org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils, org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearest(String str, int i) {
        Collection<String> wordsNearest = wordsNearest(this.lookupTable.vector(str), i);
        if (wordsNearest.contains(str)) {
            wordsNearest.remove(str);
        }
        return wordsNearest;
    }

    @Override // org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils, org.deeplearning4j.models.embeddings.reader.ModelUtils
    public Collection<String> wordsNearest(INDArray iNDArray, int i) {
        Counter counter = new Counter();
        for (String str : this.vocabCache.words()) {
            counter.incrementCount(str, (float) Transforms.cosineSim(Transforms.unitVec(iNDArray.dup()), Transforms.unitVec(this.lookupTable.vector(str).dup())));
        }
        counter.keepTopNKeys(i);
        return counter.getSortedKeys();
    }
}
