From 6cac6a4f7687ec511f423b5d799f00ee1c3469eb Mon Sep 17 00:00:00 2001 From: BohuTANG Date: Wed, 22 Apr 2020 15:26:12 +0800 Subject: [PATCH] add mysql protocol test --- src/Core/MySQLClient.cpp | 12 +- src/Core/MySQLProtocol.h | 18 +-- src/Core/tests/CMakeLists.txt | 3 + src/Core/tests/gtest_MySQLProtocol.cpp | 31 ------ src/Core/tests/mysql_client.cpp | 17 --- src/Core/tests/mysql_protocol.cpp | 146 +++++++++++++++++++++++++ 6 files changed, 164 insertions(+), 63 deletions(-) delete mode 100644 src/Core/tests/gtest_MySQLProtocol.cpp delete mode 100644 src/Core/tests/mysql_client.cpp create mode 100644 src/Core/tests/mysql_protocol.cpp diff --git a/src/Core/MySQLClient.cpp b/src/Core/MySQLClient.cpp index 1ebffefdd1e..430ee902143 100644 --- a/src/Core/MySQLClient.cpp +++ b/src/Core/MySQLClient.cpp @@ -58,19 +58,19 @@ bool MySQLClient::handshake() Native41 native41(password, handshake.auth_plugin_data); String auth_plugin_data = native41.getAuthPluginData(); - HandshakeResponse handshakeResponse( + HandshakeResponse handshake_response( client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password); - packet_sender->sendPacket(handshakeResponse, true); + packet_sender->sendPacket(handshake_response, true); - PacketResponse packetResponse(handshake.capability_flags); - packet_sender->receivePacket(packetResponse); - if (packetResponse.getType() != PACKET_ERR) + PacketResponse packet_response(handshake.capability_flags); + packet_sender->receivePacket(packet_response); + if (packet_response.getType() != PACKET_ERR) { return true; } else { - last_error = packetResponse.err.error_message; + last_error = packet_response.err.error_message; return false; } } diff --git a/src/Core/MySQLProtocol.h b/src/Core/MySQLProtocol.h index 35c4c3da1b9..979b5beef32 100644 --- a/src/Core/MySQLProtocol.h +++ b/src/Core/MySQLProtocol.h @@ -857,8 +857,9 @@ public: class EOF_Packet : public WritePacket, public ReadPacket { public: - int warnings; - int status_flags; + UInt8 header = 0xfe; + int warnings = 0; + int status_flags = 0; EOF_Packet() = default; EOF_Packet(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_) @@ -871,14 +872,13 @@ public: void writePayloadImpl(WriteBuffer & buffer) const override { - buffer.write(0xfe); // EOF header + buffer.write(header); // EOF header buffer.write(reinterpret_cast(&warnings), 2); buffer.write(reinterpret_cast(&status_flags), 2); } void readPayloadImpl(ReadBuffer & payload) override { - UInt8 header = 0; payload.readStrict(reinterpret_cast(&header), 1); assert(header == 0xfe); payload.readStrict(reinterpret_cast(&warnings), 2); @@ -889,6 +889,7 @@ public: class ERR_Packet : public WritePacket, public ReadPacket { public: + UInt8 header = 0xff; int error_code = 0; String sql_state; String error_message; @@ -906,7 +907,7 @@ public: void writePayloadImpl(WriteBuffer & buffer) const override { - buffer.write(0xff); + buffer.write(header); buffer.write(reinterpret_cast(&error_code), 2); buffer.write('#'); buffer.write(sql_state.data(), sql_state.length()); @@ -915,7 +916,6 @@ public: void readPayloadImpl(ReadBuffer & payload) override { - UInt8 header = 0; payload.readStrict(reinterpret_cast(&header), 1); assert(header == 0xff); @@ -923,9 +923,9 @@ public: payload.ignore(1); sql_state.resize(5); - payload.readStrict(reinterpret_cast(&sql_state), 5); + payload.readStrict(reinterpret_cast(sql_state.data()), 5); - readStringUntilEOF(error_message, payload); + readNullTerminated(error_message, payload); } }; @@ -941,7 +941,7 @@ public: void readPayloadImpl(ReadBuffer & payload) override { - UInt8 header = *payload.position(); + UInt8 header = static_cast(*payload.position()); switch (header) { case PACKET_OK: diff --git a/src/Core/tests/CMakeLists.txt b/src/Core/tests/CMakeLists.txt index c1af57b5b85..a5d694b7358 100644 --- a/src/Core/tests/CMakeLists.txt +++ b/src/Core/tests/CMakeLists.txt @@ -18,3 +18,6 @@ endif () add_executable (mysql_client mysql_client.cpp) target_link_libraries (mysql_client PRIVATE dbms) + +add_executable (mysql_protocol mysql_protocol.cpp) +target_link_libraries (mysql_protocol PRIVATE dbms) diff --git a/src/Core/tests/gtest_MySQLProtocol.cpp b/src/Core/tests/gtest_MySQLProtocol.cpp deleted file mode 100644 index f29bd8738f7..00000000000 --- a/src/Core/tests/gtest_MySQLProtocol.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include -#include - -#include -#include -#include - -using namespace DB; -using namespace MySQLProtocol; - -TEST(MySQLProtocol, Handshake) -{ - UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH - | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF; - - std::string s; - WriteBufferFromString out(s); - Handshake server_handshake(server_capability_flags, 0, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa"); - server_handshake.writePayloadImpl(out); - - ReadBufferFromString in(s); - Handshake client_handshake; - client_handshake.readPayloadImpl(in); - - EXPECT_EQ(server_handshake.capability_flags, client_handshake.capability_flags); - EXPECT_EQ(server_handshake.status_flags, client_handshake.status_flags); - EXPECT_EQ(server_handshake.server_version, client_handshake.server_version); - EXPECT_EQ(server_handshake.protocol_version, client_handshake.protocol_version); - EXPECT_EQ(server_handshake.auth_plugin_data.substr(0, 20), client_handshake.auth_plugin_data); - EXPECT_EQ(server_handshake.auth_plugin_name, client_handshake.auth_plugin_name); -} diff --git a/src/Core/tests/mysql_client.cpp b/src/Core/tests/mysql_client.cpp deleted file mode 100644 index bce50befaa0..00000000000 --- a/src/Core/tests/mysql_client.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include -#include - - -int main(int, char **) -{ - using namespace DB; - - UInt16 port = 9001; - String host = "127.0.0.1", user = "default", password = "123"; - MySQLClient client(host, port, user, password, ""); - if (!client.connect()) - { - std::cerr << "Connect Error: " << client.error() << std::endl; - } - return 0; -} diff --git a/src/Core/tests/mysql_protocol.cpp b/src/Core/tests/mysql_protocol.cpp new file mode 100644 index 00000000000..870eb4966c6 --- /dev/null +++ b/src/Core/tests/mysql_protocol.cpp @@ -0,0 +1,146 @@ +#include + +#include +#include +#include + + +int main(int, char **) +{ + using namespace DB; + using namespace MySQLProtocol; + using namespace MySQLProtocol::Authentication; + + /* + UInt16 port = 9001; + String host = "127.0.0.1", user = "default", password = "123"; + MySQLClient client(host, port, user, password, ""); + if (!client.connect()) + { + std::cerr << "Connect Error: " << client.error() << std::endl; + } + */ + String user = "default"; + String password = "123"; + String database = ""; + + UInt8 charset_utf8 = 33; + UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH; + String mysql_native_password = "mysql_native_password"; + + UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH + | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF; + + UInt32 client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION; + + /// Handshake packet + { + /// 1. Greeting: + /// 1.1 Server writes greeting to client + std::string s0; + WriteBufferFromString out0(s0); + + Handshake server_handshake(server_capability_flags, -1, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa"); + server_handshake.writePayloadImpl(out0); + + /// 1.2 Client reads the greeting + ReadBufferFromString in0(s0); + Handshake client_handshake; + client_handshake.readPayloadImpl(in0); + + /// Check packet + ASSERT(server_handshake.capability_flags == client_handshake.capability_flags); + ASSERT(server_handshake.status_flags == client_handshake.status_flags); + ASSERT(server_handshake.server_version == client_handshake.server_version); + ASSERT(server_handshake.protocol_version == client_handshake.protocol_version); + ASSERT(server_handshake.auth_plugin_data.substr(0, 20) == client_handshake.auth_plugin_data); + ASSERT(server_handshake.auth_plugin_name == client_handshake.auth_plugin_name); + + /// 2. Greeting Response: + std::string s1; + WriteBufferFromString out1(s1); + + /// 2.1 Client writes to server + Native41 native41(password, client_handshake.auth_plugin_data); + String auth_plugin_data = native41.getAuthPluginData(); + HandshakeResponse client_handshake_response( + client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password); + client_handshake_response.writePayloadImpl(out1); + + /// 2.2 Server reads the response + ReadBufferFromString in1(s1); + HandshakeResponse server_handshake_response; + server_handshake_response.readPayloadImpl(in1); + + /// Check + ASSERT(server_handshake_response.capability_flags == client_handshake_response.capability_flags); + ASSERT(server_handshake_response.character_set == client_handshake_response.character_set); + ASSERT(server_handshake_response.username == client_handshake_response.username); + ASSERT(server_handshake_response.database == client_handshake_response.database); + ASSERT(server_handshake_response.auth_response == client_handshake_response.auth_response); + ASSERT(server_handshake_response.auth_plugin_name == client_handshake_response.auth_plugin_name); + } + + /// OK Packet + { + // 1. Server writes packet + std::string s0; + WriteBufferFromString out0(s0); + OK_Packet server(0x00, server_capability_flags, 0, 0, 0, "", ""); + server.writePayloadImpl(out0); + + // 2. Client reads packet + ReadBufferFromString in1(s0); + PacketResponse client(server_capability_flags); + client.readPayloadImpl(in1); + + // Check + ASSERT(client.getType() == PACKET_OK); + ASSERT(client.ok.header == server.header); + ASSERT(client.ok.status_flags == server.status_flags); + ASSERT(client.ok.capabilities == server.capabilities); + } + + /// ERR Packet + { + // 1. Server writes packet + std::string s0; + WriteBufferFromString out0(s0); + ERR_Packet server(123, "12345", "This is the error message"); + server.writePayloadImpl(out0); + + // 2. Client reads packet + ReadBufferFromString in1(s0); + PacketResponse client(server_capability_flags); + client.readPayloadImpl(in1); + + // Check + ASSERT(client.getType() == PACKET_ERR); + ASSERT(client.err.header == server.header); + ASSERT(client.err.error_code == server.error_code); + ASSERT(client.err.sql_state == server.sql_state); + ASSERT(client.err.error_message == server.error_message); + } + + /// EOF Packet + { + // 1. Server writes packet + std::string s0; + WriteBufferFromString out0(s0); + EOF_Packet server(1, 1); + server.writePayloadImpl(out0); + + // 2. Client reads packet + ReadBufferFromString in1(s0); + PacketResponse client(server_capability_flags); + client.readPayloadImpl(in1); + + // Check + ASSERT(client.getType() == PACKET_EOF); + ASSERT(client.eof.header == server.header); + ASSERT(client.eof.warnings == server.warnings); + ASSERT(client.eof.status_flags == server.status_flags); + } + + return 0; +}