package io.r2dbc.mssql.client.ssl;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.ssl.SslHandler;
import io.r2dbc.mssql.client.ConnectionContext;
import io.r2dbc.mssql.client.TdsEncoder;
import io.r2dbc.mssql.message.header.Header;
import io.r2dbc.mssql.message.header.HeaderOptions;
import io.r2dbc.mssql.message.header.PacketIdProvider;
import io.r2dbc.mssql.message.header.Status;
import io.r2dbc.mssql.message.header.Type;
import io.r2dbc.mssql.message.tds.ContextualTdsFragment;
import io.r2dbc.mssql.message.tds.TdsFragment;
import io.r2dbc.mssql.util.Assert;
import java.security.GeneralSecurityException;
import reactor.util.Logger;
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;

@ChannelHandler.Sharable
/* loaded from: input_file:io/r2dbc/mssql/client/ssl/TdsSslHandler.class */
public final class TdsSslHandler extends ChannelDuplexHandler {
    private static final Logger LOGGER = Loggers.getLogger(TdsSslHandler.class);
    public static final boolean DEBUG_ENABLED = LOGGER.isDebugEnabled();
    private final ConnectionContext connectionContext;
    private final PacketIdProvider packetIdProvider;
    private final SslConfiguration sslConfiguration;
    private volatile SslHandler sslHandler;
    private ChannelHandlerContext context;
    private ByteBuf outputBuffer;
    private SslState state = SslState.OFF;
    private boolean handshakeDone;

    @Nullable
    private Chunk chunk;

    /* loaded from: input_file:io/r2dbc/mssql/client/ssl/TdsSslHandler$Chunk.class */
    static class Chunk {
        Header header;
        final ByteBuf fullMessage;
        final CompositeByteBuf aggregator;
        int decoded = 0;

        Chunk(Header header, ByteBuf byteBuf, CompositeByteBuf compositeByteBuf) {
            this.header = header;
            this.fullMessage = byteBuf;
            this.aggregator = compositeByteBuf;
        }

        void defragment(ByteBuf byteBuf) {
            int remainingLength;
            this.aggregator.addComponent(true, byteBuf);
            while (this.aggregator.isReadable() && this.aggregator.readableBytes() >= (remainingLength = getRemainingLength())) {
                this.fullMessage.writeBytes(this.aggregator, remainingLength);
                if (!Header.canDecode(this.aggregator)) {
                    return;
                }
                updateHeader(Header.decode(this.aggregator));
                if (isCompleteHandshakeAvailable()) {
                    return;
                }
            }
        }

        void updateHeader(Header header) {
            this.decoded += this.header.getLength() - 8;
            this.header = header;
        }

        boolean isCompleteHandshakeAvailable() {
            return this.header.is(Status.StatusBit.EOM) && getRemainingLength() <= 0;
        }

        int getRemainingLength() {
            return this.header.getLength() - ((this.fullMessage.readableBytes() - this.decoded) + 8);
        }

        static boolean isCompletePacketAvailable(Header header, ByteBuf byteBuf) {
            return byteBuf.readableBytes() + 8 >= header.getLength();
        }
    }

    public TdsSslHandler(PacketIdProvider packetIdProvider, SslConfiguration sslConfiguration, ConnectionContext connectionContext) {
        Assert.requireNonNull(packetIdProvider, "PacketIdProvider must not be null");
        Assert.requireNonNull(sslConfiguration, "SslConfiguration must not be null");
        Assert.requireNonNull(connectionContext, "ConnectionContext must not be null");
        this.packetIdProvider = packetIdProvider;
        this.sslConfiguration = sslConfiguration;
        this.connectionContext = connectionContext;
    }

    void setSslHandler(SslHandler sslHandler) {
        this.sslHandler = sslHandler;
    }

    void setState(SslState sslState) {
        this.state = sslState;
    }

