package de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers;

import de.uni_mannheim.informatik.dws.melt.matching_base.Filter;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServer;
import java.io.File;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/melt/matching_ml/python/nlptransformers/TransformersFineTunerHpSearch.class */
public class TransformersFineTunerHpSearch extends TransformersFineTuner implements Filter {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) TransformersFineTunerHpSearch.class);
    private int numberOfTrials;
    private float testSize;
    private TransformersOptimizingMetric optimizingMetric;
    private TransformersHpSearchSpace hpSpace;
    private TransformersHpSearchSpace hpMutations;

    public TransformersFineTunerHpSearch(TextExtractor textExtractor, String str, File file) {
        super(textExtractor, str, file);
        this.numberOfTrials = 10;
        this.testSize = 0.33f;
        this.optimizingMetric = TransformersOptimizingMetric.AUC;
        this.hpSpace = TransformersHpSearchSpace.getDefaultHpSpace();
        this.hpMutations = TransformersHpSearchSpace.getDefaultHpSpaceMutations();
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.TransformersFineTuner, de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.TransformersBaseFineTuner
    public File finetuneModel(File file) throws Exception {
        if (isAdjustMaxBatchSize()) {
            int maximumPerDeviceTrainBatchSize = getMaximumPerDeviceTrainBatchSize();
            ArrayList arrayList = new ArrayList();
            if (maximumPerDeviceTrainBatchSize >= 4) {
                if (maximumPerDeviceTrainBatchSize >= 8) {
                    int i = 4;
                    while (true) {
                        int i2 = i;
                        if (i2 > maximumPerDeviceTrainBatchSize) {
                            break;
                        }
                        arrayList.add(Integer.valueOf(i2));
                        i = i2 * 2;
                    }
                } else {
                    int i3 = 2;
                    while (true) {
                        int i4 = i3;
                        if (i4 > maximumPerDeviceTrainBatchSize) {
                            break;
                        }
                        arrayList.add(Integer.valueOf(i4));
                        i3 = i4 * 2;
                    }
                }
            } else {
                int i5 = 1;
                while (true) {
                    int i6 = i5;
                    if (i6 > maximumPerDeviceTrainBatchSize) {
                        break;
                    }
                    arrayList.add(Integer.valueOf(i6));
                    i5 = i6 * 2;
                }
            }
            LOGGER.info("Set the hyper parameter search space for \"per_device_train_batch_size\" to: {}", arrayList);
            this.hpSpace.choice("per_device_train_batch_size", arrayList);
            this.hpMutations.choice("per_device_train_batch_size", arrayList);
        }
        PythonServer.getInstance().transformersFineTuningHpSearch(this, file);
        return this.resultingModelLocation;
    }

    public int getNumberOfTrials() {
        return this.numberOfTrials;
    }

    public void setNumberOfTrials(int i) {
        this.numberOfTrials = i;
    }

    public float getTestSize() {
        return this.testSize;
    }

    public void setTestSize(float f) {
        if (f < 0.0d || f > 1.0d) {
            throw new IllegalArgumentException("Test size should be between zero and one");
        }
        this.testSize = f;
    }

    public TransformersOptimizingMetric getOptimizingMetric() {
        return this.optimizingMetric;
    }

    public void setOptimizingMetric(TransformersOptimizingMetric transformersOptimizingMetric) {
        this.optimizingMetric = transformersOptimizingMetric;
    }

    public TransformersHpSearchSpace getHpSpace() {
        return this.hpSpace;
    }

    public void setHpSpace(TransformersHpSearchSpace transformersHpSearchSpace) {
        if (transformersHpSearchSpace == null) {
            throw new IllegalArgumentException("HpSpace should not be null.");
        }
        this.hpSpace = transformersHpSearchSpace;
    }

    public TransformersHpSearchSpace getHpMutations() {
        return this.hpMutations;
    }

    public void setHpMutations(TransformersHpSearchSpace transformersHpSearchSpace) {
        if (transformersHpSearchSpace == null) {
            throw new IllegalArgumentException("HpMutations should not be null.");
        }
        this.hpMutations = transformersHpSearchSpace;
    }
}
