From c1332834a944ea10f121f0bb118cf8025fe48777 Mon Sep 17 00:00:00 2001 From: BohuTANG Date: Mon, 20 Apr 2020 19:27:58 +0800 Subject: [PATCH] add ERR packet parse --- src/Core/MySQLClient.cpp | 22 +++++++++-- src/Core/MySQLClient.h | 4 +- src/Core/MySQLProtocol.h | 67 +++++++++++++++++++++++++++++---- src/Core/tests/mysql_client.cpp | 2 +- 4 files changed, 81 insertions(+), 14 deletions(-) diff --git a/src/Core/MySQLClient.cpp b/src/Core/MySQLClient.cpp index df4fabc512a..0fbe3a5b067 100644 --- a/src/Core/MySQLClient.cpp +++ b/src/Core/MySQLClient.cpp @@ -36,7 +36,7 @@ void MySQLClient::connect() packet_sender = std::make_shared(*in, *out, seq); connected = true; - handshake(*in); + handshake(); } catch (Poco::Net::NetException & e) { @@ -61,10 +61,10 @@ void MySQLClient::close() } /// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html -void MySQLClient::handshake(ReadBuffer & payload) +void MySQLClient::handshake() { Handshake handshake; - handshake.readPayloadImpl(payload); + packet_sender->receivePacket(handshake); if (handshake.auth_plugin_name != mysql_native_password) { throw Exception( @@ -75,7 +75,21 @@ void MySQLClient::handshake(ReadBuffer & payload) Native41 native41(password, handshake.auth_plugin_data); String auth_plugin_data = native41.getAuthPluginData(); - HandshakeResponse handshakeResponse(client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password); + HandshakeResponse handshakeResponse( + client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password); packet_sender->sendPacket(handshakeResponse, true); + + PacketResponse packetResponse; + packet_sender->receivePacket(packetResponse); + + switch (packetResponse.getType()) { + case PACKET_OK: + break; + case PACKET_ERR: + throw Exception(packetResponse.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER); + break; + case PACKET_EOF: + break; + } } } diff --git a/src/Core/MySQLClient.h b/src/Core/MySQLClient.h index 043f88e2710..813af4b91ba 100644 --- a/src/Core/MySQLClient.h +++ b/src/Core/MySQLClient.h @@ -33,7 +33,7 @@ private: bool connected = false; UInt32 client_capability_flags = 0; - uint8_t seq = 1; + uint8_t seq = 0; UInt8 charset_utf8 = 33; UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH; String mysql_native_password = "mysql_native_password"; @@ -43,7 +43,7 @@ private: std::unique_ptr socket; std::optional address; - void handshake(ReadBuffer & payload); + void handshake(); protected: std::shared_ptr packet_sender; diff --git a/src/Core/MySQLProtocol.h b/src/Core/MySQLProtocol.h index 4699656ebba..ce62e111060 100644 --- a/src/Core/MySQLProtocol.h +++ b/src/Core/MySQLProtocol.h @@ -136,6 +136,12 @@ enum ColumnType MYSQL_TYPE_GEOMETRY = 0xff }; +enum ResponsePacketType +{ + PACKET_OK = 0x00, + PACKET_ERR = 0xff, + PACKET_EOF = 0xfe, +}; // https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html enum ColumnDefinitionFlags @@ -237,6 +243,7 @@ public: virtual void readPayload(ReadBuffer & in, uint8_t & sequence_id) { PacketPayloadReadBuffer payload(in, sequence_id); + payload.next(); readPayloadImpl(payload); if (!payload.eof()) { @@ -484,8 +491,6 @@ public: void readPayloadImpl(ReadBuffer & buffer) override { - buffer.ignore(4); - /// 1-byte: [0a] protocol version buffer.readStrict(reinterpret_cast(&protocol_version), 1); @@ -587,7 +592,9 @@ public: , username(std::move(username_)) , database(std::move(database_)) , auth_response(std::move(auth_response_)) - , auth_plugin_name(std::move(auth_plugin_name_)){}; + , auth_plugin_name(std::move(auth_plugin_name_)) + { + } size_t getPayloadSize() const override { @@ -851,18 +858,19 @@ protected: } }; -class ERR_Packet : public WritePacket +class ERR_Packet : public WritePacket, public ReadPacket { - int error_code; +public: + int error_code = 0; String sql_state; String error_message; -public: + + ERR_Packet() = default; ERR_Packet(int error_code_, String sql_state_, String error_message_) : error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_)) { } -protected: size_t getPayloadSize() const override { return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE); @@ -876,6 +884,51 @@ protected: buffer.write(sql_state.data(), sql_state.length()); buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE)); } + + void readPayloadImpl(ReadBuffer & payload) override + { + UInt8 header = 0; + payload.readStrict(reinterpret_cast(&header), 1); + assert(header == 0xff); + payload.readStrict(reinterpret_cast(&error_code), 2); + payload.ignore(1); + payload.readStrict(reinterpret_cast(&sql_state), 5); + readString(error_message, payload); + } +}; + +/// https://dev.mysql.com/doc/internals/en/generic-response-packets.html +class PacketResponse : public ReadPacket +{ +public: + OK_Packet * ok; + ERR_Packet err; + EOF_Packet * eof; + + PacketResponse() = default; + + void readPayloadImpl(ReadBuffer & payload) override + { + UInt8 header = *payload.position(); + switch (header) + { + case PACKET_OK: + packetType = PACKET_OK; + break; + case PACKET_ERR: + packetType = PACKET_ERR; + err.readPayloadImpl(payload); + break; + case PACKET_EOF: + packetType = PACKET_EOF; + break; + }; + } + + ResponsePacketType getType() { return packetType; } + +private: + ResponsePacketType packetType = PACKET_OK; }; class ColumnDefinition : public WritePacket diff --git a/src/Core/tests/mysql_client.cpp b/src/Core/tests/mysql_client.cpp index 435df518dce..eaf85ac5b5b 100644 --- a/src/Core/tests/mysql_client.cpp +++ b/src/Core/tests/mysql_client.cpp @@ -6,7 +6,7 @@ int main(int, char **) using namespace DB; UInt16 port = 4407; - String host = "127.0.0.1", user = "root", password = "mock"; + String host = "127.0.0.1", user = "mock", password = "mock"; MySQLClient client(host, port, user, password, ""); client.connect(); return 0;