package org.neo4j.gds.ml.splitting;

import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.atomic.AtomicLong;
import org.neo4j.gds.Orientation;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.DefaultValue;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.Relationships;
import org.neo4j.gds.core.Aggregation;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.loading.construction.GraphFactory;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;

/* loaded from: input_file:org/neo4j/gds/ml/splitting/EdgeSplitter.class */
public abstract class EdgeSplitter {
    public static final double NEGATIVE = 0.0d;
    public static final double POSITIVE = 1.0d;
    public static final String RELATIONSHIP_PROPERTY = "label";
    private static final int MAX_RETRIES = 20;
    private final ThreadLocal<Random> rng;
    final double negativeSamplingRatio;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/splitting/EdgeSplitter$SplitResult.class */
    public interface SplitResult {
        Relationships remainingRels();

        Relationships selectedRels();

        static SplitResult of(Relationships relationships, Relationships relationships2) {
            return ImmutableSplitResult.of(relationships, relationships2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public EdgeSplitter(Optional<Long> optional, double d) {
        this.rng = ThreadLocal.withInitial(() -> {
            return (Random) optional.map((v1) -> {
                return new Random(v1);
            }).orElseGet(Random::new);
        });
        this.negativeSamplingRatio = d;
    }

    public abstract SplitResult split(Graph graph, double d);

    public abstract SplitResult split(Graph graph, Graph graph2, double d);

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean sample(double d) {
        return this.rng.get().nextDouble() < d;
    }

    private long randomNodeId(Graph graph) {
        return Math.abs(this.rng.get().nextLong() % graph.nodeCount());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public long samplesPerNode(long j, double d, long j2) {
        double d2 = d / j2;
        long j3 = (long) d2;
        return Math.min(j, j3 + (sample(d2 - ((double) j3)) ? 1 : 0));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RelationshipsBuilder newRelationshipsBuilderWithProp(Graph graph, Orientation orientation) {
        return newRelationshipsBuilder(graph, orientation, true);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RelationshipsBuilder newRelationshipsBuilder(Graph graph, Orientation orientation) {
        return newRelationshipsBuilder(graph, orientation, false);
    }

    private RelationshipsBuilder newRelationshipsBuilder(Graph graph, Orientation orientation, boolean z) {
        return GraphFactory.initRelationshipsBuilder().aggregation(Aggregation.SINGLE).nodes(graph).orientation(orientation).addAllPropertyConfigs(z ? List.of(GraphFactory.PropertyConfig.of(Aggregation.SINGLE, DefaultValue.forDouble())) : List.of()).concurrency(1).executorService(Pools.DEFAULT).build();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void negativeSampling(Graph graph, Graph graph2, RelationshipsBuilder relationshipsBuilder, AtomicLong atomicLong, long j) {
        int degree = graph2.degree(j);
        long samplesPerNode = samplesPerNode((graph2.nodeCount() - 1) - degree, atomicLong.get(), graph.nodeCount() - j);
        HashSet hashSet = new HashSet(degree);
        graph2.forEachRelationship(j, (j2, j3) -> {
            hashSet.add(Long.valueOf(j3));
            return true;
        });
        int i = MAX_RETRIES;
        int i2 = 0;
        while (i2 < samplesPerNode) {
            long randomNodeId = randomNodeId(graph);
            if (hashSet.contains(Long.valueOf(randomNodeId)) || randomNodeId == j) {
                int i3 = i;
                i--;
                if (i3 > 0) {
                    i2--;
                }
            } else {
                atomicLong.decrementAndGet();
                relationshipsBuilder.addFromInternal(graph.toRootNodeId(j), graph.toRootNodeId(randomNodeId), NEGATIVE);
            }
            i2++;
        }
    }
}
