package io.quarkiverse.langchain4j.watsonx;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import io.quarkiverse.langchain4j.watsonx.WatsonxModel;
import io.quarkiverse.langchain4j.watsonx.bean.Parameters;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationRequest;
import io.quarkiverse.langchain4j.watsonx.bean.TextGenerationResponse;
import io.quarkiverse.langchain4j.watsonx.bean.TokenizationRequest;
import io.smallrye.mutiny.Context;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.function.Consumer;

/* loaded from: input_file:io/quarkiverse/langchain4j/watsonx/WatsonxStreamingChatModel.class */
public class WatsonxStreamingChatModel extends WatsonxModel implements StreamingChatLanguageModel, TokenCountEstimator {
    public WatsonxStreamingChatModel(WatsonxModel.Builder builder) {
        super(builder);
    }

    public void generate(List<ChatMessage> list, final StreamingResponseHandler<AiMessage> streamingResponseHandler) {
        Parameters.LengthPenalty lengthPenalty = null;
        if (Objects.nonNull(this.decayFactor) || Objects.nonNull(this.startIndex)) {
            lengthPenalty = new Parameters.LengthPenalty(this.decayFactor, this.startIndex);
        }
        TextGenerationRequest textGenerationRequest = new TextGenerationRequest(this.modelId, this.projectId, toInput(list), Parameters.builder().decodingMethod(this.decodingMethod).lengthPenalty(lengthPenalty).minNewTokens(this.minNewTokens).maxNewTokens(this.maxNewTokens).randomSeed(this.randomSeed).stopSequences(this.stopSequences).temperature(this.temperature).topP(this.topP).topK(this.topK).repetitionPenalty(this.repetitionPenalty).truncateInputTokens(this.truncateInputTokens).includeStopSequence(this.includeStopSequence).build());
        final Context of = Context.of(new Object[]{"response", new ArrayList()});
        this.client.chatStreaming(textGenerationRequest, this.version).subscribe().with(of, new Consumer<TextGenerationResponse>() { // from class: io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel.1
            @Override // java.util.function.Consumer
            public void accept(TextGenerationResponse textGenerationResponse) {
                if (textGenerationResponse != null) {
                    try {
                        if (textGenerationResponse.results() == null || textGenerationResponse.results().isEmpty()) {
                            return;
                        }
                        ((List) of.get("response")).add(textGenerationResponse);
                        streamingResponseHandler.onNext(textGenerationResponse.results().get(0).generatedText());
                    } catch (Exception e) {
                        streamingResponseHandler.onError(e);
                    }
                }
            }
        }, new Consumer<Throwable>() { // from class: io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel.2
            @Override // java.util.function.Consumer
            public void accept(Throwable th) {
                streamingResponseHandler.onError(th);
            }
        }, new Runnable() { // from class: io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel.3
            @Override // java.lang.Runnable
            public void run() {
                List list2 = (List) of.get("response");
                int i = 0;
                int i2 = 0;
                String str = null;
                StringBuilder sb = new StringBuilder();
                for (int i3 = 0; i3 < list2.size(); i3++) {
                    TextGenerationResponse.Result result = ((TextGenerationResponse) list2.get(i3)).results().get(0);
                    if (i3 == 0) {
                        i = result.inputTokenCount();
                    }
                    if (i3 == list2.size() - 1) {
                        i2 = result.generatedTokenCount();
                        str = result.stopReason();
                    }
                    sb.append(result.generatedText());
                }
                streamingResponseHandler.onComplete(Response.from(new AiMessage(sb.toString()), new TokenUsage(Integer.valueOf(i), Integer.valueOf(i2)), WatsonxStreamingChatModel.this.toFinishReason(str)));
            }
        });
    }

    public int estimateTokenCount(List<ChatMessage> list) {
        final TokenizationRequest tokenizationRequest = new TokenizationRequest(this.modelId, toInput(list), this.projectId);
        return ((Integer) retryOn(new Callable<Integer>() { // from class: io.quarkiverse.langchain4j.watsonx.WatsonxStreamingChatModel.4
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Integer call() throws Exception {
                return Integer.valueOf(WatsonxStreamingChatModel.this.client.tokenization(tokenizationRequest, WatsonxStreamingChatModel.this.version).result().tokenCount());
            }
        })).intValue();
    }
}
