package org.nd4j.linalg.jcublas;

import jcuda.LogLevel;
import jcuda.Pointer;
import jcuda.cuComplex;
import jcuda.cuDoubleComplex;
import jcuda.jcublas.JCublas;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.complex.JCublasComplexNDArray;

/* loaded from: input_file:org/nd4j/linalg/jcublas/SimpleJCublas.class */
public class SimpleJCublas {
    public static void free(Pointer... pointerArr) {
        for (Pointer pointer : pointerArr) {
            JCublas.cublasFree(pointer);
        }
    }

    public static void getData(JCublasNDArray jCublasNDArray, Pointer pointer, Pointer pointer2) {
        if (jCublasNDArray.length() == jCublasNDArray.data().length) {
            JCublas.cublasGetVector(jCublasNDArray.length(), 4, pointer, 1, pointer2.withByteOffset(jCublasNDArray.offset() * 4), 1);
        } else {
            JCublas.cublasGetVector(jCublasNDArray.length(), 4, pointer, 1, pointer2.withByteOffset(jCublasNDArray.offset() * 4), jCublasNDArray.majorStride());
        }
    }

    public static Pointer alloc(JCublasComplexNDArray jCublasComplexNDArray) {
        Pointer pointer = new Pointer();
        JCublas.cublasAlloc(jCublasComplexNDArray.length() * 2, 4, pointer);
        if (jCublasComplexNDArray.length() == jCublasComplexNDArray.data().length) {
            JCublas.cublasSetVector(jCublasComplexNDArray.length() * 2, 4, Pointer.to(jCublasComplexNDArray.data()).withByteOffset(jCublasComplexNDArray.offset() * 4), 1, pointer, 1);
        } else {
            JCublas.cublasSetVector(jCublasComplexNDArray.length() * 2, 4, Pointer.to(jCublasComplexNDArray.data()).withByteOffset(jCublasComplexNDArray.offset() * 4), 1, pointer, 1);
        }
        return pointer;
    }

    public static void getData(JCublasComplexNDArray jCublasComplexNDArray, Pointer pointer, Pointer pointer2) {
        if (jCublasComplexNDArray.length() == jCublasComplexNDArray.data().length) {
            JCublas.cublasGetVector(jCublasComplexNDArray.length() * 2, 4, pointer, 1, pointer2.withByteOffset(jCublasComplexNDArray.offset() * 4), 1);
        } else {
            JCublas.cublasGetVector(jCublasComplexNDArray.length() * 2, 4, pointer, 1, pointer2.withByteOffset(jCublasComplexNDArray.offset() * 4), 1);
        }
    }

    public static Pointer alloc(JCublasNDArray jCublasNDArray) {
        Pointer pointer = new Pointer();
        JCublas.cublasAlloc(jCublasNDArray.length(), 4, pointer);
        if (jCublasNDArray.length() == jCublasNDArray.data().length) {
            JCublas.cublasSetVector(jCublasNDArray.length(), 4, Pointer.to(jCublasNDArray.data()).withByteOffset(jCublasNDArray.offset() * 4), 1, pointer, 1);
        } else {
            JCublas.cublasSetVector(jCublasNDArray.length(), 4, Pointer.to(jCublasNDArray.data()).withByteOffset(jCublasNDArray.offset() * 4), jCublasNDArray.majorStride(), pointer, 1);
        }
        return pointer;
    }

