package org.neo4j.gds.ml.linkmodels;

import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;

/* loaded from: input_file:org/neo4j/gds/ml/linkmodels/LinkPredictionTrainFactory.class */
public class LinkPredictionTrainFactory extends AlgorithmFactory<LinkPredictionTrain, LinkPredictionTrainConfig> {
    LinkPredictionTrainFactory() {
    }

    protected String taskName() {
        return "LinkPredictionTrain";
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public LinkPredictionTrain build(Graph graph, LinkPredictionTrainConfig linkPredictionTrainConfig, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        return new LinkPredictionTrain(graph, linkPredictionTrainConfig, progressTracker);
    }

    public MemoryEstimation memoryEstimation(LinkPredictionTrainConfig linkPredictionTrainConfig) {
        return LinkPredictionTrainEstimation.estimate(linkPredictionTrainConfig);
    }

    public Task progressTask(Graph graph, LinkPredictionTrainConfig linkPredictionTrainConfig) {
        return Tasks.task(taskName(), Tasks.leaf("ModelSelection", linkPredictionTrainConfig.params().size() * linkPredictionTrainConfig.validationFolds()), new Task[]{Tasks.leaf("Training"), Tasks.task("Evaluation", Tasks.leaf("Training"), new Task[]{Tasks.leaf("Testing")})});
    }
}
