ISSUES-4006 split mysql protocol

This commit is contained in:
zhang2014 2020-08-13 20:41:36 +08:00
parent 34f4c8972e
commit 96bd3ac34b
26 changed files with 95 additions and 127 deletions

View File

@ -73,13 +73,13 @@ Native41::Native41(const String & password, const String & auth_plugin_data)
void Native41::authenticate( void Native41::authenticate(
const String & user_name, std::optional<String> auth_response, Context & context, const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool, const Poco::Net::SocketAddress & address) std::shared_ptr<PacketEndpoint> packet_endpoint, bool, const Poco::Net::SocketAddress & address)
{ {
if (!auth_response) if (!auth_response)
{ {
packet_sender->sendPacket(AuthSwitchRequest(getName(), scramble), true); packet_endpoint->sendPacket(AuthSwitchRequest(getName(), scramble), true);
AuthSwitchResponse response; AuthSwitchResponse response;
packet_sender->receivePacket(response); packet_endpoint->receivePacket(response);
auth_response = response.value; auth_response = response.value;
} }
@ -134,18 +134,18 @@ Sha256Password::Sha256Password(RSA & public_key_, RSA & private_key_, Poco::Logg
void Sha256Password::authenticate( void Sha256Password::authenticate(
const String & user_name, std::optional<String> auth_response, Context & context, const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool is_secure_connection, const Poco::Net::SocketAddress & address) std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address)
{ {
if (!auth_response) if (!auth_response)
{ {
packet_sender->sendPacket(AuthSwitchRequest(getName(), scramble), true); packet_endpoint->sendPacket(AuthSwitchRequest(getName(), scramble), true);
if (packet_sender->in->eof()) if (packet_endpoint->in->eof())
throw Exception("Client doesn't support authentication method " + getName() + " used by ClickHouse. Specifying user password using 'password_double_sha1_hex' may fix the problem.", throw Exception("Client doesn't support authentication method " + getName() + " used by ClickHouse. Specifying user password using 'password_double_sha1_hex' may fix the problem.",
ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES); ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
AuthSwitchResponse response; AuthSwitchResponse response;
packet_sender->receivePacket(response); packet_endpoint->receivePacket(response);
auth_response = response.value; auth_response = response.value;
LOG_TRACE(log, "Authentication method mismatch."); LOG_TRACE(log, "Authentication method mismatch.");
} }
@ -174,11 +174,11 @@ void Sha256Password::authenticate(
LOG_TRACE(log, "Key: {}", pem); LOG_TRACE(log, "Key: {}", pem);
AuthMoreData data(pem); AuthMoreData data(pem);
packet_sender->sendPacket(data, true); packet_endpoint->sendPacket(data, true);
sent_public_key = true; sent_public_key = true;
AuthSwitchResponse response; AuthSwitchResponse response;
packet_sender->receivePacket(response); packet_endpoint->receivePacket(response);
auth_response = response.value; auth_response = response.value;
} }
else else

View File

@ -33,7 +33,7 @@ public:
virtual void authenticate( virtual void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context, const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0; std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0;
}; };
/// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html /// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
@ -50,7 +50,7 @@ public:
void authenticate( void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context, const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override; std::shared_ptr<PacketEndpoint> packet_endpoint, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override;
private: private:
String scramble; String scramble;
@ -70,7 +70,7 @@ public:
void authenticate( void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context, const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_sender, bool is_secure_connection, const Poco::Net::SocketAddress & address) override; std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) override;
private: private:
RSA & public_key; RSA & public_key;

View File

@ -1,6 +1,6 @@
#include <Core/MySQL/IMySQLReadPacket.h> #include <Core/MySQL/IMySQLReadPacket.h>
#include <sstream> #include <sstream>
#include <Core/MySQL/PacketPayloadReadBuffer.h> #include <IO/MySQLPacketPayloadReadBuffer.h>
#include <IO/LimitReadBuffer.h> #include <IO/LimitReadBuffer.h>
namespace DB namespace DB
@ -16,7 +16,7 @@ namespace MySQLProtocol
void IMySQLReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id) void IMySQLReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id)
{ {
PacketPayloadReadBuffer payload(in, sequence_id); MySQLPacketPayloadReadBuffer payload(in, sequence_id);
payload.next(); payload.next();
readPayloadImpl(payload); readPayloadImpl(payload);
if (!payload.eof()) if (!payload.eof())

View File

@ -1,5 +1,5 @@
#include <Core/MySQL/IMySQLWritePacket.h> #include <Core/MySQL/IMySQLWritePacket.h>
#include <Core/MySQL/PacketPayloadWriteBuffer.h> #include <IO/MySQLPacketPayloadWriteBuffer.h>
#include <sstream> #include <sstream>
namespace DB namespace DB
@ -10,7 +10,7 @@ namespace MySQLProtocol
void IMySQLWritePacket::writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const void IMySQLWritePacket::writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const
{ {
PacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id); MySQLPacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
writePayloadImpl(buf); writePayloadImpl(buf);
buf.next(); buf.next();
if (buf.remainingPayloadSize()) if (buf.remainingPayloadSize())

View File

@ -5,7 +5,7 @@
#include <Core/MySQL/PacketsConnection.h> #include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h> #include <Core/MySQL/PacketsProtocolText.h>
#include <Core/MySQL/PacketsReplication.h> #include <Core/MySQL/PacketsReplication.h>
#include <Core/MySQLReplication.h> #include <Core/MySQL/MySQLReplication.h>
namespace DB namespace DB
{ {
@ -53,7 +53,7 @@ void MySQLClient::connect()
in = std::make_shared<ReadBufferFromPocoSocket>(*socket); in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket); out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_sender = std::make_shared<PacketEndpoint>(*in, *out, seq); packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, seq);
handshake(); handshake();
} }
@ -71,7 +71,7 @@ void MySQLClient::disconnect()
void MySQLClient::handshake() void MySQLClient::handshake()
{ {
Handshake handshake; Handshake handshake;
packet_sender->receivePacket(handshake); packet_endpoint->receivePacket(handshake);
if (handshake.auth_plugin_name != mysql_native_password) if (handshake.auth_plugin_name != mysql_native_password)
{ {
throw Exception( throw Exception(
@ -83,12 +83,12 @@ void MySQLClient::handshake()
String auth_plugin_data = native41.getAuthPluginData(); String auth_plugin_data = native41.getAuthPluginData();
HandshakeResponse handshake_response( HandshakeResponse handshake_response(
client_capability_flags, max_packet_size, charset_utf8, user, "", auth_plugin_data, mysql_native_password); client_capability_flags, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
packet_sender->sendPacket<HandshakeResponse>(handshake_response, true); packet_endpoint->sendPacket<HandshakeResponse>(handshake_response, true);
ResponsePacket packet_response(client_capability_flags, true); ResponsePacket packet_response(client_capability_flags, true);
packet_sender->receivePacket(packet_response); packet_endpoint->receivePacket(packet_response);
packet_sender->resetSequenceId(); packet_endpoint->resetSequenceId();
if (packet_response.getType() == PACKET_ERR) if (packet_response.getType() == PACKET_ERR)
throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER); throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
@ -99,10 +99,10 @@ void MySQLClient::handshake()
void MySQLClient::writeCommand(char command, String query) void MySQLClient::writeCommand(char command, String query)
{ {
WriteCommand write_command(command, query); WriteCommand write_command(command, query);
packet_sender->sendPacket<WriteCommand>(write_command, true); packet_endpoint->sendPacket<WriteCommand>(write_command, true);
ResponsePacket packet_response(client_capability_flags); ResponsePacket packet_response(client_capability_flags);
packet_sender->receivePacket(packet_response); packet_endpoint->receivePacket(packet_response);
switch (packet_response.getType()) switch (packet_response.getType())
{ {
case PACKET_ERR: case PACKET_ERR:
@ -112,17 +112,17 @@ void MySQLClient::writeCommand(char command, String query)
default: default:
break; break;
} }
packet_sender->resetSequenceId(); packet_endpoint->resetSequenceId();
} }
void MySQLClient::registerSlaveOnMaster(UInt32 slave_id) void MySQLClient::registerSlaveOnMaster(UInt32 slave_id)
{ {
RegisterSlave register_slave(slave_id); RegisterSlave register_slave(slave_id);
packet_sender->sendPacket<RegisterSlave>(register_slave, true); packet_endpoint->sendPacket<RegisterSlave>(register_slave, true);
ResponsePacket packet_response(client_capability_flags); ResponsePacket packet_response(client_capability_flags);
packet_sender->receivePacket(packet_response); packet_endpoint->receivePacket(packet_response);
packet_sender->resetSequenceId(); packet_endpoint->resetSequenceId();
if (packet_response.getType() == PACKET_ERR) if (packet_response.getType() == PACKET_ERR)
throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER); throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
} }
@ -150,12 +150,12 @@ void MySQLClient::startBinlogDump(UInt32 slave_id, String replicate_db, String b
binlog_pos = binlog_pos < 4 ? 4 : binlog_pos; binlog_pos = binlog_pos < 4 ? 4 : binlog_pos;
BinlogDump binlog_dump(binlog_pos, binlog_file_name, slave_id); BinlogDump binlog_dump(binlog_pos, binlog_file_name, slave_id);
packet_sender->sendPacket<BinlogDump>(binlog_dump, true); packet_endpoint->sendPacket<BinlogDump>(binlog_dump, true);
} }
BinlogEventPtr MySQLClient::readOneBinlogEvent(UInt64 milliseconds) BinlogEventPtr MySQLClient::readOneBinlogEvent(UInt64 milliseconds)
{ {
if (packet_sender->tryReceivePacket(replication, milliseconds)) if (packet_endpoint->tryReceivePacket(replication, milliseconds))
return replication.readOneEvent(); return replication.readOneEvent();
return {}; return {};

View File

@ -1,6 +1,6 @@
#pragma once #pragma once
#include <Core/Types.h> #include <Core/Types.h>
#include <Core/MySQLReplication.h> #include <Core/MySQL/MySQLReplication.h>
#include <IO/ReadBufferFromPocoSocket.h> #include <IO/ReadBufferFromPocoSocket.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromPocoSocket.h> #include <IO/WriteBufferFromPocoSocket.h>
@ -44,7 +44,6 @@ private:
uint8_t seq = 0; uint8_t seq = 0;
const UInt8 charset_utf8 = 33; const UInt8 charset_utf8 = 33;
const UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH;
const String mysql_native_password = "mysql_native_password"; const String mysql_native_password = "mysql_native_password";
MySQLFlavor replication; MySQLFlavor replication;
@ -52,7 +51,7 @@ private:
std::shared_ptr<WriteBuffer> out; std::shared_ptr<WriteBuffer> out;
std::unique_ptr<Poco::Net::StreamSocket> socket; std::unique_ptr<Poco::Net::StreamSocket> socket;
std::optional<Poco::Net::SocketAddress> address; std::optional<Poco::Net::SocketAddress> address;
std::shared_ptr<PacketEndpoint> packet_sender; std::shared_ptr<PacketEndpoint> packet_endpoint;
void handshake(); void handshake();
void registerSlaveOnMaster(UInt32 slave_id); void registerSlaveOnMaster(UInt32 slave_id);

View File

@ -18,9 +18,9 @@ PacketEndpoint::PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & s
{ {
} }
PacketPayloadReadBuffer PacketEndpoint::getPayload() MySQLPacketPayloadReadBuffer PacketEndpoint::getPayload()
{ {
return PacketPayloadReadBuffer(*in, sequence_id); return MySQLPacketPayloadReadBuffer(*in, sequence_id);
} }
void PacketEndpoint::receivePacket(IMySQLReadPacket & packet) void PacketEndpoint::receivePacket(IMySQLReadPacket & packet)

View File

@ -4,7 +4,7 @@
#include <IO/WriteBuffer.h> #include <IO/WriteBuffer.h>
#include "IMySQLReadPacket.h" #include "IMySQLReadPacket.h"
#include "IMySQLWritePacket.h" #include "IMySQLWritePacket.h"
#include "PacketPayloadReadBuffer.h" #include "IO/MySQLPacketPayloadReadBuffer.h"
namespace DB namespace DB
{ {
@ -21,7 +21,6 @@ public:
uint8_t & sequence_id; uint8_t & sequence_id;
ReadBuffer * in; ReadBuffer * in;
WriteBuffer * out; WriteBuffer * out;
size_t max_packet_size = MAX_PACKET_LENGTH;
/// For writing. /// For writing.
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_); PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
@ -29,7 +28,7 @@ public:
/// For reading and writing. /// For reading and writing.
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_); PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
PacketPayloadReadBuffer getPayload(); MySQLPacketPayloadReadBuffer getPayload();
void receivePacket(IMySQLReadPacket & packet); void receivePacket(IMySQLReadPacket & packet);

View File

@ -12,6 +12,8 @@ namespace MySQLProtocol
namespace Generic namespace Generic
{ {
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
enum StatusFlags enum StatusFlags
{ {
SERVER_SESSION_STATE_CHANGED = 0x4000 SERVER_SESSION_STATE_CHANGED = 0x4000

View File

@ -59,13 +59,4 @@ void BinlogDump::writePayloadImpl(WriteBuffer & buffer) const
} }
} }
}
namespace DB::MySQLProtocol
{
extern const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
} }

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <Core/MySQL/PacketPayloadReadBuffer.h> #include <IO/MySQLPacketPayloadReadBuffer.h>
#include <Core/MySQL/PacketPayloadWriteBuffer.h> #include <IO/MySQLPacketPayloadWriteBuffer.h>
#include <Core/MySQL/PacketEndpoint.h> #include <Core/MySQL/PacketEndpoint.h>
/// Implementation of MySQL wire protocol. /// Implementation of MySQL wire protocol.

View File

@ -1,6 +1,6 @@
#include <string> #include <string>
#include <Core/MySQLClient.h> #include <Core/MySQL/MySQLClient.h>
#include <Core/MySQL/Authentication.h> #include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsGeneric.h> #include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsConnection.h> #include <Core/MySQL/PacketsConnection.h>
@ -27,7 +27,7 @@ int main(int argc, char ** argv)
String database; String database;
UInt8 charset_utf8 = 33; UInt8 charset_utf8 = 33;
UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH; UInt32 max_packet_size = MAX_PACKET_LENGTH;
String mysql_native_password = "mysql_native_password"; String mysql_native_password = "mysql_native_password";
UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH

View File

@ -18,7 +18,7 @@
#endif #endif
#if USE_MYSQL #if USE_MYSQL
# include <Core/MySQLClient.h> # include <Core/MySQL/MySQLClient.h>
# include <Databases/MySQL/DatabaseConnectionMySQL.h> # include <Databases/MySQL/DatabaseConnectionMySQL.h>
# include <Databases/MySQL/MaterializeMySQLSettings.h> # include <Databases/MySQL/MaterializeMySQLSettings.h>
# include <Databases/MySQL/DatabaseMaterializeMySQL.h> # include <Databases/MySQL/DatabaseMaterializeMySQL.h>

View File

@ -5,7 +5,7 @@
#if USE_MYSQL #if USE_MYSQL
#include <mysqlxx/Pool.h> #include <mysqlxx/Pool.h>
#include <Core/MySQLClient.h> #include <Core/MySQL/MySQLClient.h>
#include <Databases/IDatabase.h> #include <Databases/IDatabase.h>
#include <Databases/MySQL/MaterializeMySQLSettings.h> #include <Databases/MySQL/MaterializeMySQLSettings.h>
#include <Databases/MySQL/MaterializeMySQLSyncThread.h> #include <Databases/MySQL/MaterializeMySQLSyncThread.h>

View File

@ -7,7 +7,7 @@
#if USE_MYSQL #if USE_MYSQL
#include <Core/Types.h> #include <Core/Types.h>
#include <Core/MySQLReplication.h> #include <Core/MySQL/MySQLReplication.h>
#include <mysqlxx/Connection.h> #include <mysqlxx/Connection.h>
#include <mysqlxx/PoolWithFailover.h> #include <mysqlxx/PoolWithFailover.h>

View File

@ -8,7 +8,7 @@
# include <mutex> # include <mutex>
# include <Core/BackgroundSchedulePool.h> # include <Core/BackgroundSchedulePool.h>
# include <Core/MySQLClient.h> # include <Core/MySQL/MySQLClient.h>
# include <DataStreams/BlockIO.h> # include <DataStreams/BlockIO.h>
# include <DataTypes/DataTypeString.h> # include <DataTypes/DataTypeString.h>
# include <DataTypes/DataTypesNumber.h> # include <DataTypes/DataTypesNumber.h>

View File

@ -1,4 +1,4 @@
#include <Core/MySQL/PacketPayloadReadBuffer.h> #include <IO/MySQLPacketPayloadReadBuffer.h>
#include <sstream> #include <sstream>
namespace DB namespace DB
@ -9,17 +9,16 @@ namespace ErrorCodes
extern const int UNKNOWN_PACKET_FROM_CLIENT; extern const int UNKNOWN_PACKET_FROM_CLIENT;
} }
namespace MySQLProtocol const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
{
PacketPayloadReadBuffer::PacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_) MySQLPacketPayloadReadBuffer::MySQLPacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_)
: ReadBuffer(in_.position(), 0), in(in_), sequence_id(sequence_id_) // not in.buffer().begin(), because working buffer may include previous packet : ReadBuffer(in_.position(), 0), in(in_), sequence_id(sequence_id_) // not in.buffer().begin(), because working buffer may include previous packet
{ {
} }
bool PacketPayloadReadBuffer::nextImpl() bool MySQLPacketPayloadReadBuffer::nextImpl()
{ {
if (!has_read_header || (payload_length == max_packet_size && offset == payload_length)) if (!has_read_header || (payload_length == MAX_PACKET_LENGTH && offset == payload_length))
{ {
has_read_header = true; has_read_header = true;
working_buffer.resize(0); working_buffer.resize(0);
@ -27,7 +26,7 @@ bool PacketPayloadReadBuffer::nextImpl()
payload_length = 0; payload_length = 0;
in.readStrict(reinterpret_cast<char *>(&payload_length), 3); in.readStrict(reinterpret_cast<char *>(&payload_length), 3);
if (payload_length > max_packet_size) if (payload_length > MAX_PACKET_LENGTH)
{ {
std::ostringstream tmp; std::ostringstream tmp;
tmp << "Received packet with payload larger than max_packet_size: " << payload_length; tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
@ -64,5 +63,3 @@ bool PacketPayloadReadBuffer::nextImpl()
} }
} }
}

View File

@ -5,20 +5,14 @@
namespace DB namespace DB
{ {
namespace MySQLProtocol
{
extern const size_t MAX_PACKET_LENGTH;
/** Reading packets. /** Reading packets.
* Internally, it calls (if no more data) next() method of the underlying ReadBufferFromPocoSocket, and sets the working buffer to the rest part of the current packet payload. * Internally, it calls (if no more data) next() method of the underlying ReadBufferFromPocoSocket, and sets the working buffer to the rest part of the current packet payload.
*/ */
class PacketPayloadReadBuffer : public ReadBuffer class MySQLPacketPayloadReadBuffer : public ReadBuffer
{ {
private: private:
ReadBuffer & in; ReadBuffer & in;
uint8_t & sequence_id; uint8_t & sequence_id;
const size_t max_packet_size = MAX_PACKET_LENGTH;
bool has_read_header = false; bool has_read_header = false;
@ -32,10 +26,8 @@ protected:
bool nextImpl() override; bool nextImpl() override;
public: public:
PacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_); MySQLPacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_);
}; };
} }
}

View File

@ -1,14 +1,11 @@
#include <Core/MySQL/PacketPayloadWriteBuffer.h> #include <IO/MySQLPacketPayloadWriteBuffer.h>
namespace DB namespace DB
{ {
namespace MySQLProtocol const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
{
extern const size_t MAX_PACKET_LENGTH; MySQLPacketPayloadWriteBuffer::MySQLPacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_)
PacketPayloadWriteBuffer::PacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_)
: WriteBuffer(out_.position(), 0), out(out_), sequence_id(sequence_id_), total_left(payload_length_) : WriteBuffer(out_.position(), 0), out(out_), sequence_id(sequence_id_), total_left(payload_length_)
{ {
startNewPacket(); startNewPacket();
@ -16,7 +13,7 @@ PacketPayloadWriteBuffer::PacketPayloadWriteBuffer(WriteBuffer & out_, size_t pa
pos = out.position(); pos = out.position();
} }
void PacketPayloadWriteBuffer::startNewPacket() void MySQLPacketPayloadWriteBuffer::startNewPacket()
{ {
payload_length = std::min(total_left, MAX_PACKET_LENGTH); payload_length = std::min(total_left, MAX_PACKET_LENGTH);
bytes_written = 0; bytes_written = 0;
@ -27,7 +24,7 @@ void PacketPayloadWriteBuffer::startNewPacket()
bytes += 4; bytes += 4;
} }
void PacketPayloadWriteBuffer::setWorkingBuffer() void MySQLPacketPayloadWriteBuffer::setWorkingBuffer()
{ {
out.nextIfAtEnd(); out.nextIfAtEnd();
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available())); working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
@ -40,7 +37,7 @@ void PacketPayloadWriteBuffer::setWorkingBuffer()
} }
} }
void PacketPayloadWriteBuffer::nextImpl() void MySQLPacketPayloadWriteBuffer::nextImpl()
{ {
const int written = pos - working_buffer.begin(); const int written = pos - working_buffer.begin();
if (eof) if (eof)
@ -57,5 +54,3 @@ void PacketPayloadWriteBuffer::nextImpl()
} }
} }
}

