package de.uni_trier.wi2.procake.utils.eval;

import de.uni_trier.wi2.procake.data.object.DataObject;
import de.uni_trier.wi2.procake.data.objectpool.DataObjectIterator;
import de.uni_trier.wi2.procake.data.objectpool.WriteableObjectPool;
import de.uni_trier.wi2.procake.data.trainingObjectPool.TrainingObjectPool;
import de.uni_trier.wi2.procake.retrieval.IdSimilarityPair;
import de.uni_trier.wi2.procake.retrieval.Query;
import de.uni_trier.wi2.procake.retrieval.RetrievalResultList;
import de.uni_trier.wi2.procake.retrieval.Retriever;
import de.uni_trier.wi2.procake.retrieval.SimpleSimilarityResult;
import de.uni_trier.wi2.procake.utils.exception.RetrieverEvaluationException;
import de.vandermeer.asciitable.AsciiTable;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.StringWriter;
import java.io.Writer;
import java.math.RoundingMode;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.io.output.TeeWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_trier/wi2/procake/utils/eval/RetrieverEvaluation.class */
public class RetrieverEvaluation<TCase extends DataObject> {
    protected final Logger logger;
    protected final Map<String, Retriever<TCase, Query>> retrievers;
    protected final List<EvalMetric> metrics;
    protected TrainingObjectPool<TCase> trainingObjectPool;
    protected List<SimpleSimilarityResult> groundTruthSimilarities;
    protected Integer k;
    protected Map<RetrieverMetricKeyPair, Double> metricResults;
    protected Map<CasePair, Collection<RetrieverSimilarityPair>> similarityResults;
    protected HashMap<String, List<Double>> retrievalTimeResultMap;
    protected DecimalFormat decimalFormat;
    protected boolean trackSimilarityResults;

    public RetrieverEvaluation() {
        this(null, null);
    }

    public RetrieverEvaluation(WriteableObjectPool<TCase> writeableObjectPool, WriteableObjectPool<TCase> writeableObjectPool2) {
        this.logger = LoggerFactory.getLogger(RetrieverEvaluation.class);
        this.decimalFormat = new DecimalFormat("#.####");
        this.retrievers = new LinkedHashMap();
        this.metrics = new ArrayList();
        if (writeableObjectPool != null && writeableObjectPool2 != null) {
            this.trainingObjectPool = new TrainingObjectPool<>(writeableObjectPool, writeableObjectPool2, (WriteableObjectPool) null);
        }
        this.k = null;
    }

