package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.concurrent.ExecutorService;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.config.MutateConfig;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.TrainConfigTransformer;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageBaseConfig;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.ml.core.EmbeddingUtils;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/algo/GraphSageAlgorithmFactory.class */
public class GraphSageAlgorithmFactory<CONFIG extends GraphSageBaseConfig> extends GraphAlgorithmFactory<GraphSage, CONFIG> {
    private final ModelCatalog modelCatalog;

    public GraphSageAlgorithmFactory(ModelCatalog modelCatalog) {
        this.modelCatalog = modelCatalog;
    }

    public GraphSage build(Graph graph, GraphSageParameters graphSageParameters, Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics> model, ProgressTracker progressTracker) {
        ExecutorService executorService = DefaultPool.INSTANCE;
        if (graph.hasRelationshipProperty()) {
            EmbeddingUtils.validateRelationshipWeightPropertyValue(graph, graphSageParameters.concurrency(), executorService);
        }
        return new GraphSage(graph, model, graphSageParameters.concurrency(), graphSageParameters.batchSize(), executorService, progressTracker, TerminationFlag.RUNNING_TRUE);
    }

    public GraphSage build(Graph graph, CONFIG config, ProgressTracker progressTracker) {
        return build(graph, config.toParameters(), GraphSageModelResolver.resolveModel(this.modelCatalog, config.modelUser(), config.modelName()), progressTracker);
    }

    public String taskName() {
        return GraphSage.class.getSimpleName();
    }

    public MemoryEstimation memoryEstimation(CONFIG config) {
        return new GraphSageMemoryEstimateDefinition(TrainConfigTransformer.toMemoryEstimateParameters(GraphSageModelResolver.resolveModel(this.modelCatalog, config.username(), config.modelName()).trainConfig()), config instanceof MutateConfig).memoryEstimation();
    }

    public Task progressTask(Graph graph, CONFIG config) {
        return Tasks.leaf(taskName(), graph.nodeCount());
    }
}
