package zju.cst.aces.runner;

import com.github.javaparser.StaticJavaParser;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.nio.file.CopyOption;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.stream.Collectors;
import okhttp3.Response;
import zju.cst.aces.api.Task;
import zju.cst.aces.api.config.Config;
import zju.cst.aces.dto.ClassInfo;
import zju.cst.aces.dto.Message;
import zju.cst.aces.dto.MethodInfo;
import zju.cst.aces.dto.PromptInfo;
import zju.cst.aces.parser.ClassParser;
import zju.cst.aces.prompt.PromptGenerator;
import zju.cst.aces.util.CodeExtractor;
import zju.cst.aces.util.TokenCounter;

/* loaded from: input_file:zju/cst/aces/runner/AbstractRunner.class */
public abstract class AbstractRunner {
    public static final String separator = "_";
    public String className;
    public String fullClassName;
    public Config config;
    public PromptGenerator promptGenerator;
    public static final Gson GSON = new GsonBuilder().setPrettyPrinting().disableHtmlEscaping().create();
    public static int testTimeOut = 8000;

    public AbstractRunner(Config config, String str) throws IOException {
        this.fullClassName = str;
        this.className = str.substring(str.lastIndexOf(".") + 1);
        this.config = config;
        this.promptGenerator = new PromptGenerator(config);
    }

    abstract void start() throws IOException;

    public static String joinLines(List<String> list) {
        return (String) list.stream().collect(Collectors.joining("\n"));
    }

    public static String filterAndJoinLines(List<String> list, String str) {
        return (String) list.stream().filter(str2 -> {
            return !str2.equals(str);
        }).collect(Collectors.joining("\n"));
    }

    public static String parseResponse(Response response) {
        return response == null ? "" : (String) ((Map) ((Map) ((ArrayList) ((Map) GSON.fromJson(response.body().charStream(), Map.class)).get("choices")).get(0)).get("message")).get("content");
    }

