package org.neo4j.gds.ml.nodemodels;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.nodemodels.metrics.ClassificationMetric;
import org.neo4j.gds.models.Classifier;
import org.neo4j.gds.models.Features;
import org.openjdk.jol.util.Multiset;

/* loaded from: input_file:org/neo4j/gds/ml/nodemodels/ClassificationMetricComputer.class */
public class ClassificationMetricComputer implements MetricComputer {
    private final List<Metric> metrics;
    private final Multiset<Long> classCounts;
    private final Features features;
    private final HugeLongArray targets;
    private final int concurrency;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;

    public ClassificationMetricComputer(List<Metric> list, Multiset<Long> multiset, Features features, HugeLongArray hugeLongArray, int i, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.metrics = list;
        this.classCounts = multiset;
        this.features = features;
        this.targets = hugeLongArray;
        this.concurrency = i;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    @Override // org.neo4j.gds.ml.nodemodels.MetricComputer
    public Map<Metric, Double> computeMetrics(HugeLongArray hugeLongArray, Classifier classifier) {
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size());
        Features features = this.features;
        Objects.requireNonNull(hugeLongArray);
        new BatchQueue(hugeLongArray.size()).parallelConsume(new NodeClassificationPredictConsumer(features, hugeLongArray::get, classifier, null, newArray, this.progressTracker), this.concurrency, this.terminationFlag);
        HugeLongArray makeLocalTargets = makeLocalTargets(hugeLongArray, this.targets);
        return (Map) this.metrics.stream().collect(Collectors.toMap(metric -> {
            return metric;
        }, metric2 -> {
            return Double.valueOf(((ClassificationMetric) metric2).compute(makeLocalTargets, newArray, this.classCounts));
        }));
    }

    private HugeLongArray makeLocalTargets(HugeLongArray hugeLongArray, HugeLongArray hugeLongArray2) {
        HugeLongArray newArray = HugeLongArray.newArray(hugeLongArray.size());
        newArray.setAll(j -> {
            return hugeLongArray2.get(hugeLongArray.get(j));
        });
        return newArray;
    }
}
