package org.nd4j.linalg.jcublas;

import java.lang.reflect.Method;
import java.util.List;
import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.jcublas.JCublas;
import org.nd4j.linalg.api.ndarray.BaseNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/jcublas/JCublasNDArray.class */
public class JCublasNDArray extends BaseNDArray {
    private Pointer pointer;
    private Pointer dataPointer;

    public JCublasNDArray(double[][] dArr) {
        super(dArr);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, char c) {
        super(fArr, iArr, c);
        setupJcuBlas();
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int i, char c) {
        super(fArr, iArr, i, c);
        setupJcuBlas();
    }

    public JCublasNDArray(int[] iArr, int[] iArr2, int i, char c) {
        super(iArr, iArr2, i, c);
        setupJcuBlas();
    }

    public JCublasNDArray(int[] iArr, int[] iArr2, char c) {
        super(iArr, iArr2, c);
        setupJcuBlas();
    }

    public JCublasNDArray(int[] iArr, int i, char c) {
        super(iArr, i, c);
        setupJcuBlas();
    }

    public JCublasNDArray(int[] iArr) {
        super(iArr);
        setupJcuBlas();
    }

    public JCublasNDArray(int i, int i2, char c) {
        super(i, i2, c);
        setupJcuBlas();
    }

    public JCublasNDArray(List<INDArray> list, int[] iArr, char c) {
        super(list, iArr, c);
        setupJcuBlas();
    }

    public JCublasNDArray(List<INDArray> list, int[] iArr, int[] iArr2, char c) {
        super(list, iArr, iArr2, c);
        setupJcuBlas();
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int[] iArr2, char c) {
        super(fArr, iArr, iArr2, c);
        setupJcuBlas();
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int[] iArr2, int i, char c) {
        super(fArr, iArr, iArr2, i, c);
        setupJcuBlas();
    }

    public JCublasNDArray(float[] fArr, int[] iArr) {
        super(fArr, iArr);
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int i) {
        super(fArr, iArr, i);
        setupJcuBlas();
    }

    public JCublasNDArray(int[] iArr, int[] iArr2, int i) {
        super(iArr, iArr2, i);
        setupJcuBlas();
    }

    public JCublasNDArray(int[] iArr, int[] iArr2) {
        super(iArr, iArr2);
        setupJcuBlas();
    }

    public JCublasNDArray(int[] iArr, int i) {
        super(iArr, i);
        setupJcuBlas();
    }

    public JCublasNDArray(int[] iArr, char c) {
        super(iArr, c);
        setupJcuBlas();
    }

    public JCublasNDArray(int i, int i2) {
        super(i, i2);
        setupJcuBlas();
    }

    public JCublasNDArray(List<INDArray> list, int[] iArr) {
        super(list, iArr);
        setupJcuBlas();
    }

    public JCublasNDArray(List<INDArray> list, int[] iArr, int[] iArr2) {
        super(list, iArr, iArr2);
        setupJcuBlas();
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int[] iArr2) {
        super(fArr, iArr, iArr2);
        setupJcuBlas();
    }

    public JCublasNDArray(float[] fArr, int[] iArr, int[] iArr2, int i) {
        super(fArr, iArr, iArr2, i);
        setupJcuBlas();
    }

    public JCublasNDArray(JCublasNDArray jCublasNDArray) {
        this(new int[]{jCublasNDArray.rows, jCublasNDArray.columns});
        this.data = dup().data();
        setupJcuBlas();
    }

    public JCublasNDArray(double[] dArr, int[] iArr, int[] iArr2, int i) {
        this.data = ArrayUtil.floatCopyOf(dArr);
        this.stride = iArr2;
        this.offset = i;
        initShape(iArr);
        setupJcuBlas();
    }

    public JCublasNDArray(float[][] fArr) {
        super(fArr);
        setupJcuBlas();
    }

    protected void setupJcuBlas() {
        if (this.pointer != null) {
            return;
        }
        this.pointer = new Pointer();
        if (this.data != null) {
            this.dataPointer = Pointer.to(data()).withByteOffset(offset() * 4);
        }
    }

    private long getPointerOffset() {
        try {
            Method declaredMethod = Pointer.class.getDeclaredMethod("getByteOffset", new Class[0]);
            declaredMethod.setAccessible(true);
            return ((Long) declaredMethod.invoke(this.pointer, new Object[0])).longValue();
        } catch (Exception e) {
            throw new IllegalStateException("Unable to get declared pointer");
        }
    }

    public void allocTest() {
        if (this.data != null) {
            this.dataPointer = Pointer.to(data()).withByteOffset(this.offset * 4);
        }
        JCublas.cublasAlloc(this.length, 4, this.pointer);
        JCublas.cublasSetVector(this.length, 4, this.dataPointer, majorStride(), this.pointer, 1);
        getData(new float[this.length]);
    }

    public void alloc() {
        if (this.data != null) {
            this.dataPointer = Pointer.to(data()).withByteOffset(offset() * 4);
        }
        free();
        this.pointer = new Pointer();
        JCublas.cublasAlloc(this.length, 4, this.pointer);
        if (this.length == this.data.length) {
            JCublas.cublasSetVector(this.length, 4, this.dataPointer, 1, this.pointer, 1);
        } else {
            JCublas.cublasSetVector(this.length, 4, this.dataPointer, majorStride(), this.pointer, 1);
        }
    }

    public void free() {
        try {
            JCublas.cublasFree(this.pointer);
        } catch (CudaException e) {
        }
    }

    public void getData(float[] fArr) {
        getData(Pointer.to(fArr));
    }

    public void getData(Pointer pointer) {
        if (this.length == this.data.length) {
            JCublas.cublasGetVector(this.length, 4, pointer(), 1, pointer, 1);
        } else {
            JCublas.cublasGetVector(this.length, 4, pointer(), 1, pointer, majorStride());
        }
    }

    public void getData() {
        getData(this.dataPointer);
    }

    public Pointer dataPointer() {
        return this.dataPointer;
    }

    public Pointer pointer() {
        return this.pointer;
    }

    public Pointer pointerWithOffset() {
        return this.pointer.withByteOffset(offset() * 4);
    }
}
