package org.neo4j.gds.embeddings.graphsage.ddl4j.functions;

import java.util.List;
import org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.ComputationContext;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/ddl4j/functions/Weights.class */
public class Weights<T extends Tensor<T>> extends AbstractVariable<T> {
    private final T data;

    public Weights(T t) {
        super(List.of(), t.dimensions());
        this.data = t;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public T apply(ComputationContext computationContext) {
        return this.data;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public Tensor<?> gradient(Variable<?> variable, ComputationContext computationContext) {
        throw new AbstractVariable.NotAFunctionException();
    }

    public T data() {
        return this.data;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.ddl4j.AbstractVariable, org.neo4j.gds.embeddings.graphsage.ddl4j.Variable
    public boolean requireGradient() {
        return true;
    }

    public static Weights<Matrix> ofMatrix(int i, int i2) {
        return new Weights<>(new Matrix(i, i2));
    }
}
