package de.datexis.sector.encoder;

import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Document;
import de.datexis.model.Span;
import de.datexis.preprocess.LowercasePreprocessor;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyWord;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/sector/encoder/ClassEncoder.class */
public class ClassEncoder extends LookupCacheEncoder {
    private static final TokenPreProcess preprocessor = new LowercasePreprocessor();
    public static final String ID = "CLS";

    public ClassEncoder() {
        super(ID);
        this.log = LoggerFactory.getLogger(ClassEncoder.class);
    }

    public ClassEncoder(String str) {
        super(str);
        this.log = LoggerFactory.getLogger(ClassEncoder.class);
    }

    public String getName() {
        return "Classification Encoder";
    }

    public INDArray encode(Span span) {
        return encode(span.getText());
    }

    public long getEmbeddingVectorSize() {
        return this.vocab.numWords();
    }

    public INDArray encode(String str) {
        return oneHot(str);
    }

    public int getIndex(String str) {
        return this.vocab.indexOf(preprocessor.preProcess(str));
    }

    public INDArray oneHot(String str) {
        INDArray zeros = Nd4j.zeros(getEmbeddingVectorSize(), 1L);
        int index = getIndex(str);
        if (index >= 0) {
            zeros.put(index, 0, Double.valueOf(1.0d));
        } else {
            this.log.warn("could not encode class '{}'. is it contained in training set?", str);
        }
        return zeros;
    }

    public boolean isUnknown(String str) {
        return !this.vocab.containsWord(preprocessor.preProcess(str));
    }

    public void trainModel(Collection<Document> collection) {
        throw new UnsupportedOperationException("cannot train classification on Documents");
    }

    public void trainModelUsingHead(Iterable<String> iterable) {
        trainModel(iterable, 0);
        double d = 0.0d;
        while (this.vocab.words().iterator().hasNext()) {
            d += ((VocabularyWord) r0.next()).getCount();
        }
        this.vocab.truncateVocabulary((int) (d / this.vocab.numWords()));
        this.vocab.updateHuffmanCodes();
        appendTrainLog("truncated to " + this.vocab.numWords() + " classes");
    }

    public void trainModel(Iterable<String> iterable, int i) {
        appendTrainLog("Training " + getName() + " model...");
        setModel(null);
        this.timer.start();
        this.totalWords = 0;
        Iterator<String> it = iterable.iterator();
        while (it.hasNext()) {
            String preProcess = preprocessor.preProcess(it.next());
            this.totalWords++;
            if (!preProcess.isEmpty()) {
                if (this.vocab.containsWord(preProcess)) {
                    this.vocab.incrementWordCounter(preProcess);
                } else {
                    this.vocab.addWord(preProcess);
                }
            }
        }
        int numWords = this.vocab.numWords();
        this.vocab.truncateVocabulary(i);
        this.vocab.updateHuffmanCodes();
        this.timer.stop();
        appendTrainLog("trained " + this.vocab.numWords() + " classes (" + numWords + " total)", this.timer.getLong());
        setModelAvailable(true);
    }

    public String getNearestNeighbour(INDArray iNDArray) {
        Collection<String> nearestNeighbours = getNearestNeighbours(iNDArray, 1);
        if (nearestNeighbours.isEmpty()) {
            return null;
        }
        return nearestNeighbours.iterator().next();
    }

    public Collection<String> getNearestNeighbours(INDArray iNDArray, int i) {
        Double[] dArr = new Double[(int) iNDArray.length()];
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            dArr[i2] = Double.valueOf(iNDArray.getDouble(i2));
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i3 = 0; i3 < i; i3++) {
            double d = Double.MIN_VALUE;
            int i4 = 0;
            for (int i5 = 0; i5 < iNDArray.length(); i5++) {
                if (dArr[i5].doubleValue() > d) {
                    i4 = i5;
                    d = dArr[i5].doubleValue();
                    dArr[i5] = Double.valueOf(Double.MIN_VALUE);
                }
            }
            arrayList.add(getWord(i4));
        }
        return arrayList;
    }

    public Collection<Map.Entry<String, Double>> getNearestNeighbourEntries(INDArray iNDArray, int i) {
        Double[] dArr = new Double[(int) iNDArray.length()];
        for (int i2 = 0; i2 < iNDArray.length(); i2++) {
            dArr[i2] = Double.valueOf(iNDArray.getDouble(i2));
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i3 = 0; i3 < i; i3++) {
            double d = Double.MIN_VALUE;
            int i4 = 0;
            for (int i5 = 0; i5 < iNDArray.length(); i5++) {
                if (dArr[i5].doubleValue() > d) {
                    i4 = i5;
                    d = dArr[i5].doubleValue();
                    dArr[i5] = Double.valueOf(Double.MIN_VALUE);
                }
            }
            arrayList.add(new AbstractMap.SimpleEntry(getWord(i4), Double.valueOf(iNDArray.getDouble(i4))));
        }
        return arrayList;
    }
}
