package org.apache.shardingsphere.proxy.frontend.mysql.authentication;

import com.google.common.base.Strings;
import io.netty.channel.ChannelHandlerContext;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Optional;
import org.apache.shardingsphere.authority.checker.AuthorityChecker;
import org.apache.shardingsphere.authority.rule.AuthorityRule;
import org.apache.shardingsphere.db.protocol.constant.CommonConstants;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLCapabilityFlag;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLCharacterSet;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConnectionPhase;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLConstants;
import org.apache.shardingsphere.db.protocol.mysql.constant.MySQLStatusFlag;
import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLErrPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.generic.MySQLOKPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthSwitchRequestPacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthSwitchResponsePacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLAuthenticationPluginData;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandshakePacket;
import org.apache.shardingsphere.db.protocol.mysql.packet.handshake.MySQLHandshakeResponse41Packet;
import org.apache.shardingsphere.db.protocol.mysql.payload.MySQLPacketPayload;
import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
import org.apache.shardingsphere.dialect.mysql.vendor.MySQLVendorError;
import org.apache.shardingsphere.infra.metadata.user.Grantee;
import org.apache.shardingsphere.infra.metadata.user.ShardingSphereUser;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResult;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationResultBuilder;
import org.apache.shardingsphere.proxy.frontend.authentication.Authenticator;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticatorFactory;
import org.apache.shardingsphere.proxy.frontend.connection.ConnectionIdGenerator;
import org.apache.shardingsphere.proxy.frontend.mysql.authentication.authenticator.MySQLAuthenticatorType;
import org.apache.shardingsphere.proxy.frontend.mysql.command.query.binary.MySQLStatementIDGenerator;

/* loaded from: input_file:org/apache/shardingsphere/proxy/frontend/mysql/authentication/MySQLAuthenticationEngine.class */
public final class MySQLAuthenticationEngine implements AuthenticationEngine {
    private final MySQLAuthenticationPluginData authPluginData = new MySQLAuthenticationPluginData();
    private MySQLConnectionPhase connectionPhase = MySQLConnectionPhase.INITIAL_HANDSHAKE;
    private byte[] authResponse;
    private AuthenticationResult currentAuthResult;

    public int handshake(ChannelHandlerContext channelHandlerContext) {
        int nextId = ConnectionIdGenerator.getInstance().nextId();
        this.connectionPhase = MySQLConnectionPhase.AUTH_PHASE_FAST_PATH;
        channelHandlerContext.writeAndFlush(new MySQLHandshakePacket(nextId, this.authPluginData));
        MySQLStatementIDGenerator.getInstance().registerConnection(nextId);
        return nextId;
    }

    public AuthenticationResult authenticate(ChannelHandlerContext channelHandlerContext, PacketPayload packetPayload) {
        AuthorityRule authorityRule = (AuthorityRule) ProxyContext.getInstance().getContextManager().getMetaDataContexts().getMetaData().getGlobalRuleMetaData().getSingleRule(AuthorityRule.class);
        if (MySQLConnectionPhase.AUTH_PHASE_FAST_PATH == this.connectionPhase) {
            this.currentAuthResult = authenticatePhaseFastPath(channelHandlerContext, packetPayload, authorityRule);
            if (!this.currentAuthResult.isFinished()) {
                return this.currentAuthResult;
            }
        } else if (MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH == this.connectionPhase) {
            authenticateMismatchedMethod((MySQLPacketPayload) packetPayload);
        }
        Grantee grantee = new Grantee(this.currentAuthResult.getUsername(), getHostAddress(channelHandlerContext));
        if (login(authorityRule, grantee, this.authResponse)) {
            if (authorizeDatabase(authorityRule, grantee, this.currentAuthResult.getDatabase())) {
                writeOKPacket(channelHandlerContext);
                return AuthenticationResultBuilder.finished(grantee.getUsername(), grantee.getHostname(), this.currentAuthResult.getDatabase());
            }
            writeErrorPacket(channelHandlerContext, new MySQLErrPacket(MySQLVendorError.ER_DBACCESS_DENIED_ERROR, new Object[]{this.currentAuthResult.getUsername(), grantee.getHostname(), this.currentAuthResult.getDatabase()}));
            return AuthenticationResultBuilder.continued();
        }
        MySQLVendorError mySQLVendorError = MySQLVendorError.ER_ACCESS_DENIED_ERROR;
        Object[] objArr = new Object[3];
        objArr[0] = this.currentAuthResult.getUsername();
        objArr[1] = grantee.getHostname();
        objArr[2] = 0 == this.authResponse.length ? "NO" : "YES";
        writeErrorPacket(channelHandlerContext, new MySQLErrPacket(mySQLVendorError, objArr));
        return AuthenticationResultBuilder.continued();
    }

