package io.quarkiverse.mcp.server.sse.runtime;

import io.quarkiverse.mcp.server.CompletionManager;
import io.quarkiverse.mcp.server.CompletionResponse;
import io.quarkiverse.mcp.server.Notification;
import io.quarkiverse.mcp.server.NotificationManager;
import io.quarkiverse.mcp.server.PromptManager;
import io.quarkiverse.mcp.server.ResourceManager;
import io.quarkiverse.mcp.server.ResourceTemplateManager;
import io.quarkiverse.mcp.server.ToolManager;
import io.quarkiverse.mcp.server.runtime.ConnectionManager;
import io.quarkiverse.mcp.server.runtime.ContextSupport;
import io.quarkiverse.mcp.server.runtime.FeatureArgument;
import io.quarkiverse.mcp.server.runtime.FeatureMetadata;
import io.quarkiverse.mcp.server.runtime.McpConnectionBase;
import io.quarkiverse.mcp.server.runtime.McpMessageHandler;
import io.quarkiverse.mcp.server.runtime.McpMetadata;
import io.quarkiverse.mcp.server.runtime.McpRequest;
import io.quarkiverse.mcp.server.runtime.McpRequestImpl;
import io.quarkiverse.mcp.server.runtime.Messages;
import io.quarkiverse.mcp.server.runtime.NotificationManagerImpl;
import io.quarkiverse.mcp.server.runtime.PromptCompletionManagerImpl;
import io.quarkiverse.mcp.server.runtime.PromptManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateCompletionManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResourceTemplateManagerImpl;
import io.quarkiverse.mcp.server.runtime.ResponseHandlers;
import io.quarkiverse.mcp.server.runtime.SecuritySupport;
import io.quarkiverse.mcp.server.runtime.Sender;
import io.quarkiverse.mcp.server.runtime.ToolManagerImpl;
import io.quarkiverse.mcp.server.runtime.TrafficLogger;
import io.quarkiverse.mcp.server.runtime.config.McpRuntimeConfig;
import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.security.identity.IdentityProviderManager;
import io.quarkus.vertx.http.runtime.CurrentVertxRequest;
import io.quarkus.vertx.http.runtime.security.QuarkusHttpUser;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.json.Json;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.web.RoutingContext;
import jakarta.enterprise.inject.Instance;
import jakarta.inject.Singleton;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.runtime.ObjectMethods;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import org.jboss.logging.Logger;

@Singleton
/* loaded from: input_file:io/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler.class */
public class StreamableHttpMcpMessageHandler extends McpMessageHandler<HttpMcpRequest> implements Handler<RoutingContext> {
    public static final String MCP_SESSION_ID_HEADER = "Mcp-Session-Id";
    private final McpMetadata metadata;
    private final CurrentVertxRequest currentVertxRequest;
    private final CurrentIdentityAssociation currentIdentityAssociation;
    private static final Logger LOG = Logger.getLogger(StreamableHttpMcpMessageHandler.class);
    private static final Set<String> FORCE_SSE_REQUESTS = Set.of("tools/call", "prompts/get", "resources/read", "completion/complete");
    private static final Set<String> FORCE_SSE_NOTIFICATIONS = Set.of("notifications/initialized", "notifications/roots/list_changed");
    private static final Set<FeatureArgument.Provider> FORCE_SSE_PROVIDERS = Set.of(FeatureArgument.Provider.PROGRESS, FeatureArgument.Provider.MCP_LOG, FeatureArgument.Provider.SAMPLING, FeatureArgument.Provider.ROOTS);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler$HttpMcpRequest.class */
    public static class HttpMcpRequest extends McpRequestImpl implements Sender {
        final boolean newSession;
        final AtomicBoolean sse;
        final HttpServerResponse response;

