package org.neo4j.gds.ml.splitting;

import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.graphalgo.Orientation;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.loading.construction.RelationshipsBuilder;

/* loaded from: input_file:org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.class */
public class DirectedEdgeSplitter extends EdgeSplitter {
    public DirectedEdgeSplitter(Optional<Long> optional, double d) {
        super(optional, d);
    }

    @Override // org.neo4j.gds.ml.splitting.EdgeSplitter
    public EdgeSplitter.SplitResult split(Graph graph, double d) {
        return split(graph, graph, d);
    }

    @Override // org.neo4j.gds.ml.splitting.EdgeSplitter
    public EdgeSplitter.SplitResult split(Graph graph, Graph graph2, double d) {
        RelationshipsBuilder newRelationshipsBuilderWithProp = newRelationshipsBuilderWithProp(graph, Orientation.NATURAL);
        RelationshipsBuilder newRelationshipsBuilder = newRelationshipsBuilder(graph, Orientation.NATURAL);
        AtomicLong atomicLong = new AtomicLong((int) (graph.relationshipCount() * d));
        AtomicLong atomicLong2 = new AtomicLong((long) (this.negativeSamplingRatio * graph.relationshipCount() * d));
        graph.forEachNode(j -> {
            positiveSampling(graph, newRelationshipsBuilderWithProp, newRelationshipsBuilder, atomicLong, j);
            negativeSampling(graph, graph2, newRelationshipsBuilderWithProp, atomicLong2, j);
            return true;
        });
        return EdgeSplitter.SplitResult.of(newRelationshipsBuilder.build(), newRelationshipsBuilderWithProp.build());
    }

    private void positiveSampling(Graph graph, RelationshipsBuilder relationshipsBuilder, RelationshipsBuilder relationshipsBuilder2, AtomicLong atomicLong, long j) {
        int degree = graph.degree(j);
        long samplesPerNode = samplesPerNode(degree, atomicLong.get(), graph.nodeCount() - j);
        AtomicLong atomicLong2 = new AtomicLong(samplesPerNode);
        AtomicLong atomicLong3 = new AtomicLong(degree);
        graph.forEachRelationship(j, (j2, j3) -> {
            double d = atomicLong2.get();
            boolean sample = sample(d / atomicLong3.getAndDecrement());
            if (samplesPerNode <= 0 || d <= EdgeSplitter.NEGATIVE || !sample) {
                relationshipsBuilder2.addFromInternal(j2, j3);
                return true;
            }
            atomicLong.decrementAndGet();
            atomicLong2.decrementAndGet();
            relationshipsBuilder.addFromInternal(j2, j3, 1.0d);
            return true;
        });
    }
}
