package de.datexis.cdv.retrieval;

import de.datexis.cdv.index.DocumentIndex;
import de.datexis.cdv.index.QueryIndex;
import de.datexis.cdv.model.EntityAspectAnnotation;
import de.datexis.common.AnnotationHelpers;
import de.datexis.common.Timer;
import de.datexis.encoder.IEncoder;
import de.datexis.model.Annotation;
import de.datexis.model.Dataset;
import de.datexis.model.Document;
import de.datexis.model.Query;
import de.datexis.model.Result;
import de.datexis.model.Sentence;
import de.datexis.model.impl.PassageAnnotation;
import de.datexis.retrieval.model.RelevanceResult;
import de.datexis.retrieval.model.ScoredResult;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/datexis/cdv/retrieval/QueryRunner.class */
public class QueryRunner {
    protected static final Logger log = LoggerFactory.getLogger(QueryRunner.class);
    public static final int NUM_CANDIDATES = 64;
    Dataset corpus;
    QueryIndex entityIndex;
    QueryIndex aspectIndex;
    Strategy strategy;
    DocumentIndex index;
    protected Timer timer;

    /* loaded from: input_file:de/datexis/cdv/retrieval/QueryRunner$Candidates.class */
    public enum Candidates {
        ALL,
        GIVEN,
        INDEX
    }

    /* loaded from: input_file:de/datexis/cdv/retrieval/QueryRunner$Strategy.class */
    public enum Strategy {
        SENTENCE_THRESHOLD,
        PASSAGE_RANK
    }

    public QueryRunner(Dataset dataset, QueryIndex queryIndex, QueryIndex queryIndex2) {
        this(dataset, queryIndex, queryIndex2, Strategy.SENTENCE_THRESHOLD);
    }

