package org.springframework.graphql.server.webflux;

import java.net.InetSocketAddress;
import java.net.URI;
import java.security.Principal;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.postgresql.core.Oid;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.aot.hint.annotation.RegisterReflectionForBinding;
import org.springframework.graphql.server.WebGraphQlHandler;
import org.springframework.graphql.server.WebGraphQlResponse;
import org.springframework.graphql.server.WebSocketGraphQlInterceptor;
import org.springframework.graphql.server.WebSocketGraphQlRequest;
import org.springframework.graphql.server.WebSocketSessionInfo;
import org.springframework.graphql.server.support.GraphQlWebSocketMessage;
import org.springframework.http.HttpHeaders;
import org.springframework.http.codec.CodecConfigurer;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@RegisterReflectionForBinding({GraphQlWebSocketMessage.class})
/* loaded from: input_file:BOOT-INF/lib/spring-graphql-1.3.2.jar:org/springframework/graphql/server/webflux/GraphQlWebSocketHandler.class */
public class GraphQlWebSocketHandler implements WebSocketHandler {
    private static final Log logger = LogFactory.getLog((Class<?>) GraphQlWebSocketHandler.class);
    private static final List<String> SUB_PROTOCOL_LIST = Arrays.asList("graphql-transport-ws", "graphql-ws");
    private final WebGraphQlHandler graphQlHandler;
    private final WebSocketGraphQlInterceptor webSocketInterceptor;
    private final WebSocketCodecDelegate codecDelegate;
    private final Duration initTimeoutDuration;

    @Nullable
    private final Duration keepAliveDuration;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/spring-graphql-1.3.2.jar:org/springframework/graphql/server/webflux/GraphQlWebSocketHandler$GraphQlStatus.class */
    public static final class GraphQlStatus {
        static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message");
        static final CloseStatus UNAUTHORIZED_STATUS = new CloseStatus(4401, "Unauthorized");
        static final CloseStatus INIT_TIMEOUT_STATUS = new CloseStatus(4408, "Connection initialisation timeout");
        static final CloseStatus TOO_MANY_INIT_REQUESTS_STATUS = new CloseStatus(4429, "Too many initialisation requests");