    public static INDArray gemv(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f, float f2) {
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasSgemv('N', iNDArray.rows(), iNDArray.columns(), f, alloc, iNDArray.rows(), alloc2, 1, f2, alloc3, 1);
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    public static IComplexNDArray gemm(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNumber iComplexNumber, IComplexNDArray iComplexNDArray3, IComplexNumber iComplexNumber2) {
        JCublas.cublasInit();
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc(jCublasComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray2);
        JCublas.cublasCgemm('n', 'n', jCublasComplexNDArray2.rows(), jCublasComplexNDArray2.columns(), jCublasComplexNDArray.columns(), cuComplex.cuCmplx(iComplexNumber.realComponent().floatValue(), iComplexNumber2.imaginaryComponent().floatValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), cuComplex.cuCmplx(iComplexNumber2.realComponent().floatValue(), iComplexNumber2.imaginaryComponent().floatValue()), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray2, alloc3, Pointer.to(jCublasComplexNDArray2.data()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    public static INDArray gemm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f, float f2) {
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasSgemm('n', 'n', iNDArray3.rows(), iNDArray3.columns(), iNDArray.columns(), f, alloc, iNDArray.rows(), alloc2, iNDArray2.rows(), f2, alloc3, iNDArray3.rows());
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    public static float nrm2(IComplexNDArray iComplexNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        float cublasSnrm2 = JCublas.cublasSnrm2(iComplexNDArray.length(), alloc, 2);
        free(alloc);
        return cublasSnrm2;
    }

    public static void copy(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        JCublas.cublasInit();
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray2;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc(jCublasComplexNDArray);
        JCublas.cublasScopy(iComplexNDArray.length(), alloc, 1, alloc2, 1);
        getData(jCublasComplexNDArray, alloc2, Pointer.to(jCublasComplexNDArray.data()));
        free(alloc, alloc2);
    }

    public static int iamax(IComplexNDArray iComplexNDArray) {
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        int cublasIzamax = JCublas.cublasIzamax(iComplexNDArray.length(), alloc, 1);
        free(alloc);
        return cublasIzamax;
    }

    public static float asum(IComplexNDArray iComplexNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        float cublasScasum = JCublas.cublasScasum(iComplexNDArray.length(), alloc, 1);
        free(alloc);
        return cublasScasum;
    }

    public static void swap(INDArray iNDArray, INDArray iNDArray2) {
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc(jCublasNDArray);
        Pointer alloc2 = alloc(jCublasNDArray2);
        JCublas.cublasSswap(jCublasNDArray.length(), alloc, 1, alloc2, 1);
        getData(jCublasNDArray2, alloc2, Pointer.to(jCublasNDArray2.data()));
        free(alloc, alloc2);
    }

    public static float asum(INDArray iNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        float cublasSasum = JCublas.cublasSasum(iNDArray.length(), alloc, 1);
        free(alloc);
        return cublasSasum;
    }

    public static float nrm2(INDArray iNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        float cublasSnrm2 = JCublas.cublasSnrm2(iNDArray.length(), alloc, 1);
        JCublas.cublasFree(alloc);
        return cublasSnrm2;
    }

    public static int iamax(INDArray iNDArray) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        int cublasIsamax = JCublas.cublasIsamax(iNDArray.length(), alloc, 1);
        free(alloc);
        return cublasIsamax - 1;
    }

    public static void axpy(float f, INDArray iNDArray, INDArray iNDArray2) {
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        JCublasNDArray jCublasNDArray2 = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc(jCublasNDArray);
        Pointer alloc2 = alloc(jCublasNDArray2);
        if (jCublasNDArray.ordering() == 'c') {
            JCublas.cublasSaxpy(jCublasNDArray.length(), f, alloc, 1, alloc2, 1);
            getData(jCublasNDArray2, alloc2, Pointer.to(jCublasNDArray2.data()));
        } else {
            JCublas.cublasSaxpy(jCublasNDArray.length(), f, alloc, 1, alloc2, 1);
            getData(jCublasNDArray2, alloc2, Pointer.to(jCublasNDArray2.data()));
        }
        free(alloc, alloc2);
    }

    public static void axpy(IComplexNumber iComplexNumber, IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublasComplexNDArray jCublasComplexNDArray2 = (JCublasComplexNDArray) iComplexNDArray2;
        JCublas.cublasInit();
        Pointer alloc = alloc(jCublasComplexNDArray);
        Pointer alloc2 = alloc(jCublasComplexNDArray2);
        JCublas.cublasCaxpy(jCublasComplexNDArray.length(), cuComplex.cuCmplx(iComplexNumber.realComponent().floatValue(), iComplexNumber.imaginaryComponent().floatValue()), alloc, 1, alloc2, 1);
        getData(jCublasComplexNDArray2, alloc2, Pointer.to(jCublasComplexNDArray2.data()));
        free(alloc, alloc2);
    }

    public static INDArray scal(float f, INDArray iNDArray) {
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray;
        Pointer alloc = alloc(jCublasNDArray);
        JCublas.cublasSscal(jCublasNDArray.length(), f, alloc, 1);
        getData(jCublasNDArray, alloc, Pointer.to(jCublasNDArray.data()));
        free(alloc);
        return iNDArray;
    }

    public static void copy(INDArray iNDArray, INDArray iNDArray2) {
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc(jCublasNDArray);
        JCublas.cublasDcopy(iNDArray.length(), alloc, 1, alloc2, 1);
        getData(jCublasNDArray, alloc2, Pointer.to(jCublasNDArray.data()));
        free(alloc, alloc2);
    }

    public static float dot(INDArray iNDArray, INDArray iNDArray2) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        float cublasSdot = JCublas.cublasSdot(iNDArray.length(), alloc, 1, alloc2, 1);
        free(alloc, alloc2);
        return cublasSdot;
    }

    public static IComplexDouble dot(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        JCublas.cublasInit();
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        cuDoubleComplex cublasZdotc = JCublas.cublasZdotc(iComplexNDArray.length(), alloc, 1, alloc2, 1);
        IComplexDouble createDouble = Nd4j.createDouble(cublasZdotc.x, cublasZdotc.y);
        free(alloc, alloc2);
        return createDouble;
    }

    public static INDArray ger(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, float f) {
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray3;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc((JCublasNDArray) iNDArray2);
        Pointer alloc3 = alloc(jCublasNDArray);
        JCublas.cublasSger(iNDArray.rows(), iNDArray.columns(), f, alloc, iNDArray.rows(), alloc2, iNDArray2.rows(), alloc3, iNDArray3.rows());
        getData(jCublasNDArray, alloc3, Pointer.to(jCublasNDArray.data()));
        free(alloc, alloc2, alloc3);
        return iNDArray3;
    }

    public static IComplexNDArray zscal(IComplexFloat iComplexFloat, IComplexNDArray iComplexNDArray) {
        JCublas.cublasInit();
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        Pointer alloc = alloc(jCublasComplexNDArray);
        JCublas.cublasCscal(iComplexNDArray.length(), cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()), alloc, 1);
        getData(jCublasComplexNDArray, alloc, Pointer.to(jCublasComplexNDArray.data()));
        free(alloc);
        return iComplexNDArray;
    }

    public static IComplexNDArray zscal(IComplexDouble iComplexDouble, IComplexNDArray iComplexNDArray) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray;
        JCublas.cublasInit();
        Pointer alloc = alloc(jCublasComplexNDArray);
        JCublas.cublasZscal(iComplexNDArray.length(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), alloc, 1);
        getData(jCublasComplexNDArray, alloc, Pointer.to(jCublasComplexNDArray.data()));
        free(alloc);
        return iComplexNDArray;
    }

