package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/Relu.class */
public class Relu<T extends Tensor<T>> extends SingleParentVariable<T> {
    private static final double ALPHA = 0.01d;

    public Relu(Variable<T> variable) {
        super(variable, variable.dimensions());
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public T apply(ComputationContext computationContext) {
        return (T) computationContext.data(parent()).map(d -> {
            return d > 0.0d ? d : ALPHA * d;
        });
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public T gradient(Variable<?> variable, ComputationContext computationContext) {
        return (T) computationContext.data(variable).map(d -> {
            if (d > 0.0d) {
                return 1.0d;
            }
            return ALPHA;
        });
    }
}