        private GraphQlStatus() {
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static <V> Flux<V> close(WebSocketSession webSocketSession, CloseStatus closeStatus) {
            return webSocketSession.close(closeStatus).thenMany(Mono.empty());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/spring-graphql-1.3.2.jar:org/springframework/graphql/server/webflux/GraphQlWebSocketHandler$SubscriptionExistsException.class */
    public static final class SubscriptionExistsException extends RuntimeException {
        private SubscriptionExistsException() {
        }
    }

    /* loaded from: input_file:BOOT-INF/lib/spring-graphql-1.3.2.jar:org/springframework/graphql/server/webflux/GraphQlWebSocketHandler$WebFluxSessionInfo.class */
    private static final class WebFluxSessionInfo implements WebSocketSessionInfo {
        private final WebSocketSession session;

        private WebFluxSessionInfo(WebSocketSession webSocketSession) {
            this.session = webSocketSession;
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public String getId() {
            return this.session.getId();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public Map<String, Object> getAttributes() {
            return this.session.getAttributes();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public URI getUri() {
            return this.session.getHandshakeInfo().getUri();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public HttpHeaders getHeaders() {
            return this.session.getHandshakeInfo().getHeaders();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public Mono<Principal> getPrincipal() {
            return this.session.getHandshakeInfo().getPrincipal();
        }

        @Override // org.springframework.graphql.server.WebSocketSessionInfo
        public InetSocketAddress getRemoteAddress() {
            return this.session.getHandshakeInfo().getRemoteAddress();
        }
    }

    public GraphQlWebSocketHandler(WebGraphQlHandler webGraphQlHandler, CodecConfigurer codecConfigurer, Duration duration) {
        this(webGraphQlHandler, codecConfigurer, duration, null);
    }

    public GraphQlWebSocketHandler(WebGraphQlHandler webGraphQlHandler, CodecConfigurer codecConfigurer, Duration duration, @Nullable Duration duration2) {
        Assert.notNull(webGraphQlHandler, "WebGraphQlHandler is required");
        this.graphQlHandler = webGraphQlHandler;
        this.webSocketInterceptor = this.graphQlHandler.getWebSocketInterceptor();
        this.codecDelegate = new WebSocketCodecDelegate(codecConfigurer);
        this.initTimeoutDuration = duration;
        this.keepAliveDuration = duration2;
    }

    public List<String> getSubProtocols() {
        return SUB_PROTOCOL_LIST;
    }

    public Mono<Void> handle(WebSocketSession webSocketSession) {
        HandshakeInfo handshakeInfo = webSocketSession.getHandshakeInfo();
        if ("graphql-ws".equalsIgnoreCase(handshakeInfo.getSubProtocol())) {
            if (logger.isDebugEnabled()) {
                logger.debug("apollographql/subscriptions-transport-ws is not supported, nor maintained. Please, use https://github.com/enisdenjo/graphql-ws.");
            }
            return webSocketSession.close(GraphQlStatus.INVALID_MESSAGE_STATUS);
        }
        WebFluxSessionInfo webFluxSessionInfo = new WebFluxSessionInfo(webSocketSession);
        AtomicReference atomicReference = new AtomicReference();
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        Mono.delay(this.initTimeoutDuration).then(Mono.defer(() -> {
            return atomicReference.compareAndSet(null, Collections.emptyMap()) ? webSocketSession.close(GraphQlStatus.INIT_TIMEOUT_STATUS) : Mono.empty();
        })).subscribe();
        webSocketSession.closeStatus().doOnSuccess(closeStatus -> {
            Map<String, Object> map = (Map) atomicReference.get();
            if (map == null) {
                return;
            }
            this.webSocketInterceptor.handleConnectionClosed(webFluxSessionInfo, closeStatus != null ? closeStatus.getCode() : Oid.INT2_ARRAY, map);
        }).subscribe();
        return webSocketSession.send(webSocketSession.receive().flatMap(webSocketMessage -> {
            GraphQlWebSocketMessage decode = this.codecDelegate.decode(webSocketMessage);
            String id = decode.getId();
            Map<String, Object> map = (Map) decode.getPayload();
            switch (decode.resolvedType()) {
                case SUBSCRIBE:
                    if (atomicReference.get() == null) {
                        return GraphQlStatus.close(webSocketSession, GraphQlStatus.UNAUTHORIZED_STATUS);
                    }
                    if (id == null) {
                        return GraphQlStatus.close(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
                    }
                    WebSocketGraphQlRequest webSocketGraphQlRequest = new WebSocketGraphQlRequest(handshakeInfo.getUri(), handshakeInfo.getHeaders(), handshakeInfo.getCookies(), handshakeInfo.getRemoteAddress(), handshakeInfo.getAttributes(), map, id, null, webFluxSessionInfo);
                    if (logger.isDebugEnabled()) {
                        logger.debug("Executing: " + webSocketGraphQlRequest);
                    }
                    return this.graphQlHandler.handleRequest(webSocketGraphQlRequest).flatMapMany(webGraphQlResponse -> {
                        return handleResponse(webSocketSession, id, concurrentHashMap, webGraphQlResponse);
                    }).doOnTerminate(() -> {
                        concurrentHashMap.remove(id);
                    });
                case PING:
                    return Flux.just(this.codecDelegate.encode(webSocketSession, GraphQlWebSocketMessage.pong(null)));
                case PONG:
                    return Flux.empty();
                case COMPLETE:
                    if (id == null) {
                        return Flux.empty();
                    }
                    Subscription subscription = (Subscription) concurrentHashMap.remove(id);
                    if (subscription != null) {
                        subscription.cancel();
                    }
                    return this.webSocketInterceptor.handleCancelledSubscription(webFluxSessionInfo, id).thenMany(Flux.empty());
                case CONNECTION_INIT:
                    if (!atomicReference.compareAndSet(null, map)) {
                        return GraphQlStatus.close(webSocketSession, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
                    }
                    Flux flux = this.webSocketInterceptor.handleConnectionInitialization(webFluxSessionInfo, map).defaultIfEmpty(Collections.emptyMap()).map(obj -> {
                        return this.codecDelegate.encodeConnectionAck(webSocketSession, obj);
                    }).flux();
                    if (this.keepAliveDuration != null) {
                        flux = flux.mergeWith(Flux.interval(this.keepAliveDuration, this.keepAliveDuration).filter(l -> {
                            return !this.codecDelegate.checkMessagesEncodedAndClear();
                        }).map(l2 -> {
                            return this.codecDelegate.encode(webSocketSession, GraphQlWebSocketMessage.ping(null));
                        }));
                    }
                    return flux.onErrorResume(th -> {
                        return GraphQlStatus.close(webSocketSession, GraphQlStatus.UNAUTHORIZED_STATUS);
                    });
                default:
                    return GraphQlStatus.close(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
            }
        }));
    }

    private Flux<WebSocketMessage> handleResponse(WebSocketSession webSocketSession, String str, Map<String, Subscription> map, WebGraphQlResponse webGraphQlResponse) {
        if (logger.isDebugEnabled()) {
            logger.debug("Execution result ready" + (!CollectionUtils.isEmpty(webGraphQlResponse.getErrors()) ? " with errors: " + webGraphQlResponse.getErrors() : "") + ".");
        }
        return (webGraphQlResponse.getData() instanceof Publisher ? Flux.from((Publisher) webGraphQlResponse.getData()).map((v0) -> {
            return v0.toSpecification();
        }).doOnSubscribe(subscription -> {
            if (((Subscription) map.putIfAbsent(str, subscription)) != null) {
                throw new SubscriptionExistsException();
            }
        }) : Flux.just(webGraphQlResponse.toMap())).map(map2 -> {
            return this.codecDelegate.encodeNext(webSocketSession, str, map2);
        }).concatWith(Mono.fromCallable(() -> {
            return this.codecDelegate.encodeComplete(webSocketSession, str);
        })).onErrorResume(th -> {
            return th instanceof SubscriptionExistsException ? GraphQlStatus.close(webSocketSession, new CloseStatus(4409, "Subscriber for " + str + " already exists")) : Mono.fromCallable(() -> {
                return this.codecDelegate.encodeError(webSocketSession, str, th);
            });
        });
    }
}
