package io.quarkiverse.mcp.server.proxy;

import io.quarkiverse.mcp.server.sse.client.SseClient;
import io.quarkus.runtime.Quarkus;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.PrintStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Phaser;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.jboss.logging.Logger;
import picocli.CommandLine;

@CommandLine.Command(name = "stdio-sse-proxy", mixinStandardHelpOptions = true)
/* loaded from: input_file:io/quarkiverse/mcp/server/proxy/StdioSseProxy.class */
public class StdioSseProxy implements Runnable {
    private static final Logger LOG = Logger.getLogger((Class<?>) StdioSseProxy.class);

    @CommandLine.Parameters(description = {"The URI of the target SSE endpoint"}, defaultValue = "http://localhost:8080/mcp/sse")
    URI sseEndpoint;

    @CommandLine.Option(names = {"-t", "--timeout"}, defaultValue = "10", description = {"The timeout in seconds; used when connecting to the SSE endpoint and to obtain the message endpoint"})
    int timeout;

    @CommandLine.Option(names = {"-s", "--sleep"}, defaultValue = "60", description = {"The sleep time in milliseconds; used when processing the stdin queue"})
    int sleep;

    @CommandLine.Option(names = {"--reconnect"}, negatable = true, defaultValue = "true", description = {"If set to true then the proxy attempts to reconnect if a message endpoint returns http status 400"})
    boolean reconnect;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/quarkiverse/mcp/server/proxy/StdioSseProxy$EndpointPhaser.class */
    public class EndpointPhaser {
        private final Phaser phaser = new Phaser(1);
        private final AtomicInteger phase = new AtomicInteger(0);

        EndpointPhaser() {
        }

        public void countDown() {
            this.phaser.arrive();
        }

        public void await(long j, TimeUnit timeUnit) throws InterruptedException, TimeoutException {
            this.phaser.awaitAdvanceInterruptibly(this.phase.get(), j, timeUnit);
        }

        public void reset() {
            this.phase.incrementAndGet();
        }
    }

    @Override // java.lang.Runnable
    public void run() {
        LOG.infof("Stdio -> SSE [sse: %s, timeout: %s, reconnect: %s, sleep: %s]", this.sseEndpoint, Integer.valueOf(this.timeout), Boolean.valueOf(this.reconnect), Integer.valueOf(this.sleep));
        final InputStream inputStream = System.in;
        final PrintStream printStream = System.out;
        System.setOut(new PrintStream(OutputStream.nullOutputStream()));
        final ConcurrentLinkedQueue concurrentLinkedQueue = new ConcurrentLinkedQueue();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(2);
        final HttpClient build = HttpClient.newBuilder().connectTimeout(Duration.ofSeconds(this.timeout)).build();
        final AtomicReference atomicReference = new AtomicReference();
        final EndpointPhaser endpointPhaser = new EndpointPhaser();
        final SseClient sseClient = new SseClient(this.sseEndpoint) { // from class: io.quarkiverse.mcp.server.proxy.StdioSseProxy.1
            @Override // io.quarkiverse.mcp.server.sse.client.SseClient
            protected void process(SseClient.SseEvent sseEvent) {
                if (!"endpoint".equals(sseEvent.name())) {
                    printStream.println(sseEvent.data());
                    return;
                }
                String strip = sseEvent.data().strip();
                StdioSseProxy.LOG.infof("Message endpoint received: %s", strip);
                atomicReference.set(StdioSseProxy.this.sseEndpoint.resolve(strip));
                endpointPhaser.countDown();
            }
        };
        sseClient.connect(build, Map.of());
        newFixedThreadPool.submit(new Runnable() { // from class: io.quarkiverse.mcp.server.proxy.StdioSseProxy.2
            @Override // java.lang.Runnable
            public void run() {
                try {
                    BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
                    while (true) {
                        try {
                            String readLine = bufferedReader.readLine();
                            if (readLine == null) {
                                StdioSseProxy.LOG.debug("EOF received, exiting");
                                Quarkus.asyncExit(0);
                                bufferedReader.close();
                                return;
                            } else if (!readLine.isBlank()) {
                                StdioSseProxy.LOG.debugf("Line added to queue:\n%s", readLine);
                                concurrentLinkedQueue.offer(readLine);
                            }
                        } finally {
                        }
                    }
                } catch (IOException e) {
                    StdioSseProxy.LOG.errorf((Throwable) e, "Error reading stdio", new Object[0]);
                }
            }
        });
        newFixedThreadPool.submit(new Runnable() { // from class: io.quarkiverse.mcp.server.proxy.StdioSseProxy.3
            @Override // java.lang.Runnable
            public void run() {
                while (true) {
                    try {
                        try {
                            endpointPhaser.await(StdioSseProxy.this.timeout, TimeUnit.SECONDS);
                        } catch (TimeoutException e) {
                            StdioSseProxy.LOG.errorf(e, "Message endpoint not received in %s seconds", Integer.valueOf(StdioSseProxy.this.timeout));
                        }
                        String str = (String) concurrentLinkedQueue.poll();
                        if (str != null && !str.isBlank()) {
                            try {
                                StdioSseProxy.this.sendData(sseClient, endpointPhaser, build, atomicReference, str, false);
                            } catch (IOException e2) {
                                StdioSseProxy.LOG.errorf(e2, "Unable to send POST request to %s", atomicReference.get());
                            }
                        }
                        TimeUnit.MILLISECONDS.sleep(StdioSseProxy.this.sleep);
                    } catch (InterruptedException e3) {
                        Thread.currentThread().interrupt();
                        return;
                    }
                }
            }
        });
        Quarkus.waitForExit();
    }

    private void sendData(SseClient sseClient, EndpointPhaser endpointPhaser, HttpClient httpClient, AtomicReference<URI> atomicReference, String str, boolean z) throws IOException, InterruptedException {
        LOG.debugf("%s data to SSE:\n%s", z ? "Resending" : "Sending", str);
        HttpResponse send = httpClient.send(HttpRequest.newBuilder().uri(atomicReference.get()).version(HttpClient.Version.HTTP_1_1).POST(HttpRequest.BodyPublishers.ofString(str)).build(), HttpResponse.BodyHandlers.discarding());
        if (send.statusCode() != 400 || !this.reconnect || z) {
            if (send.statusCode() != 200) {
                LOG.errorf("Received erroneous status code: %s", Integer.valueOf(send.statusCode()));
                return;
            }
            return;
        }
        LOG.infof("Message endpoint %s not found - reconnecting SSE client..", atomicReference.get());
        endpointPhaser.reset();
        sseClient.connect(httpClient, Map.of());
        try {
            endpointPhaser.await(this.timeout, TimeUnit.SECONDS);
            sendData(sseClient, endpointPhaser, httpClient, atomicReference, str, true);
        } catch (TimeoutException e) {
            LOG.errorf(e, "Message endpoint not received in %s seconds", Integer.valueOf(this.timeout));
        }
    }
}
