package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import java.util.List;
import org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Dimensions;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Scalar;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/MeanSquaredError.class */
public class MeanSquaredError extends AbstractVariable<Scalar> {
    private final Variable<Matrix> predictions;
    private final Variable<Matrix> targets;

    public MeanSquaredError(Variable<Matrix> variable, Variable<Matrix> variable2) {
        super(List.of(variable, variable2), Dimensions.scalar());
        this.predictions = variable;
        this.targets = variable2;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Scalar apply(ComputationContext computationContext) {
        Tensor data = computationContext.data(this.predictions);
        Tensor data2 = computationContext.data(this.targets);
        double d = 0.0d;
        for (int i = 0; i < data.totalSize(); i++) {
            d += Math.pow(data.dataAt(i) - data2.dataAt(i), 2.0d);
        }
        return new Scalar(d / data.totalSize());
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        Tensor data = computationContext.data(variable);
        Tensor data2 = variable == this.predictions ? computationContext.data(this.targets) : computationContext.data(this.predictions);
        double[] dArr = new double[data.totalSize()];
        double d = ((Scalar) computationContext.gradient(this)).data()[0] / data.totalSize();
        for (int i = 0; i < data.totalSize(); i++) {
            dArr[i] = d * 2.0d * (data.dataAt(i) - data2.dataAt(i));
        }
        return new Matrix(dArr, data.totalSize(), 1);
    }
}
