package de.jungblut.math.cuda;

import de.jungblut.math.dense.DenseDoubleMatrix;
import java.util.Random;
import jcuda.Pointer;
import jcuda.jcublas.JCublas;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaDeviceProp;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:de/jungblut/math/cuda/JCUDAMatrixUtils.class */
public final class JCUDAMatrixUtils {
    private static final Logger LOG = LogManager.getLogger(JCUDAMatrixUtils.class);
    public static boolean EXCEPTIONS_ENABLED = false;
    public static boolean CUBLAS2_AVAILABLE;
    private static cublasHandle handle;

    public static DenseDoubleMatrix multiply(DenseDoubleMatrix denseDoubleMatrix, DenseDoubleMatrix denseDoubleMatrix2) {
        return multiply(denseDoubleMatrix, denseDoubleMatrix2, false, false);
    }

    public static DenseDoubleMatrix multiply(Pointer pointer, Pointer pointer2, MatrixDimension matrixDimension) {
        Pointer pointer3 = new Pointer();
        int m = matrixDimension.getM() * matrixDimension.getN();
        int i = matrixDimension.isTransposeA() ? 1 : 0;
        int i2 = matrixDimension.isTransposeB() ? 1 : 0;
        if (CUBLAS2_AVAILABLE) {
            JCuda.cudaMalloc(pointer3, 8 * m);
            Pointer pointer4 = Pointer.to(new double[]{1.0d});
            Pointer pointer5 = Pointer.to(new double[]{0.0d});
            JCublas2.cublasDgemm(handle, i, i2, matrixDimension.getM(), matrixDimension.getN(), matrixDimension.getK(), pointer4, pointer, matrixDimension.getLdA(), pointer2, matrixDimension.getLdB(), pointer5, pointer3, matrixDimension.getLdC());
            freePointer(pointer4);
            freePointer(pointer5);
        } else {
            JCublas.cublasAlloc(m, 8, pointer3);
            JCublas.cublasDgemm(i == 0 ? 'n' : 'y', i2 == 0 ? 'n' : 'y', matrixDimension.getM(), matrixDimension.getN(), matrixDimension.getK(), 1.0d, pointer, matrixDimension.getLdA(), pointer2, matrixDimension.getLdB(), 0.0d, pointer3, matrixDimension.getLdC());
        }
        JCuda.cudaDeviceSynchronize();
        DenseDoubleMatrix matrix = getMatrix(pointer3, matrixDimension.getM(), matrixDimension.getN());
        freePointer(pointer3);
        return matrix;
    }

    public static DenseDoubleMatrix multiply(DenseDoubleMatrix denseDoubleMatrix, DenseDoubleMatrix denseDoubleMatrix2, boolean z, boolean z2) {
        Pointer memcpyMatrix = memcpyMatrix(denseDoubleMatrix);
        Pointer memcpyMatrix2 = memcpyMatrix(denseDoubleMatrix2);
        DenseDoubleMatrix multiply = multiply(memcpyMatrix, memcpyMatrix2, new MatrixDimension(denseDoubleMatrix, denseDoubleMatrix2, z, z2));
        freePointer(memcpyMatrix);
        freePointer(memcpyMatrix2);
        return multiply;
    }

    public static Pointer memcpyMatrix(DenseDoubleMatrix denseDoubleMatrix) {
        int columnCount = denseDoubleMatrix.getColumnCount() * denseDoubleMatrix.getRowCount();
        double[] columnMajorMatrix = denseDoubleMatrix.getColumnMajorMatrix();
        Pointer pointer = new Pointer();
        JCuda.cudaMalloc(pointer, columnCount * 8);
        if (CUBLAS2_AVAILABLE) {
            JCublas2.cublasSetMatrix(denseDoubleMatrix.getRowCount(), denseDoubleMatrix.getColumnCount(), 8, Pointer.to(columnMajorMatrix), denseDoubleMatrix.getRowCount(), pointer, denseDoubleMatrix.getRowCount());
        } else {
            JCublas.cublasSetMatrix(denseDoubleMatrix.getRowCount(), denseDoubleMatrix.getColumnCount(), 8, Pointer.to(columnMajorMatrix), denseDoubleMatrix.getRowCount(), pointer, denseDoubleMatrix.getRowCount());
        }
        return pointer;
    }