    private static SslHandler createSslHandler(SslConfiguration sslConfiguration, ByteBufAllocator byteBufAllocator) throws GeneralSecurityException {
        return new SslHandler(sslConfiguration.getSslProvider().getSslContext().newEngine(byteBufAllocator));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void userEventTriggered(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        if (obj == SslState.LOGIN_ONLY || obj == SslState.CONNECTION) {
            this.state = (SslState) obj;
            this.sslHandler = createSslHandler(this.sslConfiguration, channelHandlerContext.alloc());
            LOGGER.debug(this.connectionContext.getMessage("Registering Context Proxy and SSL Event Handlers to propagate SSL events to channelRead()"));
            channelHandlerContext.pipeline().addAfter(getClass().getName(), ContextProxy.class.getName(), new ContextProxy());
            channelHandlerContext.pipeline().addAfter(ContextProxy.class.getName(), SslEventHandler.class.getName(), new SslEventHandler());
            this.context = channelHandlerContext.channel().pipeline().context(ContextProxy.class.getName());
            channelHandlerContext.write(HeaderOptions.create(Type.PRE_LOGIN, Status.empty()));
            this.sslHandler.handlerAdded(this.context);
        }
        if (obj == SslState.NEGOTIATED) {
            LOGGER.debug(this.connectionContext.getMessage("SSL Handshake done"));
            channelHandlerContext.write(TdsEncoder.ResetHeader.INSTANCE, channelHandlerContext.voidPromise());
            this.handshakeDone = true;
            if (this.state == SslState.CONNECTION) {
                LOGGER.debug(this.connectionContext.getMessage("Reordering handlers for full SSL usage"));
                channelHandlerContext.pipeline().remove(this);
                channelHandlerContext.pipeline().addFirst(new ChannelHandler[]{this});
            }
        }
        super.userEventTriggered(channelHandlerContext, obj);
    }

    public void handlerAdded(ChannelHandlerContext channelHandlerContext) {
        this.outputBuffer = channelHandlerContext.alloc().buffer();
    }

    public void handlerRemoved(ChannelHandlerContext channelHandlerContext) {
        if (this.outputBuffer != null) {
            this.outputBuffer.release();
            this.outputBuffer = null;
        }
    }

    public void channelInactive(ChannelHandlerContext channelHandlerContext) throws Exception {
        if (this.sslHandler != null) {
            this.sslHandler.channelInactive(channelHandlerContext);
        }
        Chunk chunk = this.chunk;
        if (chunk != null) {
            chunk.fullMessage.release();
            chunk.aggregator.release();
            this.chunk = null;
        }
    }

    public void write(ChannelHandlerContext channelHandlerContext, Object obj, ChannelPromise channelPromise) throws Exception {
        if (this.handshakeDone && (this.state == SslState.NEGOTIATED || this.state == SslState.LOGIN_ONLY || this.state == SslState.CONNECTION)) {
            this.sslHandler.write(channelHandlerContext, unwrap(channelHandlerContext.alloc(), obj), channelPromise);
            this.sslHandler.flush(channelHandlerContext);
            if (this.state == SslState.LOGIN_ONLY) {
                this.state = SslState.AFTER_LOGIN_ONLY;
                return;
            }
            return;
        }
        if (!requiresWrapping()) {
            super.write(channelHandlerContext, obj, channelPromise);
            return;
        }
        if (DEBUG_ENABLED) {
            LOGGER.debug(this.connectionContext.getMessage("Write wrapping: Append to output buffer"));
        }
        ByteBuf byteBuf = (ByteBuf) obj;
        this.outputBuffer.writeBytes(byteBuf);
        byteBuf.release();
    }

    private Object unwrap(ByteBufAllocator byteBufAllocator, Object obj) {
        if (!(obj instanceof ContextualTdsFragment)) {
            return obj instanceof TdsFragment ? ((TdsFragment) obj).getByteBuf() : obj;
        }
        ContextualTdsFragment contextualTdsFragment = (ContextualTdsFragment) obj;
        HeaderOptions headerOptions = contextualTdsFragment.getHeaderOptions();
        Header header = new Header(headerOptions.getType(), headerOptions.getStatus().and(Status.StatusBit.EOM), 8 + contextualTdsFragment.getByteBuf().readableBytes(), 0, this.packetIdProvider.nextPacketId(), 0);
        ByteBuf buffer = byteBufAllocator.buffer(header.getLength());
        header.encode(buffer);
        buffer.writeBytes(contextualTdsFragment.getByteBuf());
        contextualTdsFragment.getByteBuf().release();
        return buffer;
    }

    public void flush(ChannelHandlerContext channelHandlerContext) throws Exception {
        if (!requiresWrapping()) {
            super.flush(channelHandlerContext);
            return;
        }
        if (DEBUG_ENABLED) {
            LOGGER.debug(this.connectionContext.getMessage("Write wrapping: Flushing output buffer and enable auto-read"));
        }
        ByteBuf byteBuf = this.outputBuffer;
        this.outputBuffer = channelHandlerContext.alloc().buffer();
        channelHandlerContext.writeAndFlush(byteBuf);
        channelHandlerContext.channel().config().setAutoRead(true);
    }

    public void channelReadComplete(ChannelHandlerContext channelHandlerContext) throws Exception {
        if (isInHandshake() && this.outputBuffer.readableBytes() > 0) {
            flush(channelHandlerContext);
        }
        super.channelReadComplete(channelHandlerContext);
    }

    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        Header header;
        if (!isInHandshake()) {
            if (this.handshakeDone && this.state == SslState.CONNECTION) {
                this.sslHandler.channelRead(channelHandlerContext, obj);
                return;
            } else {
                super.channelRead(channelHandlerContext, obj);
                return;
            }
        }
        ByteBuf byteBuf = (ByteBuf) obj;
        Chunk chunk = this.chunk;
        if (chunk != null || Header.canDecode(byteBuf)) {
            if (chunk == null) {
                header = Header.decode(byteBuf);
                if (!Chunk.isCompletePacketAvailable(header, byteBuf)) {
                    ByteBuf buffer = byteBuf.alloc().buffer(header.getLength());
                    buffer.writeBytes(byteBuf);
                    byteBuf.release();
                    this.chunk = new Chunk(header, buffer, byteBuf.alloc().compositeBuffer());
                    channelHandlerContext.read();
                    return;
                }
            } else {
                chunk.defragment(byteBuf);
                if (!chunk.isCompleteHandshakeAvailable()) {
                    return;
                }
                byteBuf = chunk.fullMessage;
                header = chunk.header;
                this.chunk.aggregator.release();
                this.chunk = null;
            }
            if (header.getType() == Type.PRE_LOGIN) {
                this.sslHandler.channelRead(this.context, byteBuf);
            }
            if (header.is(Status.StatusBit.IGNORE)) {
            }
        }
    }

    private boolean isInHandshake() {
        return requiresWrapping() && !this.handshakeDone;
    }

    private boolean requiresWrapping() {
        return this.state == SslState.LOGIN_ONLY || this.state == SslState.CONNECTION;
    }
}
