package org.neo4j.gds.ml.models;

import java.util.Optional;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionTrainConfig;
import org.neo4j.gds.ml.models.linearregression.LinearRegressionTrainer;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorTrainer;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorTrainerConfig;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/models/RegressionTrainerFactory.class */
public final class RegressionTrainerFactory {
    private RegressionTrainerFactory() {
    }

    public static RegressorTrainer create(TrainerConfig trainerConfig, TerminationFlag terminationFlag, ProgressTracker progressTracker, int i, Optional<Long> optional) {
        switch (trainerConfig.method()) {
            case LinearRegression:
                return new LinearRegressionTrainer(i, (LinearRegressionTrainConfig) trainerConfig, terminationFlag, progressTracker);
            case RandomForestRegression:
                return new RandomForestRegressorTrainer(i, (RandomForestRegressorTrainerConfig) trainerConfig, optional, progressTracker);
            default:
                throw new IllegalStateException(StringFormatting.formatWithLocale("Method %s is not a regression method", new Object[]{trainerConfig.method()}));
        }
    }
}
