package org.eclipse.jetty.websocket.jakarta.common;

import jakarta.websocket.ClientEndpointConfig;
import jakarta.websocket.CloseReason;
import jakarta.websocket.Decoder;
import jakarta.websocket.EndpointConfig;
import jakarta.websocket.MessageHandler;
import jakarta.websocket.PongMessage;
import jakarta.websocket.server.ServerEndpointConfig;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import org.eclipse.jetty.util.BufferUtil;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.thread.AutoLock;
import org.eclipse.jetty.websocket.core.CloseStatus;
import org.eclipse.jetty.websocket.core.CoreSession;
import org.eclipse.jetty.websocket.core.Frame;
import org.eclipse.jetty.websocket.core.FrameHandler;
import org.eclipse.jetty.websocket.core.OpCode;
import org.eclipse.jetty.websocket.core.exception.ProtocolException;
import org.eclipse.jetty.websocket.core.exception.WebSocketException;
import org.eclipse.jetty.websocket.jakarta.common.decoders.AvailableDecoders;
import org.eclipse.jetty.websocket.jakarta.common.decoders.RegisteredDecoder;
import org.eclipse.jetty.websocket.jakarta.common.messages.DecodedBinaryMessageSink;
import org.eclipse.jetty.websocket.jakarta.common.messages.DecodedBinaryStreamMessageSink;
import org.eclipse.jetty.websocket.jakarta.common.messages.DecodedTextMessageSink;
import org.eclipse.jetty.websocket.jakarta.common.messages.DecodedTextStreamMessageSink;
import org.eclipse.jetty.websocket.util.InvokerUtils;
import org.eclipse.jetty.websocket.util.messages.MessageSink;
import org.eclipse.jetty.websocket.util.messages.PartialByteArrayMessageSink;
import org.eclipse.jetty.websocket.util.messages.PartialByteBufferMessageSink;
import org.eclipse.jetty.websocket.util.messages.PartialStringMessageSink;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/eclipse/jetty/websocket/jakarta/common/JakartaWebSocketFrameHandler.class */
public class JakartaWebSocketFrameHandler implements FrameHandler {
    private final Logger logger;
    private final JakartaWebSocketContainer container;
    private final Object endpointInstance;
    private MethodHandle openHandle;
    private MethodHandle closeHandle;
    private MethodHandle errorHandle;
    private MethodHandle pongHandle;
    private JakartaWebSocketMessageMetadata textMetadata;
    private JakartaWebSocketMessageMetadata binaryMetadata;
    private UpgradeRequest upgradeRequest;
    private EndpointConfig endpointConfig;
    private MessageSink textSink;
    private MessageSink binarySink;
    private MessageSink activeMessageSink;
    private JakartaWebSocketSession session;
    private CoreSession coreSession;
    private final AutoLock lock = new AutoLock();
    private final AtomicBoolean closeNotified = new AtomicBoolean();
    private final Map<Byte, RegisteredMessageHandler> messageHandlerMap = new HashMap();
    protected byte dataType = -1;

    public JakartaWebSocketFrameHandler(JakartaWebSocketContainer jakartaWebSocketContainer, Object obj, MethodHandle methodHandle, MethodHandle methodHandle2, MethodHandle methodHandle3, JakartaWebSocketMessageMetadata jakartaWebSocketMessageMetadata, JakartaWebSocketMessageMetadata jakartaWebSocketMessageMetadata2, MethodHandle methodHandle4, EndpointConfig endpointConfig) {
        this.logger = LoggerFactory.getLogger(obj.getClass());
        this.container = jakartaWebSocketContainer;
        if (obj instanceof ConfiguredEndpoint) {
            RuntimeException runtimeException = new RuntimeException("ConfiguredEndpoint needs to be unwrapped");
            this.logger.warn("Unexpected ConfiguredEndpoint", runtimeException);
            throw runtimeException;
        }
        this.endpointInstance = obj;
        this.openHandle = methodHandle;
        this.closeHandle = methodHandle2;
        this.errorHandle = methodHandle3;
        this.textMetadata = jakartaWebSocketMessageMetadata;
        this.binaryMetadata = jakartaWebSocketMessageMetadata2;
        this.pongHandle = methodHandle4;
        this.endpointConfig = endpointConfig;
    }