    public Map<RetrieverMetricKeyPair, Double> performEvaluation() throws RetrieverEvaluationException {
        this.metricResults = new HashMap();
        if (this.trackSimilarityResults) {
            this.similarityResults = new HashMap();
        }
        if (this.trainingObjectPool == null || this.trainingObjectPool.getTrainPool() == null || this.trainingObjectPool.getTestPool() == null || this.retrievers.isEmpty()) {
            throw new IllegalStateException("Train case base, test case base, and retrievers to evaluate have to be initialized before starting evaluation!");
        }
        this.logger.info("Testing retrievers prior to evaluation...");
        RetrieverEvaluationUtils.testRetrievers(this.trainingObjectPool, this.retrievers);
        this.logger.info("Testing retrievers was successful!");
        this.logger.info("Testing ground-truth similarities prior to evaluation...");
        RetrieverEvaluationUtils.testGroundTruthSimilarities(this.trainingObjectPool, this.groundTruthSimilarities);
        this.logger.info("Testing ground-truth similarities was successful!");
        this.logger.info("Performing retrieval with every retriever...");
        HashMap hashMap = new HashMap();
        this.retrievalTimeResultMap = new HashMap<>();
        for (String str : this.retrievers.keySet()) {
            this.metrics.forEach(evalMetric -> {
                hashMap.put(new RetrieverMetricKeyPair(str, evalMetric), new ArrayList());
            });
            this.retrievalTimeResultMap.put(str, new ArrayList());
        }
        int i = 1;
        for (Map.Entry<String, Retriever<TCase, Query>> entry : this.retrievers.entrySet()) {
            Retriever<TCase, Query> value = entry.getValue();
            value.setObjectPool(this.trainingObjectPool.getTrainPool());
            DataObjectIterator<TCase> it = this.trainingObjectPool.getTestPool().iterator();
            while (it.hasNext()) {
                DataObject dataObject = (DataObject) it.next();
                this.logger.debug("... retrieval " + i + "/" + (this.retrievers.size() * this.trainingObjectPool.getTestPool().size()) + " (" + entry.getKey() + ")");
                Query newQuery = value.newQuery();
                newQuery.setQueryObject(dataObject);
                newQuery.setNumberOfResults(this.trainingObjectPool.getTrainPool().size());
                newQuery.setRetrieveCases(true);
                RetrievalResultList perform = value.perform(newQuery);
                if (this.trainingObjectPool.getTrainPool().size() != perform.size()) {
                    throw new RetrieverEvaluationException("Retrieval results do not contain every graph of case base!", this.trainingObjectPool.getTrainPool().toString());
                }
                this.retrievalTimeResultMap.get(entry.getKey()).add(Double.valueOf(perform.getRetrievalTime() / 1000000.0d));
                SimpleSimilarityResult orElseThrow = this.groundTruthSimilarities.stream().filter(simpleSimilarityResult -> {
                    return simpleSimilarityResult.getQueryID().equals(dataObject.getId());
                }).findFirst().orElseThrow();
                SimpleSimilarityResult fromRetrievalResultList = SimpleSimilarityResult.fromRetrievalResultList(perform);
                ((Stream) this.metrics.stream().parallel()).forEach(evalMetric2 -> {
                    ((List) hashMap.get(new RetrieverMetricKeyPair((String) entry.getKey(), evalMetric2))).add(Double.valueOf(evalMetric2.computeEvalMetric(orElseThrow, fromRetrievalResultList)));
                });
                if (this.trackSimilarityResults) {
                    Iterator<IdSimilarityPair> it2 = fromRetrievalResultList.iterator();
                    while (it2.hasNext()) {
                        IdSimilarityPair next = it2.next();
                        this.similarityResults.computeIfAbsent(new CasePair(dataObject, this.trainingObjectPool.getTrainPool().getObject(next.getId())), casePair -> {
                            return new ArrayList();
                        }).add(new RetrieverSimilarityPair(entry.getKey(), next.getSimilarity()));
                    }
                }
                i++;
            }
        }
        for (String str2 : this.retrievers.keySet()) {
            Iterator<EvalMetric> it3 = this.metrics.iterator();
            while (it3.hasNext()) {
                RetrieverMetricKeyPair retrieverMetricKeyPair = new RetrieverMetricKeyPair(str2, it3.next());
                this.metricResults.put(retrieverMetricKeyPair, Double.valueOf(((Double) ((List) hashMap.get(retrieverMetricKeyPair)).stream().reduce((v0, v1) -> {
                    return Double.sum(v0, v1);
                }).orElseThrow()).doubleValue() / this.trainingObjectPool.getTestPool().size()));
            }
        }
        return this.metricResults;
    }

    public void setGroundTruthSimilarities(List<SimpleSimilarityResult> list) {
        this.groundTruthSimilarities = list;
    }

    public void addRetrieverToEvaluate(String str, Retriever<TCase, Query> retriever) {
        this.retrievers.put(str, retriever);
    }

    public void addMetricToEvaluate(EvalMetric evalMetric) {
        this.metrics.add(evalMetric);
    }

    public void setTrainTestCaseBase(WriteableObjectPool<TCase> writeableObjectPool, WriteableObjectPool<TCase> writeableObjectPool2) {
        this.trainingObjectPool = new TrainingObjectPool<>(writeableObjectPool, writeableObjectPool2, (WriteableObjectPool) null);
    }

    public Integer getK() {
        return this.k;
    }

    public void setK(Integer num) {
        this.k = num;
    }

    public void setDecimalFormat(DecimalFormat decimalFormat) {
        this.decimalFormat = decimalFormat;
    }

    public void addRetrieversToEvaluate(Map<String, Retriever<TCase, Query>> map) {
        for (String str : map.keySet()) {
            addRetrieverToEvaluate(str, map.get(str));
        }
    }

