added MySQL wire protocol presentational format

This commit is contained in:
Yuriy 2019-05-16 06:34:04 +03:00
parent 8a3e75d92f
commit ff4937859e
9 changed files with 228 additions and 106 deletions

View File

@ -1,4 +1,5 @@
#include <DataStreams/copyData.h> #include <DataStreams/copyData.h>
#include <IO/ReadBufferFromMemory.h>
#include <IO/ReadBufferFromPocoSocket.h> #include <IO/ReadBufferFromPocoSocket.h>
#include <IO/WriteBufferFromPocoSocket.h> #include <IO/WriteBufferFromPocoSocket.h>
#include <Interpreters/executeQuery.h> #include <Interpreters/executeQuery.h>
@ -34,8 +35,11 @@ uint32_t MySQLHandler::last_connection_id = 0;
void MySQLHandler::run() void MySQLHandler::run()
{ {
connection_context = server.context(); connection_context = server.context();
connection_context.setDefaultFormat("MySQL");
packet_sender = PacketSender(socket()); in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id, "MySQLHandler");
try try
{ {
@ -47,11 +51,12 @@ void MySQLHandler::run()
*/ */
Handshake handshake(connection_id, VERSION_STRING, scramble + '\0'); Handshake handshake(connection_id, VERSION_STRING, scramble + '\0');
packet_sender.sendPacket<Handshake>(handshake, true); packet_sender->sendPacket<Handshake>(handshake, true);
LOG_TRACE(log, "Sent handshake"); LOG_TRACE(log, "Sent handshake");
HandshakeResponse handshake_response = finishHandshake(); HandshakeResponse handshake_response = finishHandshake();
connection_context.client_capabilities = handshake_response.capability_flags;
LOG_DEBUG(log, "Capabilities: " << handshake_response.capability_flags LOG_DEBUG(log, "Capabilities: " << handshake_response.capability_flags
<< "\nmax_packet_size: " << "\nmax_packet_size: "
@ -81,12 +86,12 @@ void MySQLHandler::run()
authenticate(handshake_response, scramble); authenticate(handshake_response, scramble);
OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0, 0, ""); OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0, 0, "");
packet_sender.sendPacket(ok_packet, true); packet_sender->sendPacket(ok_packet, true);
while (true) while (true)
{ {
packet_sender.resetSequenceId(); packet_sender->resetSequenceId();
String payload = packet_sender.receivePacketPayload(); String payload = packet_sender->receivePacketPayload();
int command = payload[0]; int command = payload[0];
LOG_DEBUG(log, "Received command: " << std::to_string(command) << ". Connection id: " << connection_id << "."); LOG_DEBUG(log, "Received command: " << std::to_string(command) << ". Connection id: " << connection_id << ".");
try try
@ -119,7 +124,7 @@ void MySQLHandler::run()
catch (const Exception & exc) catch (const Exception & exc)
{ {
log->log(exc); log->log(exc);
packet_sender.sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true); packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
} }
} }
} }
@ -129,35 +134,42 @@ void MySQLHandler::run()
} }
} }
/** Reads 3 bytes, finds out whether it is SSLRequest or HandshakeResponse packet, starts secure connection, if it is SSLRequest.
* Using ReadBufferFromPocoSocket would be less convenient here, because we would have to resize internal buffer many times to prevent reading SSL handshake.
* If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes.
*/
MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake() MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
{ {
/** Size of SSLRequest packet is 32 bytes.
* If we read more, then we will read part of SSL handshake, and it will be impossible to start SSL connection using Poco.
*/
HandshakeResponse packet; HandshakeResponse packet;
char b[100]; /// Client can send either SSLRequest or HandshakeResponse. char b[100]; /// Buffer for SSLRequest or HandshakeResponse.
size_t pos = 0; size_t pos = 0;
while (pos < 3) while (pos < 3)
{ {
int ret = socket().receiveBytes(b + pos, 36 - pos); int ret = socket().receiveBytes(b + pos, 36 - pos);
if (ret == 0) if (ret == 0)
{ {
throw Exception("Cannot read all data. Bytes read: " + std::to_string(pos) + ". Bytes expected: 36.", ErrorCodes::CANNOT_READ_ALL_DATA); throw Exception("Cannot read all data. Bytes read: " + std::to_string(pos) + ". Bytes expected: 3.", ErrorCodes::CANNOT_READ_ALL_DATA);
} }
pos += ret; pos += ret;
} }
size_t packet_size = *reinterpret_cast<uint32_t *>(b) & 0xFFFFFF; size_t packet_size = *reinterpret_cast<uint32_t *>(b) & 0xFFFFFFu;
LOG_TRACE(log, "packet size: " << packet_size); LOG_TRACE(log, "packet size: " << packet_size);
/// Check if it is SSLRequest.
if (packet_size == 32) if (packet_size == 32)
{ {
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
packet_sender = PacketSender(*ss, 2);
secure_connection = true; secure_connection = true;
packet_sender.receivePacket(packet); ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
connection_context.sequence_id = 2;
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id, "MySQLHandler");
packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
} }
else else
{ {
/// Reading rest of HandshakeResponse.
while (pos < 4 + packet_size) while (pos < 4 + packet_size)
{ {
int ret = socket().receiveBytes(b + pos, 4 + packet_size - pos); int ret = socket().receiveBytes(b + pos, 4 + packet_size - pos);
@ -168,7 +180,7 @@ MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
pos += ret; pos += ret;
} }
packet.readPayload(std::string(b + 4, packet_size)); packet.readPayload(std::string(b + 4, packet_size));
packet_sender.sequence_id++; packet_sender->sequence_id++;
} }
return packet; return packet;
} }
@ -191,8 +203,8 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
AuthSwitchResponse response; AuthSwitchResponse response;
if (handshake_response.auth_plugin_name != Authentication::CachingSHA2) if (handshake_response.auth_plugin_name != Authentication::CachingSHA2)
{ {
packet_sender.sendPacket(AuthSwitchRequest(Authentication::CachingSHA2, scramble + '\0'), true); packet_sender->sendPacket(AuthSwitchRequest(Authentication::CachingSHA2, scramble + '\0'), true);
packet_sender.receivePacket(response); packet_sender->receivePacket(response);
auth_response = response.value; auth_response = response.value;
LOG_TRACE(log, "Authentication method mismatch."); LOG_TRACE(log, "Authentication method mismatch.");
} }
@ -204,9 +216,9 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
/// Caching SHA2 plugin is used instead of SHA256 only because it can work without OpenSLL. /// Caching SHA2 plugin is used instead of SHA256 only because it can work without OpenSLL.
/// Fast auth path is not used, because otherwise it would be possible to authenticate using data from users.xml. /// Fast auth path is not used, because otherwise it would be possible to authenticate using data from users.xml.
packet_sender.sendPacket(AuthMoreData("\4"), true); packet_sender->sendPacket(AuthMoreData("\4"), true);
packet_sender.receivePacket(response); packet_sender->receivePacket(response);
auth_response = response.value; auth_response = response.value;
auto getOpenSSLError = []() -> String auto getOpenSSLError = []() -> String
@ -238,8 +250,8 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
LOG_TRACE(log, "Key: " << pem); LOG_TRACE(log, "Key: " << pem);
AuthMoreData data(pem); AuthMoreData data(pem);
packet_sender.sendPacket(data, true); packet_sender->sendPacket(data, true);
packet_sender.receivePacket(response); packet_sender->receivePacket(response);
auth_response = response.value; auth_response = response.value;
} }
else else
@ -301,20 +313,20 @@ void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, co
catch (const Exception & exc) catch (const Exception & exc)
{ {
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " failed."); LOG_ERROR(log, "Authentication for user " << handshake_response.username << " failed.");
packet_sender.sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true); packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
throw; throw;
} }
} }
void MySQLHandler::comInitDB(String payload) void MySQLHandler::comInitDB(const String & payload)
{ {
String database = payload.substr(1); String database = payload.substr(1);
LOG_DEBUG(log, "Setting current database to " << database); LOG_DEBUG(log, "Setting current database to " << database);
connection_context.setCurrentDatabase(database); connection_context.setCurrentDatabase(database);
packet_sender.sendPacket(OK_Packet(0, capabilities, 0, 0, 0, 1, ""), true); packet_sender->sendPacket(OK_Packet(0, capabilities, 0, 0, 0, 1, ""), true);
} }
void MySQLHandler::comFieldList(String payload) void MySQLHandler::comFieldList(const String & payload)
{ {
ComFieldList packet; ComFieldList packet;
packet.readPayload(payload); packet.readPayload(payload);
@ -325,78 +337,26 @@ void MySQLHandler::comFieldList(String payload)
ColumnDefinition column_definition( ColumnDefinition column_definition(
database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0 database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0
); );
packet_sender.sendPacket(column_definition); packet_sender->sendPacket(column_definition);
} }
packet_sender.sendPacket(OK_Packet(0xfe, capabilities, 0, 0, 0, 0, ""), true); packet_sender->sendPacket(OK_Packet(0xfe, capabilities, 0, 0, 0, 0, ""), true);
} }
void MySQLHandler::comPing() void MySQLHandler::comPing()
{ {
packet_sender.sendPacket(OK_Packet(0x0, capabilities, 0, 0, 0, 0, ""), true); packet_sender->sendPacket(OK_Packet(0x0, capabilities, 0, 0, 0, 0, ""), true);
} }
void MySQLHandler::comQuery(String payload) void MySQLHandler::comQuery(const String & payload)
{ {
BlockIO res = executeQuery(payload.substr(1), connection_context); bool with_output = false;
FormatSettings format_settings; std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
if (res.in) with_output = true;
{ };
LOG_TRACE(log, "Executing query with output."); ReadBufferFromMemory query(payload.data() + 1, payload.size() - 1);
executeQuery(query, *out, true, connection_context, set_content_type, nullptr);
Block header = res.in->getHeader(); if (!with_output) {
packet_sender.sendPacket(LengthEncodedNumber(header.columns())); packet_sender->sendPacket(OK_Packet(0x00, capabilities, 0, 0, 0, 0, ""), true);
for (const ColumnWithTypeAndName & column : header.getColumnsWithTypeAndName())
{
ColumnDefinition column_definition(column.name, CharacterSet::binary, std::numeric_limits<uint32_t>::max(),
ColumnType::MYSQL_TYPE_STRING, 0, 0);
packet_sender.sendPacket(column_definition);
LOG_TRACE(log, "Sent " << column.name << " column definition");
}
LOG_TRACE(log, "Sent columns definitions.");
if (!(capabilities & Capability::CLIENT_DEPRECATE_EOF))
{
packet_sender.sendPacket(EOF_Packet(0, 0));
}
while (Block block = res.in->read())
{
size_t rows = block.rows();
for (size_t i = 0; i < rows; i++)
{
ResultsetRow row_packet;
for (ColumnWithTypeAndName & column : block)
{
column.column = column.column->convertToFullColumnIfConst();
String column_value;
WriteBufferFromString ostr(column_value);
LOG_TRACE(log, "Sending value of type " << column.type->getName() << " of column " << column.column->getName());
column.type->serializeAsText(*column.column.get(), i, ostr, format_settings);
ostr.finish();
row_packet.appendColumn(std::move(column_value));
}
packet_sender.sendPacket(row_packet);
}
}
LOG_TRACE(log, "Sent rows.");
}
if (capabilities & CLIENT_DEPRECATE_EOF)
{
packet_sender.sendPacket(OK_Packet(0xfe, capabilities, 0, 0, 0, 0, ""), true);
}
else
{
packet_sender.sendPacket(EOF_Packet(0, 0), true);
} }
} }

