package org.neo4j.gds.ml.nodeClassification;

import java.util.function.LongUnaryOperator;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetric;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.TrainerConfig;

/* loaded from: input_file:org/neo4j/gds/ml/nodeClassification/ClassificationMetricComputer.class */
public final class ClassificationMetricComputer {
    private final HugeIntArray predictedClasses;
    private final HugeIntArray labels;

    private ClassificationMetricComputer(HugeIntArray hugeIntArray, HugeIntArray hugeIntArray2) {
        this.labels = hugeIntArray2;
        this.predictedClasses = hugeIntArray;
    }

    public double score(ClassificationMetric classificationMetric) {
        return classificationMetric.compute(this.labels, this.predictedClasses);
    }

    public static ClassificationMetricComputer forEvaluationSet(Features features, HugeIntArray hugeIntArray, ReadOnlyHugeLongArray readOnlyHugeLongArray, Classifier classifier, int i, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        return new ClassificationMetricComputer(new ParallelNodeClassifier(classifier, features, 100, i, terminationFlag, progressTracker).predict(readOnlyHugeLongArray), makeLocalTargets(readOnlyHugeLongArray, hugeIntArray));
    }

    private static HugeIntArray makeLocalTargets(ReadOnlyHugeLongArray readOnlyHugeLongArray, HugeIntArray hugeIntArray) {
        HugeIntArray newArray = HugeIntArray.newArray(readOnlyHugeLongArray.size());
        newArray.setAll(j -> {
            return hugeIntArray.get(readOnlyHugeLongArray.get(j));
        });
        return newArray;
    }

    public static MemoryEstimation estimateEvaluation(TrainerConfig trainerConfig, int i, LongUnaryOperator longUnaryOperator, LongUnaryOperator longUnaryOperator2, int i2, int i3, boolean z) {
        return MemoryEstimations.builder("computing metrics").perNode("local targets", j -> {
            return HugeLongArray.memoryEstimation(longUnaryOperator2.applyAsLong(j));
        }).perNode("predicted classes", j2 -> {
            return HugeLongArray.memoryEstimation(longUnaryOperator2.applyAsLong(j2));
        }).add("classifier model", ClassifierFactory.dataMemoryEstimation(trainerConfig, longUnaryOperator, i2, i3, z)).rangePerNode("classifier runtime", j3 -> {
            return ClassifierFactory.runtimeOverheadMemoryEstimation(trainerConfig.method(), i, i2, i3, z);
        }).build();
    }
}
