package org.neo4j.gds.embeddings.graphsage;

import com.carrotsearch.hppc.LongHashSet;
import java.util.Arrays;
import java.util.List;
import java.util.OptionalLong;
import java.util.Random;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.LongStream;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.ImmutableRelationshipCursor;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.ml.core.samplers.WeightedUniformSampler;
import org.neo4j.gds.ml.core.subgraph.NeighborhoodSampler;

/* loaded from: input_file:org/neo4j/gds/embeddings/graphsage/BatchSampler.class */
public final class BatchSampler {
    public static final double DEGREE_SMOOTHING_FACTOR = 0.75d;
    private final Graph graph;

    /* JADX INFO: Access modifiers changed from: package-private */
    public BatchSampler(Graph graph) {
        this.graph = graph;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public List<long[]> extendedBatches(int i, int i2, long j) {
        return PartitionUtils.rangePartitionWithBatchSize(this.graph.nodeCount(), i, partition -> {
            return sampleNeighborAndNegativeNodePerBatchNode(partition, i2, Math.toIntExact(Math.floorDiv(partition.startNode(), this.graph.nodeCount())) + j);
        });
    }

    long[] sampleNeighborAndNegativeNodePerBatchNode(Partition partition, int i, long j) {
        long[] array = neighborBatch(partition, j, i).toArray();
        return LongStream.concat(partition.stream(), LongStream.concat(Arrays.stream(array), negativeBatch(Math.toIntExact(partition.nodeCount()), array, j))).toArray();
    }

    LongStream neighborBatch(Partition partition, long j, int i) {
        LongStream.Builder builder = LongStream.builder();
        Random random = new Random(j);
        partition.consume(j2 -> {
            int nextInt = random.nextInt(i) + 1;
            AtomicLong atomicLong = new AtomicLong(j2);
            while (nextInt > 0) {
                OptionalLong sampleOne = new NeighborhoodSampler(atomicLong.get() + nextInt).sampleOne(this.graph, j2);
                if (sampleOne.isPresent()) {
                    atomicLong.set(sampleOne.getAsLong());
                } else {
                    nextInt = 0;
                }
                nextInt--;
            }
            builder.add(atomicLong.get());
        });
        return builder.build();
    }

    LongStream negativeBatch(int i, long[] jArr, long j) {
        long nodeCount = this.graph.nodeCount();
        WeightedUniformSampler weightedUniformSampler = new WeightedUniformSampler(j);
        LongHashSet longHashSet = new LongHashSet(jArr.length);
        longHashSet.addAll(jArr);
        return weightedUniformSampler.sample(LongStream.range(0L, nodeCount).mapToObj(j2 -> {
            return ImmutableRelationshipCursor.of(0L, j2, Math.pow(this.graph.degree(j2), 0.75d));
        }), nodeCount, i, j3 -> {
            return !longHashSet.contains(j3);
        });
    }
}
