package com.huaweicloud.pangu.dev.sdk.agent;

import com.alibaba.fastjson.JSON;
import com.huaweicloud.pangu.dev.sdk.api.agent.Agent;
import com.huaweicloud.pangu.dev.sdk.api.agent.AgentListener;
import com.huaweicloud.pangu.dev.sdk.api.agent.AgentSessionHelper;
import com.huaweicloud.pangu.dev.sdk.api.callback.StreamCallBack;
import com.huaweicloud.pangu.dev.sdk.api.callback.StreamResult;
import com.huaweicloud.pangu.dev.sdk.api.llms.LLM;
import com.huaweicloud.pangu.dev.sdk.api.llms.config.LLMModuleConfig;
import com.huaweicloud.pangu.dev.sdk.api.llms.request.ConversationMessage;
import com.huaweicloud.pangu.dev.sdk.api.llms.request.Role;
import com.huaweicloud.pangu.dev.sdk.api.llms.response.LLMResp;
import com.huaweicloud.pangu.dev.sdk.api.retriever.ToolRetriever;
import com.huaweicloud.pangu.dev.sdk.api.tool.Tool;
import com.huaweicloud.pangu.dev.sdk.exception.PanguDevSDKException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.similarity.LevenshteinDistance;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/huaweicloud/pangu/dev/sdk/agent/AbstractAgent.class */
public abstract class AbstractAgent implements Agent {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) AbstractAgent.class);
    protected final LLM llm;
    protected final LinkedHashMap<String, Tool> toolMap = new LinkedHashMap<>();
    protected int maxIterations = 15;
    protected AgentListener agentListener;
    protected ToolRetriever toolRetriever;

    /* loaded from: input_file:com/huaweicloud/pangu/dev/sdk/agent/AbstractAgent$AgentStreamCallBack.class */
    protected static class AgentStreamCallBack implements StreamCallBack {
        private final StreamCallBack textStreamCallBack;
        private final StreamCallBack toolStreamCallBack;
        private final LLMModuleConfig llmModuleConfig;
        private StringBuffer toolStringBuffer = new StringBuffer();
        private boolean inToolStream = false;

        public AgentStreamCallBack(StreamCallBack streamCallBack, StreamCallBack streamCallBack2, LLMModuleConfig lLMModuleConfig) {
            this.textStreamCallBack = streamCallBack;
            this.toolStreamCallBack = streamCallBack2;
            this.llmModuleConfig = lLMModuleConfig;
        }

        @Override // com.huaweicloud.pangu.dev.sdk.api.callback.StreamCallBack
        public void onStart(String str) {
            this.textStreamCallBack.onStart(str);
        }

        @Override // com.huaweicloud.pangu.dev.sdk.api.callback.StreamCallBack
        public void onEnd(String str, StreamResult streamResult, LLMResp lLMResp) {
            this.textStreamCallBack.onEnd(str, streamResult, lLMResp);
        }

        @Override // com.huaweicloud.pangu.dev.sdk.api.callback.StreamCallBack
        public void onError(String str, StreamResult streamResult) {
            this.textStreamCallBack.onError(str, streamResult);
            this.toolStreamCallBack.onError(str, streamResult);
        }

        /* JADX WARN: Type inference failed for: r0v34, types: [com.huaweicloud.pangu.dev.sdk.api.llms.response.LLMResp$LLMRespBuilder] */
        /* JADX WARN: Type inference failed for: r0v39, types: [com.huaweicloud.pangu.dev.sdk.api.llms.response.LLMResp$LLMRespBuilder] */
        /* JADX WARN: Type inference failed for: r2v12, types: [com.huaweicloud.pangu.dev.sdk.api.llms.response.LLMResp$LLMRespBuilder] */
        /* JADX WARN: Type inference failed for: r2v3, types: [com.huaweicloud.pangu.dev.sdk.api.llms.response.LLMResp$LLMRespBuilder] */
        /* JADX WARN: Type inference failed for: r3v3, types: [com.huaweicloud.pangu.dev.sdk.api.llms.response.LLMResp$LLMRespBuilder] */
        @Override // com.huaweicloud.pangu.dev.sdk.api.callback.StreamCallBack
        public void onNewToken(String str, LLMResp lLMResp) {
            String answer = lLMResp.getAnswer();
            String unifyToolTagPrefix = this.llmModuleConfig.getLllModuleProperty().getUnifyToolTagPrefix();
            String unifyToolTagSuffix = this.llmModuleConfig.getLllModuleProperty().getUnifyToolTagSuffix();
            if (StringUtils.contains(answer, unifyToolTagPrefix)) {
                this.textStreamCallBack.onNewToken(str, LLMResp.builder().answer(StringUtils.substringBefore(answer, unifyToolTagPrefix)).build());
                this.toolStreamCallBack.onStart(str);
                this.toolStringBuffer = new StringBuffer();
                this.inToolStream = true;
                String substringAfter = StringUtils.substringAfter(answer, unifyToolTagPrefix);
                this.toolStreamCallBack.onNewToken(str, LLMResp.builder().answer(substringAfter).build());
                this.toolStringBuffer.append(substringAfter);
                return;
            }
            if (!StringUtils.contains(answer, unifyToolTagSuffix)) {
                if (!this.inToolStream) {
                    this.textStreamCallBack.onNewToken(str, lLMResp);
                    return;
                } else {
                    this.toolStreamCallBack.onNewToken(str, lLMResp);
                    this.toolStringBuffer.append(lLMResp.getAnswer());
                    return;
                }
            }
            String substringBefore = StringUtils.substringBefore(answer, unifyToolTagSuffix);
            this.toolStreamCallBack.onNewToken(str, LLMResp.builder().answer(substringBefore).build());
            this.toolStringBuffer.append(substringBefore);
            this.toolStreamCallBack.onEnd(str, new StreamResult(), LLMResp.builder().answer(this.toolStringBuffer.toString()).build());
            this.inToolStream = false;
            this.textStreamCallBack.onNewToken(str, LLMResp.builder().answer(StringUtils.substringAfter(answer, unifyToolTagSuffix)).build());
        }
    }

    public AbstractAgent(LLM llm) {
        this.llm = llm;
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public void addTool(Tool tool) {
        this.toolMap.put(tool.getToolId(), tool);
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public void removeTool(String str) {
        this.toolMap.remove(str);
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public void clearTool() {
        this.toolMap.clear();
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public void setMaxIterations(int i) {
        if (i <= 0) {
            throw new PanguDevSDKException("iterations value not legal.");
        }
        this.maxIterations = i;
    }

    protected abstract void react(AgentSession agentSession);

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public AgentSession run(String str) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(ConversationMessage.builder().role(Role.USER).content(str).build());
        return run(arrayList);
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public AgentSession run(List<ConversationMessage> list) {
        return run(AgentSessionHelper.initAgentSession(list));
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public AgentSession run(AgentSession agentSession) {
        List<ConversationMessage> messages = agentSession.getMessages();
        initToolsFromRetriever(messages);
        this.llm.getLLMConfig().getLlmModuleConfig().setEnableAppendSystemMessage(false);
        noticeSessionStart(agentSession);
        agentSession.setByStep(false);
        if (messages.get(messages.size() - 1).getRole() != Role.ASSISTANT) {
            agentSession.setCurrentMessage(ConversationMessage.builder().role(Role.ASSISTANT).build());
            messages.add(agentSession.getCurrentMessage());
        }
        react(agentSession);
        AgentSessionHelper.updateAssistantMessage(agentSession, false);
        if (!agentSession.getCurrentMessage().getActions().isEmpty()) {
            log.info(AgentSessionHelper.printPlan(agentSession));
        }
        return agentSession;
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public AgentSession runStep(AgentSession agentSession) {
        initToolsFromRetriever(agentSession.getMessages());
        this.llm.getLLMConfig().getLlmModuleConfig().setEnableAppendSystemMessage(false);
        agentSession.setByStep(true);
        react(agentSession);
        if (agentSession.getAgentSessionStatus() != AgentSessionStatus.FINISHED) {
            noticeSessionIteration(agentSession, agentSession.getCurrentAction());
        }
        return agentSession;
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public void addListener(AgentListener agentListener) {
        this.agentListener = agentListener;
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public void setToolRetriever(ToolRetriever toolRetriever) {
        this.toolRetriever = toolRetriever;
    }

    @Override // com.huaweicloud.pangu.dev.sdk.api.agent.Agent
    public void setStreamCallback(StreamCallBack streamCallBack, StreamCallBack streamCallBack2) {
        this.llm.setStreamCallback(new AgentStreamCallBack(streamCallBack, streamCallBack2, this.llm.getLLMConfig().getLlmModuleConfig()));
    }

    private void noticeSessionStart(AgentSession agentSession) {
        agentSession.setAgentSessionStatus(AgentSessionStatus.RUNNING);
        if (this.agentListener != null) {
            this.agentListener.onSessionStart(agentSession);
        }
    }

    private void noticeSessionIteration(AgentSession agentSession, AgentAction agentAction) {
        agentSession.getCurrentMessage().getActions().add(agentAction);
        agentSession.setAgentSessionStatus(AgentSessionStatus.RUNNING);
        if (this.agentListener != null) {
            this.agentListener.onSessionIteration(agentSession);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void noticeSessionEnd(AgentSession agentSession, AgentAction agentAction) {
        agentSession.getCurrentMessage().getActions().add(agentAction);
        agentSession.setAgentSessionStatus(AgentSessionStatus.FINISHED);
        if (this.agentListener != null) {
            this.agentListener.onSessionEnd(agentSession);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean needInterrupt(AgentSession agentSession) {
        if (agentSession.getCurrentMessage().getActions().isEmpty()) {
            return false;
        }
        if (agentSession.getCurrentMessage().getActions().size() >= this.maxIterations) {
            log.warn("agent stopped due to iteration limit. maxIterations is {}", Integer.valueOf(this.maxIterations));
            return true;
        }
        if (this.agentListener == null || !this.agentListener.onCheckInterruptRequirement(agentSession)) {
            return false;
        }
        agentSession.setAgentSessionStatus(AgentSessionStatus.INTERRUPTED);
        log.info("agent stopped due to manual interruption");
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void toolExecute(Tool tool, String str, AgentSession agentSession) {
        if (agentSession.isByStep()) {
            return;
        }
        try {
            Object runFromJson = tool.runFromJson(str);
            AgentAction currentAction = agentSession.getCurrentAction();
            if ((runFromJson instanceof String) || (runFromJson instanceof Number)) {
                currentAction.setObservation(runFromJson.toString());
            } else {
                currentAction.setObservation(JSON.toJSONString(runFromJson));
            }
            noticeSessionIteration(agentSession, currentAction);
            react(agentSession);
        } catch (Exception e) {
            log.error("tool execute failed, tool={}, input={}", tool, str);
            throw new PanguDevSDKException("tool execute failed", e);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Tool getTool(String str) {
        if (this.toolMap.isEmpty()) {
            log.error("there is no tool in agent");
            throw new PanguDevSDKException("there is no tool in agent");
        }
        Tool tool = this.toolMap.get(str);
        if (tool != null) {
            return tool;
        }
        LevenshteinDistance levenshteinDistance = new LevenshteinDistance();
        String orElse = this.toolMap.keySet().stream().min(Comparator.comparing(str2 -> {
            return levenshteinDistance.apply((CharSequence) str2, (CharSequence) str);
        })).orElse("");
        log.warn("can not find tool for {} in {}, the most similar tool is {}", str, this.toolMap.keySet(), orElse);
        return this.toolMap.get(orElse);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String getSystemPrompt(AgentSession agentSession) {
        String systemPrompt = this.llm.getLLMConfig().getLlmModuleConfig().getSystemPrompt();
        if (StringUtils.isNotEmpty(systemPrompt)) {
            return systemPrompt;
        }
        List<ConversationMessage> messages = agentSession.getMessages();
        for (int size = messages.size() - 1; size >= 0; size--) {
            ConversationMessage conversationMessage = messages.get(size);
            if (conversationMessage.getRole() == Role.SYSTEM) {
                return conversationMessage.getContent();
            }
        }
        return null;
    }

    private void initToolsFromRetriever(List<ConversationMessage> list) {
        if (this.toolRetriever == null || list == null) {
            return;
        }
        List<Tool> search = this.toolRetriever.search(this.toolRetriever.getQueryPreprocessor().apply(list));
        this.toolMap.clear();
        search.forEach(this::addTool);
    }
}
