Merge pull request #5811 from yurriy/mysql

Reading and writing MySQL packets in parts
This commit is contained in:
alexey-milovidov 2019-07-28 11:55:09 +03:00 committed by GitHub
commit 41eaeb3e3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 568 additions and 317 deletions

View File

@ -1,24 +1,24 @@
#include <DataStreams/copyData.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <Interpreters/executeQuery.h>
#include <Storages/IStorage.h>
#include <Core/MySQLProtocol.h>
#include <Core/NamesAndTypes.h>
#include "MySQLHandler.h"
#include <limits>
#include <ext/scope_guard.h>
#include <openssl/rsa.h>
#include <Columns/ColumnVector.h>
#include <Common/config_version.h>
#include <Common/NetException.h>
#include <Common/OpenSSLHelpers.h>
#include <Poco/Crypto/RSAKey.h>
#include <Core/MySQLProtocol.h>
#include <Core/NamesAndTypes.h>
#include <DataStreams/copyData.h>
#include <Interpreters/executeQuery.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <Poco/Crypto/CipherFactory.h>
#include <Poco/Crypto/RSAKey.h>
#include <Poco/Net/SecureStreamSocket.h>
#include <Poco/Net/SSLManager.h>
#include "MySQLHandler.h"
#include <limits>
#include <ext/scope_guard.h>
#include <openssl/rsa.h>
#include <Storages/IStorage.h>
namespace DB
@ -53,33 +53,29 @@ MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & so
void MySQLHandler::run()
{
connection_context = server.context();
connection_context.makeSessionContext();
connection_context.setDefaultFormat("MySQLWire");
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.mysql.sequence_id);
try
{
String scramble = generateScramble();
/** Native authentication sent 20 bytes + '\0' character = 21 bytes.
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
*/
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, scramble + '\0');
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, Authentication::Native, scramble + '\0');
packet_sender->sendPacket<Handshake>(handshake, true);
LOG_TRACE(log, "Sent handshake");
HandshakeResponse handshake_response = finishHandshake();
connection_context.client_capabilities = handshake_response.capability_flags;
HandshakeResponse handshake_response;
finishHandshake(handshake_response);
connection_context.mysql.client_capabilities = handshake_response.capability_flags;
if (handshake_response.max_packet_size)
connection_context.max_packet_size = handshake_response.max_packet_size;
if (!connection_context.max_packet_size)
connection_context.max_packet_size = MAX_PACKET_LENGTH;
connection_context.mysql.max_packet_size = handshake_response.max_packet_size;
if (!connection_context.mysql.max_packet_size)
connection_context.mysql.max_packet_size = MAX_PACKET_LENGTH;
LOG_DEBUG(log, "Capabilities: " << handshake_response.capability_flags
<< "\nmax_packet_size: "
@ -110,9 +106,15 @@ void MySQLHandler::run()
while (true)
{
packet_sender->resetSequenceId();
String payload = packet_sender->receivePacketPayload();
int command = payload[0];
LOG_DEBUG(log, "Received command: " << std::to_string(command) << ". Connection id: " << connection_id << ".");
PacketPayloadReadBuffer payload = packet_sender->getPayload();
char command = 0;
payload.readStrict(command);
// For commands which are executed without MemoryTracker.
LimitReadBuffer limited_payload(payload, 10000, true, "too long MySQL packet.");
LOG_DEBUG(log, "Received command: " << static_cast<int>(static_cast<unsigned char>(command)) << ". Connection id: " << connection_id << ".");
try
{
switch (command)
@ -120,13 +122,13 @@ void MySQLHandler::run()
case COM_QUIT:
return;
case COM_INIT_DB:
comInitDB(payload);
comInitDB(limited_payload);
break;
case COM_QUERY:
comQuery(payload);
break;
case COM_FIELD_LIST:
comFieldList(payload);
comFieldList(limited_payload);
break;
case COM_PING:
comPing();
@ -147,7 +149,7 @@ void MySQLHandler::run()
}
}
}
catch (Poco::Exception & exc)
catch (const Poco::Exception & exc)
{
log->log(exc);
}
@ -157,9 +159,8 @@ void MySQLHandler::run()
* Reading is performed from socket instead of ReadBuffer to prevent reading part of SSL handshake.
* If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes.
*/
MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
void MySQLHandler::finishHandshake(MySQLProtocol::HandshakeResponse & packet)
{
HandshakeResponse packet;
size_t packet_size = PACKET_HEADER_SIZE + SSL_REQUEST_PAYLOAD_SIZE;
/// Buffer for SSLRequest or part of HandshakeResponse.
@ -187,16 +188,18 @@ MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
{
read_bytes(packet_size); /// Reading rest SSLRequest.
SSLRequest ssl_request;
ssl_request.readPayload(String(buf + PACKET_HEADER_SIZE, pos - PACKET_HEADER_SIZE));
connection_context.client_capabilities = ssl_request.capability_flags;
connection_context.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
ReadBufferFromMemory payload(buf, pos);
payload.ignore(PACKET_HEADER_SIZE);
ssl_request.readPayload(payload);
connection_context.mysql.client_capabilities = ssl_request.capability_flags;
connection_context.mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
secure_connection = true;
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
connection_context.sequence_id = 2;
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
packet_sender->max_packet_size = connection_context.max_packet_size;
connection_context.mysql.sequence_id = 2;
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.mysql.sequence_id);
packet_sender->max_packet_size = connection_context.mysql.max_packet_size;
packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
}
else
@ -206,10 +209,11 @@ MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
WriteBufferFromOwnString buf_for_handshake_response;
buf_for_handshake_response.write(buf, pos);
copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos);
packet.readPayload(buf_for_handshake_response.str().substr(PACKET_HEADER_SIZE));
ReadBufferFromString payload(buf_for_handshake_response.str());
payload.ignore(PACKET_HEADER_SIZE);
packet.readPayload(payload);
packet_sender->sequence_id++;
}
return packet;
}
String MySQLHandler::generateScramble()
@ -230,6 +234,10 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
AuthSwitchResponse response;
if (handshake_response.auth_plugin_name != Authentication::SHA256)
{
/** Native authentication sent 20 bytes + '\0' character = 21 bytes.
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
*/
packet_sender->sendPacket(AuthSwitchRequest(Authentication::SHA256, scramble + '\0'), true);
if (in->eof())
throw Exception(
@ -315,7 +323,7 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
connection_context.setUser(handshake_response.username, password, socket().address(), "");
if (!handshake_response.database.empty()) connection_context.setCurrentDatabase(handshake_response.database);
connection_context.setCurrentQueryId("");
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " succeeded.");
LOG_INFO(log, "Authentication for user " << handshake_response.username << " succeeded.");
}
catch (const Exception & exc)
{
@ -325,15 +333,16 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
}
}
void MySQLHandler::comInitDB(const String & payload)
void MySQLHandler::comInitDB(ReadBuffer & payload)
{
String database = payload.substr(1);
String database;
readStringUntilEOF(database, payload);
LOG_DEBUG(log, "Setting current database to " << database);
connection_context.setCurrentDatabase(database);
packet_sender->sendPacket(OK_Packet(0, client_capability_flags, 0, 0, 1), true);
}
void MySQLHandler::comFieldList(const String & payload)
void MySQLHandler::comFieldList(ReadBuffer & payload)
{
ComFieldList packet;
packet.readPayload(payload);
@ -354,22 +363,26 @@ void MySQLHandler::comPing()
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
}
void MySQLHandler::comQuery(const String & payload)
void MySQLHandler::comQuery(ReadBuffer & payload)
{
bool with_output = false;
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
with_output = true;
};
String query = payload.substr(1);
const String query("select ''");
ReadBufferFromString empty_select(query);
bool should_replace = false;
// Translate query from MySQL to ClickHouse.
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
if (query == "select @@version_comment limit 1") // MariaDB client starts session with that query
query = "select ''";
if (std::string(payload.position(), payload.buffer().end()) == "select @@version_comment limit 1") // MariaDB client starts session with that query
{
should_replace = true;
}
executeQuery(should_replace ? empty_select : payload, *out, true, connection_context, set_content_type, nullptr);
ReadBufferFromString buf(query);
executeQuery(buf, *out, true, connection_context, set_content_type, nullptr);
if (!with_output)
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
}

