package org.neo4j.gds.ml.nodemodels;

import java.util.List;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.ml.batch.BatchQueue;
import org.neo4j.gds.ml.batch.BatchTransformer;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRPredictor;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRResult;
import org.neo4j.graphalgo.Algorithm;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/NodeClassificationPredict.class */
public class NodeClassificationPredict extends Algorithm<NodeClassificationPredict, MultiClassNLRResult> {
    private final MultiClassNLRPredictor predictor;
    private final Graph graph;
    private final int batchSize;
    private final int concurrency;
    private final boolean produceProbabilities;
    private final List<String> featureProperties;
    private final AllocationTracker tracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    public NodeClassificationPredict(MultiClassNLRPredictor multiClassNLRPredictor, Graph graph, int i, int i2, boolean z, List<String> list, AllocationTracker allocationTracker, ProgressLogger progressLogger) {
        this.predictor = multiClassNLRPredictor;
        this.graph = graph;
        this.concurrency = i2;
        this.batchSize = i;
        this.produceProbabilities = z;
        this.featureProperties = list;
        this.tracker = allocationTracker;
        this.progressLogger = progressLogger;
    }

    /* renamed from: compute, reason: merged with bridge method [inline-methods] */
    public MultiClassNLRResult m13compute() {
        this.progressLogger.logStart();
        HugeObjectArray<double[]> initProbabilities = initProbabilities();
        HugeLongArray newArray = HugeLongArray.newArray(this.graph.nodeCount(), this.tracker);
        new BatchQueue(this.graph.nodeCount(), this.batchSize).parallelConsume(new NodeClassificationPredictConsumer(this.graph, BatchTransformer.IDENTITY, this.predictor, initProbabilities, newArray, this.featureProperties, this.progressLogger), this.concurrency);
        this.progressLogger.logFinish();
        return MultiClassNLRResult.of(newArray, initProbabilities);
    }

    /* renamed from: me, reason: merged with bridge method [inline-methods] */
    public NodeClassificationPredict m12me() {
        return this;
    }

    public void release() {
    }

    @Nullable
    private HugeObjectArray<double[]> initProbabilities() {
        if (!this.produceProbabilities) {
            return null;
        }
        int length = this.predictor.modelData().classIdMap().originalIds().length;
        HugeObjectArray<double[]> newArray = HugeObjectArray.newArray(double[].class, this.graph.nodeCount(), this.tracker);
        newArray.setAll(j -> {
            return new double[length];
        });
        return newArray;
    }
}
