Merge pull request #26051 from ClickHouse/fix_21184

Fix sequence_id in MySQL protocol
This commit is contained in:
Nikita Mikhaylov 2021-07-13 16:03:25 +03:00 committed by GitHub
commit c255f152aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 135 additions and 56 deletions

View File

@ -26,13 +26,14 @@ namespace ErrorCodes
MySQLClient::MySQLClient(const String & host_, UInt16 port_, const String & user_, const String & password_) MySQLClient::MySQLClient(const String & host_, UInt16 port_, const String & user_, const String & password_)
: host(host_), port(port_), user(user_), password(std::move(password_)) : host(host_), port(port_), user(user_), password(std::move(password_))
{ {
client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION; mysql_context.client_capabilities = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
} }
MySQLClient::MySQLClient(MySQLClient && other) MySQLClient::MySQLClient(MySQLClient && other)
: host(std::move(other.host)), port(other.port), user(std::move(other.user)), password(std::move(other.password)) : host(std::move(other.host)), port(other.port), user(std::move(other.user)), password(std::move(other.password))
, client_capability_flags(other.client_capability_flags) , mysql_context(other.mysql_context)
{ {
mysql_context.sequence_id = 0;
} }
void MySQLClient::connect() void MySQLClient::connect()
@ -56,7 +57,7 @@ void MySQLClient::connect()
in = std::make_shared<ReadBufferFromPocoSocket>(*socket); in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket); out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, seq); packet_endpoint = mysql_context.makeEndpoint(*in, *out);
handshake(); handshake();
} }
@ -68,7 +69,7 @@ void MySQLClient::disconnect()
socket->close(); socket->close();
socket = nullptr; socket = nullptr;
connected = false; connected = false;
seq = 0; mysql_context.sequence_id = 0;
} }
/// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html /// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
@ -87,10 +88,10 @@ void MySQLClient::handshake()
String auth_plugin_data = native41.getAuthPluginData(); String auth_plugin_data = native41.getAuthPluginData();
HandshakeResponse handshake_response( HandshakeResponse handshake_response(
client_capability_flags, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password); mysql_context.client_capabilities, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
packet_endpoint->sendPacket<HandshakeResponse>(handshake_response, true); packet_endpoint->sendPacket<HandshakeResponse>(handshake_response, true);
ResponsePacket packet_response(client_capability_flags, true); ResponsePacket packet_response(mysql_context.client_capabilities, true);
packet_endpoint->receivePacket(packet_response); packet_endpoint->receivePacket(packet_response);
packet_endpoint->resetSequenceId(); packet_endpoint->resetSequenceId();
@ -105,7 +106,7 @@ void MySQLClient::writeCommand(char command, String query)
WriteCommand write_command(command, query); WriteCommand write_command(command, query);
packet_endpoint->sendPacket<WriteCommand>(write_command, true); packet_endpoint->sendPacket<WriteCommand>(write_command, true);
ResponsePacket packet_response(client_capability_flags); ResponsePacket packet_response(mysql_context.client_capabilities);
packet_endpoint->receivePacket(packet_response); packet_endpoint->receivePacket(packet_response);
switch (packet_response.getType()) switch (packet_response.getType())
{ {
@ -124,7 +125,7 @@ void MySQLClient::registerSlaveOnMaster(UInt32 slave_id)
RegisterSlave register_slave(slave_id); RegisterSlave register_slave(slave_id);
packet_endpoint->sendPacket<RegisterSlave>(register_slave, true); packet_endpoint->sendPacket<RegisterSlave>(register_slave, true);
ResponsePacket packet_response(client_capability_flags); ResponsePacket packet_response(mysql_context.client_capabilities);
packet_endpoint->receivePacket(packet_response); packet_endpoint->receivePacket(packet_response);
packet_endpoint->resetSequenceId(); packet_endpoint->resetSequenceId();
if (packet_response.getType() == PACKET_ERR) if (packet_response.getType() == PACKET_ERR)

View File

@ -45,9 +45,7 @@ private:
String password; String password;
bool connected = false; bool connected = false;
UInt32 client_capability_flags = 0; MySQLWireContext mysql_context;
uint8_t seq = 0;
const UInt8 charset_utf8 = 33; const UInt8 charset_utf8 = 33;
const String mysql_native_password = "mysql_native_password"; const String mysql_native_password = "mysql_native_password";

View File

@ -68,4 +68,15 @@ String PacketEndpoint::packetToText(const String & payload)
} }
MySQLProtocol::PacketEndpointPtr MySQLWireContext::makeEndpoint(WriteBuffer & out)
{
return MySQLProtocol::PacketEndpoint::create(out, sequence_id);
}
MySQLProtocol::PacketEndpointPtr MySQLWireContext::makeEndpoint(ReadBuffer & in, WriteBuffer & out)
{
return MySQLProtocol::PacketEndpoint::create(in, out, sequence_id);
}
} }

