package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.List;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.mem.Estimate;
import org.neo4j.gds.mem.MemoryEstimateDefinition;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryRange;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainEstimateDefinition.class */
public class GraphSageTrainEstimateDefinition implements MemoryEstimateDefinition {
    private final GraphSageTrainMemoryEstimateParameters parameters;

    public GraphSageTrainEstimateDefinition(GraphSageTrainMemoryEstimateParameters graphSageTrainMemoryEstimateParameters) {
        this.parameters = graphSageTrainMemoryEstimateParameters;
    }

    public MemoryEstimation memoryEstimation() {
        return MemoryEstimations.setup("", graphDimensions -> {
            return estimate(this.parameters, graphDimensions.nodeCount(), graphDimensions.estimationNodeLabelCount());
        });
    }

    private MemoryEstimation estimate(GraphSageTrainMemoryEstimateParameters graphSageTrainMemoryEstimateParameters, long j, int i) {
        List<LayerParameters> layerParameters = graphSageTrainMemoryEstimateParameters.layerParameters();
        int size = layerParameters.size();
        MemoryEstimations.Builder startField = MemoryEstimations.builder("GraphSageTrain").startField("residentMemory").startField("weights");
        long j2 = 0;
        long j3 = 0;
        for (int i2 = 0; i2 < size; i2++) {
            LayerParameters layerParameters2 = layerParameters.get(i2);
            int rows = layerParameters2.rows() * layerParameters2.cols();
            long sizeOfDoubleArray = Estimate.sizeOfDoubleArray(rows);
            if (layerParameters2.aggregatorType() == AggregatorType.POOL) {
                sizeOfDoubleArray = sizeOfDoubleArray + Estimate.sizeOfDoubleArray(layerParameters2.rows() * layerParameters2.rows()) + Estimate.sizeOfDoubleArray(layerParameters2.rows() * layerParameters2.rows()) + Estimate.sizeOfDoubleArray(layerParameters2.rows());
            }
            startField.fixed("layer " + (i2 + 1), sizeOfDoubleArray);
            j2 += 2 * Estimate.sizeOfDoubleArray(rows);
            j3 += 5 * rows;
        }
        boolean isMultiLabel = graphSageTrainMemoryEstimateParameters.isMultiLabel();
        MemoryEstimations.Builder field = startField.endField().endField().startField("temporaryMemory").field("this.instance", GraphSage.class);
        if (isMultiLabel) {
            field.fixed("weightsByLabel", MemoryRange.of(Estimate.sizeOfDoubleArray(graphSageTrainMemoryEstimateParameters.estimationFeatureDimension() * 1), Estimate.sizeOfDoubleArray(graphSageTrainMemoryEstimateParameters.estimationFeatureDimension() * (graphSageTrainMemoryEstimateParameters.numberOfFeatureProperties() + 1))).times(i));
        }
        return field.rangePerNode("initialFeatures", j4 -> {
            return MemoryRange.of(HugeObjectArray.memoryEstimation(j4, Estimate.sizeOfDoubleArray(isMultiLabel ? 1L : graphSageTrainMemoryEstimateParameters.estimationFeatureDimension())), HugeObjectArray.memoryEstimation(j4, Estimate.sizeOfDoubleArray(graphSageTrainMemoryEstimateParameters.estimationFeatureDimension())));
        }).startField("trainOnEpoch").fixed("initialAdamOptimizer", j2).perThread("concurrentBatches", MemoryEstimations.builder().startField("trainOnBatch").add(GraphSageHelper.embeddingsEstimation(graphSageTrainMemoryEstimateParameters, 3 * graphSageTrainMemoryEstimateParameters.batchSize(), j, i, true)).fixed("updateAdamOptimizer", j3).endField().build()).endField().endField().build();
    }
}