View File

@ -20,15 +20,15 @@ public:
private:
/// Enables SSL, if client requested.
MySQLProtocol::HandshakeResponse finishHandshake();
void finishHandshake(MySQLProtocol::HandshakeResponse &);
void comQuery(const String & payload);
void comQuery(ReadBuffer & payload);
void comFieldList(const String & payload);
void comFieldList(ReadBuffer & payload);
void comPing();
void comInitDB(const String & payload);
void comInitDB(ReadBuffer & payload);
static String generateScramble();
@ -48,11 +48,11 @@ private:
RSA & public_key;
RSA & private_key;
std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
std::shared_ptr<ReadBuffer> in;
std::shared_ptr<WriteBuffer> out;
bool secure_connection = false;
std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
};
}

View File

@ -7,9 +7,7 @@
#include "MySQLProtocol.h"
namespace DB
{
namespace MySQLProtocol
namespace DB::MySQLProtocol
{
void PacketSender::resetSequenceId()
@ -17,7 +15,7 @@ void PacketSender::resetSequenceId()
sequence_id = 0;
}
String PacketSender::packetToText(String payload)
String PacketSender::packetToText(const String & payload)
{
String result;
for (auto c : payload)
@ -28,11 +26,11 @@ String PacketSender::packetToText(String payload)
return result;
}
uint64_t readLengthEncodedNumber(std::istringstream & ss)
uint64_t readLengthEncodedNumber(ReadBuffer & ss)
{
char c{};
uint64_t buf = 0;
ss.get(c);
ss.readStrict(c);
auto cc = static_cast<uint8_t>(c);
if (cc < 0xfc)
{
@ -40,55 +38,65 @@ uint64_t readLengthEncodedNumber(std::istringstream & ss)
}
else if (cc < 0xfd)
{
ss.read(reinterpret_cast<char *>(&buf), 2);
ss.readStrict(reinterpret_cast<char *>(&buf), 2);
}
else if (cc < 0xfe)
{
ss.read(reinterpret_cast<char *>(&buf), 3);
ss.readStrict(reinterpret_cast<char *>(&buf), 3);
}
else
{
ss.read(reinterpret_cast<char *>(&buf), 8);
ss.readStrict(reinterpret_cast<char *>(&buf), 8);
}
return buf;
}
std::string writeLengthEncodedNumber(uint64_t x)
void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer)
{
std::string result;
if (x < 251)
{
result.append(1, static_cast<char>(x));
buffer.write(static_cast<char>(x));
}
else if (x < (1 << 16))
{
result.append(1, 0xfc);
result.append(reinterpret_cast<char *>(&x), 2);
buffer.write(0xfc);
buffer.write(reinterpret_cast<char *>(&x), 2);
}
else if (x < (1 << 24))
{
result.append(1, 0xfd);
result.append(reinterpret_cast<char *>(&x), 3);
buffer.write(0xfd);
buffer.write(reinterpret_cast<char *>(&x), 3);
}
else
{
result.append(1, 0xfe);
result.append(reinterpret_cast<char *>(&x), 8);
buffer.write(0xfe);
buffer.write(reinterpret_cast<char *>(&x), 8);
}
return result;
}
void writeLengthEncodedString(std::string & payload, const std::string & s)
size_t getLengthEncodedNumberSize(uint64_t x)
{
payload.append(writeLengthEncodedNumber(s.length()));
payload.append(s);
if (x < 251)
{
return 1;
}
else if (x < (1 << 16))
{
return 3;
}
else if (x < (1 << 24))
{
return 4;
}
else
{
return 9;
}
}
void writeNulTerminatedString(std::string & payload, const std::string & s)
size_t getLengthEncodedStringSize(const String & s)
{
payload.append(s);
payload.append(1, 0);
return getLengthEncodedNumberSize(s.size()) + s.size();
}
}
}

