package org.neo4j.gds.ml.nodemodels.metrics;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/metrics/MetricSpecification.class */
public interface MetricSpecification {
    public static final String NUMBER_OR_STAR = "((?:-?[\\d]+)|(?:\\*))";
    public static final SortedMap<String, Function<Long, Metric>> SINGLE_CLASS_METRIC_FACTORIES = new TreeMap(Map.of(F1Score.NAME, (v1) -> {
        return new F1Score(v1);
    }, Precision.NAME, (v1) -> {
        return new Precision(v1);
    }, Recall.NAME, (v1) -> {
        return new Recall(v1);
    }, Accuracy.NAME, (v1) -> {
        return new Accuracy(v1);
    }));
    public static final String VALID_SINGLE_CLASS_METRICS = String.join("|", SINGLE_CLASS_METRIC_FACTORIES.keySet());
    public static final Pattern SINGLE_CLASS_METRIC_PATTERN = Pattern.compile("(" + VALID_SINGLE_CLASS_METRICS + ")\\([\\s]*CLASS[\\s]*=[\\s]*((?:-?[\\d]+)|(?:\\*))[\\s]*\\)");

    static MemoryEstimation memoryEstimation(int i) {
        return MemoryEstimations.builder().rangePerNode("metrics", j -> {
            long sizeOf = MemoryUsage.sizeOf(new F1Score(1L));
            return MemoryRange.of(1 * sizeOf, i * sizeOf);
        }).build();
    }

    static String composeSpecification(String str, String str2) {
        return StringFormatting.formatWithLocale("%s(class=%s)", new Object[]{str, str2});
    }

    Stream<Metric> createMetrics(Collection<Long> collection);

    String asString();

    static List<MetricSpecification> parse(List<String> list) {
        if (list.isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("No metrics specified, we require at least one", new Object[0]));
        }
        String upperCase = list.get(0).toUpperCase(Locale.ENGLISH);
        ArrayList arrayList = new ArrayList();
        if (upperCase.contains("*")) {
            arrayList.add(StringFormatting.formatWithLocale("The primary (first) metric provided must be one of %s.", new Object[]{String.join(", ", validPrimaryMetricExpressions())}));
        }
        List list2 = (List) list.stream().filter(MetricSpecification::invalidSpecification).collect(Collectors.toList());
        if (!list2.isEmpty()) {
            arrayList.add(errorMessage(list2));
        }
        if (arrayList.isEmpty()) {
            return (List) list.stream().map(MetricSpecification::parse).distinct().collect(Collectors.toList());
        }
        throw new IllegalArgumentException(String.join(" ", arrayList));
    }

    static MetricSpecification parse(String str) {
        String upperCaseWithLocale = StringFormatting.toUpperCaseWithLocale(str);
        Matcher matcher = SINGLE_CLASS_METRIC_PATTERN.matcher(upperCaseWithLocale);
        if (!matcher.matches()) {
            try {
                AllClassMetric valueOf = AllClassMetric.valueOf(upperCaseWithLocale);
                return createSpecification(collection -> {
                    return Stream.of(valueOf);
                }, upperCaseWithLocale);
            } catch (Exception e) {
                failSingleSpecification(str);
            }
        }
        String group = matcher.group(1);
        if (matcher.group(2).equals("*")) {
            for (Map.Entry<String, Function<Long, Metric>> entry : SINGLE_CLASS_METRIC_FACTORIES.entrySet()) {
                if (entry.getKey().equals(group)) {
                    return createSpecification(collection2 -> {
                        return collection2.stream().map(l -> {
                            return (Metric) ((Function) entry.getValue()).apply(l);
                        });
                    }, composeSpecification(group, "*"));
                }
            }
        }
        long parseLong = Long.parseLong(matcher.group(2));
        return createSpecification(collection3 -> {
            return Stream.of(SINGLE_CLASS_METRIC_FACTORIES.get(group).apply(Long.valueOf(parseLong)));
        }, composeSpecification(group, String.valueOf(parseLong)));
    }

    static MetricSpecification createSpecification(final Function<Collection<Long>, Stream<Metric>> function, final String str) {
        return new MetricSpecification() { // from class: org.neo4j.gds.ml.nodemodels.metrics.MetricSpecification.1
            @Override // org.neo4j.gds.ml.nodemodels.metrics.MetricSpecification
            public Stream<Metric> createMetrics(Collection<Long> collection) {
                return (Stream) function.apply(collection);
            }

            @Override // org.neo4j.gds.ml.nodemodels.metrics.MetricSpecification
            public String asString() {
                return str;
            }

            public String toString() {
                return asString();
            }

            public boolean equals(Object obj) {
                if (obj instanceof MetricSpecification) {
                    return asString().equals(((MetricSpecification) obj).asString());
                }
                return false;
            }

            public int hashCode() {
                return asString().hashCode();
            }
        };
    }

    private static List<String> allValidMetricExpressions() {
        return validMetricExpressions(true);
    }

    private static List<String> validPrimaryMetricExpressions() {
        return validMetricExpressions(false);
    }

    private static List<String> validMetricExpressions(boolean z) {
        LinkedList linkedList = new LinkedList();
        for (AllClassMetric allClassMetric : AllClassMetric.values()) {
            linkedList.add(allClassMetric.name());
        }
        for (String str : SINGLE_CLASS_METRIC_FACTORIES.keySet()) {
            if (z) {
                linkedList.add(str + "(class=*)");
            }
            linkedList.add(str + "(class=<class value>)");
        }
        return linkedList;
    }

    static List<String> specificationsToString(List<MetricSpecification> list) {
        return (List) list.stream().map((v0) -> {
            return v0.asString();
        }).collect(Collectors.toList());
    }

    static void failSingleSpecification(String str) {
        throw new IllegalArgumentException(errorMessage(List.of(str)));
    }

    static String errorMessage(List<String> list) {
        Object[] objArr = new Object[3];
        objArr[0] = list.size() == 1 ? "" : "s";
        objArr[1] = list.stream().map(str -> {
            return "`" + str + "`";
        }).collect(Collectors.joining(", "));
        objArr[2] = String.join(", ", allValidMetricExpressions());
        return StringFormatting.formatWithLocale("Invalid metric expression%s %s. Available metrics are %s (case insensitive and space allowed between brackets).", objArr);
    }

    static boolean invalidSpecification(String str) {
        String upperCase = str.toUpperCase(Locale.ENGLISH);
        if (SINGLE_CLASS_METRIC_PATTERN.matcher(upperCase).matches()) {
            return false;
        }
        Stream map = Arrays.stream(AllClassMetric.values()).map((v0) -> {
            return v0.name();
        });
        Objects.requireNonNull(upperCase);
        return map.noneMatch((v1) -> {
            return r1.equals(v1);
        });
    }
}
