package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.decisiontree.DecisionTreeLoss;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/decisiontree/DecisionTreeClassifierTrainer.class */
public class DecisionTreeClassifierTrainer<LOSS extends DecisionTreeLoss> extends DecisionTreeTrainer<LOSS, Integer> {
    private final HugeLongArray allLabels;
    private final LocalIdMap classIdMap;
    static final /* synthetic */ boolean $assertionsDisabled;

    public DecisionTreeClassifierTrainer(LOSS loss, Features features, HugeLongArray hugeLongArray, LocalIdMap localIdMap, DecisionTreeTrainerConfig decisionTreeTrainerConfig, FeatureBagger featureBagger) {
        super(features, decisionTreeTrainerConfig, loss, featureBagger);
        this.classIdMap = localIdMap;
        if (!$assertionsDisabled && hugeLongArray.size() != features.size()) {
            throw new AssertionError();
        }
        this.allLabels = hugeLongArray;
    }

    public static MemoryRange memoryEstimation(int i, int i2, long j, long j2, int i3) {
        return MemoryRange.of(MemoryUsage.sizeOfInstance(DecisionTreeClassifierTrainer.class)).add(DecisionTreeTrainer.estimateTree(i, i2, j, j2, TreeNode.leafMemoryEstimation(Integer.class))).add(MemoryUsage.sizeOfLongArray(i3));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.neo4j.gds.ml.decisiontree.DecisionTreeTrainer
    public Integer toTerminal(ReadOnlyHugeLongArray readOnlyHugeLongArray, long j) {
        if (!$assertionsDisabled && j <= 0) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && readOnlyHugeLongArray.size() < j) {
            throw new AssertionError();
        }
        long[] jArr = new long[this.classIdMap.size()];
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                break;
            }
            int mapped = this.classIdMap.toMapped(this.allLabels.get(readOnlyHugeLongArray.get(j3)));
            jArr[mapped] = jArr[mapped] + 1;
            j2 = j3 + 1;
        }
        long j4 = -1;
        int i = 0;
        for (int i2 = 0; i2 < jArr.length; i2++) {
            if (jArr[i2] > j4) {
                j4 = jArr[i2];
                i = i2;
            }
        }
        return Integer.valueOf(i);
    }

    static {
        $assertionsDisabled = !DecisionTreeClassifierTrainer.class.desiredAssertionStatus();
    }
}
