package org.neo4j.gds.ml.negativeSampling;

import java.util.Collection;
import java.util.Optional;
import java.util.SplittableRandom;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.utils.StringFormatting;

/* loaded from: input_file:org/neo4j/gds/ml/negativeSampling/UserInputNegativeSampler.class */
public class UserInputNegativeSampler implements NegativeSampler {
    private final Graph negativeExampleGraph;
    private final double testTrainFraction;
    private final SplittableRandom rng;

    public UserInputNegativeSampler(Graph graph, double d, Optional<Long> optional, Collection<NodeLabel> collection, Collection<NodeLabel> collection2) {
        if (!graph.schema().isUndirected()) {
            throw new IllegalArgumentException("UserInputNegativeSampler requires graph to be UNDIRECTED.");
        }
        this.negativeExampleGraph = graph;
        this.testTrainFraction = d;
        this.rng = (SplittableRandom) optional.map((v1) -> {
            return new SplittableRandom(v1);
        }).orElseGet(SplittableRandom::new);
        validateNegativeRelationships(collection, collection2);
    }

    @Override // org.neo4j.gds.ml.negativeSampling.NegativeSampler
    public void produceNegativeSamples(RelationshipsBuilder relationshipsBuilder, RelationshipsBuilder relationshipsBuilder2) {
        long relationshipCount = this.negativeExampleGraph.relationshipCount() / 2;
        long j = (long) (relationshipCount * this.testTrainFraction);
        MutableLong mutableLong = new MutableLong(j);
        MutableLong mutableLong2 = new MutableLong(relationshipCount - j);
        this.negativeExampleGraph.forEachNode(j2 -> {
            this.negativeExampleGraph.forEachRelationship(j2, (j2, j3) -> {
                if (j2 >= j3) {
                    return true;
                }
                if (sample(mutableLong.doubleValue() / (mutableLong.doubleValue() + mutableLong2.doubleValue()))) {
                    mutableLong.decrement();
                    relationshipsBuilder.add(j2, j3, NegativeSampler.NEGATIVE);
                    return true;
                }
                mutableLong2.decrement();
                relationshipsBuilder2.add(j2, j3, NegativeSampler.NEGATIVE);
                return true;
            });
            return true;
        });
    }

    private boolean sample(double d) {
        return this.rng.nextDouble() < d;
    }

    private void validateNegativeRelationships(Collection<NodeLabel> collection, Collection<NodeLabel> collection2) {
        this.negativeExampleGraph.forEachNode(j -> {
            this.negativeExampleGraph.forEachRelationship(j, (j, j2) -> {
                if (nodePairsHaveValidLabels(this.negativeExampleGraph.nodeLabels(j), this.negativeExampleGraph.nodeLabels(j2), collection, collection2)) {
                    return true;
                }
                throw new IllegalArgumentException(StringFormatting.formatWithLocale("There is a relationship of negativeRelationshipType between nodes %s and %s. The nodes have types %s and %s. However, they need to be between %s and %s.", new Object[]{Long.valueOf(this.negativeExampleGraph.toOriginalNodeId(j)), Long.valueOf(this.negativeExampleGraph.toOriginalNodeId(j2)), this.negativeExampleGraph.nodeLabels(j), this.negativeExampleGraph.nodeLabels(j2), collection.toString(), collection2.toString()}));
            });
            return true;
        });
    }

    /* JADX WARN: Code restructure failed: missing block: B:4:0x0030, code lost:
    
        if (r0.anyMatch((v1) -> { // java.util.function.Predicate.test(java.lang.Object):boolean
            return r1.contains(v1);
        }) == false) goto L6;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private boolean nodePairsHaveValidLabels(java.util.Collection<org.neo4j.gds.NodeLabel> r5, java.util.Collection<org.neo4j.gds.NodeLabel> r6, java.util.Collection<org.neo4j.gds.NodeLabel> r7, java.util.Collection<org.neo4j.gds.NodeLabel> r8) {
        /*
            r4 = this;
            r0 = r5
            java.util.stream.Stream r0 = r0.stream()
            r1 = r7
            r2 = r1
            java.lang.Object r2 = java.util.Objects.requireNonNull(r2)
            boolean r1 = (v1) -> { // java.util.function.Predicate.test(java.lang.Object):boolean
                return r1.contains(v1);
            }
            boolean r0 = r0.anyMatch(r1)
            if (r0 == 0) goto L33
            r0 = r6
            java.util.stream.Stream r0 = r0.stream()
            r1 = r8
            r2 = r1
            java.lang.Object r2 = java.util.Objects.requireNonNull(r2)
            boolean r1 = (v1) -> { // java.util.function.Predicate.test(java.lang.Object):boolean
                return r1.contains(v1);
            }
            boolean r0 = r0.anyMatch(r1)
            if (r0 != 0) goto L66
        L33:
            r0 = r5
            java.util.stream.Stream r0 = r0.stream()
            r1 = r8
            r2 = r1
            java.lang.Object r2 = java.util.Objects.requireNonNull(r2)
            boolean r1 = (v1) -> { // java.util.function.Predicate.test(java.lang.Object):boolean
                return r1.contains(v1);
            }
            boolean r0 = r0.anyMatch(r1)
            if (r0 == 0) goto L6a
            r0 = r6
            java.util.stream.Stream r0 = r0.stream()
            r1 = r7
            r2 = r1
            java.lang.Object r2 = java.util.Objects.requireNonNull(r2)
            boolean r1 = (v1) -> { // java.util.function.Predicate.test(java.lang.Object):boolean
                return r1.contains(v1);
            }
            boolean r0 = r0.anyMatch(r1)
            if (r0 == 0) goto L6a
        L66:
            r0 = 1
            goto L6b
        L6a:
            r0 = 0
        L6b:
            return r0
        */
        throw new UnsupportedOperationException("Method not decompiled: org.neo4j.gds.ml.negativeSampling.UserInputNegativeSampler.nodePairsHaveValidLabels(java.util.Collection, java.util.Collection, java.util.Collection, java.util.Collection):boolean");
    }
}