    public static DenseDoubleMatrix getMatrix(Pointer pointer, int i, int i2) {
        double[] dArr = new double[i * i2];
        Pointer pointer2 = Pointer.to(dArr);
        if (CUBLAS2_AVAILABLE) {
            JCublas2.cublasGetMatrix(i, i2, 8, pointer, i, pointer2, i);
        } else {
            JCublas.cublasGetMatrix(i, i2, 8, pointer, i, pointer2, i);
        }
        return new DenseDoubleMatrix(dArr, i, i2);
    }

    public static void freePointer(Pointer pointer) {
        JCuda.cudaFree(pointer);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void cublasDestroy(cublasHandle cublashandle) {
        if (CUBLAS2_AVAILABLE) {
            JCublas2.cublasDestroy(cublashandle);
        } else {
            JCublas.cublasShutdown();
        }
    }

    private static String humanReadableByteCount(long j, boolean z) {
        int i = z ? 1000 : 1024;
        if (j < i) {
            return j + " B";
        }
        int log = (int) (Math.log(j) / Math.log(i));
        return String.format("%.1f %sB", Double.valueOf(j / Math.pow(i, log)), (z ? "kMGTPE" : "KMGTPE").charAt(log - 1) + (z ? "" : "i"));
    }

    public static void main(String[] strArr) {
        DenseDoubleMatrix denseDoubleMatrix = new DenseDoubleMatrix(40000, 784, new Random());
        DenseDoubleMatrix denseDoubleMatrix2 = new DenseDoubleMatrix(784, 300, new Random());
        long currentTimeMillis = System.currentTimeMillis();
        DenseDoubleMatrix multiply = multiply(denseDoubleMatrix, denseDoubleMatrix2);
        LOG.info("GPU took: " + (((float) (System.currentTimeMillis() - currentTimeMillis)) / 1000.0f) + "s!");
        long currentTimeMillis2 = System.currentTimeMillis();
        DenseDoubleMatrix multiply2 = denseDoubleMatrix.multiply(denseDoubleMatrix2);
        LOG.info("CPU took: " + (((float) (System.currentTimeMillis() - currentTimeMillis2)) / 1000.0f) + "s!");
        LOG.info("Matrix difference: " + multiply2.subtract(multiply).sum());
    }

    static {
        CUBLAS2_AVAILABLE = false;
        try {
            JCuda.setExceptionsEnabled(EXCEPTIONS_ENABLED);
            cudaDeviceProp cudadeviceprop = new cudaDeviceProp();
            JCuda.cudaGetDeviceProperties(cudadeviceprop, 0);
            if (cudadeviceprop.major <= 1 && cudadeviceprop.minor < 3) {
                throw new IllegalArgumentException("WARN Double precision computing only allowed since capability 1.3! You have " + cudadeviceprop.major + "." + cudadeviceprop.minor + "! If you have exceptions turned off, then this may result in strange behaviour.");
            }
            if (Integer.parseInt(cudadeviceprop.getName().replaceAll("[^\\d]", "")) > 400) {
                JCublas2.setExceptionsEnabled(EXCEPTIONS_ENABLED);
                JCublas2.initialize();
                CUBLAS2_AVAILABLE = true;
                handle = new cublasHandle();
                JCublas2.cublasCreate(handle);
                JCublas2.cublasSetPointerMode(handle, 0);
            } else {
                JCublas.setExceptionsEnabled(EXCEPTIONS_ENABLED);
                JCublas.cublasInit();
            }
            Runtime.getRuntime().addShutdownHook(new Thread() { // from class: de.jungblut.math.cuda.JCUDAMatrixUtils.1
                @Override // java.lang.Thread, java.lang.Runnable
                public void run() {
                    JCUDAMatrixUtils.cublasDestroy(JCUDAMatrixUtils.handle);
                }
            });
            LOG.info("Using device " + cudadeviceprop.getName() + " with total RAM of " + humanReadableByteCount(cudadeviceprop.totalGlobalMem, false) + ". Compute capability: " + cudadeviceprop.major + "." + cudadeviceprop.minor);
        } catch (Throwable th) {
            LOG.error(th.getLocalizedMessage());
        }
    }
}
