package org.neo4j.gds.similarity.knn;

import com.carrotsearch.hppc.LongArrayList;
import java.util.List;
import java.util.function.LongFunction;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnSampler;

/* loaded from: input_file:org/neo4j/gds/similarity/knn/KnnFactory.class */
public class KnnFactory<CONFIG extends KnnBaseConfig> extends GraphAlgorithmFactory<Knn, CONFIG> {
    private static final String KNN_BASE_TASK_NAME = "Knn";

    public String taskName() {
        return KNN_BASE_TASK_NAME;
    }

    public Knn build(Graph graph, CONFIG config, ProgressTracker progressTracker) {
        return Knn.createWithDefaults(graph, config, ImmutableKnnContext.builder().progressTracker(progressTracker).executor(DefaultPool.INSTANCE).build());
    }

    public MemoryEstimation memoryEstimation(CONFIG config) {
        return memoryEstimation(taskName(), Knn.class, config);
    }

    public static MemoryRange initialSamplerMemoryEstimation(KnnSampler.SamplerType samplerType, long j) {
        switch (samplerType) {
            case UNIFORM:
                return UniformKnnSampler.memoryEstimation(j);
            case RANDOMWALK:
                return RandomWalkKnnSampler.memoryEstimation(j);
            default:
                throw new IllegalStateException("Invalid KnnSampler");
        }
    }

    public Task progressTask(Graph graph, CONFIG config) {
        return knnTaskTree(graph, config);
    }

    public static Task knnTaskTree(Graph graph, KnnBaseConfig knnBaseConfig) {
        return Tasks.task(KNN_BASE_TASK_NAME, Tasks.leaf("Initialize random neighbors", graph.nodeCount()), new Task[]{Tasks.iterativeDynamic("Iteration", () -> {
            return List.of(Tasks.leaf("Split old and new neighbors", graph.nodeCount()), Tasks.leaf("Reverse old and new neighbors", graph.nodeCount()), Tasks.leaf("Join neighbors", graph.nodeCount()));
        }, knnBaseConfig.maxIterations())});
    }

    public static <CONFIG extends KnnBaseConfig> MemoryEstimation memoryEstimation(String str, Class<?> cls, CONFIG config) {
        return MemoryEstimations.setup(str, (graphDimensions, i) -> {
            int boundedK = config.boundedK(graphDimensions.nodeCount());
            int sampledK = config.sampledK(graphDimensions.nodeCount());
            LongFunction longFunction = j -> {
                return MemoryRange.of(HugeObjectArray.memoryEstimation(j, 0L), HugeObjectArray.memoryEstimation(j, MemoryUsage.sizeOfInstance(LongArrayList.class) + MemoryUsage.sizeOfLongArray(sampledK)));
            };
            MemoryRange memoryUsage = NeighborList.memoryEstimation(boundedK).estimate(graphDimensions, i).memoryUsage();
            return MemoryEstimations.builder(cls).rangePerNode("top-k-neighbors-list", j2 -> {
                return MemoryRange.of(HugeObjectArray.memoryEstimation(j2, memoryUsage.min), HugeObjectArray.memoryEstimation(j2, memoryUsage.max));
            }).rangePerNode("old-neighbors", longFunction).rangePerNode("new-neighbors", longFunction).rangePerNode("old-reverse-neighbors", longFunction).rangePerNode("new-reverse-neighbors", longFunction).fixed("initial-random-neighbors (per thread)", initialSamplerMemoryEstimation(config.initialSampler(), boundedK).times(i)).fixed("sampled-random-neighbors (per thread)", MemoryRange.of(MemoryUsage.sizeOfIntArray(MemoryUsage.sizeOfOpenHashContainer(sampledK)) * i)).build();
        });
    }
}
