package org.nd4j.linalg.jcublas.kernel;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Properties;
import jcuda.utils.KernelLauncher;
import org.apache.commons.io.IOUtils;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/nd4j/linalg/jcublas/kernel/KernelFunctionLoader.class */
public class KernelFunctionLoader {
    public static final String NAME_SPACE = "org.nd4j.linalg.jcuda.jcublas";
    public static final String DOUBLE = "org.nd4j.linalg.jcuda.jcublas.double.functions";
    public static final String FLOAT = "org.nd4j.linalg.jcuda.jcublas.float.functions";
    public static final String IMPORTS_FLOAT = "org.nd4j.linalg.jcuda.jcublas.float.imports";
    public static final String IMPORTS_DOUBLE = "org.nd4j.linalg.jcuda.jcublas.double.imports";
    public static final String CACHE_COMPILED = "org.nd4j.linalg.jcuda.jcublas.cache_compiled";
    private static KernelFunctionLoader INSTANCE;
    private static Logger log = LoggerFactory.getLogger(KernelFunctionLoader.class);
    private static Map<String, KernelLauncher> launchers = new HashMap();
    private Map<String, String> paths = new HashMap();
    private boolean init = false;

    private KernelFunctionLoader() {
    }

    public static synchronized KernelFunctionLoader getInstance() {
        if (INSTANCE == null) {
            INSTANCE = new KernelFunctionLoader();
            Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { // from class: org.nd4j.linalg.jcublas.kernel.KernelFunctionLoader.1
                @Override // java.lang.Runnable
                public void run() {
                    KernelFunctionLoader.INSTANCE.unload();
                }
            }));
            try {
                INSTANCE.load();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return INSTANCE;
    }

    private static String dataFolder(DataBuffer.Type type) {
        return "/kernels/" + (type == DataBuffer.Type.FLOAT ? "float" : "double");
    }

    public static KernelLauncher launcher(String str, String str2) {
        return getInstance().get(str, str2);
    }

    public boolean exists(String str) {
        return (get(str, "double") == null && get(str, "float") == null) ? false : true;
    }

    public KernelLauncher get(String str, String str2) {
        KernelLauncher kernelLauncher = launchers.get(str + "_" + str2);
        if (kernelLauncher == null) {
            kernelLauncher = launchers.get(str + "_strided_" + str2);
            if (kernelLauncher == null) {
                return null;
            }
        }
        return kernelLauncher;
    }

    public void unload() {
        this.init = false;
    }

    public void load() throws Exception {
        if (this.init) {
            return;
        }
        new StringBuffer().append("nvcc -g -G -ptx");
        ClassPathResource classPathResource = new ClassPathResource("/cudafunctions.properties", KernelFunctionLoader.class.getClassLoader());
        if (!classPathResource.exists()) {
            throw new IllegalStateException("Please put a cudafunctions.properties in your class path");
        }
        Properties properties = new Properties();
        properties.load(classPathResource.getInputStream());
        log.info("Registering cuda functions...");
        ensureImports(properties, "float");
        ensureImports(properties, "double");
        compileAndLoad(properties, "org.nd4j.linalg.jcuda.jcublas.float.functions", "float");
        compileAndLoad(properties, "org.nd4j.linalg.jcuda.jcublas.double.functions", "double");
        this.init = true;
    }

    private void compileAndLoad(Properties properties, String str, String str2) throws IOException {
        compileAndLoad(properties, str, str2, 0);
    }

