package org.neo4j.gds.ml.splitting;

import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.ml.splitting.EdgeSplitter;

/* loaded from: input_file:org/neo4j/gds/ml/splitting/UndirectedEdgeSplitter.class */
public class UndirectedEdgeSplitter extends EdgeSplitter {
    public UndirectedEdgeSplitter(Optional<Long> optional, double d) {
        super(optional, d);
    }

    @Override // org.neo4j.gds.ml.splitting.EdgeSplitter
    public EdgeSplitter.SplitResult split(Graph graph, double d) {
        return split(graph, graph, d);
    }

    @Override // org.neo4j.gds.ml.splitting.EdgeSplitter
    public EdgeSplitter.SplitResult split(Graph graph, Graph graph2, double d) {
        if (!graph.isUndirected()) {
            throw new IllegalArgumentException("EdgeSplitter requires graph to be UNDIRECTED");
        }
        if (!graph2.isUndirected()) {
            throw new IllegalArgumentException("EdgeSplitter requires master graph to be UNDIRECTED");
        }
        RelationshipsBuilder newRelationshipsBuilderWithProp = newRelationshipsBuilderWithProp(graph, Orientation.NATURAL);
        RelationshipsBuilder newRelationshipsBuilder = newRelationshipsBuilder(graph, Orientation.UNDIRECTED);
        AtomicLong atomicLong = new AtomicLong(((long) (graph.relationshipCount() * d)) / 2);
        AtomicLong atomicLong2 = new AtomicLong(((long) ((this.negativeSamplingRatio * graph.relationshipCount()) * d)) / 2);
        AtomicLong atomicLong3 = new AtomicLong(graph.relationshipCount());
        graph.forEachNode(j -> {
            positiveSampling(graph, newRelationshipsBuilderWithProp, newRelationshipsBuilder, atomicLong, atomicLong3, j);
            negativeSampling(graph, graph2, newRelationshipsBuilderWithProp, atomicLong2, j);
            return true;
        });
        return EdgeSplitter.SplitResult.of(newRelationshipsBuilder.build(), newRelationshipsBuilderWithProp.build());
    }

    private void positiveSampling(Graph graph, RelationshipsBuilder relationshipsBuilder, RelationshipsBuilder relationshipsBuilder2, AtomicLong atomicLong, AtomicLong atomicLong2, long j) {
        graph.forEachRelationship(j, (j2, j3) -> {
            if (j2 == j3) {
                atomicLong2.decrementAndGet();
            }
            if (j2 >= j3) {
                return true;
            }
            if (sample((2.0d * atomicLong.get()) / atomicLong2.get())) {
                atomicLong.decrementAndGet();
                if (sample(0.5d)) {
                    relationshipsBuilder.addFromInternal(j2, j3, 1.0d);
                } else {
                    relationshipsBuilder.addFromInternal(j3, j2, 1.0d);
                }
            } else {
                relationshipsBuilder2.addFromInternal(j2, j3);
            }
            atomicLong2.addAndGet(-2L);
            return true;
        });
    }
}
