package org.neo4j.gds.ml.nodeClassification;

import java.util.Objects;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.batch.BatchTransformer;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/nodeClassification/ParallelNodeClassifier.class */
public class ParallelNodeClassifier {
    private final Classifier classifier;
    private final Features features;
    private final int batchSize;
    private final int concurrency;
    private final TerminationFlag terminationFlag;
    private final ProgressTracker progressTracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    public ParallelNodeClassifier(Classifier classifier, Features features, int i, int i2, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        this.classifier = classifier;
        this.features = features;
        this.batchSize = i;
        this.concurrency = i2;
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
    }

    public HugeLongArray predict(HugeLongArray hugeLongArray) {
        long size = hugeLongArray.size();
        Objects.requireNonNull(hugeLongArray);
        return predict(size, hugeLongArray::get, null);
    }

    public HugeLongArray predict(@Nullable HugeObjectArray<double[]> hugeObjectArray) {
        return predict(this.features.size(), BatchTransformer.IDENTITY, hugeObjectArray);
    }

    private HugeLongArray predict(long j, BatchTransformer batchTransformer, @Nullable HugeObjectArray<double[]> hugeObjectArray) {
        HugeLongArray newArray = HugeLongArray.newArray(j);
        new BatchQueue(j, this.batchSize, this.concurrency).parallelConsume(new NodeClassificationPredictConsumer(this.features, batchTransformer, this.classifier, hugeObjectArray, newArray, this.progressTracker), this.concurrency, this.terminationFlag);
        return newArray;
    }
}
