ISSUES-4006 split msql protocol & fix build

This commit is contained in:
zhang2014 2020-08-13 14:30:29 +08:00
parent 0162c39838
commit 688836cdc4
14 changed files with 108 additions and 879 deletions

View File

@ -12,19 +12,19 @@ namespace MySQLProtocol
namespace Authentication
{
class IPlugin
{
public:
virtual String getName() = 0;
virtual String getAuthPluginData() = 0;
virtual void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketSender> packet_sender, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0;
virtual ~IPlugin() = default;
};
//class IPlugin
//{
//public:
// virtual String getName() = 0;
//
// virtual String getAuthPluginData() = 0;
//
// virtual void authenticate(
// const String & user_name, std::optional<String> auth_response, Context & context,
// std::shared_ptr<PacketSender> packet_sender, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0;
//
// virtual ~IPlugin() = default;
//};
}

View File

@ -7,7 +7,7 @@ namespace DB
namespace MySQLProtocol
{
void IMySQLReadPacket::readPayload(ReadBuffer &in, uint8_t &sequence_id)
void IMySQLReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id)
{
PacketPayloadReadBuffer payload(in, sequence_id);
payload.next();
@ -20,10 +20,21 @@ void IMySQLReadPacket::readPayload(ReadBuffer &in, uint8_t &sequence_id)
}
}
void IMySQLReadPacket::readPayloadWithUnpacked(ReadBuffer & in)
{
readPayloadImpl(in);
}
void LimitedReadPacket::readPayload(ReadBuffer &in, uint8_t &sequence_id)
{
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
ReadPacket::readPayload(limited, sequence_id);
IMySQLReadPacket::readPayload(limited, sequence_id);
}
void LimitedReadPacket::readPayloadWithUnpacked(ReadBuffer & in)
{
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
IMySQLReadPacket::readPayloadWithUnpacked(limited);
}
}

View File

@ -19,6 +19,8 @@ public:
virtual void readPayload(ReadBuffer & in, uint8_t & sequence_id);
virtual void readPayloadWithUnpacked(ReadBuffer & in);
protected:
virtual void readPayloadImpl(ReadBuffer & buf) = 0;
};
@ -27,6 +29,8 @@ class LimitedReadPacket : public IMySQLReadPacket
{
public:
void readPayload(ReadBuffer & in, uint8_t & sequence_id) override;
void readPayloadWithUnpacked(ReadBuffer & in) override;
};
uint64_t readLengthEncodedNumber(ReadBuffer & ss);

View File

@ -7,7 +7,7 @@ namespace DB
namespace MySQLProtocol
{
void IMySQLWritePacket::writePayload(WriteBuffer & buffer, uint8_t &sequence_id) const
void IMySQLWritePacket::writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const
{
PacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
writePayloadImpl(buf);

View File

@ -425,22 +425,22 @@ void ResponsePacket::readPayloadImpl(ReadBuffer & payload)
{
case PACKET_OK:
packetType = PACKET_OK;
ok.readPayloadImpl(payload);
ok.readPayloadWithUnpacked(payload);
break;
case PACKET_ERR:
packetType = PACKET_ERR;
err.readPayloadImpl(payload);
err.readPayloadWithUnpacked(payload);
break;
case PACKET_EOF:
if (is_handshake)
{
packetType = PACKET_AUTH_SWITCH;
auth_switch.readPayloadImpl(payload);
auth_switch.readPayloadWithUnpacked(payload);
}
else
{
packetType = PACKET_EOF;
eof.readPayloadImpl(payload);
eof.readPayloadWithUnpacked(payload);
}
break;
case PACKET_LOCALINFILE:

View File

@ -77,7 +77,7 @@ void MySQLClient::handshake()
client_capability_flags, max_packet_size, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
packet_sender->sendPacket<HandshakeResponse>(handshake_response, true);
PacketResponse packet_response(client_capability_flags, true);
ResponsePacket packet_response(client_capability_flags, true);
packet_sender->receivePacket(packet_response);
packet_sender->resetSequenceId();
@ -92,7 +92,7 @@ void MySQLClient::writeCommand(char command, String query)
WriteCommand write_command(command, query);
packet_sender->sendPacket<WriteCommand>(write_command, true);
PacketResponse packet_response(client_capability_flags);
ResponsePacket packet_response(client_capability_flags);
packet_sender->receivePacket(packet_response);
switch (packet_response.getType())
{
@ -111,7 +111,7 @@ void MySQLClient::registerSlaveOnMaster(UInt32 slave_id)
RegisterSlave register_slave(slave_id);
packet_sender->sendPacket<RegisterSlave>(register_slave, true);
PacketResponse packet_response(client_capability_flags);
ResponsePacket packet_response(client_capability_flags);
packet_sender->receivePacket(packet_response);
packet_sender->resetSequenceId();
if (packet_response.getType() == PACKET_ERR)

View File

@ -64,7 +64,7 @@ private:
void writeCommand(char command, String query);
};
class WriteCommand : public WritePacket
class WriteCommand : public IMySQLWritePacket
{
public:
char command;

View File

@ -102,7 +102,7 @@ size_t getLengthEncodedStringSize(const String & s)
return getLengthEncodedNumberSize(s.size()) + s.size();
}
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
ColumnDefinitionPacket getColumnDefinition(const String & column_name, const TypeIndex type_index)
{
ColumnType column_type;
CharacterSet charset = CharacterSet::binary;
@ -167,7 +167,7 @@ ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex
charset = CharacterSet::utf8_general_ci;
break;
}
return ColumnDefinition(column_name, charset, 0, column_type, flags, 0);
return ColumnDefinitionPacket(column_name, charset, 0, column_type, flags, 0);
}
bool PacketPayloadReadBuffer::nextImpl()
@ -221,23 +221,23 @@ PacketPayloadReadBuffer::PacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & seq
{
}
void ReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id)
{
PacketPayloadReadBuffer payload(in, sequence_id);
payload.next();
readPayloadImpl(payload);
if (!payload.eof())
{
std::stringstream tmp;
tmp << "Packet payload is not fully read. Stopped after " << payload.count() << " bytes, while " << payload.available() << " bytes are in buffer.";
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
}
void LimitedReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id)
{
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
ReadPacket::readPayload(limited, sequence_id);
}
//void ReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id)
//{
// PacketPayloadReadBuffer payload(in, sequence_id);
// payload.next();
// readPayloadImpl(payload);
// if (!payload.eof())
// {
// std::stringstream tmp;
// tmp << "Packet payload is not fully read. Stopped after " << payload.count() << " bytes, while " << payload.available() << " bytes are in buffer.";
// throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
// }
//}
//
//void LimitedReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id)
//{
// LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
// ReadPacket::readPayload(limited, sequence_id);
//}
}

