package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.decisiontree.SplitMeanSquaredError;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/decisiontree/DecisionTreeRegressorTrainer.class */
public class DecisionTreeRegressorTrainer extends DecisionTreeTrainer<Double> {
    private final HugeDoubleArray targets;
    static final /* synthetic */ boolean $assertionsDisabled;

    public DecisionTreeRegressorTrainer(ImpurityCriterion impurityCriterion, Features features, HugeDoubleArray hugeDoubleArray, DecisionTreeTrainerConfig decisionTreeTrainerConfig, FeatureBagger featureBagger) {
        super(features, decisionTreeTrainerConfig, impurityCriterion, featureBagger);
        if (!$assertionsDisabled && hugeDoubleArray.size() != features.size()) {
            throw new AssertionError();
        }
        this.targets = hugeDoubleArray;
    }

    public static MemoryRange memoryEstimation(DecisionTreeTrainerConfig decisionTreeTrainerConfig, long j) {
        return MemoryRange.of(Estimate.sizeOfInstance(DecisionTreeRegressorTrainer.class)).add(DecisionTreeTrainer.estimateTree(decisionTreeTrainerConfig, j, TreeNode.leafMemoryEstimation(Double.class), SplitMeanSquaredError.MSEImpurityData.memoryEstimation()));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.decisiontree.DecisionTreeTrainer
    public Double toTerminal(Group group) {
        HugeLongArray array = group.array();
        double d = 0.0d;
        long startIdx = group.startIdx();
        while (true) {
            long j = startIdx;
            if (j >= group.startIdx() + group.size()) {
                return Double.valueOf(d / group.size());
            }
            d += this.targets.get(array.get(j));
            startIdx = j + 1;
        }
    }

    static {
        $assertionsDisabled = !DecisionTreeRegressorTrainer.class.desiredAssertionStatus();
    }
}
