package de.uni_trier.wi2.procake.utils.eval;

import de.uni_trier.wi2.procake.data.object.nest.NESTGraphObject;
import de.uni_trier.wi2.procake.data.objectpool.DataObjectIterator;
import de.uni_trier.wi2.procake.data.objectpool.ReadableObjectPool;
import de.uni_trier.wi2.procake.retrieval.Query;
import de.uni_trier.wi2.procake.retrieval.RetrievalResult;
import de.uni_trier.wi2.procake.retrieval.RetrievalResultList;
import de.uni_trier.wi2.procake.retrieval.Retriever;
import de.uni_trier.wi2.procake.retrieval.impl.RetrievalResultImpl;
import de.uni_trier.wi2.procake.retrieval.impl.RetrievalResultListImpl;
import de.uni_trier.wi2.procake.similarity.impl.SimilarityImpl;
import de.uni_trier.wi2.procake.utils.eval.metrics.EvalMetricComputer;
import de.uni_trier.wi2.procake.utils.eval.metrics.EvalMetricResult;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.collections4.map.MultiKeyMap;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_trier/wi2/procake/utils/eval/RetrieverEvaluation.class */
public class RetrieverEvaluation {
    private MultiKeyMap<String, Double> groundTruthSimilaritiesMap;
    private final Logger logger = LoggerFactory.getLogger(RetrieverEvaluation.class);
    private final Map<String, Retriever<NESTGraphObject, Query>> retrievers = new LinkedHashMap();
    private final List<EvalMetricComputer> metrics = new ArrayList();
    private ReadableObjectPool<NESTGraphObject> trainCaseBase = null;
    private ReadableObjectPool<NESTGraphObject> testCaseBase = null;
    private Integer k = null;

    /* loaded from: input_file:de/uni_trier/wi2/procake/utils/eval/RetrieverEvaluation$MetricWorkerTask.class */
    private static class MetricWorkerTask implements Runnable {
        private final LinkedList<RetrieverResultPair> resultsOfAllRetrievals;
        private final HashMap<String, List<EvalMetricResult>> metricResultsForAllRetrievers;
        private final CyclicBarrier barrier;
        private final EvalMetricComputer metric;
        private final Integer k;

        public MetricWorkerTask(LinkedList<RetrieverResultPair> linkedList, HashMap<String, List<EvalMetricResult>> hashMap, CyclicBarrier cyclicBarrier, EvalMetricComputer evalMetricComputer, Integer num) {
            this.resultsOfAllRetrievals = linkedList;
            HashMap<String, List<EvalMetricResult>> hashMap2 = new HashMap<>();
            for (Map.Entry<String, List<EvalMetricResult>> entry : hashMap.entrySet()) {
                hashMap2.put(entry.getKey(), entry.getValue());
            }
            this.metricResultsForAllRetrievers = hashMap2;
            this.barrier = cyclicBarrier;
            this.metric = evalMetricComputer;
            this.k = num;
        }

        @Override // java.lang.Runnable
        public void run() {
            Iterator<RetrieverResultPair> it = this.resultsOfAllRetrievals.iterator();
            while (it.hasNext()) {
                RetrieverResultPair next = it.next();
                String retrieverName = next.getRetrieverName();
                this.metricResultsForAllRetrievers.get(retrieverName).add(this.metric.computeMetric(next.getGroundTruthResults(), next.getResultList(), this.k));
            }
            try {
                this.barrier.await();
            } catch (InterruptedException | BrokenBarrierException e) {
                e.printStackTrace();
            }
        }
    }

    /* loaded from: input_file:de/uni_trier/wi2/procake/utils/eval/RetrieverEvaluation$RetrieverResultPair.class */
    private static class RetrieverResultPair {
        private final String retrieverName;
        private final RetrievalResultListImpl resultList;
        private final RetrievalResultListImpl groundTruthResults;

        public RetrieverResultPair(String str, RetrievalResultListImpl retrievalResultListImpl, RetrievalResultListImpl retrievalResultListImpl2) {
            this.retrieverName = str;
            this.resultList = retrievalResultListImpl;
            this.groundTruthResults = retrievalResultListImpl2;
        }

        public String getRetrieverName() {
            return this.retrieverName;
        }

        public RetrievalResultListImpl getResultList() {
            return this.resultList;
        }

        public RetrievalResultListImpl getGroundTruthResults() {
            return this.groundTruthResults;
        }
    }

