package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.function.Function;
import org.neo4j.gds.embeddings.graphsage.ddl4j.Variable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MatrixMultiplyWithTransposedSecondOperand;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.MultiMean;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.WeightedMultiMean;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Weights;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Tensor;
import org.neo4j.gds.embeddings.graphsage.subgraph.SubGraph;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/MeanAggregator.class */
public class MeanAggregator implements Aggregator {
    private final Weights<Matrix> weights;
    private final Function<Variable<Matrix>, Variable<Matrix>> activationFunction;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MeanAggregator(Weights<Matrix> weights, Function<Variable<Matrix>, Variable<Matrix>> function) {
        this.weights = weights;
        this.activationFunction = function;
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public Variable<Matrix> aggregate(Variable<Matrix> variable, SubGraph subGraph) {
        return this.activationFunction.apply(MatrixMultiplyWithTransposedSecondOperand.of((Variable) subGraph.maybeRelationshipWeightsFunction.map(relationshipWeights -> {
            return new WeightedMultiMean(variable, relationshipWeights, subGraph);
        }).orElse(new MultiMean(variable, subGraph.adjacency, subGraph.selfAdjacency)), this.weights));
    }

    @Override // org.neo4j.gds.embeddings.graphsage.Aggregator
    public List<Weights<? extends Tensor<?>>> weights() {
        return List.of(this.weights);
    }
}
