package de.uni_mannheim.informatik.dws.melt.matching_ml.python;

import de.uni_mannheim.informatik.dws.melt.matching_base.Filter;
import de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherYAAAJena;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Correspondence;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.CorrespondenceRelation;
import java.io.File;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.jena.ontology.OntModel;
import org.apache.jena.riot.web.HttpNames;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/melt/matching_ml/python/MachineLearningScikitFilter.class */
public class MachineLearningScikitFilter extends MatcherYAAAJena implements Filter {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) MachineLearningScikitFilter.class);
    private MatcherYAAAJena trainingGenerator;
    private List<String> confidenceNames;
    private int crossValidationNumber;
    private int numberOfParallelJobs;

    public MachineLearningScikitFilter() {
        this(new MatcherYAAAJena() { // from class: de.uni_mannheim.informatik.dws.melt.matching_ml.python.MachineLearningScikitFilter.1
            @Override // de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherYAAAJena, de.uni_mannheim.informatik.dws.melt.matching_base.IMatcher
            public Alignment match(OntModel ontModel, OntModel ontModel2, Alignment alignment, Properties properties) throws Exception {
                return alignment;
            }
        });
    }

    public MachineLearningScikitFilter(final Alignment alignment) {
        this(new MatcherYAAAJena() { // from class: de.uni_mannheim.informatik.dws.melt.matching_ml.python.MachineLearningScikitFilter.2
            @Override // de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherYAAAJena, de.uni_mannheim.informatik.dws.melt.matching_base.IMatcher
            public Alignment match(OntModel ontModel, OntModel ontModel2, Alignment alignment2, Properties properties) throws Exception {
                return Alignment.this;
            }
        });
    }

    public MachineLearningScikitFilter(final Alignment alignment, int i, int i2) {
        this(new MatcherYAAAJena() { // from class: de.uni_mannheim.informatik.dws.melt.matching_ml.python.MachineLearningScikitFilter.3
            @Override // de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherYAAAJena, de.uni_mannheim.informatik.dws.melt.matching_base.IMatcher
            public Alignment match(OntModel ontModel, OntModel ontModel2, Alignment alignment2, Properties properties) throws Exception {
                return Alignment.this;
            }
        }, null, i, i2);
    }

    public MachineLearningScikitFilter(MatcherYAAAJena matcherYAAAJena) {
        this(matcherYAAAJena, null);
    }

    public MachineLearningScikitFilter(MatcherYAAAJena matcherYAAAJena, List<String> list) {
        this(matcherYAAAJena, list, 5, 1);
    }

    public MachineLearningScikitFilter(MatcherYAAAJena matcherYAAAJena, int i, int i2) {
        this(matcherYAAAJena, null, i, i2);
    }

    public MachineLearningScikitFilter(MatcherYAAAJena matcherYAAAJena, List<String> list, int i, int i2) {
        this.trainingGenerator = matcherYAAAJena;
        this.confidenceNames = list;
        this.crossValidationNumber = i;
        this.numberOfParallelJobs = i2;
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherYAAAJena, de.uni_mannheim.informatik.dws.melt.matching_base.IMatcher
    public Alignment match(OntModel ontModel, OntModel ontModel2, Alignment alignment, Properties properties) throws Exception {
        return trainAndApplyMLModel(this.trainingGenerator.match(ontModel, ontModel2, alignment, properties), alignment, this.confidenceNames, this.crossValidationNumber, this.numberOfParallelJobs);
    }

    public static Alignment trainAndApplyMLModel(Alignment alignment, Alignment alignment2, List<String> list, int i, int i2) {
        if (list == null || list.isEmpty()) {
            list = new ArrayList(alignment.getDistinctCorrespondenceConfidenceKeys());
        }
        if (list.isEmpty()) {
            LOGGER.warn("No attributes available for learning. Return unfiltered alignment.");
            return alignment2;
        }
        try {
            File createTempFile = File.createTempFile("trainingsFile", ".csv");
            writeDataset(new ArrayList(alignment), createTempFile, true, list);
            File createTempFile2 = File.createTempFile("testFile", ".csv");
            ArrayList arrayList = new ArrayList(alignment2);
            writeDataset(arrayList, createTempFile2, false, list);
            List<Integer> learnAndApplyMLModel = PythonServer.getInstance().learnAndApplyMLModel(createTempFile, createTempFile2, i, i2);
            createTempFile.delete();
            createTempFile2.delete();
            return filterAlignment(alignment2, arrayList, learnAndApplyMLModel);
        } catch (Exception e) {
            LOGGER.error("learnAndApplyMLModel failed. Return unfiltered alignment.", (Throwable) e);
            return alignment2;
        }
    }

    public static List<String> trainAndStoreMLModel(Alignment alignment, File file, List<String> list, int i, int i2) {
        if (list == null || list.isEmpty()) {
            list = new ArrayList(alignment.getDistinctCorrespondenceConfidenceKeys());
        }
        if (list.isEmpty()) {
            LOGGER.error("No attributes available for learning. Did not create any model file.");
            return list;
        }
        try {
            File createTempFile = File.createTempFile("trainingsFile", ".csv");
            writeDataset(new ArrayList(alignment), createTempFile, true, list);
            PythonServer.getInstance().trainAndStoreMLModel(createTempFile, file, i, i2);
            createTempFile.delete();
        } catch (Exception e) {
            LOGGER.error("trainAndStoreMLModel failed. Did not create any model file.", (Throwable) e);
        }
        return list;
    }

    public static Alignment applyStoredMLModel(File file, Alignment alignment, List<String> list) {
        if (!file.exists()) {
            LOGGER.error("Model file does not exist. Return unfiltered alignment.");
            return alignment;
        }
        if (list.isEmpty()) {
            LOGGER.error("No confidence names for prediction are provided. Return unfiltered alignment.");
            return alignment;
        }
        try {
            File file2 = new File("testFile.csv");
            ArrayList arrayList = new ArrayList(alignment);
            writeDataset(arrayList, file2, false, list);
            List<Integer> applyStoredMLModel = PythonServer.getInstance().applyStoredMLModel(file, file2);
            file2.delete();
            return filterAlignment(alignment, arrayList, applyStoredMLModel);
        } catch (Exception e) {
            LOGGER.error("applyStoredMLModel failed. Return unfiltered alignment.", (Throwable) e);
            return alignment;
        }
    }

    private static Alignment filterAlignment(Alignment alignment, List<Correspondence> list, List<Integer> list2) {
        if (list.size() != list2.size()) {
            LOGGER.warn("Size of correspondences and predictions do not have the same size. Return unfiltered alignment.");
            return alignment;
        }
        Alignment alignment2 = new Alignment(alignment, false);
        for (int i = 0; i < list2.size(); i++) {
            if (list2.get(i).intValue() == 1) {
                alignment2.add(list.get(i));
            }
        }
        return alignment2;
    }

    private static void writeDataset(List<Correspondence> list, File file, boolean z, List<String> list2) throws IOException {
        CSVPrinter print = CSVFormat.DEFAULT.print(file, StandardCharsets.UTF_8);
        Throwable th = null;
        try {
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(list2);
            if (z) {
                arrayList.add(HttpNames.paramTarget);
            }
            print.printRecord(arrayList);
            int i = 0;
            int i2 = 0;
            for (Correspondence correspondence : list) {
                ArrayList arrayList2 = new ArrayList(list2.size());
                Iterator<String> it2 = list2.iterator();
                while (it2.hasNext()) {
                    arrayList2.add(Double.valueOf(correspondence.getAdditionalConfidenceOrDefault(it2.next(), 0.0d)));
                }
                if (z) {
                    if (correspondence.getRelation() == CorrespondenceRelation.EQUIVALENCE) {
                        arrayList2.add(1);
                        i++;
                    } else {
                        arrayList2.add(0);
                        i2++;
                    }
                }
                print.printRecord(arrayList2);
            }
            if (z) {
                LOGGER.info("Created training file with {} positive and {} negative examples ({} attribute(s)).", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(list2.size()));
            } else {
                LOGGER.info("Created predict file with {} examples ({} attribute(s)).", Integer.valueOf(list.size()), Integer.valueOf(list2.size()));
            }
            if (print != null) {
                if (0 == 0) {
                    print.close();
                    return;
                }
                try {
                    print.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
        } catch (Throwable th3) {
            if (print != null) {
                if (0 != 0) {
                    try {
                        print.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    print.close();
                }
            }
            throw th3;
        }
    }
}
