package io.github.alexcheng1982.springai.dashscope;

import com.alibaba.dashscope.aigc.generation.GenerationOutput;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.multimodalconversation.MultiModalConversationResult;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.MultiModalMessage;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.tools.FunctionDefinition;
import com.alibaba.dashscope.tools.ToolBase;
import com.alibaba.dashscope.tools.ToolCallFunction;
import com.alibaba.dashscope.tools.ToolFunction;
import com.alibaba.dashscope.utils.JsonUtils;
import io.github.alexcheng1982.springai.dashscope.api.DashscopeApi;
import io.github.alexcheng1982.springai.dashscope.metadata.DashscopeUsage;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.adapter.rxjava.RxJava2Adapter;
import reactor.core.publisher.Flux;

/* loaded from: input_file:io/github/alexcheng1982/springai/dashscope/DashscopeChatModel.class */
public class DashscopeChatModel extends AbstractToolCallSupport implements ChatModel {
    private static final DashscopeChatOptions DEFAULT_OPTIONS = DashscopeChatOptions.builder().withModel("qwen-turbo").build();
    private final DashscopeChatOptions defaultOptions;
    private final DashscopeApi dashscopeApi;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.github.alexcheng1982.springai.dashscope.DashscopeChatModel$3, reason: invalid class name */
    /* loaded from: input_file:io/github/alexcheng1982/springai/dashscope/DashscopeChatModel$3.class */
    public static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$org$springframework$ai$chat$messages$MessageType = new int[MessageType.values().length];

