package org.neo4j.gds.embeddings.graphsage.algo;

import java.util.List;
import org.neo4j.gds.embeddings.graphsage.Aggregator;
import org.neo4j.gds.embeddings.graphsage.GraphSageHelper;
import org.neo4j.gds.embeddings.graphsage.LayerConfig;
import org.neo4j.graphalgo.AbstractAlgorithmFactory;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimation;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimations;
import org.neo4j.graphalgo.core.utils.mem.MemoryRange;
import org.neo4j.graphalgo.core.utils.mem.MemoryUsage;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/algo/GraphSageTrainAlgorithmFactory.class */
public final class GraphSageTrainAlgorithmFactory extends AbstractAlgorithmFactory<GraphSageTrain, GraphSageTrainConfig> {
    public GraphSageTrainAlgorithmFactory() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public long taskVolume(Graph graph, GraphSageTrainConfig graphSageTrainConfig) {
        return 1L;
    }

    protected String taskName() {
        return GraphSageTrain.class.getSimpleName();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public GraphSageTrain build(Graph graph, GraphSageTrainConfig graphSageTrainConfig, AllocationTracker allocationTracker, ProgressLogger progressLogger) {
        return graphSageTrainConfig.isMultiLabel() ? new MultiLabelGraphSageTrain(graph, graphSageTrainConfig, progressLogger, allocationTracker) : new SingleLabelGraphSageTrain(graph, graphSageTrainConfig, progressLogger, allocationTracker);
    }

    public MemoryEstimation memoryEstimation(GraphSageTrainConfig graphSageTrainConfig) {
        return MemoryEstimations.setup("", graphDimensions -> {
            return estimate(graphSageTrainConfig, graphDimensions.nodeCount(), graphDimensions.estimationNodeLabelCount());
        });
    }

    private MemoryEstimation estimate(GraphSageTrainConfig graphSageTrainConfig, long j, int i) {
        List<LayerConfig> layerConfigs = graphSageTrainConfig.layerConfigs();
        int size = layerConfigs.size();
        MemoryEstimations.Builder startField = MemoryEstimations.builder("GraphSageTrain").startField("residentMemory").startField("weights");
        long j2 = 0;
        long j3 = 0;
        for (int i2 = 0; i2 < size; i2++) {
            LayerConfig layerConfig = layerConfigs.get(i2);
            int rows = layerConfig.rows() * layerConfig.cols();
            long sizeOfDoubleArray = MemoryUsage.sizeOfDoubleArray(rows);
            if (layerConfig.aggregatorType() == Aggregator.AggregatorType.POOL) {
                sizeOfDoubleArray = sizeOfDoubleArray + MemoryUsage.sizeOfDoubleArray(layerConfig.rows() * layerConfig.rows()) + MemoryUsage.sizeOfDoubleArray(layerConfig.rows() * layerConfig.rows()) + MemoryUsage.sizeOfDoubleArray(layerConfig.rows());
            }
            startField.fixed("layer " + (i2 + 1), sizeOfDoubleArray);
            j2 += 2 * MemoryUsage.sizeOfDoubleArray(rows);
            j3 += 5 * rows;
        }
        boolean isMultiLabel = graphSageTrainConfig.isMultiLabel();
        MemoryEstimation memoryEstimation = HugeObjectArray.memoryEstimation(MemoryEstimations.of("", MemoryRange.of(MemoryUsage.sizeOfDoubleArray(isMultiLabel ? graphSageTrainConfig.degreeAsProperty() ? 2L : 1L : graphSageTrainConfig.featuresSize()), MemoryUsage.sizeOfDoubleArray(graphSageTrainConfig.featuresSize()))));
        MemoryEstimations.Builder field = startField.endField().endField().startField("temporaryMemory").field("this.instance", GraphSage.class);
        if (isMultiLabel) {
            field.fixed("weightsByLabel", MemoryRange.of(MemoryUsage.sizeOfDoubleArray(graphSageTrainConfig.featuresSize() * (graphSageTrainConfig.degreeAsProperty() ? 2 : 1)), MemoryUsage.sizeOfDoubleArray(graphSageTrainConfig.featuresSize() * (graphSageTrainConfig.featureProperties().size() + (graphSageTrainConfig.degreeAsProperty() ? 1 : 0) + 1))).times(i));
        }
        return field.add("initialFeatures", memoryEstimation).startField("trainOnEpoch").fixed("initialAdamOptimizer", j2).perThread("concurrentBatches", MemoryEstimations.builder().startField("trainOnBatch").add(GraphSageHelper.embeddingsEstimation(graphSageTrainConfig, 3 * graphSageTrainConfig.batchSize(), j, i, true)).fixed("updateAdamOptimizer", j3).endField().build()).endField().endField().build();
    }

    public GraphSageTrainAlgorithmFactory(ProgressLogger.ProgressLoggerFactory progressLoggerFactory) {
        super(progressLoggerFactory);
    }
}
