package org.nd4j.linalg.jcublas.context;

import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaStream_t;

/* loaded from: input_file:org/nd4j/linalg/jcublas/context/CudaContext.class */
public class CudaContext implements AutoCloseable {
    private CUstream stream;
    private cudaStream_t oldStream;
    private cublasHandle handle;

    public CudaContext() {
        ContextHolder.getInstance().setContext();
    }

    public void syncStream() {
        JCudaDriver.cuStreamSynchronize(this.stream);
    }

    public void syncOldStream() {
        JCuda.cudaStreamSynchronize(this.oldStream);
    }

    public void syncHandle() {
        syncOldStream();
    }

    public void associateHandle() {
        JCublas2.cublasSetStream(this.handle, this.oldStream);
    }

    public void initStream() {
        this.stream = new CUstream();
        JCudaDriver.cuStreamCreate(this.stream, 1);
    }

    public void initOldStream() {
        this.oldStream = new cudaStream_t();
        JCuda.cudaStreamCreate(this.oldStream);
    }

    public void initHandle() {
        this.handle = new cublasHandle();
        JCublas2.cublasCreate(this.handle);
        associateHandle();
    }

    public void destroy() {
        if (this.handle != null) {
            JCublas2.cublasDestroy(this.handle);
        }
        if (this.stream != null) {
            JCudaDriver.cuStreamDestroy(this.stream);
        }
        if (this.oldStream != null) {
            JCuda.cudaStreamDestroy(this.oldStream);
        }
    }

    public void finishBlasOperation() {
        syncOldStream();
        destroy();
    }

    public static CudaContext getBlasContext() {
        CudaContext cudaContext = new CudaContext();
        cudaContext.initOldStream();
        cudaContext.initHandle();
        return cudaContext;
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        destroy();
    }

    public CUstream getStream() {
        return this.stream;
    }

    public cudaStream_t getOldStream() {
        return this.oldStream;
    }

    public cublasHandle getHandle() {
        return this.handle;
    }

    public void setStream(CUstream cUstream) {
        this.stream = cUstream;
    }

    public void setOldStream(cudaStream_t cudastream_t) {
        this.oldStream = cudastream_t;
    }

    public void setHandle(cublasHandle cublashandle) {
        this.handle = cublashandle;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CudaContext)) {
            return false;
        }
        CudaContext cudaContext = (CudaContext) obj;
        if (!cudaContext.canEqual(this)) {
            return false;
        }
        CUstream stream = getStream();
        CUstream stream2 = cudaContext.getStream();
        if (stream == null) {
            if (stream2 != null) {
                return false;
            }
        } else if (!stream.equals(stream2)) {
            return false;
        }
        cudaStream_t oldStream = getOldStream();
        cudaStream_t oldStream2 = cudaContext.getOldStream();
        if (oldStream == null) {
            if (oldStream2 != null) {
                return false;
            }
        } else if (!oldStream.equals(oldStream2)) {
            return false;
        }
        cublasHandle handle = getHandle();
        cublasHandle handle2 = cudaContext.getHandle();
        return handle == null ? handle2 == null : handle.equals(handle2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CudaContext;
    }

    public int hashCode() {
        CUstream stream = getStream();
        int hashCode = (1 * 59) + (stream == null ? 0 : stream.hashCode());
        cudaStream_t oldStream = getOldStream();
        int hashCode2 = (hashCode * 59) + (oldStream == null ? 0 : oldStream.hashCode());
        cublasHandle handle = getHandle();
        return (hashCode2 * 59) + (handle == null ? 0 : handle.hashCode());
    }

    public String toString() {
        return "CudaContext(stream=" + getStream() + ", oldStream=" + getOldStream() + ", handle=" + getHandle() + ")";
    }
}