    public static void exportTest(String str, Path path) {
        if (!path.toAbsolutePath().getParent().toFile().exists()) {
            path.toAbsolutePath().getParent().toFile().mkdirs();
        }
        try {
            OutputStreamWriter outputStreamWriter = new OutputStreamWriter(new FileOutputStream(path.toFile()), StandardCharsets.UTF_8);
            try {
                outputStreamWriter.write(str);
                outputStreamWriter.close();
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("In AbstractRunner.exportTest: " + e);
        }
    }

    public String extractCode(String str) {
        try {
            return new CodeExtractor(str).getExtractedCode();
        } catch (Exception e) {
            this.config.getLog().error("In AbstractRunner.extractCode: " + e);
            return "";
        }
    }

    public static String repairImports(String str, List<String> list) {
        CompilationUnit parse = StaticJavaParser.parse(str);
        parse.addImport("org.mockito", false, true);
        parse.addImport("org.junit.jupiter.api", false, true);
        parse.addImport("org.mockito.Mockito", true, true);
        parse.addImport("org.junit.jupiter.api.Assertions", true, true);
        list.forEach(str2 -> {
            parse.addImport(str2.replace("import ", "").replace(";", ""));
        });
        return parse.toString();
    }

    public static String repairPackage(String str, String str2) {
        return StaticJavaParser.parse(str).setPackageDeclaration(str2).toString();
    }

    public String addTimeout(String str, int i) {
        if (str.contains("import org.junit.Test")) {
            return str.contains("@Test(timeout =") ? str : str.replace("@Test(", String.format("@Test(timeout = %d, ", Integer.valueOf(i))).replace("@Test\n", String.format("@Test(timeout = %d)%n", Integer.valueOf(i)));
        }
        if (!str.contains("import org.junit.jupiter.api.Test")) {
            this.config.getLog().warn("Generated with unknown JUnit version, try without adding timeout.");
            return str;
        }
        if (str.contains("import org.junit.jupiter.api.Timeout;")) {
            return str;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add("import org.junit.jupiter.api.Timeout;");
        return repairImports(str, arrayList).replace("@Test\n", String.format("@Test%n    @Timeout(%d)%n", Integer.valueOf(i)));
    }

    public static String changeTestName(String str, String str2) {
        CompilationUnit parse = StaticJavaParser.parse(str);
        parse.findFirst(ClassOrInterfaceDeclaration.class).ifPresent(classOrInterfaceDeclaration -> {
            classOrInterfaceDeclaration.setName(str2);
        });
        return parse.toString();
    }

    public static PromptInfo generatePromptInfoWithoutDep(Config config, ClassInfo classInfo, MethodInfo methodInfo) throws IOException {
        PromptInfo promptInfo = new PromptInfo(false, classInfo.fullClassName, methodInfo.methodName, methodInfo.methodSignature);
        promptInfo.setClassInfo(classInfo);
        promptInfo.setMethodInfo(methodInfo);
        String joinLines = joinLines(classInfo.fields);
        String filterAndJoinLines = filterAndJoinLines(classInfo.methodsBrief, methodInfo.brief);
        String str = classInfo.packageDeclaration + "\n" + joinLines(classInfo.imports) + "\n" + classInfo.classSignature + " {\n";
        if (methodInfo.useField) {
            str = str + joinLines + "\n";
        }
        promptInfo.setContext(str + methodInfo.sourceCode + "\n}");
        promptInfo.setOtherMethodBrief(filterAndJoinLines);
        String str2 = "";
        for (String str3 : classInfo.methodSigs.keySet()) {
            if (!str3.equals(methodInfo.methodSignature)) {
                str2 = str2 + getBody(config, classInfo, str3);
            }
        }
        promptInfo.setOtherMethodBodies(str2);
        return promptInfo;
    }

    public static PromptInfo generatePromptInfoWithDep(Config config, ClassInfo classInfo, MethodInfo methodInfo) throws IOException {
        PromptInfo promptInfo = new PromptInfo(true, classInfo.fullClassName, methodInfo.methodName, methodInfo.methodSignature);
        promptInfo.setClassInfo(classInfo);
        promptInfo.setMethodInfo(methodInfo);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Map.Entry<String, Set<String>> entry : classInfo.constructorDeps.entrySet()) {
            String key = entry.getKey();
            Set<String> value = entry.getValue();
            if (!methodInfo.dependentMethods.containsKey(key)) {
                promptInfo.addConstructorDeps(key, getDepInfo(config, key, value));
            }
        }
        for (Map.Entry<String, Set<String>> entry2 : methodInfo.dependentMethods.entrySet()) {
            String key2 = entry2.getKey();
            if (key2.equals(classInfo.getClassName())) {
                Iterator<String> it = methodInfo.dependentMethods.get(key2).iterator();
                while (it.hasNext()) {
                    MethodInfo methodInfo2 = getMethodInfo(config, classInfo, it.next());
                    if (methodInfo2 != null) {
                        arrayList.add(methodInfo2.brief);
                        arrayList2.add(methodInfo2.sourceCode);
                    }
                }
            } else {
                Set<String> value2 = entry2.getValue();
                promptInfo.addMethodDeps(key2, getDepInfo(config, key2, value2));
                addMethodDepsByDepth(config, key2, value2, promptInfo, config.getDependencyDepth());
            }
        }
        String joinLines = joinLines(classInfo.fields);
        String str = classInfo.packageDeclaration + "\n" + joinLines(classInfo.imports) + "\n" + classInfo.classSignature + " {\n";
        String str2 = "";
        String str3 = "";
        if (classInfo.hasConstructor) {
            str2 = str2 + joinLines(classInfo.constructorBrief) + "\n";
            str3 = str3 + getBodies(config, classInfo, classInfo.constructorSigs) + "\n";
        }
        if (methodInfo.useField) {
            str = str + joinLines + "\n";
            str2 = str2 + joinLines(classInfo.getterSetterBrief) + "\n";
            str3 = str3 + getBodies(config, classInfo, classInfo.getterSetterSigs) + "\n";
        }
        String str4 = str2 + joinLines(arrayList) + "\n";
        String str5 = str3 + joinLines(arrayList2) + "\n";
        promptInfo.setContext(str + methodInfo.sourceCode + "\n}");
        promptInfo.setOtherMethodBrief(str4);
        promptInfo.setOtherMethodBodies(str5);
        return promptInfo;
    }

    public static void addMethodDepsByDepth(Config config, String str, Set<String> set, PromptInfo promptInfo, int i) throws IOException {
        if (i <= 1) {
            return;
        }
        for (String str2 : set) {
            ClassInfo classInfo = getClassInfo(config, str);
            if (classInfo != null) {
                addConstructorDepsByDepth(config, classInfo, promptInfo);
                MethodInfo methodInfo = getMethodInfo(config, classInfo, str2);
                if (methodInfo != null) {
                    for (String str3 : methodInfo.dependentMethods.keySet()) {
                        Set<String> set2 = methodInfo.dependentMethods.get(str3);
                        promptInfo.addMethodDeps(str3, getDepInfo(config, str3, set2));
                        addMethodDepsByDepth(config, str3, set2, promptInfo, i - 1);
                    }
                }
            }
        }
    }

    public static void addConstructorDepsByDepth(Config config, ClassInfo classInfo, PromptInfo promptInfo) throws IOException {
        for (Map.Entry<String, Set<String>> entry : classInfo.constructorDeps.entrySet()) {
            String key = entry.getKey();
            promptInfo.addConstructorDeps(key, getDepInfo(config, key, entry.getValue()));
        }
    }

    public static ClassInfo getClassInfo(Config config, String str) throws IOException {
        Path resolve = config.getParseOutput().resolve(Task.getFullClassName(config, str).replace(".", File.separator)).resolve("class.json");
        if (resolve.toFile().exists()) {
            return (ClassInfo) GSON.fromJson(Files.readString(resolve, StandardCharsets.UTF_8), ClassInfo.class);
        }
        return null;
    }

    public static MethodInfo getMethodInfo(Config config, ClassInfo classInfo, String str) throws IOException {
        Path resolve = config.getParseOutput().resolve(classInfo.packageDeclaration.replace("package ", "").replace(".", File.separator).replace(";", "")).resolve(classInfo.className).resolve(ClassParser.getFilePathBySig(str, classInfo));
        if (resolve.toFile().exists()) {
            return (MethodInfo) GSON.fromJson(Files.readString(resolve, StandardCharsets.UTF_8), MethodInfo.class);
        }
        return null;
    }

    public static String getDepInfo(Config config, String str, Set<String> set) throws IOException {
        ClassInfo classInfo = getClassInfo(config, str);
        if (classInfo == null) {
            return null;
        }
        String str2 = classInfo.packageDeclaration + "\n" + joinLines(classInfo.imports) + "\n" + classInfo.classSignature + " {\n" + joinLines(classInfo.fields) + "\n";
        if (classInfo.hasConstructor) {
            String str3 = "";
            Iterator<String> it = classInfo.constructorSigs.iterator();
            while (it.hasNext()) {
                MethodInfo methodInfo = getMethodInfo(config, classInfo, it.next());
                if (methodInfo != null) {
                    str3 = str3 + methodInfo.getSourceCode() + "\n";
                }
            }
            str2 = str2 + str3 + "\n";
        }
        String str4 = "";
        Iterator<String> it2 = set.iterator();
        while (it2.hasNext()) {
            MethodInfo methodInfo2 = getMethodInfo(config, classInfo, it2.next());
            if (methodInfo2 != null) {
                str4 = str4 + methodInfo2.getSourceCode() + "\n";
            }
        }
        return str2 + (joinLines(classInfo.getterSetterBrief) + "\n") + str4 + "}";
    }

    public static String getBodies(Config config, ClassInfo classInfo, List<String> list) throws IOException {
        String str = "";
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            str = str + getBody(config, classInfo, it.next()) + "\n";
        }
        return str;
    }

