package de.datexis.cdv.index;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import de.datexis.cdv.model.EntityAspectAnnotation;
import de.datexis.cdv.preprocess.AspectPreprocessor;
import de.datexis.common.Resource;
import de.datexis.common.WordHelpers;
import de.datexis.encoder.IEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Document;
import de.datexis.model.Query;
import de.datexis.model.Sentence;
import de.datexis.model.Span;
import de.datexis.model.Token;
import de.datexis.preprocess.DocumentFactory;
import de.datexis.retrieval.tagger.LSTMSentenceTaggerIterator;
import de.datexis.tagger.AbstractMultiDataSetIterator;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/cdv/index/AspectIndex.class */
public class AspectIndex extends QueryIndex {
    protected static final Logger log = LoggerFactory.getLogger(AspectIndex.class);
    public static String HEADING_SEPARATOR_REGEX = " \\| | and |&|/";

    protected AspectIndex() {
    }

    public AspectIndex(IEncoder iEncoder) {
        super(new AspectPreprocessor(), iEncoder);
        this.id = "ASP";
    }

    public INDArray encode(Span span) {
        if (span instanceof Token) {
            return encode(span.getText());
        }
        throw new IllegalArgumentException("Index is not configured to encode " + span.getClass());
    }

    public INDArray encode(String str) {
        String[] split = str.split(HEADING_SEPARATOR_REGEX);
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{getEmbeddingVectorSize(), 1});
        int i = 0;
        for (String str2 : split) {
            INDArray encode = super.encode(str2.trim());
            if (encode != null) {
                zeros.addi(encode);
                i++;
            }
        }
        return i > 1 ? zeros.divi(Integer.valueOf(i)) : zeros;
    }

    public INDArray lookup(String str) {
        String[] split = str.split(HEADING_SEPARATOR_REGEX);
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{getEmbeddingVectorSize(), 1});
        int i = 0;
        for (String str2 : split) {
            INDArray lookup = super.lookup(str2.trim());
            if (lookup != null) {
                zeros.addi(lookup);
                i++;
            }
        }
        if (i == 0) {
            return null;
        }
        return i > 1 ? zeros.divi(Integer.valueOf(i)) : zeros;
    }

    public INDArray decode(String str) {
        String[] split = str.split(HEADING_SEPARATOR_REGEX);
        INDArray zeros = Nd4j.zeros(DataType.FLOAT, new long[]{size(), 1});
        for (String str2 : split) {
            int index = index(str2.trim());
            if (index >= 0) {
                zeros.putScalarUnsafe(index, 1.0d);
            }
        }
        if (zeros.maxNumber().intValue() == 0) {
            int findIndex = findIndex(encode(str));
            if (findIndex >= 0) {
                zeros.putScalarUnsafe(findIndex, 1.0d);
            }
            log.warn("heading '{}' not contained in index, using nearest neighbour '{}'", str, key(findIndex));
        }
        return zeros;
    }

    @Override // de.datexis.cdv.index.QueryIndex
    public void encodeIndexFromLabels(Resource resource) {
        List<String> labels = new LSTMSentenceTaggerIterator(AbstractMultiDataSetIterator.Stage.ENCODE, this.encoder, (IEncoder) null, resource, "utf-8", WordHelpers.Language.EN, true, 64).getLabels();
        Multimap create = ArrayListMultimap.create();
        ArrayList arrayList = new ArrayList();
        for (String str : labels) {
            if (str.equals("Abstract")) {
                str = "Description";
            }
            for (String str2 : str.split(HEADING_SEPARATOR_REGEX)) {
                String preProcess = this.keyPreprocessor.preProcess(str2.trim());
                Sentence createSentenceFromTokenizedString = DocumentFactory.createSentenceFromTokenizedString(str2.trim());
                arrayList.add(preProcess);
                if (!create.containsKey(preProcess)) {
                    create.put(preProcess, createSentenceFromTokenizedString);
                }
            }
        }
        buildKeyIndex(arrayList, false);
        encodeAndBuildVectorIndex(create, false);
        setModelAvailable(true);
    }

    @Override // de.datexis.cdv.index.QueryIndex
    public void encodeIndexFromSentences(Resource resource, Set<String> set, boolean z) {
        LSTMSentenceTaggerIterator lSTMSentenceTaggerIterator = new LSTMSentenceTaggerIterator(AbstractMultiDataSetIterator.Stage.ENCODE, this.encoder, (IEncoder) null, resource, "utf-8", WordHelpers.Language.EN, set, z, 1);
        log.info("Reading {} examples...", Long.valueOf(lSTMSentenceTaggerIterator.getNumExamples()));
        ArrayListMultimap create = ArrayListMultimap.create();
        while (lSTMSentenceTaggerIterator.hasNext()) {
            Map.Entry nextLabeledSentence = lSTMSentenceTaggerIterator.nextLabeledSentence();
            String str = (String) nextLabeledSentence.getKey();
            if (str.equals("Abstract")) {
                str = "Description";
            }
            for (String str2 : str.split(HEADING_SEPARATOR_REGEX)) {
                create.put(this.keyPreprocessor.preProcess(str2.trim()), nextLabeledSentence.getValue());
            }
        }
        buildKeyIndex(create.keys(), false);
        encodeAndBuildVectorIndex(create, false);
        setModelAvailable(true);
    }

    @Override // de.datexis.cdv.index.QueryIndex
    public void encodeFromQueries(Collection<Query> collection) {
    }

    @Deprecated
    public void trainModel(Collection<Document> collection) {
        ArrayList arrayList = new ArrayList();
        Map hashMap = new HashMap();
        Iterator<Document> it = collection.iterator();
        while (it.hasNext()) {
            for (EntityAspectAnnotation entityAspectAnnotation : it.next().getAnnotations(Annotation.Source.GOLD, EntityAspectAnnotation.class)) {
                if (entityAspectAnnotation.getAspect() != null) {
                    for (String str : entityAspectAnnotation.getLabel().split(HEADING_SEPARATOR_REGEX)) {
                        String preProcess = this.keyPreprocessor.preProcess(str);
                        arrayList.add(preProcess);
                        if (!hashMap.containsKey(preProcess)) {
                            hashMap.put(preProcess, this.encoder.encode(DocumentFactory.createSentenceFromTokenizedString(str)));
                        }
                    }
                }
            }
        }
        buildKeyIndex(arrayList, false);
        buildVectorIndex(hashMap, false);
        setModelAvailable(true);
    }
}
