package de.datexis.parvec.encoder;

import de.datexis.common.Resource;
import de.datexis.encoder.LookupCacheEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.preprocess.DocumentFactory;
import de.datexis.preprocess.MinimalLowercasePreprocessor;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.inmemory.AbstractCache;
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/parvec/encoder/ParVecEncoder.class */
public class ParVecEncoder extends LookupCacheEncoder {
    protected WordVectors word2Vec;
    protected ParagraphVectors model;
    protected double learningRate;
    protected double minLearningRate;
    protected int batchSize;
    protected int numEpochs;
    protected int iterations;
    protected int layerSize;
    protected int targetSize;
    protected int windowSize;
    protected final DefaultTokenizerFactory tokenizerFactory;
    protected List<VocabWord> labelsList;
    protected List<String> stopwords;
    protected static final Logger log = LoggerFactory.getLogger(ParVecEncoder.class);
    protected static final TokenPreProcess preprocessor = new MinimalLowercasePreprocessor();

    public ParVecEncoder() {
        super("PV");
        this.learningRate = 0.025d;
        this.minLearningRate = 0.001d;
        this.batchSize = 16;
        this.numEpochs = 1;
        this.iterations = 5;
        this.layerSize = 256;
        this.windowSize = 10;
        this.stopwords = new ArrayList();
        this.tokenizerFactory = new DefaultTokenizerFactory();
        this.tokenizerFactory.setTokenPreProcessor(preprocessor);
    }

    public ParVecEncoder withWordEmbedding(WordVectors wordVectors) {
        this.word2Vec = wordVectors;
        return this;
    }

    public void setModelParams(int i, int i2) {
        this.layerSize = i;
        this.windowSize = i2;
    }

    public void setTrainingParams(double d, double d2, int i, int i2, int i3) {
        this.learningRate = d;
        this.minLearningRate = d2;
        this.batchSize = i;
        this.iterations = i2;
        this.numEpochs = i3;
    }

    public void setStopWords(List<String> list) {
        this.stopwords = list;
    }

    public void trainModel(Collection<Document> collection) {
        throw new UnsupportedOperationException("Please call trainModel(Dataset train)");
    }

    public void trainModel(Dataset dataset) {
        ParVecIterator parVecIterator = new ParVecIterator(dataset, true);
        ParagraphVectors.Builder sampling = new ParagraphVectors.Builder().minWordFrequency(3).iterations(this.iterations).epochs(this.numEpochs).layerSize(this.layerSize).learningRate(this.learningRate).minLearningRate(this.minLearningRate).batchSize(this.batchSize).windowSize(this.windowSize).iterate(parVecIterator).trainWordVectors(true).vocabCache(new AbstractCache()).tokenizerFactory(this.tokenizerFactory).stopWords(this.stopwords).sampling(0.0d);
        if (this.word2Vec != null) {
            sampling.useExistingWordVectors(this.word2Vec);
        }
        this.model = sampling.build();
        log.info("training ParVec...");
        this.model.fit();
        log.info("training complete.");
        try {
            Field declaredField = ParagraphVectors.class.getDeclaredField("labelsList");
            declaredField.setAccessible(true);
            this.labelsList = (List) declaredField.get(this.model);
            this.targetSize = this.labelsList.size();
            setModelAvailable(true);
        } catch (IllegalAccessException | NoSuchFieldException e) {
            log.error(e.getMessage(), e);
            throw new RuntimeException(e);
        }
    }

    public INDArray encode(Span span) {
        if (!(span instanceof Sentence)) {
            return encode(span.getText());
        }
        try {
            return this.model.inferVector(((Sentence) span).toTokenizedString().trim().replaceAll("\n", "*NL*").replaceAll("\t", "*t*"), this.learningRate, this.minLearningRate, 1).transpose();
        } catch (ND4JIllegalStateException e) {
            return Nd4j.zeros(this.layerSize).transpose();
        }
    }

    public INDArray encode(Annotation annotation, Document document) {
        try {
            return this.model.inferVector((String) document.streamSentencesInRange(annotation.getBegin(), annotation.getEnd(), false).map(sentence -> {
                return sentence.toTokenizedString().trim().replaceAll("\n", "*NL*").replaceAll("\t", "*t*");
            }).collect(Collectors.joining(" ")), this.learningRate, this.minLearningRate, 1).transpose();
        } catch (ND4JIllegalStateException e) {
            return Nd4j.zeros(this.layerSize).transpose();
        }
    }

