package de.datexis.encoder;

import de.datexis.common.Resource;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;
import java.io.OutputStreamWriter;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyHolder;
import org.deeplearning4j.models.word2vec.wordstore.VocabularyWord;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/encoder/LookupCacheEncoder.class */
public abstract class LookupCacheEncoder extends Encoder {
    protected VocabularyHolder vocab;
    protected int totalWords;

    public LookupCacheEncoder() {
        this("");
    }

    public LookupCacheEncoder(String str) {
        super(str);
        this.totalWords = 0;
        this.log = LoggerFactory.getLogger(LookupCacheEncoder.class);
        this.vocab = new VocabularyHolder.Builder().build();
    }

    public int getTotalWords() {
        return this.totalWords;
    }

    public void setTotalWords(int i) {
        this.totalWords = i;
    }

    @Override // de.datexis.encoder.IEncoder
    public long getEmbeddingVectorSize() {
        return this.vocab.numWords();
    }

    public int getIndex(String str) {
        return this.vocab.indexOf(str);
    }

    public INDArray oneHot(String str) {
        INDArray zeros = Nd4j.zeros(new long[]{getEmbeddingVectorSize(), 1});
        int index = getIndex(str);
        if (index >= 0) {
            zeros.put(index, 0, Double.valueOf(1.0d));
        } else {
            this.log.warn("could not encode class '{}'. is it contained in training set?", str);
        }
        return zeros;
    }

    public int getFrequency(String str) {
        VocabularyWord vocabularyWordByString = this.vocab.getVocabularyWordByString(str);
        if (vocabularyWordByString != null) {
            return vocabularyWordByString.getCount();
        }
        return 0;
    }

    public double getProbability(String str) {
        return getFrequency(str) / this.totalWords;
    }

    public double getConfidence(INDArray iNDArray, int i) {
        return iNDArray.getDouble(i);
    }

    public double getMaxConfidence(INDArray iNDArray) {
        return iNDArray.max(new int[]{0}).sumNumber().doubleValue();
    }

    public String getWord(int i) {
        VocabularyWord vocabularyWordByIdx = this.vocab.getVocabularyWordByIdx(Integer.valueOf(i));
        if (vocabularyWordByIdx != null) {
            return vocabularyWordByIdx.getWord();
        }
        return null;
    }

    public boolean isUnknown(String str) {
        return !this.vocab.containsWord(str);
    }

    @Override // de.datexis.annotator.IComponent
    public void saveModel(Resource resource, String str) {
        Resource resolve = resource.resolve(str + ".tsv.gz");
        try {
            OutputStreamWriter outputStreamWriter = new OutputStreamWriter(resolve.getGZIPOutputStream());
            Throwable th = null;
            try {
                try {
                    int i = 0;
                    for (VocabularyWord vocabularyWord : this.vocab.getVocabulary()) {
                        i++;
                        outputStreamWriter.write(vocabularyWord.getHuffmanNode().getIdx() + "\t" + vocabularyWord.getWord() + "\t" + vocabularyWord.getCount() + "\n");
                    }
                    setModel(resolve);
                    this.log.info("saved " + i + " words");
                    if (outputStreamWriter != null) {
                        if (0 != 0) {
                            try {
                                outputStreamWriter.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            outputStreamWriter.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (IOException e) {
            this.log.error(e.toString());
        }
        try {
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(resource.resolve(str + ".bin").getOutputStream());
            Throwable th4 = null;
            try {
                objectOutputStream.writeObject(this.vocab);
                if (objectOutputStream != null) {
                    if (0 != 0) {
                        try {
                            objectOutputStream.close();
                        } catch (Throwable th5) {
                            th4.addSuppressed(th5);
                        }
                    } else {
                        objectOutputStream.close();
                    }
                }
            } finally {
            }
        } catch (IOException e2) {
            this.log.error(e2.toString());
        }
    }

    @Override // de.datexis.annotator.IComponent
    public void loadModel(Resource resource) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(resource.getInputStream(), "UTF-8"));
        Throwable th = null;
        while (true) {
            try {
                try {
                    String readLine = bufferedReader.readLine();
                    if (readLine == null) {
                        break;
                    }
                    String[] split = readLine.split("\\t");
                    VocabularyWord vocabularyWord = new VocabularyWord(split[1]);
                    vocabularyWord.setCount(1000000000 - Integer.parseInt(split[0]));
                    this.vocab.addWord(vocabularyWord);
                } catch (Throwable th2) {
                    th = th2;
                    throw th2;
                }
            } catch (Throwable th3) {
                if (bufferedReader != null) {
                    if (th != null) {
                        try {
                            bufferedReader.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        bufferedReader.close();
                    }
                }
                throw th3;
            }
        }
        this.vocab.updateHuffmanCodes();
        setModel(resource);
        setModelAvailable(true);
        this.log.info("loaded " + this.vocab.numWords() + " words from " + resource.toString());
        if (bufferedReader != null) {
            if (0 == 0) {
                bufferedReader.close();
                return;
            }
            try {
                bufferedReader.close();
            } catch (Throwable th5) {
                th.addSuppressed(th5);
            }
        }
    }

    @JsonIgnore
    public List<String> getWords() {
        return (List) this.vocab.words().stream().map(vocabularyWord -> {
            return vocabularyWord.getWord();
        }).collect(Collectors.toList());
    }

    public Collection<String> getNearestNeighbours(String str, int i) {
        throw new UnsupportedOperationException("No nearest words in LookupCache.");
    }

    public String getNearestNeighbour(INDArray iNDArray) {
        throw new UnsupportedOperationException("No nearest words in LookupCache.");
    }

    public Collection<String> getNearestNeighbours(INDArray iNDArray, int i) {
        throw new UnsupportedOperationException("No nearest words in LookupCache.");
    }
}
