package de.datexis.sector.encoder;

import de.datexis.common.WordHelpers;
import de.datexis.encoder.impl.BagOfWordsEncoder;
import de.datexis.model.Span;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/sector/encoder/HeadingEncoder.class */
public class HeadingEncoder extends BagOfWordsEncoder {
    public static final String ID = "HL";
    protected static final Logger log = LoggerFactory.getLogger(HeadingEncoder.class);
    public static String OTHER_CLASS = "other";

    public HeadingEncoder() {
        super(ID);
    }

    public void trainModel(List<String> list, int i, int i2, WordHelpers.Language language) {
        appendTrainLog("Training " + getName() + " model...");
        setModel(null);
        this.totalWords = 0;
        this.timer.start();
        setLanguage(language);
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            for (String str : WordHelpers.splitSpaces(it.next())) {
                String preProcess = this.preprocessor.preProcess(str);
                if (!preProcess.isEmpty()) {
                    this.totalWords++;
                    if (!this.wordHelpers.isStopWord(preProcess) && preProcess.length() >= i2) {
                        if (this.vocab.containsWord(preProcess)) {
                            this.vocab.incrementWordCounter(preProcess);
                        } else {
                            this.vocab.addWord(preProcess);
                        }
                    }
                }
            }
        }
        int numWords = this.vocab.numWords();
        this.vocab.truncateVocabulary(i);
        this.vocab.addWord(this.preprocessor.preProcess(OTHER_CLASS));
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        appendTrainLog("trained " + this.vocab.numWords() + " words (" + numWords + " total)", this.timer.getLong());
        setModelAvailable(true);
    }

    public INDArray encode(String str) {
        return str != null ? encode(WordHelpers.splitSpaces(str)) : encodeOtherClass();
    }

    public INDArray encode(Iterable<? extends Span> iterable) {
        INDArray encode = super.encode(iterable);
        return encode.sumNumber().doubleValue() > 0.0d ? encode : encodeOtherClass();
    }

    protected INDArray encode(String[] strArr) {
        INDArray encode = super.encode(strArr);
        return encode.sumNumber().doubleValue() > 0.0d ? encode : encodeOtherClass();
    }

    public INDArray encodeSubsampled(String str) {
        INDArray encodeSubsampled = super.encodeSubsampled(str);
        return encodeSubsampled.sumNumber().doubleValue() > 0.0d ? encodeSubsampled : encodeOtherClass();
    }

    protected INDArray encodeOtherClass() {
        return Nd4j.zeros(getEmbeddingVectorSize(), 1L);
    }

    public Collection<String> getNearestNeighbours(INDArray iNDArray, int i) {
        INDArray[] sortWithIndices = Nd4j.sortWithIndices(iNDArray.dup(), 0, false);
        if (sortWithIndices[0].sumNumber().doubleValue() == 0.0d) {
            log.warn("NearestNeighbour on zero vector - please check vector alignment!");
        }
        INDArray iNDArray2 = sortWithIndices[0];
        double d = sortWithIndices[1].getDouble(0L);
        double doubleValue = sortWithIndices[1].medianNumber().doubleValue();
        ArrayList arrayList = new ArrayList(i);
        int i2 = 0;
        int i3 = 0;
        while (i3 < i) {
            String word = getWord(iNDArray2.getInt(new int[]{i2}));
            double d2 = sortWithIndices[1].getDouble(i2);
            if (d2 == 0.0d || d2 < (d + doubleValue) / 2.0d) {
                break;
            }
            if (!word.equals(OTHER_CLASS)) {
                arrayList.add(word);
                i3++;
            }
            i2++;
        }
        if (arrayList.isEmpty()) {
            arrayList.add(OTHER_CLASS);
        }
        return arrayList;
    }
}
