package org.neo4j.gds.ml.metrics;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.Optional;
import java.util.PrimitiveIterator;
import java.util.stream.DoubleStream;
import org.neo4j.gds.core.utils.TerminationFlag;
import org.neo4j.gds.core.utils.paged.HugeDoubleArray;
import org.neo4j.gds.core.utils.paged.HugeIntArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.splitting.EdgeSplitter;

/* loaded from: input_file:org/neo4j/gds/ml/metrics/SignedProbabilities.class */
public abstract class SignedProbabilities {
    static double ALMOST_ZERO = 1.0E-100d;
    private static final Comparator<Double> ABSOLUTE_VALUE_COMPARATOR = Comparator.comparingDouble((v0) -> {
        return Math.abs(v0);
    });
    private long positiveCount;
    private long negativeCount;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/neo4j/gds/ml/metrics/SignedProbabilities$ArrayBased.class */
    public static final class ArrayBased extends SignedProbabilities {
        private final ArrayList<Double> probabilities;

        private ArrayBased(int i) {
            this.probabilities = new ArrayList<>(i);
        }

        @Override // org.neo4j.gds.ml.metrics.SignedProbabilities
        void doAdd(double d) {
            this.probabilities.add(Double.valueOf(d));
        }

        @Override // org.neo4j.gds.ml.metrics.SignedProbabilities
        public DoubleStream stream() {
            this.probabilities.sort(SignedProbabilities.ABSOLUTE_VALUE_COMPARATOR);
            return this.probabilities.stream().mapToDouble(d -> {
                return d.doubleValue();
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/neo4j/gds/ml/metrics/SignedProbabilities$Huge.class */
    public static final class Huge extends SignedProbabilities {
        private final HugeDoubleArray probabilities;
        private long index = 0;

        Huge(long j) {
            this.probabilities = HugeDoubleArray.newArray(j);
        }

        @Override // org.neo4j.gds.ml.metrics.SignedProbabilities
        void doAdd(double d) {
            HugeDoubleArray hugeDoubleArray = this.probabilities;
            long j = this.index;
            this.index = j + 1;
            hugeDoubleArray.set(j, d);
        }

        @Override // org.neo4j.gds.ml.metrics.SignedProbabilities
        public DoubleStream stream() {
            return this.probabilities.stream().boxed().sorted(SignedProbabilities.ABSOLUTE_VALUE_COMPARATOR).mapToDouble(d -> {
                return d.doubleValue();
            });
        }
    }

    public static long estimateMemory(long j) {
        return MemoryUsage.sizeOfInstance(SignedProbabilities.class) + MemoryUsage.sizeOfInstance(Optional.class) + MemoryUsage.sizeOfInstance(ArrayList.class) + (MemoryUsage.sizeOfInstance(Double.class) * j);
    }

    static SignedProbabilities create(long j) {
        return j > 2147483647L ? new Huge(j) : new ArrayBased((int) j);
    }

    public static SignedProbabilities computeFromLabeledData(Features features, HugeIntArray hugeIntArray, Classifier classifier, BatchQueue batchQueue, int i, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        progressTracker.setSteps(features.size());
        SignedProbabilities create = create(batchQueue.totalSize());
        int i2 = 1;
        batchQueue.parallelConsume(i, i3 -> {
            return batch -> {
                Matrix predictProbabilities = classifier.predictProbabilities(batch, features);
                int i3 = 0;
                PrimitiveIterator.OfLong elementIds = batch.elementIds();
                while (elementIds.hasNext()) {
                    int i4 = i3;
                    i3++;
                    create.add(predictProbabilities.dataAt(i4, i2), ((double) hugeIntArray.get(elementIds.nextLong())) == 1.0d);
                }
                progressTracker.logSteps(batch.size());
            };
        }, terminationFlag);
        return create;
    }

    public synchronized void add(double d, boolean z) {
        double d2 = d == EdgeSplitter.NEGATIVE ? ALMOST_ZERO : d;
        double d3 = z ? d2 : (-1.0d) * d2;
        if (d3 > EdgeSplitter.NEGATIVE) {
            this.positiveCount++;
        } else {
            this.negativeCount++;
        }
        doAdd(d3);
    }

    abstract void doAdd(double d);

    public abstract DoubleStream stream();

    public long positiveCount() {
        return this.positiveCount;
    }

    public long negativeCount() {
        return this.negativeCount;
    }
}
