package org.neo4j.gds.embeddings.graphsage;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryRange;
import org.neo4j.gds.ml.core.Variable;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Tensor;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/Aggregator.class */
public interface Aggregator {

    /* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/Aggregator$AggregatorType.class */
    public enum AggregatorType {
        MEAN { // from class: org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType.1
            @Override // org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType
            public MemoryRange memoryEstimation(long j, long j2, long j3, long j4, int i, int i2) {
                return MemoryRange.of(Estimate.sizeOfDoubleArray(j * i) + (2 * Estimate.sizeOfDoubleArray(j * i2)), Estimate.sizeOfDoubleArray(j2 * i) + (2 * Estimate.sizeOfDoubleArray(j2 * i2)));
            }
        },
        POOL { // from class: org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType.2
            @Override // org.neo4j.gds.embeddings.graphsage.Aggregator.AggregatorType
            public MemoryRange memoryEstimation(long j, long j2, long j3, long j4, int i, int i2) {
                return MemoryRange.of((3 * Estimate.sizeOfDoubleArray(j3 * i2)) + (6 * Estimate.sizeOfDoubleArray(j * i2)), (3 * Estimate.sizeOfDoubleArray(j4 * i2)) + (6 * Estimate.sizeOfDoubleArray(j2 * i2)));
            }
        };

        private static final List<String> VALUES = (List) Arrays.stream(values()).map((v0) -> {
            return v0.name();
        }).collect(Collectors.toList());

        public static AggregatorType of(String str) {
            return valueOf(StringFormatting.toUpperCaseWithLocale(str));
        }

        public static AggregatorType parse(Object obj) {
            if (!(obj instanceof String)) {
                if (obj instanceof AggregatorType) {
                    return (AggregatorType) obj;
                }
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("Expected Aggregator or String. Got %s.", new Object[]{obj.getClass().getSimpleName()}));
            }
            String upperCaseWithLocale = StringFormatting.toUpperCaseWithLocale((String) obj);
            if (VALUES.contains(upperCaseWithLocale)) {
                return of(upperCaseWithLocale);
            }
            throw new IllegalArgumentException(StringFormatting.formatWithLocale("Aggregator `%s` is not supported. Must be one of: %s.", new Object[]{obj, StringJoining.join(VALUES)}));
        }

        public static String toString(AggregatorType aggregatorType) {
            return aggregatorType.toString();
        }

        public abstract MemoryRange memoryEstimation(long j, long j2, long j3, long j4, int i, int i2);
    }

    Variable<Matrix> aggregate(Variable<Matrix> variable, SubGraph subGraph);

    List<Weights<? extends Tensor<?>>> weights();

    List<Weights<? extends Tensor<?>>> weightsWithoutBias();

    AggregatorType type();

    ActivationFunction activationFunction();
}