    public static String getBody(Config config, ClassInfo classInfo, String str) throws IOException {
        return getMethodInfo(config, classInfo, str).sourceCode;
    }

    public void exportRecord(PromptInfo promptInfo, ClassInfo classInfo, int i) {
        String str = classInfo.methodSigs.get(promptInfo.methodSignature);
        Path resolve = this.config.getHistoryPath().resolve("class" + classInfo.index);
        exportMethodMapping(classInfo, resolve);
        Path resolve2 = resolve.resolve("method" + str);
        exportAttemptMapping(promptInfo, resolve2);
        Path resolve3 = resolve2.resolve("attempt" + i);
        if (!resolve3.toFile().exists()) {
            resolve3.toFile().mkdirs();
        }
        try {
            OutputStreamWriter outputStreamWriter = new OutputStreamWriter(new FileOutputStream(resolve3.resolve("records.json").toFile()), StandardCharsets.UTF_8);
            try {
                outputStreamWriter.write(GSON.toJson(promptInfo.getRecords()));
                outputStreamWriter.close();
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("In AbstractRunner.exportRecord: " + e);
        }
    }

    public static synchronized void exportClassMapping(Config config, Path path) {
        if (!path.toFile().exists()) {
            path.toFile().mkdirs();
        }
        File file = path.resolve("classMapping.json").toFile();
        if (file.exists()) {
            return;
        }
        try {
            Files.copy(config.tmpOutput.resolve("classMapping.json"), file.toPath(), new CopyOption[0]);
        } catch (IOException e) {
            throw new RuntimeException("In AbstractRunner.exportClassMapping: " + e);
        }
    }

    public void exportMethodMapping(ClassInfo classInfo, Path path) {
        if (!path.toFile().exists()) {
            path.toFile().mkdirs();
        }
        File file = path.resolve("methodMapping.json").toFile();
        if (file.exists()) {
            return;
        }
        TreeMap treeMap = new TreeMap();
        classInfo.methodSigs.forEach((str, str2) -> {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            linkedHashMap.put("methodName", str.split("\\(")[0]);
            linkedHashMap.put("signature", str);
            linkedHashMap.put("className", classInfo.className);
            linkedHashMap.put("packageDeclaration", classInfo.packageDeclaration);
            treeMap.put("method" + str2, linkedHashMap);
        });
        try {
            OutputStreamWriter outputStreamWriter = new OutputStreamWriter(new FileOutputStream(file), StandardCharsets.UTF_8);
            try {
                outputStreamWriter.write(GSON.toJson(treeMap));
                outputStreamWriter.close();
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("In AbstractRunner.exportMethodMapping: " + e);
        }
    }

    public void exportAttemptMapping(PromptInfo promptInfo, Path path) {
        if (!path.toFile().exists()) {
            path.toFile().mkdirs();
        }
        File file = path.resolve("attemptMapping.json").toFile();
        if (file.exists()) {
            return;
        }
        TreeMap treeMap = new TreeMap();
        String substring = promptInfo.getFullTestName().substring(0, promptInfo.getFullTestName().indexOf("_Test") - 1);
        for (int i = 0; i < this.config.getTestNumber(); i++) {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            String str = substring + i + "_Test";
            linkedHashMap.put("testClassName", str.substring(str.lastIndexOf(".") + 1));
            linkedHashMap.put("fullName", str);
            linkedHashMap.put("path", promptInfo.getTestPath().toString());
            linkedHashMap.put("className", promptInfo.className);
            linkedHashMap.put("packageDeclaration", promptInfo.classInfo.packageDeclaration);
            linkedHashMap.put("methodName", promptInfo.methodName);
            linkedHashMap.put("methodSig", promptInfo.methodSignature);
            treeMap.put("attempt" + i, linkedHashMap);
        }
        try {
            OutputStreamWriter outputStreamWriter = new OutputStreamWriter(new FileOutputStream(file), StandardCharsets.UTF_8);
            try {
                outputStreamWriter.write(GSON.toJson(treeMap));
                outputStreamWriter.close();
            } finally {
            }
        } catch (IOException e) {
            throw new RuntimeException("In AbstractRunner.exportAttemptMapping: " + e);
        }
    }

    public static boolean isExceedMaxTokens(Config config, List<Message> list) {
        int i = 0;
        Iterator<Message> it = list.iterator();
        while (it.hasNext()) {
            i += TokenCounter.countToken(it.next().getContent());
        }
        return i > config.maxPromptTokens;
    }

    public static boolean isExceedMaxTokens(Config config, String str) {
        return TokenCounter.countToken(str) > config.maxPromptTokens;
    }
}
