MySQL: make MySQLProtocol work in server&client duplex mode

This commit is contained in:
BohuTANG 2020-04-16 21:27:06 +08:00 committed by zhang2014
parent c58d0b428d
commit 63c0f495b9
8 changed files with 331 additions and 40 deletions

View File

@ -84,7 +84,6 @@ add_subdirectory (compressor)
add_subdirectory (copier)
add_subdirectory (format)
add_subdirectory (obfuscator)
add_subdirectory (myrepl-client)
if (ENABLE_CLICKHOUSE_ODBC_BRIDGE)
add_subdirectory (odbc-bridge)

View File

@ -1,4 +0,0 @@
include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_executable(myrepl-client myrepl_client.cpp)
target_link_libraries(myrepl-client)

View File

@ -1,9 +0,0 @@
#include <iostream>
int main(int argc, char ** argv)
{
std::cout << "Try: " << argv[1] << std::endl;
return 0;
}

81
src/Core/MySQLClient.cpp Normal file
View File

@ -0,0 +1,81 @@
#include <Core/MySQLClient.h>
namespace DB
{
using namespace MySQLProtocol;
using namespace MySQLProtocol::Authentication;
namespace ErrorCodes
{
extern const int NETWORK_ERROR;
extern const int SOCKET_TIMEOUT;
extern const int UNKNOWN_PACKET_FROM_SERVER;
}
MySQLClient::MySQLClient(const String & _host, UInt16 _port, const String & _user, const String & _password, const String & _database)
: host(_host), port(_port), user(_user), password(_password), database(_database)
{
client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
}
void MySQLClient::connect()
{
try
{
if (connected)
{
close();
}
socket = std::make_unique<Poco::Net::StreamSocket>();
address = DNSResolver::instance().resolveAddress(host, port);
socket->connect(*address);
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_sender = std::make_shared<PacketSender>(*in, *out, seq);
connected = true;
handshake(*in);
}
catch (Poco::Net::NetException & e)
{
close();
throw NetException(e.displayText(), ErrorCodes::NETWORK_ERROR);
}
catch (Poco::TimeoutException & e)
{
close();
throw NetException(e.displayText(), ErrorCodes::SOCKET_TIMEOUT);
}
}
void MySQLClient::close()
{
in = nullptr;
out = nullptr;
if (socket)
socket->close();
socket = nullptr;
connected = false;
}
/// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
void MySQLClient::handshake(ReadBuffer & payload)
{
Handshake handshake;
handshake.readPayloadImpl(payload);
if (handshake.auth_plugin_name != mysql_native_password)
{
throw Exception(
"Only support " + mysql_native_password + " auth plugin name, but got " + handshake.auth_plugin_name,
ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
}
Native41 native41(password, handshake.auth_plugin_data);
String response = native41.getAuthPluginData();
HandshakeResponse handshakeResponse(client_capability_flags, 0, charset_utf8, user, database, handshake.auth_plugin_data, mysql_native_password);
packet_sender->sendPacket<HandshakeResponse>(handshakeResponse, true);
}
}

50
src/Core/MySQLClient.h Normal file
View File

@ -0,0 +1,50 @@
#pragma once
#include <Core/Types.h>
#include <Core/MySQLProtocol.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <IO/WriteHelpers.h>
#include <Poco/Net/NetException.h>
#include <Poco/Net/StreamSocket.h>
#include <Common/DNSResolver.h>
#include <Common/Exception.h>
#include <Common/NetException.h>
namespace DB
{
class MySQLClient
{
public:
MySQLClient(const String & _host, UInt16 _port, const String & _user, const String & _password, const String & _database);
void connect();
void close();
private:
String host;
UInt16 port;
String user;
String password;
String database;
bool connected = false;
UInt32 client_capability_flags = 0;
uint8_t seq = 0;
UInt8 charset_utf8 = 33;
String mysql_native_password = "mysql_native_password";
std::shared_ptr<ReadBuffer> in;
std::shared_ptr<WriteBuffer> out;
std::unique_ptr<Poco::Net::StreamSocket> socket;
std::optional<Poco::Net::SocketAddress> address;
void handshake(ReadBuffer & payload);
protected:
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
};
}

View File

@ -227,17 +227,17 @@ protected:
};
class ClientPacket
class ReadPacket
{
public:
ClientPacket() = default;
ReadPacket() = default;
ClientPacket(ClientPacket &&) = default;
ReadPacket(ReadPacket &&) = default;
virtual void read(ReadBuffer & in, uint8_t & sequence_id)
virtual void readPayload(ReadBuffer & in, uint8_t & sequence_id)
{
PacketPayloadReadBuffer payload(in, sequence_id);
readPayload(payload);
readPayloadImpl(payload);
if (!payload.eof())
{
std::stringstream tmp;
@ -246,19 +246,19 @@ public:
}
}
virtual void readPayload(ReadBuffer & buf) = 0;
virtual void readPayloadImpl(ReadBuffer & buf) = 0;
virtual ~ClientPacket() = default;
virtual ~ReadPacket() = default;
};
class LimitedClientPacket : public ClientPacket
class LimitedReadPacket : public ReadPacket
{
public:
void read(ReadBuffer & in, uint8_t & sequence_id) override
void readPayload(ReadBuffer & in, uint8_t & sequence_id) override
{
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
ClientPacket::read(limited, sequence_id);
ReadPacket::readPayload(limited, sequence_id);
}
};
@ -359,7 +359,6 @@ protected:
virtual void writePayloadImpl(WriteBuffer & buffer) const = 0;
};
/* Writes and reads packets, keeping sequence-id.
* Throws ProtocolError, if packet with incorrect sequence-id was received.
*/
@ -387,9 +386,9 @@ public:
{
}
void receivePacket(ClientPacket & packet)
void receivePacket(ReadPacket & packet)
{
packet.read(*in, sequence_id);
packet.readPayload(*in, sequence_id);
}
template<class T>
@ -435,8 +434,9 @@ size_t getLengthEncodedNumberSize(uint64_t x);
size_t getLengthEncodedStringSize(const String & s);
class Handshake : public WritePacket
class Handshake : public WritePacket, ReadPacket
{
public:
int protocol_version = 0xa;
String server_version;
uint32_t connection_id;
@ -445,8 +445,10 @@ class Handshake : public WritePacket
uint32_t status_flags;
String auth_plugin_name;
String auth_plugin_data;
public:
explicit Handshake(uint32_t capability_flags_, uint32_t connection_id_, String server_version_, String auth_plugin_name_, String auth_plugin_data_)
Handshake() = default;
Handshake(uint32_t capability_flags_, uint32_t connection_id_, String server_version_, String auth_plugin_name_, String auth_plugin_data_)
: protocol_version(0xa)
, server_version(std::move(server_version_))
, connection_id(connection_id_)
@ -458,7 +460,6 @@ public:
{
}
protected:
size_t getPayloadSize() const override
{
return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size();
@ -480,16 +481,76 @@ protected:
writeString(auth_plugin_name, buffer);
writeChar(0x0, 1, buffer);
}
void readPayloadImpl(ReadBuffer & buffer) override
{
buffer.ignore(4);
/// 1-byte: [0a] protocol version
buffer.readStrict(reinterpret_cast<char *>(&protocol_version), 1);
/// string[NUL]: server version
readNullTerminated(server_version, buffer);
/// 4-bytes: connection id
buffer.readStrict(reinterpret_cast<char *>(&connection_id), 4);
/// 8-bytes: auth-plugin-data-part-1
buffer.readStrict(reinterpret_cast<char *>(&auth_plugin_data), AUTH_PLUGIN_DATA_PART_1_LENGTH);
/// 1-byte: [00] filler
buffer.ignore(1);
/// 2-bytes: capability flags lower 2-bytes
buffer.readStrict(reinterpret_cast<char *>(&capability_flags), 2);
/// 1-byte: character set
buffer.readStrict(reinterpret_cast<char *>(&character_set), 1);
/// 2-bytes: status flags(ignored)
buffer.readStrict(reinterpret_cast<char *>(&status_flags), 2);
/// 2-bytes: capability flags upper 2-bytes
buffer.readStrict((reinterpret_cast<char *>(&capability_flags)) + 2, 2);
UInt8 auth_plugin_data_length = 0;
if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH)
{
/// 1-byte: length of auth-plugin-data
buffer.readStrict(reinterpret_cast<char *>(&auth_plugin_data_length), 1);
}
else
{
/// 1-byte: [00]
buffer.ignore(1);
}
/// string[10] reserved (all [00])
buffer.ignore(10);
if (capability_flags & MySQLProtocol::CLIENT_SECURE_CONNECTION)
{
UInt8 part2_length = (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH) > 13
? 13
: (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH);
buffer.readStrict((reinterpret_cast<char *>(&auth_plugin_data)) + AUTH_PLUGIN_DATA_PART_1_LENGTH, part2_length);
}
if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH)
{
readNullTerminated(auth_plugin_name, buffer);
}
}
};
class SSLRequest : public ClientPacket
class SSLRequest : public ReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
void readPayload(ReadBuffer & buf) override
void readPayloadImpl(ReadBuffer & buf) override
{
buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
@ -497,20 +558,94 @@ public:
}
};
class HandshakeResponse : public LimitedClientPacket
class HandshakeResponse : public WritePacket, ReadPacket
{
public:
uint32_t capability_flags = 0;
uint32_t max_packet_size = 0;
uint8_t character_set = 0;
String username;
String auth_response;
String database;
String auth_response;
String auth_plugin_name;
HandshakeResponse() = default;
void readPayload(ReadBuffer & payload) override
HandshakeResponse(
UInt32 _capability_flags,
UInt32 _max_packet_size,
UInt8 _character_set,
const String & _username,
const String & _database,
const String & _auth_response,
const String & _auth_plugin_name)
: capability_flags(_capability_flags)
, max_packet_size(_max_packet_size)
, character_set(_character_set)
, username(_username)
, database(_database)
, auth_response(_auth_response)
, auth_plugin_name(_auth_plugin_name){};
size_t getPayloadSize() const override
{
size_t size = 0;
size += 4 + 4 + 1;
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
size += getLengthEncodedStringSize(auth_response);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
size += (1 + auth_response.size());
}
else
{
size += (auth_response.size() + 1);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
size += (database.size() + 1);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
size += (auth_plugin_name.size() + 1);
}
return size;
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(reinterpret_cast<const char *>(&capability_flags), 4);
buffer.write(reinterpret_cast<const char *>(&max_packet_size), 4);
buffer.write(reinterpret_cast<const char *>(&character_set), 1);
writeNulTerminatedString(username, buffer);
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
writeLengthEncodedString(auth_response, buffer);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
writeString(auth_response, buffer);
}
else
{
writeNulTerminatedString(auth_response, buffer);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
writeNulTerminatedString(database, buffer);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
writeNulTerminatedString(auth_plugin_name, buffer);
}
}
void readPayloadImpl(ReadBuffer & payload) override
{
payload.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
payload.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
@ -573,12 +708,12 @@ protected:
}
};
class AuthSwitchResponse : public LimitedClientPacket
class AuthSwitchResponse : public LimitedReadPacket
{
public:
String value;
void readPayload(ReadBuffer & payload) override
void readPayloadImpl(ReadBuffer & payload) override
{
readStringUntilEOF(value, payload);
}
@ -806,12 +941,12 @@ protected:
}
};
class ComFieldList : public LimitedClientPacket
class ComFieldList : public LimitedReadPacket
{
public:
String table, field_wildcard;
void readPayload(ReadBuffer & payload) override
void readPayloadImpl(ReadBuffer & payload) override
{
// Command byte has been already read from payload.
readNullTerminated(table, payload);
@ -931,6 +1066,29 @@ public:
}
}
Native41(const String & password, const String & auth_plugin_data)
{
/// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
/// SHA1( password ) XOR SHA1( "20-bytes random data from server" <concat> SHA1( SHA1( password ) ) )
Poco::SHA1Engine engine;
engine.update(password.data(), password.size());
const Poco::SHA1Engine::Digest & password_sha1 = engine.digest();
engine.update(password_sha1.data(), password_sha1.size());
const Poco::SHA1Engine::Digest & password_double_sha1 = engine.digest();
engine.reset();
engine.update(auth_plugin_data.data(), auth_plugin_data.size());
engine.update(password_double_sha1.data(), password_double_sha1.size());
const Poco::SHA1Engine::Digest & digest = engine.digest();
scramble.resize(Poco::SHA1Engine::DIGEST_SIZE);
for (size_t i = 0; i < scramble.size(); i++)
{
scramble[i] = password_sha1[i] ^ digest[i];
}
}
String getName() override
{
return "mysql_native_password";

View File

@ -15,3 +15,6 @@ if (ENABLE_FUZZING)
add_executable (names_and_types_fuzzer names_and_types_fuzzer.cpp)
target_link_libraries (names_and_types_fuzzer PRIVATE dbms ${LIB_FUZZING_ENGINE})
endif ()
add_executable (mysql_client mysql_client.cpp)
target_link_libraries (mysql_client PRIVATE dbms)

View File

@ -0,0 +1,13 @@
#include <Core/MySQLClient.h>
int main(int, char **)
{
using namespace DB;
UInt16 port = 3306;
String host = "127.0.0.1", user = "root", password = "";
MySQLClient client(host, port, user, password, "");
client.connect();
return 0;
}