    public void addMetricsToEvaluate(Collection<EvalMetric> collection) {
        Iterator<EvalMetric> it = collection.iterator();
        while (it.hasNext()) {
            addMetricToEvaluate(it.next());
        }
    }

    public void importGroundTruthSimilarities(String str) {
        this.logger.debug("Importing ground truth similarities...");
        this.groundTruthSimilarities = RetrieverEvaluationUtils.loadGroundTruthSimilarities(str);
        try {
            RetrieverEvaluationUtils.testGroundTruthSimilarities(this.trainingObjectPool, this.groundTruthSimilarities);
        } catch (RetrieverEvaluationException e) {
            e.printStackTrace();
        }
    }

    public String writeSimilarityResultsAsCSV(String str) throws IOException {
        if (str == null) {
            return writeSimilarityResultsAsCSV(new ByteArrayOutputStream());
        }
        File file = new File(str);
        this.logger.info("Writing similarity CSV results to file path \"" + file.getAbsolutePath() + "\" ...");
        if (!file.exists()) {
            file.getParentFile().mkdirs();
            file.createNewFile();
        }
        FileOutputStream fileOutputStream = new FileOutputStream(str);
        String writeSimilarityResultsAsCSV = writeSimilarityResultsAsCSV(fileOutputStream);
        fileOutputStream.close();
        return writeSimilarityResultsAsCSV;
    }

    public String getSimilarityResultsAsCSVString() throws IOException, RetrieverEvaluationException {
        return writeSimilarityResultsAsCSV((String) null);
    }

