package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.splitting.EdgeSplitter;

/* loaded from: input_file:org/neo4j/gds/ml/decisiontree/Entropy.class */
public class Entropy implements ImpurityCriterion {
    private static final double LN_2;
    private final HugeIntArray expectedMappedLabels;
    private final int numberOfClasses;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/ml/decisiontree/Entropy$EntropyImpurityData.class */
    public static class EntropyImpurityData implements ImpurityCriterion.ImpurityData {
        private double impurity;
        private final long[] classCounts;
        private long groupSize;

        EntropyImpurityData(double d, long[] jArr, long j) {
            this.impurity = d;
            this.classCounts = jArr;
            this.groupSize = j;
        }

        public static long memoryEstimation(int i) {
            return MemoryUsage.sizeOfInstance(EntropyImpurityData.class) + MemoryUsage.sizeOfLongArray(i);
        }

        @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion.ImpurityData
        public void copyTo(ImpurityCriterion.ImpurityData impurityData) {
            EntropyImpurityData entropyImpurityData = (EntropyImpurityData) impurityData;
            entropyImpurityData.setImpurity(impurity());
            entropyImpurityData.setGroupSize(groupSize());
            System.arraycopy(classCounts(), 0, entropyImpurityData.classCounts(), 0, classCounts().length);
        }

        @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion.ImpurityData
        public double impurity() {
            return this.impurity;
        }

        public void setImpurity(double d) {
            this.impurity = d;
        }

        public long[] classCounts() {
            return this.classCounts;
        }

        @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion.ImpurityData
        public long groupSize() {
            return this.groupSize;
        }

        public void setGroupSize(long j) {
            this.groupSize = j;
        }
    }

    public Entropy(HugeIntArray hugeIntArray, int i) {
        this.expectedMappedLabels = hugeIntArray;
        this.numberOfClasses = i;
    }

    public static Entropy fromOriginalLabels(HugeLongArray hugeLongArray, LocalIdMap localIdMap) {
        if (!$assertionsDisabled && hugeLongArray.size() <= 0) {
            throw new AssertionError();
        }
        HugeIntArray newArray = HugeIntArray.newArray(hugeLongArray.size());
        newArray.setAll(j -> {
            return localIdMap.toMapped(hugeLongArray.get(j));
        });
        return new Entropy(newArray, localIdMap.size());
    }

    public static MemoryRange memoryEstimation(long j) {
        return MemoryRange.of(HugeIntArray.memoryEstimation(j)).add(MemoryRange.of(MemoryUsage.sizeOfInstance(Entropy.class)));
    }

    @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion
    public EntropyImpurityData groupImpurity(HugeLongArray hugeLongArray, long j, long j2) {
        if (j2 == 0) {
            return new EntropyImpurityData(EdgeSplitter.NEGATIVE, new long[this.numberOfClasses], j2);
        }
        long[] jArr = new long[this.numberOfClasses];
        long j3 = j;
        while (true) {
            long j4 = j3;
            if (j4 >= j2) {
                break;
            }
            int i = this.expectedMappedLabels.get(hugeLongArray.get(j4));
            jArr[i] = jArr[i] + 1;
            j3 = j4 + 1;
        }
        double d = 0.0d;
        for (long j5 : jArr) {
            if (j5 != 0) {
                double d2 = j5 / j2;
                d -= d2 * Math.log(d2);
            }
        }
        return new EntropyImpurityData(d / LN_2, jArr, j2);
    }

    @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion
    public void incrementalImpurity(long j, ImpurityCriterion.ImpurityData impurityData) {
        EntropyImpurityData entropyImpurityData = (EntropyImpurityData) impurityData;
        int i = this.expectedMappedLabels.get(j);
        updateImpurityData(i, entropyImpurityData.groupSize() + 1, entropyImpurityData.classCounts[i] + 1, entropyImpurityData);
    }

    @Override // org.neo4j.gds.ml.decisiontree.ImpurityCriterion
    public void decrementalImpurity(long j, ImpurityCriterion.ImpurityData impurityData) {
        EntropyImpurityData entropyImpurityData = (EntropyImpurityData) impurityData;
        int i = this.expectedMappedLabels.get(j);
        updateImpurityData(i, entropyImpurityData.groupSize() - 1, entropyImpurityData.classCounts[i] - 1, entropyImpurityData);
    }

    private static void updateImpurityData(int i, long j, long j2, EntropyImpurityData entropyImpurityData) {
        long j3 = entropyImpurityData.classCounts()[i];
        double d = 0.0d;
        if (j > 0) {
            double impurity = entropyImpurityData.impurity() * LN_2;
            if (entropyImpurityData.groupSize() > 0) {
                impurity = (impurity - Math.log(entropyImpurityData.groupSize())) * entropyImpurityData.groupSize();
            }
            if (j3 > 0) {
                impurity += j3 * Math.log(j3);
            }
            if (j2 > 0) {
                impurity -= j2 * Math.log(j2);
            }
            d = ((impurity / j) + Math.log(j)) / LN_2;
        }
        entropyImpurityData.classCounts()[i] = j2;
        entropyImpurityData.setGroupSize(j);
        entropyImpurityData.setImpurity(d);
    }

    static {
        $assertionsDisabled = !Entropy.class.desiredAssertionStatus();
        LN_2 = Math.log(2.0d);
    }
}