View File

@ -5,6 +5,7 @@
#include "IMySQLReadPacket.h" #include "IMySQLReadPacket.h"
#include "IMySQLWritePacket.h" #include "IMySQLWritePacket.h"
#include "IO/MySQLPacketPayloadReadBuffer.h" #include "IO/MySQLPacketPayloadReadBuffer.h"
#include <common/shared_ptr_helper.h>
namespace DB namespace DB
{ {
@ -15,19 +16,13 @@ namespace MySQLProtocol
/* Writes and reads packets, keeping sequence-id. /* Writes and reads packets, keeping sequence-id.
* Throws ProtocolError, if packet with incorrect sequence-id was received. * Throws ProtocolError, if packet with incorrect sequence-id was received.
*/ */
class PacketEndpoint class PacketEndpoint : public shared_ptr_helper<PacketEndpoint>
{ {
public: public:
uint8_t & sequence_id; uint8_t & sequence_id;
ReadBuffer * in; ReadBuffer * in;
WriteBuffer * out; WriteBuffer * out;
/// For writing.
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
/// For reading and writing.
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
MySQLPacketPayloadReadBuffer getPayload(); MySQLPacketPayloadReadBuffer getPayload();
void receivePacket(IMySQLReadPacket & packet); void receivePacket(IMySQLReadPacket & packet);
@ -48,8 +43,29 @@ public:
/// Converts packet to text. Is used for debug output. /// Converts packet to text. Is used for debug output.
static String packetToText(const String & payload); static String packetToText(const String & payload);
protected:
/// For writing.
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
/// For reading and writing.
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
friend struct shared_ptr_helper<PacketEndpoint>;
};
using PacketEndpointPtr = std::shared_ptr<PacketEndpoint>;
}
struct MySQLWireContext
{
uint8_t sequence_id = 0;
uint32_t client_capabilities = 0;
size_t max_packet_size = 0;
MySQLProtocol::PacketEndpointPtr makeEndpoint(WriteBuffer & out);
MySQLProtocol::PacketEndpointPtr makeEndpoint(ReadBuffer & in, WriteBuffer & out);
}; };
} }
}

View File