View File

@ -5,16 +5,13 @@
namespace DB namespace DB
{ {
namespace MySQLProtocol
{
/** Writing packets. /** Writing packets.
* https://dev.mysql.com/doc/internals/en/mysql-packet.html * https://dev.mysql.com/doc/internals/en/mysql-packet.html
*/ */
class PacketPayloadWriteBuffer : public WriteBuffer class MySQLPacketPayloadWriteBuffer : public WriteBuffer
{ {
public: public:
PacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_); MySQLPacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_);
bool remainingPayloadSize() { return total_left; } bool remainingPayloadSize() { return total_left; }
@ -37,5 +34,3 @@ private:
}; };
} }
}

View File

@ -30,17 +30,17 @@ void MySQLOutputFormat::initialize()
if (header.columns()) if (header.columns())
{ {
packet_sender->sendPacket(LengthEncodedNumber(header.columns())); packet_endpoint->sendPacket(LengthEncodedNumber(header.columns()));
for (size_t i = 0; i < header.columns(); i++) for (size_t i = 0; i < header.columns(); i++)
{ {
const auto & column_name = header.getColumnsWithTypeAndName()[i].name; const auto & column_name = header.getColumnsWithTypeAndName()[i].name;
packet_sender->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId())); packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
} }
if (!(context->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF)) if (!(context->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
{ {
packet_sender->sendPacket(EOFPacket(0, 0)); packet_endpoint->sendPacket(EOFPacket(0, 0));
} }
} }
} }
@ -54,7 +54,7 @@ void MySQLOutputFormat::consume(Chunk chunk)
for (size_t i = 0; i < chunk.getNumRows(); i++) for (size_t i = 0; i < chunk.getNumRows(); i++)
{ {
ProtocolText::ResultSetRow row_packet(data_types, chunk.getColumns(), i); ProtocolText::ResultSetRow row_packet(data_types, chunk.getColumns(), i);
packet_sender->sendPacket(row_packet); packet_endpoint->sendPacket(row_packet);
} }
} }
@ -76,17 +76,17 @@ void MySQLOutputFormat::finalize()
const auto & header = getPort(PortKind::Main).getHeader(); const auto & header = getPort(PortKind::Main).getHeader();
if (header.columns() == 0) if (header.columns() == 0)
packet_sender->sendPacket(OKPacket(0x0, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); packet_endpoint->sendPacket(OKPacket(0x0, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else else
if (context->mysql.client_capabilities & CLIENT_DEPRECATE_EOF) if (context->mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OKPacket(0xfe, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); packet_endpoint->sendPacket(OKPacket(0xfe, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else else
packet_sender->sendPacket(EOFPacket(0, 0), true); packet_endpoint->sendPacket(EOFPacket(0, 0), true);
} }
void MySQLOutputFormat::flush() void MySQLOutputFormat::flush()
{ {
packet_sender->out->next(); packet_endpoint->out->next();
} }
void registerOutputFormatProcessorMySQLWire(FormatFactory & factory) void registerOutputFormatProcessorMySQLWire(FormatFactory & factory)

View File

@ -29,8 +29,7 @@ public:
void setContext(const Context & context_) void setContext(const Context & context_)
{ {
context = &context_; context = &context_;
packet_sender = std::make_unique<MySQLProtocol::PacketEndpoint>(out, const_cast<uint8_t &>(context_.mysql.sequence_id)); /// TODO: fix it packet_endpoint = std::make_unique<MySQLProtocol::PacketEndpoint>(out, const_cast<uint8_t &>(context_.mysql.sequence_id)); /// TODO: fix it
packet_sender->max_packet_size = context_.mysql.max_packet_size;
} }
void consume(Chunk) override; void consume(Chunk) override;
@ -45,7 +44,7 @@ private:
bool initialized = false; bool initialized = false;
const Context * context = nullptr; const Context * context = nullptr;
std::unique_ptr<MySQLProtocol::PacketEndpoint> packet_sender; std::unique_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint;
FormatSettings format_settings; FormatSettings format_settings;
DataTypes data_types; DataTypes data_types;
}; };

View File

@ -41,9 +41,9 @@ namespace DB
{ {
using namespace MySQLProtocol; using namespace MySQLProtocol;
using namespace MySQLProtocol::ConnectionPhase;
using namespace MySQLProtocol::ProtocolText;
using namespace MySQLProtocol::Generic; using namespace MySQLProtocol::Generic;
using namespace MySQLProtocol::ProtocolText;
using namespace MySQLProtocol::ConnectionPhase;
#if USE_SSL #if USE_SSL
using Poco::Net::SecureStreamSocket; using Poco::Net::SecureStreamSocket;
@ -91,13 +91,13 @@ void MySQLHandler::run()
in = std::make_shared<ReadBufferFromPocoSocket>(socket()); in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket()); out = std::make_shared<WriteBufferFromPocoSocket>(socket());
packet_sender = std::make_shared<PacketEndpoint>(*in, *out, connection_context.mysql.sequence_id); packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context.mysql.sequence_id);
try try
{ {
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME,
auth_plugin->getName(), auth_plugin->getAuthPluginData(), CharacterSet::utf8_general_ci); auth_plugin->getName(), auth_plugin->getAuthPluginData(), CharacterSet::utf8_general_ci);
packet_sender->sendPacket<Handshake>(handshake, true); packet_endpoint->sendPacket<Handshake>(handshake, true);
LOG_TRACE(log, "Sent handshake"); LOG_TRACE(log, "Sent handshake");
@ -135,16 +135,16 @@ void MySQLHandler::run()
catch (const Exception & exc) catch (const Exception & exc)
{ {
log->log(exc); log->log(exc);
packet_sender->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true); packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
} }
OKPacket ok_packet(0, handshake_response.capability_flags, 0, 0, 0); OKPacket ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
packet_sender->sendPacket(ok_packet, true); packet_endpoint->sendPacket(ok_packet, true);
while (true) while (true)
{ {
packet_sender->resetSequenceId(); packet_endpoint->resetSequenceId();
PacketPayloadReadBuffer payload = packet_sender->getPayload(); MySQLPacketPayloadReadBuffer payload = packet_endpoint->getPayload();
char command = 0; char command = 0;
payload.readStrict(command); payload.readStrict(command);
@ -184,7 +184,7 @@ void MySQLHandler::run()
} }
catch (...) catch (...)
{ {
packet_sender->sendPacket(ERRPacket(getCurrentExceptionCode(), "00000", getCurrentExceptionMessage(false)), true); packet_endpoint->sendPacket(ERRPacket(getCurrentExceptionCode(), "00000", getCurrentExceptionMessage(false)), true);
} }
} }
} }
@ -233,11 +233,11 @@ void MySQLHandler::finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResp
packet_size = PACKET_HEADER_SIZE + payload_size; packet_size = PACKET_HEADER_SIZE + payload_size;
WriteBufferFromOwnString buf_for_handshake_response; WriteBufferFromOwnString buf_for_handshake_response;
buf_for_handshake_response.write(buf, pos); buf_for_handshake_response.write(buf, pos);
copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos); copyData(*packet_endpoint->in, buf_for_handshake_response, packet_size - pos);
ReadBufferFromString payload(buf_for_handshake_response.str()); ReadBufferFromString payload(buf_for_handshake_response.str());
payload.ignore(PACKET_HEADER_SIZE); payload.ignore(PACKET_HEADER_SIZE);
packet.readPayloadWithUnpacked(payload); packet.readPayloadWithUnpacked(payload);
packet_sender->sequence_id++; packet_endpoint->sequence_id++;
} }
} }
@ -254,12 +254,12 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
} }
std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt; std::optional<String> auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional<String>(initial_auth_response) : std::nullopt;
auth_plugin->authenticate(user_name, auth_response, connection_context, packet_sender, secure_connection, socket().peerAddress()); auth_plugin->authenticate(user_name, auth_response, connection_context, packet_endpoint, secure_connection, socket().peerAddress());
} }
catch (const Exception & exc) catch (const Exception & exc)
{ {
LOG_ERROR(log, "Authentication for user {} failed.", user_name); LOG_ERROR(log, "Authentication for user {} failed.", user_name);
packet_sender->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true); packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
throw; throw;
} }
LOG_INFO(log, "Authentication for user {} succeeded.", user_name); LOG_INFO(log, "Authentication for user {} succeeded.", user_name);
@ -271,7 +271,7 @@ void MySQLHandler::comInitDB(ReadBuffer & payload)
readStringUntilEOF(database, payload); readStringUntilEOF(database, payload);
LOG_DEBUG(log, "Setting current database to {}", database); LOG_DEBUG(log, "Setting current database to {}", database);
connection_context.setCurrentDatabase(database); connection_context.setCurrentDatabase(database);
packet_sender->sendPacket(OKPacket(0, client_capability_flags, 0, 0, 1), true); packet_endpoint->sendPacket(OKPacket(0, client_capability_flags, 0, 0, 1), true);
} }
void MySQLHandler::comFieldList(ReadBuffer & payload) void MySQLHandler::comFieldList(ReadBuffer & payload)
@ -286,14 +286,14 @@ void MySQLHandler::comFieldList(ReadBuffer & payload)
ColumnDefinition column_definition( ColumnDefinition column_definition(
database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0 database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0
); );
packet_sender->sendPacket(column_definition); packet_endpoint->sendPacket(column_definition);
} }
packet_sender->sendPacket(OKPacket(0xfe, client_capability_flags, 0, 0, 0), true); packet_endpoint->sendPacket(OKPacket(0xfe, client_capability_flags, 0, 0, 0), true);
} }
void MySQLHandler::comPing() void MySQLHandler::comPing()
{ {
packet_sender->sendPacket(OKPacket(0x0, client_capability_flags, 0, 0, 0), true); packet_endpoint->sendPacket(OKPacket(0x0, client_capability_flags, 0, 0, 0), true);
} }
static bool isFederatedServerSetupSetCommand(const String & query); static bool isFederatedServerSetupSetCommand(const String & query);
@ -306,7 +306,7 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
// As Clickhouse doesn't support these statements, we just send OK packet in response. // As Clickhouse doesn't support these statements, we just send OK packet in response.
if (isFederatedServerSetupSetCommand(query)) if (isFederatedServerSetupSetCommand(query))
{ {
packet_sender->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true); packet_endpoint->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true);
} }
else else
{ {
@ -336,7 +336,7 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
); );
if (!with_output) if (!with_output)
packet_sender->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true); packet_endpoint->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true);
} }
} }
@ -380,9 +380,8 @@ void MySQLHandlerSSL::finishHandshakeSSL(
in = std::make_shared<ReadBufferFromPocoSocket>(*ss); in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
out = std::make_shared<WriteBufferFromPocoSocket>(*ss); out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
connection_context.mysql.sequence_id = 2; connection_context.mysql.sequence_id = 2;
packet_sender = std::make_shared<PacketEndpoint>(*in, *out, connection_context.mysql.sequence_id); packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context.mysql.sequence_id);
packet_sender->max_packet_size = connection_context.mysql.max_packet_size; packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
} }
#endif #endif

View File

@ -58,7 +58,7 @@ protected:
Context connection_context; Context connection_context;
std::shared_ptr<MySQLProtocol::PacketEndpoint> packet_sender; std::shared_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint;
private: private:
size_t connection_id = 0; size_t connection_id = 0;