    public Object getEndpoint() {
        return this.endpointInstance;
    }

    public EndpointConfig getEndpointConfig() {
        return this.endpointConfig;
    }

    public JakartaWebSocketSession getSession() {
        return this.session;
    }

    public void onOpen(CoreSession coreSession, Callback callback) {
        this.coreSession = coreSession;
        try {
            this.endpointConfig = getWrappedEndpointConfig();
            this.session = new JakartaWebSocketSession(this.container, coreSession, this, this.endpointConfig);
            this.openHandle = InvokerUtils.bindTo(this.openHandle, new Object[]{this.session, this.endpointConfig});
            this.closeHandle = InvokerUtils.bindTo(this.closeHandle, new Object[]{this.session});
            this.errorHandle = InvokerUtils.bindTo(this.errorHandle, new Object[]{this.session});
            this.pongHandle = InvokerUtils.bindTo(this.pongHandle, new Object[]{this.session});
            JakartaWebSocketMessageMetadata copyOf = JakartaWebSocketMessageMetadata.copyOf(this.textMetadata);
            if (copyOf != null) {
                if (copyOf.isMaxMessageSizeSet()) {
                    this.session.setMaxTextMessageBufferSize(copyOf.getMaxMessageSize());
                }
                copyOf.setMethodHandle(JakartaWebSocketFrameHandlerFactory.wrapNonVoidReturnType(InvokerUtils.bindTo(copyOf.getMethodHandle(), new Object[]{this.endpointInstance, this.endpointConfig, this.session}), this.session));
                this.textSink = JakartaWebSocketFrameHandlerFactory.createMessageSink(this.session, copyOf);
                this.textMetadata = copyOf;
            }
            JakartaWebSocketMessageMetadata copyOf2 = JakartaWebSocketMessageMetadata.copyOf(this.binaryMetadata);
            if (copyOf2 != null) {
                if (copyOf2.isMaxMessageSizeSet()) {
                    this.session.setMaxBinaryMessageBufferSize(copyOf2.getMaxMessageSize());
                }
                copyOf2.setMethodHandle(JakartaWebSocketFrameHandlerFactory.wrapNonVoidReturnType(InvokerUtils.bindTo(copyOf2.getMethodHandle(), new Object[]{this.endpointInstance, this.endpointConfig, this.session}), this.session));
                this.binarySink = JakartaWebSocketFrameHandlerFactory.createMessageSink(this.session, copyOf2);
                this.binaryMetadata = copyOf2;
            }
            if (this.openHandle != null) {
                (void) this.openHandle.invoke();
            }
            this.container.notifySessionListeners(jakartaWebSocketSessionListener -> {
                jakartaWebSocketSessionListener.onJakartaWebSocketSessionOpened(this.session);
            });
            callback.succeeded();
        } catch (Throwable th) {
            callback.failed(new WebSocketException(this.endpointInstance.getClass().getSimpleName() + " OPEN method error: " + th.getMessage(), th));
        }
    }