        static {
            try {
                $SwitchMap$org$springframework$ai$chat$messages$MessageType[MessageType.SYSTEM.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$springframework$ai$chat$messages$MessageType[MessageType.ASSISTANT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$springframework$ai$chat$messages$MessageType[MessageType.TOOL.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    public DashscopeChatModel(DashscopeApi dashscopeApi) {
        this(dashscopeApi, DEFAULT_OPTIONS);
    }

    public DashscopeChatModel(DashscopeApi dashscopeApi, DashscopeChatOptions dashscopeChatOptions) {
        this(dashscopeApi, dashscopeChatOptions, null);
    }

    public DashscopeChatModel(DashscopeApi dashscopeApi, FunctionCallbackResolver functionCallbackResolver) {
        this(dashscopeApi, DEFAULT_OPTIONS, functionCallbackResolver);
    }

    public DashscopeChatModel(DashscopeApi dashscopeApi, DashscopeChatOptions dashscopeChatOptions, FunctionCallbackResolver functionCallbackResolver) {
        super(functionCallbackResolver);
        Assert.notNull(dashscopeApi, "dashscopeApi must not be null");
        Assert.notNull(dashscopeChatOptions, "Options must not be null");
        this.dashscopeApi = dashscopeApi;
        this.defaultOptions = dashscopeChatOptions;
    }

    public static DashscopeChatModel createDefault() {
        return new DashscopeChatModel(new DashscopeApi());
    }

    public ChatResponse call(Prompt prompt) {
        DashscopeApi.ChatCompletionRequest createRequest = createRequest(prompt);
        ChatResponse chatCompletionResultToChatResponse = chatCompletionResultToChatResponse(doChatCompletion(createRequest));
        return (createRequest.isMultiModalRequest() || !isToolCall(chatCompletionResultToChatResponse)) ? chatCompletionResultToChatResponse : call(new Prompt(handleToolCalls(prompt, chatCompletionResultToChatResponse), prompt.getOptions()));
    }

    public ChatOptions getDefaultOptions() {
        return DashscopeChatOptions.builder().withModel("qwen-turbo").build();
    }

    public Flux<ChatResponse> stream(Prompt prompt) {
        DashscopeApi.ChatCompletionRequest createRequest = createRequest(prompt);
        return createRequest.isMultiModalRequest() ? RxJava2Adapter.flowableToFlux(this.dashscopeApi.multiModalStream(createRequest.getMultiModalMessages(), createRequest.options()).map(this::multiModalConversationResultToChatResponse)) : RxJava2Adapter.flowableToFlux(this.dashscopeApi.chatCompletionStream(createRequest.getMessages(), createRequest.options()).flatMap(generationResult -> {
            ChatResponse generationResultToChatResponse = generationResultToChatResponse(generationResult);
            return !isToolCall(generationResultToChatResponse, Set.of("tool")) ? Flux.just(generationResultToChatResponse) : stream(new Prompt(handleToolCalls(prompt, generationResultToChatResponse), prompt.getOptions()));
        }));
    }

    private boolean isToolCall(ChatResponse chatResponse) {
        return isToolCall(chatResponse, Set.of("tool_calls"));
    }

    private ChatResponse chatCompletionResultToChatResponse(DashscopeApi.ChatCompletionResult chatCompletionResult) {
        return chatCompletionResult.multiModalConversationResult() != null ? multiModalConversationResultToChatResponse(chatCompletionResult.multiModalConversationResult()) : generationResultToChatResponse(chatCompletionResult.generationResult());
    }

    private ChatResponse multiModalConversationResultToChatResponse(MultiModalConversationResult multiModalConversationResult) {
        return new ChatResponse(multiModalConversationResult.getOutput().getChoices().stream().map(choice -> {
            return new Generation(new AssistantMessage((String) ((Map) choice.getMessage().getContent().get(0)).get("text")), ChatGenerationMetadata.builder().finishReason(choice.getFinishReason()).build());
        }).toList(), buildChatResponseMetadata(multiModalConversationResult));
    }

    private ChatResponse generationResultToChatResponse(GenerationResult generationResult) {
        return new ChatResponse(generationResult.getOutput().getChoices().stream().map(choice -> {
            return buildGeneration(choice, new HashMap());
        }).toList(), buildChatResponseMetadata(generationResult));
    }

    private Generation buildGeneration(GenerationOutput.Choice choice, Map<String, Object> map) {
        return new Generation(new AssistantMessage(choice.getMessage().getContent(), map, choice.getMessage().getToolCalls() == null ? List.of() : choice.getMessage().getToolCalls().stream().filter(toolCallBase -> {
            return toolCallBase.getType().equals("function");
        }).map(toolCallBase2 -> {
            return (ToolCallFunction) toolCallBase2;
        }).map(toolCallFunction -> {
            return new AssistantMessage.ToolCall(toolCallFunction.getId(), toolCallFunction.getType(), toolCallFunction.getFunction().getName(), toolCallFunction.getFunction().getArguments());
        }).toList()), ChatGenerationMetadata.builder().finishReason(choice.getFinishReason() != null ? choice.getFinishReason() : "").build());
    }

    private ChatResponseMetadata buildChatResponseMetadata(GenerationResult generationResult) {
        return ChatResponseMetadata.builder().usage(new DashscopeUsage(generationResult.getUsage())).build();
    }

    private ChatResponseMetadata buildChatResponseMetadata(MultiModalConversationResult multiModalConversationResult) {
        return ChatResponseMetadata.builder().usage(new DashscopeUsage(multiModalConversationResult.getUsage())).build();
    }

    private DashscopeApi.ChatCompletionRequest createRequest(Prompt prompt) {
        HashSet hashSet = new HashSet();
        List<DashscopeApi.ChatCompletionMessage> dashscopeMessages = toDashscopeMessages(prompt.getInstructions());
        DashscopeChatOptions dashscopeChatOptions = new DashscopeChatOptions();
        if (this.defaultOptions != null) {
            dashscopeChatOptions = this.defaultOptions.createCopy();
        }
        if (prompt.getOptions() != null) {
            dashscopeChatOptions = dashscopeChatOptions.copyFrom((DashscopeChatOptions) ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class, DashscopeChatOptions.class));
            hashSet.addAll(runtimeFunctionCallbackConfigurations(dashscopeChatOptions));
        }
        if (this.defaultOptions != null && !CollectionUtils.isEmpty(this.defaultOptions.getFunctions())) {
            hashSet.addAll(this.defaultOptions.getFunctions());
        }
        if (!CollectionUtils.isEmpty(hashSet)) {
            dashscopeChatOptions.setTools(getToolFunctions(hashSet));
        }
        return new DashscopeApi.ChatCompletionRequest(dashscopeMessages, dashscopeChatOptions);
    }

    private List<ToolBase> getToolFunctions(Set<String> set) {
        return resolveFunctionCallbacks(set).stream().map(this::toToolFunction).toList();
    }

    private ToolBase toToolFunction(FunctionCallback functionCallback) {
        return ToolFunction.builder().function(FunctionDefinition.builder().name(functionCallback.getName()).description(functionCallback.getDescription()).parameters(JsonUtils.parseString(functionCallback.getInputTypeSchema()).getAsJsonObject()).build()).build();
    }

    private DashscopeApi.ChatCompletionResult doChatCompletion(DashscopeApi.ChatCompletionRequest chatCompletionRequest) {
        return chatCompletionRequest.isMultiModalRequest() ? new DashscopeApi.ChatCompletionResult(this.dashscopeApi.multiModal(chatCompletionRequest.getMultiModalMessages(), chatCompletionRequest.options())) : new DashscopeApi.ChatCompletionResult(this.dashscopeApi.chatCompletion(chatCompletionRequest.getMessages(), chatCompletionRequest.options()));
    }

    private List<DashscopeApi.ChatCompletionMessage> toDashscopeMessages(List<Message> list) {
        return list.stream().anyMatch(message -> {
            return (message instanceof UserMessage) && !CollectionUtils.isEmpty(((UserMessage) message).getMedia());
        }) ? list.stream().map(this::toDashscopeMultiModalMessage).map(DashscopeApi.ChatCompletionMessage::new).toList() : list.stream().map(this::toDashscopeMessage).map(DashscopeApi.ChatCompletionMessage::new).toList();
    }

    private com.alibaba.dashscope.common.Message toDashscopeMessage(Message message) {
        Message.MessageBuilder content = com.alibaba.dashscope.common.Message.builder().role(roleFrom(message.getMessageType())).content(message.getText());
        if (message instanceof ToolResponseMessage) {
            ToolResponseMessage toolResponseMessage = (ToolResponseMessage) message;
            if (!CollectionUtils.isEmpty(toolResponseMessage.getResponses())) {
                ToolResponseMessage.ToolResponse toolResponse = (ToolResponseMessage.ToolResponse) toolResponseMessage.getResponses().get(0);
                content.toolCallId(toolResponse.id()).name(toolResponse.name()).content(toolResponse.responseData());
                return content.build();
            }
        }
        if (message instanceof AssistantMessage) {
            content.toolCalls(((AssistantMessage) message).getToolCalls().stream().map(toolCall -> {
                ToolCallFunction toolCallFunction = new ToolCallFunction();
                toolCallFunction.setId(toolCall.id());
                Objects.requireNonNull(toolCallFunction);
                ToolCallFunction.CallFunction callFunction = new ToolCallFunction.CallFunction(toolCallFunction);
                callFunction.setName(toolCall.name());
                callFunction.setArguments(toolCall.arguments());
                toolCallFunction.setFunction(callFunction);
                return toolCallFunction;
            }).toList());
        }
        return content.build();
    }

    private MultiModalMessage toDashscopeMultiModalMessage(final org.springframework.ai.chat.messages.Message message) {
        ArrayList arrayList = new ArrayList();
        if (message instanceof UserMessage) {
            for (final Media media : ((UserMessage) message).getMedia()) {
                arrayList.add(new HashMap<String, Object>() { // from class: io.github.alexcheng1982.springai.dashscope.DashscopeChatModel.1
                    {
                        put(media.getMimeType().getType(), media.getData());
                    }
                });
            }
        }
        arrayList.add(new HashMap<String, Object>() { // from class: io.github.alexcheng1982.springai.dashscope.DashscopeChatModel.2
            {
                put("text", message.getText());
            }
        });
        return MultiModalMessage.builder().role(roleFrom(message.getMessageType())).content(arrayList).build();
    }

    private String roleFrom(MessageType messageType) {
        switch (AnonymousClass3.$SwitchMap$org$springframework$ai$chat$messages$MessageType[messageType.ordinal()]) {
            case 1:
                return Role.SYSTEM.getValue();
            case 2:
                return Role.ASSISTANT.getValue();
            case 3:
                return Role.TOOL.getValue();
            default:
                return Role.USER.getValue();
        }
    }
}
