package org.neo4j.gds.ml.nodemodels;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.NodeProperties;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.Predictor;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.nodemodels.metrics.ClassificationMetric;
import org.openjdk.jol.util.Multiset;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/ClassificationMetricComputer.class */
public class ClassificationMetricComputer implements MetricComputer {
    private final AllocationTracker allocationTracker;
    private final List<Metric> metrics;
    private final Multiset<Long> classCounts;
    private final Graph graph;
    private final NodeClassificationTrainConfig config;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;

    public ClassificationMetricComputer(AllocationTracker allocationTracker, List<Metric> list, Multiset<Long> multiset, Graph graph, NodeClassificationTrainConfig nodeClassificationTrainConfig, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.allocationTracker = allocationTracker;
        this.metrics = list;
        this.classCounts = multiset;
        this.graph = graph;
        this.config = nodeClassificationTrainConfig;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    @Override // org.neo4j.gds.ml.nodemodels.MetricComputer
    public Map<Metric, Double> computeMetrics(HugeLongArray hugeLongArray, Predictor<Matrix, ?> predictor) {
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size(), this.allocationTracker);
        Graph graph = this.graph;
        Objects.requireNonNull(hugeLongArray);
        new BatchQueue(hugeLongArray.size()).parallelConsume(new NodeClassificationPredictConsumer(graph, hugeLongArray::get, predictor, null, newArray, this.config.featureProperties(), this.progressTracker), this.config.concurrency(), this.terminationFlag);
        HugeLongArray makeLocalTargets = makeLocalTargets(hugeLongArray);
        return (Map) this.metrics.stream().collect(Collectors.toMap(metric -> {
            return metric;
        }, metric2 -> {
            return Double.valueOf(((ClassificationMetric) metric2).compute(makeLocalTargets, newArray, this.classCounts));
        }));
    }

    private HugeLongArray makeLocalTargets(HugeLongArray hugeLongArray) {
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size(), this.allocationTracker);
        NodeProperties nodeProperties = this.graph.nodeProperties(this.config.targetProperty());
        newArray.setAll(j -> {
            return nodeProperties.longValue(hugeLongArray.get(j));
        });
        return newArray;
    }
}
