package de.uni_trier.wi2.procake.utils.eval;

import de.uni_trier.wi2.procake.data.io.xml.SimilarityTags;
import de.uni_trier.wi2.procake.data.object.DataObject;
import de.uni_trier.wi2.procake.data.objectpool.DataObjectIterator;
import de.uni_trier.wi2.procake.data.objectpool.ObjectPoolFactory;
import de.uni_trier.wi2.procake.data.objectpool.WriteableObjectPool;
import de.uni_trier.wi2.procake.retrieval.IdSimilarityPair;
import de.uni_trier.wi2.procake.retrieval.Query;
import de.uni_trier.wi2.procake.retrieval.RetrievalResultList;
import de.uni_trier.wi2.procake.retrieval.Retriever;
import de.uni_trier.wi2.procake.retrieval.SimpleSimilarityResult;
import de.uni_trier.wi2.procake.utils.eval.metrics.k.KEvalMetric;
import de.uni_trier.wi2.procake.utils.exception.RetrieverEvaluationException;
import de.vandermeer.asciitable.AsciiTable;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.StringWriter;
import java.io.Writer;
import java.math.RoundingMode;
import java.text.DecimalFormatSymbols;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Stream;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.io.output.TeeWriter;

/* loaded from: input_file:de/uni_trier/wi2/procake/utils/eval/MACFACRetrieverEvaluation.class */
public class MACFACRetrieverEvaluation<TCase extends DataObject> extends RetrieverEvaluation<TCase> {
    private int[] filterSizes;
    private int[] ks;
    private Retriever<TCase, Query> facGTRetriever;
    private HashMap<RetrieverFSKKeyPair, List<Double>> retrievalTimeResultMapCombined;

