package org.neo4j.gds.algorithms.embeddings;

import org.neo4j.gds.algorithms.AlgorithmComputationResult;
import org.neo4j.gds.algorithms.TrainResult;
import org.neo4j.gds.algorithms.runner.AlgorithmResultWithTiming;
import org.neo4j.gds.algorithms.runner.AlgorithmRunner;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer;
import org.neo4j.gds.embeddings.graphsage.ModelData;
import org.neo4j.gds.embeddings.graphsage.algo.GraphSageTrainConfig;
import org.neo4j.gds.modelcatalogservices.ModelCatalogService;

/* loaded from: input_file:org/neo4j/gds/algorithms/embeddings/NodeEmbeddingsAlgorithmsTrainBusinessFacade.class */
public class NodeEmbeddingsAlgorithmsTrainBusinessFacade {
    private final NodeEmbeddingsAlgorithmsFacade nodeEmbeddingsAlgorithmsFacade;
    private final ModelCatalogService modelCatalogService;

    public NodeEmbeddingsAlgorithmsTrainBusinessFacade(NodeEmbeddingsAlgorithmsFacade nodeEmbeddingsAlgorithmsFacade, ModelCatalogService modelCatalogService) {
        this.nodeEmbeddingsAlgorithmsFacade = nodeEmbeddingsAlgorithmsFacade;
        this.modelCatalogService = modelCatalogService;
    }

    public TrainResult<Model<ModelData, GraphSageTrainConfig, GraphSageModelTrainer.GraphSageTrainMetrics>> graphSage(String str, GraphSageTrainConfig graphSageTrainConfig) {
        AlgorithmResultWithTiming runWithTiming = AlgorithmRunner.runWithTiming(() -> {
            return this.nodeEmbeddingsAlgorithmsFacade.graphSageTrain(str, graphSageTrainConfig);
        });
        Model model = (Model) ((AlgorithmComputationResult) runWithTiming.algorithmResult).result().orElse(null);
        if (model != null) {
            this.modelCatalogService.set(model);
            if (graphSageTrainConfig.storeModelToDisk()) {
                this.modelCatalogService.storeModelToDisk(model);
            }
        }
        return TrainResult.builder().trainMillis(runWithTiming.computeMilliseconds).algorithmSpecificFields(model).build();
    }
}
