package org.nd4j.linalg.jcublas;

import java.util.Arrays;
import jcuda.Pointer;
import jcuda.jcublas.JCublas;
import jcuda.jcublas.JCublas2;
import org.apache.commons.lang3.tuple.Triple;
import org.nd4j.linalg.api.blas.BlasBufferUtil;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.jcublas.buffer.DevicePointerInfo;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.jcublas.context.CudaContext;

/* loaded from: input_file:org/nd4j/linalg/jcublas/CublasPointer.class */
public class CublasPointer implements AutoCloseable {
    private JCudaBuffer buffer;
    private Pointer devicePointer;
    private INDArray arr;
    private CudaContext cudaContext;
    private boolean closed = false;
    private boolean resultPointer = false;

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        if (isResultPointer()) {
            return;
        }
        destroy();
    }

    public void destroy() {
        if (this.closed) {
            return;
        }
        if (this.arr != null) {
            this.buffer.freeDevicePointer(this.arr.offset(), this.arr.length());
        } else {
            this.buffer.freeDevicePointer(0, this.buffer.length());
        }
        this.closed = true;
    }

    public JCudaBuffer getBuffer() {
        return this.buffer;
    }

    public Pointer getDevicePointer() {
        return this.devicePointer;
    }

    public void copyToHost() {
        if (this.arr != null) {
            this.buffer.copyToHost(this.cudaContext, this.arr.offset(), this.arr instanceof IComplexNDArray ? this.arr.length() * 2 : this.arr.length());
        } else {
            this.buffer.copyToHost(this.cudaContext, 0, this.buffer.length());
        }
    }

    public CublasPointer(JCudaBuffer jCudaBuffer, CudaContext cudaContext) {
        this.buffer = jCudaBuffer;
        this.devicePointer = jCudaBuffer.getDevicePointer(1, 0, jCudaBuffer.length());
        this.cudaContext = cudaContext;
        cudaContext.initOldStream();
        JCublas2.cublasSetVectorAsync(jCudaBuffer.length(), jCudaBuffer.getElementSize(), jCudaBuffer.getHostPointer(), 1, this.devicePointer, 1, cudaContext.getOldStream());
        jCudaBuffer.setCopied(Thread.currentThread().getName());
    }

    public CublasPointer(INDArray iNDArray, CudaContext cudaContext) {
        if ((iNDArray instanceof IComplexNDArray) && iNDArray.length() * 2 < iNDArray.data().length() && !iNDArray.isVector()) {
            iNDArray = Shape.toOffsetZero(iNDArray);
        }
        this.cudaContext = cudaContext;
        this.buffer = (JCudaBuffer) iNDArray.data();
        String name = Thread.currentThread().getName();
        this.arr = iNDArray;
        if (iNDArray.elementWiseStride() < 0) {
            this.arr = iNDArray.dup();
            this.buffer = (JCudaBuffer) this.arr.data();
            if (this.arr.elementWiseStride() < 0) {
                throw new IllegalStateException("Unable to iterate over buffer");
            }
        }
        int length = this.arr instanceof IComplexNDArray ? this.arr.length() * 2 : this.arr.length();
        this.devicePointer = this.buffer.getDevicePointer(this.arr, this.arr instanceof IComplexNDArray ? BlasBufferUtil.getBlasStride(this.arr) / 2 : BlasBufferUtil.getBlasStride(this.arr), this.arr.offset(), length);
        if (this.buffer.copied(name)) {
            return;
        }
        JCublas.cublasSetVectorAsync(this.buffer.length(), this.arr.data().getElementSize(), this.buffer.getHostPointer(), 1, ((DevicePointerInfo) this.buffer.getPointersToContexts().get(name, Triple.of(0, Integer.valueOf(this.buffer.length()), 1))).getPointer(), 1, this.cudaContext.getOldStream());
        this.buffer.setCopied(name);
    }

    public boolean isResultPointer() {
        return this.resultPointer;
    }

    public void setResultPointer(boolean z) {
        this.resultPointer = z;
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.devicePointer == null) {
            stringBuffer.append("No device pointer yet");
        } else if (this.arr != null) {
            if (((this.arr instanceof IComplexNDArray) && this.arr.length() * 2 == this.buffer.length()) || this.arr.length() == this.buffer.length()) {
                appendWhereArrayLengthEqualsBufferLength(stringBuffer);
            } else {
                appendWhereArrayLengthLessThanBufferLength(stringBuffer);
            }
        } else if (this.buffer.dataType() == DataBuffer.Type.DOUBLE) {
            double[] dArr = new double[this.buffer.length()];
            JCublas2.cublasGetVectorAsync(this.buffer.length(), this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(dArr), 1, this.cudaContext.getOldStream());
            stringBuffer.append(Arrays.toString(dArr));
        } else if (this.buffer.dataType() == DataBuffer.Type.INT) {
            int[] iArr = new int[this.buffer.length()];
            JCublas2.cublasGetVectorAsync(this.buffer.length(), this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(iArr), 1, this.cudaContext.getOldStream());
            stringBuffer.append(Arrays.toString(iArr));
        } else {
            float[] fArr = new float[this.buffer.length()];
            JCublas2.cublasGetVectorAsync(this.buffer.length(), this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(fArr), 1, this.cudaContext.getOldStream());
            stringBuffer.append(Arrays.toString(fArr));
        }
        return stringBuffer.toString();
    }

    private void appendWhereArrayLengthLessThanBufferLength(StringBuffer stringBuffer) {
        int length = this.arr instanceof IComplexNDArray ? this.arr.length() * 2 : this.arr.length();
        if (this.arr.data().dataType() == DataBuffer.Type.DOUBLE) {
            double[] dArr = new double[length];
            JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, BlasBufferUtil.getBlasStride(this.arr), Pointer.to(dArr), 1, this.cudaContext.getOldStream());
            this.cudaContext.syncOldStream();
            stringBuffer.append(Arrays.toString(dArr));
            return;
        }
        if (this.arr.data().dataType() == DataBuffer.Type.INT) {
            int[] iArr = new int[length];
            JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, BlasBufferUtil.getBlasStride(this.arr), Pointer.to(iArr), 1, this.cudaContext.getOldStream());
            ContextHolder.syncStream();
            stringBuffer.append(Arrays.toString(iArr));
            return;
        }
        float[] fArr = new float[length];
        JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, BlasBufferUtil.getBlasStride(this.arr), Pointer.to(fArr), 1, this.cudaContext.getOldStream());
        ContextHolder.syncStream();
        stringBuffer.append(Arrays.toString(fArr));
    }

    private void appendWhereArrayLengthEqualsBufferLength(StringBuffer stringBuffer) {
        int length = this.arr instanceof IComplexNDArray ? this.arr.length() * 2 : this.arr.length();
        if (this.arr.data().dataType() == DataBuffer.Type.DOUBLE) {
            double[] dArr = new double[length];
            JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(dArr), 1, this.cudaContext.getOldStream());
            this.cudaContext.syncOldStream();
            stringBuffer.append(Arrays.toString(dArr));
            return;
        }
        if (this.arr.data().dataType() == DataBuffer.Type.INT) {
            int[] iArr = new int[length];
            JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(iArr), 1, this.cudaContext.getOldStream());
            this.cudaContext.syncOldStream();
            stringBuffer.append(Arrays.toString(iArr));
            return;
        }
        float[] fArr = new float[length];
        JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(fArr), 1, this.cudaContext.getOldStream());
        this.cudaContext.syncOldStream();
        stringBuffer.append(Arrays.toString(fArr));
    }

    public static void free(CublasPointer... cublasPointerArr) {
        for (CublasPointer cublasPointer : cublasPointerArr) {
            try {
                cublasPointer.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}
