package io.quarkiverse.langchain4j.runtime.aiservice;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolExecutor;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.ServiceOutputParser;
import dev.langchain4j.service.TokenStream;
import io.quarkiverse.langchain4j.audit.Audit;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Function;
import org.jboss.logging.Logger;

/* loaded from: input_file:io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport.class */
public class AiServiceMethodImplementationSupport {
    private static final Logger log = Logger.getLogger(AiServiceMethodImplementationSupport.class);
    private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;

    /* loaded from: input_file:io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport$Input.class */
    public static class Input {
        final QuarkusAiServiceContext context;
        final AiServiceMethodCreateInfo createInfo;
        final Object[] methodArgs;

        public Input(QuarkusAiServiceContext quarkusAiServiceContext, AiServiceMethodCreateInfo aiServiceMethodCreateInfo, Object[] objArr) {
            this.context = quarkusAiServiceContext;
            this.createInfo = aiServiceMethodCreateInfo;
            this.methodArgs = objArr;
        }
    }

    /* loaded from: input_file:io/quarkiverse/langchain4j/runtime/aiservice/AiServiceMethodImplementationSupport$Wrapper.class */
    public interface Wrapper {
        Object wrap(Input input, Function<Input, Object> function);
    }

    public Object implement(Input input) {
        QuarkusAiServiceContext quarkusAiServiceContext = input.context;
        AiServiceMethodCreateInfo aiServiceMethodCreateInfo = input.createInfo;
        Object[] objArr = input.methodArgs;
        AuditService auditService = quarkusAiServiceContext.auditService;
        Audit audit = null;
        if (auditService != null) {
            audit = auditService.create(new Audit.CreateInfo(aiServiceMethodCreateInfo.getInterfaceName(), aiServiceMethodCreateInfo.getMethodName(), objArr, aiServiceMethodCreateInfo.getMemoryIdParamPosition()));
        }
        try {
            Object doImplement = doImplement(aiServiceMethodCreateInfo, objArr, quarkusAiServiceContext, audit);
            if (audit != null) {
                audit.onCompletion(doImplement);
                auditService.complete(audit);
            }
            return doImplement;
        } catch (Exception e) {
            log.errorv(e, "Execution of {0}#{1} failed", aiServiceMethodCreateInfo.getInterfaceName(), aiServiceMethodCreateInfo.getMethodName());
            if (audit != null) {
                audit.onFailure(e);
                auditService.complete(audit);
            }
            throw e;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v104, types: [java.util.List] */
    private static Object doImplement(AiServiceMethodCreateInfo aiServiceMethodCreateInfo, Object[] objArr, QuarkusAiServiceContext quarkusAiServiceContext, Audit audit) {
        ArrayList arrayList;
        Optional<SystemMessage> prepareSystemMessage = prepareSystemMessage(aiServiceMethodCreateInfo, objArr);
        UserMessage prepareUserMessage = prepareUserMessage(quarkusAiServiceContext, aiServiceMethodCreateInfo, objArr);
        if (audit != null) {
            audit.initialMessages(prepareSystemMessage, prepareUserMessage);
        }
        Object orElse = memoryId(aiServiceMethodCreateInfo, objArr).orElse("default");
        if (quarkusAiServiceContext.retrievalAugmentor != null) {
            prepareUserMessage = quarkusAiServiceContext.retrievalAugmentor.augment(prepareUserMessage, Metadata.from(prepareUserMessage, orElse, quarkusAiServiceContext.hasChatMemory() ? quarkusAiServiceContext.chatMemory(orElse).messages() : null));
        }
        UserMessage from = UserMessage.from(prepareUserMessage.text() + aiServiceMethodCreateInfo.getUserMessageInfo().getOutputFormatInstructions());
        if (quarkusAiServiceContext.hasChatMemory()) {
            ChatMemory chatMemory = quarkusAiServiceContext.chatMemory(orElse);
            if (prepareSystemMessage.isPresent()) {
                chatMemory.add(prepareSystemMessage.get());
            }
            chatMemory.add(from);
        }
        if (quarkusAiServiceContext.hasChatMemory()) {
            arrayList = quarkusAiServiceContext.chatMemory(orElse).messages();
        } else {
            arrayList = new ArrayList();
            Objects.requireNonNull(arrayList);
            prepareSystemMessage.ifPresent((v1) -> {
                r1.add(v1);
            });
            arrayList.add(from);
        }
        Class<?> returnType = aiServiceMethodCreateInfo.getReturnType();
        if (returnType.equals(TokenStream.class)) {
            return new AiServiceTokenStream(arrayList, quarkusAiServiceContext, orElse);
        }
        Future<Moderation> triggerModerationIfNeeded = triggerModerationIfNeeded(quarkusAiServiceContext, aiServiceMethodCreateInfo, arrayList);
        log.debug("Attempting to obtain AI response");
        Response<AiMessage> generate = quarkusAiServiceContext.toolSpecifications == null ? quarkusAiServiceContext.chatModel.generate(arrayList) : quarkusAiServiceContext.chatModel.generate(arrayList, quarkusAiServiceContext.toolSpecifications);
        log.debug("AI response obtained");
        if (audit != null) {
            audit.addLLMToApplicationMessage(generate);
        }
        TokenUsage tokenUsage = generate.tokenUsage();
        AiServices.verifyModerationIfNeeded(triggerModerationIfNeeded);
        int i = MAX_SEQUENTIAL_TOOL_EXECUTIONS;
        while (true) {
            int i2 = i;
            i--;
            if (i2 == 0) {
                throw Exceptions.runtime("Something is wrong, exceeded %s sequential tool executions", new Object[]{Integer.valueOf(MAX_SEQUENTIAL_TOOL_EXECUTIONS)});
            }
            AiMessage aiMessage = (AiMessage) generate.content();
            if (quarkusAiServiceContext.hasChatMemory()) {
                quarkusAiServiceContext.chatMemory(orElse).add((ChatMessage) generate.content());
            }
            if (!aiMessage.hasToolExecutionRequests()) {
                return ServiceOutputParser.parse(Response.from((AiMessage) generate.content(), tokenUsage, generate.finishReason()), returnType);
            }
            ChatMemory chatMemory2 = quarkusAiServiceContext.chatMemory(orElse);
            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                log.debugv("Attempting to execute tool {0}", toolExecutionRequest);
                ToolExecutor toolExecutor = (ToolExecutor) quarkusAiServiceContext.toolExecutors.get(toolExecutionRequest.name());
                if (toolExecutor == null) {
                    throw Exceptions.runtime("Tool executor %s not found", new Object[]{toolExecutionRequest.name()});
                }
                String execute = toolExecutor.execute(toolExecutionRequest, orElse);
                log.debugv("Result of {0} is '{1}'", toolExecutionRequest, execute);
                ToolExecutionResultMessage from2 = ToolExecutionResultMessage.from(toolExecutionRequest, execute);
                if (audit != null) {
                    audit.addApplicationToLLMMessage(from2);
                }
                chatMemory2.add(from2);
            }
            log.debug("Attempting to obtain AI response");
            generate = quarkusAiServiceContext.chatModel.generate(chatMemory2.messages(), quarkusAiServiceContext.toolSpecifications);
            log.debug("AI response obtained");
            if (audit != null) {
                audit.addLLMToApplicationMessage(generate);
            }
            tokenUsage = tokenUsage.add(generate.tokenUsage());
        }
    }

    private static Future<Moderation> triggerModerationIfNeeded(final AiServiceContext aiServiceContext, AiServiceMethodCreateInfo aiServiceMethodCreateInfo, final List<ChatMessage> list) {
        Future<Moderation> future = null;
        if (aiServiceMethodCreateInfo.isRequiresModeration()) {
            log.debug("Moderation is required and it will be executed in the background");
            future = ((ExecutorService) Infrastructure.getDefaultExecutor()).submit(new Callable<Moderation>() { // from class: io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport.1
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public Moderation call() {
                    List removeToolMessages = AiServices.removeToolMessages(list);
                    AiServiceMethodImplementationSupport.log.debug("Attempting to moderate messages");
                    Moderation moderation = (Moderation) aiServiceContext.moderationModel.moderate(removeToolMessages).content();
                    AiServiceMethodImplementationSupport.log.debug("Moderation completed");
                    return moderation;
                }
            });
        }
        return future;
    }

    private static Optional<SystemMessage> prepareSystemMessage(AiServiceMethodCreateInfo aiServiceMethodCreateInfo, Object[] objArr) {
        if (aiServiceMethodCreateInfo.getSystemMessageInfo().isEmpty()) {
            return Optional.empty();
        }
        AiServiceMethodCreateInfo.TemplateInfo templateInfo = aiServiceMethodCreateInfo.getSystemMessageInfo().get();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Integer> entry : templateInfo.getNameToParamPosition().entrySet()) {
            hashMap.put(entry.getKey(), objArr[entry.getValue().intValue()]);
        }
        return Optional.of(PromptTemplate.from(templateInfo.getText()).apply(hashMap).toSystemMessage());
    }