    @Override // de.uni_trier.wi2.procake.utils.eval.RetrieverEvaluation
    public Map<RetrieverMetricKeyPair, Double> performEvaluation() throws RetrieverEvaluationException {
        this.metricResults = new HashMap();
        this.logger.info("Testing retrievers prior to evaluation...");
        RetrieverEvaluationUtils.testRetrievers(this.trainingObjectPool, this.retrievers);
        this.logger.info("Performing MAC retrieval with every retriever...");
        ArrayList arrayList = new ArrayList();
        int i = 1;
        for (Map.Entry<String, Retriever<TCase, Query>> entry : this.retrievers.entrySet()) {
            Retriever<TCase, Query> value = entry.getValue();
            value.setObjectPool(this.trainingObjectPool.getTrainPool());
            DataObjectIterator<TCase> it = this.trainingObjectPool.getTestPool().iterator();
            while (it.hasNext()) {
                DataObject dataObject = (DataObject) it.next();
                this.logger.debug("... retrieval " + i + "/" + (this.retrievers.size() * this.trainingObjectPool.getTestPool().size()) + " (" + entry.getKey() + ")");
                Query newQuery = value.newQuery();
                newQuery.setQueryObject(dataObject);
                newQuery.setNumberOfResults(this.trainingObjectPool.getTrainPool().size());
                newQuery.setRetrieveCases(true);
                RetrievalResultList perform = value.perform(newQuery);
                if (this.trainingObjectPool.getTrainPool().size() != perform.size()) {
                    throw new RetrieverEvaluationException("Retrieval results do not contain every object of case base!", this.trainingObjectPool.getTrainPool().toString());
                }
                arrayList.add(new MACRetrievalResult(SimpleSimilarityResult.fromRetrievalResultList(perform), perform.getRetrievalTime() / 1000000.0d, entry.getKey()));
                i++;
            }
        }
        HashMap hashMap = new HashMap();
        this.retrievalTimeResultMapCombined = new HashMap<>();
        for (String str : this.retrievers.keySet()) {
            for (int i2 = 0; i2 < this.filterSizes.length; i2++) {
                int i3 = i2;
                this.metrics.forEach(evalMetric -> {
                    hashMap.put(new RetrieverFSKMetricKeyPair(str, evalMetric, this.ks[i3], this.filterSizes[i3]), new ArrayList());
                });
                this.retrievalTimeResultMapCombined.put(new RetrieverFSKKeyPair(str, this.ks[i2], this.filterSizes[i2]), new ArrayList());
            }
        }
        this.logger.info("Performing FAC retrieval with every retriever for every filter size and k...");
        int i4 = 0;
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            MACRetrievalResult mACRetrievalResult = (MACRetrievalResult) it2.next();
            SimpleSimilarityResult simpleSimResult = mACRetrievalResult.getSimpleSimResult();
            String retriever = mACRetrievalResult.getRetriever();
            String queryID = simpleSimResult.getQueryID();
            TCase object = this.trainingObjectPool.getTestPool().getObject(queryID);
            if (i4 % this.retrievers.size() == 0) {
                this.logger.debug("==== " + queryID + " and retriever \"" + retriever + "\" (" + ((i4 / this.retrievers.size()) + 1) + "/" + arrayList.size() + ")");
            }
            for (int i5 = 0; i5 < this.filterSizes.length; i5++) {
                int i6 = this.filterSizes[i5];
                int i7 = this.ks[i5];
                WriteableObjectPool newObjectPool = ObjectPoolFactory.newObjectPool();
                Iterator<IdSimilarityPair> it3 = simpleSimResult.getCaseSimilarities().iterator();
                while (it3.hasNext()) {
                    newObjectPool.store(this.trainingObjectPool.getTrainPool().getObject(it3.next().getId()));
                    if (newObjectPool.size() == i6) {
                        break;
                    }
                }
                if (newObjectPool.size() < i6) {
                    throw new RetrieverEvaluationException("A-Star object pool must exactly have the size of filter size (" + i6 + "!=" + newObjectPool.size() + ")!", this.trainingObjectPool.getTrainPool().toString());
                }
                this.facGTRetriever.setObjectPool(newObjectPool);
                Query newQuery2 = this.facGTRetriever.newQuery();
                newQuery2.setQueryObject(object);
                newQuery2.setNumberOfResults(i7);
                newQuery2.setRetrieveCases(true);
                RetrievalResultList perform2 = this.facGTRetriever.perform(newQuery2);
                if (i7 != perform2.size()) {
                    throw new RetrieverEvaluationException("FAC A-Star retrieval results must exactly have the size of k!", this.trainingObjectPool.getTrainPool().toString());
                }
                SimpleSimilarityResult fromRetrievalResultList = SimpleSimilarityResult.fromRetrievalResultList(perform2);
                SimpleSimilarityResult orElseThrow = this.groundTruthSimilarities.stream().filter(simpleSimilarityResult -> {
                    return simpleSimilarityResult.getQueryID().equals(queryID);
                }).findFirst().orElseThrow();
                this.retrievalTimeResultMapCombined.get(new RetrieverFSKKeyPair(retriever, i7, i6)).add(Double.valueOf(mACRetrievalResult.getTimeMs() + (perform2.getRetrievalTime() / 1000000.0d)));
                ((Stream) this.metrics.stream().parallel()).forEach(evalMetric2 -> {
                    if (evalMetric2 instanceof KEvalMetric) {
                        ((KEvalMetric) evalMetric2).setK(i7);
                    }
                    ((List) hashMap.get(new RetrieverFSKMetricKeyPair(retriever, evalMetric2, i7, i6))).add(Double.valueOf(evalMetric2.computeEvalMetric(orElseThrow, fromRetrievalResultList)));
                });
            }
            i4++;
        }
        for (int i8 = 0; i8 < this.filterSizes.length; i8++) {
            int i9 = this.filterSizes[i8];
            int i10 = this.ks[i8];
            Iterator it4 = new ArrayList(this.retrievers.keySet()).iterator();
            while (it4.hasNext()) {
                String str2 = (String) it4.next();
                Iterator<EvalMetric> it5 = this.metrics.iterator();
                while (it5.hasNext()) {
                    RetrieverFSKMetricKeyPair retrieverFSKMetricKeyPair = new RetrieverFSKMetricKeyPair(str2, it5.next(), i10, i9);
                    this.metricResults.put(retrieverFSKMetricKeyPair, Double.valueOf(((List) hashMap.get(retrieverFSKMetricKeyPair)).stream().mapToDouble(d -> {
                        return d.doubleValue();
                    }).average().getAsDouble()));
                }
            }
        }
        return this.metricResults;
    }

    public void setFilterSizes(int[] iArr) {
        this.filterSizes = iArr;
    }

    public void setKs(int[] iArr) {
        this.ks = iArr;
    }

    public void setFacGTRetriever(Retriever<TCase, Query> retriever) {
        this.facGTRetriever = retriever;
    }

    @Override // de.uni_trier.wi2.procake.utils.eval.RetrieverEvaluation
    public String writeMetricResultsAsCSV(OutputStream outputStream) throws IOException, RetrieverEvaluationException {
        this.logger.info("\n=== OVERALL RESULTS ===");
        ArrayList arrayList = new ArrayList(this.retrievers.keySet());
        this.logger.info("Writing CSV results to file...");
        LinkedList linkedList = new LinkedList();
        linkedList.add("fs");
        linkedList.add(SimilarityTags.ATT_K);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            for (EvalMetric evalMetric : this.metrics) {
                linkedList.add(str);
            }
            linkedList.add(str);
        }
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add("fs");
        linkedList2.add(SimilarityTags.ATT_K);
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            Iterator<EvalMetric> it3 = this.metrics.iterator();
            while (it3.hasNext()) {
                linkedList2.add(it3.next().getMetricName());
            }
            linkedList2.add("time (ms)");
        }
        AsciiTable asciiTable = new AsciiTable();
        asciiTable.addRule();
        asciiTable.addRow(linkedList);
        asciiTable.addRule();
        asciiTable.addRow(linkedList2);
        asciiTable.addRule();
        this.decimalFormat.setDecimalFormatSymbols(DecimalFormatSymbols.getInstance(Locale.ENGLISH));
        this.decimalFormat.setRoundingMode(RoundingMode.HALF_UP);
        StringWriter stringWriter = new StringWriter();
        CSVPrinter cSVPrinter = new CSVPrinter(new TeeWriter(new Writer[]{new OutputStreamWriter(outputStream), stringWriter}), CSVFormat.EXCEL.builder().setHeader((String[]) linkedList.toArray(new String[0])).build());
        cSVPrinter.printRecord(linkedList2);
        for (int i = 0; i < this.filterSizes.length; i++) {
            int i2 = this.filterSizes[i];
            int i3 = this.ks[i];
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(String.valueOf(i2));
            arrayList2.add(String.valueOf(i3));
            Iterator it4 = arrayList.iterator();
            while (it4.hasNext()) {
                String str2 = (String) it4.next();
                Iterator<EvalMetric> it5 = this.metrics.iterator();
                while (it5.hasNext()) {
                    arrayList2.add(this.decimalFormat.format(this.metricResults.get(new RetrieverFSKMetricKeyPair(str2, it5.next(), i3, i2))));
                }
                arrayList2.add(this.decimalFormat.format(this.retrievalTimeResultMapCombined.get(new RetrieverFSKKeyPair(str2, i3, i2)).stream().mapToDouble(d -> {
                    return d.doubleValue();
                }).average().getAsDouble()));
            }
            cSVPrinter.printRecord(arrayList2);
            asciiTable.addRow(arrayList2);
            asciiTable.addRule();
        }
        cSVPrinter.close(true);
        stringWriter.close();
        this.logger.debug("\n=== Evaluation results as CSV ===");
        String stringWriter2 = stringWriter.toString();
        this.logger.debug("\n" + stringWriter2);
        this.logger.debug("\n" + asciiTable.render());
        return stringWriter2;
    }
}
