package org.nd4j.linalg.jcublas.context;

import java.util.concurrent.atomic.AtomicBoolean;
import jcuda.driver.CUevent;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaEvent_t;
import jcuda.runtime.cudaStream_t;
import org.nd4j.linalg.jcublas.CublasPointer;

/* loaded from: input_file:org/nd4j/linalg/jcublas/context/CudaContext.class */
public class CudaContext implements AutoCloseable {
    private CUstream stream;
    private CUevent cUevent;
    private cudaStream_t oldStream;
    private cudaEvent_t oldEvent;
    private cublasHandle handle;
    private CublasPointer resultPointer;
    private AtomicBoolean oldStreamReturned;
    private AtomicBoolean handleReturned;
    private AtomicBoolean streamReturned;
    private boolean streamFromPool;
    private boolean handleFromPool;
    private boolean oldStreamFromPool;
    private boolean free;
    private boolean oldEventDestroyed;
    private boolean eventDestroyed;

    public CudaContext(boolean z) {
        this();
        this.free = z;
    }

    public CudaContext() {
        this.oldStreamReturned = new AtomicBoolean(false);
        this.handleReturned = new AtomicBoolean(false);
        this.streamReturned = new AtomicBoolean(false);
        this.streamFromPool = true;
        this.handleFromPool = true;
        this.oldStreamFromPool = true;
        this.free = true;
        this.oldEventDestroyed = true;
        this.eventDestroyed = true;
        ContextHolder.getInstance().setContext();
    }

    public void syncStream() {
        if (this.eventDestroyed) {
            return;
        }
        JCudaDriver.cuEventSynchronize(this.cUevent);
        JCudaDriver.cuEventDestroy(this.cUevent);
        this.eventDestroyed = true;
    }

    public void syncOldStream() {
        if (this.oldEventDestroyed) {
            return;
        }
        JCuda.cudaStreamWaitEvent(this.oldStream, this.oldEvent, 0);
        JCuda.cudaEventDestroy(this.oldEvent);
        this.oldEventDestroyed = true;
    }

    public void syncHandle() {
        syncOldStream();
    }

    public CublasPointer getResultPointer() {
        return this.resultPointer;
    }

    public void associateHandle() {
        ContextHolder.getInstance().setContext();
        JCublas2.cublasSetStream(this.handle, this.oldStream);
    }

    public void startOldEvent() {
        JCuda.cudaEventRecord(this.oldEvent, this.oldStream);
    }

    public void startNewEvent() {
        JCudaDriver.cuEventRecord(this.cUevent, this.stream);
    }

    public void initStream() {
        if (this.stream == null) {
            try {
                this.stream = (CUstream) ContextHolder.getInstance().getStreamPool().borrowObject();
            } catch (Exception e) {
                this.stream = new CUstream();
                JCudaDriver.cuStreamCreate(this.stream, 1);
                this.streamFromPool = false;
            }
            this.cUevent = new CUevent();
            JCudaDriver.cuEventCreate(this.cUevent, 0);
            this.eventDestroyed = false;
        }
    }

    public void initOldStream() {
        if (this.oldStream == null) {
            try {
                this.oldStream = (cudaStream_t) ContextHolder.getInstance().getOldStreamPool().borrowObject();
            } catch (Exception e) {
                this.oldStreamFromPool = false;
                this.oldStream = new cudaStream_t();
                JCuda.cudaStreamCreate(this.oldStream);
            }
            this.oldEvent = new cudaEvent_t();
            JCuda.cudaEventCreate(this.oldEvent);
            this.oldEventDestroyed = false;
        }
    }

    public void initHandle() {
        if (this.handle == null) {
            try {
                this.handle = (cublasHandle) ContextHolder.getInstance().getHandlePool().borrowObject();
            } catch (Exception e) {
                this.handle = new cublasHandle();
                JCublas2.cublasCreate(this.handle);
                this.handleFromPool = false;
            }
            associateHandle();
        }
    }