View File

@ -22,6 +22,7 @@
#include <Poco/Net/StreamSocket.h>
#include <Poco/RandomStream.h>
#include <Poco/SHA1Engine.h>
#include <Core/MySQL/MySQLPackets.h>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
@ -109,50 +110,6 @@ enum Command
COM_DAEMON = 0x1d
};
enum ColumnType
{
MYSQL_TYPE_DECIMAL = 0x00,
MYSQL_TYPE_TINY = 0x01,
MYSQL_TYPE_SHORT = 0x02,
MYSQL_TYPE_LONG = 0x03,
MYSQL_TYPE_FLOAT = 0x04,
MYSQL_TYPE_DOUBLE = 0x05,
MYSQL_TYPE_NULL = 0x06,
MYSQL_TYPE_TIMESTAMP = 0x07,
MYSQL_TYPE_LONGLONG = 0x08,
MYSQL_TYPE_INT24 = 0x09,
MYSQL_TYPE_DATE = 0x0a,
MYSQL_TYPE_TIME = 0x0b,
MYSQL_TYPE_DATETIME = 0x0c,
MYSQL_TYPE_YEAR = 0x0d,
MYSQL_TYPE_NEWDATE = 0x0e,
MYSQL_TYPE_VARCHAR = 0x0f,
MYSQL_TYPE_BIT = 0x10,
MYSQL_TYPE_TIMESTAMP2 = 0x11,
MYSQL_TYPE_DATETIME2 = 0x12,
MYSQL_TYPE_TIME2 = 0x13,
MYSQL_TYPE_JSON = 0xf5,
MYSQL_TYPE_NEWDECIMAL = 0xf6,
MYSQL_TYPE_ENUM = 0xf7,
MYSQL_TYPE_SET = 0xf8,
MYSQL_TYPE_TINY_BLOB = 0xf9,
MYSQL_TYPE_MEDIUM_BLOB = 0xfa,
MYSQL_TYPE_LONG_BLOB = 0xfb,
MYSQL_TYPE_BLOB = 0xfc,
MYSQL_TYPE_VAR_STRING = 0xfd,
MYSQL_TYPE_STRING = 0xfe,
MYSQL_TYPE_GEOMETRY = 0xff
};
enum ResponsePacketType
{
PACKET_OK = 0x00,
PACKET_ERR = 0xff,
PACKET_EOF = 0xfe,
PACKET_AUTH_SWITCH = 0xfe,
PACKET_LOCALINFILE = 0xfb,
};
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html
enum ColumnDefinitionFlags
{
@ -194,28 +151,6 @@ public:
};
class ReadPacket
{
public:
ReadPacket() = default;
ReadPacket(ReadPacket &&) = default;
virtual void readPayload(ReadBuffer & in, uint8_t & sequence_id);
virtual void readPayloadImpl(ReadBuffer & buf) = 0;
virtual ~ReadPacket() = default;
};
class LimitedReadPacket : public ReadPacket
{
public:
void readPayload(ReadBuffer & in, uint8_t & sequence_id) override;
};
/** Writing packets.
* https://dev.mysql.com/doc/internals/en/mysql-packet.html
*/
@ -287,31 +222,6 @@ protected:
}
};
class WritePacket
{
public:
virtual void writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const
{
PacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
writePayloadImpl(buf);
buf.next();
if (buf.remainingPayloadSize())
{
std::stringstream ss;
ss << "Incomplete payload. Written " << getPayloadSize() - buf.remainingPayloadSize() << " bytes, expected " << getPayloadSize() << " bytes.";
throw Exception(ss.str(), 0);
}
}
virtual ~WritePacket() = default;
protected:
virtual size_t getPayloadSize() const = 0;
virtual void writePayloadImpl(WriteBuffer & buffer) const = 0;
};
/* Writes and reads packets, keeping sequence-id.
* Throws ProtocolError, if packet with incorrect sequence-id was received.
*/
@ -339,12 +249,12 @@ public:
{
}
void receivePacket(ReadPacket & packet)
void receivePacket(IMySQLReadPacket & packet)
{
packet.readPayload(*in, sequence_id);
}
bool tryReceivePacket(ReadPacket & packet, UInt64 millisecond = 0)
bool tryReceivePacket(IMySQLReadPacket & packet, UInt64 millisecond = 0)
{
if (millisecond != 0)
{
@ -364,7 +274,7 @@ public:
template<class T>
void sendPacket(const T & packet, bool flush = false)
{
static_assert(std::is_base_of<WritePacket, T>());
static_assert(std::is_base_of<IMySQLWritePacket, T>());
packet.writePayload(*out, sequence_id);
if (flush)
out->next();
@ -410,710 +320,13 @@ size_t getLengthEncodedNumberSize(uint64_t x);
size_t getLengthEncodedStringSize(const String & s);
class Handshake : public WritePacket, public ReadPacket
{
public:
int protocol_version = 0xa;
String server_version;
uint32_t connection_id;
uint32_t capability_flags;
uint8_t character_set;
uint32_t status_flags;
String auth_plugin_name;
String auth_plugin_data;
Handshake() : connection_id(0x00), capability_flags(0x00), character_set(0x00), status_flags(0x00) { }
Handshake(uint32_t capability_flags_, uint32_t connection_id_, String server_version_, String auth_plugin_name_, String auth_plugin_data_)
: protocol_version(0xa)
, server_version(std::move(server_version_))
, connection_id(connection_id_)
, capability_flags(capability_flags_)
, character_set(CharacterSet::utf8_general_ci)
, status_flags(0)
, auth_plugin_name(std::move(auth_plugin_name_))
, auth_plugin_data(std::move(auth_plugin_data_))
{
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(static_cast<char>(protocol_version));
writeNulTerminatedString(server_version, buffer);
buffer.write(reinterpret_cast<const char *>(&connection_id), 4);
writeNulTerminatedString(auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
buffer.write(reinterpret_cast<const char *>(&capability_flags), 2);
buffer.write(reinterpret_cast<const char *>(&character_set), 1);
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
buffer.write((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
buffer.write(static_cast<char>(auth_plugin_data.size()));
writeChar(0x0, 10, buffer);
writeString(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
writeString(auth_plugin_name, buffer);
writeChar(0x0, 1, buffer);
}
void readPayloadImpl(ReadBuffer & payload) override
{
payload.readStrict(reinterpret_cast<char *>(&protocol_version), 1);
readNullTerminated(server_version, payload);
payload.readStrict(reinterpret_cast<char *>(&connection_id), 4);
auth_plugin_data.resize(AUTH_PLUGIN_DATA_PART_1_LENGTH);
payload.readStrict(auth_plugin_data.data(), AUTH_PLUGIN_DATA_PART_1_LENGTH);
payload.ignore(1);
payload.readStrict(reinterpret_cast<char *>(&capability_flags), 2);
payload.readStrict(reinterpret_cast<char *>(&character_set), 1);
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
payload.readStrict((reinterpret_cast<char *>(&capability_flags)) + 2, 2);
UInt8 auth_plugin_data_length = 0;
if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH)
{
payload.readStrict(reinterpret_cast<char *>(&auth_plugin_data_length), 1);
}
else
{
payload.ignore(1);
}
payload.ignore(10);
if (capability_flags & MySQLProtocol::CLIENT_SECURE_CONNECTION)
{
UInt8 part2_length = (SCRAMBLE_LENGTH - AUTH_PLUGIN_DATA_PART_1_LENGTH);
auth_plugin_data.resize(SCRAMBLE_LENGTH);
payload.readStrict(auth_plugin_data.data() + AUTH_PLUGIN_DATA_PART_1_LENGTH, part2_length);
payload.ignore(1);
}
if (capability_flags & MySQLProtocol::CLIENT_PLUGIN_AUTH)
{
readNullTerminated(auth_plugin_name, payload);
}
}
protected:
size_t getPayloadSize() const override
{
return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size();
}
};
class SSLRequest : public ReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
void readPayloadImpl(ReadBuffer & buf) override
{
buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
buf.readStrict(reinterpret_cast<char *>(&character_set), 1);
}
};
class HandshakeResponse : public WritePacket, public ReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
String username;
String database;
String auth_response;
String auth_plugin_name;
HandshakeResponse() : capability_flags(0x00), max_packet_size(0x00), character_set(0x00) { }
HandshakeResponse(
UInt32 capability_flags_,
UInt32 max_packet_size_,
UInt8 character_set_,
const String & username_,
const String & database_,
const String & auth_response_,
const String & auth_plugin_name_)
: capability_flags(capability_flags_)
, max_packet_size(max_packet_size_)
, character_set(character_set_)
, username(std::move(username_))
, database(std::move(database_))
, auth_response(std::move(auth_response_))
, auth_plugin_name(std::move(auth_plugin_name_))
{
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(reinterpret_cast<const char *>(&capability_flags), 4);
buffer.write(reinterpret_cast<const char *>(&max_packet_size), 4);
buffer.write(reinterpret_cast<const char *>(&character_set), 1);
writeChar(0x0, 23, buffer);
writeNulTerminatedString(username, buffer);
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
writeLengthEncodedString(auth_response, buffer);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
writeChar(auth_response.size(), buffer);
writeString(auth_response.data(), auth_response.size(), buffer);
}
else
{
writeNulTerminatedString(auth_response, buffer);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
writeNulTerminatedString(database, buffer);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
writeNulTerminatedString(auth_plugin_name, buffer);
}
}
void readPayloadImpl(ReadBuffer & payload) override
{
payload.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
payload.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
payload.readStrict(reinterpret_cast<char *>(&character_set), 1);
payload.ignore(23);
readNullTerminated(username, payload);
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
readLengthEncodedString(auth_response, payload);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
char len;
payload.readStrict(len);
auth_response.resize(static_cast<unsigned int>(len));
payload.readStrict(auth_response.data(), len);
}
else
{
readNullTerminated(auth_response, payload);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
readNullTerminated(database, payload);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
readNullTerminated(auth_plugin_name, payload);
}
}
protected:
size_t getPayloadSize() const override
{
size_t size = 0;
size += 4 + 4 + 1 + 23;
size += username.size() + 1;
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
size += getLengthEncodedStringSize(auth_response);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
size += (1 + auth_response.size());
}
else
{
size += (auth_response.size() + 1);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
size += (database.size() + 1);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
size += (auth_plugin_name.size() + 1);
}
return size;
}
};
class AuthSwitchRequest : public WritePacket
{
String plugin_name;
String auth_plugin_data;
public:
AuthSwitchRequest(String plugin_name_, String auth_plugin_data_)
: plugin_name(std::move(plugin_name_)), auth_plugin_data(std::move(auth_plugin_data_))
{
}
protected:
size_t getPayloadSize() const override
{
return 2 + plugin_name.size() + auth_plugin_data.size();
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(0xfe);
writeNulTerminatedString(plugin_name, buffer);
writeString(auth_plugin_data, buffer);
}
};
class AuthSwitchResponse : public LimitedReadPacket
{
public:
String value;
void readPayloadImpl(ReadBuffer & payload) override
{
readStringUntilEOF(value, payload);
}
};
class AuthMoreData : public WritePacket
{
String data;
public:
explicit AuthMoreData(String data_): data(std::move(data_)) {}
protected:
size_t getPayloadSize() const override
{
return 1 + data.size();
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(0x01);
writeString(data, buffer);
}
};
class OK_Packet : public WritePacket, public ReadPacket
{
public:
uint8_t header;
uint32_t capabilities;
uint64_t affected_rows;
uint64_t last_insert_id;
int16_t warnings = 0;
uint32_t status_flags;
String session_state_changes;
String info;
OK_Packet(uint32_t capabilities_)
: header(0x00), capabilities(capabilities_), affected_rows(0x00), last_insert_id(0x00), status_flags(0x00)
{
}
OK_Packet(
uint8_t header_,
uint32_t capabilities_,
uint64_t affected_rows_,
uint32_t status_flags_,
int16_t warnings_,
String session_state_changes_ = "",
String info_ = "")
: header(header_)
, capabilities(capabilities_)
, affected_rows(affected_rows_)
, last_insert_id(0)
, warnings(warnings_)
, status_flags(status_flags_)
, session_state_changes(std::move(session_state_changes_))
, info(std::move(info_))
{
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(header);
writeLengthEncodedNumber(affected_rows, buffer);
writeLengthEncodedNumber(last_insert_id, buffer); /// last insert-id
if (capabilities & CLIENT_PROTOCOL_41)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
writeLengthEncodedString(info, buffer);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
writeLengthEncodedString(session_state_changes, buffer);
}
else
{
writeString(info, buffer);
}
}
void readPayloadImpl(ReadBuffer & payload) override
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
affected_rows = readLengthEncodedNumber(payload);
last_insert_id = readLengthEncodedNumber(payload);
if (capabilities & CLIENT_PROTOCOL_41)
{
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
readLengthEncodedString(info, payload);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
{
readLengthEncodedString(session_state_changes, payload);
}
}
else
{
readString(info, payload);
}
}
protected:
size_t getPayloadSize() const override
{
size_t result = 2 + getLengthEncodedNumberSize(affected_rows);
if (capabilities & CLIENT_PROTOCOL_41)
{
result += 4;
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
result += 2;
}
if (capabilities & CLIENT_SESSION_TRACK)
{
result += getLengthEncodedStringSize(info);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
result += getLengthEncodedStringSize(session_state_changes);
}
else
{
result += info.size();
}
return result;
}
};
class EOF_Packet : public WritePacket, public ReadPacket
{
public:
UInt8 header = 0xfe;
int warnings;
int status_flags;
EOF_Packet() : warnings(0x00), status_flags(0x00) { }
EOF_Packet(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_) { }
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(header); // EOF header
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
void readPayloadImpl(ReadBuffer & payload) override
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xfe);
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
}
protected:
size_t getPayloadSize() const override
{
return 5;
}
};
class AuthSwitch_Packet : public ReadPacket
{
public:
String plugin_name;
AuthSwitch_Packet() { }
void readPayloadImpl(ReadBuffer & payload) override
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xfe);
readStringUntilEOF(plugin_name, payload);
}
private:
UInt8 header = 0x00;
};
class ERR_Packet : public WritePacket, public ReadPacket
{
public:
UInt8 header = 0xff;
int error_code;
String sql_state;
String error_message;
ERR_Packet() : error_code(0x00) { }
ERR_Packet(int error_code_, String sql_state_, String error_message_)
: error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_))
{
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(header);
buffer.write(reinterpret_cast<const char *>(&error_code), 2);
buffer.write('#');
buffer.write(sql_state.data(), sql_state.length());
buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
}
void readPayloadImpl(ReadBuffer & payload) override
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xff);
payload.readStrict(reinterpret_cast<char *>(&error_code), 2);
/// SQL State [optional: # + 5bytes string]
UInt8 sharp = static_cast<unsigned char>(*payload.position());
if (sharp == 0x23)
{
payload.ignore(1);
sql_state.resize(5);
payload.readStrict(reinterpret_cast<char *>(sql_state.data()), 5);
}
readString(error_message, payload);
}
protected:
size_t getPayloadSize() const override
{
return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE);
}
};
/// https://dev.mysql.com/doc/internals/en/generic-response-packets.html
class PacketResponse : public ReadPacket
{
public:
OK_Packet ok;
ERR_Packet err;
EOF_Packet eof;
AuthSwitch_Packet auth_switch;
UInt64 column_length = 0;
PacketResponse(UInt32 server_capability_flags_) : ok(OK_Packet(server_capability_flags_)) { }
PacketResponse(UInt32 server_capability_flags_, bool is_handshake_)
: ok(OK_Packet(server_capability_flags_)), is_handshake(is_handshake_)
{
}
void readPayloadImpl(ReadBuffer & payload) override
{
UInt16 header = static_cast<unsigned char>(*payload.position());
switch (header)
{
case PACKET_OK:
packetType = PACKET_OK;
ok.readPayloadImpl(payload);
break;
case PACKET_ERR:
packetType = PACKET_ERR;
err.readPayloadImpl(payload);
break;
case PACKET_EOF:
if (is_handshake)
{
packetType = PACKET_AUTH_SWITCH;
auth_switch.readPayloadImpl(payload);
}
else
{
packetType = PACKET_EOF;
eof.readPayloadImpl(payload);
}
break;
case PACKET_LOCALINFILE:
packetType = PACKET_LOCALINFILE;
break;
default:
packetType = PACKET_OK;
column_length = readLengthEncodedNumber(payload);
}
}
ResponsePacketType getType() { return packetType; }
private:
bool is_handshake = false;
ResponsePacketType packetType = PACKET_OK;
};
class ColumnDefinition : public WritePacket, public ReadPacket
{
public:
String schema;
String table;
String org_table;
String name;
String org_name;
size_t next_length = 0x0c;
uint16_t character_set;
uint32_t column_length;
ColumnType column_type;
uint16_t flags;
uint8_t decimals = 0x00;
ColumnDefinition() : character_set(0x00), column_length(0), column_type(MYSQL_TYPE_DECIMAL), flags(0x00) { }
ColumnDefinition(
String schema_,
String table_,
String org_table_,
String name_,
String org_name_,
uint16_t character_set_,
uint32_t column_length_,
ColumnType column_type_,
uint16_t flags_,
uint8_t decimals_)
: schema(std::move(schema_)), table(std::move(table_)), org_table(std::move(org_table_)), name(std::move(name_)),
org_name(std::move(org_name_)), character_set(character_set_), column_length(column_length_), column_type(column_type_), flags(flags_),
decimals(decimals_)
{
}
/// Should be used when column metadata (original name, table, original table, database) is unknown.
ColumnDefinition(
String name_,
uint16_t character_set_,
uint32_t column_length_,
ColumnType column_type_,
uint16_t flags_,
uint8_t decimals_)
: ColumnDefinition("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_)
{
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
writeLengthEncodedString(std::string("def"), buffer); /// always "def"
writeLengthEncodedString(schema, buffer);
writeLengthEncodedString(table, buffer);
writeLengthEncodedString(org_table, buffer);
writeLengthEncodedString(name, buffer);
writeLengthEncodedString(org_name, buffer);
writeLengthEncodedNumber(next_length, buffer);
buffer.write(reinterpret_cast<const char *>(&character_set), 2);
buffer.write(reinterpret_cast<const char *>(&column_length), 4);
buffer.write(reinterpret_cast<const char *>(&column_type), 1);
buffer.write(reinterpret_cast<const char *>(&flags), 2);
buffer.write(reinterpret_cast<const char *>(&decimals), 2);
writeChar(0x0, 2, buffer);
}
void readPayloadImpl(ReadBuffer & payload) override
{
String def;
readLengthEncodedString(def, payload);
assert(def == "def");
readLengthEncodedString(schema, payload);
readLengthEncodedString(table, payload);
readLengthEncodedString(org_table, payload);
readLengthEncodedString(name, payload);
readLengthEncodedString(org_name, payload);
next_length = readLengthEncodedNumber(payload);
payload.readStrict(reinterpret_cast<char *>(&character_set), 2);
payload.readStrict(reinterpret_cast<char *>(&column_length), 4);
payload.readStrict(reinterpret_cast<char *>(&column_type), 1);
payload.readStrict(reinterpret_cast<char *>(&flags), 2);
payload.readStrict(reinterpret_cast<char *>(&decimals), 2);
payload.ignore(2);
}
protected:
size_t getPayloadSize() const override
{
return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \
getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length);
}
};
class ComFieldList : public LimitedReadPacket
{
public:
String table, field_wildcard;
void readPayloadImpl(ReadBuffer & payload) override
{
// Command byte has been already read from payload.
readNullTerminated(table, payload);
readStringUntilEOF(field_wildcard, payload);
}
};
class LengthEncodedNumber : public WritePacket
{
uint64_t value;
public:
explicit LengthEncodedNumber(uint64_t value_): value(value_)
{
}
protected:
size_t getPayloadSize() const override
{
return getLengthEncodedNumberSize(value);
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
writeLengthEncodedNumber(value, buffer);
}
};
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex index);
ColumnDefinitionPacket getColumnDefinition(const String & column_name, const TypeIndex index);
namespace ProtocolText
{
class ResultsetRow : public WritePacket
class ResultsetRow : public IMySQLWritePacket
{
const Columns & columns;
int row_num;
@ -1431,7 +644,7 @@ private:
namespace Replication
{
/// https://dev.mysql.com/doc/internals/en/com-register-slave.html
class RegisterSlave : public WritePacket
class RegisterSlave : public IMySQLWritePacket
{
public:
UInt8 header = COM_REGISTER_SLAVE;
@ -1466,7 +679,7 @@ namespace Replication
};
/// https://dev.mysql.com/doc/internals/en/com-binlog-dump.html
class BinlogDump : public WritePacket
class BinlogDump : public IMySQLWritePacket
{
public:
UInt8 header = COM_BINLOG_DUMP;

View File

@ -717,8 +717,8 @@ namespace MySQLReplication
case PACKET_EOF:
throw ReplicationError("Master maybe lost", ErrorCodes::UNKNOWN_EXCEPTION);
case PACKET_ERR:
ERR_Packet err;
err.readPayloadImpl(payload);
ERRPacket err;
err.readPayloadWithUnpacked(payload);
throw ReplicationError(err.error_message, ErrorCodes::UNKNOWN_EXCEPTION);
}
// skip the header flag.

View File

@ -466,7 +466,7 @@ namespace MySQLReplication
void updateLogName(String binlog) { binlog_name = std::move(binlog); }
};
class IFlavor : public MySQLProtocol::ReadPacket
class IFlavor : public MySQLProtocol::IMySQLReadPacket
{
public:
virtual String getName() const = 0;

View File

@ -14,6 +14,7 @@ int main(int argc, char ** argv)
using namespace MySQLProtocol::Authentication;
uint8_t sequence_id = 1;
String user = "default";
String password = "123";
String database;
@ -35,12 +36,12 @@ int main(int argc, char ** argv)
WriteBufferFromString out0(s0);
Handshake server_handshake(server_capability_flags, -1, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa");
server_handshake.writePayloadImpl(out0);
server_handshake.writePayload(out0, sequence_id);
/// 1.2 Client reads the greeting
ReadBufferFromString in0(s0);
Handshake client_handshake;
client_handshake.readPayloadImpl(in0);
client_handshake.readPayload(in0, sequence_id);
/// Check packet
ASSERT(server_handshake.capability_flags == client_handshake.capability_flags)
@ -59,12 +60,12 @@ int main(int argc, char ** argv)
String auth_plugin_data = native41.getAuthPluginData();
HandshakeResponse client_handshake_response(
client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password);
client_handshake_response.writePayloadImpl(out1);
client_handshake_response.writePayload(out1, sequence_id);
/// 2.2 Server reads the response
ReadBufferFromString in1(s1);
HandshakeResponse server_handshake_response;
server_handshake_response.readPayloadImpl(in1);
server_handshake_response.readPayload(in1, sequence_id);
/// Check
ASSERT(server_handshake_response.capability_flags == client_handshake_response.capability_flags)
@ -80,13 +81,13 @@ int main(int argc, char ** argv)
// 1. Server writes packet
std::string s0;
WriteBufferFromString out0(s0);
OK_Packet server(0x00, server_capability_flags, 0, 0, 0, "", "");
server.writePayloadImpl(out0);
OKPacket server(0x00, server_capability_flags, 0, 0, 0, "", "");
server.writePayload(out0, sequence_id);
// 2. Client reads packet
ReadBufferFromString in0(s0);
PacketResponse client(server_capability_flags);
client.readPayloadImpl(in0);
ResponsePacket client(server_capability_flags);
client.readPayload(in0, sequence_id);
// Check
ASSERT(client.getType() == PACKET_OK)
@ -100,13 +101,13 @@ int main(int argc, char ** argv)
// 1. Server writes packet
std::string s0;
WriteBufferFromString out0(s0);
ERR_Packet server(123, "12345", "This is the error message");
server.writePayloadImpl(out0);
ERRPacket server(123, "12345", "This is the error message");
server.writePayload(out0, sequence_id);
// 2. Client reads packet
ReadBufferFromString in0(s0);
PacketResponse client(server_capability_flags);
client.readPayloadImpl(in0);
ResponsePacket client(server_capability_flags);
client.readPayload(in0, sequence_id);
// Check
ASSERT(client.getType() == PACKET_ERR)
@ -121,13 +122,13 @@ int main(int argc, char ** argv)
// 1. Server writes packet
std::string s0;
WriteBufferFromString out0(s0);
EOF_Packet server(1, 1);
server.writePayloadImpl(out0);
EOFPacket server(1, 1);
server.writePayload(out0, sequence_id);
// 2. Client reads packet
ReadBufferFromString in0(s0);
PacketResponse client(server_capability_flags);
client.readPayloadImpl(in0);
ResponsePacket client(server_capability_flags);
client.readPayload(in0, sequence_id);
// Check
ASSERT(client.getType() == PACKET_EOF)
@ -141,13 +142,13 @@ int main(int argc, char ** argv)
// 1. Server writes packet
std::string s0;
WriteBufferFromString out0(s0);
ColumnDefinition server("schema", "tbl", "org_tbl", "name", "org_name", 33, 0x00, MYSQL_TYPE_STRING, 0x00, 0x00);
server.writePayloadImpl(out0);
ColumnDefinitionPacket server("schema", "tbl", "org_tbl", "name", "org_name", 33, 0x00, MYSQL_TYPE_STRING, 0x00, 0x00);
server.writePayload(out0, sequence_id);
// 2. Client reads packet
ReadBufferFromString in0(s0);
ColumnDefinition client;
client.readPayloadImpl(in0);
ColumnDefinitionPacket client;
client.readPayload(in0, sequence_id);
// Check
ASSERT(client.column_type == server.column_type)

View File

@ -39,7 +39,7 @@ void MySQLOutputFormat::initialize()
if (!(context->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
{
packet_sender->sendPacket(EOF_Packet(0, 0));
packet_sender->sendPacket(EOFPacket(0, 0));
}
}
}
@ -75,12 +75,12 @@ void MySQLOutputFormat::finalize()
const auto & header = getPort(PortKind::Main).getHeader();
if (header.columns() == 0)
packet_sender->sendPacket(OK_Packet(0x0, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
packet_sender->sendPacket(OKPacket(0x0, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else
if (context->mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OK_Packet(0xfe, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
packet_sender->sendPacket(OKPacket(0xfe, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else
packet_sender->sendPacket(EOF_Packet(0, 0), true);
packet_sender->sendPacket(EOFPacket(0, 0), true);
}
void MySQLOutputFormat::flush()

View File

@ -117,10 +117,10 @@ void MySQLHandler::run()
catch (const Exception & exc)
{
log->log(exc);
packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
packet_sender->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
}
OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
OKPacket ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
packet_sender->sendPacket(ok_packet, true);
while (true)
@ -166,7 +166,7 @@ void MySQLHandler::run()
}
catch (...)
{
packet_sender->sendPacket(ERR_Packet(getCurrentExceptionCode(), "00000", getCurrentExceptionMessage(false)), true);
packet_sender->sendPacket(ERRPacket(getCurrentExceptionCode(), "00000", getCurrentExceptionMessage(false)), true);
}
}
}
@ -218,7 +218,7 @@ void MySQLHandler::finishHandshake(MySQLProtocol::HandshakeResponse & packet)
copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos);
ReadBufferFromString payload(buf_for_handshake_response.str());
payload.ignore(PACKET_HEADER_SIZE);
packet.readPayloadImpl(payload);
packet.readPayloadWithUnpacked(payload);
packet_sender->sequence_id++;
}
}
@ -241,7 +241,7 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
catch (const Exception & exc)
{
LOG_ERROR(log, "Authentication for user {} failed.", user_name);
packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
packet_sender->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true);
throw;
}
LOG_INFO(log, "Authentication for user {} succeeded.", user_name);
@ -253,29 +253,29 @@ void MySQLHandler::comInitDB(ReadBuffer & payload)
readStringUntilEOF(database, payload);
LOG_DEBUG(log, "Setting current database to {}", database);
connection_context.setCurrentDatabase(database);
packet_sender->sendPacket(OK_Packet(0, client_capability_flags, 0, 0, 1), true);
packet_sender->sendPacket(OKPacket(0, client_capability_flags, 0, 0, 1), true);
}
void MySQLHandler::comFieldList(ReadBuffer & payload)
{
ComFieldList packet;
packet.readPayloadImpl(payload);
packet.readPayloadWithUnpacked(payload);
String database = connection_context.getCurrentDatabase();
StoragePtr table_ptr = DatabaseCatalog::instance().getTable({database, packet.table}, connection_context);
auto metadata_snapshot = table_ptr->getInMemoryMetadataPtr();
for (const NameAndTypePair & column : metadata_snapshot->getColumns().getAll())
{
ColumnDefinition column_definition(
ColumnDefinitionPacket column_definition(
database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0
);
packet_sender->sendPacket(column_definition);
}
packet_sender->sendPacket(OK_Packet(0xfe, client_capability_flags, 0, 0, 0), true);
packet_sender->sendPacket(OKPacket(0xfe, client_capability_flags, 0, 0, 0), true);
}
void MySQLHandler::comPing()
{
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
packet_sender->sendPacket(OKPacket(0x0, client_capability_flags, 0, 0, 0), true);
}
static bool isFederatedServerSetupSetCommand(const String & query);
@ -288,7 +288,7 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
// As Clickhouse doesn't support these statements, we just send OK packet in response.
if (isFederatedServerSetupSetCommand(query))
{
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
packet_sender->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true);
}
else
{
@ -318,7 +318,7 @@ void MySQLHandler::comQuery(ReadBuffer & payload)
);
if (!with_output)
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
packet_sender->sendPacket(OKPacket(0x00, client_capability_flags, 0, 0, 0), true);
}
}
@ -350,7 +350,7 @@ void MySQLHandlerSSL::finishHandshakeSSL(size_t packet_size, char * buf, size_t
SSLRequest ssl_request;
ReadBufferFromMemory payload(buf, pos);
payload.ignore(PACKET_HEADER_SIZE);
ssl_request.readPayloadImpl(payload);
ssl_request.readPayloadWithUnpacked(payload);
connection_context.mysql.client_capabilities = ssl_request.capability_flags;
connection_context.mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
secure_connection = true;