package org.datavec.python;

import java.io.File;
import java.io.FileInputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;
import org.bytedeco.cpython.PyCompilerFlags;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.cpython.global.python;
import org.datavec.python.PythonVariables;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.nd4j.linalg.api.buffer.DataType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/datavec/python/PythonExecutioner.class */
public class PythonExecutioner {
    private static PyObject module;
    private static PyObject globals;
    private static final Logger log = LoggerFactory.getLogger(PythonExecutioner.class);
    private static JSONParser parser = new JSONParser();
    private static Map<Long, Integer> gilStates = new HashMap();

    public static void init() {
        log.info("CPython: Py_InitializeEx()");
        python.Py_InitializeEx(1);
        log.info("CPython: PyEval_InitThreads()");
        python.PyEval_InitThreads();
        log.info("CPython: PyImport_AddModule()");
        module = python.PyImport_AddModule("__main__");
        log.info("CPython: PyModule_GetDict()");
        globals = python.PyModule_GetDict(module);
        log.info("CPython: PyThreadState_Get()");
        python.PyEval_SaveThread();
    }

    public static void free() {
        python.Py_Finalize();
    }

    private static String inputCode(PythonVariables pythonVariables) throws Exception {
        String str;
        String str2 = "loc={};";
        if (pythonVariables == null) {
            return str2;
        }
        Map<String, String> strVariables = pythonVariables.getStrVariables();
        Map<String, Long> intVariables = pythonVariables.getIntVariables();
        Map<String, Double> floatVariables = pythonVariables.getFloatVariables();
        Map<String, NumpyArray> nDArrayVariables = pythonVariables.getNDArrayVariables();
        Map<String, Object[]> listVariables = pythonVariables.getListVariables();
        Map<String, String> fileVariables = pythonVariables.getFileVariables();
        for (String str3 : (String[]) strVariables.keySet().toArray(new String[strVariables.size()])) {
            str2 = (str2 + ((Object) str3) + " = \"\"\"" + escapeStr(strVariables.get(str3)) + "\"\"\"\n") + "loc['" + ((Object) str3) + "']=" + ((Object) str3) + "\n";
        }
        for (String str4 : (String[]) intVariables.keySet().toArray(new String[intVariables.size()])) {
            str2 = (str2 + str4 + " = " + intVariables.get(str4).toString() + "\n") + "loc['" + str4 + "']=" + str4 + "\n";
        }
        for (String str5 : (String[]) floatVariables.keySet().toArray(new String[floatVariables.size()])) {
            str2 = (str2 + str5 + " = " + floatVariables.get(str5).toString() + "\n") + "loc['" + str5 + "']=" + str5 + "\n";
        }
        for (String str6 : (String[]) listVariables.keySet().toArray(new String[listVariables.size()])) {
            str2 = (str2 + str6 + " = " + jArrayToPyString(listVariables.get(str6)) + "\n") + "loc['" + str6 + "']=" + str6 + "\n";
        }
        for (String str7 : (String[]) fileVariables.keySet().toArray(new String[fileVariables.size()])) {
            str2 = (str2 + ((Object) str7) + " = \"\"\"" + escapeStr(fileVariables.get(str7)) + "\"\"\"\n") + "loc['" + ((Object) str7) + "']=" + ((Object) str7) + "\n";
        }
        if (nDArrayVariables.size() > 0) {
            str2 = (str2 + "import ctypes; import numpy as np;") + "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape);";
            for (String str8 : (String[]) nDArrayVariables.keySet().toArray(new String[nDArrayVariables.size()])) {
                NumpyArray copy = nDArrayVariables.get(str8).copy();
                String str9 = "(";
                for (long j : copy.getShape()) {
                    str9 = str9 + String.valueOf(j) + ",";
                }
                String str10 = str9 + ")";
                if (copy.getDtype() == DataType.FLOAT) {
                    str = "ctypes.c_float";
                } else if (copy.getDtype() == DataType.DOUBLE) {
                    str = "ctypes.c_double";
                } else if (copy.getDtype() == DataType.SHORT) {
                    str = "ctypes.c_int16";
                } else if (copy.getDtype() == DataType.INT) {
                    str = "ctypes.c_int32";
                } else {
                    if (copy.getDtype() != DataType.LONG) {
                        throw new Exception("Unsupported data type: " + copy.getDtype().toString() + ".");
                    }
                    str = "ctypes.c_int64";
                }
                str2 = (str2 + (str8 + "=" + ("__arr_converter(" + String.valueOf(copy.getAddress()) + "," + str10 + "," + str + ")") + "\n")) + "loc['" + str8 + "']=" + str8 + "\n";
            }
        }
        return str2;
    }