    public void destroy(CublasPointer cublasPointer, boolean z) {
        if (this.handle != null && !this.handleReturned.get()) {
            try {
                if (this.handleFromPool) {
                    ContextHolder.getInstance().getHandlePool().returnObject(this.handle);
                } else {
                    JCublas2.cublasDestroy(this.handle);
                }
                this.handleReturned.set(true);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if (this.stream != null && !this.streamReturned.get()) {
            try {
                if (this.streamFromPool) {
                    ContextHolder.getInstance().getStreamPool().returnObject(this.stream);
                } else {
                    JCudaDriver.cuStreamDestroy(this.stream);
                }
                this.streamReturned.set(true);
            } catch (Exception e2) {
                e2.printStackTrace();
            }
        }
        if (this.oldStream != null && !this.oldStreamReturned.get()) {
            try {
                if (this.oldStreamFromPool) {
                    ContextHolder.getInstance().getOldStreamPool().returnObject(this.oldStream);
                } else {
                    JCuda.cudaStreamDestroy(this.oldStream);
                }
                this.oldStreamReturned.set(true);
            } catch (Exception e3) {
                e3.printStackTrace();
            }
        }
        if (cublasPointer != null && z && z) {
            cublasPointer.copyToHost();
            try {
                cublasPointer.close();
            } catch (Exception e4) {
                e4.printStackTrace();
            }
        }
        if (!this.oldEventDestroyed) {
            JCuda.cudaEventDestroy(this.oldEvent);
            this.oldEventDestroyed = true;
        }
        if (this.eventDestroyed) {
            return;
        }
        JCudaDriver.cuEventDestroy(this.cUevent);
        this.eventDestroyed = true;
    }

    public void destroy() {
        if (this.handle != null && !this.handleReturned.get()) {
            try {
                if (this.handleFromPool) {
                    ContextHolder.getInstance().getHandlePool().returnObject(this.handle);
                } else {
                    JCublas2.cublasDestroy(this.handle);
                }
                this.handleReturned.set(true);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if (this.stream != null && !this.streamReturned.get()) {
            try {
                if (this.streamFromPool) {
                    ContextHolder.getInstance().getStreamPool().returnObject(this.stream);
                } else {
                    JCudaDriver.cuStreamDestroy(this.stream);
                }
                this.streamReturned.set(true);
            } catch (Exception e2) {
                e2.printStackTrace();
            }
        }
        if (this.oldStream != null && !this.oldStreamReturned.get()) {
            try {
                if (this.oldStreamFromPool) {
                    ContextHolder.getInstance().getOldStreamPool().returnObject(this.oldStream);
                } else {
                    JCuda.cudaStreamDestroy(this.oldStream);
                }
                this.oldStreamReturned.set(true);
            } catch (Exception e3) {
                e3.printStackTrace();
            }
        }
        if (this.resultPointer != null) {
            this.resultPointer.copyToHost();
            try {
                this.resultPointer.close();
            } catch (Exception e4) {
                e4.printStackTrace();
            }
        }
    }

    public void finishBlasOperation() {
        destroy();
    }

    public static CudaContext getBlasContext() {
        CudaContext cudaContext = new CudaContext();
        cudaContext.initOldStream();
        cudaContext.initHandle();
        cudaContext.startOldEvent();
        return cudaContext;
    }

    public void syncDevice() {
        JCuda.cudaDeviceSynchronize();
    }

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        destroy();
    }

    public CUstream getStream() {
        return this.stream;
    }

    public CUevent getCUevent() {
        return this.cUevent;
    }

    public cudaStream_t getOldStream() {
        return this.oldStream;
    }

    public cudaEvent_t getOldEvent() {
        return this.oldEvent;
    }

    public cublasHandle getHandle() {
        return this.handle;
    }

    public AtomicBoolean getOldStreamReturned() {
        return this.oldStreamReturned;
    }

    public AtomicBoolean getHandleReturned() {
        return this.handleReturned;
    }

    public AtomicBoolean getStreamReturned() {
        return this.streamReturned;
    }

    public boolean isStreamFromPool() {
        return this.streamFromPool;
    }

    public boolean isHandleFromPool() {
        return this.handleFromPool;
    }

    public boolean isOldStreamFromPool() {
        return this.oldStreamFromPool;
    }

    public boolean isFree() {
        return this.free;
    }

    public boolean isOldEventDestroyed() {
        return this.oldEventDestroyed;
    }

    public boolean isEventDestroyed() {
        return this.eventDestroyed;
    }

    public void setStream(CUstream cUstream) {
        this.stream = cUstream;
    }

    public void setCUevent(CUevent cUevent) {
        this.cUevent = cUevent;
    }

    public void setOldStream(cudaStream_t cudastream_t) {
        this.oldStream = cudastream_t;
    }

    public void setOldEvent(cudaEvent_t cudaevent_t) {
        this.oldEvent = cudaevent_t;
    }

    public void setHandle(cublasHandle cublashandle) {
        this.handle = cublashandle;
    }

    public void setResultPointer(CublasPointer cublasPointer) {
        this.resultPointer = cublasPointer;
    }

    public void setOldStreamReturned(AtomicBoolean atomicBoolean) {
        this.oldStreamReturned = atomicBoolean;
    }

    public void setHandleReturned(AtomicBoolean atomicBoolean) {
        this.handleReturned = atomicBoolean;
    }

    public void setStreamReturned(AtomicBoolean atomicBoolean) {
        this.streamReturned = atomicBoolean;
    }

    public void setStreamFromPool(boolean z) {
        this.streamFromPool = z;
    }

    public void setHandleFromPool(boolean z) {
        this.handleFromPool = z;
    }

    public void setOldStreamFromPool(boolean z) {
        this.oldStreamFromPool = z;
    }

    public void setFree(boolean z) {
        this.free = z;
    }

    public void setOldEventDestroyed(boolean z) {
        this.oldEventDestroyed = z;
    }

    public void setEventDestroyed(boolean z) {
        this.eventDestroyed = z;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CudaContext)) {
            return false;
        }
        CudaContext cudaContext = (CudaContext) obj;
        if (!cudaContext.canEqual(this)) {
            return false;
        }
        CUstream stream = getStream();
        CUstream stream2 = cudaContext.getStream();
        if (stream == null) {
            if (stream2 != null) {
                return false;
            }
        } else if (!stream.equals(stream2)) {
            return false;
        }
        CUevent cUevent = getCUevent();
        CUevent cUevent2 = cudaContext.getCUevent();
        if (cUevent == null) {
            if (cUevent2 != null) {
                return false;
            }
        } else if (!cUevent.equals(cUevent2)) {
            return false;
        }
        cudaStream_t oldStream = getOldStream();
        cudaStream_t oldStream2 = cudaContext.getOldStream();
        if (oldStream == null) {
            if (oldStream2 != null) {
                return false;
            }
        } else if (!oldStream.equals(oldStream2)) {
            return false;
        }
        cudaEvent_t oldEvent = getOldEvent();
        cudaEvent_t oldEvent2 = cudaContext.getOldEvent();
        if (oldEvent == null) {
            if (oldEvent2 != null) {
                return false;
            }
        } else if (!oldEvent.equals(oldEvent2)) {
            return false;
        }
        cublasHandle handle = getHandle();
        cublasHandle handle2 = cudaContext.getHandle();
        if (handle == null) {
            if (handle2 != null) {
                return false;
            }
        } else if (!handle.equals(handle2)) {
            return false;
        }
        CublasPointer resultPointer = getResultPointer();
        CublasPointer resultPointer2 = cudaContext.getResultPointer();
        if (resultPointer == null) {
            if (resultPointer2 != null) {
                return false;
            }
        } else if (!resultPointer.equals(resultPointer2)) {
            return false;
        }
        AtomicBoolean oldStreamReturned = getOldStreamReturned();
        AtomicBoolean oldStreamReturned2 = cudaContext.getOldStreamReturned();
        if (oldStreamReturned == null) {
            if (oldStreamReturned2 != null) {
                return false;
            }
        } else if (!oldStreamReturned.equals(oldStreamReturned2)) {
            return false;
        }
        AtomicBoolean handleReturned = getHandleReturned();
        AtomicBoolean handleReturned2 = cudaContext.getHandleReturned();
        if (handleReturned == null) {
            if (handleReturned2 != null) {
                return false;
            }
        } else if (!handleReturned.equals(handleReturned2)) {
            return false;
        }
        AtomicBoolean streamReturned = getStreamReturned();
        AtomicBoolean streamReturned2 = cudaContext.getStreamReturned();
        if (streamReturned == null) {
            if (streamReturned2 != null) {
                return false;
            }
        } else if (!streamReturned.equals(streamReturned2)) {
            return false;
        }
        return isStreamFromPool() == cudaContext.isStreamFromPool() && isHandleFromPool() == cudaContext.isHandleFromPool() && isOldStreamFromPool() == cudaContext.isOldStreamFromPool() && isFree() == cudaContext.isFree() && isOldEventDestroyed() == cudaContext.isOldEventDestroyed() && isEventDestroyed() == cudaContext.isEventDestroyed();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CudaContext;
    }

    public int hashCode() {
        CUstream stream = getStream();
        int hashCode = (1 * 59) + (stream == null ? 0 : stream.hashCode());
        CUevent cUevent = getCUevent();
        int hashCode2 = (hashCode * 59) + (cUevent == null ? 0 : cUevent.hashCode());
        cudaStream_t oldStream = getOldStream();
        int hashCode3 = (hashCode2 * 59) + (oldStream == null ? 0 : oldStream.hashCode());
        cudaEvent_t oldEvent = getOldEvent();
        int hashCode4 = (hashCode3 * 59) + (oldEvent == null ? 0 : oldEvent.hashCode());
        cublasHandle handle = getHandle();
        int hashCode5 = (hashCode4 * 59) + (handle == null ? 0 : handle.hashCode());
        CublasPointer resultPointer = getResultPointer();
        int hashCode6 = (hashCode5 * 59) + (resultPointer == null ? 0 : resultPointer.hashCode());
        AtomicBoolean oldStreamReturned = getOldStreamReturned();
        int hashCode7 = (hashCode6 * 59) + (oldStreamReturned == null ? 0 : oldStreamReturned.hashCode());
        AtomicBoolean handleReturned = getHandleReturned();
        int hashCode8 = (hashCode7 * 59) + (handleReturned == null ? 0 : handleReturned.hashCode());
        AtomicBoolean streamReturned = getStreamReturned();
        return (((((((((((((hashCode8 * 59) + (streamReturned == null ? 0 : streamReturned.hashCode())) * 59) + (isStreamFromPool() ? 79 : 97)) * 59) + (isHandleFromPool() ? 79 : 97)) * 59) + (isOldStreamFromPool() ? 79 : 97)) * 59) + (isFree() ? 79 : 97)) * 59) + (isOldEventDestroyed() ? 79 : 97)) * 59) + (isEventDestroyed() ? 79 : 97);
    }

    public String toString() {
        return "CudaContext(stream=" + getStream() + ", cUevent=" + getCUevent() + ", oldStream=" + getOldStream() + ", oldEvent=" + getOldEvent() + ", handle=" + getHandle() + ", resultPointer=" + getResultPointer() + ", oldStreamReturned=" + getOldStreamReturned() + ", handleReturned=" + getHandleReturned() + ", streamReturned=" + getStreamReturned() + ", streamFromPool=" + isStreamFromPool() + ", handleFromPool=" + isHandleFromPool() + ", oldStreamFromPool=" + isOldStreamFromPool() + ", free=" + isFree() + ", oldEventDestroyed=" + isOldEventDestroyed() + ", eventDestroyed=" + isEventDestroyed() + ")";
    }
}
