package de.datexis.encoder.impl;

import cc.fasttext.FastText;
import cc.fasttext.Matrix;
import cc.fasttext.Vector;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import de.datexis.common.Resource;
import de.datexis.encoder.Encoder;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.SentenceDetectorMENL;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.logging.Level;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang.Validate;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@JsonIgnoreProperties(ignoreUnknown = true)
/* loaded from: input_file:de/datexis/encoder/impl/FastTextEncoder.class */
public class FastTextEncoder extends Encoder {
    private static final Logger log = LoggerFactory.getLogger(FastTextEncoder.class);
    private FastText ft;
    private String modelName;
    private Resource modelSource;
    private long size;
    private Method getPrecomputedWordVectors;
    private Method findNN;

    public FastTextEncoder() {
        super("FT");
        this.size = 0L;
    }

    public FastTextEncoder(String str) {
        super(str);
        this.size = 0L;
    }

    private void initializeMethodsFromReflection() {
        try {
            this.getPrecomputedWordVectors = this.ft.getClass().getDeclaredMethod("getPrecomputedWordVectors", new Class[0]);
            this.getPrecomputedWordVectors.setAccessible(true);
            this.findNN = this.ft.getClass().getDeclaredMethod("findNN", Matrix.class, Vector.class, Integer.TYPE, Set.class);
            this.findNN.setAccessible(true);
        } catch (Exception e) {
            java.util.logging.Logger.getLogger(SentenceDetectorMENL.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    public static FastTextEncoder load(Resource resource) throws IOException {
        FastTextEncoder fastTextEncoder = new FastTextEncoder();
        fastTextEncoder.loadModel(resource);
        return fastTextEncoder;
    }

    @Override // de.datexis.annotator.IComponent
    public void loadModel(Resource resource) throws IOException {
        log.info("Loading FastText model: " + resource.getFileName());
        this.ft = FastText.DEFAULT_FACTORY.load(resource.getInputStream());
        initializeMethodsFromReflection();
        this.size = this.ft.getWordVector("the").size();
        setModel(resource);
        setModelAvailable(true);
        this.modelSource = resource;
        log.info("Loaded FastText model '{}' with {} words and vector size {}", new Object[]{resource.getFileName(), Integer.valueOf(this.ft.getDictionary().size()), Long.valueOf(this.size)});
    }

    public void loadModelAsReference(Resource resource) throws IOException {
        loadModel(resource);
        this.modelSource = null;
    }

    public void setModelAsReference() {
        this.modelSource = null;
    }

    @Override // de.datexis.annotator.IComponent
    public void saveModel(Resource resource, String str) {
        if (this.modelSource != null) {
            try {
                Resource resolve = resource.resolve(str + (this.modelSource.getFileName().endsWith(".gz") ? ".bin.gz" : ".bin"));
                FileUtils.copyFile(this.modelSource.toFile(), resolve.toFile());
                setModel(resolve);
            } catch (IOException e) {
                log.error(e.toString());
            }
        }
    }

    @Override // de.datexis.encoder.Encoder
    public void trainModel(Collection<Document> collection) {
        throw new UnsupportedOperationException("model training not implemented");
    }

    @Override // de.datexis.annotator.AnnotatorComponent, de.datexis.annotator.IComponent
    public String getName() {
        return this.modelName;
    }

    protected INDArray asINDArray(Vector vector) {
        INDArray createUninitialized = Nd4j.createUninitialized(vector.size(), 1L);
        int i = 0;
        Iterator it = vector.getData().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            createUninitialized.putScalar(i2, 1L, ((Float) it.next()).doubleValue());
        }
        return createUninitialized;
    }

    protected Vector asVector(INDArray iNDArray) {
        Vector vector = new Vector((int) iNDArray.length());
        for (int i = 0; i < iNDArray.length(); i++) {
            vector.set(i, iNDArray.getFloat(i));
        }
        return vector;
    }

    protected INDArray getWordVector(String str) {
        return asINDArray(this.ft.getWordVector(str));
    }

    protected INDArray getSentenceVector(String str) {
        return asINDArray(this.ft.getSentenceVector(str));
    }

    public boolean isUnknown(String str) {
        return this.ft.getDictionary().getId(str) <= 0;
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(Span span) {
        return span instanceof Token ? getWordVector(span.getText()) : span instanceof Sentence ? getSentenceVector(((Sentence) span).toTokenizedString()) : encode(span.getText());
    }

    @Override // de.datexis.encoder.IEncoder
    public long getEmbeddingVectorSize() {
        return this.size;
    }

    @Override // de.datexis.encoder.IEncoder
    public INDArray encode(String str) {
        return str.contains(" ") ? getSentenceVector(str) : getWordVector(str);
    }

    public List<String> getNearestNeighbours(String str, int i) {
        return (List) this.ft.nn(i, str).entries().stream().sorted((entry, entry2) -> {
            return ((Float) entry2.getValue()).compareTo((Float) entry.getValue());
        }).map(entry3 -> {
            return (String) entry3.getKey();
        }).collect(Collectors.toList());
    }

    public List<String> getNearestNeighbours(INDArray iNDArray, int i) {
        try {
            Validate.isTrue(i > 0, "Not positive factor");
            return (List) Multimaps.invertFrom((Multimap) this.findNN.invoke(this.ft, (Matrix) this.getPrecomputedWordVectors.invoke(this.ft, new Object[0]), asVector(iNDArray), Integer.valueOf(i), new HashSet()), ArrayListMultimap.create()).entries().stream().sorted((entry, entry2) -> {
                return ((Float) entry2.getValue()).compareTo((Float) entry.getValue());
            }).map(entry3 -> {
                return (String) entry3.getKey();
            }).collect(Collectors.toList());
        } catch (IllegalAccessException | InvocationTargetException e) {
            e.printStackTrace();
            return Collections.EMPTY_LIST;
        }
    }
}