        public HttpMcpRequest(Object obj, McpConnectionBase mcpConnectionBase, SecuritySupport securitySupport, HttpServerResponse httpServerResponse, boolean z, ContextSupport contextSupport, CurrentIdentityAssociation currentIdentityAssociation) {
            super(obj, mcpConnectionBase, (Sender) null, securitySupport, contextSupport, currentIdentityAssociation);
            this.newSession = z;
            this.sse = new AtomicBoolean(false);
            this.response = httpServerResponse;
        }

        public Sender sender() {
            return this;
        }

        boolean initiateSse() {
            if (!this.sse.compareAndSet(false, true)) {
                return false;
            }
            this.response.setChunked(true);
            this.response.headers().add(HttpHeaders.CONTENT_TYPE, "text/event-stream");
            return true;
        }

        public Future<Void> send(JsonObject jsonObject) {
            if (jsonObject == null) {
                return Future.succeededFuture();
            }
            messageSent(jsonObject);
            if (this.sse.get()) {
                return this.response.write("event: message\ndata: " + jsonObject.encode() + "\n\n");
            }
            this.response.putHeader(HttpHeaders.CONTENT_TYPE, "application/json");
            return this.response.end(jsonObject.toBuffer());
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler$ScanResult.class */
    public static final class ScanResult extends Record {
        private final boolean forceSseInit;
        private final boolean containsRequest;

        ScanResult(boolean z, boolean z2) {
            this.forceSseInit = z;
            this.containsRequest = z2;
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, ScanResult.class), ScanResult.class, "forceSseInit;containsRequest", "FIELD:Lio/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler$ScanResult;->forceSseInit:Z", "FIELD:Lio/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler$ScanResult;->containsRequest:Z").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, ScanResult.class), ScanResult.class, "forceSseInit;containsRequest", "FIELD:Lio/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler$ScanResult;->forceSseInit:Z", "FIELD:Lio/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler$ScanResult;->containsRequest:Z").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, ScanResult.class, Object.class), ScanResult.class, "forceSseInit;containsRequest", "FIELD:Lio/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler$ScanResult;->forceSseInit:Z", "FIELD:Lio/quarkiverse/mcp/server/sse/runtime/StreamableHttpMcpMessageHandler$ScanResult;->containsRequest:Z").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public boolean forceSseInit() {
            return this.forceSseInit;
        }

        public boolean containsRequest() {
            return this.containsRequest;
        }
    }

    StreamableHttpMcpMessageHandler(McpRuntimeConfig mcpRuntimeConfig, ConnectionManager connectionManager, PromptManagerImpl promptManagerImpl, ToolManagerImpl toolManagerImpl, ResourceManagerImpl resourceManagerImpl, PromptCompletionManagerImpl promptCompletionManagerImpl, ResourceTemplateManagerImpl resourceTemplateManagerImpl, ResourceTemplateCompletionManagerImpl resourceTemplateCompletionManagerImpl, NotificationManagerImpl notificationManagerImpl, ResponseHandlers responseHandlers, CurrentVertxRequest currentVertxRequest, Instance<CurrentIdentityAssociation> instance, McpMetadata mcpMetadata, Vertx vertx) {
        super(mcpRuntimeConfig, connectionManager, promptManagerImpl, toolManagerImpl, resourceManagerImpl, promptCompletionManagerImpl, resourceTemplateManagerImpl, resourceTemplateCompletionManagerImpl, notificationManagerImpl, responseHandlers, mcpMetadata, vertx);
        this.metadata = mcpMetadata;
        this.currentVertxRequest = currentVertxRequest;
        this.currentIdentityAssociation = instance.isResolvable() ? (CurrentIdentityAssociation) instance.get() : null;
    }

    public void handle(final RoutingContext routingContext) {
        McpConnectionBase mcpConnectionBase;
        HttpServerRequest request = routingContext.request();
        List<String> all = routingContext.request().headers().getAll(HttpHeaders.ACCEPT);
        if (!accepts(all, "application/json") || !accepts(all, "text/event-stream")) {
            LOG.errorf("Invalid Accept header: %s", all);
            routingContext.fail(400);
            return;
        }
        String header = request.getHeader(MCP_SESSION_ID_HEADER);
        if (header == null) {
            String connectionId = ConnectionManager.connectionId();
            LOG.debugf("Streamable connection initialized [%s]", connectionId);
            mcpConnectionBase = new StreamableHttpMcpConnection(connectionId, this.config.clientLogging().defaultLevel(), this.config.trafficLogging().enabled() ? new TrafficLogger(this.config.trafficLogging().textLimit()) : null, this.config.autoPingInterval());
            this.connectionManager.add(mcpConnectionBase);
        } else {
            mcpConnectionBase = this.connectionManager.get(header);
        }
        if (mcpConnectionBase == null) {
            LOG.errorf("Mcp session not found: %s", header);
            routingContext.fail(404);
            return;
        }
        try {
            Object decodeValue = Json.decodeValue(routingContext.body().buffer());
            final QuarkusHttpUser user = routingContext.user();
            HttpMcpRequest httpMcpRequest = new HttpMcpRequest(decodeValue, mcpConnectionBase, new SecuritySupport() { // from class: io.quarkiverse.mcp.server.sse.runtime.StreamableHttpMcpMessageHandler.1
                public void setCurrentIdentity(CurrentIdentityAssociation currentIdentityAssociation) {
                    if (user != null) {
                        currentIdentityAssociation.setIdentity(user.getSecurityIdentity());
                    } else {
                        currentIdentityAssociation.setIdentity(QuarkusHttpUser.getSecurityIdentity(routingContext, (IdentityProviderManager) null));
                    }
                }
            }, routingContext.response(), header == null, new ContextSupport() { // from class: io.quarkiverse.mcp.server.sse.runtime.StreamableHttpMcpMessageHandler.2
                public void requestContextActivated() {
                    StreamableHttpMcpMessageHandler.this.currentVertxRequest.setCurrent(routingContext);
                }
            }, this.currentIdentityAssociation);
            ScanResult scan = scan(httpMcpRequest);
            if (scan.forceSseInit()) {
                httpMcpRequest.initiateSse();
            }
            handle((McpRequest) httpMcpRequest).onComplete(asyncResult -> {
                if (!asyncResult.succeeded()) {
                    if (routingContext.response().ended()) {
                        return;
                    }
                    routingContext.response().setStatusCode(500).end();
                } else if (httpMcpRequest.sse.get()) {
                    routingContext.response().end();
                } else {
                    if (routingContext.response().ended()) {
                        return;
                    }
                    if (scan.containsRequest()) {
                        routingContext.end();
                    } else {
                        routingContext.response().setStatusCode(202).end();
                    }
                }
            });
        } catch (Exception e) {
            LOG.errorf(e, "Unable to parse the JSON message", new Object[0]);
            routingContext.response().putHeader(HttpHeaders.CONTENT_TYPE, "application/json");
            routingContext.end(Messages.newError((Object) null, -32700, "Unable to parse the JSON message").toBuffer());
        }
    }

    public void terminateSession(RoutingContext routingContext) {
        String header = routingContext.request().getHeader(MCP_SESSION_ID_HEADER);
        if (header == null) {
            LOG.errorf("Mcp session id header is missing: %s", routingContext.normalizedPath());
            routingContext.fail(404);
            return;
        }
        McpConnectionBase mcpConnectionBase = this.connectionManager.get(header);
        if (mcpConnectionBase == null) {
            LOG.errorf("Mcp session not found: %s", header);
            routingContext.fail(404);
        } else {
            if (this.connectionManager.remove(mcpConnectionBase.id())) {
                LOG.infof("Mcp session terminated: %s", mcpConnectionBase.id());
            }
            routingContext.end();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void afterInitialize(HttpMcpRequest httpMcpRequest) {
        httpMcpRequest.response.headers().add(MCP_SESSION_ID_HEADER, httpMcpRequest.connection().id());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initializeFailed(HttpMcpRequest httpMcpRequest) {
        this.connectionManager.remove(httpMcpRequest.connection().id());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void jsonrpcValidationFailed(HttpMcpRequest httpMcpRequest) {
        if (httpMcpRequest.newSession) {
            this.connectionManager.remove(httpMcpRequest.connection().id());
        }
    }

    private boolean accepts(List<String> list, String str) {
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            if (it.next().contains(str)) {
                return true;
            }
        }
        return false;
    }

    private ScanResult scan(HttpMcpRequest httpMcpRequest) {
        boolean z = false;
        boolean z2 = false;
        Object json = httpMcpRequest.json();
        if (json instanceof JsonObject) {
            JsonObject jsonObject = (JsonObject) json;
            z = forceSse(httpMcpRequest, jsonObject);
            z2 = Messages.isRequest(jsonObject);
        } else {
            Object json2 = httpMcpRequest.json();
            if (json2 instanceof JsonArray) {
                JsonArray jsonArray = (JsonArray) json2;
                if (!Messages.isResponse(jsonArray.getJsonObject(0))) {
                    z = jsonArray.size() > 1 || forceSse(httpMcpRequest, jsonArray.getJsonObject(0));
                    Iterator it = jsonArray.iterator();
                    while (true) {
                        if (!it.hasNext()) {
                            break;
                        }
                        Object next = it.next();
                        if ((next instanceof JsonObject) && Messages.isRequest((JsonObject) next)) {
                            z2 = true;
                            break;
                        }
                    }
                }
            }
        }
        return new ScanResult(z, z2);
    }

    private boolean forceSse(HttpMcpRequest httpMcpRequest, JsonObject jsonObject) {
        String string = jsonObject.getString("method");
        if (string == null) {
            return false;
        }
        if (!Messages.isRequest(jsonObject) || !FORCE_SSE_REQUESTS.contains(string)) {
            return Messages.isNotification(jsonObject) && FORCE_SSE_NOTIFICATIONS.contains(string) && forceSseNotification(string);
        }
        JsonObject jsonObject2 = jsonObject.getJsonObject("params");
        if (jsonObject2 == null) {
            return false;
        }
        boolean z = -1;
        switch (string.hashCode()) {
            case -1474017780:
                if (string.equals("completion/complete")) {
                    z = 3;
                    break;
                }
                break;
            case 498659858:
                if (string.equals("tools/call")) {
                    z = false;
                    break;
                }
                break;
            case 812186432:
                if (string.equals("resources/read")) {
                    z = 2;
                    break;
                }
                break;
            case 1650876630:
                if (string.equals("prompts/get")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return forceSseTool(jsonObject2);
            case true:
                return forceSsePrompt(jsonObject2);
            case true:
                return forceSseResource(jsonObject2);
            case true:
                return forceSseCompletion(jsonObject2);
            default:
                throw new IllegalArgumentException("Unexpected value: " + string);
        }
    }

    private boolean forceSseTool(JsonObject jsonObject) {
        String string = jsonObject.getString("name");
        FeatureMetadata featureMetadata = (FeatureMetadata) this.metadata.tools().stream().filter(featureMetadata2 -> {
            return featureMetadata2.info().name().equals(string);
        }).findFirst().orElse(null);
        if (featureMetadata == null) {
            ToolManager.ToolInfo tool = this.toolManager.getTool(string);
            return (tool == null || tool.isMethod()) ? false : true;
        }
        Iterator it = featureMetadata.info().arguments().iterator();
        while (it.hasNext()) {
            if (FORCE_SSE_PROVIDERS.contains(((FeatureArgument) it.next()).provider())) {
                return true;
            }
        }
        return false;
    }

    private boolean forceSsePrompt(JsonObject jsonObject) {
        String string = jsonObject.getString("name");
        FeatureMetadata featureMetadata = (FeatureMetadata) this.metadata.prompts().stream().filter(featureMetadata2 -> {
            return featureMetadata2.info().name().equals(string);
        }).findFirst().orElse(null);
        if (featureMetadata == null) {
            PromptManager.PromptInfo prompt = this.promptManager.getPrompt(string);
            return (prompt == null || prompt.isMethod()) ? false : true;
        }
        Iterator it = featureMetadata.info().arguments().iterator();
        while (it.hasNext()) {
            if (FORCE_SSE_PROVIDERS.contains(((FeatureArgument) it.next()).provider())) {
                return true;
            }
        }
        return false;
    }

    private boolean forceSseResource(JsonObject jsonObject) {
        String string = jsonObject.getString("uri");
        FeatureMetadata featureMetadata = (FeatureMetadata) this.metadata.resources().stream().filter(featureMetadata2 -> {
            return featureMetadata2.info().uri().equals(string);
        }).findFirst().orElse(null);
        if (featureMetadata != null) {
            Iterator it = featureMetadata.info().arguments().iterator();
            while (it.hasNext()) {
                if (FORCE_SSE_PROVIDERS.contains(((FeatureArgument) it.next()).provider())) {
                    return true;
                }
            }
            return false;
        }
        ResourceManager.ResourceInfo resource = this.resourceManager.getResource(string);
        if (resource != null) {
            return !resource.isMethod();
        }
        ResourceTemplateManager.ResourceTemplateInfo findMatching = this.resourceTemplateManager.findMatching(string);
        return (findMatching == null || findMatching.isMethod()) ? false : true;
    }

    private boolean forceSseCompletion(JsonObject jsonObject) {
        JsonObject jsonObject2 = jsonObject.getJsonObject("ref");
        if (jsonObject2 == null) {
            return false;
        }
        String string = jsonObject2.getString("type");
        String string2 = jsonObject2.getString("name");
        JsonObject jsonObject3 = jsonObject.getJsonObject("argument");
        String string3 = jsonObject3 != null ? jsonObject3.getString("name") : null;
        if (string2 == null || string3 == null) {
            return false;
        }
        if ("ref/prompt".equals(string)) {
            return forceSseCompletion(string2, string3, this.metadata.promptCompletions(), this.promptCompletionManager);
        }
        if ("ref/resource".equals(string)) {
            return forceSseCompletion(string2, string3, this.metadata.resourceTemplateCompletions(), this.resourceTemplateCompletionManager);
        }
        return false;
    }

    private boolean forceSseCompletion(String str, String str2, List<FeatureMetadata<CompletionResponse>> list, CompletionManager completionManager) {
        FeatureMetadata<CompletionResponse> orElse = list.stream().filter(featureMetadata -> {
            return featureMetadata.info().name().equals(str) && str2.equals(((FeatureArgument) featureMetadata.info().arguments().stream().filter((v0) -> {
                return v0.isParam();
            }).findFirst().orElseThrow()).name());
        }).findFirst().orElse(null);
        if (orElse == null) {
            CompletionManager.CompletionInfo completion = completionManager.getCompletion(str, str2);
            return (completion == null || completion.isMethod()) ? false : true;
        }
        Iterator it = orElse.info().arguments().iterator();
        while (it.hasNext()) {
            if (FORCE_SSE_PROVIDERS.contains(((FeatureArgument) it.next()).provider())) {
                return true;
            }
        }
        return false;
    }

    private boolean forceSseNotification(String str) {
        Iterator it = this.metadata.notifications().stream().filter(featureMetadata -> {
            return Notification.Type.valueOf(featureMetadata.info().description()) == Notification.Type.from(str);
        }).toList().iterator();
        while (it.hasNext()) {
            Iterator it2 = ((FeatureMetadata) it.next()).info().arguments().iterator();
            while (it2.hasNext()) {
                if (FORCE_SSE_PROVIDERS.contains(((FeatureArgument) it2.next()).provider())) {
                    return true;
                }
            }
        }
        Iterator it3 = this.notificationManager.iterator();
        while (it3.hasNext()) {
            NotificationManager.NotificationInfo notificationInfo = (NotificationManager.NotificationInfo) it3.next();
            if (!notificationInfo.isMethod() && notificationInfo.type() == Notification.Type.from(str)) {
                return true;
            }
        }
        return false;
    }
}
