From ab8b847e66aadda369166bdc2b5047610a4a32bb Mon Sep 17 00:00:00 2001 From: BohuTANG Date: Mon, 20 Apr 2020 22:31:43 +0800 Subject: [PATCH] add OK/ERR packet parse --- src/Core/MySQLClient.cpp | 5 +-- src/Core/MySQLProtocol.h | 76 +++++++++++++++++++++++++++++++++------- 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/src/Core/MySQLClient.cpp b/src/Core/MySQLClient.cpp index 0fbe3a5b067..44648551622 100644 --- a/src/Core/MySQLClient.cpp +++ b/src/Core/MySQLClient.cpp @@ -79,10 +79,11 @@ void MySQLClient::handshake() client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password); packet_sender->sendPacket(handshakeResponse, true); - PacketResponse packetResponse; + PacketResponse packetResponse(handshake.capability_flags); packet_sender->receivePacket(packetResponse); - switch (packetResponse.getType()) { + switch (packetResponse.getType()) + { case PACKET_OK: break; case PACKET_ERR: diff --git a/src/Core/MySQLProtocol.h b/src/Core/MySQLProtocol.h index ce62e111060..08f764e7b11 100644 --- a/src/Core/MySQLProtocol.h +++ b/src/Core/MySQLProtocol.h @@ -752,17 +752,21 @@ protected: }; -class OK_Packet : public WritePacket +class OK_Packet : public WritePacket, public ReadPacket { +public: uint8_t header; uint32_t capabilities; uint64_t affected_rows; + uint64_t last_insert_id; int16_t warnings = 0; uint32_t status_flags; String session_state_changes; String info; -public: - OK_Packet(uint8_t header_, + + OK_Packet(uint32_t capabilities_) : capabilities(capabilities_) { } + OK_Packet( + uint8_t header_, uint32_t capabilities_, uint64_t affected_rows_, uint32_t status_flags_, @@ -779,7 +783,6 @@ public: { } -protected: size_t getPayloadSize() const override { size_t result = 2 + getLengthEncodedNumberSize(affected_rows); @@ -834,17 +837,51 @@ protected: writeString(info, buffer); } } + + void readPayloadImpl(ReadBuffer & payload) override + { + payload.readStrict(reinterpret_cast(&header), 1); + affected_rows = readLengthEncodedNumber(payload); + last_insert_id = readLengthEncodedNumber(payload); + + if (capabilities & CLIENT_PROTOCOL_41) + { + payload.readStrict(reinterpret_cast(&status_flags), 2); + payload.readStrict(reinterpret_cast(&warnings), 2); + } + else if (capabilities & CLIENT_TRANSACTIONS) + { + payload.readStrict(reinterpret_cast(&status_flags), 2); + } + + if (capabilities & CLIENT_SESSION_TRACK) + { + auto len = readLengthEncodedNumber(payload); + info.resize(len); + payload.readStrict(info.data(), len); + if (status_flags & SERVER_SESSION_STATE_CHANGED) { + len = readLengthEncodedNumber(payload); + session_state_changes.resize(len); + payload.readStrict(session_state_changes.data(), len); + } + } + else + { + readString(info, payload); + } + } }; -class EOF_Packet : public WritePacket +class EOF_Packet : public WritePacket, public ReadPacket { +public: int warnings; int status_flags; -public: + + EOF_Packet() = default; EOF_Packet(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_) {} -protected: size_t getPayloadSize() const override { return 5; @@ -856,6 +893,15 @@ protected: buffer.write(reinterpret_cast(&warnings), 2); buffer.write(reinterpret_cast(&status_flags), 2); } + + void readPayloadImpl(ReadBuffer & payload) override + { + UInt8 header = 0; + payload.readStrict(reinterpret_cast(&header), 1); + assert(header == 0xfe); + payload.readStrict(reinterpret_cast(&warnings), 2); + payload.readStrict(reinterpret_cast(&status_flags), 2); + } }; class ERR_Packet : public WritePacket, public ReadPacket @@ -890,10 +936,14 @@ public: UInt8 header = 0; payload.readStrict(reinterpret_cast(&header), 1); assert(header == 0xff); + payload.readStrict(reinterpret_cast(&error_code), 2); payload.ignore(1); + + sql_state.resize(5); payload.readStrict(reinterpret_cast(&sql_state), 5); - readString(error_message, payload); + + readStringUntilEOF(error_message, payload); } }; @@ -901,11 +951,11 @@ public: class PacketResponse : public ReadPacket { public: - OK_Packet * ok; - ERR_Packet err; - EOF_Packet * eof; + OK_Packet ok; + ERR_Packet err; + EOF_Packet eof; - PacketResponse() = default; + PacketResponse(UInt32 server_capability_flags_) : ok(OK_Packet(server_capability_flags_)) { } void readPayloadImpl(ReadBuffer & payload) override { @@ -914,6 +964,7 @@ public: { case PACKET_OK: packetType = PACKET_OK; + ok.readPayloadImpl(payload); break; case PACKET_ERR: packetType = PACKET_ERR; @@ -921,6 +972,7 @@ public: break; case PACKET_EOF: packetType = PACKET_EOF; + eof.readPayloadImpl(payload); break; }; }