package org.neo4j.gds.ml.nodemodels;

import java.util.List;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.batch.BatchTransformer;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeClassificationResult;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionPredictor;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationPredict.class */
public class NodeClassificationPredict extends Algorithm<NodeClassificationPredict, NodeClassificationResult> {
    private final NodeLogisticRegressionPredictor predictor;
    private final Graph graph;
    private final int batchSize;
    private final int concurrency;
    private final boolean produceProbabilities;
    private final List<String> featureProperties;
    private final AllocationTracker allocationTracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    public static MemoryEstimation memoryEstimation(boolean z, int i, int i2, int i3) {
        MemoryEstimations.Builder builder = MemoryEstimations.builder(NodeClassificationPredict.class);
        if (z) {
            builder.perNode("predicted probabilities", j -> {
                return HugeObjectArray.memoryEstimation(j, MemoryUsage.sizeOfDoubleArray(i3));
            });
        }
        builder.perNode("predicted classes", HugeLongArray::memoryEstimation);
        builder.fixed("computation graph", NodeLogisticRegressionPredictor.sizeOfPredictionsVariableInBytes(i, i2, i3));
        return builder.build();
    }

    public NodeClassificationPredict(NodeLogisticRegressionPredictor nodeLogisticRegressionPredictor, Graph graph, int i, int i2, boolean z, List<String> list, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        super(progressTracker);
        this.predictor = nodeLogisticRegressionPredictor;
        this.graph = graph;
        this.concurrency = i2;
        this.batchSize = i;
        this.produceProbabilities = z;
        this.featureProperties = list;
        this.allocationTracker = allocationTracker;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public NodeClassificationResult m73compute() {
        this.progressTracker.beginSubTask();
        HugeObjectArray<double[]> initProbabilities = initProbabilities();
        HugeLongArray newArray = HugeLongArray.newArray(this.graph.nodeCount(), this.allocationTracker);
        new BatchQueue(this.graph.nodeCount(), this.batchSize).parallelConsume(new NodeClassificationPredictConsumer(this.graph, BatchTransformer.IDENTITY, this.predictor, initProbabilities, newArray, this.featureProperties, this.progressTracker), this.concurrency, this.terminationFlag);
        this.progressTracker.endSubTask();
        return NodeClassificationResult.of(newArray, initProbabilities);
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public NodeClassificationPredict m72me() {
        return this;
    }

    public void release() {
    }

    @Nullable
    private HugeObjectArray<double[]> initProbabilities() {
        if (!this.produceProbabilities) {
            return null;
        }
        int length = this.predictor.modelData().classIdMap().originalIds().length;
        HugeObjectArray<double[]> newArray = HugeObjectArray.newArray(double[].class, this.graph.nodeCount(), this.allocationTracker);
        newArray.setAll(j -> {
            return new double[length];
        });
        return newArray;
    }
}