    private void compileAndLoad(Properties properties, String str, String str2, int i) throws IOException {
        String property = properties.getProperty(str);
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("nvcc -g -G ");
        stringBuffer.append(" --include-path ");
        String property2 = System.getProperty("java.io.tmpdir");
        StringBuffer stringBuffer2 = new StringBuffer();
        boolean parseBoolean = Boolean.parseBoolean(properties.getProperty(CACHE_COMPILED, String.valueOf("true")));
        stringBuffer.append(property2).append(File.separator).append("kernels").append(File.separator).append(str2).append(File.separator).toString();
        String stringBuffer3 = stringBuffer2.append(property2).append(File.separator).append("kernels").append(File.separator).append(str2).append(File.separator).toString();
        boolean z = parseBoolean;
        if (parseBoolean) {
            File file = new File(property2 + File.separator + "kernels");
            if (file.exists()) {
                z = (parseBoolean && !file.exists()) || i > 0;
            }
        }
        String[] split = property.split(",");
        if (z) {
            stringBuffer.append(" ").append(" -ptx ");
            log.info("Loading " + str2 + " cuda functions");
            if (property != null) {
                for (String str3 : split) {
                    stringBuffer.append(" " + extract("/kernels/" + str2 + "/" + str3 + ".cu", str2.equals("float") ? DataBuffer.Type.FLOAT : DataBuffer.Type.DOUBLE));
                }
                stringBuffer.append(" --output-directory " + stringBuffer3);
                Process exec = Runtime.getRuntime().exec(stringBuffer.toString());
                String str4 = new String(IOUtils.toByteArray(exec.getErrorStream()));
                String str5 = new String(IOUtils.toByteArray(exec.getInputStream()));
                try {
                    int waitFor = exec.waitFor();
                    if (waitFor != 0) {
                        log.info("nvcc process exitValue " + waitFor);
                        log.info("errorMessage:\n" + str4);
                        log.info("outputMessage:\n" + str5);
                        throw new IOException("Could not create .ptx file: " + str4 + "\noutputMessage: " + str5);
                    }
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new IOException("Interrupted while waiting for nvcc output", e);
                }
            }
        } else {
            log.info("Modules appear to already be compiled..attempting to use cache");
            for (String str6 : split) {
                log.info("Is cached: " + str6);
                this.paths.put(str6 + "_" + str2, stringBuffer3 + str6 + ".ptx");
            }
        }
        try {
            for (String str7 : split) {
                log.info("Loading " + str7);
                String str8 = stringBuffer3 + str7 + ".ptx";
                String str9 = str7 + "_" + str2;
                this.paths.put(str9, str8);
                launchers.put(str9, KernelLauncher.load(str8, str9));
            }
        } catch (Exception e2) {
            if (z || i >= 3) {
                throw new RuntimeException(e2);
            }
            log.warn("Error loading modules...attempting recompile");
            properties.setProperty(CACHE_COMPILED, String.valueOf(true));
            compileAndLoad(properties, str, str2, i + 1);
        }
    }

    public String extract(String str, DataBuffer.Type type) throws IOException {
        File file = new File(System.getProperty("java.io.tmpdir"), dataFolder(type));
        if (!file.exists()) {
            file.mkdirs();
        }
        return loadFile(str);
    }

    private void ensureImports(Properties properties, String str) throws IOException {
        if (str.equals("float")) {
            for (String str2 : properties.getProperty(IMPORTS_FLOAT).split(",")) {
                loadFile("/kernels/" + str + "/" + str2);
            }
            return;
        }
        for (String str3 : properties.getProperty(IMPORTS_DOUBLE).split(",")) {
            loadFile("/kernels/" + str + "/" + str3);
        }
    }

    private String loadFile(String str) throws IOException {
        ClassPathResource classPathResource = new ClassPathResource(str, KernelFunctionLoader.class.getClassLoader());
        String property = System.getProperty("java.io.tmpdir");
        if (!classPathResource.exists()) {
            throw new IllegalStateException("Unable to find file " + classPathResource);
        }
        File file = new File(property, str);
        if (!file.getParentFile().exists()) {
            file.getParentFile().mkdirs();
        }
        if (file.exists()) {
            file.delete();
        }
        file.createNewFile();
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
        IOUtils.copy(classPathResource.getInputStream(), bufferedOutputStream);
        bufferedOutputStream.flush();
        bufferedOutputStream.close();
        file.deleteOnExit();
        return file.getAbsolutePath();
    }
}
