package org.neo4j.gds.similarity;

import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import org.neo4j.gds.core.utils.Intersections;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.UserFunction;
import org.neo4j.values.storable.Values;

/* loaded from: input_file:org/neo4j/gds/similarity/SimilaritiesFunc.class */
public class SimilaritiesFunc {
    private static final Predicate<Number> IS_NULL = Predicate.isEqual(null);
    private static final Comparator<Number> NUMBER_COMPARATOR = new NumberComparator();
    private static final String CATEGORY_KEY = "category";
    private static final String WEIGHT_KEY = "weight";

    /* loaded from: input_file:org/neo4j/gds/similarity/SimilaritiesFunc$NumberComparator.class */
    static class NumberComparator implements Comparator<Number> {
        NumberComparator() {
        }

        @Override // java.util.Comparator
        public int compare(Number number, Number number2) {
            return ((number instanceof Long) && (number2 instanceof Long)) ? ((Long) number).compareTo((Long) number2) : number instanceof Long ? Values.longValue(number.longValue()).compareTo(Values.doubleValue(number2.doubleValue())) : number2 instanceof Long ? Values.doubleValue(number.doubleValue()).compareTo(Values.longValue(number2.longValue())) : Double.compare(number.doubleValue(), number2.doubleValue());
        }
    }

    @UserFunction("gds.similarity.jaccard")
    @Description("RETURN gds.similarity.jaccard(vector1, vector2) - Given two collection vectors, calculate Jaccard similarity")
    public double jaccardSimilarity(@Name("vector1") List<Number> list, @Name("vector2") List<Number> list2) {
        if (list == null || list2 == null) {
            return 0.0d;
        }
        return jaccard(list, list2);
    }

    @UserFunction("gds.similarity.cosine")
    @Description("RETURN gds.similarity.cosine(vector1, vector2) - Given two collection vectors, calculate cosine similarity")
    public double cosineSimilarity(@Name("vector1") List<Number> list, @Name("vector2") List<Number> list2) {
        return Intersections.cosine(toArray(list), toArray(list2), validateLength(list, list2));
    }

    @UserFunction("gds.similarity.pearson")
    @Description("RETURN gds.similarity.pearson(vector1, vector2) - Given two collection vectors, calculate pearson similarity")
    public double pearsonSimilarity(@Name("vector1") List<Number> list, @Name("vector2") List<Number> list2) {
        return Intersections.pearson(toArray(list), toArray(list2), validateLength(list, list2));
    }

    @UserFunction("gds.similarity.euclideanDistance")
    @Description("RETURN gds.similarity.euclideanDistance(vector1, vector2) - Given two collection vectors, calculate the euclidean distance (square root of the sum of the squared differences)")
    public double euclideanDistance(@Name("vector1") List<Number> list, @Name("vector2") List<Number> list2) {
        return Math.sqrt(Intersections.sumSquareDelta(toArray(list), toArray(list2), validateLength(list, list2)));
    }

    @UserFunction("gds.similarity.euclidean")
    @Description("RETURN gds.similarity.euclidean(vector1, vector2) - Given two collection vectors, calculate similarity based on euclidean distance")
    public double euclideanSimilarity(@Name("vector1") List<Number> list, @Name("vector2") List<Number> list2) {
        return 1.0d / (1.0d + euclideanDistance(list, list2));
    }

    @UserFunction("gds.similarity.overlap")
    @Description("RETURN gds.similarity.overlap(vector1, vector2) - Given two collection vectors, calculate overlap similarity")
    public double overlapSimilarity(@Name("vector1") List<Number> list, @Name("vector2") List<Number> list2) {
        list.removeIf(IS_NULL);
        list2.removeIf(IS_NULL);
        if (list == null || list2 == null) {
            return 0.0d;
        }
        HashSet hashSet = new HashSet(list);
        hashSet.retainAll(list2);
        int size = hashSet.size();
        long min = Math.min(list.size(), list2.size());
        if (min == 0) {
            return 0.0d;
        }
        return size / min;
    }

    private double[] toArray(List<Number> list) {
        int size = list.size();
        double[] dArr = new double[size];
        for (int i = 0; i < size; i++) {
            dArr[i] = getDoubleValue(list.get(i));
        }
        return dArr;
    }

    private int validateLength(List<Number> list, List<Number> list2) {
        if (list.size() != list2.size() || list.isEmpty()) {
            throw new RuntimeException("Vectors must be non-empty and of the same size");
        }
        return list.size();
    }

    private double jaccard(List<Number> list, List<Number> list2) {
        list.removeIf(IS_NULL);
        list2.removeIf(IS_NULL);
        list.sort(NUMBER_COMPARATOR);
        list2.sort(NUMBER_COMPARATOR);
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        double d = 0.0d;
        while (i < list.size() && i2 < list2.size()) {
            int compare = NUMBER_COMPARATOR.compare(list.get(i), list2.get(i2));
            if (compare == 0) {
                i3++;
                d += 1.0d;
                i++;
                i2++;
            } else if (compare < 0) {
                d += 1.0d;
                i++;
            } else {
                d += 1.0d;
                i2++;
            }
        }
        double size = d + (list.size() - i) + (list2.size() - i2);
        if (size == 0.0d) {
            return 1.0d;
        }
        return i3 / size;
    }

    private static double getDoubleValue(Number number) {
        return ((Double) Optional.ofNullable(number).map((v0) -> {
            return v0.doubleValue();
        }).orElse(Double.valueOf(0.0d))).doubleValue();
    }
}
