package de.datexis.cdv.index;

import de.datexis.cdv.model.EntityAspectAnnotation;
import de.datexis.cdv.retrieval.EntityAspectQueryAnnotation;
import de.datexis.encoder.IEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Document;
import de.datexis.model.Query;
import de.datexis.model.Sentence;
import de.datexis.preprocess.DocumentFactory;
import de.datexis.retrieval.index.InMemoryIndex;
import de.datexis.retrieval.preprocess.WikipediaUrlPreprocessor;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/cdv/index/EntityIndex.class */
public class EntityIndex extends QueryIndex {
    protected static final Logger log = LoggerFactory.getLogger(EntityIndex.class);
    public static String ID_SEPARATOR_REGEX = ";";

    protected EntityIndex() {
    }

    public EntityIndex(IEncoder iEncoder) {
        super(new WikipediaUrlPreprocessor(), iEncoder);
        this.id = "ENT";
    }

    public INDArray lookup(String str) {
        String[] split = str.split(ID_SEPARATOR_REGEX);
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{getEmbeddingVectorSize(), 1});
        int i = 0;
        for (String str2 : split) {
            INDArray lookup = super.lookup(str2);
            if (lookup != null) {
                zeros.addi(lookup);
                i++;
            }
        }
        if (i == 0) {
            return null;
        }
        return i > 1 ? zeros.divi(Integer.valueOf(i)) : zeros;
    }

    public INDArray encode(String str) {
        return super.encode(str.replace('_', ' '));
    }

    public INDArray decode(String str) {
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{size(), 1});
        for (String str2 : str.split(ID_SEPARATOR_REGEX)) {
            int index = index(str2);
            if (index >= 0) {
                zeros.putScalarUnsafe(index, 1.0d);
            }
        }
        if (zeros.sumNumber().doubleValue() == 0.0d) {
            log.warn("entity '{}' not contained in index", str);
        }
        return zeros;
    }

    @Deprecated
    public void trainModel(Collection<Document> collection) {
        ArrayList arrayList = new ArrayList();
        Map hashMap = new HashMap();
        Iterator<Document> it = collection.iterator();
        while (it.hasNext()) {
            for (EntityAspectAnnotation entityAspectAnnotation : it.next().getAnnotations(Annotation.Source.GOLD, EntityAspectAnnotation.class)) {
                String entityId = entityAspectAnnotation.getEntityId();
                Sentence createSentenceFromTokenizedString = DocumentFactory.createSentenceFromTokenizedString(entityAspectAnnotation.getEntity());
                if (entityId != null) {
                    for (String str : entityId.split(ID_SEPARATOR_REGEX)) {
                        String preProcess = this.keyPreprocessor.preProcess(str);
                        arrayList.add(preProcess);
                        if (!hashMap.containsKey(preProcess)) {
                            hashMap.put(preProcess, this.encoder.encode(createSentenceFromTokenizedString));
                        }
                    }
                }
            }
        }
        buildKeyIndex(arrayList, false);
        buildVectorIndex(hashMap, false);
        setModelAvailable(true);
    }

    @Override // de.datexis.cdv.index.QueryIndex
    public void encodeFromQueries(Collection<Query> collection) {
        ArrayList arrayList = new ArrayList();
        Map hashMap = new HashMap();
        for (Query query : collection) {
            EntityAspectQueryAnnotation entityAspectQueryAnnotation = (EntityAspectQueryAnnotation) query.getAnnotation(EntityAspectQueryAnnotation.class);
            String entityId = entityAspectQueryAnnotation.getEntityId();
            Sentence createSentenceFromTokenizedString = DocumentFactory.createSentenceFromTokenizedString(entityAspectQueryAnnotation.getEntity());
            if (entityId != null) {
                for (String str : entityId.split(ID_SEPARATOR_REGEX)) {
                    arrayList.add(str);
                    if (!hashMap.containsKey(str)) {
                        INDArray lookup = this.encoder instanceof InMemoryIndex ? this.encoder.lookup(str) : null;
                        if (lookup == null) {
                            log.info("Fallback encoding entity {} '{}'", entityAspectQueryAnnotation.getEntityId(), entityAspectQueryAnnotation.getEntity());
                            lookup = this.encoder.encode(createSentenceFromTokenizedString);
                        }
                        hashMap.put(str, lookup);
                    }
                }
            } else {
                log.warn("Found query without entityID for '{}' - skipping", query.getId(), entityAspectQueryAnnotation.getEntity());
            }
        }
        buildKeyIndex(arrayList, true);
        buildVectorIndex(hashMap, true);
        setModelAvailable(true);
    }
}
