mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
MySQL: make MySQLProtocol work in server&client duplex mode
This commit is contained in:
parent
c58d0b428d
commit
63c0f495b9
@ -84,7 +84,6 @@ add_subdirectory (compressor)
|
||||
add_subdirectory (copier)
|
||||
add_subdirectory (format)
|
||||
add_subdirectory (obfuscator)
|
||||
add_subdirectory (myrepl-client)
|
||||
|
||||
if (ENABLE_CLICKHOUSE_ODBC_BRIDGE)
|
||||
add_subdirectory (odbc-bridge)
|
||||
|
@ -1,4 +0,0 @@
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
add_executable(myrepl-client myrepl_client.cpp)
|
||||
target_link_libraries(myrepl-client)
|
@ -1,9 +0,0 @@
|
||||
#include <iostream>
|
||||
|
||||
|
||||
int main(int argc, char ** argv)
|
||||
{
|
||||
std::cout << "Try: " << argv[1] << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
81
src/Core/MySQLClient.cpp
Normal file
81
src/Core/MySQLClient.cpp
Normal file
@ -0,0 +1,81 @@
|
||||
#include <Core/MySQLClient.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
using namespace MySQLProtocol;
|
||||
using namespace MySQLProtocol::Authentication;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NETWORK_ERROR;
|
||||
extern const int SOCKET_TIMEOUT;
|
||||
extern const int UNKNOWN_PACKET_FROM_SERVER;
|
||||
}
|
||||
|
||||
MySQLClient::MySQLClient(const String & _host, UInt16 _port, const String & _user, const String & _password, const String & _database)
|
||||
: host(_host), port(_port), user(_user), password(_password), database(_database)
|
||||
{
|
||||
client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
|
||||
}
|
||||
|
||||
void MySQLClient::connect()
|
||||
{
|
||||
try
|
||||
{
|
||||
if (connected)
|
||||
{
|
||||
close();
|
||||
}
|
||||
|
||||
socket = std::make_unique<Poco::Net::StreamSocket>();
|
||||
address = DNSResolver::instance().resolveAddress(host, port);
|
||||
socket->connect(*address);
|
||||
|
||||
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
|
||||
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
|
||||
packet_sender = std::make_shared<PacketSender>(*in, *out, seq);
|
||||
connected = true;
|
||||
|
||||
handshake(*in);
|
||||
}
|
||||
catch (Poco::Net::NetException & e)
|
||||
{
|
||||
close();
|
||||
throw NetException(e.displayText(), ErrorCodes::NETWORK_ERROR);
|
||||
}
|
||||
catch (Poco::TimeoutException & e)
|
||||
{
|
||||
close();
|
||||
throw NetException(e.displayText(), ErrorCodes::SOCKET_TIMEOUT);
|
||||
}
|
||||
}
|
||||
|
||||
void MySQLClient::close()
|
||||
{
|
||||
in = nullptr;
|
||||
out = nullptr;
|
||||
if (socket)
|
||||
socket->close();
|
||||
socket = nullptr;
|
||||
connected = false;
|
||||
}
|
||||
|
||||
/// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
|
||||
void MySQLClient::handshake(ReadBuffer & payload)
|
||||
{
|
||||
Handshake handshake;
|
||||
handshake.readPayloadImpl(payload);
|
||||
if (handshake.auth_plugin_name != mysql_native_password)
|
||||
{
|
||||
throw Exception(
|
||||
"Only support " + mysql_native_password + " auth plugin name, but got " + handshake.auth_plugin_name,
|
||||
ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
|
||||
}
|
||||
|
||||
Native41 native41(password, handshake.auth_plugin_data);
|
||||
String response = native41.getAuthPluginData();
|
||||
|
||||
HandshakeResponse handshakeResponse(client_capability_flags, 0, charset_utf8, user, database, handshake.auth_plugin_data, mysql_native_password);
|
||||
packet_sender->sendPacket<HandshakeResponse>(handshakeResponse, true);
|
||||
}
|
||||
}
|
50
src/Core/MySQLClient.h
Normal file
50
src/Core/MySQLClient.h
Normal file
@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
#include <Core/Types.h>
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <IO/ReadBufferFromPocoSocket.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteBufferFromPocoSocket.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Poco/Net/NetException.h>
|
||||
#include <Poco/Net/StreamSocket.h>
|
||||
#include <Common/DNSResolver.h>
|
||||
#include <Common/Exception.h>
|
||||
#include <Common/NetException.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
class MySQLClient
|
||||
{
|
||||
public:
|
||||
MySQLClient(const String & _host, UInt16 _port, const String & _user, const String & _password, const String & _database);
|
||||
|
||||
void connect();
|
||||
|
||||
void close();
|
||||
|
||||
private:
|
||||
String host;
|
||||
UInt16 port;
|
||||
String user;
|
||||
String password;
|
||||
String database;
|
||||
|
||||
bool connected = false;
|
||||
UInt32 client_capability_flags = 0;
|
||||
|
||||
uint8_t seq = 0;
|
||||
UInt8 charset_utf8 = 33;
|
||||
String mysql_native_password = "mysql_native_password";
|
||||
|
||||
std::shared_ptr<ReadBuffer> in;
|
||||
std::shared_ptr<WriteBuffer> out;
|
||||
std::unique_ptr<Poco::Net::StreamSocket> socket;
|
||||
std::optional<Poco::Net::SocketAddress> address;
|
||||
|
||||
void handshake(ReadBuffer & payload);
|
||||
|
||||
protected:
|
||||
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
|
||||
};
|
||||
}
|
@ -227,17 +227,17 @@ protected:
|
||||
};
|
||||
|
||||
|
||||
class ClientPacket
|
||||
class ReadPacket
|
||||
{
|
||||
public:
|
||||
ClientPacket() = default;
|
||||
ReadPacket() = default;
|
||||
|
||||
ClientPacket(ClientPacket &&) = default;
|
||||
ReadPacket(ReadPacket &&) = default;
|
||||
|
||||
virtual void read(ReadBuffer & in, uint8_t & sequence_id)
|
||||
virtual void readPayload(ReadBuffer & in, uint8_t & sequence_id)
|
||||
{
|
||||
PacketPayloadReadBuffer payload(in, sequence_id);
|
||||
readPayload(payload);
|
||||
readPayloadImpl(payload);
|
||||
if (!payload.eof())
|
||||
{
|
||||
std::stringstream tmp;
|
||||
@ -246,19 +246,19 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
virtual void readPayload(ReadBuffer & buf) = 0;
|
||||
virtual void readPayloadImpl(ReadBuffer & buf) = 0;
|
||||
|
||||
virtual ~ClientPacket() = default;
|
||||
virtual ~ReadPacket() = default;
|
||||
};
|
||||
|
||||
|
||||
class LimitedClientPacket : public ClientPacket
|
||||
class LimitedReadPacket : public ReadPacket
|
||||
{
|
||||
public:
|
||||
void read(ReadBuffer & in, uint8_t & sequence_id) override
|
||||
void readPayload(ReadBuffer & in, uint8_t & sequence_id) override
|
||||
{
|
||||
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
|
||||
ClientPacket::read(limited, sequence_id);
|
||||
ReadPacket::readPayload(limited, sequence_id);
|
||||
}
|
||||
};
|
||||
|
||||
@ -359,7 +359,6 @@ protected:
|
||||
virtual void writePayloadImpl(WriteBuffer & buffer) const = 0;
|
||||
};
|
||||
|
||||
|
||||
/* Writes and reads packets, keeping sequence-id.
|
||||
* Throws ProtocolError, if packet with incorrect sequence-id was received.
|
||||
*/
|
||||
@ -387,9 +386,9 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
void receivePacket(ClientPacket & packet)
|
||||
void receivePacket(ReadPacket & packet)
|
||||
{
|
||||
packet.read(*in, sequence_id);
|
||||
packet.readPayload(*in, sequence_id);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
@ -435,8 +434,9 @@ size_t getLengthEncodedNumberSize(uint64_t x);
|
||||
size_t getLengthEncodedStringSize(const String & s);
|
||||
|
||||
|
||||
class Handshake : public WritePacket
|
||||
class Handshake : public WritePacket, ReadPacket
|
||||
{
|
||||
public:
|
||||
int protocol_version = 0xa;
|
||||
String server_version;
|
||||
uint32_t connection_id;
|
||||
@ -445,8 +445,10 @@ class Handshake : public WritePacket
|
||||
uint32_t status_flags;
|
||||
String auth_plugin_name;
|
||||
String auth_plugin_data;
|
||||
public:
|
||||
explicit Handshake(uint32_t capability_flags_, uint32_t connection_id_, String server_version_, String auth_plugin_name_, String auth_plugin_data_)
|
||||
|
||||
Handshake() = default;
|
||||
|
||||
Handshake(uint32_t capability_flags_, uint32_t connection_id_, String server_version_, String auth_plugin_name_, String auth_plugin_data_)
|
||||
: protocol_version(0xa)
|
||||
, server_version(std::move(server_version_))
|
||||
, connection_id(connection_id_)
|
||||
@ -458,7 +460,6 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size();
|
||||
@ -480,16 +481,76 @@ protected:
|
||||
writeString(auth_plugin_name, buffer);
|
||||
writeChar(0x0, 1, buffer);
|
||||
}
|
||||
|
||||
void readPayloadImpl(ReadBuffer & buffer) override
|
||||
{
|
||||
buffer.ignore(4);
|
||||
|
||||
/// 1-byte: [0a] protocol version
|
||||
buffer.readStrict(reinterpret_cast<char *>(&protocol_version), 1);
|
||||
|
||||
/// string[NUL]: server version
|
||||
readNullTerminated(server_version, buffer);
|
||||
|
||||
/// 4-bytes: connection id
|
||||
buffer.readStrict(reinterpret_cast<char *>(&connection_id), 4);
|
||||
|
||||
/// 8-bytes: auth-plugin-data-part-1
|
||||
buffer.readStrict(reinterpret_cast<char *>(&auth_plugin_data), AUTH_PLUGIN_DATA_PART_1_LENGTH);
|
||||
|
||||
/// 1-byte: [00] filler
|
||||
buffer.ignore(1);
|
||||
|
||||
/// 2-bytes: capability flags lower 2-bytes
|
||||
buffer.readStrict(reinterpret_cast<char *>(&capability_flags), 2);
|
||||
|
||||
/// 1-byte: character set
|
||||
buffer.readStrict(reinterpret_cast<char *>(&character_set), 1);
|
||||
|
||||
/// 2-bytes: status flags(ignored)
|
||||
buffer.readStrict(reinterpret_cast<char *>(&status_flags), 2);
|
||||
|
||||
/// 2-bytes: capability flags upper 2-bytes
|
||||
buffer.readStrict((reinterpret_cast<char *>(&capability_flags)) + 2, 2);
|
||||
|
||||
UInt8 auth_plugin_data_length = 0;
|
||||
if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH)
|
||||
{
|
||||
/// 1-byte: length of auth-plugin-data
|
||||
buffer.readStrict(reinterpret_cast<char *>(&auth_plugin_data_length), 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
/// 1-byte: [00]
|
||||
buffer.ignore(1);
|
||||
}
|
||||
|
||||
/// string[10] reserved (all [00])
|
||||
buffer.ignore(10);
|
||||
|
||||
if (capability_flags & MySQLProtocol::CLIENT_SECURE_CONNECTION)
|
||||
{
|
||||
UInt8 part2_length = (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH) > 13
|
||||
? 13
|
||||
: (auth_plugin_data_length - AUTH_PLUGIN_DATA_PART_1_LENGTH);
|
||||
buffer.readStrict((reinterpret_cast<char *>(&auth_plugin_data)) + AUTH_PLUGIN_DATA_PART_1_LENGTH, part2_length);
|
||||
}
|
||||
|
||||
if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH)
|
||||
{
|
||||
readNullTerminated(auth_plugin_name, buffer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class SSLRequest : public ClientPacket
|
||||
class SSLRequest : public ReadPacket
|
||||
{
|
||||
public:
|
||||
uint32_t capability_flags;
|
||||
uint32_t max_packet_size;
|
||||
uint8_t character_set;
|
||||
|
||||
void readPayload(ReadBuffer & buf) override
|
||||
void readPayloadImpl(ReadBuffer & buf) override
|
||||
{
|
||||
buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
|
||||
buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
|
||||
@ -497,20 +558,94 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class HandshakeResponse : public LimitedClientPacket
|
||||
class HandshakeResponse : public WritePacket, ReadPacket
|
||||
{
|
||||
public:
|
||||
uint32_t capability_flags = 0;
|
||||
uint32_t max_packet_size = 0;
|
||||
uint8_t character_set = 0;
|
||||
String username;
|
||||
String auth_response;
|
||||
String database;
|
||||
String auth_response;
|
||||
String auth_plugin_name;
|
||||
|
||||
HandshakeResponse() = default;
|
||||
|
||||
void readPayload(ReadBuffer & payload) override
|
||||
HandshakeResponse(
|
||||
UInt32 _capability_flags,
|
||||
UInt32 _max_packet_size,
|
||||
UInt8 _character_set,
|
||||
const String & _username,
|
||||
const String & _database,
|
||||
const String & _auth_response,
|
||||
const String & _auth_plugin_name)
|
||||
: capability_flags(_capability_flags)
|
||||
, max_packet_size(_max_packet_size)
|
||||
, character_set(_character_set)
|
||||
, username(_username)
|
||||
, database(_database)
|
||||
, auth_response(_auth_response)
|
||||
, auth_plugin_name(_auth_plugin_name){};
|
||||
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
size_t size = 0;
|
||||
size += 4 + 4 + 1;
|
||||
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
|
||||
{
|
||||
size += getLengthEncodedStringSize(auth_response);
|
||||
}
|
||||
else if (capability_flags & CLIENT_SECURE_CONNECTION)
|
||||
{
|
||||
size += (1 + auth_response.size());
|
||||
}
|
||||
else
|
||||
{
|
||||
size += (auth_response.size() + 1);
|
||||
}
|
||||
if (capability_flags & CLIENT_CONNECT_WITH_DB)
|
||||
{
|
||||
size += (database.size() + 1);
|
||||
}
|
||||
if (capability_flags & CLIENT_PLUGIN_AUTH)
|
||||
{
|
||||
size += (auth_plugin_name.size() + 1);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(reinterpret_cast<const char *>(&capability_flags), 4);
|
||||
buffer.write(reinterpret_cast<const char *>(&max_packet_size), 4);
|
||||
buffer.write(reinterpret_cast<const char *>(&character_set), 1);
|
||||
writeNulTerminatedString(username, buffer);
|
||||
|
||||
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
|
||||
{
|
||||
writeLengthEncodedString(auth_response, buffer);
|
||||
}
|
||||
else if (capability_flags & CLIENT_SECURE_CONNECTION)
|
||||
{
|
||||
writeString(auth_response, buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
writeNulTerminatedString(auth_response, buffer);
|
||||
}
|
||||
|
||||
if (capability_flags & CLIENT_CONNECT_WITH_DB)
|
||||
{
|
||||
writeNulTerminatedString(database, buffer);
|
||||
}
|
||||
|
||||
if (capability_flags & CLIENT_PLUGIN_AUTH)
|
||||
{
|
||||
writeNulTerminatedString(auth_plugin_name, buffer);
|
||||
}
|
||||
}
|
||||
|
||||
void readPayloadImpl(ReadBuffer & payload) override
|
||||
{
|
||||
payload.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
|
||||
payload.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
|
||||
@ -573,12 +708,12 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
class AuthSwitchResponse : public LimitedClientPacket
|
||||
class AuthSwitchResponse : public LimitedReadPacket
|
||||
{
|
||||
public:
|
||||
String value;
|
||||
|
||||
void readPayload(ReadBuffer & payload) override
|
||||
void readPayloadImpl(ReadBuffer & payload) override
|
||||
{
|
||||
readStringUntilEOF(value, payload);
|
||||
}
|
||||
@ -806,12 +941,12 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
class ComFieldList : public LimitedClientPacket
|
||||
class ComFieldList : public LimitedReadPacket
|
||||
{
|
||||
public:
|
||||
String table, field_wildcard;
|
||||
|
||||
void readPayload(ReadBuffer & payload) override
|
||||
void readPayloadImpl(ReadBuffer & payload) override
|
||||
{
|
||||
// Command byte has been already read from payload.
|
||||
readNullTerminated(table, payload);
|
||||
@ -931,6 +1066,29 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
Native41(const String & password, const String & auth_plugin_data)
|
||||
{
|
||||
/// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
|
||||
/// SHA1( password ) XOR SHA1( "20-bytes random data from server" <concat> SHA1( SHA1( password ) ) )
|
||||
Poco::SHA1Engine engine;
|
||||
engine.update(password.data(), password.size());
|
||||
const Poco::SHA1Engine::Digest & password_sha1 = engine.digest();
|
||||
engine.update(password_sha1.data(), password_sha1.size());
|
||||
const Poco::SHA1Engine::Digest & password_double_sha1 = engine.digest();
|
||||
|
||||
engine.reset();
|
||||
|
||||
engine.update(auth_plugin_data.data(), auth_plugin_data.size());
|
||||
engine.update(password_double_sha1.data(), password_double_sha1.size());
|
||||
const Poco::SHA1Engine::Digest & digest = engine.digest();
|
||||
|
||||
scramble.resize(Poco::SHA1Engine::DIGEST_SIZE);
|
||||
for (size_t i = 0; i < scramble.size(); i++)
|
||||
{
|
||||
scramble[i] = password_sha1[i] ^ digest[i];
|
||||
}
|
||||
}
|
||||
|
||||
String getName() override
|
||||
{
|
||||
return "mysql_native_password";
|
||||
|
@ -15,3 +15,6 @@ if (ENABLE_FUZZING)
|
||||
add_executable (names_and_types_fuzzer names_and_types_fuzzer.cpp)
|
||||
target_link_libraries (names_and_types_fuzzer PRIVATE dbms ${LIB_FUZZING_ENGINE})
|
||||
endif ()
|
||||
|
||||
add_executable (mysql_client mysql_client.cpp)
|
||||
target_link_libraries (mysql_client PRIVATE dbms)
|
||||
|
13
src/Core/tests/mysql_client.cpp
Normal file
13
src/Core/tests/mysql_client.cpp
Normal file
@ -0,0 +1,13 @@
|
||||
#include <Core/MySQLClient.h>
|
||||
|
||||
|
||||
int main(int, char **)
|
||||
{
|
||||
using namespace DB;
|
||||
|
||||
UInt16 port = 3306;
|
||||
String host = "127.0.0.1", user = "root", password = "";
|
||||
MySQLClient client(host, port, user, password, "");
|
||||
client.connect();
|
||||
return 0;
|
||||
}
|
Loading…
Reference in New Issue
Block a user