View File

@ -36,13 +36,13 @@ private:
/// Enables SSL, if client requested. /// Enables SSL, if client requested.
MySQLProtocol::HandshakeResponse finishHandshake(); MySQLProtocol::HandshakeResponse finishHandshake();
void comQuery(String payload); void comQuery(const String & payload);
void comFieldList(String payload); void comFieldList(const String & payload);
void comPing(); void comPing();
void comInitDB(String payload); void comInitDB(const String & payload);
static String generateScramble(); static String generateScramble();
@ -52,7 +52,7 @@ private:
Poco::Logger * log; Poco::Logger * log;
Context connection_context; Context connection_context;
MySQLProtocol::PacketSender packet_sender; std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
uint32_t connection_id = 0; uint32_t connection_id = 0;

View File

@ -5,11 +5,14 @@
#include <IO/copyData.h> #include <IO/copyData.h>
#include <IO/ReadBufferFromPocoSocket.h> #include <IO/ReadBufferFromPocoSocket.h>
#include <IO/WriteBufferFromPocoSocket.h> #include <IO/WriteBufferFromPocoSocket.h>
#include <IO/WriteBufferFromString.h>
#include <Core/Types.h> #include <Core/Types.h>
#include <Poco/RandomStream.h> #include <Poco/RandomStream.h>
#include <Poco/Net/StreamSocket.h> #include <Poco/Net/StreamSocket.h>
#include <random> #include <random>
#include <sstream> #include <sstream>
#include <common/logger_useful.h>
#include <Poco/Logger.h>
/// Implementation of MySQL wire protocol /// Implementation of MySQL wire protocol
@ -146,15 +149,25 @@ public:
class PacketSender class PacketSender
{ {
public: public:
size_t sequence_id = 0; size_t & sequence_id;
ReadBuffer * in;
WriteBuffer * out;
PacketSender() = default; /// For reading and writing.
PacketSender(ReadBuffer & in, WriteBuffer & out, size_t & sequence_id, const String logger_name)
explicit PacketSender(Poco::Net::StreamSocket & socket, size_t sequence_id=0)
: sequence_id(sequence_id) : sequence_id(sequence_id)
, in(std::make_shared<ReadBufferFromPocoSocket>(socket)) , in(&in)
, out(std::make_shared<WriteBufferFromPocoSocket>(socket)) , out(&out)
, log(&Poco::Logger::get("MySQLHandler")) , log(&Poco::Logger::get(logger_name))
{
}
/// For writing.
PacketSender(WriteBuffer & out, size_t & sequence_id, const String logger_name)
: sequence_id(sequence_id)
, in(nullptr)
, out(&out)
, log(&Poco::Logger::get(logger_name))
{ {
} }
@ -238,9 +251,6 @@ private:
/// Converts packet to text. Is used for debug output. /// Converts packet to text. Is used for debug output.
static String packetToText(String payload); static String packetToText(String payload);
std::shared_ptr<ReadBuffer> in;
std::shared_ptr<WriteBuffer> out;
Poco::Logger * log; Poco::Logger * log;
}; };

