package io.quarkiverse.langchain4j.mcp.runtime.http;

import com.fasterxml.jackson.databind.JsonNode;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.mcp.client.protocol.InitializationNotification;
import dev.langchain4j.mcp.client.protocol.McpClientMessage;
import dev.langchain4j.mcp.client.protocol.McpInitializeRequest;
import dev.langchain4j.mcp.client.transport.McpOperationHandler;
import dev.langchain4j.mcp.client.transport.McpTransport;
import io.quarkiverse.langchain4j.QuarkusJsonCodecFactory;
import io.quarkus.rest.client.reactive.QuarkusRestClientBuilder;
import io.smallrye.mutiny.Uni;
import io.smallrye.mutiny.groups.UniOnFailure;
import java.io.IOException;
import java.net.URI;
import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import org.jboss.logging.Logger;
import org.jboss.resteasy.reactive.client.api.LoggingScope;
import org.jboss.resteasy.reactive.server.jackson.JacksonBasicMessageBodyReader;

/* loaded from: input_file:io/quarkiverse/langchain4j/mcp/runtime/http/QuarkusHttpMcpTransport.class */
public class QuarkusHttpMcpTransport implements McpTransport {
    private static final Logger log = Logger.getLogger(QuarkusHttpMcpTransport.class);
    private final String sseUrl;
    private final McpSseEndpoint sseEndpoint;
    private final Duration timeout;
    private final boolean logResponses;
    private final boolean logRequests;
    private SseSubscriber mcpSseEventListener;
    private volatile String postUrl;
    private volatile McpPostEndpoint postEndpoint;
    private volatile McpOperationHandler operationHandler;

    /* loaded from: input_file:io/quarkiverse/langchain4j/mcp/runtime/http/QuarkusHttpMcpTransport$Builder.class */
    public static class Builder {
        private String sseUrl;
        private Duration timeout;
        private boolean logRequests = false;
        private boolean logResponses = false;

        public Builder sseUrl(String str) {
            this.sseUrl = str;
            return this;
        }

        public Builder timeout(Duration duration) {
            this.timeout = duration;
            return this;
        }

        public Builder logRequests(boolean z) {
            this.logRequests = z;
            return this;
        }

        public Builder logResponses(boolean z) {
            this.logResponses = z;
            return this;
        }

        public QuarkusHttpMcpTransport build() {
            return new QuarkusHttpMcpTransport(this);
        }
    }

    public QuarkusHttpMcpTransport(Builder builder) {
        this.sseUrl = (String) ValidationUtils.ensureNotNull(builder.sseUrl, "Missing SSE endpoint URL");
        this.timeout = (Duration) Utils.getOrDefault(builder.timeout, Duration.ofSeconds(60L));
        this.logRequests = builder.logRequests;
        this.logResponses = builder.logResponses;
        QuarkusRestClientBuilder register = QuarkusRestClientBuilder.newBuilder().baseUri(URI.create(builder.sseUrl)).connectTimeout(this.timeout.toSeconds(), TimeUnit.SECONDS).readTimeout(this.timeout.toSeconds(), TimeUnit.SECONDS).loggingScope(LoggingScope.ALL).register(new JacksonBasicMessageBodyReader(QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER));
        if (this.logRequests || this.logResponses) {
            register.loggingScope(LoggingScope.REQUEST_RESPONSE);
            register.clientLogger(new McpHttpClientLogger(this.logRequests, this.logResponses));
        }
        this.sseEndpoint = (McpSseEndpoint) register.build(McpSseEndpoint.class);
    }

    public void start(McpOperationHandler mcpOperationHandler) {
        this.operationHandler = mcpOperationHandler;
        this.mcpSseEventListener = startSseChannel(this.logResponses);
        QuarkusRestClientBuilder register = QuarkusRestClientBuilder.newBuilder().baseUri(URI.create(this.postUrl)).connectTimeout(this.timeout.toSeconds(), TimeUnit.SECONDS).readTimeout(this.timeout.toSeconds(), TimeUnit.SECONDS).register(new JacksonBasicMessageBodyReader(QuarkusJsonCodecFactory.ObjectMapperHolder.MAPPER));
        if (this.logRequests || this.logResponses) {
            register.loggingScope(LoggingScope.REQUEST_RESPONSE);
            register.clientLogger(new McpHttpClientLogger(this.logRequests, this.logResponses));
        }
        this.postEndpoint = (McpPostEndpoint) register.build(McpPostEndpoint.class);
    }

    public CompletableFuture<JsonNode> initialize(McpInitializeRequest mcpInitializeRequest) {
        return execute(mcpInitializeRequest, mcpInitializeRequest.getId()).onItem().transformToUni(jsonNode -> {
            return execute(new InitializationNotification(), null).onItem().transform(jsonNode -> {
                return jsonNode;
            });
        }).subscribeAsCompletionStage();
    }

    public void checkHealth() {
    }

    public CompletableFuture<JsonNode> executeOperationWithResponse(McpClientMessage mcpClientMessage) {
        return execute(mcpClientMessage, mcpClientMessage.getId()).subscribeAsCompletionStage();
    }

    public void executeOperationWithoutResponse(McpClientMessage mcpClientMessage) {
        execute(mcpClientMessage, null).subscribe().with(jsonNode -> {
        });
    }

    private Uni<JsonNode> execute(McpClientMessage mcpClientMessage, Long l) {
        CompletableFuture completableFuture = new CompletableFuture();
        Uni<JsonNode> completionStage = Uni.createFrom().completionStage(completableFuture);
        if (l != null) {
            this.operationHandler.startOperation(l, completableFuture);
        }
        UniOnFailure onFailure = this.postEndpoint.post(mcpClientMessage).onFailure();
        Objects.requireNonNull(completableFuture);
        onFailure.invoke(completableFuture::completeExceptionally).onItem().invoke(response -> {
            int status = response.getStatus();
            if (!isExpectedStatusCode(status)) {
                completableFuture.completeExceptionally(new RuntimeException("Unexpected status code: " + status));
            }
            if (l == null) {
                completableFuture.complete(null);
            }
        }).subscribeAsCompletionStage();
        return completionStage;
    }

    private boolean isExpectedStatusCode(int i) {
        return i >= 200 && i < 300;
    }

    private SseSubscriber startSseChannel(boolean z) {
        CompletableFuture completableFuture = new CompletableFuture();
        SseSubscriber sseSubscriber = new SseSubscriber(this.operationHandler, z, completableFuture);
        this.sseEndpoint.get().subscribe().with(sseSubscriber, th -> {
            if (completableFuture.isDone()) {
                return;
            }
            log.warn("Failed to connect to the SSE channel, the MCP client will not be used", th);
            completableFuture.completeExceptionally(th);
        });
        try {
            this.postUrl = buildAbsolutePostUrl((String) completableFuture.get(this.timeout.toMillis() > 0 ? this.timeout.toMillis() : 2147483647L, TimeUnit.MILLISECONDS));
            log.debug("Received the server's POST URL: " + this.postUrl);
            return sseSubscriber;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private String buildAbsolutePostUrl(String str) {
        try {
            return URI.create(this.sseUrl).resolve(str).toString();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void close() throws IOException {
    }
}