    public INDArray encode(String str) {
        String str2 = (String) DocumentFactory.createTokensFromText(str).stream().map(token -> {
            return token.getText().trim().replaceAll("\n", "*NL*").replaceAll("\t", "*t*");
        }).collect(Collectors.joining(" "));
        try {
            return this.model.inferVector(str2).transpose();
        } catch (ND4JIllegalStateException e) {
            log.trace("unknown paragraph vector for '{}'", str2);
            return Nd4j.zeros(this.layerSize).transpose();
        }
    }

    public void saveModel(Resource resource, String str) {
        try {
            Resource resolve = resource.resolve(str + ".zip");
            WordVectorSerializer.writeParagraphVectors(this.model, resolve.getOutputStream());
            setModel(resolve);
        } catch (IOException e) {
            log.error(e.toString());
        }
    }

    public static ParVecEncoder load(Resource resource) throws IOException {
        ParVecEncoder parVecEncoder = new ParVecEncoder();
        parVecEncoder.loadModel(resource);
        return parVecEncoder;
    }

    public void loadModel(Resource resource) throws IOException {
        this.model = WordVectorSerializer.readParagraphVectors(resource.getInputStream());
        this.model.setTokenizerFactory(this.tokenizerFactory);
        this.layerSize = this.model.getLayerSize();
        try {
            Field declaredField = ParagraphVectors.class.getDeclaredField("labelsList");
            declaredField.setAccessible(true);
            this.labelsList = (List) declaredField.get(this.model);
            this.targetSize = this.labelsList.size();
            log.info("Loaded ParagraphVectors with {} classes and layer size {}", Integer.valueOf(this.targetSize), Integer.valueOf(this.layerSize));
            setModel(resource);
            setModelAvailable(true);
        } catch (IllegalAccessException | NoSuchFieldException e) {
            log.error(e.getMessage(), e);
            throw new RuntimeException(e);
        }
    }

    @JsonIgnore
    public List<String> getWords() {
        return (List) this.labelsList.stream().map((v0) -> {
            return v0.getLabel();
        }).collect(Collectors.toList());
    }

    public int getTotalWords() {
        return this.labelsList.size();
    }

    public long getEmbeddingVectorSize() {
        return this.model.inferVector("test").length();
    }

    public long getOutputVectorSize() {
        return this.targetSize;
    }

    public int getInputVectorSize() {
        return 0;
    }

    public String getWord(int i) {
        if (this.labelsList.size() < i) {
            return null;
        }
        return this.labelsList.get(i).getWord();
    }

    public int getIndex(String str) {
        return IntStream.range(0, this.labelsList.size()).filter(i -> {
            return str.equals(this.labelsList.get(i).getWord());
        }).findFirst().orElse(-1);
    }

    public INDArray oneHot(String str) {
        INDArray zeros = Nd4j.zeros(this.targetSize);
        int index = getIndex(str);
        if (index >= 0) {
            zeros.putScalar(index, 1.0d);
        } else {
            log.warn("could not encode class '{}'. is it contained in training set?", str);
        }
        return zeros.transpose();
    }

    public String getNearestNeighbour(INDArray iNDArray) {
        return getNearestNeighbours(iNDArray, 1).stream().findFirst().orElse(null);
    }

    public Collection<String> getNearestNeighbours(INDArray iNDArray, int i) {
        INDArray[] sortWithIndices = Nd4j.sortWithIndices(Nd4j.toFlattened(new INDArray[]{iNDArray}).dup(), 1, false);
        if (sortWithIndices[0].length() <= 1 || sortWithIndices[0].sumNumber().doubleValue() == 0.0d) {
            log.warn("NearestNeighbour on zero vector - please check vector alignment!");
        }
        INDArray iNDArray2 = sortWithIndices[0];
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(getWord(iNDArray2.getInt(new int[]{i2})));
        }
        return arrayList;
    }

    public INDArray getPredictions(INDArray iNDArray) {
        return new LabelSeeker(getWords(), this.model.getLookupTable()).getScoresAsVector(iNDArray).transpose();
    }
}
