package org.neo4j.gds.ml.nodeClassification;

import java.util.Iterator;
import java.util.function.Consumer;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.batch.BatchTransformer;
import org.neo4j.gds.ml.core.batch.MappedBatch;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;

/* loaded from: input_file:org/neo4j/gds/ml/nodeClassification/NodeClassificationPredictConsumer.class */
public class NodeClassificationPredictConsumer implements Consumer<Batch> {
    private final Features features;
    private final BatchTransformer nodeIds;
    private final Classifier classifier;

    @Nullable
    private final HugeObjectArray<double[]> predictedProbabilities;
    private final HugeLongArray predictedClasses;
    private final ProgressTracker progressTracker;

    /* JADX INFO: Access modifiers changed from: package-private */
    public NodeClassificationPredictConsumer(Features features, BatchTransformer batchTransformer, Classifier classifier, @Nullable HugeObjectArray<double[]> hugeObjectArray, HugeLongArray hugeLongArray, ProgressTracker progressTracker) {
        this.features = features;
        this.nodeIds = batchTransformer;
        this.classifier = classifier;
        this.predictedProbabilities = hugeObjectArray;
        this.predictedClasses = hugeLongArray;
        this.progressTracker = progressTracker;
    }

    @Override // java.util.function.Consumer
    public void accept(Batch batch) {
        int numberOfClasses = this.classifier.numberOfClasses();
        Matrix predictProbabilities = this.classifier.predictProbabilities(new MappedBatch(batch, this.nodeIds), this.features);
        int i = 0;
        Iterator it = batch.nodeIds().iterator();
        while (it.hasNext()) {
            long longValue = ((Long) it.next()).longValue();
            if (this.predictedProbabilities != null) {
                this.predictedProbabilities.set(longValue, predictProbabilities.getRow(i));
            }
            int i2 = -1;
            double d = -1.0d;
            for (int i3 = 0; i3 < numberOfClasses; i3++) {
                double dataAt = predictProbabilities.dataAt(i, i3);
                if (dataAt > d) {
                    d = dataAt;
                    i2 = i3;
                }
            }
            this.predictedClasses.set(longValue, this.classifier.classIdMap().toOriginal(i2));
            i++;
        }
        this.progressTracker.logProgress(batch.size());
    }
}
