diff --git a/src/Core/MySQL/IMySQLReadPacket.cpp b/src/Core/MySQL/IMySQLReadPacket.cpp index 351a7deb698..4ad7a22b87f 100644 --- a/src/Core/MySQL/IMySQLReadPacket.cpp +++ b/src/Core/MySQL/IMySQLReadPacket.cpp @@ -37,6 +37,38 @@ void LimitedReadPacket::readPayloadWithUnpacked(ReadBuffer & in) IMySQLReadPacket::readPayloadWithUnpacked(limited); } +uint64_t readLengthEncodedNumber(ReadBuffer & buffer) +{ + char c{}; + uint64_t buf = 0; + buffer.readStrict(c); + auto cc = static_cast(c); + if (cc < 0xfc) + { + return cc; + } + else if (cc < 0xfd) + { + buffer.readStrict(reinterpret_cast(&buf), 2); + } + else if (cc < 0xfe) + { + buffer.readStrict(reinterpret_cast(&buf), 3); + } + else + { + buffer.readStrict(reinterpret_cast(&buf), 8); + } + return buf; +} + +void readLengthEncodedString(String & s, ReadBuffer & buffer) +{ + uint64_t len = readLengthEncodedNumber(buffer); + s.resize(len); + buffer.readStrict(reinterpret_cast(s.data()), len); +} + } } diff --git a/src/Core/MySQL/IMySQLReadPacket.h b/src/Core/MySQL/IMySQLReadPacket.h index 54a0df52224..874aa31151c 100644 --- a/src/Core/MySQL/IMySQLReadPacket.h +++ b/src/Core/MySQL/IMySQLReadPacket.h @@ -33,7 +33,7 @@ public: void readPayloadWithUnpacked(ReadBuffer & in) override; }; -uint64_t readLengthEncodedNumber(ReadBuffer & ss); +uint64_t readLengthEncodedNumber(ReadBuffer & buffer); void readLengthEncodedString(String & s, ReadBuffer & buffer); //inline void readLengthEncodedString(String & s, ReadBuffer & buffer) diff --git a/src/Core/MySQL/IMySQLWritePacket.cpp b/src/Core/MySQL/IMySQLWritePacket.cpp index 84931d08922..4ebc8a21584 100644 --- a/src/Core/MySQL/IMySQLWritePacket.cpp +++ b/src/Core/MySQL/IMySQLWritePacket.cpp @@ -20,6 +20,66 @@ void IMySQLWritePacket::writePayload(WriteBuffer & buffer, uint8_t & sequence_id } } +size_t getLengthEncodedNumberSize(uint64_t x) +{ + if (x < 251) + { + return 1; + } + else if (x < (1 << 16)) + { + return 3; + } + else if (x < (1 << 24)) + { + return 4; + } + else + { + return 9; + } +} + +size_t getLengthEncodedStringSize(const String & s) +{ + return getLengthEncodedNumberSize(s.size()) + s.size(); +} + +void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer) +{ + if (x < 251) + { + buffer.write(static_cast(x)); + } + else if (x < (1 << 16)) + { + buffer.write(0xfc); + buffer.write(reinterpret_cast(&x), 2); + } + else if (x < (1 << 24)) + { + buffer.write(0xfd); + buffer.write(reinterpret_cast(&x), 3); + } + else + { + buffer.write(0xfe); + buffer.write(reinterpret_cast(&x), 8); + } +} + +void writeLengthEncodedString(const String & s, WriteBuffer & buffer) +{ + writeLengthEncodedNumber(s.size(), buffer); + buffer.write(s.data(), s.size()); +} + +void writeNulTerminatedString(const String & s, WriteBuffer & buffer) +{ + buffer.write(s.data(), s.size()); + buffer.write(0); +} + } } diff --git a/src/Core/MySQL/IMySQLWritePacket.h b/src/Core/MySQL/IMySQLWritePacket.h index a89529f370e..b00ffbaacae 100644 --- a/src/Core/MySQL/IMySQLWritePacket.h +++ b/src/Core/MySQL/IMySQLWritePacket.h @@ -25,6 +25,13 @@ protected: virtual void writePayloadImpl(WriteBuffer & buffer) const = 0; }; +size_t getLengthEncodedNumberSize(uint64_t x); +size_t getLengthEncodedStringSize(const String & s); + +void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer); +void writeLengthEncodedString(const String & s, WriteBuffer & buffer); +void writeNulTerminatedString(const String & s, WriteBuffer & buffer); + } } diff --git a/src/Core/MySQL/MySQLPackets.cpp b/src/Core/MySQL/MySQLPackets.cpp index 2370c2e22d0..f858463b91b 100644 --- a/src/Core/MySQL/MySQLPackets.cpp +++ b/src/Core/MySQL/MySQLPackets.cpp @@ -235,75 +235,6 @@ void ResponsePacket::readPayloadImpl(ReadBuffer & payload) } } -ColumnDefinitionPacket::ColumnDefinitionPacket() - : character_set(0x00), column_length(0), column_type(MYSQL_TYPE_DECIMAL), flags(0x00) -{ -} - -ColumnDefinitionPacket::ColumnDefinitionPacket( - String schema_, String table_, String org_table_, String name_, String org_name_, uint16_t character_set_, uint32_t column_length_, - ColumnType column_type_, uint16_t flags_, uint8_t decimals_) - : schema(std::move(schema_)), table(std::move(table_)), org_table(std::move(org_table_)), name(std::move(name_)), - org_name(std::move(org_name_)), character_set(character_set_), column_length(column_length_), column_type(column_type_), - flags(flags_), decimals(decimals_) -{ -} - -ColumnDefinitionPacket::ColumnDefinitionPacket( - String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_) - : ColumnDefinitionPacket("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_) -{ -} - -size_t ColumnDefinitionPacket::getPayloadSize() const -{ - return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \ - getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length); -} - -void ColumnDefinitionPacket::readPayloadImpl(ReadBuffer & payload) -{ - String def; - readLengthEncodedString(def, payload); - assert(def == "def"); - readLengthEncodedString(schema, payload); - readLengthEncodedString(table, payload); - readLengthEncodedString(org_table, payload); - readLengthEncodedString(name, payload); - readLengthEncodedString(org_name, payload); - next_length = readLengthEncodedNumber(payload); - payload.readStrict(reinterpret_cast(&character_set), 2); - payload.readStrict(reinterpret_cast(&column_length), 4); - payload.readStrict(reinterpret_cast(&column_type), 1); - payload.readStrict(reinterpret_cast(&flags), 2); - payload.readStrict(reinterpret_cast(&decimals), 2); - payload.ignore(2); -} - -void ColumnDefinitionPacket::writePayloadImpl(WriteBuffer & buffer) const -{ - writeLengthEncodedString(std::string("def"), buffer); /// always "def" - writeLengthEncodedString(schema, buffer); - writeLengthEncodedString(table, buffer); - writeLengthEncodedString(org_table, buffer); - writeLengthEncodedString(name, buffer); - writeLengthEncodedString(org_name, buffer); - writeLengthEncodedNumber(next_length, buffer); - buffer.write(reinterpret_cast(&character_set), 2); - buffer.write(reinterpret_cast(&column_length), 4); - buffer.write(reinterpret_cast(&column_type), 1); - buffer.write(reinterpret_cast(&flags), 2); - buffer.write(reinterpret_cast(&decimals), 2); - writeChar(0x0, 2, buffer); -} - -void ComFieldList::readPayloadImpl(ReadBuffer & payload) -{ - // Command byte has been already read from payload. - readNullTerminated(table, payload); - readStringUntilEOF(field_wildcard, payload); -} - LengthEncodedNumber::LengthEncodedNumber(uint64_t value_) : value(value_) { } diff --git a/src/Core/MySQL/MySQLPackets.h b/src/Core/MySQL/MySQLPackets.h index 60c16794b6f..daaadd66281 100644 --- a/src/Core/MySQL/MySQLPackets.h +++ b/src/Core/MySQL/MySQLPackets.h @@ -136,84 +136,6 @@ public: ResponsePacket(UInt32 server_capability_flags_, bool is_handshake_); }; -enum ColumnType -{ - MYSQL_TYPE_DECIMAL = 0x00, - MYSQL_TYPE_TINY = 0x01, - MYSQL_TYPE_SHORT = 0x02, - MYSQL_TYPE_LONG = 0x03, - MYSQL_TYPE_FLOAT = 0x04, - MYSQL_TYPE_DOUBLE = 0x05, - MYSQL_TYPE_NULL = 0x06, - MYSQL_TYPE_TIMESTAMP = 0x07, - MYSQL_TYPE_LONGLONG = 0x08, - MYSQL_TYPE_INT24 = 0x09, - MYSQL_TYPE_DATE = 0x0a, - MYSQL_TYPE_TIME = 0x0b, - MYSQL_TYPE_DATETIME = 0x0c, - MYSQL_TYPE_YEAR = 0x0d, - MYSQL_TYPE_NEWDATE = 0x0e, - MYSQL_TYPE_VARCHAR = 0x0f, - MYSQL_TYPE_BIT = 0x10, - MYSQL_TYPE_TIMESTAMP2 = 0x11, - MYSQL_TYPE_DATETIME2 = 0x12, - MYSQL_TYPE_TIME2 = 0x13, - MYSQL_TYPE_JSON = 0xf5, - MYSQL_TYPE_NEWDECIMAL = 0xf6, - MYSQL_TYPE_ENUM = 0xf7, - MYSQL_TYPE_SET = 0xf8, - MYSQL_TYPE_TINY_BLOB = 0xf9, - MYSQL_TYPE_MEDIUM_BLOB = 0xfa, - MYSQL_TYPE_LONG_BLOB = 0xfb, - MYSQL_TYPE_BLOB = 0xfc, - MYSQL_TYPE_VAR_STRING = 0xfd, - MYSQL_TYPE_STRING = 0xfe, - MYSQL_TYPE_GEOMETRY = 0xff -}; - -class ColumnDefinitionPacket : public IMySQLWritePacket, public IMySQLReadPacket -{ -public: - String schema; - String table; - String org_table; - String name; - String org_name; - size_t next_length = 0x0c; - uint16_t character_set; - uint32_t column_length; - ColumnType column_type; - uint16_t flags; - uint8_t decimals = 0x00; - -protected: - size_t getPayloadSize() const override; - - void readPayloadImpl(ReadBuffer & payload) override; - - void writePayloadImpl(WriteBuffer & buffer) const override; - -public: - ColumnDefinitionPacket(); - - ColumnDefinitionPacket( - String schema_, String table_, String org_table_, String name_, String org_name_, uint16_t character_set_, uint32_t column_length_, - ColumnType column_type_, uint16_t flags_, uint8_t decimals_); - - /// Should be used when column metadata (original name, table, original table, database) is unknown. - ColumnDefinitionPacket( - String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_); - -}; - -class ComFieldList : public LimitedReadPacket -{ -public: - String table, field_wildcard; - - void readPayloadImpl(ReadBuffer & payload) override; -}; - class LengthEncodedNumber : public IMySQLWritePacket { protected: diff --git a/src/Core/MySQL/PacketsConnection.cpp b/src/Core/MySQL/PacketsConnection.cpp index 048a4d17023..8a508ed38b2 100644 --- a/src/Core/MySQL/PacketsConnection.cpp +++ b/src/Core/MySQL/PacketsConnection.cpp @@ -15,9 +15,10 @@ Handshake::Handshake() : connection_id(0x00), capability_flags(0x00), character_ } Handshake::Handshake( - uint32_t capability_flags_, uint32_t connection_id_, String server_version_, String auth_plugin_name_, String auth_plugin_data_) + uint32_t capability_flags_, uint32_t connection_id_, + String server_version_, String auth_plugin_name_, String auth_plugin_data_, uint8_t charset_) : protocol_version(0xa), server_version(std::move(server_version_)), connection_id(connection_id_), capability_flags(capability_flags_), - character_set(CharacterSet::utf8_general_ci), status_flags(0), auth_plugin_name(std::move(auth_plugin_name_)), + character_set(charset_), status_flags(0), auth_plugin_name(std::move(auth_plugin_name_)), auth_plugin_data(std::move(auth_plugin_data_)) { } diff --git a/src/Core/MySQL/PacketsConnection.h b/src/Core/MySQL/PacketsConnection.h index 1a108fdb40a..30a4beb48f2 100644 --- a/src/Core/MySQL/PacketsConnection.h +++ b/src/Core/MySQL/PacketsConnection.h @@ -34,7 +34,9 @@ protected: public: Handshake(); - Handshake(uint32_t capability_flags_, uint32_t connection_id_, String server_version_, String auth_plugin_name_, String auth_plugin_data_); + Handshake( + uint32_t capability_flags_, uint32_t connection_id_, + String server_version_, String auth_plugin_name_, String auth_plugin_data_, uint8_t charset_); }; class HandshakeResponse : public IMySQLWritePacket, public IMySQLReadPacket diff --git a/src/Core/MySQL/PacketsProtocolText.cpp b/src/Core/MySQL/PacketsProtocolText.cpp index fa69f0a8d9b..766bcf636e4 100644 --- a/src/Core/MySQL/PacketsProtocolText.cpp +++ b/src/Core/MySQL/PacketsProtocolText.cpp @@ -1,5 +1,7 @@ #include -#include +#include +#include +#include namespace DB { @@ -46,6 +48,143 @@ void ResultSetRow::writePayloadImpl(WriteBuffer & buffer) const } } +void ComFieldList::readPayloadImpl(ReadBuffer & payload) +{ + // Command byte has been already read from payload. + readNullTerminated(table, payload); + readStringUntilEOF(field_wildcard, payload); +} + +ColumnDefinition::ColumnDefinition() + : character_set(0x00), column_length(0), column_type(MYSQL_TYPE_DECIMAL), flags(0x00) +{ +} + +ColumnDefinition::ColumnDefinition( + String schema_, String table_, String org_table_, String name_, String org_name_, uint16_t character_set_, uint32_t column_length_, + ColumnType column_type_, uint16_t flags_, uint8_t decimals_) + : schema(std::move(schema_)), table(std::move(table_)), org_table(std::move(org_table_)), name(std::move(name_)), + org_name(std::move(org_name_)), character_set(character_set_), column_length(column_length_), column_type(column_type_), + flags(flags_), decimals(decimals_) +{ +} + +ColumnDefinition::ColumnDefinition( + String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_) + : ColumnDefinition("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_) +{ +} + +size_t ColumnDefinition::getPayloadSize() const +{ + return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \ + getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length); +} + +void ColumnDefinition::readPayloadImpl(ReadBuffer & payload) +{ + String def; + readLengthEncodedString(def, payload); + assert(def == "def"); + readLengthEncodedString(schema, payload); + readLengthEncodedString(table, payload); + readLengthEncodedString(org_table, payload); + readLengthEncodedString(name, payload); + readLengthEncodedString(org_name, payload); + next_length = readLengthEncodedNumber(payload); + payload.readStrict(reinterpret_cast(&character_set), 2); + payload.readStrict(reinterpret_cast(&column_length), 4); + payload.readStrict(reinterpret_cast(&column_type), 1); + payload.readStrict(reinterpret_cast(&flags), 2); + payload.readStrict(reinterpret_cast(&decimals), 2); + payload.ignore(2); +} + +void ColumnDefinition::writePayloadImpl(WriteBuffer & buffer) const +{ + writeLengthEncodedString(std::string("def"), buffer); /// always "def" + writeLengthEncodedString(schema, buffer); + writeLengthEncodedString(table, buffer); + writeLengthEncodedString(org_table, buffer); + writeLengthEncodedString(name, buffer); + writeLengthEncodedString(org_name, buffer); + writeLengthEncodedNumber(next_length, buffer); + buffer.write(reinterpret_cast(&character_set), 2); + buffer.write(reinterpret_cast(&column_length), 4); + buffer.write(reinterpret_cast(&column_type), 1); + buffer.write(reinterpret_cast(&flags), 2); + buffer.write(reinterpret_cast(&decimals), 2); + writeChar(0x0, 2, buffer); +} + +ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index) +{ + ColumnType column_type; + CharacterSet charset = CharacterSet::binary; + int flags = 0; + switch (type_index) + { + case TypeIndex::UInt8: + column_type = ColumnType::MYSQL_TYPE_TINY; + flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; + break; + case TypeIndex::UInt16: + column_type = ColumnType::MYSQL_TYPE_SHORT; + flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; + break; + case TypeIndex::UInt32: + column_type = ColumnType::MYSQL_TYPE_LONG; + flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; + break; + case TypeIndex::UInt64: + column_type = ColumnType::MYSQL_TYPE_LONGLONG; + flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; + break; + case TypeIndex::Int8: + column_type = ColumnType::MYSQL_TYPE_TINY; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Int16: + column_type = ColumnType::MYSQL_TYPE_SHORT; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Int32: + column_type = ColumnType::MYSQL_TYPE_LONG; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Int64: + column_type = ColumnType::MYSQL_TYPE_LONGLONG; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Float32: + column_type = ColumnType::MYSQL_TYPE_FLOAT; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Float64: + column_type = ColumnType::MYSQL_TYPE_DOUBLE; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::Date: + column_type = ColumnType::MYSQL_TYPE_DATE; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::DateTime: + column_type = ColumnType::MYSQL_TYPE_DATETIME; + flags = ColumnDefinitionFlags::BINARY_FLAG; + break; + case TypeIndex::String: + case TypeIndex::FixedString: + column_type = ColumnType::MYSQL_TYPE_STRING; + charset = CharacterSet::utf8_general_ci; + break; + default: + column_type = ColumnType::MYSQL_TYPE_STRING; + charset = CharacterSet::utf8_general_ci; + break; + } + return ColumnDefinition(column_name, charset, 0, column_type, flags, 0); +} + } } diff --git a/src/Core/MySQL/PacketsProtocolText.h b/src/Core/MySQL/PacketsProtocolText.h index 504a3b98b3b..d449e94cff1 100644 --- a/src/Core/MySQL/PacketsProtocolText.h +++ b/src/Core/MySQL/PacketsProtocolText.h @@ -1,9 +1,10 @@ #pragma once -#include #include #include -#include + +#include +#include namespace DB { @@ -14,6 +15,54 @@ namespace MySQLProtocol namespace ProtocolText { +enum CharacterSet +{ + utf8_general_ci = 33, + binary = 63 +}; + +// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html +enum ColumnDefinitionFlags +{ + UNSIGNED_FLAG = 32, + BINARY_FLAG = 128 +}; + +enum ColumnType +{ + MYSQL_TYPE_DECIMAL = 0x00, + MYSQL_TYPE_TINY = 0x01, + MYSQL_TYPE_SHORT = 0x02, + MYSQL_TYPE_LONG = 0x03, + MYSQL_TYPE_FLOAT = 0x04, + MYSQL_TYPE_DOUBLE = 0x05, + MYSQL_TYPE_NULL = 0x06, + MYSQL_TYPE_TIMESTAMP = 0x07, + MYSQL_TYPE_LONGLONG = 0x08, + MYSQL_TYPE_INT24 = 0x09, + MYSQL_TYPE_DATE = 0x0a, + MYSQL_TYPE_TIME = 0x0b, + MYSQL_TYPE_DATETIME = 0x0c, + MYSQL_TYPE_YEAR = 0x0d, + MYSQL_TYPE_NEWDATE = 0x0e, + MYSQL_TYPE_VARCHAR = 0x0f, + MYSQL_TYPE_BIT = 0x10, + MYSQL_TYPE_TIMESTAMP2 = 0x11, + MYSQL_TYPE_DATETIME2 = 0x12, + MYSQL_TYPE_TIME2 = 0x13, + MYSQL_TYPE_JSON = 0xf5, + MYSQL_TYPE_NEWDECIMAL = 0xf6, + MYSQL_TYPE_ENUM = 0xf7, + MYSQL_TYPE_SET = 0xf8, + MYSQL_TYPE_TINY_BLOB = 0xf9, + MYSQL_TYPE_MEDIUM_BLOB = 0xfa, + MYSQL_TYPE_LONG_BLOB = 0xfb, + MYSQL_TYPE_BLOB = 0xfc, + MYSQL_TYPE_VAR_STRING = 0xfd, + MYSQL_TYPE_STRING = 0xfe, + MYSQL_TYPE_GEOMETRY = 0xff +}; + class ResultSetRow : public IMySQLWritePacket { protected: @@ -30,6 +79,51 @@ public: ResultSetRow(const DataTypes & data_types, const Columns & columns_, int row_num_); }; +class ComFieldList : public LimitedReadPacket +{ +public: + String table, field_wildcard; + + void readPayloadImpl(ReadBuffer & payload) override; +}; + +class ColumnDefinition : public IMySQLWritePacket, public IMySQLReadPacket +{ +public: + String schema; + String table; + String org_table; + String name; + String org_name; + size_t next_length = 0x0c; + uint16_t character_set; + uint32_t column_length; + ColumnType column_type; + uint16_t flags; + uint8_t decimals = 0x00; + +protected: + size_t getPayloadSize() const override; + + void readPayloadImpl(ReadBuffer & payload) override; + + void writePayloadImpl(WriteBuffer & buffer) const override; + +public: + ColumnDefinition(); + + ColumnDefinition( + String schema_, String table_, String org_table_, String name_, String org_name_, uint16_t character_set_, uint32_t column_length_, + ColumnType column_type_, uint16_t flags_, uint8_t decimals_); + + /// Should be used when column metadata (original name, table, original table, database) is unknown. + ColumnDefinition( + String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_); + +}; + +ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex index); + } } diff --git a/src/Core/MySQLProtocol.cpp b/src/Core/MySQLProtocol.cpp index 94eaa6d0790..f9d12e8b351 100644 --- a/src/Core/MySQLProtocol.cpp +++ b/src/Core/MySQLProtocol.cpp @@ -10,168 +10,5 @@ namespace DB::MySQLProtocol { - extern const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb - - - -uint64_t readLengthEncodedNumber(ReadBuffer & ss) -{ - char c{}; - uint64_t buf = 0; - ss.readStrict(c); - auto cc = static_cast(c); - if (cc < 0xfc) - { - return cc; - } - else if (cc < 0xfd) - { - ss.readStrict(reinterpret_cast(&buf), 2); - } - else if (cc < 0xfe) - { - ss.readStrict(reinterpret_cast(&buf), 3); - } - else - { - ss.readStrict(reinterpret_cast(&buf), 8); - } - return buf; -} - -void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer) -{ - if (x < 251) - { - buffer.write(static_cast(x)); - } - else if (x < (1 << 16)) - { - buffer.write(0xfc); - buffer.write(reinterpret_cast(&x), 2); - } - else if (x < (1 << 24)) - { - buffer.write(0xfd); - buffer.write(reinterpret_cast(&x), 3); - } - else - { - buffer.write(0xfe); - buffer.write(reinterpret_cast(&x), 8); - } -} - -size_t getLengthEncodedNumberSize(uint64_t x) -{ - if (x < 251) - { - return 1; - } - else if (x < (1 << 16)) - { - return 3; - } - else if (x < (1 << 24)) - { - return 4; - } - else - { - return 9; - } -} - -size_t getLengthEncodedStringSize(const String & s) -{ - return getLengthEncodedNumberSize(s.size()) + s.size(); -} - -ColumnDefinitionPacket getColumnDefinition(const String & column_name, const TypeIndex type_index) -{ - ColumnType column_type; - CharacterSet charset = CharacterSet::binary; - int flags = 0; - switch (type_index) - { - case TypeIndex::UInt8: - column_type = ColumnType::MYSQL_TYPE_TINY; - flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; - break; - case TypeIndex::UInt16: - column_type = ColumnType::MYSQL_TYPE_SHORT; - flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; - break; - case TypeIndex::UInt32: - column_type = ColumnType::MYSQL_TYPE_LONG; - flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; - break; - case TypeIndex::UInt64: - column_type = ColumnType::MYSQL_TYPE_LONGLONG; - flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG; - break; - case TypeIndex::Int8: - column_type = ColumnType::MYSQL_TYPE_TINY; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Int16: - column_type = ColumnType::MYSQL_TYPE_SHORT; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Int32: - column_type = ColumnType::MYSQL_TYPE_LONG; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Int64: - column_type = ColumnType::MYSQL_TYPE_LONGLONG; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Float32: - column_type = ColumnType::MYSQL_TYPE_FLOAT; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Float64: - column_type = ColumnType::MYSQL_TYPE_DOUBLE; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::Date: - column_type = ColumnType::MYSQL_TYPE_DATE; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::DateTime: - column_type = ColumnType::MYSQL_TYPE_DATETIME; - flags = ColumnDefinitionFlags::BINARY_FLAG; - break; - case TypeIndex::String: - case TypeIndex::FixedString: - column_type = ColumnType::MYSQL_TYPE_STRING; - charset = CharacterSet::utf8_general_ci; - break; - default: - column_type = ColumnType::MYSQL_TYPE_STRING; - charset = CharacterSet::utf8_general_ci; - break; - } - return ColumnDefinitionPacket(column_name, charset, 0, column_type, flags, 0); -} - -//void ReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id) -//{ -// PacketPayloadReadBuffer payload(in, sequence_id); -// payload.next(); -// readPayloadImpl(payload); -// if (!payload.eof()) -// { -// std::stringstream tmp; -// tmp << "Packet payload is not fully read. Stopped after " << payload.count() << " bytes, while " << payload.available() << " bytes are in buffer."; -// throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT); -// } -//} -// -//void LimitedReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id) -//{ -// LimitReadBuffer limited(in, 10000, true, "too long MySQL packet."); -// ReadPacket::readPayload(limited, sequence_id); -//} - +extern const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb } diff --git a/src/Core/MySQLProtocol.h b/src/Core/MySQLProtocol.h index 2fe52c0431c..57e9b055d6b 100644 --- a/src/Core/MySQLProtocol.h +++ b/src/Core/MySQLProtocol.h @@ -62,13 +62,6 @@ const size_t MYSQL_ERRMSG_SIZE = 512; const size_t PACKET_HEADER_SIZE = 4; const size_t SSL_REQUEST_PAYLOAD_SIZE = 32; - -enum CharacterSet -{ - utf8_general_ci = 33, - binary = 63 -}; - enum StatusFlags { SERVER_SESSION_STATE_CHANGED = 0x4000 @@ -113,13 +106,6 @@ enum Command COM_DAEMON = 0x1d }; -// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html -enum ColumnDefinitionFlags -{ - UNSIGNED_FLAG = 32, - BINARY_FLAG = 128 -}; - class ProtocolError : public DB::Exception { @@ -128,35 +114,6 @@ public: }; -uint64_t readLengthEncodedNumber(ReadBuffer & ss); - -inline void readLengthEncodedString(String & s, ReadBuffer & buffer) -{ - uint64_t len = readLengthEncodedNumber(buffer); - s.resize(len); - buffer.readStrict(reinterpret_cast(s.data()), len); -} - -void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer); - -inline void writeLengthEncodedString(const String & s, WriteBuffer & buffer) -{ - writeLengthEncodedNumber(s.size(), buffer); - buffer.write(s.data(), s.size()); -} - -inline void writeNulTerminatedString(const String & s, WriteBuffer & buffer) -{ - buffer.write(s.data(), s.size()); - buffer.write(0); -} - -size_t getLengthEncodedNumberSize(uint64_t x); - -size_t getLengthEncodedStringSize(const String & s); - -ColumnDefinitionPacket getColumnDefinition(const String & column_name, const TypeIndex index); - namespace Replication { /// https://dev.mysql.com/doc/internals/en/com-register-slave.html diff --git a/src/Core/MySQLReplication.cpp b/src/Core/MySQLReplication.cpp index 68ac99225f7..d533262d7cb 100644 --- a/src/Core/MySQLReplication.cpp +++ b/src/Core/MySQLReplication.cpp @@ -15,6 +15,7 @@ namespace ErrorCodes namespace MySQLReplication { using namespace MySQLProtocol; + using namespace MySQLProtocol::ProtocolText; /// https://dev.mysql.com/doc/internals/en/binlog-event-header.html void EventHeader::parse(ReadBuffer & payload) diff --git a/src/Core/tests/mysql_protocol.cpp b/src/Core/tests/mysql_protocol.cpp index af265120ffe..5bb19798a3b 100644 --- a/src/Core/tests/mysql_protocol.cpp +++ b/src/Core/tests/mysql_protocol.cpp @@ -13,6 +13,7 @@ int main(int argc, char ** argv) using namespace MySQLProtocol; using namespace MySQLProtocol::Authentication; using namespace MySQLProtocol::ConnectionPhase; + using namespace MySQLProtocol::ProtocolText; uint8_t sequence_id = 1; @@ -36,7 +37,7 @@ int main(int argc, char ** argv) std::string s0; WriteBufferFromString out0(s0); - Handshake server_handshake(server_capability_flags, -1, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa"); + Handshake server_handshake(server_capability_flags, -1, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa", CharacterSet::utf8_general_ci); server_handshake.writePayload(out0, sequence_id); /// 1.2 Client reads the greeting @@ -143,12 +144,12 @@ int main(int argc, char ** argv) // 1. Server writes packet std::string s0; WriteBufferFromString out0(s0); - ColumnDefinitionPacket server("schema", "tbl", "org_tbl", "name", "org_name", 33, 0x00, MYSQL_TYPE_STRING, 0x00, 0x00); + ColumnDefinition server("schema", "tbl", "org_tbl", "name", "org_name", 33, 0x00, MYSQL_TYPE_STRING, 0x00, 0x00); server.writePayload(out0, sequence_id); // 2. Client reads packet ReadBufferFromString in0(s0); - ColumnDefinitionPacket client; + ColumnDefinition client; client.readPayload(in0, sequence_id); // Check diff --git a/src/Processors/Formats/Impl/MySQLOutputFormat.cpp b/src/Processors/Formats/Impl/MySQLOutputFormat.cpp index 2a00e3393a8..f7af3ce898e 100644 --- a/src/Processors/Formats/Impl/MySQLOutputFormat.cpp +++ b/src/Processors/Formats/Impl/MySQLOutputFormat.cpp @@ -10,6 +10,7 @@ namespace DB { using namespace MySQLProtocol; +using namespace MySQLProtocol::ProtocolText; MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header_, const FormatSettings & settings_) diff --git a/src/Server/MySQLHandler.cpp b/src/Server/MySQLHandler.cpp index e2bfe958c29..ff72f139fb7 100644 --- a/src/Server/MySQLHandler.cpp +++ b/src/Server/MySQLHandler.cpp @@ -32,6 +32,7 @@ namespace DB using namespace MySQLProtocol; using namespace MySQLProtocol::ConnectionPhase; +using namespace MySQLProtocol::ProtocolText; #if USE_SSL using Poco::Net::SecureStreamSocket; @@ -79,7 +80,8 @@ void MySQLHandler::run() try { - Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, auth_plugin->getName(), auth_plugin->getAuthPluginData()); + Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, + auth_plugin->getName(), auth_plugin->getAuthPluginData(), CharacterSet::utf8_general_ci); packet_sender->sendPacket(handshake, true); LOG_TRACE(log, "Sent handshake"); @@ -266,7 +268,7 @@ void MySQLHandler::comFieldList(ReadBuffer & payload) auto metadata_snapshot = table_ptr->getInMemoryMetadataPtr(); for (const NameAndTypePair & column : metadata_snapshot->getColumns().getAll()) { - ColumnDefinitionPacket column_definition( + ColumnDefinition column_definition( database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0 ); packet_sender->sendPacket(column_definition);