package org.neo4j.gds.ml.nodemodels.multiclasslogisticregression;

import java.util.Iterator;
import java.util.function.Consumer;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.embeddings.graphsage.ddl4j.tensor.Matrix;
import org.neo4j.gds.ml.Predictor;
import org.neo4j.gds.ml.batch.Batch;
import org.neo4j.gds.ml.batch.BatchTransformer;
import org.neo4j.gds.ml.batch.MappedBatch;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.utils.ProgressLogger;
import org.neo4j.graphalgo.core.utils.paged.HugeLongArray;
import org.neo4j.graphalgo.core.utils.paged.HugeObjectArray;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/multiclasslogisticregression/NodeClassificationPredictConsumer.class */
public class NodeClassificationPredictConsumer implements Consumer<Batch> {
    private final Graph graph;
    private final BatchTransformer nodeIds;
    private final Predictor<Matrix, MultiClassNLRData> predictor;
    private final HugeObjectArray<double[]> predictedProbabilities;
    private final HugeLongArray predictedClasses;
    private final ProgressLogger progressLogger;

    public NodeClassificationPredictConsumer(Graph graph, BatchTransformer batchTransformer, Predictor<Matrix, MultiClassNLRData> predictor, @Nullable HugeObjectArray<double[]> hugeObjectArray, HugeLongArray hugeLongArray, ProgressLogger progressLogger) {
        this.graph = graph;
        this.nodeIds = batchTransformer;
        this.predictor = predictor;
        this.predictedProbabilities = hugeObjectArray;
        this.predictedClasses = hugeLongArray;
        this.progressLogger = progressLogger;
    }

    @Override // java.util.function.Consumer
    public void accept(Batch batch) {
        Matrix predict = this.predictor.predict(this.graph, new MappedBatch(batch, this.nodeIds));
        int cols = predict.cols();
        double[] data = predict.data();
        int i = 0;
        Iterator it = batch.nodeIds().iterator();
        while (it.hasNext()) {
            long longValue = ((Long) it.next()).longValue();
            int i2 = i * cols;
            if (this.predictedProbabilities != null) {
                double[] dArr = new double[cols];
                System.arraycopy(data, i2, dArr, 0, cols);
                this.predictedProbabilities.set(longValue, dArr);
            }
            int i3 = -1;
            double d = -1.0d;
            for (int i4 = 0; i4 < cols; i4++) {
                double d2 = data[i2 + i4];
                if (d2 > d) {
                    d = d2;
                    i3 = i4;
                }
            }
            this.predictedClasses.set(longValue, this.predictor.modelData().classIdMap().toOriginal(i3));
            i++;
        }
        this.progressLogger.logProgress(batch.size());
    }
}
