package zju.cst.aces.prompt;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import zju.cst.aces.api.config.Config;
import zju.cst.aces.dto.Message;
import zju.cst.aces.dto.PromptInfo;
import zju.cst.aces.util.TokenCounter;

/* loaded from: input_file:zju/cst/aces/prompt/PromptGenerator.class */
public class PromptGenerator {
    public Config config;
    public PromptTemplate promptTemplate;

    public PromptGenerator(Config config) throws IOException {
        this.config = config;
        this.promptTemplate = new PromptTemplate(config, config.properties, config.getPromptPath(), config.getMaxPromptTokens());
    }

    public void setConfig(Config config) {
        this.config = config;
        this.promptTemplate = new PromptTemplate(config, config.properties, config.getPromptPath(), config.getMaxPromptTokens());
    }

    public List<Message> generateMessages(PromptInfo promptInfo) {
        ArrayList arrayList = new ArrayList();
        if (promptInfo.errorMsg == null) {
            arrayList.add(Message.ofSystem(createSystemPrompt(promptInfo, this.promptTemplate.TEMPLATE_INIT)));
            arrayList.add(Message.of(createUserPrompt(promptInfo, this.promptTemplate.TEMPLATE_INIT)));
        } else {
            arrayList.add(Message.of(createUserPrompt(promptInfo, this.promptTemplate.TEMPLATE_REPAIR)));
        }
        return arrayList;
    }

    public List<Message> generateMessages(PromptInfo promptInfo, String str) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(Message.ofSystem(createSystemPrompt(promptInfo, str)));
        arrayList.add(Message.of(createUserPrompt(promptInfo, str)));
        return arrayList;
    }

    public String createUserPrompt(PromptInfo promptInfo, String str) {
        try {
            this.promptTemplate.buildDataModel(this.config, promptInfo);
            if (!str.equals(this.promptTemplate.TEMPLATE_REPAIR)) {
                return this.promptTemplate.renderTemplate(str);
            }
            int max = Math.max(this.config.getMaxPromptTokens() - ((((TokenCounter.countToken(promptInfo.getUnitTest()) + TokenCounter.countToken(promptInfo.getMethodSignature())) + TokenCounter.countToken(promptInfo.getClassName())) + TokenCounter.countToken(promptInfo.getContext())) + TokenCounter.countToken(promptInfo.getOtherMethodBrief())), this.config.getMinErrorTokens());
            String str2 = "";
            for (String str3 : promptInfo.getErrorMsg().getErrorMessage()) {
                if (TokenCounter.countToken(str2 + str3 + "\n") <= max) {
                    str2 = str2 + str3 + "\n";
                }
            }
            this.config.getLog().debug("Allowed tokens: " + max);
            this.config.getLog().debug("Processed error message: \n" + str2);
            this.promptTemplate.dataModel.put("unit_test", promptInfo.getUnitTest());
            this.promptTemplate.dataModel.put("error_message", str2);
            return this.promptTemplate.renderTemplate(this.promptTemplate.TEMPLATE_REPAIR);
        } catch (Exception e) {
            throw new RuntimeException("An error occurred while generating the user prompt: " + e);
        }
    }

    public String createSystemPrompt(PromptInfo promptInfo, String str) {
        try {
            return this.promptTemplate.renderTemplate(addSystemFileName(str));
        } catch (Exception e) {
            if (e instanceof IOException) {
                return "";
            }
            throw new RuntimeException("An error occurred while generating the system prompt: " + e);
        }
    }

    public String addSystemFileName(String str) {
        String[] split = str.split("\\.");
        return split.length > 1 ? split[0] + "_system." + split[1] : str;
    }

    public String buildCOT(COT<?> cot) {
        return "";
    }

    public String buildTOT(TOT<?> tot) {
        return "";
    }
}