    private static UserMessage prepareUserMessage(AiServiceContext aiServiceContext, AiServiceMethodCreateInfo aiServiceMethodCreateInfo, Object[] objArr) {
        AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = aiServiceMethodCreateInfo.getUserMessageInfo();
        String obj = userMessageInfo.getUserNameParamPosition().isPresent() ? objArr[userMessageInfo.getUserNameParamPosition().get().intValue()].toString() : null;
        if (!userMessageInfo.getTemplate().isPresent()) {
            if (!userMessageInfo.getParamPosition().isPresent()) {
                throw new IllegalStateException("Unable to construct UserMessage for class '" + aiServiceContext.aiServiceClass.getName() + "'. Please contact the maintainers");
            }
            Integer num = userMessageInfo.getParamPosition().get();
            Object obj2 = objArr[num.intValue()];
            if (obj2 == null) {
                throw new IllegalArgumentException("Unable to construct UserMessage for class '" + aiServiceContext.aiServiceClass.getName() + "' because parameter with index " + num + " is null");
            }
            return createUserMessage(obj, toString(obj2));
        }
        AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.getTemplate().get();
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, Integer> entry : templateInfo.getNameToParamPosition().entrySet()) {
            hashMap.put(entry.getKey(), transformTemplateParamValue(objArr[entry.getValue().intValue()]));
        }
        return createUserMessage(obj, PromptTemplate.from(templateInfo.getText()).apply(hashMap).text());
    }

    private static UserMessage createUserMessage(String str, String str2) {
        return str == null ? UserMessage.userMessage(str2) : UserMessage.userMessage(str, str2);
    }

    private static Object transformTemplateParamValue(Object obj) {
        return obj.getClass().isArray() ? Arrays.toString((Object[]) obj) : obj;
    }

    private static Optional<Object> memoryId(AiServiceMethodCreateInfo aiServiceMethodCreateInfo, Object[] objArr) {
        return aiServiceMethodCreateInfo.getMemoryIdParamPosition().isPresent() ? Optional.of(objArr[aiServiceMethodCreateInfo.getMemoryIdParamPosition().get().intValue()]) : Optional.empty();
    }

    private static String toString(Object obj) {
        return obj.getClass().isArray() ? arrayToString(obj) : obj.getClass().isAnnotationPresent(StructuredPrompt.class) ? StructuredPromptProcessor.toPrompt(obj).text() : obj.toString();
    }

    private static String arrayToString(Object obj) {
        StringBuilder sb = new StringBuilder("[");
        int length = Array.getLength(obj);
        for (int i = 0; i < length; i++) {
            sb.append(toString(Array.get(obj, i)));
            if (i < length - 1) {
                sb.append(", ");
            }
        }
        sb.append("]");
        return sb.toString();
    }
}
