package org.nd4j.linalg.jcublas.kernel;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Properties;
import java.util.regex.Pattern;
import jcuda.jcublas.JCublas;
import jcuda.runtime.JCuda;
import jcuda.utils.KernelLauncher;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.jcublas.CublasPointer;
import org.nd4j.linalg.jcublas.buffer.JCudaBuffer;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.reflections.Reflections;
import org.reflections.scanners.ResourcesScanner;
import org.reflections.scanners.Scanner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/jcublas/kernel/KernelFunctionLoader.class */
public class KernelFunctionLoader {
    public static final String NAME_SPACE = "org.nd4j.linalg.jcuda.jcublas";
    public static final String DOUBLE = "org.nd4j.linalg.jcuda.jcublas.double.functions";
    public static final String FLOAT = "org.nd4j.linalg.jcuda.jcublas.float.functions";
    public static final String CACHE_COMPILED = "org.nd4j.linalg.jcuda.jcublas.cache_compiled";
    public static final String FUNCTION_KEY = "org.nd4j.linalg.jcuda.jcublas.functions";
    private static KernelFunctionLoader INSTANCE;
    private String kernelPath;
    private String[] modules;
    public static final String PRINT_KERNEL_NAME = "printShapeBuffer";
    private static KernelLauncher printFunction;
    private static Table<String, String, KernelLauncher> launchers = HashBasedTable.create();
    private static Logger log = LoggerFactory.getLogger(KernelFunctionLoader.class);
    private Map<String, String> paths = new HashMap();
    private boolean alreadyCompiled = false;
    private boolean init = false;

    private KernelFunctionLoader() {
    }

