package org.neo4j.gds.embeddings.graphsage;

import java.util.Random;
import org.neo4j.gds.ml.core.functions.Weights;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.core.tensor.Vector;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/LayerFactory.class */
public final class LayerFactory {
    private LayerFactory() {
    }

    public static Layer createLayer(LayerConfig layerConfig) {
        int rows = layerConfig.rows();
        int cols = layerConfig.cols();
        ActivationFunction activationFunction = layerConfig.activationFunction();
        long randomSeed = layerConfig.randomSeed();
        Weights<Matrix> generateWeights = generateWeights(rows, cols, activationFunction.weightInitBound(rows, cols), randomSeed);
        switch (layerConfig.aggregatorType()) {
            case MEAN:
                return new MeanAggregatingLayer(generateWeights, layerConfig.sampleSize(), activationFunction, randomSeed);
            case POOL:
                return new MaxPoolAggregatingLayer(layerConfig.sampleSize(), generateWeights, generateWeights(rows, cols, activationFunction.weightInitBound(rows, cols), randomSeed + 1), generateWeights(rows, rows, activationFunction.weightInitBound(rows, rows), randomSeed + 2), new Weights(Vector.create(0.0d, rows)), activationFunction, randomSeed);
            default:
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("Aggregator: %s is unknown", new Object[]{layerConfig.aggregatorType()}));
        }
    }

    public static Weights<Matrix> generateWeights(int i, int i2, double d, long j) {
        return new Weights<>(new Matrix(new Random(j).doubles(Math.multiplyExact(i, i2), -d, d).toArray(), i, i2));
    }
}
