package org.neo4j.gds.embeddings.graphsage;

import java.util.List;
import java.util.concurrent.ExecutorService;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.ComputationContext;
import org.neo4j.gds.ml.core.subgraph.SubGraph;
import org.neo4j.gds.ml.core.tensor.Matrix;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/GraphSageEmbeddingsGenerator.class */
public class GraphSageEmbeddingsGenerator {
    private final Layer[] layers;
    private final int batchSize;
    private final int concurrency;
    private final boolean isWeighted;
    private final FeatureFunction featureFunction;
    private final ExecutorService executor;
    private final ProgressTracker progressTracker;

    public GraphSageEmbeddingsGenerator(Layer[] layerArr, int i, int i2, boolean z, FeatureFunction featureFunction, ExecutorService executorService, ProgressTracker progressTracker) {
        this.layers = layerArr;
        this.batchSize = i;
        this.concurrency = i2;
        this.isWeighted = z;
        this.featureFunction = featureFunction;
        this.executor = executorService;
        this.progressTracker = progressTracker;
    }

    public HugeObjectArray<double[]> makeEmbeddings(Graph graph, HugeObjectArray<double[]> hugeObjectArray) {
        HugeObjectArray<double[]> newArray = HugeObjectArray.newArray(double[].class, graph.nodeCount());
        this.progressTracker.beginSubTask();
        ParallelUtil.runWithConcurrency(this.concurrency, PartitionUtils.rangePartitionWithBatchSize(graph.nodeCount(), this.batchSize, partition -> {
            return createEmbeddings(graph, partition, hugeObjectArray, newArray);
        }), this.executor);
        this.progressTracker.endSubTask();
        return newArray;
    }

    private Runnable createEmbeddings(Graph graph, Partition partition, HugeObjectArray<double[]> hugeObjectArray, HugeObjectArray<double[]> hugeObjectArray2) {
        return () -> {
            List<SubGraph> subGraphsPerLayer = GraphSageHelper.subGraphsPerLayer(graph, this.isWeighted, partition.stream().toArray(), this.layers);
            Matrix forward = new ComputationContext().forward(GraphSageHelper.embeddingsComputationGraph(subGraphsPerLayer, this.layers, this.featureFunction.apply(graph, subGraphsPerLayer.get(subGraphsPerLayer.size() - 1).originalNodeIds(), hugeObjectArray)));
            long startNode = partition.startNode();
            long nodeCount = partition.nodeCount();
            for (int i = 0; i < nodeCount; i++) {
                hugeObjectArray2.set(startNode + i, forward.getRow(i));
            }
            this.progressTracker.logProgress(nodeCount);
        };
    }
}
