package org.neo4j.gds.ml.nodePropertyPrediction;

import java.util.Optional;
import java.util.function.LongUnaryOperator;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.core.utils.paged.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeMergeSort;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.splitting.FractionSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.util.ShuffleUtil;
import org.neo4j.gds.ml.util.TrainingSetWarnings;

/* loaded from: input_file:org/neo4j/gds/ml/nodePropertyPrediction/NodeSplitter.class */
public final class NodeSplitter {
    private final int concurrency;
    private final long numberOfExamples;
    private final ProgressTracker progressTracker;
    private final LongUnaryOperator toOriginalId;
    private final LongUnaryOperator toMappedId;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/nodePropertyPrediction/NodeSplitter$NodeSplits.class */
    public interface NodeSplits {
        ReadOnlyHugeLongArray allTrainingExamples();

        TrainingExamplesSplit outerSplit();
    }

    public NodeSplitter(int i, long j, ProgressTracker progressTracker, LongUnaryOperator longUnaryOperator, LongUnaryOperator longUnaryOperator2) {
        this.concurrency = i;
        this.numberOfExamples = j;
        this.progressTracker = progressTracker;
        this.toOriginalId = longUnaryOperator;
        this.toMappedId = longUnaryOperator2;
    }

    public NodeSplits split(double d, int i, Optional<Long> optional) {
        HugeLongArray newArray = HugeLongArray.newArray(this.numberOfExamples);
        newArray.setAll(this.toOriginalId);
        HugeMergeSort.sort(newArray, this.concurrency);
        newArray.setAll(j -> {
            return this.toMappedId.applyAsLong(newArray.get(j));
        });
        ShuffleUtil.shuffleHugeLongArray(newArray, ShuffleUtil.createRandomDataGenerator(optional));
        TrainingExamplesSplit split = new FractionSplitter().split(ReadOnlyHugeLongArray.of(newArray), 1.0d - d);
        TrainingSetWarnings.warnForSmallNodeSets(split.trainSet().size(), split.testSet().size(), i, this.progressTracker);
        return ImmutableNodeSplits.of(ReadOnlyHugeLongArray.of(newArray), split);
    }
}