    public void performEvaluation(String str) throws IOException, RetrieverEvalException {
        if (this.trainCaseBase == null || this.testCaseBase == null || this.retrievers.isEmpty()) {
            throw new IllegalStateException("Train case base, test case base, and retrievers to evaluate have to be initialized before starting evaluation!");
        }
        Iterator<Retriever<NESTGraphObject, Query>> it = this.retrievers.values().iterator();
        while (it.hasNext()) {
            it.next().setObjectPool(this.trainCaseBase);
        }
        HashMap hashMap = new HashMap();
        for (EvalMetricComputer evalMetricComputer : this.metrics) {
            hashMap.put(evalMetricComputer.getMetricName(), evalMetricComputer);
        }
        MultiKeyMap multiKeyMap = new MultiKeyMap();
        for (EvalMetricComputer evalMetricComputer2 : this.metrics) {
            Iterator<String> it2 = this.retrievers.keySet().iterator();
            while (it2.hasNext()) {
                multiKeyMap.put(it2.next(), evalMetricComputer2.getMetricName(), new LinkedList());
            }
        }
        LinkedList linkedList = new LinkedList();
        this.logger.info("Computing similarities with every retriever ...");
        HashMap hashMap2 = new HashMap();
        int i = 1;
        for (Map.Entry<String, Retriever<NESTGraphObject, Query>> entry : this.retrievers.entrySet()) {
            String key = entry.getKey();
            Retriever<NESTGraphObject, Query> value = entry.getValue();
            value.setObjectPool(this.trainCaseBase);
            DataObjectIterator<NESTGraphObject> it3 = this.testCaseBase.iterator();
            while (it3.hasNext()) {
                NESTGraphObject nESTGraphObject = (NESTGraphObject) it3.next();
                this.logger.info("... retrieval " + i + "/" + (this.retrievers.size() * this.testCaseBase.size()));
                Query newQuery = value.newQuery();
                newQuery.setQueryObject(nESTGraphObject);
                newQuery.setNumberOfResults(this.trainCaseBase.size());
                newQuery.setRetrieveCases(true);
                RetrievalResultList perform = value.perform(newQuery);
                if (this.trainCaseBase.size() != perform.size()) {
                    throw new RetrieverEvalException("Size of retrieval results must equal size of case base", value, nESTGraphObject);
                }
                RetrievalResultListImpl retrievalResultListImpl = new RetrievalResultListImpl();
                DataObjectIterator<NESTGraphObject> it4 = this.trainCaseBase.iterator();
                while (it4.hasNext()) {
                    NESTGraphObject nESTGraphObject2 = (NESTGraphObject) it4.next();
                    retrievalResultListImpl.add(new RetrievalResultImpl(new SimilarityImpl(null, nESTGraphObject, nESTGraphObject2, ((Double) this.groundTruthSimilaritiesMap.get(nESTGraphObject.getId(), nESTGraphObject2.getId())).doubleValue()), nESTGraphObject2));
                }
                hashMap2.merge(entry.getKey(), Double.valueOf(perform.getRetrievalTime() / 1000000.0d), (v0, v1) -> {
                    return Double.sum(v0, v1);
                });
                linkedList.add(new RetrieverResultPair(key, (RetrievalResultListImpl) perform, retrievalResultListImpl));
                i++;
            }
        }
        HashMap hashMap3 = new HashMap();
        Iterator<String> it5 = this.retrievers.keySet().iterator();
        while (it5.hasNext()) {
            hashMap3.put(it5.next(), Collections.synchronizedList(new LinkedList()));
        }
        int size = this.metrics.size();
        ThreadPoolExecutor threadPoolExecutor = new ThreadPoolExecutor(size, size, 30L, TimeUnit.SECONDS, new ArrayBlockingQueue(size));
        CyclicBarrier cyclicBarrier = new CyclicBarrier(size + 1);
        Iterator<EvalMetricComputer> it6 = this.metrics.iterator();
        while (it6.hasNext()) {
            threadPoolExecutor.execute(new MetricWorkerTask(linkedList, hashMap3, cyclicBarrier, it6.next(), this.k));
        }
        threadPoolExecutor.shutdown();
        try {
            cyclicBarrier.await();
        } catch (InterruptedException | BrokenBarrierException e) {
            e.printStackTrace();
        }
        MultiKeyMap multiKeyMap2 = new MultiKeyMap();
        for (EvalMetricComputer evalMetricComputer3 : this.metrics) {
            for (String str2 : this.retrievers.keySet()) {
                List<EvalMetricResult> list = (List) ((List) hashMap3.get(str2)).stream().filter(evalMetricResult -> {
                    return evalMetricResult.getMetricName().equals(evalMetricComputer3.getMetricName());
                }).collect(Collectors.toList());
                multiKeyMap2.put(str2, evalMetricComputer3.getMetricName(), list.get(0).average(list));
            }
        }
        StringWriter stringWriter = new StringWriter();
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add("retriever name");
        linkedList2.add("time (ms)");
        Iterator<EvalMetricComputer> it7 = this.metrics.iterator();
        while (it7.hasNext()) {
            linkedList2.add(it7.next().getMetricName());
        }
        CSVPrinter cSVPrinter = new CSVPrinter(stringWriter, CSVFormat.DEFAULT.withHeader((String[]) linkedList2.toArray(new String[0])));
        for (String str3 : this.retrievers.keySet()) {
            cSVPrinter.print(str3);
            cSVPrinter.print(Double.valueOf(((Double) hashMap2.get(str3)).doubleValue() / this.testCaseBase.size()));
            Iterator<EvalMetricComputer> it8 = this.metrics.iterator();
            while (it8.hasNext()) {
                cSVPrinter.print(((EvalMetricResult) multiKeyMap2.get(str3, it8.next().getMetricName())).getValue());
            }
            cSVPrinter.println();
        }
        cSVPrinter.close(true);
        String stringWriter2 = stringWriter.toString();
        this.logger.info("\n" + stringWriter2);
        stringWriter.close();
        if (str != null) {
            writeEvalResultsToFile(str, stringWriter2);
        }
    }

