package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.negativeSampling.NegativeSampler;

/* loaded from: input_file:org/neo4j/gds/ml/decisiontree/SplitMeanSquaredError.class */
public class SplitMeanSquaredError implements ImpurityCriterion {
    private final HugeDoubleArray targets;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/ml/decisiontree/SplitMeanSquaredError$MSEImpurityData.class */
    public static class MSEImpurityData implements ImpurityCriterion.ImpurityData {
        private double impurity;
        private double sumOfSquares;
        private double sum;
        private long groupSize;

        MSEImpurityData(double d, double d2, double d3, long j) {
            this.impurity = d;
            this.sumOfSquares = d2;
            this.sum = d3;
            this.groupSize = j;
        }

        public static long memoryEstimation() {
            return MemoryUsage.sizeOfInstance(MSEImpurityData.class);
        }

        @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion.ImpurityData
        public double impurity() {
            return this.impurity;
        }

        @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion.ImpurityData
        public long groupSize() {
            return this.groupSize;
        }

        @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion.ImpurityData
        public void copyTo(ImpurityCriterion.ImpurityData impurityData) {
            MSEImpurityData mSEImpurityData = (MSEImpurityData) impurityData;
            mSEImpurityData.setImpurity(impurity());
            mSEImpurityData.setSumOfSquares(sumOfSquares());
            mSEImpurityData.setSum(sum());
            mSEImpurityData.setGroupSize(groupSize());
        }

        public void setGroupSize(long j) {
            this.groupSize = j;
        }

        public void setSum(double d) {
            this.sum = d;
        }

        public void setSumOfSquares(double d) {
            this.sumOfSquares = d;
        }

        public double sum() {
            return this.sum;
        }

        public double sumOfSquares() {
            return this.sumOfSquares;
        }

        public void setImpurity(double d) {
            this.impurity = d;
        }
    }

    public SplitMeanSquaredError(HugeDoubleArray hugeDoubleArray) {
        this.targets = hugeDoubleArray;
    }

    public static MemoryRange memoryEstimation() {
        return MemoryRange.of(MemoryUsage.sizeOfInstance(SplitMeanSquaredError.class));
    }

    @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion
    public MSEImpurityData groupImpurity(HugeLongArray hugeLongArray, long j, long j2) {
        if (j2 <= 0) {
            return new MSEImpurityData(NegativeSampler.NEGATIVE, NegativeSampler.NEGATIVE, NegativeSampler.NEGATIVE, 0L);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        long j3 = j;
        while (true) {
            long j4 = j3;
            if (j4 >= j2) {
                double d3 = d / j2;
                return new MSEImpurityData((d2 / j2) - (d3 * d3), d2, d, j2);
            }
            double d4 = this.targets.get(hugeLongArray.get(j4));
            d += d4;
            d2 += d4 * d4;
            j3 = j4 + 1;
        }
    }

    @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion
    public void incrementalImpurity(long j, ImpurityCriterion.ImpurityData impurityData) {
        MSEImpurityData mSEImpurityData = (MSEImpurityData) impurityData;
        double d = this.targets.get(j);
        updateImpurityData(mSEImpurityData.sum() + d, mSEImpurityData.sumOfSquares + (d * d), mSEImpurityData.groupSize + 1, mSEImpurityData);
    }

    @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion
    public void decrementalImpurity(long j, ImpurityCriterion.ImpurityData impurityData) {
        MSEImpurityData mSEImpurityData = (MSEImpurityData) impurityData;
        double d = this.targets.get(j);
        updateImpurityData(mSEImpurityData.sum() - d, mSEImpurityData.sumOfSquares - (d * d), mSEImpurityData.groupSize - 1, mSEImpurityData);
    }

    private static void updateImpurityData(double d, double d2, long j, MSEImpurityData mSEImpurityData) {
        double d3 = d / j;
        mSEImpurityData.setImpurity((d2 / j) - (d3 * d3));
        mSEImpurityData.setSum(d);
        mSEImpurityData.setSumOfSquares(d2);
        mSEImpurityData.setGroupSize(j);
    }
}
