package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import java.util.List;
import org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Vector;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/MatrixVectorSum.class */
public class MatrixVectorSum extends AbstractVariable<Matrix> {
    private final Variable<Matrix> matrix;
    private final Variable<Vector> vector;
    private final int rows;
    private final int cols;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MatrixVectorSum(Variable<Matrix> variable, Variable<Vector> variable2) {
        super(List.of(variable, variable2), variable.dimensions());
        if (!$assertionsDisabled && variable.dimension(1) != variable2.dimension(0)) {
            throw new AssertionError(StringFormatting.formatWithLocale("Cannot broadcast vector with length %d to a matrix with %d columns", new Object[]{Integer.valueOf(variable2.dimension(0)), Integer.valueOf(variable.dimension(1))}));
        }
        this.matrix = variable;
        this.rows = variable.dimension(0);
        this.cols = variable.dimension(1);
        this.vector = variable2;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix apply(ComputationContext computationContext) {
        double[] data = computationContext.data(this.matrix).data();
        double[] data2 = computationContext.data(this.vector).data();
        double[] dArr = new double[data.length];
        for (int i = 0; i < this.rows; i++) {
            for (int i2 = 0; i2 < this.cols; i2++) {
                int i3 = (i * this.cols) + i2;
                dArr[i3] = data[i3] + data2[i2];
            }
        }
        return new Matrix(dArr, this.rows, this.cols);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        if (variable == this.matrix) {
            return computationContext.gradient(this);
        }
        Tensor<?> gradient = computationContext.gradient(this);
        double[] dArr = new double[this.cols];
        for (int i = 0; i < this.rows; i++) {
            for (int i2 = 0; i2 < this.cols; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + gradient.dataAt((i * this.cols) + i2);
            }
        }
        return new Vector(dArr);
    }

    static {
        $assertionsDisabled = !MatrixVectorSum.class.desiredAssertionStatus();
    }
}