View File

@ -1,18 +1,25 @@
#pragma once
#include <Common/MemoryTracker.h>
#include <Common/PODArray.h>
#include <Core/Types.h>
#include <IO/copyData.h>
#include <IO/ReadBuffer.h>
#include <IO/ReadBufferFromMemory.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <IO/WriteBufferFromString.h>
#include <IO/WriteHelpers.h>
#include <Poco/Net/StreamSocket.h>
#include <Poco/RandomStream.h>
#include <random>
#include <sstream>
#include <IO/LimitReadBuffer.h>
/// Implementation of MySQL wire protocol
/// Implementation of MySQL wire protocol.
/// Works only on little-endian architecture.
namespace DB
{
@ -124,23 +131,204 @@ public:
};
class WritePacket
/** Reading packets.
* Internally, it calls (if no more data) next() method of the underlying ReadBufferFromPocoSocket, and sets the working buffer to the rest part of the current packet payload.
*/
class PacketPayloadReadBuffer : public ReadBuffer
{
public:
virtual String getPayload() const = 0;
PacketPayloadReadBuffer(ReadBuffer & in, uint8_t & sequence_id)
: ReadBuffer(in.position(), 0) // not in.buffer().begin(), because working buffer may include previous packet
, in(in)
, sequence_id(sequence_id)
{
}
virtual ~WritePacket() = default;
private:
ReadBuffer & in;
uint8_t & sequence_id;
const size_t max_packet_size = MAX_PACKET_LENGTH;
// Size of packet which is being read now.
size_t payload_length = 0;
// Offset in packet payload.
size_t offset = 0;
protected:
bool nextImpl() override
{
if (payload_length == 0 || (payload_length == max_packet_size && offset == payload_length))
{
working_buffer.resize(0);
offset = 0;
payload_length = 0;
in.readStrict(reinterpret_cast<char *>(&payload_length), 3);
if (payload_length > max_packet_size)
{
std::ostringstream tmp;
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
else if (payload_length == 0)
{
return false;
}
size_t packet_sequence_id = 0;
in.read(reinterpret_cast<char &>(packet_sequence_id));
if (packet_sequence_id != sequence_id)
{
std::ostringstream tmp;
tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << static_cast<unsigned int>(sequence_id) << '.';
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
sequence_id++;
}
else if (offset == payload_length)
{
return false;
}
in.nextIfAtEnd();
working_buffer = ReadBuffer::Buffer(in.position(), in.buffer().end());
size_t count = std::min(in.available(), payload_length - offset);
working_buffer.resize(count);
in.ignore(count);
offset += count;
return true;
}
};
class ReadPacket
class ClientPacket
{
public:
ReadPacket() = default;
ReadPacket(const ReadPacket &) = default;
virtual void readPayload(String payload) = 0;
ClientPacket() = default;
ClientPacket(ClientPacket &&) = default;
virtual ~ReadPacket() = default;
virtual void read(ReadBuffer & in, uint8_t & sequence_id)
{
PacketPayloadReadBuffer payload(in, sequence_id);
readPayload(payload);
if (!payload.eof())
{
std::stringstream tmp;
tmp << "Packet payload is not fully read. Stopped after " << payload.count() << " bytes, while " << payload.available() << " bytes are in buffer.";
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
}
virtual void readPayload(ReadBuffer & buf) = 0;
virtual ~ClientPacket() = default;
};
class LimitedClientPacket : public ClientPacket
{
public:
void read(ReadBuffer & in, uint8_t & sequence_id) override
{
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
ClientPacket::read(limited, sequence_id);
}
};
/** Writing packets.
* https://dev.mysql.com/doc/internals/en/mysql-packet.html
*/
class PacketPayloadWriteBuffer : public WriteBuffer
{
public:
PacketPayloadWriteBuffer(WriteBuffer & out, size_t payload_length, uint8_t & sequence_id)
: WriteBuffer(out.position(), 0)
, out(out)
, sequence_id(sequence_id)
, total_left(payload_length)
{
startPacket();
}
void checkPayloadSize()
{
if (bytes_written + offset() < payload_length)
{
std::stringstream ss;
ss << "Incomplete payload. Written " << bytes << " bytes, expected " << payload_length << " bytes.";
throw Exception(ss.str(), 0);
}
}
~PacketPayloadWriteBuffer() override { next(); }
private:
WriteBuffer & out;
uint8_t & sequence_id;
size_t total_left = 0;
size_t payload_length = 0;
size_t bytes_written = 0;
void startPacket()
{
payload_length = std::min(total_left, MAX_PACKET_LENGTH);
bytes_written = 0;
total_left -= payload_length;
out.write(reinterpret_cast<char *>(&payload_length), 3);
out.write(sequence_id++);
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
pos = working_buffer.begin();
}
protected:
void nextImpl() override
{
int written = pos - working_buffer.begin();
out.position() += written;
bytes_written += written;
if (bytes_written < payload_length)
{
out.nextIfAtEnd();
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
}
else if (total_left > 0 || payload_length == MAX_PACKET_LENGTH)
{
// Starting new packet, since packets of size greater than MAX_PACKET_LENGTH should be split.
startPacket();
}
else
{
// Finished writing packet. Buffer is set to empty to prevent rewriting (pos will be set to the beginning of a working buffer in next()).
// Further attempts to write will stall in the infinite loop.
working_buffer = WriteBuffer::Buffer(out.position(), out.position());
}
}
};
class WritePacket
{
public:
virtual void writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const
{
PacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
writePayloadImpl(buf);
buf.checkPayloadSize();
}
virtual ~WritePacket() = default;
protected:
virtual size_t getPayloadSize() const = 0;
virtual void writePayloadImpl(WriteBuffer & buffer) const = 0;
};
@ -150,13 +338,13 @@ public:
class PacketSender
{
public:
size_t & sequence_id;
uint8_t & sequence_id;
ReadBuffer * in;
WriteBuffer * out;
size_t max_packet_size = MAX_PACKET_LENGTH;
/// For reading and writing.
PacketSender(ReadBuffer & in, WriteBuffer & out, size_t & sequence_id)
PacketSender(ReadBuffer & in, WriteBuffer & out, uint8_t & sequence_id)
: sequence_id(sequence_id)
, in(&in)
, out(&out)
@ -164,91 +352,59 @@ public:
}
/// For writing.
PacketSender(WriteBuffer & out, size_t & sequence_id)
PacketSender(WriteBuffer & out, uint8_t & sequence_id)
: sequence_id(sequence_id)
, in(nullptr)
, out(&out)
{
}
String receivePacketPayload()
void receivePacket(ClientPacket & packet)
{
WriteBufferFromOwnString buf;
size_t payload_length = 0;
size_t packet_sequence_id = 0;
// packets which are larger than or equal to 16MB are splitted
do
{
in->readStrict(reinterpret_cast<char *>(&payload_length), 3);
if (payload_length > max_packet_size)
{
std::ostringstream tmp;
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
in->readStrict(reinterpret_cast<char *>(&packet_sequence_id), 1);
if (packet_sequence_id != sequence_id)
{
std::ostringstream tmp;
tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << sequence_id << '.';
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
sequence_id++;
copyData(*in, static_cast<WriteBuffer &>(buf), payload_length);
} while (payload_length == max_packet_size);
return std::move(buf.str());
}
void receivePacket(ReadPacket & packet)
{
packet.readPayload(receivePacketPayload());
packet.read(*in, sequence_id);
}
template<class T>
void sendPacket(const T & packet, bool flush = false)
{
static_assert(std::is_base_of<WritePacket, T>());
String payload = packet.getPayload();
size_t pos = 0;
do
{
size_t payload_length = std::min(payload.length() - pos, max_packet_size);
out->write(reinterpret_cast<const char *>(&payload_length), 3);
out->write(reinterpret_cast<const char *>(&sequence_id), 1);
out->write(payload.data() + pos, payload_length);
pos += payload_length;
sequence_id++;
} while (pos < payload.length());
packet.writePayload(*out, sequence_id);
if (flush)
out->next();
}
PacketPayloadReadBuffer getPayload()
{
return PacketPayloadReadBuffer(*in, sequence_id);
}
/// Sets sequence-id to 0. Must be called before each command phase.
void resetSequenceId();
private:
/// Converts packet to text. Is used for debug output.
static String packetToText(String payload);
static String packetToText(const String & payload);
};
uint64_t readLengthEncodedNumber(std::istringstream & ss);
uint64_t readLengthEncodedNumber(ReadBuffer & ss);
String writeLengthEncodedNumber(uint64_t x);
void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer);
void writeLengthEncodedString(String & payload, const String & s);
inline void writeLengthEncodedString(const String & s, WriteBuffer & buffer)
{
writeLengthEncodedNumber(s.size(), buffer);
buffer.write(s.data(), s.size());
}
void writeNulTerminatedString(String & payload, const String & s);
inline void writeNulTerminatedString(const String & s, WriteBuffer & buffer)
{
buffer.write(s.data(), s.size());
buffer.write(0);
}
size_t getLengthEncodedNumberSize(uint64_t x);
size_t getLengthEncodedStringSize(const String & s);
class Handshake : public WritePacket
@ -259,66 +415,69 @@ class Handshake : public WritePacket
uint32_t capability_flags;
uint8_t character_set;
uint32_t status_flags;
String auth_plugin_name;
String auth_plugin_data;
public:
explicit Handshake(uint32_t capability_flags, uint32_t connection_id, String server_version, String auth_plugin_data)
explicit Handshake(uint32_t capability_flags, uint32_t connection_id, String server_version, String auth_plugin_name, String auth_plugin_data)
: protocol_version(0xa)
, server_version(std::move(server_version))
, connection_id(connection_id)
, capability_flags(capability_flags)
, character_set(CharacterSet::utf8_general_ci)
, status_flags(0)
, auth_plugin_data(auth_plugin_data)
, auth_plugin_name(std::move(auth_plugin_name))
, auth_plugin_data(std::move(auth_plugin_data))
{
}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
String result;
result.append(1, protocol_version);
writeNulTerminatedString(result, server_version);
result.append(reinterpret_cast<const char *>(&connection_id), 4);
writeNulTerminatedString(result, auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH));
result.append(reinterpret_cast<const char *>(&capability_flags), 2);
result.append(reinterpret_cast<const char *>(&character_set), 1);
result.append(reinterpret_cast<const char *>(&status_flags), 2);
result.append((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
result.append(1, auth_plugin_data.size());
result.append(10, 0x0);
result.append(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH));
return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size();
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(static_cast<char>(protocol_version));
writeNulTerminatedString(server_version, buffer);
buffer.write(reinterpret_cast<const char *>(&connection_id), 4);
writeNulTerminatedString(auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
buffer.write(reinterpret_cast<const char *>(&capability_flags), 2);
buffer.write(reinterpret_cast<const char *>(&character_set), 1);
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
buffer.write((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
buffer.write(static_cast<char>(auth_plugin_data.size()));
writeChar(0x0, 10, buffer);
writeString(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
// A workaround for PHP mysqlnd extension bug which occurs when sha256_password is used as a default authentication plugin.
// Instead of using client response for mysql_native_password plugin, the server will always generate authentication method mismatch
// and switch to sha256_password to simulate that mysql_native_password is used as a default plugin.
result.append(Authentication::Native);
result.append(1, 0x0);
return result;
writeString(auth_plugin_name, buffer);
writeChar(0x0, 1, buffer);
}
};
class SSLRequest : public ReadPacket
class SSLRequest : public ClientPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
void readPayload(String s) override
void readPayload(ReadBuffer & buf) override
{
std::istringstream ss(s);
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
buf.readStrict(reinterpret_cast<char *>(&character_set), 1);
}
};
class HandshakeResponse : public ReadPacket
class HandshakeResponse : public LimitedClientPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
uint32_t capability_flags = 0;
uint32_t max_packet_size = 0;
uint8_t character_set = 0;
String username;
String auth_response;
String database;
@ -326,45 +485,41 @@ public:
HandshakeResponse() = default;
HandshakeResponse(const HandshakeResponse &) = default;
void readPayload(String s) override
void readPayload(ReadBuffer & payload) override
{
std::istringstream ss(s);
payload.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
payload.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
payload.readStrict(reinterpret_cast<char *>(&character_set), 1);
payload.ignore(23);
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
ss.ignore(23);
std::getline(ss, username, static_cast<char>(0x0));
readNullTerminated(username, payload);
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
auto len = readLengthEncodedNumber(ss);
auto len = readLengthEncodedNumber(payload);
auth_response.resize(len);
ss.read(auth_response.data(), static_cast<std::streamsize>(len));
payload.readStrict(auth_response.data(), len);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
uint8_t len;
ss.read(reinterpret_cast<char *>(&len), 1);
auth_response.resize(len);
ss.read(auth_response.data(), len);
char len;
payload.readStrict(len);
auth_response.resize(static_cast<unsigned int>(len));
payload.readStrict(auth_response.data(), len);
}
else
{
std::getline(ss, auth_response, static_cast<char>(0x0));
readNullTerminated(auth_response, payload);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
std::getline(ss, database, static_cast<char>(0x0));
readNullTerminated(database, payload);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
std::getline(ss, auth_plugin_name, static_cast<char>(0x0));
readNullTerminated(auth_plugin_name, payload);
}
}
};
@ -379,24 +534,28 @@ public:
{
}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
String result;
result.append(1, 0xfe);
writeNulTerminatedString(result, plugin_name);
result.append(auth_plugin_data);
return result;
return 2 + plugin_name.size() + auth_plugin_data.size();
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(0xfe);
writeNulTerminatedString(plugin_name, buffer);
writeString(auth_plugin_data, buffer);
}
};
class AuthSwitchResponse : public ReadPacket
class AuthSwitchResponse : public LimitedClientPacket
{
public:
String value;
void readPayload(String s) override
void readPayload(ReadBuffer & payload) override
{
value = std::move(s);
readStringUntilEOF(value, payload);
}
};
@ -404,33 +563,21 @@ class AuthMoreData : public WritePacket
{
String data;
public:
AuthMoreData(String data): data(std::move(data)) {}
explicit AuthMoreData(String data): data(std::move(data)) {}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
String result;
result.append(1, 0x01);
result.append(data);
return result;
return 1 + data.size();
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(0x01);
writeString(data, buffer);
}
};
/// Packet with a single null-terminated string. Is used for clear text authentication.
class NullTerminatedString : public ReadPacket
{
public:
String value;
void readPayload(String s) override
{
if (s.length() == 0 || s.back() != 0)
{
throw ProtocolError("String is not null terminated.", ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
value = s;
value.pop_back();
}
};
class OK_Packet : public WritePacket
{
@ -455,43 +602,65 @@ public:
, warnings(warnings)
, status_flags(status_flags)
, session_state_changes(std::move(session_state_changes))
, info(info)
, info(std::move(info))
{
}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
String result;
result.append(1, header);
result.append(writeLengthEncodedNumber(affected_rows));
result.append(writeLengthEncodedNumber(0)); /// last insert-id
size_t result = 2 + getLengthEncodedNumberSize(affected_rows);
if (capabilities & CLIENT_PROTOCOL_41)
{
result.append(reinterpret_cast<const char *>(&status_flags), 2);
result.append(reinterpret_cast<const char *>(&warnings), 2);
result += 4;
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
result.append(reinterpret_cast<const char *>(&status_flags), 2);
result += 2;
}
if (capabilities & CLIENT_SESSION_TRACK)
{
result.append(writeLengthEncodedNumber(info.length()));
result.append(info);
result += getLengthEncodedStringSize(info);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
{
result.append(writeLengthEncodedNumber(session_state_changes.length()));
result.append(session_state_changes);
}
result += getLengthEncodedStringSize(session_state_changes);
}
else
{
result.append(info);
result += info.size();
}
return result;
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(header);
writeLengthEncodedNumber(affected_rows, buffer);
writeLengthEncodedNumber(0, buffer); /// last insert-id
if (capabilities & CLIENT_PROTOCOL_41)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
writeLengthEncodedString(info, buffer);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
writeLengthEncodedString(session_state_changes, buffer);
}
else
{
writeString(info, buffer);
}
}
};
class EOF_Packet : public WritePacket
@ -502,13 +671,17 @@ public:
EOF_Packet(int warnings, int status_flags) : warnings(warnings), status_flags(status_flags)
{}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
String result;
result.append(1, 0xfe); // EOF header
result.append(reinterpret_cast<const char *>(&warnings), 2);
result.append(reinterpret_cast<const char *>(&status_flags), 2);
return result;
return 5;
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(0xfe); // EOF header
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
};
@ -523,15 +696,19 @@ public:
{
}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
String result;
result.append(1, 0xff);
result.append(reinterpret_cast<const char *>(&error_code), 2);
result.append("#", 1);
result.append(sql_state.data(), sql_state.length());
result.append(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
return result;
return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE);
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
buffer.write(0xff);
buffer.write(reinterpret_cast<const char *>(&error_code), 2);
buffer.write('#');
buffer.write(sql_state.data(), sql_state.length());
buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
}
};
@ -579,37 +756,41 @@ public:
{
}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
String result;
writeLengthEncodedString(result, "def"); /// always "def"
writeLengthEncodedString(result, ""); /// schema
writeLengthEncodedString(result, ""); /// table
writeLengthEncodedString(result, ""); /// org_table
writeLengthEncodedString(result, name);
writeLengthEncodedString(result, ""); /// org_name
result.append(writeLengthEncodedNumber(next_length));
result.append(reinterpret_cast<const char *>(&character_set), 2);
result.append(reinterpret_cast<const char *>(&column_length), 4);
result.append(reinterpret_cast<const char *>(&column_type), 1);
result.append(reinterpret_cast<const char *>(&flags), 2);
result.append(reinterpret_cast<const char *>(&decimals), 2);
result.append(2, 0x0);
return result;
return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \
getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length);
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
writeLengthEncodedString(std::string("def"), buffer); /// always "def"
writeLengthEncodedString(schema, buffer);
writeLengthEncodedString(table, buffer);
writeLengthEncodedString(org_table, buffer);
writeLengthEncodedString(name, buffer);
writeLengthEncodedString(org_name, buffer);
writeLengthEncodedNumber(next_length, buffer);
buffer.write(reinterpret_cast<const char *>(&character_set), 2);
buffer.write(reinterpret_cast<const char *>(&column_length), 4);
buffer.write(reinterpret_cast<const char *>(&column_type), 1);
buffer.write(reinterpret_cast<const char *>(&flags), 2);
buffer.write(reinterpret_cast<const char *>(&decimals), 2);
writeChar(0x0, 2, buffer);
}
};
class ComFieldList : public ReadPacket
class ComFieldList : public LimitedClientPacket
{
public:
String table, field_wildcard;
void readPayload(String payload)
void readPayload(ReadBuffer & payload) override
{
std::istringstream ss(payload);
ss.ignore(1); // command byte
std::getline(ss, table, static_cast<char>(0x0));
field_wildcard = payload.substr(table.length() + 2); // rest of the packet
// Command byte has been already read from payload.
readNullTerminated(table, payload);
readStringUntilEOF(field_wildcard, payload);
}
};
@ -617,37 +798,45 @@ class LengthEncodedNumber : public WritePacket
{
uint64_t value;
public:
LengthEncodedNumber(uint64_t value): value(value)
explicit LengthEncodedNumber(uint64_t value): value(value)
{
}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
return writeLengthEncodedNumber(value);
return getLengthEncodedNumberSize(value);
}
void writePayloadImpl(WriteBuffer & buffer) const override
{
writeLengthEncodedNumber(value, buffer);
}
};
class ResultsetRow : public WritePacket
{
std::vector<String> columns;
size_t payload_size = 0;
public:
ResultsetRow()
{
}
ResultsetRow() = default;
void appendColumn(String value)
void appendColumn(String && value)
{
payload_size += getLengthEncodedStringSize(value);
columns.emplace_back(std::move(value));
}
String getPayload() const override
protected:
size_t getPayloadSize() const override
{
String result;
for (const String & column : columns)
{
writeLengthEncodedString(result, column);
return payload_size;
}
return result;
void writePayloadImpl(WriteBuffer & buffer) const override
{
for (const String & column : columns)
writeLengthEncodedString(column, buffer);
}
};

View File

@ -12,9 +12,9 @@ using namespace MySQLProtocol;
MySQLWireBlockOutputStream::MySQLWireBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context)
: header(header)
, context(context)
, packet_sender(std::make_shared<PacketSender>(buf, context.sequence_id))
, packet_sender(std::make_shared<PacketSender>(buf, context.mysql.sequence_id))
{
packet_sender->max_packet_size = context.max_packet_size;
packet_sender->max_packet_size = context.mysql.max_packet_size;
}
void MySQLWireBlockOutputStream::writePrefix()
@ -30,7 +30,7 @@ void MySQLWireBlockOutputStream::writePrefix()
packet_sender->sendPacket(column_definition);
}
if (!(context.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
if (!(context.mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
{
packet_sender->sendPacket(EOF_Packet(0, 0));
}
@ -45,12 +45,9 @@ void MySQLWireBlockOutputStream::write(const Block & block)
ResultsetRow row_packet;
for (const ColumnWithTypeAndName & column : block)
{
String column_value;
WriteBufferFromString ostr(column_value);
WriteBufferFromOwnString ostr;
column.type->serializeAsText(*column.column.get(), i, ostr, format_settings);
ostr.finish();
row_packet.appendColumn(std::move(column_value));
row_packet.appendColumn(std::move(ostr.str()));
}
packet_sender->sendPacket(row_packet);
}
@ -70,10 +67,10 @@ void MySQLWireBlockOutputStream::writeSuffix()
<< formatReadableSizeWithBinarySuffix(info.read_bytes / info.elapsed_seconds) << "/sec.";
if (header.columns() == 0)
packet_sender->sendPacket(OK_Packet(0x0, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
packet_sender->sendPacket(OK_Packet(0x0, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
else
if (context.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OK_Packet(0xfe, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
if (context.mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OK_Packet(0xfe, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
else
packet_sender->sendPacket(EOF_Packet(0, 0), true);
}

View File

@ -173,6 +173,12 @@ inline void appendToStringOrVector(PaddedPODArray<UInt8> & s, ReadBuffer & rb, c
s.insert(rb.position(), end);
}
template <>
inline void appendToStringOrVector(PODArray<char> & s, ReadBuffer & rb, const char * end)
{
s.insert(rb.position(), end);
}
template <typename Vector>
void readStringInto(Vector & s, ReadBuffer & buf)
{
@ -188,6 +194,25 @@ void readStringInto(Vector & s, ReadBuffer & buf)
}
}
template <typename Vector>
void readNullTerminated(Vector & s, ReadBuffer & buf)
{
while (!buf.eof())
{
char * next_pos = find_first_symbols<'\0'>(buf.position(), buf.buffer().end());
appendToStringOrVector(s, buf, next_pos);
buf.position() = next_pos;
if (buf.hasPendingData())
break;
}
buf.ignore();
}
template void readNullTerminated<PODArray<char>>(PODArray<char> & s, ReadBuffer & buf);
template void readNullTerminated<String>(String & s, ReadBuffer & buf);
void readString(String & s, ReadBuffer & buf)
{
s.clear();

View File

@ -425,6 +425,9 @@ void readCSVString(String & s, ReadBuffer & buf, const FormatSettings::CSV & set
template <typename Vector>
void readStringInto(Vector & s, ReadBuffer & buf);
template <typename Vector>
void readNullTerminated(Vector & s, ReadBuffer & buf);
template <typename Vector>
void readEscapedStringInto(Vector & s, ReadBuffer & buf);

View File

@ -47,6 +47,18 @@ inline void writeChar(char x, WriteBuffer & buf)
++buf.position();
}
/// Write the same character n times.
inline void writeChar(char c, size_t n, WriteBuffer & buf)
{
while (n)
{
buf.nextIfAtEnd();
size_t count = std::min(n, buf.available());
memset(buf.position(), c, count);
n -= count;
buf.position() += count;
}
}
/// Write POD-type in native format. It's recommended to use only with packed (dense) data types.
template <typename T>

View File

@ -499,10 +499,14 @@ public:
IHostContextPtr & getHostContext();
const IHostContextPtr & getHostContext() const;
/// MySQL wire protocol state.
size_t sequence_id = 0;
struct MySQLWireContext
{
uint8_t sequence_id = 0;
uint32_t client_capabilities = 0;
size_t max_packet_size = 0;
};
MySQLWireContext mysql;
private:
/** Check if the current client has access to the specified database.
* If access is denied, throw an exception.

View File

@ -17,7 +17,7 @@ using namespace MySQLProtocol;
MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header, const Context & context, const FormatSettings & settings)
: IOutputFormat(header, out_)
, context(context)
, packet_sender(std::make_shared<PacketSender>(out, const_cast<size_t &>(context.sequence_id))) /// TODO: fix it
, packet_sender(std::make_shared<PacketSender>(out, const_cast<uint8_t &>(context.mysql.sequence_id))) /// TODO: fix it
, format_settings(settings)
{
}
@ -43,7 +43,7 @@ void MySQLOutputFormat::consume(Chunk chunk)
packet_sender->sendPacket(column_definition);
}
if (!(context.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
if (!(context.mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
{
packet_sender->sendPacket(EOF_Packet(0, 0));
}
@ -85,10 +85,10 @@ void MySQLOutputFormat::finalize()
auto & header = getPort(PortKind::Main).getHeader();
if (header.columns() == 0)
packet_sender->sendPacket(OK_Packet(0x0, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
packet_sender->sendPacket(OK_Packet(0x0, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
else
if (context.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OK_Packet(0xfe, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
if (context.mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OK_Packet(0xfe, context.mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
else
packet_sender->sendPacket(EOF_Packet(0, 0), true);
}

View File

@ -15,7 +15,7 @@ SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
config_dir = os.path.join(SCRIPT_DIR, './configs')
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance('node', config_dir=config_dir)
node = cluster.add_instance('node', config_dir=config_dir, env_variables={'UBSAN_OPTIONS': 'print_stacktrace=1'})
server_port = 9001