package de.datexis.encoder.impl;

import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.MinimalLowercasePreprocessor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/encoder/impl/OneHotEncoder.class */
public class OneHotEncoder extends LookupCacheEncoder {
    private static final TokenPreProcess preprocessor = new MinimalLowercasePreprocessor();

    public OneHotEncoder() {
        super("1H");
        this.log = LoggerFactory.getLogger(OneHotEncoder.class);
    }

    public OneHotEncoder(String str) {
        super(str);
        this.log = LoggerFactory.getLogger(OneHotEncoder.class);
    }

    @Override // de.datexis.annotator.AnnotatorComponent, de.datexis.annotator.IComponent
    public String getName() {
        return "1-hot Encoder";
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Span span) {
        return encode(span.getText());
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(String str) {
        INDArray zeros = Nd4j.zeros(new long[]{getEmbeddingVectorSize(), 1});
        int indexOf = this.vocab.indexOf(preprocessor.preProcess(str));
        if (indexOf >= 0) {
            zeros.put(indexOf, 0, Double.valueOf(1.0d));
        }
        return zeros;
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public boolean isUnknown(String str) {
        return !this.vocab.containsWord(preprocessor.preProcess(str));
    }

    @Override // de.datexis.encoder.Encoder
    public void trainModel(Collection<Document> collection) {
        trainModel(collection, 1);
    }

    public void trainModel(Collection<Document> collection, int i) {
        appendTrainLog("Training " + getName() + " model...");
        setModel(null);
        this.timer.start();
        this.totalWords = 0;
        Iterator<Document> it = collection.iterator();
        while (it.hasNext()) {
            Iterator<Token> it2 = it.next().getTokens().iterator();
            while (it2.hasNext()) {
                String preProcess = preprocessor.preProcess(it2.next().getText());
                this.totalWords++;
                if (!preProcess.isEmpty()) {
                    if (this.vocab.containsWord(preProcess)) {
                        this.vocab.incrementWordCounter(preProcess);
                    } else {
                        this.vocab.addWord(preProcess);
                    }
                }
            }
        }
        int numWords = this.vocab.numWords();
        this.vocab.truncateVocabulary(i);
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        appendTrainLog("trained " + this.vocab.numWords() + " words (" + numWords + " total)", this.timer.getLong());
        setModelAvailable(true);
    }

    @Override // de.datexis.encoder.LookupCacheEncoder
    public Collection<String> getNearestNeighbours(INDArray iNDArray, int i) {
        Double[] dArr = new Double[(int) iNDArray.length()];
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            dArr[i2] = Double.valueOf(iNDArray.getDouble(i2));
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i3 = 0; i3 < i; i3++) {
            double d = 0.0d;
            int i4 = 0;
            for (int i5 = 0; i5 < iNDArray.length(); i5++) {
                if (dArr[i5].doubleValue() > d) {
                    i4 = i5;
                    d = dArr[i5].doubleValue();
                    dArr[i5] = Double.valueOf(Double.MIN_VALUE);
                }
            }
            arrayList.add(getWord(i4));
        }
        return arrayList;
    }
}
