package org.neo4j.gds.ml.splitting;

import com.carrotsearch.hppc.predicates.LongLongPredicate;
import com.carrotsearch.hppc.predicates.LongPredicate;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Random;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.DefaultValue;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
import org.neo4j.gds.api.schema.Direction;
import org.neo4j.gds.core.Aggregation;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.loading.construction.GraphFactory;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;

/* loaded from: input_file:org/neo4j/gds/ml/splitting/EdgeSplitter.class */
public abstract class EdgeSplitter {
    public static final double POSITIVE = 1.0d;
    public static final String RELATIONSHIP_PROPERTY = "label";
    private final Random rng = new Random();
    private final RelationshipType selectedRelationshipType;
    private final RelationshipType remainingRelationshipType;
    protected final IdMap sourceNodes;
    protected final IdMap targetNodes;
    protected final IdMap rootNodes;
    protected int concurrency;

    @ValueClass
    /* loaded from: input_file:org/neo4j/gds/ml/splitting/EdgeSplitter$SplitResult.class */
    public interface SplitResult {
        RelationshipsBuilder remainingRels();

        long remainingRelCount();

        RelationshipsBuilder selectedRels();

        long selectedRelCount();

        static SplitResult of(RelationshipsBuilder relationshipsBuilder, long j, RelationshipsBuilder relationshipsBuilder2, long j2) {
            return ImmutableSplitResult.of(relationshipsBuilder, j, relationshipsBuilder2, j2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public EdgeSplitter(Optional<Long> optional, IdMap idMap, IdMap idMap2, IdMap idMap3, RelationshipType relationshipType, RelationshipType relationshipType2, int i) {
        this.rootNodes = idMap;
        this.selectedRelationshipType = relationshipType;
        this.remainingRelationshipType = relationshipType2;
        Random random = this.rng;
        Objects.requireNonNull(random);
        optional.ifPresent((v1) -> {
            r1.setSeed(v1);
        });
        this.sourceNodes = idMap2;
        this.targetNodes = idMap3;
        this.concurrency = i;
    }

    public SplitResult splitPositiveExamples(Graph graph, double d, Optional<String> optional) {
        LongPredicate longPredicate = j -> {
            return this.sourceNodes.contains(graph.toOriginalNodeId(j));
        };
        LongPredicate longPredicate2 = j2 -> {
            return this.targetNodes.contains(graph.toOriginalNodeId(j2));
        };
        LongLongPredicate longLongPredicate = (j3, j4) -> {
            return longPredicate.apply(j3) && longPredicate2.apply(j4);
        };
        RelationshipsBuilder newRelationshipsBuilder = newRelationshipsBuilder(this.rootNodes, this.selectedRelationshipType, Direction.DIRECTED, Optional.of(RELATIONSHIP_PROPERTY));
        RelationshipsBuilder newRelationshipsBuilder2 = newRelationshipsBuilder(this.rootNodes, this.remainingRelationshipType, graph.schema().direction(), optional);
        RelationshipWithPropertyConsumer relationshipWithPropertyConsumer = (j5, j6, d2) -> {
            newRelationshipsBuilder2.addFromInternal(graph.toRootNodeId(j5), graph.toRootNodeId(j6), d2);
            return true;
        };
        long validPositiveRelationshipCandidateCount = validPositiveRelationshipCandidateCount(graph, longLongPredicate);
        MutableLong mutableLong = new MutableLong((long) (validPositiveRelationshipCandidateCount * d));
        MutableLong mutableLong2 = new MutableLong(validPositiveRelationshipCandidateCount);
        MutableLong mutableLong3 = new MutableLong(0L);
        MutableLong mutableLong4 = new MutableLong(0L);
        graph.forEachNode(j7 -> {
            positiveSampling(graph, newRelationshipsBuilder, relationshipWithPropertyConsumer, mutableLong3, mutableLong4, j7, longLongPredicate, mutableLong, mutableLong2);
            return true;
        });
        return SplitResult.of(newRelationshipsBuilder2, mutableLong4.longValue(), newRelationshipsBuilder, mutableLong3.longValue());
    }

    protected abstract void positiveSampling(Graph graph, RelationshipsBuilder relationshipsBuilder, RelationshipWithPropertyConsumer relationshipWithPropertyConsumer, MutableLong mutableLong, MutableLong mutableLong2, long j, LongLongPredicate longLongPredicate, MutableLong mutableLong3, MutableLong mutableLong4);

    protected abstract long validPositiveRelationshipCandidateCount(Graph graph, LongLongPredicate longLongPredicate);

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean sample(double d) {
        return this.rng.nextDouble() < d;
    }

    protected long samplesPerNode(long j, double d, long j2) {
        double d2 = d / j2;
        long j3 = (long) d2;
        return Math.min(j, j3 + (sample(d2 - ((double) j3)) ? 1 : 0));
    }

    private static RelationshipsBuilder newRelationshipsBuilder(IdMap idMap, RelationshipType relationshipType, Direction direction, Optional<String> optional) {
        return GraphFactory.initRelationshipsBuilder().relationshipType(relationshipType).aggregation(Aggregation.SINGLE).nodes(idMap).orientation(direction.toOrientation()).addAllPropertyConfigs((Iterable) optional.map(str -> {
            return List.of(GraphFactory.PropertyConfig.of(str, Aggregation.SINGLE, DefaultValue.forDouble()));
        }).orElse(List.of())).concurrency(1).executorService(Pools.DEFAULT).build();
    }
}
