package org.neo4j.gds.ml.nodemodels.metrics;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.utils.StringFormatting;
import org.openjdk.jol.util.Multiset;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/metrics/AllClassMetric.class */
public enum AllClassMetric implements Metric {
    F1_WEIGHTED(new MetricStrategy() { // from class: org.neo4j.gds.ml.nodemodels.metrics.F1Weighted
        @Override // org.neo4j.gds.ml.nodemodels.metrics.AllClassMetric.MetricStrategy
        public double compute(HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2, Multiset<Long> multiset) {
            return multiset.size() == 0 ? EdgeSplitter.NEGATIVE : multiset.keys().stream().mapToDouble(l -> {
                return multiset.count(l) * new F1Score(l.longValue()).compute(hugeLongArray, hugeLongArray2);
            }).sum() / multiset.size();
        }
    }),
    F1_MACRO(new MetricStrategy() { // from class: org.neo4j.gds.ml.nodemodels.metrics.F1Macro
        @Override // org.neo4j.gds.ml.nodemodels.metrics.AllClassMetric.MetricStrategy
        public double compute(HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2, Multiset<Long> multiset) {
            return ((List) multiset.keys().stream().map((v1) -> {
                return new F1Score(v1);
            }).collect(Collectors.toList())).stream().mapToDouble(f1Score -> {
                return f1Score.compute(hugeLongArray, hugeLongArray2);
            }).average().orElse(-1.0d);
        }
    }),
    ACCURACY(new MetricStrategy() { // from class: org.neo4j.gds.ml.nodemodels.metrics.AccuracyMetric
        static final /* synthetic */ boolean $assertionsDisabled;

        @Override // org.neo4j.gds.ml.nodemodels.metrics.AllClassMetric.MetricStrategy
        public double compute(HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2, Multiset<Long> multiset) {
            long j = 0;
            if (!$assertionsDisabled && hugeLongArray.size() != hugeLongArray2.size()) {
                throw new AssertionError(StringFormatting.formatWithLocale("Metrics require equal length targets and predictions. Sizes are %d and %d respectively.", new Object[]{Long.valueOf(hugeLongArray.size()), Long.valueOf(hugeLongArray2.size())}));
            }
            long j2 = 0;
            while (true) {
                long j3 = j2;
                if (j3 >= hugeLongArray.size()) {
                    break;
                }
                if (hugeLongArray2.get(j3) == hugeLongArray.get(j3)) {
                    j++;
                }
                j2 = j3 + 1;
            }
            return hugeLongArray.size() == 0 ? EdgeSplitter.NEGATIVE : BigDecimal.valueOf(j).divide(BigDecimal.valueOf(hugeLongArray.size()), 8, RoundingMode.UP).doubleValue();
        }

        static {
            $assertionsDisabled = !AccuracyMetric.class.desiredAssertionStatus();
        }
    });

    private final MetricStrategy strategy;

    /* loaded from: input_file:org/neo4j/gds/ml/nodemodels/metrics/AllClassMetric$MetricStrategy.class */
    interface MetricStrategy {
        double compute(HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2, Multiset<Long> multiset);
    }

    AllClassMetric(MetricStrategy metricStrategy) {
        this.strategy = metricStrategy;
    }

    @Override // org.neo4j.gds.ml.nodemodels.metrics.Metric
    public double compute(HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2, Multiset<Long> multiset) {
        return this.strategy.compute(hugeLongArray, hugeLongArray2, multiset);
    }
}
