package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.splitting.EdgeSplitter;

/* loaded from: input_file:org/neo4j/gds/ml/decisiontree/MeanSquaredError.class */
public class MeanSquaredError implements DecisionTreeLoss {
    private final HugeDoubleArray targets;

    public MeanSquaredError(HugeDoubleArray hugeDoubleArray) {
        this.targets = hugeDoubleArray;
    }

    public static MemoryRange memoryEstimation() {
        return MemoryRange.of(MemoryUsage.sizeOfInstance(MeanSquaredError.class));
    }

    @Override // org.neo4j.gds.ml.decisiontree.DecisionTreeLoss
    public double splitLoss(Groups groups, GroupSizes groupSizes) {
        return groupMeanSquaredError(groups.left(), groupSizes.left(), groupMean(groups.left(), groupSizes.left())) + groupMeanSquaredError(groups.right(), groupSizes.right(), groupMean(groups.right(), groupSizes.right()));
    }

    private double groupMean(HugeLongArray hugeLongArray, long j) {
        if (j == 0) {
            return EdgeSplitter.NEGATIVE;
        }
        double d = 0.0d;
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return d / j;
            }
            d += this.targets.get(hugeLongArray.get(j3));
            j2 = j3 + 1;
        }
    }

    private double groupMeanSquaredError(HugeLongArray hugeLongArray, long j, double d) {
        if (j == 0) {
            return EdgeSplitter.NEGATIVE;
        }
        double d2 = 0.0d;
        long j2 = 0;
        while (true) {
            long j3 = j2;
            if (j3 >= j) {
                return d2 / j;
            }
            double d3 = d - this.targets.get(hugeLongArray.get(j3));
            d2 += d3 * d3;
            j2 = j3 + 1;
        }
    }
}
