package org.neo4j.gds.ml.negativeSampling;

import java.util.Optional;
import java.util.SplittableRandom;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;

/* loaded from: input_file:org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.class */
public class UserInputNegativeSampler implements NegativeSampler {
    private final Graph negativeExampleGraph;
    private final double testTrainFraction;
    private final SplittableRandom rng;

    public UserInputNegativeSampler(Graph graph, double d, Optional<Long> optional) {
        if (!graph.schema().isUndirected()) {
            throw new IllegalArgumentException("UserInputNegativeSampler requires graph to be UNDIRECTED.");
        }
        this.negativeExampleGraph = graph;
        this.testTrainFraction = d;
        this.rng = (SplittableRandom) optional.map((v1) -> {
            return new SplittableRandom(v1);
        }).orElseGet(SplittableRandom::new);
    }

    @Override // org.neo4j.gds.ml.negativeSampling.NegativeSampler
    public void produceNegativeSamples(RelationshipsBuilder relationshipsBuilder, RelationshipsBuilder relationshipsBuilder2) {
        long relationshipCount = this.negativeExampleGraph.relationshipCount() / 2;
        long j = (long) (relationshipCount * this.testTrainFraction);
        MutableLong mutableLong = new MutableLong(j);
        MutableLong mutableLong2 = new MutableLong(relationshipCount - j);
        this.negativeExampleGraph.forEachNode(j2 -> {
            this.negativeExampleGraph.forEachRelationship(j2, (j2, j3) -> {
                if (j2 >= j3) {
                    return true;
                }
                if (sample(mutableLong.doubleValue() / (mutableLong.doubleValue() + mutableLong2.doubleValue()))) {
                    mutableLong.decrement();
                    relationshipsBuilder.add(j2, j3, NegativeSampler.NEGATIVE);
                    return true;
                }
                mutableLong2.decrement();
                relationshipsBuilder2.add(j2, j3, NegativeSampler.NEGATIVE);
                return true;
            });
            return true;
        });
    }

    private boolean sample(double d) {
        return this.rng.nextDouble() < d;
    }
}
