package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/MatrixSum.class */
public class MatrixSum extends AbstractVariable<Matrix> {
    static final /* synthetic */ boolean $assertionsDisabled;

    public MatrixSum(List<Variable<Matrix>> list) {
        super(list, validatedDimensions(list));
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Matrix apply(ComputationContext computationContext) {
        Matrix fill = Matrix.fill(0.0d, dimension(0), dimension(1));
        Iterator<? extends Variable<?>> it = parents().iterator();
        while (it.hasNext()) {
            fill.addInPlace(computationContext.data(it.next()));
        }
        return fill;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        return computationContext.gradient(this);
    }

    private static int[] validatedDimensions(List<Variable<Matrix>> list) {
        int[] dimensions = list.get(0).dimensions();
        list.forEach(variable -> {
            if (!$assertionsDisabled && !Arrays.equals(variable.dimensions(), dimensions)) {
                throw new AssertionError();
            }
        });
        return dimensions;
    }

    static {
        $assertionsDisabled = !MatrixSum.class.desiredAssertionStatus();
    }
}