    private void writeEvalResultsToFile(String str, String str2) throws IOException {
        File file = new File(str);
        if (!file.exists()) {
            file.getParentFile().mkdirs();
            file.createNewFile();
        }
        FileWriter fileWriter = new FileWriter(file, false);
        fileWriter.append((CharSequence) str2);
        fileWriter.flush();
        fileWriter.close();
    }

    public void loadGroundTruthSimilarities(String str) {
        this.logger.info("Loading ground truth similarities...");
        this.groundTruthSimilaritiesMap = new MultiKeyMap<>();
        try {
            CSVParser cSVParser = new CSVParser(new FileReader(str), CSVFormat.DEFAULT.withFirstRecordAsHeader());
            Iterator it = cSVParser.iterator();
            while (it.hasNext()) {
                CSVRecord cSVRecord = (CSVRecord) it.next();
                this.groundTruthSimilaritiesMap.put(cSVRecord.get(0), cSVRecord.get(1), Double.valueOf(Double.parseDouble(cSVRecord.get(2))));
            }
            cSVParser.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void computeGroundTruthSimilarities(Retriever<NESTGraphObject, Query> retriever, String str) throws IOException {
        retriever.setObjectPool(this.trainCaseBase);
        this.logger.info("Computing ground truth similarities...");
        int i = 1;
        HashMap hashMap = new HashMap();
        DataObjectIterator<NESTGraphObject> it = this.testCaseBase.iterator();
        while (it.hasNext()) {
            NESTGraphObject nESTGraphObject = (NESTGraphObject) it.next();
            this.logger.info("... retrieval " + i + "/" + this.testCaseBase.size());
            Query newQuery = retriever.newQuery();
            newQuery.setQueryObject(nESTGraphObject);
            newQuery.setNumberOfResults(this.trainCaseBase.size());
            newQuery.setRetrieveCases(true);
            hashMap.put(nESTGraphObject.getId(), retriever.perform(newQuery));
            i++;
        }
        if (str != null) {
            this.logger.info("Exporting ground truth similarities...");
            File file = new File(str);
            if (!file.exists()) {
                this.logger.info("File created: " + file.createNewFile());
            }
            CSVPrinter cSVPrinter = new CSVPrinter(new FileWriter(file, false), CSVFormat.DEFAULT.withHeader(new String[]{"NameQueryGraph", "NameCaseGraph", "Similarity"}));
            for (Map.Entry entry : hashMap.entrySet()) {
                String str2 = (String) entry.getKey();
                for (RetrievalResult retrievalResult : (RetrievalResultList) entry.getValue()) {
                    cSVPrinter.printRecord(new Object[]{str2, retrievalResult.getObject().getId(), Double.valueOf(retrievalResult.getSimilarity().getValue())});
                }
            }
            cSVPrinter.close(true);
            this.logger.info("Export finished.");
        }
        this.groundTruthSimilaritiesMap = new MultiKeyMap<>();
        for (Map.Entry entry2 : hashMap.entrySet()) {
            String str3 = (String) entry2.getKey();
            for (RetrievalResult retrievalResult2 : (RetrievalResultList) entry2.getValue()) {
                this.groundTruthSimilaritiesMap.put(str3, retrievalResult2.getObject().getId(), Double.valueOf(retrievalResult2.getSimilarity().getValue()));
            }
        }
    }

    public void addRetrieverToEvaluate(String str, Retriever<NESTGraphObject, Query> retriever) {
        this.retrievers.put(str, retriever);
    }

    public void addMetricToEvaluate(EvalMetricComputer evalMetricComputer) {
        this.metrics.add(evalMetricComputer);
    }

    public void setTrainCaseBase(ReadableObjectPool<NESTGraphObject> readableObjectPool) {
        this.trainCaseBase = readableObjectPool;
    }

    public void setTestCaseBase(ReadableObjectPool<NESTGraphObject> readableObjectPool) {
        this.testCaseBase = readableObjectPool;
    }

    public Integer getK() {
        return this.k;
    }

    public void setK(Integer num) {
        this.k = num;
    }
}
