package de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers;

import de.uni_mannheim.informatik.dws.melt.matching_base.FileUtil;
import de.uni_mannheim.informatik.dws.melt.matching_base.Filter;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractorMap;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServer;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServerException;
import java.io.File;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/melt/matching_ml/python/nlptransformers/TransformersFineTuner.class */
public class TransformersFineTuner extends TransformersBaseFineTuner implements Filter {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) TransformersFineTuner.class);
    protected BatchSizeOptimization batchSizeOptimization;

    public TransformersFineTuner(TextExtractor textExtractor, String str, File file) {
        super(textExtractor, str, file);
        this.batchSizeOptimization = BatchSizeOptimization.NONE;
    }

    public TransformersFineTuner(TextExtractorMap textExtractorMap, String str, File file) {
        super(textExtractorMap, str, file);
        this.batchSizeOptimization = BatchSizeOptimization.NONE;
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.TransformersBaseFineTuner
    public File finetuneModel(File file) throws Exception {
        if (this.batchSizeOptimization != BatchSizeOptimization.NONE) {
            this.trainingArguments.addParameter("per_device_train_batch_size", Integer.valueOf(getMaximumPerDeviceTrainBatchSize(file)));
        }
        PythonServer.getInstance().transformersFineTuning(this, file);
        return this.resultingModelLocation;
    }

    public int getMaximumPerDeviceTrainBatchSize() {
        if (this.trainingFile == null || !this.trainingFile.exists() || this.trainingFile.length() == 0) {
            throw new IllegalArgumentException("Cannot get maximum per device train batch size because no training file is generated. Did you call the match method before (e.g. in a pipeline)?");
        }
        return getMaximumPerDeviceTrainBatchSize(this.trainingFile);
    }

    public int getMaximumPerDeviceTrainBatchSize(File file) {
        TransformersTrainerArguments transformersTrainerArguments = this.trainingArguments;
        String str = this.cudaVisibleDevices;
        this.cudaVisibleDevices = getCudaVisibleDevicesButOnlyOneGPU();
        int i = 4;
        List<String> examplesForBatchSizeOptimization = getExamplesForBatchSizeOptimization(file, 8194, this.batchSizeOptimization);
        while (i < 8193) {
            LOGGER.info("Try out per_device_train_batch_size of {}", Integer.valueOf(i));
            File createFileWithRandomNumber = FileUtil.createFileWithRandomNumber("alignment_transformers_find_max_batch_size", ".txt");
            try {
                try {
                    if (!writeExamplesToFile(examplesForBatchSizeOptimization, createFileWithRandomNumber, i)) {
                        int i2 = i / 2;
                        LOGGER.info("File contains too few lines to further increase batch size. Thus use now {}", Integer.valueOf(i2));
                        createFileWithRandomNumber.delete();
                        return i2;
                    }
                    this.trainingArguments = new TransformersTrainerArguments(transformersTrainerArguments);
                    this.trainingArguments.addParameter("per_device_train_batch_size", Integer.valueOf(i));
                    this.trainingArguments.addParameter("save_at_end", false);
                    this.trainingArguments.addParameter("max_steps", 1);
                    PythonServer.getInstance().transformersFineTuning(this, createFileWithRandomNumber);
                    createFileWithRandomNumber.delete();
                    i *= 2;
                } catch (PythonServerException e) {
                    if (!e.getMessage().contains("not enough memory") && !e.getMessage().contains("out of memory")) {
                        LOGGER.warn("Something went wrong in python server during getMaximumPerDeviceTrainBatchSize. Return default of 8", (Throwable) e);
                        this.trainingArguments = transformersTrainerArguments;
                        this.cudaVisibleDevices = str;
                        createFileWithRandomNumber.delete();
                        return 8;
                    }
                    int i3 = i / 2;
                    LOGGER.info("Found memory error, thus returning batchsize of {}", Integer.valueOf(i3));
                    this.trainingArguments = transformersTrainerArguments;
                    this.cudaVisibleDevices = str;
                    createFileWithRandomNumber.delete();
                    return i3;
                } catch (Exception e2) {
                    LOGGER.warn("Something went wrong during getMaximumPerDeviceTrainBatchSize. Return default of 8", (Throwable) e2);
                    this.trainingArguments = transformersTrainerArguments;
                    this.cudaVisibleDevices = str;
                    createFileWithRandomNumber.delete();
                    return 8;
                }
            } catch (Throwable th) {
                createFileWithRandomNumber.delete();
                throw th;
            }
        }
        LOGGER.info("It looks like that batch sizes up to 8192 works out which is unusual. If greater batch sizes are possible the code to search max batch size needs to be changed.");
        this.trainingArguments = transformersTrainerArguments;
        this.cudaVisibleDevices = str;
        return i;
    }

    public void addTrainingParameterToMakeTrainingFaster() {
        this.trainingArguments.addParameter("fp16", true);
    }

    public boolean isAdjustMaxBatchSize() {
        return this.batchSizeOptimization != BatchSizeOptimization.NONE;
    }

    public void setAdjustMaxBatchSize(boolean z) {
        if (z) {
            this.batchSizeOptimization = BatchSizeOptimization.USE_LONGEST_TEXTS;
        } else {
            this.batchSizeOptimization = BatchSizeOptimization.NONE;
        }
    }

    public BatchSizeOptimization getBatchSizeOptimization() {
        return this.batchSizeOptimization;
    }

    public void setBatchSizeOptimization(BatchSizeOptimization batchSizeOptimization) {
        this.batchSizeOptimization = batchSizeOptimization;
    }
}