View File

@ -0,0 +1,72 @@
#include "MySQLBlockOutputStream.h"
#include <Core/MySQLProtocol.h>
namespace DB
{
using namespace MySQLProtocol;
MySQLBlockOutputStream::MySQLBlockOutputStream(WriteBuffer & buf, const Block & header, const uint32_t capabilities, size_t & sequence_id)
: header(header)
, capabilities(capabilities)
, packet_sender(new PacketSender(buf, sequence_id, "MySQLBlockOutputStream"))
{
}
void MySQLBlockOutputStream::writePrefix()
{
if (header.columns() == 0)
return;
packet_sender->sendPacket(LengthEncodedNumber(header.columns()));
for (const ColumnWithTypeAndName & column : header.getColumnsWithTypeAndName())
{
ColumnDefinition column_definition(column.name, CharacterSet::binary, std::numeric_limits<uint32_t>::max(),
ColumnType::MYSQL_TYPE_STRING, 0, 0);
packet_sender->sendPacket(column_definition);
}
if (!(capabilities & Capability::CLIENT_DEPRECATE_EOF))
{
packet_sender->sendPacket(EOF_Packet(0, 0));
}
}
void MySQLBlockOutputStream::write(const Block & block)
{
size_t rows = block.rows();
for (size_t i = 0; i < rows; i++)
{
ResultsetRow row_packet;
for (const ColumnWithTypeAndName & column : block)
{
String column_value;
WriteBufferFromString ostr(column_value);
column.type->serializeAsText(*column.column.get(), i, ostr, format_settings);
ostr.finish();
row_packet.appendColumn(std::move(column_value));
}
packet_sender->sendPacket(row_packet);
}
}
void MySQLBlockOutputStream::writeSuffix()
{
if (header.columns() == 0)
packet_sender->sendPacket(OK_Packet(0x0, capabilities, 0, 0, 0, 0, ""), true);
else
if (capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OK_Packet(0xfe, capabilities, 0, 0, 0, 0, ""), true);
else
packet_sender->sendPacket(EOF_Packet(0, 0), true);
}
void MySQLBlockOutputStream::flush() {
packet_sender->out->next();
}
}