    private EndpointConfig getWrappedEndpointConfig() {
        final PutListenerMap putListenerMap = new PutListenerMap(this.endpointConfig.getUserProperties(), this::configListener);
        return this.endpointConfig instanceof ServerEndpointConfig ? new ServerEndpointConfigWrapper(this.endpointConfig) { // from class: org.eclipse.jetty.websocket.jakarta.common.JakartaWebSocketFrameHandler.1
            @Override // org.eclipse.jetty.websocket.jakarta.common.EndpointConfigWrapper
            public Map<String, Object> getUserProperties() {
                return putListenerMap;
            }
        } : this.endpointConfig instanceof ClientEndpointConfig ? new ClientEndpointConfigWrapper(this.endpointConfig) { // from class: org.eclipse.jetty.websocket.jakarta.common.JakartaWebSocketFrameHandler.2
            @Override // org.eclipse.jetty.websocket.jakarta.common.EndpointConfigWrapper
            public Map<String, Object> getUserProperties() {
                return putListenerMap;
            }
        } : new EndpointConfigWrapper(this.endpointConfig) { // from class: org.eclipse.jetty.websocket.jakarta.common.JakartaWebSocketFrameHandler.3
            @Override // org.eclipse.jetty.websocket.jakarta.common.EndpointConfigWrapper
            public Map<String, Object> getUserProperties() {
                return putListenerMap;
            }
        };
    }

    public void onFrame(Frame frame, Callback callback) {
        switch (frame.getOpCode()) {
            case 0:
                onContinuation(frame, callback);
                break;
            case 1:
                this.dataType = (byte) 1;
                onText(frame, callback);
                break;
            case 2:
                this.dataType = (byte) 2;
                onBinary(frame, callback);
                break;
            case 3:
            case 4:
            case 5:
            case 6:
            case 7:
            default:
                callback.failed(new IllegalStateException());
                break;
            case 8:
                onClose(frame, callback);
                break;
            case 9:
                onPing(frame, callback);
                break;
            case 10:
                onPong(frame, callback);
                break;
        }
        if (!frame.isFin() || frame.isControlFrame()) {
            return;
        }
        this.dataType = (byte) -1;
    }

    public void onClose(Frame frame, Callback callback) {
        notifyOnClose(CloseStatus.getCloseStatus(frame), callback);
    }

    public void onClosed(CloseStatus closeStatus, Callback callback) {
        notifyOnClose(closeStatus, callback);
        this.container.notifySessionListeners(jakartaWebSocketSessionListener -> {
            jakartaWebSocketSessionListener.onJakartaWebSocketSessionClosed(this.session);
        });
    }

    private void notifyOnClose(CloseStatus closeStatus, Callback callback) {
        if (!this.closeNotified.compareAndSet(false, true)) {
            callback.succeeded();
            return;
        }
        try {
            if (this.closeHandle != null) {
                (void) this.closeHandle.invoke(new CloseReason(CloseReason.CloseCodes.getCloseCode(closeStatus.getCode()), closeStatus.getReason()));
            }
            callback.succeeded();
        } catch (Throwable th) {
            callback.failed(new WebSocketException(this.endpointInstance.getClass().getSimpleName() + " CLOSE method error: " + th.getMessage(), th));
        }
    }

    public void onError(Throwable th, Callback callback) {
        try {
            if (this.errorHandle != null) {
                (void) this.errorHandle.invoke(th);
            } else {
                this.logger.warn("Unhandled Error: " + this.endpointInstance, th);
            }
            callback.succeeded();
        } catch (Throwable th2) {
            WebSocketException webSocketException = new WebSocketException(this.endpointInstance.getClass().getSimpleName() + " ERROR method error: " + th.getMessage(), th2);
            webSocketException.addSuppressed(th);
            callback.failed(webSocketException);
        }
    }

    public Set<MessageHandler> getMessageHandlers() {
        return (Set) this.messageHandlerMap.values().stream().map((v0) -> {
            return v0.getMessageHandler();
        }).collect(Collectors.toUnmodifiableSet());
    }

    public Map<Byte, RegisteredMessageHandler> getMessageHandlerMap() {
        return this.messageHandlerMap;
    }

    public JakartaWebSocketMessageMetadata getBinaryMetadata() {
        return this.binaryMetadata;
    }

    public JakartaWebSocketMessageMetadata getTextMetadata() {
        return this.textMetadata;
    }

