package org.neo4j.gds.ml.models;

import java.util.Optional;
import java.util.function.LongUnaryOperator;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainConfig;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionTrainer;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainer;
import org.neo4j.gds.ml.models.randomforest.RandomForestTrainerConfig;

/* loaded from: input_file:org/neo4j/gds/ml/models/ClassifierTrainerFactory.class */
public class ClassifierTrainerFactory {
    private ClassifierTrainerFactory() {
    }

    public static ClassifierTrainer create(TrainerConfig trainerConfig, LocalIdMap localIdMap, TerminationFlag terminationFlag, ProgressTracker progressTracker, int i, Optional<Long> optional, boolean z) {
        switch (TrainingMethod.valueOf(trainerConfig.methodName())) {
            case LogisticRegression:
                return new LogisticRegressionTrainer(i, (LogisticRegressionTrainConfig) trainerConfig, localIdMap, z, terminationFlag, progressTracker);
            case RandomForest:
                return new RandomForestClassifierTrainer(i, localIdMap, (RandomForestTrainerConfig) trainerConfig, false, optional, progressTracker, terminationFlag);
            default:
                throw new IllegalStateException("No such training method.");
        }
    }

    public static MemoryEstimation memoryEstimation(TrainerConfig trainerConfig, LongUnaryOperator longUnaryOperator, int i, MemoryRange memoryRange, boolean z) {
        switch (TrainingMethod.valueOf(trainerConfig.methodName())) {
            case LogisticRegression:
                return LogisticRegressionTrainer.memoryEstimation(z, i, memoryRange, ((LogisticRegressionTrainConfig) trainerConfig).batchSize(), longUnaryOperator);
            case RandomForest:
                return RandomForestClassifierTrainer.memoryEstimation(longUnaryOperator, i, memoryRange, (RandomForestTrainerConfig) trainerConfig);
            default:
                throw new IllegalStateException("No such training method.");
        }
    }
}