    public QueryRunner(Dataset dataset, QueryIndex queryIndex, QueryIndex queryIndex2, Strategy strategy) {
        this.timer = new Timer();
        this.corpus = dataset;
        this.entityIndex = queryIndex;
        this.aspectIndex = queryIndex2;
        this.strategy = strategy;
        this.index = new DocumentIndex();
        try {
            this.index.createInMemoryIndex(dataset);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void retrieveAllQueries() {
    }

    public void retrieveAllQueries(Candidates candidates) {
        Timer timer = new Timer();
        timer.start();
        AtomicInteger atomicInteger = new AtomicInteger();
        long countQueries = this.corpus.countQueries();
        log.info("Retrieving {} queries on {} documents...", Long.valueOf(countQueries), Integer.valueOf(this.corpus.countDocuments()));
        this.corpus.getQueries().stream().forEach(query -> {
            EntityAspectQueryAnnotation entityAspectQueryAnnotation = (EntityAspectQueryAnnotation) query.getAnnotation(EntityAspectQueryAnnotation.class);
            if (candidates.equals(Candidates.GIVEN)) {
                retrieveQueryFromCandidates(query);
            } else if (candidates.equals(Candidates.INDEX)) {
                retrieveQueryFromIndex(query);
            } else {
                retrieveQuery(query);
            }
            log.info("Finished query {}/{} '{}' ({}) - '{}' [{}]", new Object[]{Integer.valueOf(atomicInteger.incrementAndGet()), Long.valueOf(countQueries), entityAspectQueryAnnotation.getEntity(), entityAspectQueryAnnotation.getEntityId(), entityAspectQueryAnnotation.getAspect(), Timer.millisToLongDHMS(timer.setSplit("query"))});
        });
        long j = timer.getLong();
        log.info("Finished {} queries on {} documents... [{}, {}/q]", new Object[]{Long.valueOf(countQueries), Integer.valueOf(this.corpus.countDocuments()), Timer.millisToLongDHMS(j), Timer.millisToLongDHMS(j / countQueries)});
    }

    public void retrieveAllQueries(long j) {
        for (Query query : this.corpus.getQueries()) {
            long j2 = j;
            j = j2 - 1;
            if (j2 <= 0) {
                return;
            } else {
                retrieveQuery(query);
            }
        }
    }

    public void retrieveAllQueriesPerDocument() {
        for (Query query : this.corpus.getQueries()) {
            retrieveQueryFromDocs(query, Collections.singleton(query.getDocumentRef()));
        }
    }

    public Query retrieveQuery(Query query) {
        return retrieveQueryFromDocs(query, this.corpus.getDocuments());
    }

    public Query retrieveQueryFromCandidates(Query query) {
        Collection<? extends Annotation> results = query.getResults(Annotation.Source.GOLD, RelevanceResult.class);
        results.addAll(query.getResults(Annotation.Source.SILVER, RelevanceResult.class));
        HashSet hashSet = new HashSet();
        Iterator<? extends Annotation> it = results.iterator();
        while (it.hasNext()) {
            hashSet.add(((Result) it.next()).getDocumentRef());
        }
        return retrieveQueryFromDocs(query, hashSet, results);
    }

    public Query retrieveQueryFromIndex(Query query) {
        return retrieveQueryFromDocs(query, (List) this.index.search(((EntityAspectQueryAnnotation) query.getAnnotation(EntityAspectQueryAnnotation.class)).getEntity(), 64).stream().map(documentResult -> {
            return (Document) this.corpus.getDocument(documentResult.documentId).get();
        }).collect(Collectors.toList()));
    }

    protected Query retrieveQueryFromDocs(Query query, Collection<Document> collection) {
        return retrieveQueryFromDocs(query, collection, null);
    }

    protected Query retrieveQueryFromDocs(Query query, Collection<Document> collection, Collection<? extends Annotation> collection2) {
        EntityAspectQueryAnnotation entityAspectQueryAnnotation = (EntityAspectQueryAnnotation) query.getAnnotation(EntityAspectQueryAnnotation.class);
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        if (this.entityIndex != null && entityAspectQueryAnnotation.hasEntity()) {
            iNDArray = this.entityIndex.lookup(entityAspectQueryAnnotation.getEntityId() != null ? entityAspectQueryAnnotation.getEntityId() : entityAspectQueryAnnotation.getEntity());
            if (iNDArray == null) {
                log.debug("fallback encoding entity '{}'", entityAspectQueryAnnotation.getEntity());
                iNDArray = this.entityIndex.encode(entityAspectQueryAnnotation.getEntity());
            }
        }
        if (this.aspectIndex != null && entityAspectQueryAnnotation.hasAspect()) {
            iNDArray2 = this.aspectIndex.lookup(this.aspectIndex.getKeyPreprocessor().preProcess(entityAspectQueryAnnotation.getAspect()));
            if (iNDArray2 == null) {
                log.error("fallback encoding aspect '{}'", entityAspectQueryAnnotation.getAspect());
                iNDArray2 = this.aspectIndex.encode(entityAspectQueryAnnotation.getAspect());
            }
        }
        INDArray iNDArray3 = iNDArray;
        INDArray iNDArray4 = iNDArray2;
        ((Stream) collection.stream().parallel()).filter(document -> {
            return !document.isEmpty();
        }).forEach(document2 -> {
            retrievePassages(document2, query, getHistogram(document2, iNDArray3, iNDArray4), collection2);
        });
        return query;
    }

    public Query retrieveQuery(Document document, Query query) {
        return retrievePassages(document, query, getHistogram(document, query));
    }

    protected Query retrievePassages(Document document, Query query, INDArray iNDArray) {
        return retrievePassages(document, query, iNDArray, null);
    }

    protected Query retrievePassages(Document document, Query query, INDArray iNDArray, Collection<? extends Annotation> collection) {
        switch (this.strategy) {
            case PASSAGE_RANK:
                return retrievePassagesByRanking(document, query, iNDArray, collection);
            case SENTENCE_THRESHOLD:
            default:
                return retrievePassagesByThreshold(document, query, iNDArray);
        }
    }

    public INDArray getHistogram(Document document, Query query) {
        EntityAspectQueryAnnotation entityAspectQueryAnnotation = (EntityAspectQueryAnnotation) query.getAnnotation(EntityAspectQueryAnnotation.class);
        INDArray iNDArray = null;
        INDArray iNDArray2 = null;
        if (this.entityIndex != null && entityAspectQueryAnnotation.hasEntity()) {
            iNDArray = this.entityIndex.lookup(entityAspectQueryAnnotation.getEntityId() != null ? entityAspectQueryAnnotation.getEntityId() : entityAspectQueryAnnotation.getEntity());
            if (iNDArray == null) {
                iNDArray = this.entityIndex.encode(entityAspectQueryAnnotation.getEntity());
            }
        }
        if (this.aspectIndex != null && entityAspectQueryAnnotation.hasAspect()) {
            iNDArray2 = this.aspectIndex.lookup(entityAspectQueryAnnotation.getAspect());
            if (iNDArray2 == null) {
                iNDArray2 = this.aspectIndex.encode(entityAspectQueryAnnotation.getAspect());
            }
        }
        return getHistogram(document, iNDArray, iNDArray2);
    }

    protected INDArray getHistogram(Document document, INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray != null && iNDArray2 != null) {
            return projectQuery(document, iNDArray, iNDArray2);
        }
        if (iNDArray != null) {
            return projectQuery(document, iNDArray, (IEncoder) this.entityIndex);
        }
        if (iNDArray2 != null) {
            return projectQuery(document, iNDArray2, (IEncoder) this.aspectIndex);
        }
        return null;
    }

    protected INDArray projectQuery(Document document, INDArray iNDArray, IEncoder iEncoder) {
        return Transforms.unitVec(iNDArray).transpose().mmul(document.getVector(iEncoder.getClass())).transpose();
    }

    protected INDArray projectQuery(Document document, INDArray iNDArray, INDArray iNDArray2) {
        INDArray vector = document.getVector(this.entityIndex.getClass());
        INDArray vector2 = document.getVector(this.aspectIndex.getClass());
        INDArray vstack = Nd4j.vstack(new INDArray[]{Transforms.unitVec(iNDArray), Transforms.unitVec(iNDArray2)});
        INDArray vstack2 = Nd4j.vstack(new INDArray[]{vector, vector2});
        for (int i = 0; i < vstack2.size(1); i++) {
            vstack2.getColumn(i).assign(Transforms.unitVec(vstack2.getColumn(i)));
        }
        return Transforms.unitVec(vstack).transpose().mmul(vstack2).transpose();
    }

    @Deprecated
    protected INDArray mergeHistograms(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray != null && iNDArray2 != null) {
            return iNDArray.add(iNDArray2).divi(Double.valueOf(2.0d));
        }
        if (iNDArray != null) {
            return iNDArray;
        }
        if (iNDArray2 != null) {
            return iNDArray2;
        }
        throw new IllegalArgumentException("Both encodings are null");
    }

