package org.neo4j.gds.embeddings.graphsage;

import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.Relu;
import org.neo4j.gds.ml.core.functions.Sigmoid;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.graphalgo.utils.StringFormatting;
import org.neo4j.graphalgo.utils.StringJoining;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ActivationFunction.class */
public enum ActivationFunction {
    SIGMOID { // from class: org.neo4j.gds.embeddings.graphsage.ActivationFunction.1
        @Override // org.neo4j.gds.embeddings.graphsage.ActivationFunction
        public Function<Variable<Matrix>, Variable<Matrix>> activationFunction() {
            return Sigmoid::new;
        }

        @Override // org.neo4j.gds.embeddings.graphsage.ActivationFunction
        public double weightInitBound(int i, int i2) {
            return Math.sqrt(2.0d / (i + i2));
        }
    },
    RELU { // from class: org.neo4j.gds.embeddings.graphsage.ActivationFunction.2
        @Override // org.neo4j.gds.embeddings.graphsage.ActivationFunction
        public Function<Variable<Matrix>, Variable<Matrix>> activationFunction() {
            return Relu::new;
        }

        @Override // org.neo4j.gds.embeddings.graphsage.ActivationFunction
        public double weightInitBound(int i, int i2) {
            return Math.sqrt(2.0d / i2);
        }
    };

    private static final List<String> VALUES = (List) Arrays.stream(values()).map((v0) -> {
        return v0.name();
    }).collect(Collectors.toList());

    public abstract Function<Variable<Matrix>, Variable<Matrix>> activationFunction();

    public abstract double weightInitBound(int i, int i2);

    public static ActivationFunction of(String str) {
        return valueOf(StringFormatting.toUpperCaseWithLocale(str));
    }

    public static ActivationFunction parse(Object obj) {
        if (!(obj instanceof String)) {
            if (obj instanceof ActivationFunction) {
                return (ActivationFunction) obj;
            }
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Expected ActivationFunction or String. Got %s.", new Object[]{obj.getClass().getSimpleName()}));
        }
        String upperCaseWithLocale = StringFormatting.toUpperCaseWithLocale((String) obj);
        if (VALUES.contains(upperCaseWithLocale)) {
            return of(upperCaseWithLocale);
        }
        throw new IllegalArgumentException(StringFormatting.formatWithLocale("ActivationFunction `%s` is not supported. Must be one of: %s.", new Object[]{obj, StringJoining.join(VALUES)}));
    }

    public static String toString(ActivationFunction activationFunction) {
        return activationFunction.toString();
    }
}