View File

@ -0,0 +1,36 @@
#pragma once
#include "IBlockOutputStream.h"
#include <Core/MySQLProtocol.h>
#include <Formats/FormatFactory.h>
#include <Formats/FormatSettings.h>
#include <Interpreters/Context.h>
namespace DB
{
/** Interface for writing rows in MySQL Client/Server Protocol format.
*/
class MySQLBlockOutputStream : public IBlockOutputStream
{
public:
MySQLBlockOutputStream(WriteBuffer & buf, const Block & header, const uint32_t capabilities, size_t & sequence_id);
Block getHeader() const { return header; }
void write(const Block & block);
void writePrefix();
void writeSuffix();
void flush();
private:
Block header;
uint32_t capabilities;
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
FormatSettings format_settings;
};
using MySQLBlockOutputStreamPtr = std::shared_ptr<MySQLBlockOutputStream>;
}

View File

@ -129,6 +129,7 @@ void registerOutputFormatXML(FormatFactory & factory);
void registerOutputFormatODBCDriver(FormatFactory & factory); void registerOutputFormatODBCDriver(FormatFactory & factory);
void registerOutputFormatODBCDriver2(FormatFactory & factory); void registerOutputFormatODBCDriver2(FormatFactory & factory);
void registerOutputFormatNull(FormatFactory & factory); void registerOutputFormatNull(FormatFactory & factory);
void registerOutputFormatMySQL(FormatFactory & factory);
/// Input only formats. /// Input only formats.
@ -167,6 +168,7 @@ FormatFactory::FormatFactory()
registerOutputFormatODBCDriver(*this); registerOutputFormatODBCDriver(*this);
registerOutputFormatODBCDriver2(*this); registerOutputFormatODBCDriver2(*this);
registerOutputFormatNull(*this); registerOutputFormatNull(*this);
registerOutputFormatMySQL(*this);
} }
} }

