package org.neo4j.gds.ml.splitting;

import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.graphalgo.Orientation;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.core.loading.construction.RelationshipsBuilder;

/* loaded from: input_file:org/neo4j/gds/ml/splitting/DirectedEdgeSplitter.class */
public class DirectedEdgeSplitter extends EdgeSplitter {
    public DirectedEdgeSplitter(Optional<Long> optional) {
        super(optional);
    }

    public DirectedEdgeSplitter(long j) {
        this((Optional<Long>) Optional.of(Long.valueOf(j)));
    }

    @Override // org.neo4j.gds.ml.splitting.EdgeSplitter
    public EdgeSplitter.SplitResult split(Graph graph, double d) {
        return split(graph, graph, d);
    }

    @Override // org.neo4j.gds.ml.splitting.EdgeSplitter
    public EdgeSplitter.SplitResult split(Graph graph, Graph graph2, double d) {
        RelationshipsBuilder newRelationshipsBuilderWithProp = newRelationshipsBuilderWithProp(graph, Orientation.NATURAL);
        RelationshipsBuilder newRelationshipsBuilder = newRelationshipsBuilder(graph, Orientation.NATURAL);
        int relationshipCount = (int) (graph.relationshipCount() * d);
        AtomicLong atomicLong = new AtomicLong((long) (graph.relationshipCount() * d));
        AtomicLong atomicLong2 = new AtomicLong(0L);
        graph.forEachNode(j -> {
            positiveSampling(graph, newRelationshipsBuilderWithProp, newRelationshipsBuilder, relationshipCount, atomicLong2, j);
            negativeSampling(graph, graph2, newRelationshipsBuilderWithProp, atomicLong, j);
            return true;
        });
        return EdgeSplitter.SplitResult.of(newRelationshipsBuilder.build(), newRelationshipsBuilderWithProp.build());
    }

    private void positiveSampling(Graph graph, RelationshipsBuilder relationshipsBuilder, RelationshipsBuilder relationshipsBuilder2, int i, AtomicLong atomicLong, long j) {
        int degree = graph.degree(j);
        long samplesPerNode = samplesPerNode(degree, i - atomicLong.get(), graph.nodeCount() - j);
        long j2 = atomicLong.get();
        AtomicLong atomicLong2 = new AtomicLong(degree);
        graph.forEachRelationship(j, (j3, j4) -> {
            double d = samplesPerNode - (atomicLong.get() - j2);
            boolean sample = sample(d / atomicLong2.getAndDecrement());
            if (samplesPerNode <= 0 || d <= EdgeSplitter.NEGATIVE || !sample) {
                relationshipsBuilder2.addFromInternal(j3, j4);
                return true;
            }
            atomicLong.incrementAndGet();
            relationshipsBuilder.addFromInternal(j3, j4, 1.0d);
            return true;
        });
    }
}