    public String writeSimilarityResultsAsCSV(OutputStream outputStream) throws IOException {
        if (outputStream == null) {
            throw new IllegalStateException("The output stream must not by null!");
        }
        if (this.similarityResults == null) {
            throw new IllegalStateException("The similarity results must not by null!");
        }
        LinkedList linkedList = new LinkedList();
        linkedList.add("qCase");
        linkedList.add("cCase");
        linkedList.add("GT");
        Stream<String> sorted = this.retrievers.keySet().stream().sorted();
        Objects.requireNonNull(linkedList);
        sorted.forEach((v1) -> {
            r1.add(v1);
        });
        this.decimalFormat.setDecimalFormatSymbols(DecimalFormatSymbols.getInstance(Locale.ENGLISH));
        this.decimalFormat.setRoundingMode(RoundingMode.HALF_UP);
        StringWriter stringWriter = new StringWriter();
        CSVPrinter cSVPrinter = new CSVPrinter(new TeeWriter(new Writer[]{new OutputStreamWriter(outputStream), stringWriter}), CSVFormat.EXCEL.builder().setHeader((String[]) linkedList.toArray(new String[0])).build());
        for (CasePair casePair : (List) this.similarityResults.keySet().stream().sorted((casePair2, casePair3) -> {
            int compareTo = casePair2.getQueryGraph().getId().compareTo(casePair3.getQueryGraph().getId());
            return compareTo == 0 ? casePair2.getCaseGraph().getId().compareTo(casePair3.getCaseGraph().getId()) : compareTo;
        }).collect(Collectors.toList())) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(casePair.getQueryGraph().getId());
            arrayList.add(casePair.getCaseGraph().getId());
            arrayList.add(this.decimalFormat.format(getGroundTruthSimilarity(casePair)));
            this.similarityResults.get(casePair).stream().sorted(Comparator.comparing((v0) -> {
                return v0.getRetriever();
            })).forEach(retrieverSimilarityPair -> {
                arrayList.add(this.decimalFormat.format(retrieverSimilarityPair.getSimilarity()));
            });
            try {
                cSVPrinter.printRecord(arrayList);
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        cSVPrinter.flush();
        return stringWriter.toString();
    }

    private double getGroundTruthSimilarity(CasePair casePair) {
        return this.groundTruthSimilarities.stream().filter(simpleSimilarityResult -> {
            return simpleSimilarityResult.getQueryID().equals(casePair.getQueryGraph().getId());
        }).flatMap(simpleSimilarityResult2 -> {
            return simpleSimilarityResult2.getCaseSimilarities().stream();
        }).filter(idSimilarityPair -> {
            return idSimilarityPair.getId().equals(casePair.getCaseGraph().getId());
        }).mapToDouble((v0) -> {
            return v0.getSimilarity();
        }).findFirst().orElseThrow();
    }

    public String writeMetricResultsAsCSV(String str) throws IOException, RetrieverEvaluationException {
        if (str == null) {
            return writeMetricResultsAsCSV(new ByteArrayOutputStream());
        }
        File file = new File(str);
        this.logger.info("Writing metric CSV results to file path \"" + file.getAbsolutePath() + "\" ...");
        if (!file.exists()) {
            file.getParentFile().mkdirs();
            file.createNewFile();
        }
        FileOutputStream fileOutputStream = new FileOutputStream(str);
        String writeMetricResultsAsCSV = writeMetricResultsAsCSV(fileOutputStream);
        fileOutputStream.close();
        return writeMetricResultsAsCSV;
    }

    public String getMetricResultsAsCSVString() throws IOException, RetrieverEvaluationException {
        return writeMetricResultsAsCSV((String) null);
    }

    public String writeMetricResultsAsCSV(OutputStream outputStream) throws IOException, RetrieverEvaluationException {
        if (outputStream == null) {
            throw new IllegalStateException("The output stream must not by null!");
        }
        LinkedList linkedList = new LinkedList();
        linkedList.add("retriever name");
        linkedList.add("time (ms)");
        this.metrics.forEach(evalMetric -> {
            linkedList.add(evalMetric.getMetricName());
        });
        this.decimalFormat.setDecimalFormatSymbols(DecimalFormatSymbols.getInstance(Locale.ENGLISH));
        this.decimalFormat.setRoundingMode(RoundingMode.HALF_UP);
        StringWriter stringWriter = new StringWriter();
        CSVPrinter cSVPrinter = new CSVPrinter(new TeeWriter(new Writer[]{new OutputStreamWriter(outputStream), stringWriter}), CSVFormat.EXCEL.builder().setHeader((String[]) linkedList.toArray(new String[0])).build());
        for (String str : this.retrievers.keySet()) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(str);
            arrayList.add(this.decimalFormat.format(this.retrievalTimeResultMap.get(str).stream().reduce((v0, v1) -> {
                return Double.sum(v0, v1);
            }).orElseThrow().doubleValue() / this.trainingObjectPool.getTestPool().size()));
            Iterator<EvalMetric> it = this.metrics.iterator();
            while (it.hasNext()) {
                arrayList.add(this.decimalFormat.format(this.metricResults.get(new RetrieverMetricKeyPair(str, it.next()))));
            }
            cSVPrinter.printRecord(arrayList);
        }
        cSVPrinter.flush();
        return stringWriter.toString();
    }

    public void printMetricResultsAsASCIITable() {
        LinkedList linkedList = new LinkedList();
        linkedList.add("retriever name");
        linkedList.add("time (ms)");
        this.metrics.forEach(evalMetric -> {
            linkedList.add(evalMetric.getMetricName());
        });
        AsciiTable asciiTable = new AsciiTable();
        asciiTable.addRule();
        asciiTable.addRow(linkedList);
        asciiTable.addRule();
        this.decimalFormat.setDecimalFormatSymbols(DecimalFormatSymbols.getInstance(Locale.ENGLISH));
        this.decimalFormat.setRoundingMode(RoundingMode.HALF_UP);
        for (String str : this.retrievers.keySet()) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(str);
            arrayList.add(this.decimalFormat.format(this.retrievalTimeResultMap.get(str).stream().reduce((v0, v1) -> {
                return Double.sum(v0, v1);
            }).get().doubleValue() / this.trainingObjectPool.getTestPool().size()));
            Iterator<EvalMetric> it = this.metrics.iterator();
            while (it.hasNext()) {
                arrayList.add(this.decimalFormat.format(this.metricResults.get(new RetrieverMetricKeyPair(str, it.next()))));
            }
            asciiTable.addRow(arrayList);
            asciiTable.addRule();
        }
        this.logger.debug("\n" + asciiTable.render());
    }

    public void trackSimilarityResults() {
        this.trackSimilarityResults = true;
    }
}