    private AuthenticationResult authenticatePhaseFastPath(ChannelHandlerContext channelHandlerContext, PacketPayload packetPayload, AuthorityRule authorityRule) {
        MySQLHandshakeResponse41Packet mySQLHandshakeResponse41Packet = new MySQLHandshakeResponse41Packet((MySQLPacketPayload) packetPayload);
        String database = mySQLHandshakeResponse41Packet.getDatabase();
        this.authResponse = mySQLHandshakeResponse41Packet.getAuthResponse();
        setCharacterSet(channelHandlerContext, mySQLHandshakeResponse41Packet);
        if (!Strings.isNullOrEmpty(database) && !ProxyContext.getInstance().databaseExists(database)) {
            writeErrorPacket(channelHandlerContext, new MySQLErrPacket(MySQLVendorError.ER_BAD_DB_ERROR, new Object[]{database}));
            return AuthenticationResultBuilder.continued();
        }
        String username = mySQLHandshakeResponse41Packet.getUsername();
        String hostAddress = getHostAddress(channelHandlerContext);
        Authenticator newInstance = new AuthenticatorFactory(MySQLAuthenticatorType.class, authorityRule).newInstance((ShardingSphereUser) authorityRule.findUser(new Grantee(username, hostAddress)).orElseGet(() -> {
            return new ShardingSphereUser(username, "", hostAddress);
        }));
        if (!isClientPluginAuthenticate(mySQLHandshakeResponse41Packet) || newInstance.getAuthenticationMethod().getMethodName().equals(mySQLHandshakeResponse41Packet.getAuthPluginName())) {
            return AuthenticationResultBuilder.finished(username, hostAddress, database);
        }
        this.connectionPhase = MySQLConnectionPhase.AUTHENTICATION_METHOD_MISMATCH;
        channelHandlerContext.writeAndFlush(new MySQLAuthSwitchRequestPacket(newInstance.getAuthenticationMethod().getMethodName(), this.authPluginData));
        return AuthenticationResultBuilder.continued(username, hostAddress, database);
    }

    private void setCharacterSet(ChannelHandlerContext channelHandlerContext, MySQLHandshakeResponse41Packet mySQLHandshakeResponse41Packet) {
        MySQLCharacterSet findById = MySQLCharacterSet.findById(mySQLHandshakeResponse41Packet.getCharacterSet());
        channelHandlerContext.channel().attr(CommonConstants.CHARSET_ATTRIBUTE_KEY).set(findById.getCharset());
        channelHandlerContext.channel().attr(MySQLConstants.MYSQL_CHARACTER_SET_ATTRIBUTE_KEY).set(findById);
    }

    private boolean isClientPluginAuthenticate(MySQLHandshakeResponse41Packet mySQLHandshakeResponse41Packet) {
        return 0 != (mySQLHandshakeResponse41Packet.getCapabilityFlags() & MySQLCapabilityFlag.CLIENT_PLUGIN_AUTH.getValue());
    }

    private void authenticateMismatchedMethod(MySQLPacketPayload mySQLPacketPayload) {
        this.authResponse = new MySQLAuthSwitchResponsePacket(mySQLPacketPayload).getAuthPluginResponse();
    }

    private boolean login(AuthorityRule authorityRule, Grantee grantee, byte[] bArr) {
        Optional findUser = authorityRule.findUser(grantee);
        return findUser.isPresent() && new AuthenticatorFactory(MySQLAuthenticatorType.class, authorityRule).newInstance((ShardingSphereUser) findUser.get()).authenticate((ShardingSphereUser) findUser.get(), new Object[]{bArr, this.authPluginData});
    }

    private boolean authorizeDatabase(AuthorityRule authorityRule, Grantee grantee, String str) {
        return null == str || new AuthorityChecker(authorityRule, grantee).isAuthorized(str);
    }

    private String getHostAddress(ChannelHandlerContext channelHandlerContext) {
        SocketAddress remoteAddress = channelHandlerContext.channel().remoteAddress();
        return remoteAddress instanceof InetSocketAddress ? ((InetSocketAddress) remoteAddress).getAddress().getHostAddress() : remoteAddress.toString();
    }

    private void writeErrorPacket(ChannelHandlerContext channelHandlerContext, MySQLErrPacket mySQLErrPacket) {
        channelHandlerContext.writeAndFlush(mySQLErrPacket);
        channelHandlerContext.close();
    }

    private void writeOKPacket(ChannelHandlerContext channelHandlerContext) {
        channelHandlerContext.writeAndFlush(new MySQLOKPacket(MySQLStatusFlag.SERVER_STATUS_AUTOCOMMIT.getValue()));
    }
}
