package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.decisiontree.GiniIndex;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/decisiontree/DecisionTreeClassifierTrainer.class */
public class DecisionTreeClassifierTrainer extends DecisionTreeTrainer<Integer> {
    private final HugeIntArray allLabels;
    private final int numberOfClasses;
    static final /* synthetic */ boolean $assertionsDisabled;

    public DecisionTreeClassifierTrainer(ImpurityCriterion impurityCriterion, Features features, HugeIntArray hugeIntArray, int i, DecisionTreeTrainerConfig decisionTreeTrainerConfig, FeatureBagger featureBagger) {
        super(features, decisionTreeTrainerConfig, impurityCriterion, featureBagger);
        this.numberOfClasses = i;
        if (!$assertionsDisabled && hugeIntArray.size() != features.size()) {
            throw new AssertionError();
        }
        this.allLabels = hugeIntArray;
    }

    public static MemoryRange memoryEstimation(DecisionTreeTrainerConfig decisionTreeTrainerConfig, long j, int i) {
        return MemoryRange.of(Estimate.sizeOfInstance(DecisionTreeClassifierTrainer.class)).add(DecisionTreeTrainer.estimateTree(decisionTreeTrainerConfig, j, TreeNode.leafMemoryEstimation(Integer.class), GiniIndex.GiniImpurityData.memoryEstimation(i))).add(Estimate.sizeOfLongArray(i));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.decisiontree.DecisionTreeTrainer
    public Integer toTerminal(Group group) {
        long[] jArr = new long[this.numberOfClasses];
        HugeLongArray array = group.array();
        long startIdx = group.startIdx();
        while (true) {
            long j = startIdx;
            if (j >= group.startIdx() + group.size()) {
                break;
            }
            int i = this.allLabels.get(array.get(j));
            jArr[i] = jArr[i] + 1;
            startIdx = j + 1;
        }
        long j2 = -1;
        int i2 = 0;
        for (int i3 = 0; i3 < jArr.length; i3++) {
            if (jArr[i3] > j2) {
                j2 = jArr[i3];
                i2 = i3;
            }
        }
        return Integer.valueOf(i2);
    }

    static {
        $assertionsDisabled = !DecisionTreeClassifierTrainer.class.desiredAssertionStatus();
    }
}