    public <T> void addMessageHandler(Class<T> cls, MessageHandler.Partial<T> partial) {
        byte b;
        try {
            MethodHandle bindTo = JakartaWebSocketFrameHandlerFactory.getServerMethodHandleLookup().findVirtual(MessageHandler.Partial.class, "onMessage", MethodType.methodType(Void.TYPE, Object.class, Boolean.TYPE)).bindTo(partial);
            JakartaWebSocketMessageMetadata jakartaWebSocketMessageMetadata = new JakartaWebSocketMessageMetadata();
            jakartaWebSocketMessageMetadata.setMethodHandle(bindTo);
            if (byte[].class.isAssignableFrom(cls)) {
                b = 2;
                jakartaWebSocketMessageMetadata.setSinkClass(PartialByteArrayMessageSink.class);
            } else if (ByteBuffer.class.isAssignableFrom(cls)) {
                b = 2;
                jakartaWebSocketMessageMetadata.setSinkClass(PartialByteBufferMessageSink.class);
            } else {
                if (!String.class.isAssignableFrom(cls)) {
                    throw new RuntimeException("Unable to add " + partial.getClass().getName() + " with type " + cls + ": only supported types byte[], " + ByteBuffer.class.getName() + ", " + String.class.getName());
                }
                b = 1;
                jakartaWebSocketMessageMetadata.setSinkClass(PartialStringMessageSink.class);
            }
            registerMessageHandler((Class<?>) cls, (MessageHandler) partial, b, jakartaWebSocketMessageMetadata);
        } catch (IllegalAccessException e) {
            throw new IllegalStateException("Unable to access " + partial.getClass().getName(), e);
        } catch (NoSuchMethodException e2) {
            throw new IllegalStateException("Unable to find method", e2);
        }
    }

    public <T> void addMessageHandler(Class<T> cls, MessageHandler.Whole<T> whole) {
        byte b;
        try {
            MethodHandle bindTo = JakartaWebSocketFrameHandlerFactory.getServerMethodHandleLookup().findVirtual(MessageHandler.Whole.class, "onMessage", MethodType.methodType((Class<?>) Void.TYPE, (Class<?>) Object.class)).bindTo(whole);
            if (PongMessage.class.isAssignableFrom(cls)) {
                assertBasicTypeNotRegistered((byte) 10, whole);
                this.pongHandle = bindTo;
                registerMessageHandler((byte) 10, (Class) cls, (MessageHandler) whole, (MessageSink) null);
                return;
            }
            AvailableDecoders decoders = this.session.getDecoders();
            RegisteredDecoder firstRegisteredDecoder = decoders.getFirstRegisteredDecoder(cls);
            if (firstRegisteredDecoder == null) {
                throw new IllegalStateException("Unable to find Decoder for type: " + cls);
            }
            JakartaWebSocketMessageMetadata jakartaWebSocketMessageMetadata = new JakartaWebSocketMessageMetadata();
            jakartaWebSocketMessageMetadata.setMethodHandle(bindTo);
            if (firstRegisteredDecoder.implementsInterface(Decoder.Binary.class)) {
                b = 2;
                jakartaWebSocketMessageMetadata.setRegisteredDecoders(decoders.getBinaryDecoders(cls));
                jakartaWebSocketMessageMetadata.setSinkClass(DecodedBinaryMessageSink.class);
            } else if (firstRegisteredDecoder.implementsInterface(Decoder.BinaryStream.class)) {
                b = 2;
                jakartaWebSocketMessageMetadata.setRegisteredDecoders(decoders.getBinaryStreamDecoders(cls));
                jakartaWebSocketMessageMetadata.setSinkClass(DecodedBinaryStreamMessageSink.class);
            } else if (firstRegisteredDecoder.implementsInterface(Decoder.Text.class)) {
                b = 1;
                jakartaWebSocketMessageMetadata.setRegisteredDecoders(decoders.getTextDecoders(cls));
                jakartaWebSocketMessageMetadata.setSinkClass(DecodedTextMessageSink.class);
            } else {
                if (!firstRegisteredDecoder.implementsInterface(Decoder.TextStream.class)) {
                    throw new RuntimeException("Unable to add " + whole.getClass().getName() + ": type " + cls + " is unrecognized by declared decoders");
                }
                b = 1;
                jakartaWebSocketMessageMetadata.setRegisteredDecoders(decoders.getTextStreamDecoders(cls));
                jakartaWebSocketMessageMetadata.setSinkClass(DecodedTextStreamMessageSink.class);
            }
            registerMessageHandler((Class<?>) cls, (MessageHandler) whole, b, jakartaWebSocketMessageMetadata);
        } catch (IllegalAccessException e) {
            throw new IllegalStateException("Unable to access " + whole.getClass().getName(), e);
        } catch (NoSuchMethodException e2) {
            throw new IllegalStateException("Unable to find method", e2);
        }
    }