    protected Query retrievePassagesByRanking(Document document, Query query, INDArray iNDArray, Collection<? extends Annotation> collection) {
        if (collection == null) {
            collection = (Collection) document.streamAnnotations(Annotation.Source.GOLD, PassageAnnotation.class, true).sorted().map(passageAnnotation -> {
                passageAnnotation.setDocumentRef(document);
                return passageAnnotation;
            }).collect(Collectors.toList());
        }
        for (Annotation annotation : collection) {
            if (annotation.getDocumentRef() == document) {
                List list = (List) AnnotationHelpers.streamSpansInRange(document, Sentence.class, annotation.getBegin(), annotation.getEnd(), true).collect(Collectors.toList());
                INDArray zeros = Nd4j.zeros(new int[]{list.size()});
                int i = 0;
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    int i2 = i;
                    i++;
                    zeros.putScalar(i2, iNDArray.getDouble(document.getSentenceIndexAtPosition(((Sentence) it.next()).getBegin())));
                }
                if (i > 0) {
                    addResult(query, document, annotation, zeros.meanNumber().doubleValue());
                }
            }
        }
        return query;
    }

    public static double percentile(double d, INDArray iNDArray) {
        INDArray sort = Nd4j.sort(iNDArray.dup(iNDArray.ordering()), true);
        double length = (d / 100.0d) * (sort.length() + 1);
        double floor = FastMath.floor(length);
        int i = (int) floor;
        double d2 = length - floor;
        double d3 = sort.getDouble(Math.max(0, i - 1));
        return d3 + (d2 * (sort.getDouble(Math.min(i, sort.length() - 1)) - d3));
    }

    public static double rms(INDArray iNDArray) {
        double d = 0.0d;
        if (iNDArray.length() == 0) {
            return 0.0d;
        }
        for (int i = 0; i < iNDArray.length(); i++) {
            d += Math.pow(iNDArray.getDouble(i), 2.0d);
        }
        if (d == 0.0d) {
            return 0.0d;
        }
        return Math.sqrt(d / iNDArray.length());
    }

    protected void printResults(Query query) {
        int i = 0;
        for (Result result : query.getResults(Annotation.Source.PRED, Result.class)) {
            i++;
            if (i > 10) {
                return;
            }
            EntityAspectAnnotation annotationRef = result.getAnnotationRef();
            log.info(" rank {}: {} - {} ({})", new Object[]{Integer.valueOf(i), annotationRef.getEntity(), annotationRef.getAspect(), Double.valueOf(result.getConfidence())});
        }
    }

    protected Query retrievePassagesByThreshold(Document document, Query query, INDArray iNDArray) {
        iNDArray.maxNumber().doubleValue();
        iNDArray.sumNumber().doubleValue();
        iNDArray.meanNumber().doubleValue();
        int i = 0;
        boolean z = false;
        int i2 = 0;
        int i3 = 0;
        double d = 1.0d;
        double d2 = 0.0d;
        for (Sentence sentence : document.getSentences()) {
            int i4 = i;
            i++;
            double d3 = iNDArray.getDouble(i4);
            if (!z && d3 >= 0.8d) {
                z = true;
                d = 1.0d;
                d2 = d3;
                i2 = sentence.getBegin();
                i3 = sentence.getEnd();
            } else if (z && d3 < 0.6d) {
                z = false;
                addResult(query, document, i2, i3, d2 / d);
            } else if (z) {
                d += 1.0d;
                d2 += d3;
                i3 = sentence.getEnd();
            }
        }
        if (z) {
            addResult(query, document, i2, i3, d2 / d);
        }
        return query;
    }

    public void addResult(Query query, Document document, int i, int i2, double d) {
        ScoredResult scoredResult = new ScoredResult(Annotation.Source.PRED, document, i, i2);
        scoredResult.setConfidence(d);
        scoredResult.setScore(Double.valueOf(d));
        query.addResult(scoredResult);
        log.trace("adding result from document '{}' with relevance {}: '{}'", new Object[]{document.getTitle(), Double.valueOf(d), document.getText(scoredResult)});
    }

    public void addResult(Query query, Document document, Annotation annotation, double d) {
        ScoredResult scoredResult = new ScoredResult(Annotation.Source.PRED, document, annotation.getBegin(), annotation.getEnd());
        scoredResult.setConfidence(d);
        scoredResult.setScore(Double.valueOf(d));
        scoredResult.setAnnotationRef(annotation);
        query.addResult(scoredResult);
        log.trace("adding result from document '{}' with relevance {}: '{}'", new Object[]{document.getTitle(), Double.valueOf(d), document.getText(scoredResult)});
    }
}
