package org.nd4j.linalg.jcublas;

import java.util.Arrays;
import jcuda.Pointer;
import jcuda.jcublas.JCublas;
import jcuda.jcublas.JCublas2;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.ContextHolder;
import org.nd4j.linalg.util.LinearUtil;

/* loaded from: input_file:org/nd4j/linalg/jcublas/CublasPointer.class */
public class CublasPointer implements AutoCloseable {
    final JCudaBuffer buffer;
    final Pointer devicePointer;
    private boolean closed = false;
    private INDArray arr;

    @Override // java.lang.AutoCloseable
    public void close() throws Exception {
        if (this.closed) {
            return;
        }
        if (this.arr != null) {
            this.buffer.freeDevicePointer(this.arr.offset());
        } else {
            this.buffer.freeDevicePointer(0);
        }
        this.closed = true;
    }

    public JCudaBuffer getBuffer() {
        return this.buffer;
    }

    public Pointer getDevicePointer() {
        return this.devicePointer;
    }

    public void copyToHost() {
        if (this.arr != null) {
            this.buffer.copyToHost(this.arr.offset());
        } else {
            this.buffer.copyToHost(0);
        }
    }

    public CublasPointer(JCudaBuffer jCudaBuffer) {
        this.buffer = jCudaBuffer;
        this.devicePointer = jCudaBuffer.getDevicePointer(1, 0, jCudaBuffer.length());
        JCublas2.cublasSetVectorAsync(jCudaBuffer.length(), jCudaBuffer.getElementSize(), jCudaBuffer.getHostPointer(), 1, this.devicePointer, 1, ContextHolder.getInstance().getCudaStream());
        ContextHolder.syncStream();
    }

    public CublasPointer(INDArray iNDArray) {
        this.buffer = (JCudaBuffer) iNDArray.data();
        String name = Thread.currentThread().getName();
        this.arr = iNDArray;
        if (!(iNDArray instanceof IComplexNDArray)) {
            this.devicePointer = this.buffer.getDevicePointer(iNDArray.majorStride(), iNDArray.offset(), iNDArray.length());
            if (this.buffer.copied(name)) {
                return;
            }
            JCublas.cublasSetVectorAsync(this.buffer.length(), iNDArray.data().getElementSize(), this.buffer.getHostPointer(), 1, ((BaseCudaDataBuffer.DevicePointerInfo) this.buffer.getPointersToContexts().get(name, 0)).getPointer(), 1, ContextHolder.getInstance().getCudaStream());
            this.buffer.setCopied(name);
            return;
        }
        this.devicePointer = this.buffer.getDevicePointer(iNDArray.majorStride(), iNDArray.offset(), iNDArray.length());
        if (this.buffer.copied(name)) {
            return;
        }
        JCublas.cublasSetVectorAsync(this.buffer.length(), iNDArray.data().getElementSize(), this.buffer.getHostPointer(), 1, ((BaseCudaDataBuffer.DevicePointerInfo) this.buffer.getPointersToContexts().get(name, 0)).getPointer(), 1, ContextHolder.getInstance().getCudaStream());
        this.buffer.setCopied(name);
        toString();
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.devicePointer == null) {
            stringBuffer.append("No device pointer yet");
        } else if (this.arr != null) {
            if (((this.arr instanceof IComplexNDArray) && this.arr.length() * 2 == this.buffer.length()) || this.arr.length() == this.buffer.length()) {
                appendWhereArrayLengthEqualsBufferLength(stringBuffer);
            } else {
                appendWhereArrayLengthLessThanBufferLength(stringBuffer);
            }
        } else if (this.buffer.dataType() == DataBuffer.Type.DOUBLE) {
            double[] dArr = new double[this.buffer.length()];
            JCublas2.cublasGetVectorAsync(this.buffer.length(), this.buffer.getElementSize(), this.devicePointer, LinearUtil.linearStride(this.arr), Pointer.to(dArr), 1, ContextHolder.getInstance().getCudaStream());
            stringBuffer.append(Arrays.toString(dArr));
        } else {
            float[] fArr = new float[this.buffer.length()];
            JCublas2.cublasGetVectorAsync(this.buffer.length(), this.buffer.getElementSize(), this.devicePointer, LinearUtil.linearStride(this.arr), Pointer.to(fArr), 1, ContextHolder.getInstance().getCudaStream());
            stringBuffer.append(Arrays.toString(fArr));
        }
        return stringBuffer.toString();
    }

    private void appendWhereArrayLengthLessThanBufferLength(StringBuffer stringBuffer) {
        int length = this.arr instanceof IComplexNDArray ? this.arr.length() * 2 : this.arr.length();
        if (this.arr.data().dataType() == DataBuffer.Type.DOUBLE) {
            double[] dArr = new double[length];
            JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, this.arr.majorStride(), Pointer.to(dArr), 1, ContextHolder.getInstance().getCudaStream());
            ContextHolder.syncStream();
            stringBuffer.append(Arrays.toString(dArr));
            return;
        }
        float[] fArr = new float[length];
        JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(fArr), 1, ContextHolder.getInstance().getCudaStream());
        ContextHolder.syncStream();
        stringBuffer.append(Arrays.toString(fArr));
    }

    private void appendWhereArrayLengthEqualsBufferLength(StringBuffer stringBuffer) {
        int length = this.arr instanceof IComplexNDArray ? this.arr.length() * 2 : this.arr.length();
        if (this.arr.data().dataType() == DataBuffer.Type.DOUBLE) {
            double[] dArr = new double[length];
            JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(dArr), 1, ContextHolder.getInstance().getCudaStream());
            ContextHolder.syncStream();
            stringBuffer.append(Arrays.toString(dArr));
            return;
        }
        float[] fArr = new float[length];
        JCublas2.cublasGetVectorAsync(length, this.buffer.getElementSize(), this.devicePointer, 1, Pointer.to(fArr), 1, ContextHolder.getInstance().getCudaStream());
        ContextHolder.syncStream();
        stringBuffer.append(Arrays.toString(fArr));
    }
}
