mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-25 17:12:03 +00:00
Merge pull request #5811 from yurriy/mysql
Reading and writing MySQL packets in parts
This commit is contained in:
commit
41eaeb3e3d
@ -1,24 +1,24 @@
|
||||
#include <DataStreams/copyData.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/ReadBufferFromPocoSocket.h>
|
||||
#include <IO/WriteBufferFromPocoSocket.h>
|
||||
#include <Interpreters/executeQuery.h>
|
||||
#include <Storages/IStorage.h>
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <Core/NamesAndTypes.h>
|
||||
#include "MySQLHandler.h"
|
||||
|
||||
#include <limits>
|
||||
#include <ext/scope_guard.h>
|
||||
#include <openssl/rsa.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/config_version.h>
|
||||
#include <Common/NetException.h>
|
||||
#include <Common/OpenSSLHelpers.h>
|
||||
#include <Poco/Crypto/RSAKey.h>
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <Core/NamesAndTypes.h>
|
||||
#include <DataStreams/copyData.h>
|
||||
#include <Interpreters/executeQuery.h>
|
||||
#include <IO/ReadBufferFromPocoSocket.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromPocoSocket.h>
|
||||
#include <Poco/Crypto/CipherFactory.h>
|
||||
#include <Poco/Crypto/RSAKey.h>
|
||||
#include <Poco/Net/SecureStreamSocket.h>
|
||||
#include <Poco/Net/SSLManager.h>
|
||||
#include "MySQLHandler.h"
|
||||
#include <limits>
|
||||
#include <ext/scope_guard.h>
|
||||
|
||||
#include <openssl/rsa.h>
|
||||
#include <Storages/IStorage.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
@ -53,33 +53,29 @@ MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & so
|
||||
|
||||
void MySQLHandler::run()
|
||||
{
|
||||
connection_context = server.context();
|
||||
connection_context.makeSessionContext();
|
||||
connection_context.setDefaultFormat("MySQLWire");
|
||||
|
||||
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
|
||||
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
|
||||
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
|
||||
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.mysql.sequence_id);
|
||||
|
||||
try
|
||||
{
|
||||
String scramble = generateScramble();
|
||||
|
||||
/** Native authentication sent 20 bytes + '\0' character = 21 bytes.
|
||||
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
|
||||
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
|
||||
*/
|
||||
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, scramble + '\0');
|
||||
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, Authentication::Native, scramble + '\0');
|
||||
packet_sender->sendPacket<Handshake>(handshake, true);
|
||||
|
||||
LOG_TRACE(log, "Sent handshake");
|
||||
|
||||
HandshakeResponse handshake_response = finishHandshake();
|
||||
connection_context.client_capabilities = handshake_response.capability_flags;
|
||||
HandshakeResponse handshake_response;
|
||||
finishHandshake(handshake_response);
|
||||
connection_context.mysql.client_capabilities = handshake_response.capability_flags;
|
||||
if (handshake_response.max_packet_size)
|
||||
connection_context.max_packet_size = handshake_response.max_packet_size;
|
||||
if (!connection_context.max_packet_size)
|
||||
connection_context.max_packet_size = MAX_PACKET_LENGTH;
|
||||
connection_context.mysql.max_packet_size = handshake_response.max_packet_size;
|
||||
if (!connection_context.mysql.max_packet_size)
|
||||
connection_context.mysql.max_packet_size = MAX_PACKET_LENGTH;
|
||||
|
||||
LOG_DEBUG(log, "Capabilities: " << handshake_response.capability_flags
|
||||
<< "\nmax_packet_size: "
|
||||
@ -110,9 +106,15 @@ void MySQLHandler::run()
|
||||
while (true)
|
||||
{
|
||||
packet_sender->resetSequenceId();
|
||||
String payload = packet_sender->receivePacketPayload();
|
||||
int command = payload[0];
|
||||
LOG_DEBUG(log, "Received command: " << std::to_string(command) << ". Connection id: " << connection_id << ".");
|
||||
PacketPayloadReadBuffer payload = packet_sender->getPayload();
|
||||
|
||||
char command = 0;
|
||||
payload.readStrict(command);
|
||||
|
||||
// For commands which are executed without MemoryTracker.
|
||||
LimitReadBuffer limited_payload(payload, 10000, true, "too long MySQL packet.");
|
||||
|
||||
LOG_DEBUG(log, "Received command: " << static_cast<int>(static_cast<unsigned char>(command)) << ". Connection id: " << connection_id << ".");
|
||||
try
|
||||
{
|
||||
switch (command)
|
||||
@ -120,13 +122,13 @@ void MySQLHandler::run()
|
||||
case COM_QUIT:
|
||||
return;
|
||||
case COM_INIT_DB:
|
||||
comInitDB(payload);
|
||||
comInitDB(limited_payload);
|
||||
break;
|
||||
case COM_QUERY:
|
||||
comQuery(payload);
|
||||
break;
|
||||
case COM_FIELD_LIST:
|
||||
comFieldList(payload);
|
||||
comFieldList(limited_payload);
|
||||
break;
|
||||
case COM_PING:
|
||||
comPing();
|
||||
@ -147,7 +149,7 @@ void MySQLHandler::run()
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Poco::Exception & exc)
|
||||
catch (const Poco::Exception & exc)
|
||||
{
|
||||
log->log(exc);
|
||||
}
|
||||
@ -157,9 +159,8 @@ void MySQLHandler::run()
|
||||
* Reading is performed from socket instead of ReadBuffer to prevent reading part of SSL handshake.
|
||||
* If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes.
|
||||
*/
|
||||
MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
|
||||
void MySQLHandler::finishHandshake(MySQLProtocol::HandshakeResponse & packet)
|
||||
{
|
||||
HandshakeResponse packet;
|
||||
size_t packet_size = PACKET_HEADER_SIZE + SSL_REQUEST_PAYLOAD_SIZE;
|
||||
|
||||
/// Buffer for SSLRequest or part of HandshakeResponse.
|
||||
@ -187,16 +188,18 @@ MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
|
||||
{
|
||||
read_bytes(packet_size); /// Reading rest SSLRequest.
|
||||
SSLRequest ssl_request;
|
||||
ssl_request.readPayload(String(buf + PACKET_HEADER_SIZE, pos - PACKET_HEADER_SIZE));
|
||||
connection_context.client_capabilities = ssl_request.capability_flags;
|
||||
connection_context.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
|
||||
ReadBufferFromMemory payload(buf, pos);
|
||||
payload.ignore(PACKET_HEADER_SIZE);
|
||||
ssl_request.readPayload(payload);
|
||||
connection_context.mysql.client_capabilities = ssl_request.capability_flags;
|
||||
connection_context.mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
|
||||
secure_connection = true;
|
||||
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
|
||||
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
|
||||
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
|
||||
connection_context.sequence_id = 2;
|
||||
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
|
||||
packet_sender->max_packet_size = connection_context.max_packet_size;
|
||||
connection_context.mysql.sequence_id = 2;
|
||||
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.mysql.sequence_id);
|
||||
packet_sender->max_packet_size = connection_context.mysql.max_packet_size;
|
||||
packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
|
||||
}
|
||||
else
|
||||
@ -206,10 +209,11 @@ MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
|
||||
WriteBufferFromOwnString buf_for_handshake_response;
|
||||
buf_for_handshake_response.write(buf, pos);
|
||||
copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos);
|
||||
packet.readPayload(buf_for_handshake_response.str().substr(PACKET_HEADER_SIZE));
|
||||
ReadBufferFromString payload(buf_for_handshake_response.str());
|
||||
payload.ignore(PACKET_HEADER_SIZE);
|
||||
packet.readPayload(payload);
|
||||
packet_sender->sequence_id++;
|
||||
}
|
||||
return packet;
|
||||
}
|
||||
|
||||
String MySQLHandler::generateScramble()
|
||||
@ -230,6 +234,10 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
|
||||
AuthSwitchResponse response;
|
||||
if (handshake_response.auth_plugin_name != Authentication::SHA256)
|
||||
{
|
||||
/** Native authentication sent 20 bytes + '\0' character = 21 bytes.
|
||||
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
|
||||
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
|
||||
*/
|
||||
packet_sender->sendPacket(AuthSwitchRequest(Authentication::SHA256, scramble + '\0'), true);
|
||||
if (in->eof())
|
||||
throw Exception(
|
||||
@ -315,7 +323,7 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
|
||||
connection_context.setUser(handshake_response.username, password, socket().address(), "");
|
||||
if (!handshake_response.database.empty()) connection_context.setCurrentDatabase(handshake_response.database);
|
||||
connection_context.setCurrentQueryId("");
|
||||
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " succeeded.");
|
||||
LOG_INFO(log, "Authentication for user " << handshake_response.username << " succeeded.");
|
||||
}
|
||||
catch (const Exception & exc)
|
||||
{
|
||||
@ -325,15 +333,16 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
|
||||
}
|
||||
}
|
||||
|
||||
void MySQLHandler::comInitDB(const String & payload)
|
||||
void MySQLHandler::comInitDB(ReadBuffer & payload)
|
||||
{
|
||||
String database = payload.substr(1);
|
||||
String database;
|
||||
readStringUntilEOF(database, payload);
|
||||
LOG_DEBUG(log, "Setting current database to " << database);
|
||||
connection_context.setCurrentDatabase(database);
|
||||
packet_sender->sendPacket(OK_Packet(0, client_capability_flags, 0, 0, 1), true);
|
||||
}
|
||||
|
||||
void MySQLHandler::comFieldList(const String & payload)
|
||||
void MySQLHandler::comFieldList(ReadBuffer & payload)
|
||||
{
|
||||
ComFieldList packet;
|
||||
packet.readPayload(payload);
|
||||
@ -354,22 +363,26 @@ void MySQLHandler::comPing()
|
||||
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
|
||||
}
|
||||
|
||||
void MySQLHandler::comQuery(const String & payload)
|
||||
void MySQLHandler::comQuery(ReadBuffer & payload)
|
||||
{
|
||||
bool with_output = false;
|
||||
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
|
||||
with_output = true;
|
||||
};
|
||||
|
||||
String query = payload.substr(1);
|
||||
const String query("select ''");
|
||||
ReadBufferFromString empty_select(query);
|
||||
|
||||
bool should_replace = false;
|
||||
// Translate query from MySQL to ClickHouse.
|
||||
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
|
||||
if (query == "select @@version_comment limit 1") // MariaDB client starts session with that query
|
||||
query = "select ''";
|
||||
if (std::string(payload.position(), payload.buffer().end()) == "select @@version_comment limit 1") // MariaDB client starts session with that query
|
||||
{
|
||||
should_replace = true;
|
||||
}
|
||||
|
||||
executeQuery(should_replace ? empty_select : payload, *out, true, connection_context, set_content_type, nullptr);
|
||||
|
||||
ReadBufferFromString buf(query);
|
||||
executeQuery(buf, *out, true, connection_context, set_content_type, nullptr);
|
||||
if (!with_output)
|
||||
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
|
||||
}
|
||||
|
@ -20,15 +20,15 @@ public:
|
||||
|
||||
private:
|
||||
/// Enables SSL, if client requested.
|
||||
MySQLProtocol::HandshakeResponse finishHandshake();
|
||||
void finishHandshake(MySQLProtocol::HandshakeResponse &);
|
||||
|
||||
void comQuery(const String & payload);
|
||||
void comQuery(ReadBuffer & payload);
|
||||
|
||||
void comFieldList(const String & payload);
|
||||
void comFieldList(ReadBuffer & payload);
|
||||
|
||||
void comPing();
|
||||
|
||||
void comInitDB(const String & payload);
|
||||
void comInitDB(ReadBuffer & payload);
|
||||
|
||||
static String generateScramble();
|
||||
|
||||
@ -48,11 +48,11 @@ private:
|
||||
RSA & public_key;
|
||||
RSA & private_key;
|
||||
|
||||
std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
|
||||
std::shared_ptr<ReadBuffer> in;
|
||||
std::shared_ptr<WriteBuffer> out;
|
||||
|
||||
bool secure_connection = false;
|
||||
std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -7,9 +7,7 @@
|
||||
#include "MySQLProtocol.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace MySQLProtocol
|
||||
namespace DB::MySQLProtocol
|
||||
{
|
||||
|
||||
void PacketSender::resetSequenceId()
|
||||
@ -17,7 +15,7 @@ void PacketSender::resetSequenceId()
|
||||
sequence_id = 0;
|
||||
}
|
||||
|
||||
String PacketSender::packetToText(String payload)
|
||||
String PacketSender::packetToText(const String & payload)
|
||||
{
|
||||
String result;
|
||||
for (auto c : payload)
|
||||
@ -28,11 +26,11 @@ String PacketSender::packetToText(String payload)
|
||||
return result;
|
||||
}
|
||||
|
||||
uint64_t readLengthEncodedNumber(std::istringstream & ss)
|
||||
uint64_t readLengthEncodedNumber(ReadBuffer & ss)
|
||||
{
|
||||
char c{};
|
||||
uint64_t buf = 0;
|
||||
ss.get(c);
|
||||
ss.readStrict(c);
|
||||
auto cc = static_cast<uint8_t>(c);
|
||||
if (cc < 0xfc)
|
||||
{
|
||||
@ -40,55 +38,65 @@ uint64_t readLengthEncodedNumber(std::istringstream & ss)
|
||||
}
|
||||
else if (cc < 0xfd)
|
||||
{
|
||||
ss.read(reinterpret_cast<char *>(&buf), 2);
|
||||
ss.readStrict(reinterpret_cast<char *>(&buf), 2);
|
||||
}
|
||||
else if (cc < 0xfe)
|
||||
{
|
||||
ss.read(reinterpret_cast<char *>(&buf), 3);
|
||||
ss.readStrict(reinterpret_cast<char *>(&buf), 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
ss.read(reinterpret_cast<char *>(&buf), 8);
|
||||
ss.readStrict(reinterpret_cast<char *>(&buf), 8);
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
std::string writeLengthEncodedNumber(uint64_t x)
|
||||
void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer)
|
||||
{
|
||||
std::string result;
|
||||
if (x < 251)
|
||||
{
|
||||
result.append(1, static_cast<char>(x));
|
||||
buffer.write(static_cast<char>(x));
|
||||
}
|
||||
else if (x < (1 << 16))
|
||||
{
|
||||
result.append(1, 0xfc);
|
||||
result.append(reinterpret_cast<char *>(&x), 2);
|
||||
buffer.write(0xfc);
|
||||
buffer.write(reinterpret_cast<char *>(&x), 2);
|
||||
}
|
||||
else if (x < (1 << 24))
|
||||
{
|
||||
result.append(1, 0xfd);
|
||||
result.append(reinterpret_cast<char *>(&x), 3);
|
||||
buffer.write(0xfd);
|
||||
buffer.write(reinterpret_cast<char *>(&x), 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
result.append(1, 0xfe);
|
||||
result.append(reinterpret_cast<char *>(&x), 8);
|
||||
buffer.write(0xfe);
|
||||
buffer.write(reinterpret_cast<char *>(&x), 8);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void writeLengthEncodedString(std::string & payload, const std::string & s)
|
||||
size_t getLengthEncodedNumberSize(uint64_t x)
|
||||
{
|
||||
payload.append(writeLengthEncodedNumber(s.length()));
|
||||
payload.append(s);
|
||||
if (x < 251)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else if (x < (1 << 16))
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if (x < (1 << 24))
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 9;
|
||||
}
|
||||
}
|
||||
|
||||
void writeNulTerminatedString(std::string & payload, const std::string & s)
|
||||
size_t getLengthEncodedStringSize(const String & s)
|
||||
{
|
||||
payload.append(s);
|
||||
payload.append(1, 0);
|
||||
return getLengthEncodedNumberSize(s.size()) + s.size();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <Common/MemoryTracker.h>
|
||||
#include <Common/PODArray.h>
|
||||
#include <Core/Types.h>
|
||||
#include <IO/copyData.h>
|
||||
#include <IO/ReadBuffer.h>
|
||||
#include <IO/ReadBufferFromMemory.h>
|
||||
#include <IO/ReadBufferFromPocoSocket.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/WriteBufferFromPocoSocket.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Poco/Net/StreamSocket.h>
|
||||
#include <Poco/RandomStream.h>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include <IO/LimitReadBuffer.h>
|
||||
|
||||
/// Implementation of MySQL wire protocol
|
||||
/// Implementation of MySQL wire protocol.
|
||||
/// Works only on little-endian architecture.
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -124,23 +131,204 @@ public:
|
||||
};
|
||||
|
||||
|
||||
class WritePacket
|
||||
/** Reading packets.
|
||||
* Internally, it calls (if no more data) next() method of the underlying ReadBufferFromPocoSocket, and sets the working buffer to the rest part of the current packet payload.
|
||||
*/
|
||||
class PacketPayloadReadBuffer : public ReadBuffer
|
||||
{
|
||||
public:
|
||||
virtual String getPayload() const = 0;
|
||||
PacketPayloadReadBuffer(ReadBuffer & in, uint8_t & sequence_id)
|
||||
: ReadBuffer(in.position(), 0) // not in.buffer().begin(), because working buffer may include previous packet
|
||||
, in(in)
|
||||
, sequence_id(sequence_id)
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~WritePacket() = default;
|
||||
private:
|
||||
ReadBuffer & in;
|
||||
uint8_t & sequence_id;
|
||||
const size_t max_packet_size = MAX_PACKET_LENGTH;
|
||||
|
||||
// Size of packet which is being read now.
|
||||
size_t payload_length = 0;
|
||||
|
||||
// Offset in packet payload.
|
||||
size_t offset = 0;
|
||||
|
||||
protected:
|
||||
bool nextImpl() override
|
||||
{
|
||||
if (payload_length == 0 || (payload_length == max_packet_size && offset == payload_length))
|
||||
{
|
||||
working_buffer.resize(0);
|
||||
offset = 0;
|
||||
payload_length = 0;
|
||||
in.readStrict(reinterpret_cast<char *>(&payload_length), 3);
|
||||
|
||||
if (payload_length > max_packet_size)
|
||||
{
|
||||
std::ostringstream tmp;
|
||||
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
|
||||
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
else if (payload_length == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t packet_sequence_id = 0;
|
||||
in.read(reinterpret_cast<char &>(packet_sequence_id));
|
||||
if (packet_sequence_id != sequence_id)
|
||||
{
|
||||
std::ostringstream tmp;
|
||||
tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << static_cast<unsigned int>(sequence_id) << '.';
|
||||
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
sequence_id++;
|
||||
}
|
||||
else if (offset == payload_length)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
in.nextIfAtEnd();
|
||||
working_buffer = ReadBuffer::Buffer(in.position(), in.buffer().end());
|
||||
size_t count = std::min(in.available(), payload_length - offset);
|
||||
working_buffer.resize(count);
|
||||
in.ignore(count);
|
||||
|
||||
offset += count;
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class ReadPacket
|
||||
class ClientPacket
|
||||
{
|
||||
public:
|
||||
ReadPacket() = default;
|
||||
ReadPacket(const ReadPacket &) = default;
|
||||
virtual void readPayload(String payload) = 0;
|
||||
ClientPacket() = default;
|
||||
ClientPacket(ClientPacket &&) = default;
|
||||
|
||||
virtual ~ReadPacket() = default;
|
||||
virtual void read(ReadBuffer & in, uint8_t & sequence_id)
|
||||
{
|
||||
PacketPayloadReadBuffer payload(in, sequence_id);
|
||||
readPayload(payload);
|
||||
if (!payload.eof())
|
||||
{
|
||||
std::stringstream tmp;
|
||||
tmp << "Packet payload is not fully read. Stopped after " << payload.count() << " bytes, while " << payload.available() << " bytes are in buffer.";
|
||||
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void readPayload(ReadBuffer & buf) = 0;
|
||||
|
||||
virtual ~ClientPacket() = default;
|
||||
};
|
||||
|
||||
|
||||
class LimitedClientPacket : public ClientPacket
|
||||
{
|
||||
public:
|
||||
void read(ReadBuffer & in, uint8_t & sequence_id) override
|
||||
{
|
||||
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
|
||||
ClientPacket::read(limited, sequence_id);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/** Writing packets.
|
||||
* https://dev.mysql.com/doc/internals/en/mysql-packet.html
|
||||
*/
|
||||
class PacketPayloadWriteBuffer : public WriteBuffer
|
||||
{
|
||||
public:
|
||||
PacketPayloadWriteBuffer(WriteBuffer & out, size_t payload_length, uint8_t & sequence_id)
|
||||
: WriteBuffer(out.position(), 0)
|
||||
, out(out)
|
||||
, sequence_id(sequence_id)
|
||||
, total_left(payload_length)
|
||||
{
|
||||
startPacket();
|
||||
}
|
||||
|
||||
void checkPayloadSize()
|
||||
{
|
||||
if (bytes_written + offset() < payload_length)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "Incomplete payload. Written " << bytes << " bytes, expected " << payload_length << " bytes.";
|
||||
throw Exception(ss.str(), 0);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
~PacketPayloadWriteBuffer() override { next(); }
|
||||
private:
|
||||
WriteBuffer & out;
|
||||
uint8_t & sequence_id;
|
||||
|
||||
size_t total_left = 0;
|
||||
size_t payload_length = 0;
|
||||
size_t bytes_written = 0;
|
||||
|
||||
void startPacket()
|
||||
{
|
||||
payload_length = std::min(total_left, MAX_PACKET_LENGTH);
|
||||
bytes_written = 0;
|
||||
total_left -= payload_length;
|
||||
|
||||
out.write(reinterpret_cast<char *>(&payload_length), 3);
|
||||
out.write(sequence_id++);
|
||||
|
||||
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
|
||||
pos = working_buffer.begin();
|
||||
}
|
||||
protected:
|
||||
void nextImpl() override
|
||||
{
|
||||
int written = pos - working_buffer.begin();
|
||||
out.position() += written;
|
||||
bytes_written += written;
|
||||
|
||||
if (bytes_written < payload_length)
|
||||
{
|
||||
out.nextIfAtEnd();
|
||||
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
|
||||
}
|
||||
else if (total_left > 0 || payload_length == MAX_PACKET_LENGTH)
|
||||
{
|
||||
// Starting new packet, since packets of size greater than MAX_PACKET_LENGTH should be split.
|
||||
startPacket();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Finished writing packet. Buffer is set to empty to prevent rewriting (pos will be set to the beginning of a working buffer in next()).
|
||||
// Further attempts to write will stall in the infinite loop.
|
||||
working_buffer = WriteBuffer::Buffer(out.position(), out.position());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class WritePacket
|
||||
{
|
||||
public:
|
||||
virtual void writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const
|
||||
{
|
||||
PacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
|
||||
writePayloadImpl(buf);
|
||||
buf.checkPayloadSize();
|
||||
}
|
||||
|
||||
virtual ~WritePacket() = default;
|
||||
|
||||
protected:
|
||||
virtual size_t getPayloadSize() const = 0;
|
||||
|
||||
virtual void writePayloadImpl(WriteBuffer & buffer) const = 0;
|
||||
};
|
||||
|
||||
|
||||
@ -150,13 +338,13 @@ public:
|
||||
class PacketSender
|
||||
{
|
||||
public:
|
||||
size_t & sequence_id;
|
||||
uint8_t & sequence_id;
|
||||
ReadBuffer * in;
|
||||
WriteBuffer * out;
|
||||
size_t max_packet_size = MAX_PACKET_LENGTH;
|
||||
|
||||
/// For reading and writing.
|
||||
PacketSender(ReadBuffer & in, WriteBuffer & out, size_t & sequence_id)
|
||||
PacketSender(ReadBuffer & in, WriteBuffer & out, uint8_t & sequence_id)
|
||||
: sequence_id(sequence_id)
|
||||
, in(&in)
|
||||
, out(&out)
|
||||
@ -164,91 +352,59 @@ public:
|
||||
}
|
||||
|
||||
/// For writing.
|
||||
PacketSender(WriteBuffer & out, size_t & sequence_id)
|
||||
PacketSender(WriteBuffer & out, uint8_t & sequence_id)
|
||||
: sequence_id(sequence_id)
|
||||
, in(nullptr)
|
||||
, out(&out)
|
||||
{
|
||||
}
|
||||
|
||||
String receivePacketPayload()
|
||||
void receivePacket(ClientPacket & packet)
|
||||
{
|
||||
WriteBufferFromOwnString buf;
|
||||
|
||||
size_t payload_length = 0;
|
||||
size_t packet_sequence_id = 0;
|
||||
|
||||
// packets which are larger than or equal to 16MB are splitted
|
||||
do
|
||||
{
|
||||
in->readStrict(reinterpret_cast<char *>(&payload_length), 3);
|
||||
|
||||
if (payload_length > max_packet_size)
|
||||
{
|
||||
std::ostringstream tmp;
|
||||
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
|
||||
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
|
||||
in->readStrict(reinterpret_cast<char *>(&packet_sequence_id), 1);
|
||||
|
||||
if (packet_sequence_id != sequence_id)
|
||||
{
|
||||
std::ostringstream tmp;
|
||||
tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << sequence_id << '.';
|
||||
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
sequence_id++;
|
||||
|
||||
copyData(*in, static_cast<WriteBuffer &>(buf), payload_length);
|
||||
} while (payload_length == max_packet_size);
|
||||
|
||||
return std::move(buf.str());
|
||||
}
|
||||
|
||||
void receivePacket(ReadPacket & packet)
|
||||
{
|
||||
packet.readPayload(receivePacketPayload());
|
||||
packet.read(*in, sequence_id);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void sendPacket(const T & packet, bool flush = false)
|
||||
{
|
||||
static_assert(std::is_base_of<WritePacket, T>());
|
||||
String payload = packet.getPayload();
|
||||
size_t pos = 0;
|
||||
do
|
||||
{
|
||||
size_t payload_length = std::min(payload.length() - pos, max_packet_size);
|
||||
|
||||
out->write(reinterpret_cast<const char *>(&payload_length), 3);
|
||||
out->write(reinterpret_cast<const char *>(&sequence_id), 1);
|
||||
out->write(payload.data() + pos, payload_length);
|
||||
|
||||
pos += payload_length;
|
||||
sequence_id++;
|
||||
} while (pos < payload.length());
|
||||
|
||||
packet.writePayload(*out, sequence_id);
|
||||
if (flush)
|
||||
out->next();
|
||||
}
|
||||
|
||||
PacketPayloadReadBuffer getPayload()
|
||||
{
|
||||
return PacketPayloadReadBuffer(*in, sequence_id);
|
||||
}
|
||||
|
||||
/// Sets sequence-id to 0. Must be called before each command phase.
|
||||
void resetSequenceId();
|
||||
|
||||
private:
|
||||
/// Converts packet to text. Is used for debug output.
|
||||
static String packetToText(String payload);
|
||||
static String packetToText(const String & payload);
|
||||
};
|
||||
|
||||
|
||||
uint64_t readLengthEncodedNumber(std::istringstream & ss);
|
||||
uint64_t readLengthEncodedNumber(ReadBuffer & ss);
|
||||
|
||||
String writeLengthEncodedNumber(uint64_t x);
|
||||
void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer);
|
||||
|
||||
void writeLengthEncodedString(String & payload, const String & s);
|
||||
inline void writeLengthEncodedString(const String & s, WriteBuffer & buffer)
|
||||
{
|
||||
writeLengthEncodedNumber(s.size(), buffer);
|
||||
buffer.write(s.data(), s.size());
|
||||
}
|
||||
|
||||
void writeNulTerminatedString(String & payload, const String & s);
|
||||
inline void writeNulTerminatedString(const String & s, WriteBuffer & buffer)
|
||||
{
|
||||
buffer.write(s.data(), s.size());
|
||||
buffer.write(0);
|
||||
}
|
||||
|
||||
size_t getLengthEncodedNumberSize(uint64_t x);
|
||||
|
||||
size_t getLengthEncodedStringSize(const String & s);
|
||||
|
||||
|
||||
class Handshake : public WritePacket
|
||||
@ -259,66 +415,69 @@ class Handshake : public WritePacket
|
||||
uint32_t capability_flags;
|
||||
uint8_t character_set;
|
||||
uint32_t status_flags;
|
||||
String auth_plugin_name;
|
||||
String auth_plugin_data;
|
||||
public:
|
||||
explicit Handshake(uint32_t capability_flags, uint32_t connection_id, String server_version, String auth_plugin_data)
|
||||
explicit Handshake(uint32_t capability_flags, uint32_t connection_id, String server_version, String auth_plugin_name, String auth_plugin_data)
|
||||
: protocol_version(0xa)
|
||||
, server_version(std::move(server_version))
|
||||
, connection_id(connection_id)
|
||||
, capability_flags(capability_flags)
|
||||
, character_set(CharacterSet::utf8_general_ci)
|
||||
, status_flags(0)
|
||||
, auth_plugin_data(auth_plugin_data)
|
||||
, auth_plugin_name(std::move(auth_plugin_name))
|
||||
, auth_plugin_data(std::move(auth_plugin_data))
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, protocol_version);
|
||||
writeNulTerminatedString(result, server_version);
|
||||
result.append(reinterpret_cast<const char *>(&connection_id), 4);
|
||||
writeNulTerminatedString(result, auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH));
|
||||
result.append(reinterpret_cast<const char *>(&capability_flags), 2);
|
||||
result.append(reinterpret_cast<const char *>(&character_set), 1);
|
||||
result.append(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
result.append((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
|
||||
result.append(1, auth_plugin_data.size());
|
||||
result.append(10, 0x0);
|
||||
result.append(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH));
|
||||
return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size();
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(static_cast<char>(protocol_version));
|
||||
writeNulTerminatedString(server_version, buffer);
|
||||
buffer.write(reinterpret_cast<const char *>(&connection_id), 4);
|
||||
writeNulTerminatedString(auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
|
||||
buffer.write(reinterpret_cast<const char *>(&capability_flags), 2);
|
||||
buffer.write(reinterpret_cast<const char *>(&character_set), 1);
|
||||
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
buffer.write((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
|
||||
buffer.write(static_cast<char>(auth_plugin_data.size()));
|
||||
writeChar(0x0, 10, buffer);
|
||||
writeString(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
|
||||
// A workaround for PHP mysqlnd extension bug which occurs when sha256_password is used as a default authentication plugin.
|
||||
// Instead of using client response for mysql_native_password plugin, the server will always generate authentication method mismatch
|
||||
// and switch to sha256_password to simulate that mysql_native_password is used as a default plugin.
|
||||
result.append(Authentication::Native);
|
||||
|
||||
result.append(1, 0x0);
|
||||
return result;
|
||||
writeString(auth_plugin_name, buffer);
|
||||
writeChar(0x0, 1, buffer);
|
||||
}
|
||||
};
|
||||
|
||||
class SSLRequest : public ReadPacket
|
||||
class SSLRequest : public ClientPacket
|
||||
{
|
||||
public:
|
||||
uint32_t capability_flags;
|
||||
uint32_t max_packet_size;
|
||||
uint8_t character_set;
|
||||
|
||||
void readPayload(String s) override
|
||||
void readPayload(ReadBuffer & buf) override
|
||||
{
|
||||
std::istringstream ss(s);
|
||||
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
|
||||
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
|
||||
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
|
||||
buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
|
||||
buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
|
||||
buf.readStrict(reinterpret_cast<char *>(&character_set), 1);
|
||||
}
|
||||
};
|
||||
|
||||
class HandshakeResponse : public ReadPacket
|
||||
class HandshakeResponse : public LimitedClientPacket
|
||||
{
|
||||
public:
|
||||
uint32_t capability_flags;
|
||||
uint32_t max_packet_size;
|
||||
uint8_t character_set;
|
||||
uint32_t capability_flags = 0;
|
||||
uint32_t max_packet_size = 0;
|
||||
uint8_t character_set = 0;
|
||||
String username;
|
||||
String auth_response;
|
||||
String database;
|
||||
@ -326,45 +485,41 @@ public:
|
||||
|
||||
HandshakeResponse() = default;
|
||||
|
||||
HandshakeResponse(const HandshakeResponse &) = default;
|
||||
|
||||
void readPayload(String s) override
|
||||
void readPayload(ReadBuffer & payload) override
|
||||
{
|
||||
std::istringstream ss(s);
|
||||
payload.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
|
||||
payload.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
|
||||
payload.readStrict(reinterpret_cast<char *>(&character_set), 1);
|
||||
payload.ignore(23);
|
||||
|
||||
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
|
||||
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
|
||||
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
|
||||
ss.ignore(23);
|
||||
|
||||
std::getline(ss, username, static_cast<char>(0x0));
|
||||
readNullTerminated(username, payload);
|
||||
|
||||
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
|
||||
{
|
||||
auto len = readLengthEncodedNumber(ss);
|
||||
auto len = readLengthEncodedNumber(payload);
|
||||
auth_response.resize(len);
|
||||
ss.read(auth_response.data(), static_cast<std::streamsize>(len));
|
||||
payload.readStrict(auth_response.data(), len);
|
||||
}
|
||||
else if (capability_flags & CLIENT_SECURE_CONNECTION)
|
||||
{
|
||||
uint8_t len;
|
||||
ss.read(reinterpret_cast<char *>(&len), 1);
|
||||
auth_response.resize(len);
|
||||
ss.read(auth_response.data(), len);
|
||||
char len;
|
||||
payload.readStrict(len);
|
||||
auth_response.resize(static_cast<unsigned int>(len));
|
||||
payload.readStrict(auth_response.data(), len);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::getline(ss, auth_response, static_cast<char>(0x0));
|
||||
readNullTerminated(auth_response, payload);
|
||||
}
|
||||
|
||||
if (capability_flags & CLIENT_CONNECT_WITH_DB)
|
||||
{
|
||||
std::getline(ss, database, static_cast<char>(0x0));
|
||||
readNullTerminated(database, payload);
|
||||
}
|
||||
|
||||
if (capability_flags & CLIENT_PLUGIN_AUTH)
|
||||
{
|
||||
std::getline(ss, auth_plugin_name, static_cast<char>(0x0));
|
||||
readNullTerminated(auth_plugin_name, payload);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -379,24 +534,28 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, 0xfe);
|
||||
writeNulTerminatedString(result, plugin_name);
|
||||
result.append(auth_plugin_data);
|
||||
return result;
|
||||
return 2 + plugin_name.size() + auth_plugin_data.size();
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(0xfe);
|
||||
writeNulTerminatedString(plugin_name, buffer);
|
||||
writeString(auth_plugin_data, buffer);
|
||||
}
|
||||
};
|
||||
|
||||
class AuthSwitchResponse : public ReadPacket
|
||||
class AuthSwitchResponse : public LimitedClientPacket
|
||||
{
|
||||
public:
|
||||
String value;
|
||||
|
||||
void readPayload(String s) override
|
||||
void readPayload(ReadBuffer & payload) override
|
||||
{
|
||||
value = std::move(s);
|
||||
readStringUntilEOF(value, payload);
|
||||
}
|
||||
};
|
||||
|
||||
@ -404,33 +563,21 @@ class AuthMoreData : public WritePacket
|
||||
{
|
||||
String data;
|
||||
public:
|
||||
AuthMoreData(String data): data(std::move(data)) {}
|
||||
explicit AuthMoreData(String data): data(std::move(data)) {}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, 0x01);
|
||||
result.append(data);
|
||||
return result;
|
||||
return 1 + data.size();
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(0x01);
|
||||
writeString(data, buffer);
|
||||
}
|
||||
};
|
||||
|
||||
/// Packet with a single null-terminated string. Is used for clear text authentication.
|
||||
class NullTerminatedString : public ReadPacket
|
||||
{
|
||||
public:
|
||||
String value;
|
||||
|
||||
void readPayload(String s) override
|
||||
{
|
||||
if (s.length() == 0 || s.back() != 0)
|
||||
{
|
||||
throw ProtocolError("String is not null terminated.", ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
value = s;
|
||||
value.pop_back();
|
||||
}
|
||||
};
|
||||
|
||||
class OK_Packet : public WritePacket
|
||||
{
|
||||
@ -455,43 +602,65 @@ public:
|
||||
, warnings(warnings)
|
||||
, status_flags(status_flags)
|
||||
, session_state_changes(std::move(session_state_changes))
|
||||
, info(info)
|
||||
, info(std::move(info))
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, header);
|
||||
result.append(writeLengthEncodedNumber(affected_rows));
|
||||
result.append(writeLengthEncodedNumber(0)); /// last insert-id
|
||||
size_t result = 2 + getLengthEncodedNumberSize(affected_rows);
|
||||
|
||||
if (capabilities & CLIENT_PROTOCOL_41)
|
||||
{
|
||||
result.append(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
result.append(reinterpret_cast<const char *>(&warnings), 2);
|
||||
result += 4;
|
||||
}
|
||||
else if (capabilities & CLIENT_TRANSACTIONS)
|
||||
{
|
||||
result.append(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
result += 2;
|
||||
}
|
||||
|
||||
if (capabilities & CLIENT_SESSION_TRACK)
|
||||
{
|
||||
result.append(writeLengthEncodedNumber(info.length()));
|
||||
result.append(info);
|
||||
result += getLengthEncodedStringSize(info);
|
||||
if (status_flags & SERVER_SESSION_STATE_CHANGED)
|
||||
{
|
||||
result.append(writeLengthEncodedNumber(session_state_changes.length()));
|
||||
result.append(session_state_changes);
|
||||
}
|
||||
result += getLengthEncodedStringSize(session_state_changes);
|
||||
}
|
||||
else
|
||||
{
|
||||
result.append(info);
|
||||
result += info.size();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(header);
|
||||
writeLengthEncodedNumber(affected_rows, buffer);
|
||||
writeLengthEncodedNumber(0, buffer); /// last insert-id
|
||||
|
||||
if (capabilities & CLIENT_PROTOCOL_41)
|
||||
{
|
||||
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
|
||||
}
|
||||
else if (capabilities & CLIENT_TRANSACTIONS)
|
||||
{
|
||||
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
}
|
||||
|
||||
if (capabilities & CLIENT_SESSION_TRACK)
|
||||
{
|
||||
writeLengthEncodedString(info, buffer);
|
||||
if (status_flags & SERVER_SESSION_STATE_CHANGED)
|
||||
writeLengthEncodedString(session_state_changes, buffer);
|
||||
}
|
||||
else
|
||||
{
|
||||
writeString(info, buffer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class EOF_Packet : public WritePacket
|
||||
@ -502,13 +671,17 @@ public:
|
||||
EOF_Packet(int warnings, int status_flags) : warnings(warnings), status_flags(status_flags)
|
||||
{}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, 0xfe); // EOF header
|
||||
result.append(reinterpret_cast<const char *>(&warnings), 2);
|
||||
result.append(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
return result;
|
||||
return 5;
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(0xfe); // EOF header
|
||||
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
|
||||
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
}
|
||||
};
|
||||
|
||||
@ -523,15 +696,19 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, 0xff);
|
||||
result.append(reinterpret_cast<const char *>(&error_code), 2);
|
||||
result.append("#", 1);
|
||||
result.append(sql_state.data(), sql_state.length());
|
||||
result.append(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
|
||||
return result;
|
||||
return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE);
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
buffer.write(0xff);
|
||||
buffer.write(reinterpret_cast<const char *>(&error_code), 2);
|
||||
buffer.write('#');
|
||||
buffer.write(sql_state.data(), sql_state.length());
|
||||
buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
|
||||
}
|
||||
};
|
||||
|
||||
@ -579,37 +756,41 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
String result;
|
||||
writeLengthEncodedString(result, "def"); /// always "def"
|
||||
writeLengthEncodedString(result, ""); /// schema
|
||||
writeLengthEncodedString(result, ""); /// table
|
||||
writeLengthEncodedString(result, ""); /// org_table
|
||||
writeLengthEncodedString(result, name);
|
||||
writeLengthEncodedString(result, ""); /// org_name
|
||||
result.append(writeLengthEncodedNumber(next_length));
|
||||
result.append(reinterpret_cast<const char *>(&character_set), 2);
|
||||
result.append(reinterpret_cast<const char *>(&column_length), 4);
|
||||
result.append(reinterpret_cast<const char *>(&column_type), 1);
|
||||
result.append(reinterpret_cast<const char *>(&flags), 2);
|
||||
result.append(reinterpret_cast<const char *>(&decimals), 2);
|
||||
result.append(2, 0x0);
|
||||
return result;
|
||||
return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \
|
||||
getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length);
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
writeLengthEncodedString(std::string("def"), buffer); /// always "def"
|
||||
writeLengthEncodedString(schema, buffer);
|
||||
writeLengthEncodedString(table, buffer);
|
||||
writeLengthEncodedString(org_table, buffer);
|
||||
writeLengthEncodedString(name, buffer);
|
||||
writeLengthEncodedString(org_name, buffer);
|
||||
writeLengthEncodedNumber(next_length, buffer);
|
||||
buffer.write(reinterpret_cast<const char *>(&character_set), 2);
|
||||
buffer.write(reinterpret_cast<const char *>(&column_length), 4);
|
||||
buffer.write(reinterpret_cast<const char *>(&column_type), 1);
|
||||
buffer.write(reinterpret_cast<const char *>(&flags), 2);
|
||||
buffer.write(reinterpret_cast<const char *>(&decimals), 2);
|
||||
writeChar(0x0, 2, buffer);
|
||||
}
|
||||
};
|
||||
|
||||
class ComFieldList : public ReadPacket
|
||||
class ComFieldList : public LimitedClientPacket
|
||||
{
|
||||
public:
|
||||
String table, field_wildcard;
|
||||
|
||||
void readPayload(String payload)
|
||||
void readPayload(ReadBuffer & payload) override
|
||||
{
|
||||
std::istringstream ss(payload);
|
||||
ss.ignore(1); // command byte
|
||||
std::getline(ss, table, static_cast<char>(0x0));
|
||||
field_wildcard = payload.substr(table.length() + 2); // rest of the packet
|
||||
// Command byte has been already read from payload.
|
||||
readNullTerminated(table, payload);
|
||||
readStringUntilEOF(field_wildcard, payload);
|
||||
}
|
||||
};
|
||||
|
||||
@ -617,37 +798,45 @@ class LengthEncodedNumber : public WritePacket
|
||||
{
|
||||
uint64_t value;
|
||||
public:
|
||||
LengthEncodedNumber(uint64_t value): value(value)
|
||||
explicit LengthEncodedNumber(uint64_t value): value(value)
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
return writeLengthEncodedNumber(value);
|
||||
return getLengthEncodedNumberSize(value);
|
||||
}
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
writeLengthEncodedNumber(value, buffer);
|
||||
}
|
||||
};
|
||||
|
||||
class ResultsetRow : public WritePacket
|
||||
{
|
||||
std::vector<String> columns;
|
||||
size_t payload_size = 0;
|
||||
public:
|
||||
ResultsetRow()
|
||||
{
|
||||
}
|
||||
ResultsetRow() = default;
|
||||
|
||||
void appendColumn(String value)
|
||||
void appendColumn(String && value)
|
||||
{
|
||||
payload_size += getLengthEncodedStringSize(value);
|
||||
columns.emplace_back(std::move(value));
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
protected:
|
||||
size_t getPayloadSize() const override
|
||||
{
|
||||
String result;
|
||||
for (const String & column : columns)
|
||||
{
|
||||
writeLengthEncodedString(result, column);
|
||||
return payload_size;
|
||||
}
|
||||
return result;
|
||||
|
||||
void writePayloadImpl(WriteBuffer & buffer) const override
|
||||
{
|
||||
for (const String & column : columns)
|
||||
writeLengthEncodedString(column, buffer);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -12,9 +12,9 @@ using namespace MySQLProtocol;
|
||||
MySQLWireBlockOutputStream::MySQLWireBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context)
|
||||
: header(header)
|
||||
, context(context)
|
||||
, packet_sender(std::make_shared<PacketSender>(buf, context.sequence_id))
|
||||
, packet_sender(std::make_shared<PacketSender>(buf, context.mysql.sequence_id))
|
||||
{
|
||||
packet_sender->max_packet_size = context.max_packet_size;
|
||||
packet_sender->max_packet_size = context.mysql.max_packet_size;
|
||||
}
|
||||
|
||||
void MySQLWireBlockOutputStream::writePrefix()
|
||||
@ -30,7 +30,7 @@ void MySQLWireBlockOutputStream::writePrefix()
|
||||
packet_sender->sendPacket(column_definition);
|
||||
}
|
||||
|
||||
if (!(context.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
|
||||
if (!(context.mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
|
||||
{
|
||||
packet_sender->sendPacket(EOF_Packet(0, 0));
|
||||
}
|
||||
@ -45,12 +45,9 @@ void MySQLWireBlockOutputStream::write(const Block & block)
|
||||
ResultsetRow row_packet;
|
||||
for (const ColumnWithTypeAndName & column : block)
|
||||
{
|
||||
String column_value;
|
||||
WriteBufferFromString ostr(column_value);
|
||||
WriteBufferFromOwnString ostr;
|
||||
column.type->serializeAsText(*column.column.get(), i, ostr, format_settings);
|
||||
ostr.finish();
|
||||
|
||||
row_packet.appendColumn(std::move(column_value));
|
||||
row_packet.appendColumn(std::move(ostr.str()));
|
||||
}
|
||||
packet_sender->sendPacket(row_packet);
|
||||
}
|
||||
@ -70,10 +67,10 @@ void MySQLWireBlockOutputStream::writeSuffix()
|
||||
<< formatReadableSizeWithBinarySuffix(info.read_bytes / info.elapsed_seconds) << "/sec.";
|
||||
|
||||
if (header.columns() == 0)
|
||||
packet_sender->sendPacket(OK_Packet(0x0, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
packet_sender->sendPacket(OK_Packet(0x0, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
else
|
||||
if (context.client_capabilities & CLIENT_DEPRECATE_EOF)
|
||||
packet_sender->sendPacket(OK_Packet(0xfe, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
if (context.mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
|
||||
packet_sender->sendPacket(OK_Packet(0xfe, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
else
|
||||
packet_sender->sendPacket(EOF_Packet(0, 0), true);
|
||||
}
|
||||
|
@ -173,6 +173,12 @@ inline void appendToStringOrVector(PaddedPODArray<UInt8> & s, ReadBuffer & rb, c
|
||||
s.insert(rb.position(), end);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void appendToStringOrVector(PODArray<char> & s, ReadBuffer & rb, const char * end)
|
||||
{
|
||||
s.insert(rb.position(), end);
|
||||
}
|
||||
|
||||
template <typename Vector>
|
||||
void readStringInto(Vector & s, ReadBuffer & buf)
|
||||
{
|
||||
@ -188,6 +194,25 @@ void readStringInto(Vector & s, ReadBuffer & buf)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Vector>
|
||||
void readNullTerminated(Vector & s, ReadBuffer & buf)
|
||||
{
|
||||
while (!buf.eof())
|
||||
{
|
||||
char * next_pos = find_first_symbols<'\0'>(buf.position(), buf.buffer().end());
|
||||
|
||||
appendToStringOrVector(s, buf, next_pos);
|
||||
buf.position() = next_pos;
|
||||
|
||||
if (buf.hasPendingData())
|
||||
break;
|
||||
}
|
||||
buf.ignore();
|
||||
}
|
||||
|
||||
template void readNullTerminated<PODArray<char>>(PODArray<char> & s, ReadBuffer & buf);
|
||||
template void readNullTerminated<String>(String & s, ReadBuffer & buf);
|
||||
|
||||
void readString(String & s, ReadBuffer & buf)
|
||||
{
|
||||
s.clear();
|
||||
|
@ -425,6 +425,9 @@ void readCSVString(String & s, ReadBuffer & buf, const FormatSettings::CSV & set
|
||||
template <typename Vector>
|
||||
void readStringInto(Vector & s, ReadBuffer & buf);
|
||||
|
||||
template <typename Vector>
|
||||
void readNullTerminated(Vector & s, ReadBuffer & buf);
|
||||
|
||||
template <typename Vector>
|
||||
void readEscapedStringInto(Vector & s, ReadBuffer & buf);
|
||||
|
||||
|
@ -47,6 +47,18 @@ inline void writeChar(char x, WriteBuffer & buf)
|
||||
++buf.position();
|
||||
}
|
||||
|
||||
/// Write the same character n times.
|
||||
inline void writeChar(char c, size_t n, WriteBuffer & buf)
|
||||
{
|
||||
while (n)
|
||||
{
|
||||
buf.nextIfAtEnd();
|
||||
size_t count = std::min(n, buf.available());
|
||||
memset(buf.position(), c, count);
|
||||
n -= count;
|
||||
buf.position() += count;
|
||||
}
|
||||
}
|
||||
|
||||
/// Write POD-type in native format. It's recommended to use only with packed (dense) data types.
|
||||
template <typename T>
|
||||
|
@ -499,10 +499,14 @@ public:
|
||||
IHostContextPtr & getHostContext();
|
||||
const IHostContextPtr & getHostContext() const;
|
||||
|
||||
/// MySQL wire protocol state.
|
||||
size_t sequence_id = 0;
|
||||
struct MySQLWireContext
|
||||
{
|
||||
uint8_t sequence_id = 0;
|
||||
uint32_t client_capabilities = 0;
|
||||
size_t max_packet_size = 0;
|
||||
};
|
||||
|
||||
MySQLWireContext mysql;
|
||||
private:
|
||||
/** Check if the current client has access to the specified database.
|
||||
* If access is denied, throw an exception.
|
||||
|
@ -17,7 +17,7 @@ using namespace MySQLProtocol;
|
||||
MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header, const Context & context, const FormatSettings & settings)
|
||||
: IOutputFormat(header, out_)
|
||||
, context(context)
|
||||
, packet_sender(std::make_shared<PacketSender>(out, const_cast<size_t &>(context.sequence_id))) /// TODO: fix it
|
||||
, packet_sender(std::make_shared<PacketSender>(out, const_cast<uint8_t &>(context.mysql.sequence_id))) /// TODO: fix it
|
||||
, format_settings(settings)
|
||||
{
|
||||
}
|
||||
@ -43,7 +43,7 @@ void MySQLOutputFormat::consume(Chunk chunk)
|
||||
packet_sender->sendPacket(column_definition);
|
||||
}
|
||||
|
||||
if (!(context.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
|
||||
if (!(context.mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
|
||||
{
|
||||
packet_sender->sendPacket(EOF_Packet(0, 0));
|
||||
}
|
||||
@ -85,10 +85,10 @@ void MySQLOutputFormat::finalize()
|
||||
auto & header = getPort(PortKind::Main).getHeader();
|
||||
|
||||
if (header.columns() == 0)
|
||||
packet_sender->sendPacket(OK_Packet(0x0, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
packet_sender->sendPacket(OK_Packet(0x0, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
else
|
||||
if (context.client_capabilities & CLIENT_DEPRECATE_EOF)
|
||||
packet_sender->sendPacket(OK_Packet(0xfe, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
if (context.mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
|
||||
packet_sender->sendPacket(OK_Packet(0xfe, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
else
|
||||
packet_sender->sendPacket(EOF_Packet(0, 0), true);
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
config_dir = os.path.join(SCRIPT_DIR, './configs')
|
||||
cluster = ClickHouseCluster(__file__)
|
||||
node = cluster.add_instance('node', config_dir=config_dir)
|
||||
node = cluster.add_instance('node', config_dir=config_dir, env_variables={'UBSAN_OPTIONS': 'print_stacktrace=1'})
|
||||
|
||||
server_port = 9001
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user