package org.neo4j.gds.embeddings.graphsage;

import java.util.concurrent.ThreadLocalRandom;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.ddl4j.functions.Weights;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Vector;
import org.neo4j.graphalgo.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/LayerFactory.class */
public final class LayerFactory {
    private LayerFactory() {
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Layer createLayer(LayerConfig layerConfig) {
        int rows = layerConfig.rows();
        int cols = layerConfig.cols();
        ActivationFunction activationFunction = layerConfig.activationFunction();
        Weights<Matrix> generateWeights = generateWeights(rows, cols, activationFunction.weightInitBound(rows, cols));
        if (layerConfig.aggregatorType() == Aggregator.AggregatorType.MEAN) {
            return new MeanAggregatingLayer(generateWeights, layerConfig.sampleSize(), activationFunction);
        }
        if (layerConfig.aggregatorType() == Aggregator.AggregatorType.POOL) {
            return new MaxPoolAggregatingLayer(layerConfig.sampleSize(), generateWeights, generateWeights(rows, cols, activationFunction.weightInitBound(rows, cols)), generateWeights(rows, rows, activationFunction.weightInitBound(rows, rows)), new Weights(Vector.fill(0.0d, rows)), activationFunction);
        }
        throw new RuntimeException(StringFormatting.formatWithLocale("Aggregator: %s is unknown", new Object[]{layerConfig.aggregatorType()}));
    }

    public static Weights<Matrix> generateWeights(int i, int i2, double d) {
        return new Weights<>(new Matrix(ThreadLocalRandom.current().doubles(i * i2, -d, d).toArray(), i, i2));
    }
}