    private void assertBasicTypeNotRegistered(byte b, MessageHandler messageHandler) {
        MessageSink messageSink;
        switch (b) {
            case 1:
                messageSink = this.textSink;
                break;
            case 2:
                messageSink = this.binarySink;
                break;
            case 10:
                messageSink = this.pongHandle;
                break;
            default:
                throw new IllegalStateException();
        }
        if (messageSink != null) {
            throw new IllegalStateException("Cannot register " + messageHandler.getClass().getName() + ": Basic WebSocket type " + OpCode.name(b) + " is already registered");
        }
    }

    private void registerMessageHandler(Class<?> cls, MessageHandler messageHandler, byte b, JakartaWebSocketMessageMetadata jakartaWebSocketMessageMetadata) {
        assertBasicTypeNotRegistered(b, messageHandler);
        MessageSink createMessageSink = JakartaWebSocketFrameHandlerFactory.createMessageSink(this.session, jakartaWebSocketMessageMetadata);
        switch (b) {
            case 1:
                this.textSink = registerMessageHandler((byte) 1, cls, messageHandler, createMessageSink);
                this.textMetadata = jakartaWebSocketMessageMetadata;
                return;
            case 2:
                this.binarySink = registerMessageHandler((byte) 2, cls, messageHandler, createMessageSink);
                this.binaryMetadata = jakartaWebSocketMessageMetadata;
                return;
            default:
                throw new IllegalStateException();
        }
    }

