package jcuda.utils;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.driver.CUcontext;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.jcurand.curandGenerator;
import jcuda.runtime.dim3;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:jcuda/utils/KernelLauncher.class */
public class KernelLauncher {
    private static final Logger logger = LoggerFactory.getLogger(KernelLauncher.class.getName());
    private static String compilerPath = "";
    private CUcontext context;
    private CUmodule module;
    private CUfunction function;
    private CUstream stream;
    private int deviceNumber = 0;
    private dim3 blockSize = new dim3(1, 1, 1);
    private dim3 gridSize = new dim3(1, 1, 1);
    private int sharedMemSize = 0;

    public static void setCompilerPath(String str) {
        if (str == null) {
            compilerPath = "";
        }
        compilerPath = str;
        if (compilerPath.endsWith(File.separator)) {
            return;
        }
        compilerPath += File.separator;
    }

    public void setDeviceNumber(int i) {
        int[] iArr = new int[1];
        JCudaDriver.cuDeviceGetCount(iArr);
        if (i < 0) {
            throw new CudaException("Invalid device number: " + i + ". There are only " + iArr[0] + " devices available");
        }
        this.deviceNumber = i;
    }

    public static KernelLauncher compile(String str, String str2, String... strArr) {
        try {
            File createTempFile = File.createTempFile("temp_JCuda_", ".cu");
            String path = createTempFile.getPath();
            FileOutputStream fileOutputStream = null;
            try {
                try {
                    fileOutputStream = new FileOutputStream(createTempFile);
                    fileOutputStream.write(str.getBytes());
                    if (fileOutputStream != null) {
                        try {
                            fileOutputStream.close();
                        } catch (IOException e) {
                            throw new CudaException("Could not close temporary .cu file", e);
                        }
                    }
                    return create(path, str2, strArr);
                } catch (Throwable th) {
                    if (fileOutputStream != null) {
                        try {
                            fileOutputStream.close();
                        } catch (IOException e2) {
                            throw new CudaException("Could not close temporary .cu file", e2);
                        }
                    }
                    throw th;
                }
            } catch (IOException e3) {
                throw new CudaException("Could not write temporary .cu file", e3);
            }
        } catch (IOException e4) {
            throw new CudaException("Could not create temporary .cu file", e4);
        }
    }

    public static KernelLauncher create(String str, String str2, String... strArr) {
        return create(str, str2, false, strArr);
    }

    public static KernelLauncher create(String str, String str2, boolean z, String... strArr) {
        try {
            String preparePtxFile = preparePtxFile(str, z, strArr);
            KernelLauncher kernelLauncher = new KernelLauncher();
            kernelLauncher.initModule(loadData(preparePtxFile));
            kernelLauncher.initFunction(str2);
            return kernelLauncher;
        } catch (IOException e) {
            throw new CudaException("Could not prepare PTX for source file '" + str + "'", e);
        }
    }

    public static KernelLauncher load(String str, String str2) {
        KernelLauncher kernelLauncher = new KernelLauncher();
        kernelLauncher.initModule(loadData(str));
        kernelLauncher.initFunction(str2);
        return kernelLauncher;
    }

    public static KernelLauncher load(InputStream inputStream, String str) {
        KernelLauncher kernelLauncher = new KernelLauncher();
        kernelLauncher.initModule(loadData(inputStream));
        kernelLauncher.initFunction(str);
        return kernelLauncher;
    }

    private static byte[] loadData(String str) {
        FileInputStream fileInputStream = null;
        try {
            try {
                fileInputStream = new FileInputStream(new File(str));
                byte[] loadData = loadData(fileInputStream);
                if (fileInputStream != null) {
                    try {
                        fileInputStream.close();
                    } catch (IOException e) {
                        throw new CudaException("Could not close '" + str + "'", e);
                    }
                }
                return loadData;
            } catch (Throwable th) {
                if (fileInputStream != null) {
                    try {
                        fileInputStream.close();
                    } catch (IOException e2) {
                        throw new CudaException("Could not close '" + str + "'", e2);
                    }
                }
                throw th;
            }
        } catch (FileNotFoundException e3) {
            throw new CudaException("Could not open '" + str + "'", e3);
        }
    }

