package org.neo4j.gds.embeddings.node2vec;

import org.neo4j.gds.MemoryEstimateDefinition;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryUsage;

/* loaded from: input_file:org/neo4j/gds/embeddings/node2vec/Node2VecMemoryEstimateDefinition.class */
public final class Node2VecMemoryEstimateDefinition implements MemoryEstimateDefinition {
    private final Node2VecParameters parameters;

    public Node2VecMemoryEstimateDefinition(Node2VecParameters node2VecParameters) {
        this.parameters = node2VecParameters;
    }

    public MemoryEstimation memoryEstimation() {
        int i = this.parameters.walkParameters().walksPerNode;
        int i2 = this.parameters.walkParameters().walkLength;
        return MemoryEstimations.builder(Node2Vec.class).perNode("random walks", j -> {
            return HugeObjectArray.memoryEstimation(j * i, MemoryUsage.sizeOfLongArray(i2));
        }).add("probability cache", randomWalksMemoryEstimation()).add("model", modelMemoryEstimation(this.parameters.trainParameters().embeddingDimension)).build();
    }

    private MemoryEstimation randomWalksMemoryEstimation() {
        return MemoryEstimations.builder(RandomWalkProbabilities.class).perNode("node frequencies", HugeLongArray::memoryEstimation).perNode("positive sampling probabilities", HugeDoubleArray::memoryEstimation).perNode("negative sampling distribution", HugeLongArray::memoryEstimation).build();
    }

    private MemoryEstimation modelMemoryEstimation(int i) {
        long sizeOfFloatArray = MemoryUsage.sizeOfFloatArray(i);
        return MemoryEstimations.builder(Node2VecModel.class).perNode("center embeddings", j -> {
            return HugeObjectArray.memoryEstimation(j, sizeOfFloatArray);
        }).perNode("context embeddings", j2 -> {
            return HugeObjectArray.memoryEstimation(j2, sizeOfFloatArray);
        }).build();
    }
}
