diff --git a/src/Core/MySQLProtocol.h b/src/Core/MySQLProtocol.h index 0a4474fdae0..237929ed068 100644 --- a/src/Core/MySQLProtocol.h +++ b/src/Core/MySQLProtocol.h @@ -475,11 +475,6 @@ public: { } - size_t getPayloadSize() const override - { - return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size(); - } - void writePayloadImpl(WriteBuffer & buffer) const override { buffer.write(static_cast(protocol_version)); @@ -536,6 +531,12 @@ public: readNullTerminated(auth_plugin_name, payload); } } + +protected: + size_t getPayloadSize() const override + { + return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size(); + } }; class SSLRequest : public ReadPacket @@ -556,15 +557,15 @@ public: class HandshakeResponse : public WritePacket, public ReadPacket { public: - uint32_t capability_flags = 0; - uint32_t max_packet_size = 0; - uint8_t character_set = 0; + uint32_t capability_flags; + uint32_t max_packet_size; + uint8_t character_set; String username; String database; String auth_response; String auth_plugin_name; - HandshakeResponse() = default; + HandshakeResponse() : capability_flags(0x00), max_packet_size(0x00), character_set(0x00) { } HandshakeResponse( UInt32 capability_flags_, @@ -584,35 +585,6 @@ public: { } - size_t getPayloadSize() const override - { - size_t size = 0; - size += 4 + 4 + 1 + 23; - size += username.size() + 1; - - if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) - { - size += getLengthEncodedStringSize(auth_response); - } - else if (capability_flags & CLIENT_SECURE_CONNECTION) - { - size += (1 + auth_response.size()); - } - else - { - size += (auth_response.size() + 1); - } - if (capability_flags & CLIENT_CONNECT_WITH_DB) - { - size += (database.size() + 1); - } - if (capability_flags & CLIENT_PLUGIN_AUTH) - { - size += (auth_plugin_name.size() + 1); - } - return size; - } - void writePayloadImpl(WriteBuffer & buffer) const override { buffer.write(reinterpret_cast(&capability_flags), 4); @@ -681,6 +653,36 @@ public: readNullTerminated(auth_plugin_name, payload); } } + +protected: + size_t getPayloadSize() const override + { + size_t size = 0; + size += 4 + 4 + 1 + 23; + size += username.size() + 1; + + if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) + { + size += getLengthEncodedStringSize(auth_response); + } + else if (capability_flags & CLIENT_SECURE_CONNECTION) + { + size += (1 + auth_response.size()); + } + else + { + size += (auth_response.size() + 1); + } + if (capability_flags & CLIENT_CONNECT_WITH_DB) + { + size += (database.size() + 1); + } + if (capability_flags & CLIENT_PLUGIN_AUTH) + { + size += (auth_plugin_name.size() + 1); + } + return size; + } }; class AuthSwitchRequest : public WritePacket @@ -750,7 +752,11 @@ public: String session_state_changes; String info; - OK_Packet(uint32_t capabilities_) : header(0x00), capabilities(capabilities_), affected_rows(0), last_insert_id(0), status_flags(0) { } + OK_Packet(uint32_t capabilities_) + : header(0x00), capabilities(capabilities_), affected_rows(0x00), last_insert_id(0x00), status_flags(0x00) + { + } + OK_Packet( uint8_t header_, uint32_t capabilities_, @@ -770,33 +776,6 @@ public: { } - size_t getPayloadSize() const override - { - size_t result = 2 + getLengthEncodedNumberSize(affected_rows); - - if (capabilities & CLIENT_PROTOCOL_41) - { - result += 4; - } - else if (capabilities & CLIENT_TRANSACTIONS) - { - result += 2; - } - - if (capabilities & CLIENT_SESSION_TRACK) - { - result += getLengthEncodedStringSize(info); - if (status_flags & SERVER_SESSION_STATE_CHANGED) - result += getLengthEncodedStringSize(session_state_changes); - } - else - { - result += info.size(); - } - - return result; - } - void writePayloadImpl(WriteBuffer & buffer) const override { buffer.write(header); @@ -854,23 +833,46 @@ public: readString(info, payload); } } + +protected: + size_t getPayloadSize() const override + { + size_t result = 2 + getLengthEncodedNumberSize(affected_rows); + + if (capabilities & CLIENT_PROTOCOL_41) + { + result += 4; + } + else if (capabilities & CLIENT_TRANSACTIONS) + { + result += 2; + } + + if (capabilities & CLIENT_SESSION_TRACK) + { + result += getLengthEncodedStringSize(info); + if (status_flags & SERVER_SESSION_STATE_CHANGED) + result += getLengthEncodedStringSize(session_state_changes); + } + else + { + result += info.size(); + } + + return result; + } }; class EOF_Packet : public WritePacket, public ReadPacket { public: UInt8 header = 0xfe; - int warnings = 0; - int status_flags = 0; + int warnings; + int status_flags; - EOF_Packet() = default; - EOF_Packet(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_) - {} + EOF_Packet() : warnings(0x00), status_flags(0x00) { } - size_t getPayloadSize() const override - { - return 5; - } + EOF_Packet(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_) { } void writePayloadImpl(WriteBuffer & buffer) const override { @@ -886,27 +888,29 @@ public: payload.readStrict(reinterpret_cast(&warnings), 2); payload.readStrict(reinterpret_cast(&status_flags), 2); } + +protected: + size_t getPayloadSize() const override + { + return 5; + } }; class ERR_Packet : public WritePacket, public ReadPacket { public: UInt8 header = 0xff; - int error_code = 0; + int error_code; String sql_state; String error_message; - ERR_Packet() = default; + ERR_Packet() : error_code(0x00) { } + ERR_Packet(int error_code_, String sql_state_, String error_message_) : error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_)) { } - size_t getPayloadSize() const override - { - return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE); - } - void writePayloadImpl(WriteBuffer & buffer) const override { buffer.write(header); @@ -933,6 +937,12 @@ public: } readString(error_message, payload); } + +protected: + size_t getPayloadSize() const override + { + return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE); + } }; /// https://dev.mysql.com/doc/internals/en/generic-response-packets.html @@ -967,6 +977,7 @@ public: packetType = PACKET_LOCALINFILE; break; default: + packetType = PACKET_OK; column_length = readLengthEncodedNumber(payload); } } @@ -1024,12 +1035,6 @@ public: { } - size_t getPayloadSize() const override - { - return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \ - getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length); - } - void writePayloadImpl(WriteBuffer & buffer) const override { writeLengthEncodedString(std::string("def"), buffer); /// always "def" @@ -1065,6 +1070,13 @@ public: payload.readStrict(reinterpret_cast(&decimals), 2); payload.ignore(2); } + +protected: + size_t getPayloadSize() const override + { + return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \ + getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length); + } }; class ComFieldList : public LimitedReadPacket