package org.neo4j.gds.ml.core;

import java.util.List;
import org.assertj.core.api.AssertionsForInterfaceTypes;
import org.assertj.core.data.Offset;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Scalar;

/* loaded from: input_file:org/neo4j/gds/ml/core/FiniteDifferenceTest.class */
public interface FiniteDifferenceTest {
    public static final String FAIL_MESSAGE = "AutoGrad of %f and FiniteDifference gradients of %f differs for coordinate %s more than the tolerance.";

    default double tolerance() {
        return 1.0E-5d;
    }

    default double epsilon() {
        return 1.0E-4d;
    }

    default void finiteDifferenceShouldApproximateGradient(Weights<?> weights, Variable<Scalar> variable) {
        finiteDifferenceShouldApproximateGradient(List.of(weights), variable);
    }

    default void finiteDifferenceShouldApproximateGradient(List<Weights<?>> list, Variable<Scalar> variable) {
        for (Weights<?> weights : list) {
            for (int i = 0; i < Dimensions.totalSize(weights.dimensions()); i++) {
                ComputationContext computationContext = new ComputationContext();
                double value = computationContext.forward(variable).value();
                computationContext.backward(variable);
                double dataAt = computationContext.gradient(weights).dataAt(i);
                weights.data().addDataAt(i, epsilon());
                double value2 = (new ComputationContext().forward(variable).value() - value) / epsilon();
                AssertionsForInterfaceTypes.assertThat(value2).isNotNaN().withFailMessage(FAIL_MESSAGE, new Object[]{Double.valueOf(value2), Double.valueOf(dataAt), Integer.valueOf(i)}).isEqualTo(dataAt, Offset.offset(Double.valueOf(tolerance())));
            }
        }
    }
}