View File

@ -0,0 +1,19 @@
#include <DataStreams/MySQLBlockOutputStream.h>
namespace DB
{
void registerOutputFormatMySQL(FormatFactory & factory)
{
factory.registerOutputFormat("MySQL", [](
WriteBuffer & buf,
const Block & sample,
const Context & context,
const FormatSettings &)
{
return std::make_shared<MySQLBlockOutputStream>(buf, sample, context.client_capabilities, const_cast<Context &>(context).sequence_id);
});
}
}

View File

@ -473,6 +473,9 @@ public:
IHostContextPtr & getHostContext(); IHostContextPtr & getHostContext();
const IHostContextPtr & getHostContext() const; const IHostContextPtr & getHostContext() const;
/// MySQL wire protocol state.
size_t sequence_id = 0;
uint32_t client_capabilities = 0;
private: private:
/** Check if the current client has access to the specified database. /** Check if the current client has access to the specified database.
* If access is denied, throw an exception. * If access is denied, throw an exception.

View File

@ -47,8 +47,8 @@ def test_mysql_client(mysql_client, server_address):
# type: (Container, str) -> None # type: (Container, str) -> None
code, (stdout, stderr) = mysql_client.exec_run(''' code, (stdout, stderr) = mysql_client.exec_run('''
mysql --protocol tcp -h {host} -P {port} default -u default --password=123 mysql --protocol tcp -h {host} -P {port} default -u default --password=123
-e "select 1 as a;" -e "SELECT 1 as a;"
-e "select 'тест' as b;" -e "SELECT 'тест' as b;"
'''.format(host=server_address, port=server_port), demux=True) '''.format(host=server_address, port=server_port), demux=True)
assert stdout == 'a\n1\nb\nтест\n' assert stdout == 'a\n1\nb\nтест\n'
@ -71,6 +71,18 @@ def test_mysql_client(mysql_client, server_address):
assert stderr == "mysql: [Warning] Using a password on the command line interface can be insecure.\n" \ assert stderr == "mysql: [Warning] Using a password on the command line interface can be insecure.\n" \
"ERROR 81 (00000) at line 1: Database system2 doesn't exist\n" "ERROR 81 (00000) at line 1: Database system2 doesn't exist\n"
code, (stdout, stderr) = mysql_client.exec_run('''
mysql --protocol tcp -h {host} -P {port} default -u default --password=123
-e "CREATE DATABASE x;"
-e "USE x;"
-e "CREATE TABLE table1 (a UInt32) ENGINE = Memory;"
-e "INSERT INTO table1 VALUES (0), (1), (5);"
-e "INSERT INTO table1 VALUES (0), (1), (5);"
-e "SELECT * FROM table1 ORDER BY a;"
'''.format(host=server_address, port=server_port), demux=True)
assert stdout == 'a\n0\n0\n1\n1\n5\n5\n'
def test_python_client(server_address): def test_python_client(server_address):
with pytest.raises(pymysql.InternalError) as exc_info: with pytest.raises(pymysql.InternalError) as exc_info:
@ -96,6 +108,14 @@ def test_python_client(server_address):
assert exc_info.value.args == (81, "Database system2 doesn't exist") assert exc_info.value.args == (81, "Database system2 doesn't exist")
client.select_db('x')
cursor = client.cursor(pymysql.cursors.DictCursor)
cursor.execute("TRUNCATE TABLE table1")
cursor.execute("INSERT INTO table1 VALUES (1), (3)")
cursor.execute("INSERT INTO table1 VALUES (1), (4)")
cursor.execute("SELECT * FROM table1 ORDER BY a")
assert cursor.fetchall() == [{'a': '1'}, {'a': '1'}, {'a': '3'}, {'a': '4'}]
def test_golang_client(server_address, golang_container): def test_golang_client(server_address, golang_container):
# type: (str, Container) -> None # type: (str, Container) -> None