package org.neo4j.gds.embeddings.graphsage;

import com.carrotsearch.hppc.LongHashSet;
import java.util.Arrays;
import java.util.List;
import java.util.SplittableRandom;
import java.util.stream.LongStream;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.ImmutableRelationshipCursor;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.ml.core.samplers.WeightedUniformSampler;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/BatchSampler.class */
public final class BatchSampler {
    public static final double DEGREE_SMOOTHING_FACTOR = 0.75d;
    private final Graph graph;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    public BatchSampler(Graph graph) {
        this.graph = graph;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<long[]> extendedBatches(int i, int i2, long j) {
        return PartitionUtils.rangePartitionWithBatchSize(this.graph.nodeCount(), i, partition -> {
            return sampleNeighborAndNegativeNodePerBatchNode(partition, i2, Math.toIntExact(Math.floorDiv(partition.startNode(), this.graph.nodeCount())) + j);
        });
    }

    long[] sampleNeighborAndNegativeNodePerBatchNode(Partition partition, int i, long j) {
        long[] neighborBatch = neighborBatch(partition, j, i);
        return LongStream.concat(partition.stream(), LongStream.concat(Arrays.stream(neighborBatch), negativeBatch(Math.toIntExact(partition.nodeCount()), neighborBatch, j))).toArray();
    }

    long[] neighborBatch(Partition partition, long j, int i) {
        int intExact = Math.toIntExact(partition.nodeCount());
        long[] jArr = new long[intExact];
        SplittableRandom splittableRandom = new SplittableRandom(j);
        long startNode = partition.startNode();
        for (int i2 = 0; i2 < intExact; i2++) {
            int nextInt = splittableRandom.nextInt(i) + 1;
            MutableLong mutableLong = new MutableLong(startNode + i2);
            while (nextInt > 0) {
                int degree = this.graph.degree(mutableLong.longValue());
                if (degree != 0) {
                    int nextInt2 = splittableRandom.nextInt(degree);
                    long nthTarget = this.graph.nthTarget(mutableLong.longValue(), nextInt2);
                    if (!$assertionsDisabled && nthTarget == -1) {
                        throw new AssertionError("The offset '" + nextInt2 + "' is bound by the degree but no target could be found for nodeId " + mutableLong.longValue());
                    }
                    mutableLong.setValue(nthTarget);
                } else {
                    nextInt = 0;
                }
                nextInt--;
            }
            jArr[i2] = mutableLong.longValue();
        }
        return jArr;
    }

    LongStream negativeBatch(int i, long[] jArr, long j) {
        long nodeCount = this.graph.nodeCount();
        WeightedUniformSampler weightedUniformSampler = new WeightedUniformSampler(j);
        LongHashSet longHashSet = new LongHashSet(jArr.length);
        longHashSet.addAll(jArr);
        return weightedUniformSampler.sample(LongStream.range(0L, nodeCount).mapToObj(j2 -> {
            return ImmutableRelationshipCursor.of(0L, j2, Math.pow(this.graph.degree(j2), 0.75d));
        }), nodeCount, i, j3 -> {
            return !longHashSet.contains(j3);
        });
    }

    static {
        $assertionsDisabled = !BatchSampler.class.desiredAssertionStatus();
    }
}
