package de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers;

import com.googlecode.cqengine.index.support.CloseableIterator;
import de.uni_mannheim.informatik.dws.melt.matching_base.FileUtil;
import de.uni_mannheim.informatik.dws.melt.matching_base.Filter;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractorMap;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServer;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServerException;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Correspondence;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.jena.ontology.OntModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/melt/matching_ml/python/nlptransformers/TransformersFilter.class */
public class TransformersFilter extends TransformersBase implements Filter {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) TransformersFilter.class);
    private static final String NEWLINE = System.getProperty("line.separator");
    private boolean changeClass;
    private BatchSizeOptimization batchSizeOptimization;

    public TransformersFilter(TextExtractor textExtractor, String str) {
        super(textExtractor, str);
        this.changeClass = false;
        this.batchSizeOptimization = BatchSizeOptimization.NONE;
    }

    public TransformersFilter(TextExtractorMap textExtractorMap, String str) {
        super(textExtractorMap, str);
        this.changeClass = false;
        this.batchSizeOptimization = BatchSizeOptimization.NONE;
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherYAAAJena, de.uni_mannheim.informatik.dws.melt.matching_base.IMatcher
    public Alignment match(OntModel ontModel, OntModel ontModel2, Alignment alignment, Properties properties) throws Exception {
        File createFileWithRandomNumber = FileUtil.createFileWithRandomNumber("alignment_transformers_predict", ".txt");
        try {
            Map<Correspondence, List<Integer>> createPredictionFile = createPredictionFile(ontModel, ontModel2, alignment, createFileWithRandomNumber, false);
            try {
                if (createPredictionFile.isEmpty()) {
                    LOGGER.warn("No correspondences have enough text to be processed (the input alignment has {} correspondences) - the input alignment is returned unchanged.", Integer.valueOf(alignment.size()));
                    createFileWithRandomNumber.delete();
                    return alignment;
                }
                LOGGER.info("Run prediction");
                List<Double> predictConfidences = predictConfidences(createFileWithRandomNumber);
                LOGGER.info("Finished prediction");
                for (Map.Entry<Correspondence, List<Integer>> entry : createPredictionFile.entrySet()) {
                    double d = 0.0d;
                    Iterator<Integer> it2 = entry.getValue().iterator();
                    while (it2.hasNext()) {
                        Double d2 = predictConfidences.get(it2.next().intValue());
                        if (d2 == null) {
                            throw new IllegalArgumentException("Could not find a confidence for a given correspondence.");
                        }
                        if (d2.doubleValue() > d) {
                            d = d2.doubleValue();
                        }
                    }
                    entry.getKey().addAdditionalConfidence(getClass(), d);
                }
                return alignment;
            } finally {
                createFileWithRandomNumber.delete();
            }
        } catch (IOException e) {
            LOGGER.warn("Could not write text to prediction file. Return unmodified input alignment.", (Throwable) e);
            createFileWithRandomNumber.delete();
            return alignment;
        }
    }

    public Map<Correspondence, List<Integer>> createPredictionFile(OntModel ontModel, OntModel ontModel2, Alignment alignment, File file, boolean z) throws IOException {
        HashMap hashMap = new HashMap();
        int i = 0;
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file, z), StandardCharsets.UTF_8));
        Throwable th = null;
        try {
            try {
                HashMap hashMap2 = new HashMap();
                CloseableIterator<Correspondence> it2 = alignment.iterator();
                while (it2.hasNext()) {
                    Correspondence next = it2.next();
                    next.addAdditionalConfidence(getClass(), 0.0d);
                    Map<String, Set<String>> textualRepresentation = getTextualRepresentation(ontModel.getResource(next.getEntityOne()), hashMap2);
                    Map<String, Set<String>> textualRepresentation2 = getTextualRepresentation(ontModel2.getResource(next.getEntityTwo()), hashMap2);
                    for (Map.Entry<String, Set<String>> entry : textualRepresentation.entrySet()) {
                        for (String str : textualRepresentation2.get(entry.getKey())) {
                            if (!StringUtils.isBlank(str)) {
                                for (String str2 : entry.getValue()) {
                                    if (!StringUtils.isBlank(str2)) {
                                        bufferedWriter.write(StringEscapeUtils.escapeCsv(str2) + "," + StringEscapeUtils.escapeCsv(str) + NEWLINE);
                                        ((List) hashMap.computeIfAbsent(next, correspondence -> {
                                            return new ArrayList();
                                        })).add(Integer.valueOf(i));
                                        i++;
                                    }
                                }
                            }
                        }
                    }
                }
                if (bufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            bufferedWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedWriter.close();
                    }
                }
                LOGGER.info("Wrote {} examples to prediction file {}", Integer.valueOf(i), file);
                return hashMap;
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedWriter != null) {
                if (th != null) {
                    try {
                        bufferedWriter.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedWriter.close();
                }
            }
            throw th3;
        }
    }

    public List<Double> predictConfidences(File file) throws Exception {
        if (this.batchSizeOptimization != BatchSizeOptimization.NONE) {
            this.trainingArguments.addParameter("per_device_eval_batch_size", Integer.valueOf(getMaximumPerDeviceEvalBatchSize(file)));
        }
        return PythonServer.getInstance().transformersPrediction(this, file);
    }

    protected int getMaximumPerDeviceEvalBatchSize(File file) {
        TransformersTrainerArguments transformersTrainerArguments = this.trainingArguments;
        String str = this.cudaVisibleDevices;
        this.cudaVisibleDevices = getCudaVisibleDevicesButOnlyOneGPU();
        int i = 4;
        List<String> examplesForBatchSizeOptimization = getExamplesForBatchSizeOptimization(file, 8194, this.batchSizeOptimization);
        while (i < 8193) {
            LOGGER.info("Try out per_device_eval_batch_size of {}", Integer.valueOf(i));
            File createFileWithRandomNumber = FileUtil.createFileWithRandomNumber("alignment_transformers_predict_find_max_batch_size", ".txt");
            try {
                try {
                    try {
                        if (!writeExamplesToFile(examplesForBatchSizeOptimization, createFileWithRandomNumber, i)) {
                            int i2 = i / 2;
                            LOGGER.info("File contains too few lines to further increase batch size. Thus use now {}", Integer.valueOf(i2));
                            createFileWithRandomNumber.delete();
                            return i2;
                        }
                        this.trainingArguments = new TransformersTrainerArguments(transformersTrainerArguments);
                        this.trainingArguments.addParameter("per_device_eval_batch_size", Integer.valueOf(i));
                        PythonServer.getInstance().transformersPrediction(this, createFileWithRandomNumber);
                        createFileWithRandomNumber.delete();
                        i *= 2;
                    } catch (Exception e) {
                        LOGGER.warn("Something went wrong during getMaximumPerDeviceEvalBatchSize. Return default of 8", (Throwable) e);
                        this.trainingArguments = transformersTrainerArguments;
                        this.cudaVisibleDevices = str;
                        createFileWithRandomNumber.delete();
                        return 8;
                    }
                } catch (PythonServerException e2) {
                    if (!e2.getMessage().contains("not enough memory") && !e2.getMessage().contains("out of memory")) {
                        LOGGER.warn("Something went wrong in python server during getMaximumPerDeviceEvalBatchSize. Return default of 8", (Throwable) e2);
                        this.trainingArguments = transformersTrainerArguments;
                        this.cudaVisibleDevices = str;
                        createFileWithRandomNumber.delete();
                        return 8;
                    }
                    int i3 = i / 2;
                    LOGGER.info("Found memory error, thus returning batchsize of {}", Integer.valueOf(i3));
                    this.trainingArguments = transformersTrainerArguments;
                    this.cudaVisibleDevices = str;
                    createFileWithRandomNumber.delete();
                    return i3;
                }
            } catch (Throwable th) {
                createFileWithRandomNumber.delete();
                throw th;
            }
        }
        LOGGER.info("It looks like that batch sizes up to 8192 works out which is unusual. If greater batch sizes are possible the code to search max batch size needs to be changed.");
        this.trainingArguments = transformersTrainerArguments;
        return i;
    }

    public boolean isChangeClass() {
        return this.changeClass;
    }

    public void setChangeClass(boolean z) {
        this.changeClass = z;
    }

    public boolean isOptimizeBatchSize() {
        return this.batchSizeOptimization != BatchSizeOptimization.NONE;
    }

    public void setOptimizeBatchSize(boolean z) {
        this.batchSizeOptimization = BatchSizeOptimization.USE_LONGEST_TEXTS;
    }

    public BatchSizeOptimization getBatchSizeOptimization() {
        return this.batchSizeOptimization;
    }

    public void setBatchSizeOptimization(BatchSizeOptimization batchSizeOptimization) {
        this.batchSizeOptimization = batchSizeOptimization;
    }

    public void setOptimizeAll(boolean z) {
        setOptimizeBatchSize(z);
        setOptimizeForMixedPrecisionTraining(z);
    }

    public boolean isOptimizeAll() {
        return isOptimizeBatchSize() && isOptimizeForMixedPrecisionTraining();
    }
}