    private static byte[] loadData(InputStream inputStream) {
        ByteArrayOutputStream byteArrayOutputStream = null;
        try {
            try {
                byteArrayOutputStream = new ByteArrayOutputStream();
                byte[] bArr = new byte[8192];
                while (true) {
                    int read = inputStream.read(bArr);
                    if (read == -1) {
                        break;
                    }
                    byteArrayOutputStream.write(bArr, 0, read);
                }
                byteArrayOutputStream.write(0);
                byteArrayOutputStream.flush();
                byte[] byteArray = byteArrayOutputStream.toByteArray();
                if (byteArrayOutputStream != null) {
                    try {
                        byteArrayOutputStream.close();
                    } catch (IOException e) {
                        throw new CudaException("Could not close output", e);
                    }
                }
                return byteArray;
            } catch (IOException e2) {
                throw new CudaException("Could not load data", e2);
            }
        } catch (Throwable th) {
            if (byteArrayOutputStream != null) {
                try {
                    byteArrayOutputStream.close();
                } catch (IOException e3) {
                    throw new CudaException("Could not close output", e3);
                }
            }
            throw th;
        }
    }

    private KernelLauncher() {
    }

    public KernelLauncher forFunction(String str) {
        KernelLauncher kernelLauncher = new KernelLauncher();
        kernelLauncher.module = this.module;
        kernelLauncher.initFunction(str);
        return kernelLauncher;
    }

    private void initModule(byte[] bArr) {
        this.module = new CUmodule();
        JCudaDriver.cuModuleLoadDataEx(this.module, Pointer.to(bArr), 0, new int[0], Pointer.to(new int[0]));
    }

    private void initFunction(String str) {
        this.function = new CUfunction();
        String str2 = "Could not get function '" + str + "' from module. \nName in module might be mangled. Try adding the line \nextern \"C\"\nbefore the function you want to call, or open the PTX/CUBIN \nfile with a text editor to find out the mangled function name";
        try {
            if (JCudaDriver.cuModuleGetFunction(this.function, this.module, str) != 0) {
                throw new CudaException(str2);
            }
        } catch (CudaException e) {
            throw new CudaException(str2, e);
        }
    }

    public CUmodule getModule() {
        return this.module;
    }

    public KernelLauncher setGridSize(int i, int i2) {
        this.gridSize.x = i;
        this.gridSize.y = i2;
        return this;
    }

    public KernelLauncher setGridSize(int i, int i2, int i3) {
        this.gridSize.x = i;
        this.gridSize.y = i2;
        this.gridSize.z = i3;
        return this;
    }

    public KernelLauncher setBlockSize(int i, int i2, int i3) {
        this.blockSize.x = i;
        this.blockSize.y = i2;
        this.blockSize.z = i3;
        return this;
    }

    public KernelLauncher setSharedMemSize(int i) {
        this.sharedMemSize = i;
        return this;
    }

    public KernelLauncher setStream(CUstream cUstream) {
        this.stream = cUstream;
        return this;
    }

    public KernelLauncher setup(dim3 dim3Var, dim3 dim3Var2) {
        return setup(dim3Var, dim3Var2, this.sharedMemSize, this.stream);
    }

    public KernelLauncher setup(dim3 dim3Var, dim3 dim3Var2, int i) {
        return setup(dim3Var, dim3Var2, i, this.stream);
    }

    public CUcontext context() {
        return this.context;
    }

    public KernelLauncher setup(dim3 dim3Var, dim3 dim3Var2, int i, CUstream cUstream) {
        setGridSize(dim3Var.x, dim3Var.y);
        setBlockSize(dim3Var2.x, dim3Var2.y, dim3Var2.z);
        setSharedMemSize(i);
        setStream(cUstream);
        return this;
    }

