package io.quarkiverse.langchain4j.gemini.common;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.request.ChatRequestParameters;
import dev.langchain4j.model.chat.request.ResponseFormat;
import dev.langchain4j.model.chat.request.ResponseFormatType;
import dev.langchain4j.model.chat.request.json.JsonEnumSchema;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.jboss.logging.Logger;

/* loaded from: input_file:io/quarkiverse/langchain4j/gemini/common/GeminiChatLanguageModel.class */
public abstract class GeminiChatLanguageModel implements ChatLanguageModel {
    private static final Logger log = Logger.getLogger(GeminiChatLanguageModel.class);
    private final String modelId;
    private final Double temperature;
    private final Integer maxOutputTokens;
    private final Integer topK;
    private final Double topP;
    private final ResponseFormat responseFormat;
    private final List<ChatModelListener> listeners;

    public GeminiChatLanguageModel(String str, Double d, Integer num, Integer num2, Double d2, ResponseFormat responseFormat, List<ChatModelListener> list) {
        this.modelId = str;
        this.temperature = d;
        this.maxOutputTokens = num;
        this.topK = num2;
        this.topP = d2;
        this.responseFormat = responseFormat;
        this.listeners = list;
    }

    public Set<Capability> supportedCapabilities() {
        HashSet hashSet = new HashSet();
        if (this.responseFormat != null && ResponseFormatType.JSON.equals(this.responseFormat.type())) {
            hashSet.add(Capability.RESPONSE_FORMAT_JSON_SCHEMA);
        } else if (this.responseFormat == null) {
            hashSet.add(Capability.RESPONSE_FORMAT_JSON_SCHEMA);
        }
        return hashSet;
    }

    public ChatResponse chat(ChatRequest chatRequest) {
        ChatRequestParameters parameters = chatRequest.parameters();
        ResponseFormat responseFormat = (ResponseFormat) Utils.getOrDefault(parameters.responseFormat(), this.responseFormat);
        GenerateContentRequest map = ContentMapper.map(chatRequest.messages(), chatRequest.toolSpecifications(), GenerationConfig.builder().maxOutputTokens((Integer) Utils.getOrDefault(parameters.maxOutputTokens(), this.maxOutputTokens)).responseMimeType(computeMimeType(responseFormat)).responseSchema(responseFormat != null ? SchemaMapper.fromJsonSchemaToSchema(responseFormat.jsonSchema()) : null).stopSequences(parameters.stopSequences()).temperature((Double) Utils.getOrDefault(parameters.temperature(), this.temperature)).topK((Integer) Utils.getOrDefault(parameters.topK(), this.topK)).topP((Double) Utils.getOrDefault(parameters.topP(), this.topP)).build());
        ChatModelRequest createModelListenerRequest = createModelListenerRequest(map, chatRequest.messages(), chatRequest.toolSpecifications());
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        ChatModelRequestContext chatModelRequestContext = new ChatModelRequestContext(createModelListenerRequest, concurrentHashMap);
        this.listeners.forEach(chatModelListener -> {
            try {
                chatModelListener.onRequest(chatModelRequestContext);
            } catch (Exception e) {
                log.warn("Exception while calling model listener", e);
            }
        });
        try {
            GenerateContentResponse generateContext = generateContext(map);
            String text = GenerateContentResponseHandler.getText(generateContext);
            List<ToolExecutionRequest> toolExecutionRequests = GenerateContentResponseHandler.getToolExecutionRequests(generateContext);
            AiMessage aiMessage = toolExecutionRequests.isEmpty() ? AiMessage.aiMessage(text) : AiMessage.aiMessage(text, toolExecutionRequests);
            TokenUsage tokenUsage = GenerateContentResponseHandler.getTokenUsage(generateContext.usageMetadata());
            FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(generateContext));
            ChatModelResponseContext chatModelResponseContext = new ChatModelResponseContext(createModelListenerResponse(null, this.modelId, Response.from(aiMessage, tokenUsage)), createModelListenerRequest, concurrentHashMap);
            this.listeners.forEach(chatModelListener2 -> {
                try {
                    chatModelListener2.onResponse(chatModelResponseContext);
                } catch (Exception e) {
                    log.warn("Exception while calling model listener", e);
                }
            });
            return ChatResponse.builder().aiMessage(aiMessage).tokenUsage(GenerateContentResponseHandler.getTokenUsage(generateContext.usageMetadata())).finishReason(FinishReasonMapper.map(GenerateContentResponseHandler.getFinishReason(generateContext))).build();
        } catch (RuntimeException e) {
            ChatModelErrorContext chatModelErrorContext = new ChatModelErrorContext(e, createModelListenerRequest, (ChatModelResponse) null, concurrentHashMap);
            this.listeners.forEach(chatModelListener3 -> {
                try {
                    chatModelListener3.onError(chatModelErrorContext);
                } catch (Exception e2) {
                    log.warn("Exception while calling model listener", e2);
                }
            });
            throw e;
        }
    }

    protected abstract GenerateContentResponse generateContext(GenerateContentRequest generateContentRequest);

    private ChatModelRequest createModelListenerRequest(GenerateContentRequest generateContentRequest, List<ChatMessage> list, List<ToolSpecification> list2) {
        return ChatModelRequest.builder().model(this.modelId).messages(list).toolSpecifications(list2).temperature(this.temperature).topP(this.topP).maxTokens(this.maxOutputTokens).build();
    }

    private ChatModelResponse createModelListenerResponse(String str, String str2, Response<AiMessage> response) {
        if (response == null) {
            return null;
        }
        return ChatModelResponse.builder().id(str).model(str2).tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).aiMessage((AiMessage) response.content()).build();
    }

    private String computeMimeType(ResponseFormat responseFormat) {
        return (responseFormat == null || ResponseFormatType.TEXT.equals(responseFormat.type())) ? "text/plain" : (!ResponseFormatType.JSON.equals(responseFormat.type()) || responseFormat.jsonSchema() == null || responseFormat.jsonSchema().rootElement() == null || !(responseFormat.jsonSchema().rootElement() instanceof JsonEnumSchema)) ? "application/json" : "text/x.enum";
    }

    public ChatResponse doChat(ChatRequest chatRequest) {
        ChatResponse chat = chat(ChatRequest.builder().messages(chatRequest.messages()).toolSpecifications(chatRequest.toolSpecifications()).build());
        return ChatResponse.builder().aiMessage(chat.aiMessage()).tokenUsage(chat.tokenUsage()).finishReason(chat.finishReason()).build();
    }
}