    private static void _readOutputs(PythonVariables pythonVariables) {
        DataType dataType;
        String read = read(getTempFile());
        new File(getTempFile()).delete();
        try {
            JSONObject jSONObject = (JSONObject) new JSONParser().parse(read);
            for (String str : pythonVariables.getVariables()) {
                PythonVariables.Type type = pythonVariables.getType(str);
                if (type == PythonVariables.Type.NDARRAY) {
                    JSONObject jSONObject2 = (JSONObject) jSONObject.get(str);
                    long longValue = ((Long) jSONObject2.get("address")).longValue();
                    JSONArray jSONArray = (JSONArray) jSONObject2.get("shape");
                    JSONArray jSONArray2 = (JSONArray) jSONObject2.get("strides");
                    long[] jsonArrayToLongArray = jsonArrayToLongArray(jSONArray);
                    long[] jsonArrayToLongArray2 = jsonArrayToLongArray(jSONArray2);
                    String str2 = (String) jSONObject2.get("dtype");
                    if (str2.equals("float64")) {
                        dataType = DataType.DOUBLE;
                    } else if (str2.equals("float32")) {
                        dataType = DataType.FLOAT;
                    } else if (str2.equals("int16")) {
                        dataType = DataType.SHORT;
                    } else if (str2.equals("int32")) {
                        dataType = DataType.INT;
                    } else {
                        if (!str2.equals("int64")) {
                            throw new Exception("Unsupported array type " + str2 + ".");
                        }
                        dataType = DataType.LONG;
                    }
                    pythonVariables.setValue(str, new NumpyArray(longValue, jsonArrayToLongArray, jsonArrayToLongArray2, dataType, true));
                } else if (type == PythonVariables.Type.LIST) {
                    pythonVariables.setValue(str, ((JSONArray) jSONObject.get(str)).toArray());
                } else {
                    pythonVariables.setValue(str, jSONObject.get(str));
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static void acquireGIL() {
        log.info("---_enterSubInterpreter()---");
        if (python.PyGILState_Check() != 1) {
            gilStates.put(Long.valueOf(Thread.currentThread().getId()), Integer.valueOf(python.PyGILState_Ensure()));
            log.info("GIL ensured");
        }
    }

    private static void releaseGIL() {
        if (python.PyGILState_Check() == 1) {
            log.info("Releasing gil...");
            python.PyGILState_Release(gilStates.get(Long.valueOf(Thread.currentThread().getId())).intValue());
            log.info("Gil released.");
        }
    }

    public static void exec(String str) {
        String functionalCode = getFunctionalCode("__f_" + Thread.currentThread().getId(), str);
        acquireGIL();
        log.info("CPython: PyRun_SimpleStringFlag()");
        log.info(functionalCode);
        if (python.PyRun_SimpleStringFlags(functionalCode, (PyCompilerFlags) null) != 0) {
            python.PyErr_Print();
            throw new RuntimeException("exec failed");
        }
        log.info("Exec done");
        releaseGIL();
    }

    public static void exec(String str, PythonVariables pythonVariables) {
        exec(str + '\n' + outputCode(pythonVariables));
        _readOutputs(pythonVariables);
    }

    public static void exec(String str, PythonVariables pythonVariables, PythonVariables pythonVariables2) throws Exception {
        exec(inputCode(pythonVariables) + str, pythonVariables2);
    }

    public static PythonVariables exec(PythonTransform pythonTransform) throws Exception {
        if (pythonTransform.getInputs() != null && pythonTransform.getInputs().getVariables().length > 0) {
            throw new Exception("Required inputs not provided.");
        }
        exec(pythonTransform.getCode(), null, pythonTransform.getOutputs());
        return pythonTransform.getOutputs();
    }

    public static PythonVariables exec(PythonTransform pythonTransform, PythonVariables pythonVariables) throws Exception {
        exec(pythonTransform.getCode(), pythonVariables, pythonTransform.getOutputs());
        return pythonTransform.getOutputs();
    }

    public static String evalSTRING(String str) {
        log.info("CPython: PyImport_AddModule()");
        module = python.PyImport_AddModule("__main__");
        log.info("CPython: PyModule_GetDict()");
        globals = python.PyModule_GetDict(module);
        PyObject PyDict_GetItemString = python.PyDict_GetItemString(globals, str);
        PyObject PyUnicode_AsEncodedString = python.PyUnicode_AsEncodedString(PyDict_GetItemString, "UTF-8", "strict");
        String string = python.PyBytes_AsString(PyUnicode_AsEncodedString).getString();
        python.Py_DecRef(PyDict_GetItemString);
        python.Py_DecRef(PyUnicode_AsEncodedString);
        return string;
    }

    public static long evalINTEGER(String str) {
        log.info("CPython: PyImport_AddModule()");
        module = python.PyImport_AddModule("__main__");
        log.info("CPython: PyModule_GetDict()");
        globals = python.PyModule_GetDict(module);
        return python.PyLong_AsLongLong(python.PyDict_GetItemString(globals, str));
    }

    public static double evalFLOAT(String str) {
        log.info("CPython: PyImport_AddModule()");
        module = python.PyImport_AddModule("__main__");
        log.info("CPython: PyModule_GetDict()");
        globals = python.PyModule_GetDict(module);
        return python.PyFloat_AsDouble(python.PyDict_GetItemString(globals, str));
    }

    public static Object[] evalLIST(String str) throws Exception {
        log.info("CPython: PyImport_AddModule()");
        module = python.PyImport_AddModule("__main__");
        log.info("CPython: PyModule_GetDict()");
        globals = python.PyModule_GetDict(module);
        PyObject PyDict_GetItemString = python.PyDict_GetItemString(globals, str);
        PyObject PyUnicode_AsEncodedString = python.PyUnicode_AsEncodedString(python.PyObject_Str(PyDict_GetItemString), "UTF-8", "strict");
        String string = python.PyBytes_AsString(PyUnicode_AsEncodedString).getString();
        python.Py_DecRef(PyDict_GetItemString);
        python.Py_DecRef(PyUnicode_AsEncodedString);
        return ((JSONArray) parser.parse(string.replace("'", "\""))).toArray();
    }

    public static NumpyArray evalNDARRAY(String str) throws Exception {
        DataType dataType;
        log.info("CPython: PyImport_AddModule()");
        module = python.PyImport_AddModule("__main__");
        log.info("CPython: PyModule_GetDict()");
        globals = python.PyModule_GetDict(module);
        PyObject PyDict_GetItemString = python.PyDict_GetItemString(globals, str);
        PyObject PyObject_GetAttrString = python.PyObject_GetAttrString(PyDict_GetItemString, "__array_interface__");
        PyObject PyDict_GetItemString2 = python.PyDict_GetItemString(PyObject_GetAttrString, "data");
        PyObject PyLong_FromLong = python.PyLong_FromLong(0L);
        PyObject PyObject_GetItem = python.PyObject_GetItem(PyDict_GetItemString2, PyLong_FromLong);
        long PyLong_AsLongLong = python.PyLong_AsLongLong(PyObject_GetItem);
        PyObject PyObject_GetAttrString2 = python.PyObject_GetAttrString(PyDict_GetItemString, "shape");
        int PyObject_Size = (int) python.PyObject_Size(PyObject_GetAttrString2);
        long[] jArr = new long[PyObject_Size];
        for (int i = 0; i < PyObject_Size; i++) {
            PyObject PyLong_FromLong2 = python.PyLong_FromLong(i);
            jArr[i] = python.PyLong_AsLongLong(python.PyObject_GetItem(PyObject_GetAttrString2, PyLong_FromLong2));
            python.Py_DecRef(PyLong_FromLong2);
        }
        PyObject PyObject_GetAttrString3 = python.PyObject_GetAttrString(PyDict_GetItemString, "strides");
        long[] jArr2 = new long[PyObject_Size];
        for (int i2 = 0; i2 < PyObject_Size; i2++) {
            PyObject PyLong_FromLong3 = python.PyLong_FromLong(i2);
            jArr2[i2] = python.PyLong_AsLongLong(python.PyObject_GetItem(PyObject_GetAttrString3, PyLong_FromLong3));
            python.Py_DecRef(PyLong_FromLong3);
        }
        String string = python.PyBytes_AsString(python.PyUnicode_AsEncodedString(python.PyObject_GetAttrString(python.PyObject_GetAttrString(PyDict_GetItemString, "dtype"), "name"), "UTF-8", "strict")).getString();
        if (string.equals("float64")) {
            dataType = DataType.DOUBLE;
        } else if (string.equals("float32")) {
            dataType = DataType.FLOAT;
        } else if (string.equals("int16")) {
            dataType = DataType.SHORT;
        } else if (string.equals("int32")) {
            dataType = DataType.INT;
        } else {
            if (!string.equals("int64")) {
                throw new Exception("Unsupported array type " + string + ".");
            }
            dataType = DataType.LONG;
        }
        NumpyArray numpyArray = new NumpyArray(PyLong_AsLongLong, jArr, jArr2, dataType, true);
        python.Py_DecRef(PyObject_GetAttrString);
        python.Py_DecRef(PyDict_GetItemString2);
        python.Py_DecRef(PyLong_FromLong);
        python.Py_DecRef(PyObject_GetItem);
        python.Py_DecRef(PyObject_GetAttrString2);
        python.Py_DecRef(PyObject_GetAttrString3);
        return numpyArray;
    }

    private static String getOutputCheckCode(PythonVariables pythonVariables) {
        String str = "__error_message=''\n";
        for (String str2 : pythonVariables.getVariables()) {
            PythonVariables.Type type = pythonVariables.getType(str2);
            str = str + String.format("if '%s' not in locals(): __error_message += '%s not found.'\n", str2, str2);
            switch (type) {
                case INT:
                    str = str + String.format("if not isinstance(%s, %s): __error_message += '%s is not of required type.'\n", str2, "int", str2);
                    break;
                case STR:
                    str = str + String.format("if not isinstance(%s, %s): __error_message += '%s is not of required type.'\n", str2, "str", str2);
                    break;
                case FLOAT:
                    str = str + String.format("if not isinstance(%s, %s): __error_message += '%s is not of required type.'\n", str2, "float", str2);
                    break;
                case BOOL:
                    str = str + String.format("if not isinstance(%s, %s): __error_message += '%s is not of required type.'\n", str2, "bool", str2);
                    break;
                case NDARRAY:
                    str = str + String.format("if not isinstance(%s, %s): __error_message += '%s is not of required type.'\n", str2, "np.ndarray", str2);
                    break;
                case LIST:
                    str = str + String.format("if not isinstance(%s, %s): __error_message += '%s is not of required type.'\n", str2, "list", str2);
                    break;
            }
        }
        return str;
    }

    private static String outputCode(PythonVariables pythonVariables) {
        String str;
        if (pythonVariables == null) {
            return "";
        }
        String str2 = "import json\nwith open('" + getTempFile() + "', 'w') as ___fobj_:json.dump({";
        boolean z = false;
        for (String str3 : pythonVariables.getVariables()) {
            if (pythonVariables.getType(str3) == PythonVariables.Type.NDARRAY) {
                if (!z) {
                    z = true;
                    str2 = "serialize_ndarray_metadata=lambda x:{\"address\":x.__array_interface__['data'][0],\"shape\":x.shape,\"strides\":x.strides,\"dtype\":str(x.dtype)}\n" + str2;
                }
                str = str2 + "\"" + str3 + "\":serialize_ndarray_metadata(" + str3 + "),";
            } else {
                str = str2 + "\"" + str3 + "\":" + str3 + ",";
            }
            str2 = str;
        }
        return str2.substring(0, str2.length() - 1) + "}, ___fobj_)\n";
    }

    private static String read(String str) {
        try {
            File file = new File(str);
            FileInputStream fileInputStream = new FileInputStream(file);
            byte[] bArr = new byte[(int) file.length()];
            fileInputStream.read(bArr);
            fileInputStream.close();
            return new String(bArr, "UTF-8");
        } catch (Exception e) {
            return "";
        }
    }

    private static String jArrayToPyString(Object[] objArr) {
        String str = "[";
        for (int i = 0; i < objArr.length; i++) {
            Object obj = objArr[i];
            str = obj instanceof Object[] ? str + jArrayToPyString((Object[]) obj) : obj instanceof String ? str + "\"" + obj + "\"" : str + obj.toString().replace("\"", "\\\"");
            if (i < objArr.length - 1) {
                str = str + ",";
            }
        }
        return str + "]";
    }

    private static String escapeStr(String str) {
        return str.replace("\\", "\\\\").replace("\"\"\"", "\\\"\\\"\\\"");
    }

    private static String getFunctionalCode(String str, String str2) {
        String format = String.format("def %s():\n", str);
        for (String str3 : str2.split(Pattern.quote("\n"))) {
            format = format + "    " + str3 + "\n";
        }
        return format + "\n\n" + str + "()\n";
    }

    private static String getTempFile() {
        String str = "temp_" + Thread.currentThread().getId() + ".json";
        log.info(str);
        return str;
    }

    private static long[] jsonArrayToLongArray(JSONArray jSONArray) {
        long[] jArr = new long[jSONArray.size()];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = ((Long) jSONArray.get(i)).longValue();
        }
        return jArr;
    }

    static {
        init();
    }
}
