package de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers;

import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractorMap;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServer;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServerException;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Correspondence;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.CorrespondenceRelation;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.jena.ontology.OntModel;
import org.apache.jena.rdf.model.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/melt/matching_ml/python/nlptransformers/SentenceTransformersFineTuner.class */
public class SentenceTransformersFineTuner extends TransformersBaseFineTuner {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) SentenceTransformersFineTuner.class);
    private static final String NEWLINE = System.getProperty("line.separator");
    private float testSize;
    private int trainBatchSize;
    private int testBatchSize;
    private int numberOfEpochs;
    private SentenceTransformersLoss loss;

    public SentenceTransformersFineTuner(TextExtractorMap textExtractorMap, String str, File file) {
        super(textExtractorMap, str, file);
        this.testSize = 0.33f;
        this.trainBatchSize = 64;
        this.testBatchSize = 32;
        this.numberOfEpochs = 5;
        this.loss = SentenceTransformersLoss.CosineSimilarityLoss;
    }

    public SentenceTransformersFineTuner(TextExtractor textExtractor, String str, File file) {
        this(TextExtractorMap.wrapTextExtractor(textExtractor), str, file);
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.TransformersBaseFineTuner
    public File finetuneModel(File file) throws PythonServerException {
        PythonServer.getInstance().sentenceTransformersFineTuning(this, file, null);
        return this.resultingModelLocation;
    }

    public float finetuneModel(File file, File file2) throws PythonServerException {
        return PythonServer.getInstance().sentenceTransformersFineTuning(this, file, file2);
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.TransformersBaseFineTuner
    public int writeTrainingFile(OntModel ontModel, OntModel ontModel2, Alignment alignment, File file, boolean z) throws IOException {
        switch (this.loss) {
            case CosineSimilarityLoss:
                return writeClassificationFormat(ontModel, ontModel2, alignment, file, z);
            case MultipleNegativesRankingLoss:
                return writeTripletFormat(ontModel, ontModel2, alignment, file, z);
            default:
                throw new IOException("Loss is not recognized");
        }
    }

    private int writeTripletFormat(OntModel ontModel, OntModel ontModel2, Alignment alignment, File file, boolean z) throws IOException {
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        HashMap hashMap = new HashMap();
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file, z), StandardCharsets.UTF_8));
        Throwable th = null;
        try {
            try {
                for (Correspondence correspondence : alignment.getCorrespondencesRelation(CorrespondenceRelation.EQUIVALENCE)) {
                    i2++;
                    int i4 = 0;
                    Resource resource = ontModel.getResource(correspondence.getEntityOne());
                    Resource resource2 = ontModel2.getResource(correspondence.getEntityTwo());
                    Iterator<Correspondence> it2 = alignment.getCorrespondencesSourceRelation(correspondence.getEntityOne(), CorrespondenceRelation.INCOMPAT).iterator();
                    while (it2.hasNext()) {
                        Resource resource3 = ontModel2.getResource(it2.next().getEntityTwo());
                        i4 += writeOneTriplet(resource, resource2, resource3, hashMap, bufferedWriter);
                        if (this.additionallySwitchSourceTarget) {
                            i4 += writeOneTriplet(resource2, resource, resource3, hashMap, bufferedWriter);
                        }
                    }
                    if (this.additionallySwitchSourceTarget) {
                        Iterator<Correspondence> it3 = alignment.getCorrespondencesTargetRelation(correspondence.getEntityTwo(), CorrespondenceRelation.INCOMPAT).iterator();
                        while (it3.hasNext()) {
                            Resource resource4 = ontModel2.getResource(it3.next().getEntityOne());
                            i4 = i4 + writeOneTriplet(resource2, resource, resource4, hashMap, bufferedWriter) + writeOneTriplet(resource, resource2, resource4, hashMap, bufferedWriter);
                        }
                    }
                    if (i4 == 0) {
                        i++;
                    }
                    i3 += i4;
                }
                if (bufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            bufferedWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedWriter.close();
                    }
                }
                LOGGER.info("Wrote {} triplet training EXAMPLES. The initial ALIGNMENT contains {} positive correspondences. {} of those correspondences are not used due to insufficient textual data or non existent negatives (the negatives should use the INCOMPAT relation).", Integer.valueOf(i3), Integer.valueOf(i2), Integer.valueOf(i));
                return i3;
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedWriter != null) {
                if (th != null) {
                    try {
                        bufferedWriter.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedWriter.close();
                }
            }
            throw th3;
        }
    }

    private int writeOneTriplet(Resource resource, Resource resource2, Resource resource3, Map<Resource, Map<String, Set<String>>> map, Writer writer) throws IOException {
        int i = 0;
        Map<String, Set<String>> textualRepresentation = getTextualRepresentation(resource, map);
        Map<String, Set<String>> textualRepresentation2 = getTextualRepresentation(resource2, map);
        Map<String, Set<String>> textualRepresentation3 = getTextualRepresentation(resource3, map);
        for (Map.Entry<String, Set<String>> entry : textualRepresentation.entrySet()) {
            for (String str : textualRepresentation2.get(entry.getKey())) {
                for (String str2 : textualRepresentation3.get(entry.getKey())) {
                    Iterator<String> it2 = entry.getValue().iterator();
                    while (it2.hasNext()) {
                        i++;
                        writer.write(StringEscapeUtils.escapeCsv(it2.next()) + "," + StringEscapeUtils.escapeCsv(str) + "," + StringEscapeUtils.escapeCsv(str2) + NEWLINE);
                    }
                }
            }
        }
        return i;
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.TransformersBase
    public void setTrainingArguments(TransformersTrainerArguments transformersTrainerArguments) {
        throw new IllegalArgumentException("Training arguments are not used in SentenceTransformersFineTuner.");
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.TransformersBase
    public void setUsingTensorflow(boolean z) {
        if (z) {
            throw new IllegalArgumentException("SentenceTransformersFineTuner only work with Pytorch. Do not set usingTensorflow to true.");
        }
    }

    public float getTestSize() {
        return this.testSize;
    }

    public void setTestSize(float f) {
        if (f < 0.0d || f > 1.0d) {
            throw new IllegalArgumentException("Test size should be between zero and one");
        }
        this.testSize = f;
    }

    public int getTrainBatchSize() {
        return this.trainBatchSize;
    }

    public void setTrainBatchSize(int i) {
        this.trainBatchSize = i;
    }

    public int getTestBatchSize() {
        return this.testBatchSize;
    }

    public void setTestBatchSize(int i) {
        this.testBatchSize = i;
    }

    public int getNumberOfEpochs() {
        return this.numberOfEpochs;
    }

    public void setNumberOfEpochs(int i) {
        this.numberOfEpochs = i;
    }

    public SentenceTransformersLoss getLoss() {
        return this.loss;
    }

    public void setLoss(SentenceTransformersLoss sentenceTransformersLoss) {
        this.loss = sentenceTransformersLoss;
    }
}
