package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeSerialIndirectMergeSort;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainer;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/decisiontree/Splitter.class */
public class Splitter {
    private final ImpurityCriterion impurityCriterion;
    private final Features features;
    private final FeatureBagger featureBagger;
    private final int minLeafSize;
    private final HugeLongArray sortCache;
    private final ImpurityCriterion.ImpurityData rightImpurityData;

    /* JADX INFO: Access modifiers changed from: package-private */
    public Splitter(long j, ImpurityCriterion impurityCriterion, FeatureBagger featureBagger, Features features, int i) {
        this.featureBagger = featureBagger;
        this.impurityCriterion = impurityCriterion;
        this.features = features;
        this.minLeafSize = i;
        this.sortCache = HugeLongArray.newArray(j);
        this.rightImpurityData = impurityCriterion.groupImpurity(HugeLongArray.of(new long[0]), 0L, 0L);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static long memoryEstimation(long j, long j2) {
        return MemoryUsage.sizeOfInstance(Splitter.class) + HugeLongArray.memoryEstimation(j) + (4 * j2) + (4 * HugeLongArray.memoryEstimation(j));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DecisionTreeTrainer.Split findBestSplit(Group group) {
        int i = -1;
        double d = Double.MAX_VALUE;
        double d2 = Double.MAX_VALUE;
        long j = -1;
        HugeLongArray newArray = HugeLongArray.newArray(group.size());
        HugeLongArray newArray2 = HugeLongArray.newArray(group.size());
        HugeLongArray newArray3 = HugeLongArray.newArray(group.size());
        HugeLongArray newArray4 = HugeLongArray.newArray(group.size());
        ImpurityCriterion.ImpurityData groupImpurity = this.impurityCriterion.groupImpurity(HugeLongArray.of(new long[0]), 0L, 0L);
        ImpurityCriterion.ImpurityData groupImpurity2 = this.impurityCriterion.groupImpurity(HugeLongArray.of(new long[0]), 0L, 0L);
        newArray2.setAll(j2 -> {
            return group.array().get(group.startIdx() + j2);
        });
        newArray2.copyTo(newArray4, group.size());
        for (int i2 : this.featureBagger.sample()) {
            HugeSerialIndirectMergeSort.sort(newArray2, group.size(), j3 -> {
                return this.features.get(j3)[i2];
            }, this.sortCache);
            group.impurityData().copyTo(this.rightImpurityData);
            long j4 = 1;
            while (true) {
                long j5 = j4;
                if (j5 >= this.minLeafSize) {
                    break;
                }
                long j6 = newArray2.get(j5 - 1);
                newArray.set(j5 - 1, j6);
                this.impurityCriterion.decrementalImpurity(j6, this.rightImpurityData);
                j4 = j5 + 1;
            }
            ImpurityCriterion.ImpurityData groupImpurity3 = this.impurityCriterion.groupImpurity(newArray, 0L, this.minLeafSize - 1);
            boolean z = false;
            long j7 = this.minLeafSize;
            while (true) {
                long j8 = j7;
                if (j8 > group.size() - this.minLeafSize) {
                    break;
                }
                long j9 = newArray2.get(j8 - 1);
                newArray.set(j8 - 1, j9);
                this.impurityCriterion.incrementalImpurity(j9, groupImpurity3);
                this.impurityCriterion.decrementalImpurity(j9, this.rightImpurityData);
                double combinedImpurity = this.impurityCriterion.combinedImpurity(groupImpurity3, this.rightImpurityData);
                if (combinedImpurity < d2) {
                    z = true;
                    i = i2;
                    d = this.features.get(j9)[i2];
                    d2 = combinedImpurity;
                    j = j8;
                    groupImpurity3.copyTo(groupImpurity);
                    this.rightImpurityData.copyTo(groupImpurity2);
                }
                j7 = j8 + 1;
            }
            if (z) {
                HugeLongArray hugeLongArray = newArray4;
                newArray4 = newArray2;
                newArray2 = hugeLongArray;
                HugeLongArray hugeLongArray2 = newArray3;
                newArray3 = newArray;
                newArray = hugeLongArray2;
            }
        }
        return ImmutableSplit.of(i, d, ImmutableGroups.of(ImmutableGroup.of(newArray3, 0L, j, groupImpurity), ImmutableGroup.of(newArray4, j, group.size() - j, groupImpurity2)));
    }
}