@ -33,6 +33,7 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR; extern const int LOGICAL_ERROR;
extern const int FORMAT_IS_NOT_SUITABLE_FOR_INPUT; extern const int FORMAT_IS_NOT_SUITABLE_FOR_INPUT;
extern const int FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT; extern const int FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT;
extern const int UNSUPPORTED_METHOD;
} }
const FormatFactory::Creators & FormatFactory::getCreators(const String & name) const const FormatFactory::Creators & FormatFactory::getCreators(const String & name) const
@ -207,6 +208,9 @@ BlockOutputStreamPtr FormatFactory::getOutputStreamParallelIfPossible(
WriteCallback callback, WriteCallback callback,
const std::optional<FormatSettings> & _format_settings) const const std::optional<FormatSettings> & _format_settings) const
{ {
if (context->getMySQLProtocolContext() && name != "MySQLWire")
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats");
const auto & output_getter = getCreators(name).output_processor_creator; const auto & output_getter = getCreators(name).output_processor_creator;
const Settings & settings = context->getSettingsRef(); const Settings & settings = context->getSettingsRef();
@ -309,7 +313,10 @@ OutputFormatPtr FormatFactory::getOutputFormatParallelIfPossible(
{ {
const auto & output_getter = getCreators(name).output_processor_creator; const auto & output_getter = getCreators(name).output_processor_creator;
if (!output_getter) if (!output_getter)
throw Exception("Format " + name + " is not suitable for output (with processors)", ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT); throw Exception(ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT, "Format {} is not suitable for output (with processors)", name);
if (context->getMySQLProtocolContext() && name != "MySQLWire")
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats");
auto format_settings = _format_settings ? *_format_settings : getFormatSettings(context); auto format_settings = _format_settings ? *_format_settings : getFormatSettings(context);
@ -344,7 +351,7 @@ OutputFormatPtr FormatFactory::getOutputFormat(
{ {
const auto & output_getter = getCreators(name).output_processor_creator; const auto & output_getter = getCreators(name).output_processor_creator;
if (!output_getter) if (!output_getter)
throw Exception("Format " + name + " is not suitable for output (with processors)", ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT); throw Exception(ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT, "Format {} is not suitable for output (with processors)", name);
if (context->hasQueryContext() && context->getSettingsRef().log_queries) if (context->hasQueryContext() && context->getSettingsRef().log_queries)
context->getQueryContext()->addQueryFactoriesInfo(Context::QueryLogFactories::Format, name); context->getQueryContext()->addQueryFactoriesInfo(Context::QueryLogFactories::Format, name);

View File

@ -2355,11 +2355,6 @@ OutputFormatPtr Context::getOutputFormatParallelIfPossible(const String & name,
return FormatFactory::instance().getOutputFormatParallelIfPossible(name, buf, sample, shared_from_this()); return FormatFactory::instance().getOutputFormatParallelIfPossible(name, buf, sample, shared_from_this());
} }
OutputFormatPtr Context::getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const
{
return FormatFactory::instance().getOutputFormat(name, buf, sample, shared_from_this());
}
time_t Context::getUptimeSeconds() const time_t Context::getUptimeSeconds() const
{ {
@ -2732,4 +2727,18 @@ PartUUIDsPtr Context::getIgnoredPartUUIDs() const
return ignored_part_uuids; return ignored_part_uuids;
} }
void Context::setMySQLProtocolContext(MySQLWireContext * mysql_context)
{
assert(session_context.lock().get() == this);
assert(!mysql_protocol_context);
assert(mysql_context);
mysql_protocol_context = mysql_context;
}
MySQLWireContext * Context::getMySQLProtocolContext() const
{
assert(!mysql_protocol_context || session_context.lock().get());
return mysql_protocol_context;
}
} }

View File

@ -119,6 +119,8 @@ using ThrottlerPtr = std::shared_ptr<Throttler>;
class ZooKeeperMetadataTransaction; class ZooKeeperMetadataTransaction;
using ZooKeeperMetadataTransactionPtr = std::shared_ptr<ZooKeeperMetadataTransaction>; using ZooKeeperMetadataTransactionPtr = std::shared_ptr<ZooKeeperMetadataTransaction>;
struct MySQLWireContext;
/// Callback for external tables initializer /// Callback for external tables initializer
using ExternalTablesInitializer = std::function<void(ContextPtr)>; using ExternalTablesInitializer = std::function<void(ContextPtr)>;
@ -298,6 +300,8 @@ private:
/// thousands of signatures. /// thousands of signatures.
/// And I hope it will be replaced with more common Transaction sometime. /// And I hope it will be replaced with more common Transaction sometime.
MySQLWireContext * mysql_protocol_context = nullptr;
Context(); Context();
Context(const Context &); Context(const Context &);
Context & operator=(const Context &); Context & operator=(const Context &);
@ -538,7 +542,6 @@ public:
BlockOutputStreamPtr getOutputStream(const String & name, WriteBuffer & buf, const Block & sample) const; BlockOutputStreamPtr getOutputStream(const String & name, WriteBuffer & buf, const Block & sample) const;
OutputFormatPtr getOutputFormatParallelIfPossible(const String & name, WriteBuffer & buf, const Block & sample) const; OutputFormatPtr getOutputFormatParallelIfPossible(const String & name, WriteBuffer & buf, const Block & sample) const;
OutputFormatPtr getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const;
InterserverIOHandler & getInterserverIOHandler(); InterserverIOHandler & getInterserverIOHandler();
@ -794,14 +797,10 @@ public:
/// Returns context of current distributed DDL query or nullptr. /// Returns context of current distributed DDL query or nullptr.
ZooKeeperMetadataTransactionPtr getZooKeeperMetadataTransaction() const; ZooKeeperMetadataTransactionPtr getZooKeeperMetadataTransaction() const;
struct MySQLWireContext /// Caller is responsible for lifetime of mysql_context.
{ /// Used in MySQLHandler for session context.
uint8_t sequence_id = 0; void setMySQLProtocolContext(MySQLWireContext * mysql_context);
uint32_t client_capabilities = 0; MySQLWireContext * getMySQLProtocolContext() const;
size_t max_packet_size = 0;
};
MySQLWireContext mysql;
PartUUIDsPtr getPartUUIDs() const; PartUUIDsPtr getPartUUIDs() const;
PartUUIDsPtr getIgnoredPartUUIDs() const; PartUUIDsPtr getIgnoredPartUUIDs() const;

View File

@ -17,6 +17,22 @@ MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header_,
{ {
} }
void MySQLOutputFormat::setContext(ContextPtr context_)
{
context = context_;
/// MySQlWire is a special format that is usually used as output format for MySQL protocol connections.
/// In this case we have to use the corresponding session context to set correct sequence_id.
mysql_context = getContext()->getMySQLProtocolContext();
if (!mysql_context)
{
/// But it's also possible to specify MySQLWire as output format for clickhouse-client or clickhouse-local.
/// There is no MySQL protocol context in this case, so we create dummy one.
own_mysql_context.emplace();
mysql_context = &own_mysql_context.value();
}
packet_endpoint = mysql_context->makeEndpoint(out);
}
void MySQLOutputFormat::initialize() void MySQLOutputFormat::initialize()
{ {
if (initialized) if (initialized)
@ -40,7 +56,7 @@ void MySQLOutputFormat::initialize()
packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId())); packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
} }
if (!(getContext()->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF)) if (!(mysql_context->client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
{ {
packet_endpoint->sendPacket(EOFPacket(0, 0)); packet_endpoint->sendPacket(EOFPacket(0, 0));
} }
@ -79,10 +95,10 @@ void MySQLOutputFormat::finalize()
const auto & header = getPort(PortKind::Main).getHeader(); const auto & header = getPort(PortKind::Main).getHeader();
if (header.columns() == 0) if (header.columns() == 0)
packet_endpoint->sendPacket( packet_endpoint->sendPacket(
OKPacket(0x0, getContext()->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); OKPacket(0x0, mysql_context->client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else if (getContext()->mysql.client_capabilities & CLIENT_DEPRECATE_EOF) else if (mysql_context->client_capabilities & CLIENT_DEPRECATE_EOF)
packet_endpoint->sendPacket( packet_endpoint->sendPacket(
OKPacket(0xfe, getContext()->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); OKPacket(0xfe, mysql_context->client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else else
packet_endpoint->sendPacket(EOFPacket(0, 0), true); packet_endpoint->sendPacket(EOFPacket(0, 0), true);
} }

View File

@ -25,11 +25,7 @@ public:
String getName() const override { return "MySQLOutputFormat"; } String getName() const override { return "MySQLOutputFormat"; }
void setContext(ContextPtr context_) void setContext(ContextPtr context_);
{
context = context_;
packet_endpoint = std::make_unique<MySQLProtocol::PacketEndpoint>(out, const_cast<uint8_t &>(getContext()->mysql.sequence_id)); /// TODO: fix it
}
void consume(Chunk) override; void consume(Chunk) override;
void finalize() override; void finalize() override;
@ -41,7 +37,9 @@ public:
private: private:
bool initialized = false; bool initialized = false;
std::unique_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint; std::optional<MySQLWireContext> own_mysql_context;
MySQLWireContext * mysql_context = nullptr;
MySQLProtocol::PacketEndpointPtr packet_endpoint;
FormatSettings format_settings; FormatSettings format_settings;
DataTypes data_types; DataTypes data_types;
Serializations serializations; Serializations serializations;

View File

@ -95,10 +95,11 @@ void MySQLHandler::run()
connection_context->getClientInfo().interface = ClientInfo::Interface::MYSQL; connection_context->getClientInfo().interface = ClientInfo::Interface::MYSQL;
connection_context->setDefaultFormat("MySQLWire"); connection_context->setDefaultFormat("MySQLWire");
connection_context->getClientInfo().connection_id = connection_id; connection_context->getClientInfo().connection_id = connection_id;
connection_context->setMySQLProtocolContext(&connection_context_mysql);
in = std::make_shared<ReadBufferFromPocoSocket>(socket()); in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket()); out = std::make_shared<WriteBufferFromPocoSocket>(socket());
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context->mysql.sequence_id); packet_endpoint = connection_context_mysql.makeEndpoint(*in, *out);
try try
{ {
@ -110,11 +111,11 @@ void MySQLHandler::run()
HandshakeResponse handshake_response; HandshakeResponse handshake_response;
finishHandshake(handshake_response); finishHandshake(handshake_response);
connection_context->mysql.client_capabilities = handshake_response.capability_flags; connection_context_mysql.client_capabilities = handshake_response.capability_flags;
if (handshake_response.max_packet_size) if (handshake_response.max_packet_size)
connection_context->mysql.max_packet_size = handshake_response.max_packet_size; connection_context_mysql.max_packet_size = handshake_response.max_packet_size;
if (!connection_context->mysql.max_packet_size) if (!connection_context_mysql.max_packet_size)
connection_context->mysql.max_packet_size = MAX_PACKET_LENGTH; connection_context_mysql.max_packet_size = MAX_PACKET_LENGTH;
LOG_TRACE(log, LOG_TRACE(log,
"Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}", "Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}",
@ -395,14 +396,14 @@ void MySQLHandlerSSL::finishHandshakeSSL(
ReadBufferFromMemory payload(buf, pos); ReadBufferFromMemory payload(buf, pos);
payload.ignore(PACKET_HEADER_SIZE); payload.ignore(PACKET_HEADER_SIZE);
ssl_request.readPayloadWithUnpacked(payload); ssl_request.readPayloadWithUnpacked(payload);
connection_context->mysql.client_capabilities = ssl_request.capability_flags; connection_context_mysql.client_capabilities = ssl_request.capability_flags;
connection_context->mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH; connection_context_mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
secure_connection = true; secure_connection = true;
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext())); ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
in = std::make_shared<ReadBufferFromPocoSocket>(*ss); in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
out = std::make_shared<WriteBufferFromPocoSocket>(*ss); out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
connection_context->mysql.sequence_id = 2; connection_context_mysql.sequence_id = 2;
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context->mysql.sequence_id); packet_endpoint = connection_context_mysql.makeEndpoint(*in, *out);
packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket. packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
} }

View File

@ -56,9 +56,10 @@ private:
protected: protected:
Poco::Logger * log; Poco::Logger * log;
MySQLWireContext connection_context_mysql;
ContextMutablePtr connection_context; ContextMutablePtr connection_context;
std::shared_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint; MySQLProtocol::PacketEndpointPtr packet_endpoint;
private: private:
UInt64 connection_id = 0; UInt64 connection_id = 0;

View File

@ -22,5 +22,27 @@ expect "| dummy |"
expect "| 0 |" expect "| 0 |"
expect "1 row in set" expect "1 row in set"
# exception before start
send -- "select * from table_that_does_not_exist;\r"
expect "ERROR 60 (00000): Code: 60"
# exception after start
send -- "select throwIf(number) from numbers(2) settings max_block_size=1;\r"
expect "ERROR 395 (00000): Code: 395"
# other formats
send -- "select * from system.one format TSV;\r"
expect "ERROR 1 (00000): Code: 1"
send -- "select count(number), sum(number) from numbers(10);\r"
expect "+---------------+-------------+"
expect "| count(number) | sum(number) |"
expect "+---------------+-------------+"
expect "| 10 | 45 |"
expect "+---------------+-------------+"
expect "1 row in set"
expect "Read 10 rows, 80.00 B"
expect "mysql> "
send -- "quit;\r" send -- "quit;\r"
expect eof expect eof