package org.nd4j.linalg.jcublas.rng;

import java.util.List;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.rng.NativeRandom;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/rng/CudaNativeRandom.class */
public class CudaNativeRandom extends NativeRandom {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) CudaNativeRandom.class);
    protected List<DataBuffer> stateBuffers;

    public CudaNativeRandom() {
        this(System.currentTimeMillis());
    }

    public CudaNativeRandom(long j) {
        this(j, 10000000L);
    }

    public CudaNativeRandom(long j, long j2) {
        super(j, j2);
    }

    @Override // org.nd4j.rng.NativeRandom
    public void init() {
        this.statePointer = this.nativeOps.initRandom(getExtraPointers(), this.seed, this.numberOfElements, AtomicAllocator.getInstance().getPointer(this.stateBuffer));
        AtomicAllocator.getInstance().getAllocationPoint(this.stateBuffer).tickDeviceWrite();
    }

    @Override // org.nd4j.rng.NativeRandom
    public PointerPointer getExtraPointers() {
        PointerPointer pointerPointer = new PointerPointer(4L);
        CudaContext cudaContext = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
        pointerPointer.put(0L, AtomicAllocator.getInstance().getHostPointer(this.stateBuffer));
        pointerPointer.put(1L, cudaContext.getOldStream());
        return pointerPointer;
    }
}
