mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-28 20:50:49 +00:00
add mysql protocol test
This commit is contained in:
parent
b9e2c0d72c
commit
6cac6a4f76
@ -58,19 +58,19 @@ bool MySQLClient::handshake()
|
||||
Native41 native41(password, handshake.auth_plugin_data);
|
||||
String auth_plugin_data = native41.getAuthPluginData();
|
||||
|
||||
HandshakeResponse handshakeResponse(
|
||||
HandshakeResponse handshake_response(
|
||||
client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password);
|
||||
packet_sender->sendPacket<HandshakeResponse>(handshakeResponse, true);
|
||||
packet_sender->sendPacket<HandshakeResponse>(handshake_response, true);
|
||||
|
||||
PacketResponse packetResponse(handshake.capability_flags);
|
||||
packet_sender->receivePacket(packetResponse);
|
||||
if (packetResponse.getType() != PACKET_ERR)
|
||||
PacketResponse packet_response(handshake.capability_flags);
|
||||
packet_sender->receivePacket(packet_response);
|
||||
if (packet_response.getType() != PACKET_ERR)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
last_error = packetResponse.err.error_message;
|
||||
last_error = packet_response.err.error_message;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -857,8 +857,9 @@ public:
|
||||
class EOF_Packet : public WritePacket, public ReadPacket
|
||||
{
|
||||
public:
|
||||
int warnings;
|
||||
int status_flags;
|
||||
UInt8 header = 0xfe;
|
||||
int warnings = 0;
|
||||
int status_flags = 0;
|
||||
|
||||
EOF_Packet() = default;
|
||||
EOF_Packet(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_)
|
||||
@ -871,14 +872,13 @@ public:
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(0xfe); // EOF header
|
||||
buffer.write(header); // EOF header
|
||||
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
|
||||
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
}
|
||||
|
||||
void readPayloadImpl(ReadBuffer & payload) override
|
||||
{
|
||||
UInt8 header = 0;
|
||||
payload.readStrict(reinterpret_cast<char *>(&header), 1);
|
||||
assert(header == 0xfe);
|
||||
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
|
||||
@ -889,6 +889,7 @@ public:
|
||||
class ERR_Packet : public WritePacket, public ReadPacket
|
||||
{
|
||||
public:
|
||||
UInt8 header = 0xff;
|
||||
int error_code = 0;
|
||||
String sql_state;
|
||||
String error_message;
|
||||
@ -906,7 +907,7 @@ public:
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(0xff);
|
||||
buffer.write(header);
|
||||
buffer.write(reinterpret_cast<const char *>(&error_code), 2);
|
||||
buffer.write('#');
|
||||
buffer.write(sql_state.data(), sql_state.length());
|
||||
@ -915,7 +916,6 @@ public:
|
||||
|
||||
void readPayloadImpl(ReadBuffer & payload) override
|
||||
{
|
||||
UInt8 header = 0;
|
||||
payload.readStrict(reinterpret_cast<char *>(&header), 1);
|
||||
assert(header == 0xff);
|
||||
|
||||
@ -923,9 +923,9 @@ public:
|
||||
payload.ignore(1);
|
||||
|
||||
sql_state.resize(5);
|
||||
payload.readStrict(reinterpret_cast<char *>(&sql_state), 5);
|
||||
payload.readStrict(reinterpret_cast<char *>(sql_state.data()), 5);
|
||||
|
||||
readStringUntilEOF(error_message, payload);
|
||||
readNullTerminated(error_message, payload);
|
||||
}
|
||||
};
|
||||
|
||||
@ -941,7 +941,7 @@ public:
|
||||
|
||||
void readPayloadImpl(ReadBuffer & payload) override
|
||||
{
|
||||
UInt8 header = *payload.position();
|
||||
UInt8 header = static_cast<unsigned char>(*payload.position());
|
||||
switch (header)
|
||||
{
|
||||
case PACKET_OK:
|
||||
|
@ -18,3 +18,6 @@ endif ()
|
||||
|
||||
add_executable (mysql_client mysql_client.cpp)
|
||||
target_link_libraries (mysql_client PRIVATE dbms)
|
||||
|
||||
add_executable (mysql_protocol mysql_protocol.cpp)
|
||||
target_link_libraries (mysql_protocol PRIVATE dbms)
|
||||
|
@ -1,31 +0,0 @@
|
||||
#include <string>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
|
||||
using namespace DB;
|
||||
using namespace MySQLProtocol;
|
||||
|
||||
TEST(MySQLProtocol, Handshake)
|
||||
{
|
||||
UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
|
||||
| CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
|
||||
|
||||
std::string s;
|
||||
WriteBufferFromString out(s);
|
||||
Handshake server_handshake(server_capability_flags, 0, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa");
|
||||
server_handshake.writePayloadImpl(out);
|
||||
|
||||
ReadBufferFromString in(s);
|
||||
Handshake client_handshake;
|
||||
client_handshake.readPayloadImpl(in);
|
||||
|
||||
EXPECT_EQ(server_handshake.capability_flags, client_handshake.capability_flags);
|
||||
EXPECT_EQ(server_handshake.status_flags, client_handshake.status_flags);
|
||||
EXPECT_EQ(server_handshake.server_version, client_handshake.server_version);
|
||||
EXPECT_EQ(server_handshake.protocol_version, client_handshake.protocol_version);
|
||||
EXPECT_EQ(server_handshake.auth_plugin_data.substr(0, 20), client_handshake.auth_plugin_data);
|
||||
EXPECT_EQ(server_handshake.auth_plugin_name, client_handshake.auth_plugin_name);
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
#include <iostream>
|
||||
#include <Core/MySQLClient.h>
|
||||
|
||||
|
||||
int main(int, char **)
|
||||
{
|
||||
using namespace DB;
|
||||
|
||||
UInt16 port = 9001;
|
||||
String host = "127.0.0.1", user = "default", password = "123";
|
||||
MySQLClient client(host, port, user, password, "");
|
||||
if (!client.connect())
|
||||
{
|
||||
std::cerr << "Connect Error: " << client.error() << std::endl;
|
||||
}
|
||||
return 0;
|
||||
}
|
146
src/Core/tests/mysql_protocol.cpp
Normal file
146
src/Core/tests/mysql_protocol.cpp
Normal file
@ -0,0 +1,146 @@
|
||||
#include <string>
|
||||
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
|
||||
|
||||
int main(int, char **)
|
||||
{
|
||||
using namespace DB;
|
||||
using namespace MySQLProtocol;
|
||||
using namespace MySQLProtocol::Authentication;
|
||||
|
||||
/*
|
||||
UInt16 port = 9001;
|
||||
String host = "127.0.0.1", user = "default", password = "123";
|
||||
MySQLClient client(host, port, user, password, "");
|
||||
if (!client.connect())
|
||||
{
|
||||
std::cerr << "Connect Error: " << client.error() << std::endl;
|
||||
}
|
||||
*/
|
||||
String user = "default";
|
||||
String password = "123";
|
||||
String database = "";
|
||||
|
||||
UInt8 charset_utf8 = 33;
|
||||
UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH;
|
||||
String mysql_native_password = "mysql_native_password";
|
||||
|
||||
UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
|
||||
| CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
|
||||
|
||||
UInt32 client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
|
||||
|
||||
/// Handshake packet
|
||||
{
|
||||
/// 1. Greeting:
|
||||
/// 1.1 Server writes greeting to client
|
||||
std::string s0;
|
||||
WriteBufferFromString out0(s0);
|
||||
|
||||
Handshake server_handshake(server_capability_flags, -1, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa");
|
||||
server_handshake.writePayloadImpl(out0);
|
||||
|
||||
/// 1.2 Client reads the greeting
|
||||
ReadBufferFromString in0(s0);
|
||||
Handshake client_handshake;
|
||||
client_handshake.readPayloadImpl(in0);
|
||||
|
||||
/// Check packet
|
||||
ASSERT(server_handshake.capability_flags == client_handshake.capability_flags);
|
||||
ASSERT(server_handshake.status_flags == client_handshake.status_flags);
|
||||
ASSERT(server_handshake.server_version == client_handshake.server_version);
|
||||
ASSERT(server_handshake.protocol_version == client_handshake.protocol_version);
|
||||
ASSERT(server_handshake.auth_plugin_data.substr(0, 20) == client_handshake.auth_plugin_data);
|
||||
ASSERT(server_handshake.auth_plugin_name == client_handshake.auth_plugin_name);
|
||||
|
||||
/// 2. Greeting Response:
|
||||
std::string s1;
|
||||
WriteBufferFromString out1(s1);
|
||||
|
||||
/// 2.1 Client writes to server
|
||||
Native41 native41(password, client_handshake.auth_plugin_data);
|
||||
String auth_plugin_data = native41.getAuthPluginData();
|
||||
HandshakeResponse client_handshake_response(
|
||||
client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password);
|
||||
client_handshake_response.writePayloadImpl(out1);
|
||||
|
||||
/// 2.2 Server reads the response
|
||||
ReadBufferFromString in1(s1);
|
||||
HandshakeResponse server_handshake_response;
|
||||
server_handshake_response.readPayloadImpl(in1);
|
||||
|
||||
/// Check
|
||||
ASSERT(server_handshake_response.capability_flags == client_handshake_response.capability_flags);
|
||||
ASSERT(server_handshake_response.character_set == client_handshake_response.character_set);
|
||||
ASSERT(server_handshake_response.username == client_handshake_response.username);
|
||||
ASSERT(server_handshake_response.database == client_handshake_response.database);
|
||||
ASSERT(server_handshake_response.auth_response == client_handshake_response.auth_response);
|
||||
ASSERT(server_handshake_response.auth_plugin_name == client_handshake_response.auth_plugin_name);
|
||||
}
|
||||
|
||||
/// OK Packet
|
||||
{
|
||||
// 1. Server writes packet
|
||||
std::string s0;
|
||||
WriteBufferFromString out0(s0);
|
||||
OK_Packet server(0x00, server_capability_flags, 0, 0, 0, "", "");
|
||||
server.writePayloadImpl(out0);
|
||||
|
||||
// 2. Client reads packet
|
||||
ReadBufferFromString in1(s0);
|
||||
PacketResponse client(server_capability_flags);
|
||||
client.readPayloadImpl(in1);
|
||||
|
||||
// Check
|
||||
ASSERT(client.getType() == PACKET_OK);
|
||||
ASSERT(client.ok.header == server.header);
|
||||
ASSERT(client.ok.status_flags == server.status_flags);
|
||||
ASSERT(client.ok.capabilities == server.capabilities);
|
||||
}
|
||||
|
||||
/// ERR Packet
|
||||
{
|
||||
// 1. Server writes packet
|
||||
std::string s0;
|
||||
WriteBufferFromString out0(s0);
|
||||
ERR_Packet server(123, "12345", "This is the error message");
|
||||
server.writePayloadImpl(out0);
|
||||
|
||||
// 2. Client reads packet
|
||||
ReadBufferFromString in1(s0);
|
||||
PacketResponse client(server_capability_flags);
|
||||
client.readPayloadImpl(in1);
|
||||
|
||||
// Check
|
||||
ASSERT(client.getType() == PACKET_ERR);
|
||||
ASSERT(client.err.header == server.header);
|
||||
ASSERT(client.err.error_code == server.error_code);
|
||||
ASSERT(client.err.sql_state == server.sql_state);
|
||||
ASSERT(client.err.error_message == server.error_message);
|
||||
}
|
||||
|
||||
/// EOF Packet
|
||||
{
|
||||
// 1. Server writes packet
|
||||
std::string s0;
|
||||
WriteBufferFromString out0(s0);
|
||||
EOF_Packet server(1, 1);
|
||||
server.writePayloadImpl(out0);
|
||||
|
||||
// 2. Client reads packet
|
||||
ReadBufferFromString in1(s0);
|
||||
PacketResponse client(server_capability_flags);
|
||||
client.readPayloadImpl(in1);
|
||||
|
||||
// Check
|
||||
ASSERT(client.getType() == PACKET_EOF);
|
||||
ASSERT(client.eof.header == server.header);
|
||||
ASSERT(client.eof.warnings == server.warnings);
|
||||
ASSERT(client.eof.status_flags == server.status_flags);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user