package org.neo4j.gds.ml.models;

import java.util.Optional;
import java.util.function.LongUnaryOperator;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainer;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainConfig;
import org.neo4j.gds.ml.models.mlp.MLPClassifierTrainer;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainer;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/ml/models/ClassifierTrainerFactory.class */
public final class ClassifierTrainerFactory {

    /* renamed from: org.neo4j.gds.ml.models.ClassifierTrainerFactory$1, reason: invalid class name */
    /* loaded from: input_file:org/neo4j/gds/ml/models/ClassifierTrainerFactory$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$neo4j$gds$ml$api$TrainingMethod = new int[TrainingMethod.values().length];

        static {
            try {
                $SwitchMap$org$neo4j$gds$ml$api$TrainingMethod[TrainingMethod.LogisticRegression.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$neo4j$gds$ml$api$TrainingMethod[TrainingMethod.RandomForestClassification.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$neo4j$gds$ml$api$TrainingMethod[TrainingMethod.MLPClassification.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    private ClassifierTrainerFactory() {
    }

    public static ClassifierTrainer create(TrainerConfig trainerConfig, int i, TerminationFlag terminationFlag, ProgressTracker progressTracker, LogLevel logLevel, int i2, Optional<Long> optional, boolean z, ModelSpecificMetricsHandler modelSpecificMetricsHandler) {
        switch (AnonymousClass1.$SwitchMap$org$neo4j$gds$ml$api$TrainingMethod[trainerConfig.method().ordinal()]) {
            case 1:
                return new LogisticRegressionTrainer(i2, (LogisticRegressionTrainConfig) trainerConfig, i, z, terminationFlag, progressTracker, logLevel);
            case 2:
                return new RandomForestClassifierTrainer(i2, i, (RandomForestClassifierTrainerConfig) trainerConfig, optional, progressTracker, logLevel, terminationFlag, modelSpecificMetricsHandler);
            case 3:
                return new MLPClassifierTrainer(i, (MLPClassifierTrainConfig) trainerConfig, optional, progressTracker, logLevel, terminationFlag, i2);
            default:
                throw new IllegalStateException("No such training method.");
        }
    }

    public static MemoryEstimation memoryEstimation(TrainerConfig trainerConfig, LongUnaryOperator longUnaryOperator, int i, MemoryRange memoryRange, boolean z) {
        switch (AnonymousClass1.$SwitchMap$org$neo4j$gds$ml$api$TrainingMethod[trainerConfig.method().ordinal()]) {
            case 1:
                return LogisticRegressionTrainer.memoryEstimation(z, i, memoryRange, ((LogisticRegressionTrainConfig) trainerConfig).batchSize(), longUnaryOperator);
            case 2:
                return RandomForestClassifierTrainer.memoryEstimation(longUnaryOperator, i, memoryRange, (RandomForestClassifierTrainerConfig) trainerConfig);
            case 3:
                return MemoryEstimations.empty();
            default:
                throw new IllegalStateException("No such training method.");
        }
    }
}
