package dev.langchain4j.rag.query.transformer;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.rag.query.Query;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/* loaded from: input_file:dev/langchain4j/rag/query/transformer/CompressingQueryTransformer.class */
public class CompressingQueryTransformer implements QueryTransformer {
    public static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from("Read and understand the conversation between the User and the AI. Then, analyze the new query from the User. Identify all relevant details, terms, and context from both the conversation and the new query. Reformulate this query into a clear, concise, and self-contained format suitable for information retrieval.\n\nConversation:\n{{chatMemory}}\n\nUser query: {{query}}\n\nIt is very important that you provide only reformulated query and nothing else! Do not prepend a query with anything!");
    protected final PromptTemplate promptTemplate;
    protected final ChatLanguageModel chatLanguageModel;

    /* loaded from: input_file:dev/langchain4j/rag/query/transformer/CompressingQueryTransformer$CompressingQueryTransformerBuilder.class */
    public static class CompressingQueryTransformerBuilder {
        private ChatLanguageModel chatLanguageModel;
        private PromptTemplate promptTemplate;

        CompressingQueryTransformerBuilder() {
        }

        public CompressingQueryTransformerBuilder chatLanguageModel(ChatLanguageModel chatLanguageModel) {
            this.chatLanguageModel = chatLanguageModel;
            return this;
        }

        public CompressingQueryTransformerBuilder promptTemplate(PromptTemplate promptTemplate) {
            this.promptTemplate = promptTemplate;
            return this;
        }

        public CompressingQueryTransformer build() {
            return new CompressingQueryTransformer(this.chatLanguageModel, this.promptTemplate);
        }

        public String toString() {
            return "CompressingQueryTransformer.CompressingQueryTransformerBuilder(chatLanguageModel=" + String.valueOf(this.chatLanguageModel) + ", promptTemplate=" + String.valueOf(this.promptTemplate) + ")";
        }
    }

    public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel) {
        this(chatLanguageModel, DEFAULT_PROMPT_TEMPLATE);
    }

    public CompressingQueryTransformer(ChatLanguageModel chatLanguageModel, PromptTemplate promptTemplate) {
        this.chatLanguageModel = (ChatLanguageModel) ValidationUtils.ensureNotNull(chatLanguageModel, "chatLanguageModel");
        this.promptTemplate = (PromptTemplate) Utils.getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
    }

    public static CompressingQueryTransformerBuilder builder() {
        return new CompressingQueryTransformerBuilder();
    }

    @Override // dev.langchain4j.rag.query.transformer.QueryTransformer
    public Collection<Query> transform(Query query) {
        List<ChatMessage> chatMemory = query.metadata().chatMemory();
        if (chatMemory.isEmpty()) {
            return Collections.singletonList(query);
        }
        String generate = this.chatLanguageModel.generate(createPrompt(query, format(chatMemory)).text());
        return Collections.singletonList(query.metadata() == null ? Query.from(generate) : Query.from(generate, query.metadata()));
    }

    protected String format(List<ChatMessage> list) {
        return (String) list.stream().map(this::format).filter((v0) -> {
            return Objects.nonNull(v0);
        }).collect(Collectors.joining("\n"));
    }

    protected String format(ChatMessage chatMessage) {
        if (chatMessage instanceof UserMessage) {
            return "User: " + chatMessage.text();
        }
        if (!(chatMessage instanceof AiMessage)) {
            return null;
        }
        AiMessage aiMessage = (AiMessage) chatMessage;
        if (aiMessage.hasToolExecutionRequests()) {
            return null;
        }
        return "AI: " + aiMessage.text();
    }

    protected Prompt createPrompt(Query query, String str) {
        HashMap hashMap = new HashMap();
        hashMap.put("query", query.text());
        hashMap.put("chatMemory", str);
        return this.promptTemplate.apply((Map<String, Object>) hashMap);
    }
}