    private <T> MessageSink registerMessageHandler(byte b, Class<T> cls, MessageHandler messageHandler, MessageSink messageSink) {
        AutoLock lock = this.lock.lock();
        try {
            RegisteredMessageHandler registeredMessageHandler = this.messageHandlerMap.get(Byte.valueOf(b));
            if (registeredMessageHandler != null) {
                throw new IllegalStateException(String.format("Cannot register %s: Basic WebSocket type %s is already registered to %s", messageHandler.getClass().getName(), OpCode.name(b), registeredMessageHandler.getMessageHandler().getClass().getName()));
            }
            RegisteredMessageHandler registeredMessageHandler2 = new RegisteredMessageHandler(b, cls, messageHandler);
            getMessageHandlerMap().put(Byte.valueOf(registeredMessageHandler2.getWebsocketMessageType()), registeredMessageHandler2);
            if (lock != null) {
                lock.close();
            }
            return messageSink;
        } catch (Throwable th) {
            if (lock != null) {
                try {
                    lock.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public void removeMessageHandler(MessageHandler messageHandler) {
        AutoLock lock = this.lock.lock();
        try {
            Optional<Map.Entry<Byte, RegisteredMessageHandler>> findFirst = this.messageHandlerMap.entrySet().stream().filter(entry -> {
                return ((RegisteredMessageHandler) entry.getValue()).getMessageHandler().equals(messageHandler);
            }).findFirst();
            if (findFirst.isPresent()) {
                byte byteValue = findFirst.get().getKey().byteValue();
                this.messageHandlerMap.remove(Byte.valueOf(byteValue));
                switch (byteValue) {
                    case 1:
                        this.textMetadata = null;
                        this.textSink = null;
                        break;
                    case 2:
                        this.binaryMetadata = null;
                        this.binarySink = null;
                        break;
                    case 10:
                        this.pongHandle = null;
                        break;
                    default:
                        throw new IllegalStateException("Invalid MessageHandler type " + OpCode.name(byteValue));
                }
            }
            if (lock != null) {
                lock.close();
            }
        } catch (Throwable th) {
            if (lock != null) {
                try {
                    lock.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(getClass().getSimpleName());
        sb.append('@').append(Integer.toHexString(hashCode()));
        sb.append("[endpoint=");
        if (this.endpointInstance == null) {
            sb.append("<null>");
        } else {
            sb.append(this.endpointInstance.getClass().getName());
        }
        sb.append(']');
        return sb.toString();
    }

    private void acceptMessage(Frame frame, Callback callback) {
        if (this.activeMessageSink == null) {
            callback.succeeded();
            return;
        }
        this.activeMessageSink.accept(frame, callback);
        if (frame.isFin()) {
            this.activeMessageSink = null;
        }
    }

    public void onPing(Frame frame, Callback callback) {
        this.coreSession.sendFrame(new Frame((byte) 10).setPayload(BufferUtil.copy(frame.getPayload())), Callback.NOOP, false);
        callback.succeeded();
    }

    public void onPong(Frame frame, Callback callback) {
        if (this.pongHandle != null) {
            try {
                ByteBuffer payload = frame.getPayload();
                if (payload == null) {
                    payload = BufferUtil.EMPTY_BUFFER;
                }
                (void) this.pongHandle.invoke(new JakartaWebSocketPongMessage(payload));
            } catch (Throwable th) {
                throw new WebSocketException(this.endpointInstance.getClass().getSimpleName() + " PONG method error: " + th.getMessage(), th);
            }
        }
        callback.succeeded();
    }

    public void onText(Frame frame, Callback callback) {
        if (this.activeMessageSink == null) {
            this.activeMessageSink = this.textSink;
        }
        acceptMessage(frame, callback);
    }

    public void onBinary(Frame frame, Callback callback) {
        if (this.activeMessageSink == null) {
            this.activeMessageSink = this.binarySink;
        }
        acceptMessage(frame, callback);
    }

    public void onContinuation(Frame frame, Callback callback) {
        switch (this.dataType) {
            case 1:
                onText(frame, callback);
                return;
            case 2:
                onBinary(frame, callback);
                return;
            default:
                throw new ProtocolException("Unable to process continuation during dataType " + this.dataType);
        }
    }

    public void setUpgradeRequest(UpgradeRequest upgradeRequest) {
        this.upgradeRequest = upgradeRequest;
    }

    public UpgradeRequest getUpgradeRequest() {
        return this.upgradeRequest;
    }

    private void configListener(String str, Object obj) {
        if (str.startsWith("org.eclipse.jetty.websocket.")) {
            boolean z = -1;
            switch (str.hashCode()) {
                case -1540887766:
                    if (str.equals("org.eclipse.jetty.websocket.outputBufferSize")) {
                        z = 2;
                        break;
                    }
                    break;
                case -1376488441:
                    if (str.equals("org.eclipse.jetty.websocket.autoFragment")) {
                        z = false;
                        break;
                    }
                    break;
                case -430556942:
                    if (str.equals("org.eclipse.jetty.websocket.maxFrameSize")) {
                        z = true;
                        break;
                    }
                    break;
                case 383553507:
                    if (str.equals("org.eclipse.jetty.websocket.inputBufferSize")) {
                        z = 3;
                        break;
                    }
                    break;
            }
            switch (z) {
                case false:
                    this.coreSession.setAutoFragment(((Boolean) obj).booleanValue());
                    return;
                case true:
                    this.coreSession.setMaxFrameSize(((Long) obj).longValue());
                    return;
                case true:
                    this.coreSession.setOutputBufferSize(((Integer) obj).intValue());
                    return;
                case true:
                    this.coreSession.setInputBufferSize(((Integer) obj).intValue());
                    return;
                default:
                    return;
            }
        }
    }
}