    public static IComplexDouble dotu(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2) {
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        cuDoubleComplex cublasZdotu = JCublas.cublasZdotu(iComplexNDArray.length(), alloc, 1, alloc2, 1);
        IComplexDouble createDouble = Nd4j.createDouble(cublasZdotu.x, cublasZdotu.y);
        free(alloc, alloc2);
        return createDouble;
    }

    public static IComplexNDArray geru(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray);
        JCublas.cublasZgeru(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray, alloc3, Pointer.to(jCublasComplexNDArray.data()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    public static IComplexNDArray gerc(IComplexNDArray iComplexNDArray, IComplexNDArray iComplexNDArray2, IComplexNDArray iComplexNDArray3, IComplexDouble iComplexDouble) {
        JCublasComplexNDArray jCublasComplexNDArray = (JCublasComplexNDArray) iComplexNDArray3;
        Pointer alloc = alloc((JCublasComplexNDArray) iComplexNDArray);
        Pointer alloc2 = alloc((JCublasComplexNDArray) iComplexNDArray2);
        Pointer alloc3 = alloc(jCublasComplexNDArray);
        JCublas.cublasZgerc(iComplexNDArray.rows(), iComplexNDArray.columns(), cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()), alloc, iComplexNDArray.rows(), alloc2, iComplexNDArray2.rows(), alloc3, iComplexNDArray3.rows());
        getData(jCublasComplexNDArray, alloc3, Pointer.to(jCublasComplexNDArray.data()));
        free(alloc, alloc2, alloc3);
        return iComplexNDArray3;
    }

    public static void saxpy(float f, INDArray iNDArray, INDArray iNDArray2) {
        JCublas.cublasInit();
        JCublasNDArray jCublasNDArray = (JCublasNDArray) iNDArray2;
        Pointer alloc = alloc((JCublasNDArray) iNDArray);
        Pointer alloc2 = alloc(jCublasNDArray);
        JCublas.cublasSaxpy(iNDArray.length(), f, alloc, 1, alloc2, 1);
        getData(jCublasNDArray, alloc2, Pointer.to(jCublasNDArray.data()));
        free(alloc, alloc2);
    }

    static {
        JCublas.setLogLevel(LogLevel.LOG_DEBUG);
        JCublas.setExceptionsEnabled(true);
        JCublas.cublasInit();
        Runtime.getRuntime().addShutdownHook(new Thread() { // from class: org.nd4j.linalg.jcublas.SimpleJCublas.1
            @Override // java.lang.Thread, java.lang.Runnable
            public void run() {
                JCublas.cublasShutdown();
            }
        });
    }
}
