package org.neo4j.gds.similarity.filteredknn;

import org.apache.commons.lang3.function.TriFunction;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.similarity.filteredknn.FilteredKnnBaseConfig;
import org.neo4j.gds.similarity.knn.ImmutableKnnContext;
import org.neo4j.gds.similarity.knn.KnnContext;
import org.neo4j.gds.similarity.knn.KnnFactory;
import org.neo4j.gds.termination.TerminationFlag;

/* loaded from: input_file:org/neo4j/gds/similarity/filteredknn/FilteredKnnFactory.class */
public class FilteredKnnFactory<CONFIG extends FilteredKnnBaseConfig> extends GraphAlgorithmFactory<FilteredKnn, CONFIG> {
    private static final String FILTERED_KNN_TASK_NAME = "Filtered KNN";
    private final TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> unseededFilteredKnnSupplier;
    private final TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> seededFilteredKnnSupplier;

    public FilteredKnnFactory() {
        this((graph, filteredKnnBaseConfig, knnContext) -> {
            return FilteredKnn.createWithoutSeeding(graph, filteredKnnBaseConfig, knnContext, TerminationFlag.RUNNING_TRUE);
        }, (graph2, filteredKnnBaseConfig2, knnContext2) -> {
            return FilteredKnn.createWithDefaultSeeding(graph2, filteredKnnBaseConfig2, knnContext2, TerminationFlag.RUNNING_TRUE);
        });
    }

    FilteredKnnFactory(TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> triFunction, TriFunction<Graph, CONFIG, KnnContext, FilteredKnn> triFunction2) {
        this.unseededFilteredKnnSupplier = triFunction;
        this.seededFilteredKnnSupplier = triFunction2;
    }

    public String taskName() {
        return FILTERED_KNN_TASK_NAME;
    }

    public FilteredKnn build(Graph graph, CONFIG config, ProgressTracker progressTracker) {
        KnnContext build = ImmutableKnnContext.builder().progressTracker(progressTracker).executor(DefaultPool.INSTANCE).build();
        return config.seedTargetNodes() ? (FilteredKnn) this.seededFilteredKnnSupplier.apply(graph, config, build) : (FilteredKnn) this.unseededFilteredKnnSupplier.apply(graph, config, build);
    }

    public MemoryEstimation memoryEstimation(CONFIG config) {
        return new FilteredKnnMemoryEstimateDefinition(config.toMemoryEstimationParameters()).memoryEstimation();
    }

    public Task progressTask(Graph graph, CONFIG config) {
        return KnnFactory.knnTaskTree(graph.nodeCount(), config.maxIterations());
    }
}
