diff --git a/src/Core/MySQLProtocol.h b/src/Core/MySQLProtocol.h index 643168018a0..35c4c3da1b9 100644 --- a/src/Core/MySQLProtocol.h +++ b/src/Core/MySQLProtocol.h @@ -517,11 +517,9 @@ public: buffer.ignore(10); if (capability_flags & MySQLProtocol::CLIENT_SECURE_CONNECTION) { - UInt8 part2_length = (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH) > 13 - ? 13 - : (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH); - auth_plugin_data.resize(part2_length + AUTH_PLUGIN_DATA_PART_1_LENGTH - 1); - buffer.readStrict(auth_plugin_data.data() + AUTH_PLUGIN_DATA_PART_1_LENGTH, part2_length - 1); + UInt8 part2_length = (SCRAMBLE_LENGTH - AUTH_PLUGIN_DATA_PART_1_LENGTH); + auth_plugin_data.resize(SCRAMBLE_LENGTH); + buffer.readStrict(auth_plugin_data.data() + AUTH_PLUGIN_DATA_PART_1_LENGTH, part2_length); buffer.ignore(1); } @@ -958,7 +956,7 @@ public: packetType = PACKET_EOF; eof.readPayloadImpl(payload); break; - }; + } } ResponsePacketType getType() { return packetType; } diff --git a/src/Core/tests/gtest_MySQLProtocol.cpp b/src/Core/tests/gtest_MySQLProtocol.cpp new file mode 100644 index 00000000000..f29bd8738f7 --- /dev/null +++ b/src/Core/tests/gtest_MySQLProtocol.cpp @@ -0,0 +1,31 @@ +#include +#include + +#include +#include +#include + +using namespace DB; +using namespace MySQLProtocol; + +TEST(MySQLProtocol, Handshake) +{ + UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH + | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF; + + std::string s; + WriteBufferFromString out(s); + Handshake server_handshake(server_capability_flags, 0, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa"); + server_handshake.writePayloadImpl(out); + + ReadBufferFromString in(s); + Handshake client_handshake; + client_handshake.readPayloadImpl(in); + + EXPECT_EQ(server_handshake.capability_flags, client_handshake.capability_flags); + EXPECT_EQ(server_handshake.status_flags, client_handshake.status_flags); + EXPECT_EQ(server_handshake.server_version, client_handshake.server_version); + EXPECT_EQ(server_handshake.protocol_version, client_handshake.protocol_version); + EXPECT_EQ(server_handshake.auth_plugin_data.substr(0, 20), client_handshake.auth_plugin_data); + EXPECT_EQ(server_handshake.auth_plugin_name, client_handshake.auth_plugin_name); +}