    public static synchronized KernelFunctionLoader getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new KernelFunctionLoader();
            Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { // from class: org.nd4j.linalg.jcublas.kernel.KernelFunctionLoader.1
                @Override // java.lang.Runnable
                public void run() {
                    KernelFunctionLoader.INSTANCE.unload();
                }
            }));
            try {
                INSTANCE.load();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return INSTANCE;
    }

    public static KernelLauncher launcher(String str, String str2) {
        return getInstance().get(str, str2);
    }

    public boolean exists(String str) {
        return (get(str, "double") == null && get(str, "float") == null) ? false : true;
    }

    public KernelLauncher get(String str, String str2) {
        String str3 = str + "_" + str2;
        if (!launchers.containsRow(Thread.currentThread().getName())) {
            try {
                loadModules(this.modules, this.kernelPath);
                log.debug("Loading modules for " + Thread.currentThread().getName());
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        KernelLauncher kernelLauncher = (KernelLauncher) launchers.get(Thread.currentThread().getName(), str3);
        if (kernelLauncher == null) {
            kernelLauncher = (KernelLauncher) launchers.get(Thread.currentThread().getName(), str + "_strided_" + str2);
            if (kernelLauncher == null) {
                return null;
            }
        }
        return kernelLauncher;
    }

    public void unload() {
        this.init = false;
    }

    public void load() throws Exception {
        if (this.init) {
            return;
        }
        ClassPathResource classPathResource = new ClassPathResource("/cudafunctions.properties", KernelFunctionLoader.class.getClassLoader());
        if (!classPathResource.exists()) {
            throw new IllegalStateException("Please put a cudafunctions.properties in your class path");
        }
        Properties properties = new Properties();
        properties.load(classPathResource.getInputStream());
        log.info("Registering cuda functions...");
        compileAndLoad(properties);
        this.init = true;
    }

    private void compileAndLoad(Properties properties) throws IOException {
        compileAndLoad(properties, 0);
    }

    private void compileAndLoad(Properties properties, int i) throws IOException {
        String property = properties.getProperty(FUNCTION_KEY);
        String property2 = System.getProperty("java.io.tmpdir");
        this.kernelPath = new StringBuffer().append(property2).append(File.separator).append("nd4j-kernels").append(File.separator).append("output").append(File.separator).toString();
        File file = new File(property2 + File.separator + "nd4j-kernels" + File.separatorChar + "output");
        boolean z = !file.exists() || (file.exists() && file.listFiles().length <= 1) || this.alreadyCompiled;
        String[] split = property.split(",");
        this.modules = split;
        if (z) {
            loadCudaKernels();
        } else {
            log.info("Modules appear to already be compiled..attempting to use cache");
            for (String str : split) {
                String str2 = this.kernelPath + str + ".cubin";
                this.paths.put(str + "_double", str2);
                this.paths.put(str + "_float", str2);
            }
        }
        try {
            loadModules(split, this.kernelPath);
            this.alreadyCompiled = true;
        } catch (IOException e) {
            if (z || i >= 3) {
                throw new RuntimeException(e);
            }
            log.warn("Error loading modules...attempting recompile");
            FileUtils.deleteDirectory(new File(this.kernelPath));
            properties.setProperty(CACHE_COMPILED, String.valueOf(true));
            compileAndLoad(properties, i + 1);
        } catch (Exception e2) {
            throw new RuntimeException(e2);
        }
    }

    public static void printBuffer(JCudaBuffer jCudaBuffer, CudaContext cudaContext) throws Exception {
        CublasPointer cublasPointer = new CublasPointer(jCudaBuffer, cudaContext);
        cublasPointer.copyToHost();
        JCublas.printVector(jCudaBuffer.length(), cublasPointer.getDevicePointer());
        jCudaBuffer.asNio().rewind();
        FloatBuffer asFloatBuffer = cublasPointer.getHostPointer().getByteBuffer(0L, jCudaBuffer.getElementSize() * jCudaBuffer.length()).asFloatBuffer();
        for (int i = 0; i < jCudaBuffer.length(); i++) {
            System.out.println("Item " + i + " is " + asFloatBuffer.get(i));
        }
        JCuda.cudaDeviceSynchronize();
    }

    private void loadModules(String[] strArr, String str) throws Exception {
        for (String str2 : strArr) {
            log.debug("Loading " + str2);
            String str3 = str + str2 + ".cubin";
            if (!new File(str3).exists()) {
                throw new IllegalStateException("Unable to find path " + str3 + ". Recompiling");
            }
            this.paths.put(str2, str3);
            KernelLauncher load = KernelLauncher.load(str3, str2, "float");
            launchers.put(Thread.currentThread().getName(), str2 + "_double", KernelLauncher.load(str2, "double", load.getModule()));
            launchers.put(Thread.currentThread().getName(), str2 + "_float", load);
            if (printFunction == null) {
                printFunction = KernelLauncher.load(PRINT_KERNEL_NAME, load.getModule());
            }
        }
    }

    private void loadCudaKernels() throws IOException {
        Iterator it = new Reflections("org.nd4j.nd4j-kernels", new Scanner[]{new ResourcesScanner()}).getResources(Pattern.compile(".*")).iterator();
        while (it.hasNext()) {
            extract((String) it.next());
        }
        new File(System.getProperty("java.io.tmpdir") + File.separator + "nd4j-kernels", "output").mkdirs();
        log.info("Compiling cuda kernels");
        ProcessBuilder processBuilder = new ProcessBuilder("bash", "-c", "make && /usr/bin/make install");
        processBuilder.directory(new File(System.getProperty("java.io.tmpdir") + File.separator + "nd4j-kernels"));
        Process start = processBuilder.start();
        InputStream inputStream = start.getInputStream();
        try {
            start.waitFor();
            Iterator it2 = IOUtils.readLines(new BufferedInputStream(inputStream), "UTF-8").iterator();
            while (it2.hasNext()) {
                log.info((String) it2.next());
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }

    public String extract(String str) throws IOException {
        String property = System.getProperty("java.io.tmpdir");
        String[] split = str.split("/");
        String[] strArr = new String[split.length - 2];
        int i = 0;
        for (int i2 = 2; i2 < split.length; i2++) {
            strArr[i] = split[i2];
            i++;
        }
        File file = new File(property, StringUtils.join(strArr, "/"));
        if (!file.getParentFile().exists()) {
            file.mkdirs();
        }
        return loadFile(str, file);
    }

    private String loadFile(String str, File file) throws IOException {
        ClassPathResource classPathResource = new ClassPathResource(str, KernelFunctionLoader.class.getClassLoader());
        if (!classPathResource.exists()) {
            throw new IllegalStateException("Unable to find file " + classPathResource);
        }
        if (!file.getParentFile().exists()) {
            file.getParentFile().mkdirs();
        }
        if (file.exists()) {
            file.delete();
        }
        file.createNewFile();
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
        IOUtils.copy(classPathResource.getInputStream(), bufferedOutputStream);
        bufferedOutputStream.flush();
        bufferedOutputStream.close();
        return file.getAbsolutePath();
    }
}
