package org.neo4j.gds.ml.models.randomforest;

import com.carrotsearch.hppc.ObjectArrayList;
import java.util.List;
import java.util.function.LongUnaryOperator;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainer;
import org.neo4j.gds.ml.decisiontree.TreeNode;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.ml.models.TrainingMethod;

@ValueClass
/* loaded from: input_file:org/neo4j/gds/ml/models/randomforest/RandomForestRegressorData.class */
public interface RandomForestRegressorData extends Regressor.RegressorData {
    List<DecisionTreePredictor<Double>> decisionTrees();

    @Override // org.neo4j.gds.ml.models.Regressor.RegressorData
    @Value.Derived
    default TrainingMethod trainerMethod() {
        return TrainingMethod.RandomForestClassification;
    }

    static MemoryEstimation memoryEstimation(LongUnaryOperator longUnaryOperator, RandomForestTrainerConfig randomForestTrainerConfig) {
        return MemoryEstimations.builder("Random forest model data").rangePerNode("Decision trees", j -> {
            return MemoryRange.of(MemoryUsage.sizeOfInstance(ObjectArrayList.class)).add(DecisionTreeTrainer.estimateTree(randomForestTrainerConfig, longUnaryOperator.applyAsLong(j), TreeNode.leafMemoryEstimation(Double.class)).times(randomForestTrainerConfig.numberOfDecisionTrees()));
        }).build();
    }
}
