package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import org.neo4j.gds.embeddings.graphsage.algo.ActivationFunctionType;
import org.neo4j.gds.embeddings.graphsage.algo.AggregatorType;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.ml.core.functions.MultiMean;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/MeanAggregator.class */
public class MeanAggregator implements Aggregator {
    private final Weights<Matrix> weights;
    private final ActivationFunction activationFunction;
    private final ActivationFunctionType activationFunctionType;

    public MeanAggregator(Weights<Matrix> weights, ActivationFunctionWrapper activationFunctionWrapper) {
        this.weights = weights;
        this.activationFunction = activationFunctionWrapper.activationFunction();
        this.activationFunctionType = activationFunctionWrapper.activationFunctionType();
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public Variable<Matrix> aggregate(Variable<Matrix> variable, SubGraph subGraph) {
        return this.activationFunction.apply(MatrixMultiplyWithTransposedSecondOperand.of(new MultiMean(variable, subGraph), this.weights));
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.weights);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public List<Weights<? extends Tensor<?>>> weightsWithoutBias() {
        return List.of(this.weights);
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public AggregatorType type() {
        return AggregatorType.MEAN;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public ActivationFunctionType activationFunctionType() {
        return this.activationFunctionType;
    }

    public Matrix weightsData() {
        return this.weights.data();
    }
}
