package io.quarkiverse.langchain4j.runtime.devui;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.tool.ToolExecutor;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.devui.json.ChatMessagePojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ChatResultPojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ToolExecutionRequestPojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ToolExecutionResultPojo;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkus.arc.All;
import io.quarkus.arc.Arc;
import io.quarkus.logging.Log;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.smallrye.mutiny.subscription.MultiEmitter;
import io.vertx.core.json.JsonObject;
import jakarta.enterprise.context.control.ActivateRequestContext;
import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

@ActivateRequestContext
/* loaded from: input_file:io/quarkiverse/langchain4j/runtime/devui/ChatJsonRPCService.class */
public class ChatJsonRPCService {
    public static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 20;
    private final ChatLanguageModel model;
    private final Optional<StreamingChatLanguageModel> streamingModel;
    private final ChatMemoryProvider memoryProvider;
    private RetrievalAugmentor retrievalAugmentor;
    private final List<ToolSpecification> toolSpecifications;
    private final Map<String, ToolExecutor> toolExecutors;
    private final AtomicReference<ChatMemory> currentMemory = new AtomicReference<>();
    private final AtomicLong currentMemoryId = new AtomicLong();

    public ChatJsonRPCService(@All List<ChatLanguageModel> list, @All List<StreamingChatLanguageModel> list2, @All List<Supplier<RetrievalAugmentor>> list3, @All List<RetrievalAugmentor> list4, ChatMemoryProvider chatMemoryProvider, QuarkusToolExecutorFactory quarkusToolExecutorFactory) {
        this.model = list.get(0);
        this.streamingModel = list2.isEmpty() ? Optional.empty() : Optional.of(list2.get(0));
        this.retrievalAugmentor = null;
        Iterator<Supplier<RetrievalAugmentor>> it = list3.iterator();
        while (it.hasNext()) {
            this.retrievalAugmentor = it.next().get();
            if (this.retrievalAugmentor != null) {
                break;
            }
        }
        if (this.retrievalAugmentor == null) {
            Iterator<RetrievalAugmentor> it2 = list4.iterator();
            while (it2.hasNext()) {
                this.retrievalAugmentor = it2.next();
                if (this.retrievalAugmentor != null) {
                    break;
                }
            }
        }
        this.memoryProvider = chatMemoryProvider;
        Map<String, List<ToolMethodCreateInfo>> metadata = ToolsRecorder.getMetadata();
        if (metadata == null) {
            this.toolSpecifications = List.of();
            this.toolExecutors = Map.of();
            return;
        }
        this.toolExecutors = new HashMap();
        this.toolSpecifications = new ArrayList();
        for (Map.Entry<String, List<ToolMethodCreateInfo>> entry : metadata.entrySet()) {
            for (ToolMethodCreateInfo toolMethodCreateInfo : entry.getValue()) {
                try {
                    this.toolExecutors.put(toolMethodCreateInfo.toolSpecification().name(), quarkusToolExecutorFactory.create(new QuarkusToolExecutor.Context(Arc.container().select(Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()), new Annotation[0]).get(), toolMethodCreateInfo.invokerClassName(), toolMethodCreateInfo.methodName(), toolMethodCreateInfo.argumentMapperClassName())));
                    this.toolSpecifications.add(toolMethodCreateInfo.toolSpecification());
                } catch (ClassNotFoundException e) {
                    throw new RuntimeException(e);
                }
            }
        }
    }

    public String reset(String str) {
        if (this.currentMemory.get() != null) {
            this.currentMemory.get().clear();
        }
        long nextLong = ThreadLocalRandom.current().nextLong();
        this.currentMemoryId.set(nextLong);
        ChatMemory chatMemory = this.memoryProvider.get(Long.valueOf(nextLong));
        this.currentMemory.set(chatMemory);
        if (str == null || str.isEmpty()) {
            return "OK";
        }
        chatMemory.add(new SystemMessage(str));
        return "OK";
    }

    public boolean isStreamingChatSupported() {
        return this.streamingModel.isPresent();
    }

    public Multi<JsonObject> streamingChat(String str, boolean z) {
        ChatMemory chatMemory = this.currentMemory.get();
        if (chatMemory == null) {
            reset("");
            chatMemory = this.currentMemory.get();
        }
        ChatMemory chatMemory2 = chatMemory;
        List messages = chatMemory2.messages();
        return Multi.createFrom().emitter(multiEmitter -> {
            try {
                if (this.retrievalAugmentor == null || !z) {
                    chatMemory2.add(new UserMessage(str));
                } else {
                    UserMessage from = UserMessage.from(str);
                    ChatMessage chatMessage = this.retrievalAugmentor.augment(new AugmentationRequest(from, Metadata.from(from, Long.valueOf(this.currentMemoryId.get()), chatMemory2.messages()))).chatMessage();
                    chatMemory2.add(chatMessage);
                    multiEmitter.emit(new JsonObject().put("augmentedMessage", chatMessage.text()));
                }
                StreamingChatLanguageModel orElseThrow = this.streamingModel.orElseThrow(IllegalStateException::new);
                if (this.toolSpecifications.isEmpty()) {
                    orElseThrow.generate(chatMemory2.messages(), new StreamingResponseHandler<AiMessage>() { // from class: io.quarkiverse.langchain4j.runtime.devui.ChatJsonRPCService.1
                        public void onComplete(Response<AiMessage> response) {
                            chatMemory2.add((ChatMessage) response.content());
                            multiEmitter.emit(new JsonObject().put("message", ((AiMessage) response.content()).text()));
                            multiEmitter.complete();
                        }

                        public void onNext(String str2) {
                            multiEmitter.emit(new JsonObject().put("token", str2));
                        }

                        public void onError(Throwable th) {
                            multiEmitter.fail(th);
                        }
                    });
                } else {
                    executeWithToolsAndStreaming(chatMemory2, multiEmitter, 20);
                }
            } catch (Throwable th) {
                chatMemory2.clear();
                Objects.requireNonNull(chatMemory2);
                messages.forEach(chatMemory2::add);
                Log.warn(th);
                multiEmitter.fail(th);
            }
        }).runSubscriptionOn(Infrastructure.getDefaultWorkerPool());
    }

    public ChatResultPojo chat(String str, boolean z) {
        ChatMemory chatMemory = this.currentMemory.get();
        if (chatMemory == null) {
            reset("");
            chatMemory = this.currentMemory.get();
        }
        List messages = chatMemory.messages();
        try {
            if (this.retrievalAugmentor == null || !z) {
                chatMemory.add(new UserMessage(str));
            } else {
                UserMessage from = UserMessage.from(str);
                chatMemory.add(this.retrievalAugmentor.augment(new AugmentationRequest(from, Metadata.from(from, Long.valueOf(this.currentMemoryId.get()), chatMemory.messages()))).chatMessage());
            }
            if (this.toolSpecifications.isEmpty()) {
                chatMemory.add((ChatMessage) this.model.generate(chatMemory.messages()).content());
            } else {
                executeWithTools(chatMemory);
            }
            return new ChatResultPojo(ChatMessagePojo.listFromMemory(chatMemory), null);
        } catch (Throwable th) {
            chatMemory.clear();
            ChatMemory chatMemory2 = chatMemory;
            Objects.requireNonNull(chatMemory2);
            messages.forEach(chatMemory2::add);
            Log.warn(th);
            return new ChatResultPojo(null, th.getMessage());
        }
    }

    private Response<AiMessage> executeWithTools(ChatMemory chatMemory) {
        Response generate = this.model.generate(chatMemory.messages(), this.toolSpecifications);
        int i = 20;
        while (true) {
            int i2 = i;
            i--;
            if (i2 == 0) {
                throw new RuntimeException("Something is wrong, exceeded " + 20 + " sequential tool executions");
            }
            AiMessage aiMessage = (AiMessage) generate.content();
            chatMemory.add(aiMessage);
            if (!aiMessage.hasToolExecutionRequests()) {
                return Response.from((AiMessage) generate.content(), new TokenUsage(), generate.finishReason());
            }
            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                chatMemory.add(ToolExecutionResultMessage.from(toolExecutionRequest, this.toolExecutors.get(toolExecutionRequest.name()).execute(toolExecutionRequest, Long.valueOf(this.currentMemoryId.get()))));
            }
            generate = this.model.generate(chatMemory.messages(), this.toolSpecifications);
        }
    }

    private void executeWithToolsAndStreaming(final ChatMemory chatMemory, final MultiEmitter<? super JsonObject> multiEmitter, int i) {
        final int i2 = i - 1;
        if (i2 == 0) {
            throw new RuntimeException("Something is wrong, exceeded 20 sequential tool executions");
        }
        this.streamingModel.get().generate(chatMemory.messages(), this.toolSpecifications, new StreamingResponseHandler<AiMessage>() { // from class: io.quarkiverse.langchain4j.runtime.devui.ChatJsonRPCService.2
            public void onComplete(Response<AiMessage> response) {
                Executor defaultExecutor = Infrastructure.getDefaultExecutor();
                ChatMemory chatMemory2 = chatMemory;
                MultiEmitter multiEmitter2 = multiEmitter;
                int i3 = i2;
                defaultExecutor.execute(() -> {
                    AiMessage aiMessage = (AiMessage) response.content();
                    chatMemory2.add(aiMessage);
                    if (!aiMessage.hasToolExecutionRequests()) {
                        multiEmitter2.emit(new JsonObject().put("message", aiMessage.text()));
                        multiEmitter2.complete();
                        return;
                    }
                    for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                        ToolExecutor toolExecutor = ChatJsonRPCService.this.toolExecutors.get(toolExecutionRequest.name());
                        multiEmitter2.emit(new JsonObject().put("toolExecutionRequest", new ToolExecutionRequestPojo(toolExecutionRequest.id(), toolExecutionRequest.name(), toolExecutionRequest.arguments())));
                        ToolExecutionResultMessage from = ToolExecutionResultMessage.from(toolExecutionRequest, toolExecutor.execute(toolExecutionRequest, Long.valueOf(ChatJsonRPCService.this.currentMemoryId.get())));
                        chatMemory2.add(from);
                        multiEmitter2.emit(new JsonObject().put("toolExecutionResult", new ToolExecutionResultPojo(from.id(), from.toolName(), from.text())));
                    }
                    ChatJsonRPCService.this.executeWithToolsAndStreaming(chatMemory2, multiEmitter2, i3);
                });
            }

            public void onNext(String str) {
                multiEmitter.emit(new JsonObject().put("token", str));
            }

            public void onError(Throwable th) {
                throw new RuntimeException(th);
            }
        });
    }
}
