package io.quarkiverse.langchain4j.bam;

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkiverse.langchain4j.bam.BamModel;
import io.quarkiverse.langchain4j.bam.TextGenerationResponse;
import io.smallrye.mutiny.Context;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;

/* loaded from: input_file:io/quarkiverse/langchain4j/bam/BamStreamingChatModel.class */
public class BamStreamingChatModel extends BamModel implements StreamingChatLanguageModel, TokenCountEstimator {
    private final ObjectMapper mapper;

    public BamStreamingChatModel(BamModel.Builder builder) {
        super(builder);
        this.mapper = QuarkusJsonCodecFactory.SnakeCaseObjectMapperHolder.MAPPER;
    }

    public void generate(List<ChatMessage> list, final StreamingResponseHandler<AiMessage> streamingResponseHandler) {
        TextGenerationRequest textGenerationRequest = new TextGenerationRequest(this.modelId, toInput(list), Parameters.builder().decodingMethod(this.decodingMethod).includeStopSequence(this.includeStopSequence).minNewTokens(this.minNewTokens).maxNewTokens(this.maxNewTokens).randomSeed(this.randomSeed).stopSequences(this.stopSequences).temperature(this.temperature).timeLimit(this.timeLimit).topP(this.topP).topK(this.topK).typicalP(this.typicalP).repetitionPenalty(this.repetitionPenalty).truncateInputTokens(this.truncateInputTokens).beamWidth(this.beamWidth).build());
        final Context of = Context.of(new Object[]{"response", new ArrayList()});
        this.client.chatStreaming(textGenerationRequest, this.token, this.version).subscribe().with(of, new Consumer<String>() { // from class: io.quarkiverse.langchain4j.bam.BamStreamingChatModel.1
            @Override // java.util.function.Consumer
            public void accept(String str) {
                if (str != null) {
                    try {
                        if (str.isBlank()) {
                            return;
                        }
                        TextGenerationResponse textGenerationResponse = (TextGenerationResponse) BamStreamingChatModel.this.mapper.readValue(str, TextGenerationResponse.class);
                        ((List) of.get("response")).add(textGenerationResponse);
                        streamingResponseHandler.onNext(textGenerationResponse.results().get(0).generatedText());
                    } catch (Exception e) {
                        streamingResponseHandler.onError(e);
                    }
                }
            }
        }, new Consumer<Throwable>() { // from class: io.quarkiverse.langchain4j.bam.BamStreamingChatModel.2
            @Override // java.util.function.Consumer
            public void accept(Throwable th) {
                streamingResponseHandler.onError(th);
            }
        }, new Runnable() { // from class: io.quarkiverse.langchain4j.bam.BamStreamingChatModel.3
            @Override // java.lang.Runnable
            public void run() {
                List list2 = (List) of.get("response");
                int i = 0;
                int i2 = 0;
                String str = null;
                StringBuilder sb = new StringBuilder();
                for (int i3 = 0; i3 < list2.size(); i3++) {
                    TextGenerationResponse.Results results = ((TextGenerationResponse) list2.get(i3)).results().get(0);
                    if (i3 == 0) {
                        i = results.inputTokenCount();
                    }
                    if (i3 == list2.size() - 1) {
                        i2 = results.generatedTokenCount();
                        str = results.stopReason();
                    }
                    sb.append(results.generatedText());
                }
                streamingResponseHandler.onComplete(Response.from(new AiMessage(sb.toString()), new TokenUsage(Integer.valueOf(i), Integer.valueOf(i2)), BamStreamingChatModel.this.toFinishReason(str)));
            }
        });
    }

    public int estimateTokenCount(List<ChatMessage> list) {
        return this.client.tokenization(new TokenizationRequest(this.modelId, (String) list.stream().map((v0) -> {
            return v0.text();
        }).collect(Collectors.joining(" "))), this.token, this.version).results().get(0).tokenCount();
    }
}
