package io.quarkiverse.langchain4j.bedrock.runtime;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.CustomMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.chat.response.ChatResponseMetadata;
import dev.langchain4j.model.chat.response.StreamingChatResponseHandler;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.bedrock.runtime.config.ChatModelConfig;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Function;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;

/* loaded from: input_file:io/quarkiverse/langchain4j/bedrock/runtime/BedrockConverseStreamingChatModel.class */
public class BedrockConverseStreamingChatModel implements StreamingChatLanguageModel {
    private final BedrockRuntimeAsyncClient client;
    private final String modelId;
    private final ChatModelConfig config;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel$5, reason: invalid class name */
    /* loaded from: input_file:io/quarkiverse/langchain4j/bedrock/runtime/BedrockConverseStreamingChatModel$5.class */
    public static /* synthetic */ class AnonymousClass5 {
        static final /* synthetic */ int[] $SwitchMap$software$amazon$awssdk$services$bedrockruntime$model$StopReason = new int[StopReason.values().length];

        static {
            try {
                $SwitchMap$software$amazon$awssdk$services$bedrockruntime$model$StopReason[StopReason.END_TURN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$software$amazon$awssdk$services$bedrockruntime$model$StopReason[StopReason.STOP_SEQUENCE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$software$amazon$awssdk$services$bedrockruntime$model$StopReason[StopReason.GUARDRAIL_INTERVENED.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$software$amazon$awssdk$services$bedrockruntime$model$StopReason[StopReason.TOOL_USE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$software$amazon$awssdk$services$bedrockruntime$model$StopReason[StopReason.MAX_TOKENS.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$software$amazon$awssdk$services$bedrockruntime$model$StopReason[StopReason.CONTENT_FILTERED.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/quarkiverse/langchain4j/bedrock/runtime/BedrockConverseStreamingChatModel$StreamContext.class */
    public class StreamContext {
        private FinishReason stopReason;
        private final StreamingChatResponseHandler handler;
        private final StringBuilder finalCompletion = new StringBuilder();
        private TokenUsage tokenUsage = new TokenUsage();

        public StreamContext(StreamingChatResponseHandler streamingChatResponseHandler) {
            this.handler = streamingChatResponseHandler;
        }

        public Consumer<MessageStopEvent> setStopReason() {
            return new Consumer<MessageStopEvent>() { // from class: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel.StreamContext.1
                @Override // java.util.function.Consumer
                public void accept(MessageStopEvent messageStopEvent) {
                    StreamContext.this.stopReason = StreamContext.this.mapFinishReason(messageStopEvent.stopReason());
                }
            };
        }

        public Consumer<ConverseStreamMetadataEvent> updateTokenUsage() {
            return new Consumer<ConverseStreamMetadataEvent>() { // from class: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel.StreamContext.2
                @Override // java.util.function.Consumer
                public void accept(ConverseStreamMetadataEvent converseStreamMetadataEvent) {
                    software.amazon.awssdk.services.bedrockruntime.model.TokenUsage usage = converseStreamMetadataEvent.usage();
                    StreamContext.this.tokenUsage = StreamContext.this.tokenUsage.add(new TokenUsage(usage.inputTokens(), usage.outputTokens(), usage.totalTokens()));
                }
            };
        }

        public Consumer<ContentBlockDeltaEvent> handleChunk() {
            return new Consumer<ContentBlockDeltaEvent>() { // from class: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel.StreamContext.3
                @Override // java.util.function.Consumer
                public void accept(ContentBlockDeltaEvent contentBlockDeltaEvent) {
                    String text = contentBlockDeltaEvent.delta().text();
                    StreamContext.this.finalCompletion.append(text);
                    StreamContext.this.handler.onPartialResponse(text);
                }
            };
        }

        public Runnable handleCompletion() {
            return new Runnable() { // from class: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel.StreamContext.4
                @Override // java.lang.Runnable
                public void run() {
                    StreamContext.this.handler.onCompleteResponse(ChatResponse.builder().aiMessage(new AiMessage(StreamContext.this.finalCompletion.toString())).metadata(ChatResponseMetadata.builder().modelName(BedrockConverseStreamingChatModel.this.modelId).tokenUsage(StreamContext.this.tokenUsage).finishReason(StreamContext.this.stopReason).build()).build());
                }
            };
        }

        private FinishReason mapFinishReason(StopReason stopReason) {
            if (stopReason == null) {
                return FinishReason.OTHER;
            }
            switch (AnonymousClass5.$SwitchMap$software$amazon$awssdk$services$bedrockruntime$model$StopReason[stopReason.ordinal()]) {
                case 1:
                case 2:
                case 3:
                    return FinishReason.STOP;
                case 4:
                    return FinishReason.TOOL_EXECUTION;
                case 5:
                    return FinishReason.LENGTH;
                case 6:
                    return FinishReason.CONTENT_FILTER;
                default:
                    return FinishReason.OTHER;
            }
        }
    }

    public BedrockConverseStreamingChatModel(BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient, String str, ChatModelConfig chatModelConfig) {
        this.client = bedrockRuntimeAsyncClient;
        this.modelId = str;
        this.config = chatModelConfig;
    }

    public void chat(final ChatRequest chatRequest, final StreamingChatResponseHandler streamingChatResponseHandler) {
        StreamContext streamContext = new StreamContext(streamingChatResponseHandler);
        this.client.converseStream(new Consumer<ConverseStreamRequest.Builder>(this) { // from class: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel.2
            final /* synthetic */ BedrockConverseStreamingChatModel this$0;

            {
                this.this$0 = this;
            }

            @Override // java.util.function.Consumer
            public void accept(ConverseStreamRequest.Builder builder) {
                builder.modelId(this.this$0.modelId).messages(this.this$0.toBedrockMessages(chatRequest)).inferenceConfig(this.this$0.createInferenceConfig());
            }
        }, ((ConverseStreamResponseHandler.Builder) ((ConverseStreamResponseHandler.Builder) ConverseStreamResponseHandler.builder().subscriber(ConverseStreamResponseHandler.Visitor.builder().onMessageStop(streamContext.setStopReason()).onMetadata(streamContext.updateTokenUsage()).onContentBlockDelta(streamContext.handleChunk()).build()).onComplete(streamContext.handleCompletion())).onError(new Consumer<Throwable>(this) { // from class: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel.1
            final /* synthetic */ BedrockConverseStreamingChatModel this$0;

            {
                this.this$0 = this;
            }

            @Override // java.util.function.Consumer
            public void accept(Throwable th) {
                streamingChatResponseHandler.onError(th);
            }
        })).build());
    }

    private Consumer<InferenceConfiguration.Builder> createInferenceConfig() {
        return new Consumer<InferenceConfiguration.Builder>() { // from class: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel.3
            @Override // java.util.function.Consumer
            public void accept(InferenceConfiguration.Builder builder) {
                builder.maxTokens(BedrockConverseStreamingChatModel.this.config.maxTokens());
                if (BedrockConverseStreamingChatModel.this.config.temperature().isPresent()) {
                    builder.temperature(Float.valueOf((float) BedrockConverseStreamingChatModel.this.config.temperature().getAsDouble()));
                }
                if (BedrockConverseStreamingChatModel.this.config.topP().isPresent()) {
                    builder.topP(Float.valueOf((float) BedrockConverseStreamingChatModel.this.config.topP().getAsDouble()));
                }
            }
        };
    }

    private List<Message> toBedrockMessages(ChatRequest chatRequest) {
        return chatRequest.messages().stream().map(messageTransformer()).toList();
    }

    private Function<ChatMessage, Message> messageTransformer() {
        return new Function<ChatMessage, Message>() { // from class: io.quarkiverse.langchain4j.bedrock.runtime.BedrockConverseStreamingChatModel.4
            @Override // java.util.function.Function
            public Message apply(ChatMessage chatMessage) {
                String text;
                ConversationRole conversationRole;
                if (chatMessage instanceof SystemMessage) {
                    text = ((SystemMessage) chatMessage).text();
                    conversationRole = ConversationRole.ASSISTANT;
                } else if (chatMessage instanceof UserMessage) {
                    text = ((UserMessage) chatMessage).singleText();
                    conversationRole = ConversationRole.USER;
                } else if (chatMessage instanceof AiMessage) {
                    text = ((AiMessage) chatMessage).text();
                    conversationRole = ConversationRole.USER;
                } else if (chatMessage instanceof ToolExecutionResultMessage) {
                    text = ((ToolExecutionResultMessage) chatMessage).text();
                    conversationRole = ConversationRole.ASSISTANT;
                } else {
                    if (!(chatMessage instanceof CustomMessage)) {
                        throw new IllegalArgumentException(chatMessage == null ? "null" : chatMessage.getClass().getName());
                    }
                    text = ((CustomMessage) chatMessage).text();
                    conversationRole = ConversationRole.USER;
                }
                return (Message) Message.builder().content(new ContentBlock[]{ContentBlock.fromText(text)}).role(conversationRole).build();
            }
        };
    }
}