    public void call(Object... objArr) {
        Pointer[] pointerArr = new Pointer[objArr.length];
        for (int i = 0; i < objArr.length; i++) {
            Object obj = objArr[i];
            if (obj instanceof Pointer) {
                pointerArr[i] = Pointer.to((Pointer) obj);
            } else if (obj instanceof Byte) {
                pointerArr[i] = Pointer.to(new byte[]{((Byte) obj).byteValue()});
            } else if (obj instanceof Short) {
                pointerArr[i] = Pointer.to(new short[]{((Short) obj).shortValue()});
            } else if (obj instanceof Integer) {
                pointerArr[i] = Pointer.to(new int[]{((Integer) obj).intValue()});
            } else if (obj instanceof Long) {
                pointerArr[i] = Pointer.to(new long[]{((Long) obj).longValue()});
            } else if (obj instanceof Float) {
                pointerArr[i] = Pointer.to(new float[]{((Float) obj).floatValue()});
            } else if (obj instanceof Double) {
                pointerArr[i] = Pointer.to(new double[]{((Double) obj).doubleValue()});
            } else if (obj instanceof double[]) {
                pointerArr[i] = Pointer.to((double[]) obj);
            } else if (obj instanceof float[]) {
                pointerArr[i] = Pointer.to((float[]) obj);
            } else if (obj instanceof int[]) {
                pointerArr[i] = Pointer.to((int[]) obj);
            } else {
                if (!(obj instanceof curandGenerator)) {
                    throw new CudaException("Type " + obj.getClass() + " may not be passed to a function");
                }
                pointerArr[i] = Pointer.to((curandGenerator) obj);
            }
        }
        JCudaDriver.cuLaunchKernel(this.function, this.gridSize.x, this.gridSize.y, this.gridSize.z, this.blockSize.x, this.blockSize.y, this.blockSize.z, this.sharedMemSize, this.stream, Pointer.to(pointerArr), null);
    }

    private static String preparePtxFile(String str, boolean z, String... strArr) throws IOException {
        logger.info("Preparing PTX for \n" + str);
        File file = new File(str);
        if (!file.exists()) {
            throw new CudaException("Input file not found: " + str);
        }
        int lastIndexOf = str.lastIndexOf(46);
        String str2 = lastIndexOf == -1 ? str + ".ptx" : str.substring(0, lastIndexOf) + ".ptx";
        File file2 = new File(str2);
        if (file2.exists() && !z && file.lastModified() < file2.lastModified()) {
            return str2;
        }
        String str3 = compilerPath + "nvcc " + ("-m" + System.getProperty("sun.arch.data.model")) + "  " + createArgumentsString(strArr) + " -ptx " + file.getPath() + " -o " + str2;
        logger.info("Executing\n" + str3);
        Process exec = Runtime.getRuntime().exec(str3);
        String str4 = new String(toByteArray(exec.getErrorStream()));
        String str5 = new String(toByteArray(exec.getInputStream()));
        try {
            int waitFor = exec.waitFor();
            logger.info("nvcc process exitValue " + waitFor);
            if (waitFor == 0) {
                return str2;
            }
            logger.error("errorMessage:\n" + str4);
            logger.error("outputMessage:\n" + str5);
            throw new CudaException("Could not create .ptx file: " + str4);
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new CudaException("Interrupted while waiting for nvcc output", e);
        }
    }

    private static String createArgumentsString(String... strArr) {
        if (strArr == null || strArr.length == 0) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        for (String str : strArr) {
            sb.append(str);
            sb.append(" ");
        }
        return sb.toString();
    }

    private static byte[] toByteArray(InputStream inputStream) throws IOException {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        byte[] bArr = new byte[8192];
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                return byteArrayOutputStream.toByteArray();
            }
            byteArrayOutputStream.write(bArr, 0, read);
        }
    }

    public String toString() {
        return "KernelLauncher{deviceNumber=" + this.deviceNumber + ", context=" + this.context + ", module=" + this.module + ", function=" + this.function + ", blockSize=" + this.blockSize + ", gridSize=" + this.gridSize + ", sharedMemSize=" + this.sharedMemSize + ", stream=" + this.stream + '}';
    }
}
