mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
Merge pull request #26051 from ClickHouse/fix_21184
Fix sequence_id in MySQL protocol
This commit is contained in:
commit
c255f152aa
@ -26,13 +26,14 @@ namespace ErrorCodes
|
||||
MySQLClient::MySQLClient(const String & host_, UInt16 port_, const String & user_, const String & password_)
|
||||
: host(host_), port(port_), user(user_), password(std::move(password_))
|
||||
{
|
||||
client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
|
||||
mysql_context.client_capabilities = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION;
|
||||
}
|
||||
|
||||
MySQLClient::MySQLClient(MySQLClient && other)
|
||||
: host(std::move(other.host)), port(other.port), user(std::move(other.user)), password(std::move(other.password))
|
||||
, client_capability_flags(other.client_capability_flags)
|
||||
, mysql_context(other.mysql_context)
|
||||
{
|
||||
mysql_context.sequence_id = 0;
|
||||
}
|
||||
|
||||
void MySQLClient::connect()
|
||||
@ -56,7 +57,7 @@ void MySQLClient::connect()
|
||||
|
||||
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
|
||||
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
|
||||
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, seq);
|
||||
packet_endpoint = mysql_context.makeEndpoint(*in, *out);
|
||||
handshake();
|
||||
}
|
||||
|
||||
@ -68,7 +69,7 @@ void MySQLClient::disconnect()
|
||||
socket->close();
|
||||
socket = nullptr;
|
||||
connected = false;
|
||||
seq = 0;
|
||||
mysql_context.sequence_id = 0;
|
||||
}
|
||||
|
||||
/// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html
|
||||
@ -87,10 +88,10 @@ void MySQLClient::handshake()
|
||||
String auth_plugin_data = native41.getAuthPluginData();
|
||||
|
||||
HandshakeResponse handshake_response(
|
||||
client_capability_flags, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
|
||||
mysql_context.client_capabilities, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
|
||||
packet_endpoint->sendPacket<HandshakeResponse>(handshake_response, true);
|
||||
|
||||
ResponsePacket packet_response(client_capability_flags, true);
|
||||
ResponsePacket packet_response(mysql_context.client_capabilities, true);
|
||||
packet_endpoint->receivePacket(packet_response);
|
||||
packet_endpoint->resetSequenceId();
|
||||
|
||||
@ -105,7 +106,7 @@ void MySQLClient::writeCommand(char command, String query)
|
||||
WriteCommand write_command(command, query);
|
||||
packet_endpoint->sendPacket<WriteCommand>(write_command, true);
|
||||
|
||||
ResponsePacket packet_response(client_capability_flags);
|
||||
ResponsePacket packet_response(mysql_context.client_capabilities);
|
||||
packet_endpoint->receivePacket(packet_response);
|
||||
switch (packet_response.getType())
|
||||
{
|
||||
@ -124,7 +125,7 @@ void MySQLClient::registerSlaveOnMaster(UInt32 slave_id)
|
||||
RegisterSlave register_slave(slave_id);
|
||||
packet_endpoint->sendPacket<RegisterSlave>(register_slave, true);
|
||||
|
||||
ResponsePacket packet_response(client_capability_flags);
|
||||
ResponsePacket packet_response(mysql_context.client_capabilities);
|
||||
packet_endpoint->receivePacket(packet_response);
|
||||
packet_endpoint->resetSequenceId();
|
||||
if (packet_response.getType() == PACKET_ERR)
|
||||
|
@ -45,9 +45,7 @@ private:
|
||||
String password;
|
||||
|
||||
bool connected = false;
|
||||
UInt32 client_capability_flags = 0;
|
||||
|
||||
uint8_t seq = 0;
|
||||
MySQLWireContext mysql_context;
|
||||
const UInt8 charset_utf8 = 33;
|
||||
const String mysql_native_password = "mysql_native_password";
|
||||
|
||||
|
@ -68,4 +68,15 @@ String PacketEndpoint::packetToText(const String & payload)
|
||||
|
||||
}
|
||||
|
||||
|
||||
MySQLProtocol::PacketEndpointPtr MySQLWireContext::makeEndpoint(WriteBuffer & out)
|
||||
{
|
||||
return MySQLProtocol::PacketEndpoint::create(out, sequence_id);
|
||||
}
|
||||
|
||||
MySQLProtocol::PacketEndpointPtr MySQLWireContext::makeEndpoint(ReadBuffer & in, WriteBuffer & out)
|
||||
{
|
||||
return MySQLProtocol::PacketEndpoint::create(in, out, sequence_id);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "IMySQLReadPacket.h"
|
||||
#include "IMySQLWritePacket.h"
|
||||
#include "IO/MySQLPacketPayloadReadBuffer.h"
|
||||
#include <common/shared_ptr_helper.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -15,19 +16,13 @@ namespace MySQLProtocol
|
||||
/* Writes and reads packets, keeping sequence-id.
|
||||
* Throws ProtocolError, if packet with incorrect sequence-id was received.
|
||||
*/
|
||||
class PacketEndpoint
|
||||
class PacketEndpoint : public shared_ptr_helper<PacketEndpoint>
|
||||
{
|
||||
public:
|
||||
uint8_t & sequence_id;
|
||||
ReadBuffer * in;
|
||||
WriteBuffer * out;
|
||||
|
||||
/// For writing.
|
||||
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
|
||||
|
||||
/// For reading and writing.
|
||||
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
|
||||
|
||||
MySQLPacketPayloadReadBuffer getPayload();
|
||||
|
||||
void receivePacket(IMySQLReadPacket & packet);
|
||||
@ -48,8 +43,29 @@ public:
|
||||
|
||||
/// Converts packet to text. Is used for debug output.
|
||||
static String packetToText(const String & payload);
|
||||
|
||||
protected:
|
||||
/// For writing.
|
||||
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
|
||||
|
||||
/// For reading and writing.
|
||||
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
|
||||
|
||||
friend struct shared_ptr_helper<PacketEndpoint>;
|
||||
};
|
||||
|
||||
using PacketEndpointPtr = std::shared_ptr<PacketEndpoint>;
|
||||
|
||||
}
|
||||
|
||||
struct MySQLWireContext
|
||||
{
|
||||
uint8_t sequence_id = 0;
|
||||
uint32_t client_capabilities = 0;
|
||||
size_t max_packet_size = 0;
|
||||
|
||||
MySQLProtocol::PacketEndpointPtr makeEndpoint(WriteBuffer & out);
|
||||
MySQLProtocol::PacketEndpointPtr makeEndpoint(ReadBuffer & in, WriteBuffer & out);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -33,6 +33,7 @@ namespace ErrorCodes
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int FORMAT_IS_NOT_SUITABLE_FOR_INPUT;
|
||||
extern const int FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT;
|
||||
extern const int UNSUPPORTED_METHOD;
|
||||
}
|
||||
|
||||
const FormatFactory::Creators & FormatFactory::getCreators(const String & name) const
|
||||
@ -207,6 +208,9 @@ BlockOutputStreamPtr FormatFactory::getOutputStreamParallelIfPossible(
|
||||
WriteCallback callback,
|
||||
const std::optional<FormatSettings> & _format_settings) const
|
||||
{
|
||||
if (context->getMySQLProtocolContext() && name != "MySQLWire")
|
||||
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats");
|
||||
|
||||
const auto & output_getter = getCreators(name).output_processor_creator;
|
||||
|
||||
const Settings & settings = context->getSettingsRef();
|
||||
@ -309,7 +313,10 @@ OutputFormatPtr FormatFactory::getOutputFormatParallelIfPossible(
|
||||
{
|
||||
const auto & output_getter = getCreators(name).output_processor_creator;
|
||||
if (!output_getter)
|
||||
throw Exception("Format " + name + " is not suitable for output (with processors)", ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT);
|
||||
throw Exception(ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT, "Format {} is not suitable for output (with processors)", name);
|
||||
|
||||
if (context->getMySQLProtocolContext() && name != "MySQLWire")
|
||||
throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats");
|
||||
|
||||
auto format_settings = _format_settings ? *_format_settings : getFormatSettings(context);
|
||||
|
||||
@ -344,7 +351,7 @@ OutputFormatPtr FormatFactory::getOutputFormat(
|
||||
{
|
||||
const auto & output_getter = getCreators(name).output_processor_creator;
|
||||
if (!output_getter)
|
||||
throw Exception("Format " + name + " is not suitable for output (with processors)", ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT);
|
||||
throw Exception(ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT, "Format {} is not suitable for output (with processors)", name);
|
||||
|
||||
if (context->hasQueryContext() && context->getSettingsRef().log_queries)
|
||||
context->getQueryContext()->addQueryFactoriesInfo(Context::QueryLogFactories::Format, name);
|
||||
|
@ -2355,11 +2355,6 @@ OutputFormatPtr Context::getOutputFormatParallelIfPossible(const String & name,
|
||||
return FormatFactory::instance().getOutputFormatParallelIfPossible(name, buf, sample, shared_from_this());
|
||||
}
|
||||
|
||||
OutputFormatPtr Context::getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const
|
||||
{
|
||||
return FormatFactory::instance().getOutputFormat(name, buf, sample, shared_from_this());
|
||||
}
|
||||
|
||||
|
||||
time_t Context::getUptimeSeconds() const
|
||||
{
|
||||
@ -2732,4 +2727,18 @@ PartUUIDsPtr Context::getIgnoredPartUUIDs() const
|
||||
return ignored_part_uuids;
|
||||
}
|
||||
|
||||
void Context::setMySQLProtocolContext(MySQLWireContext * mysql_context)
|
||||
{
|
||||
assert(session_context.lock().get() == this);
|
||||
assert(!mysql_protocol_context);
|
||||
assert(mysql_context);
|
||||
mysql_protocol_context = mysql_context;
|
||||
}
|
||||
|
||||
MySQLWireContext * Context::getMySQLProtocolContext() const
|
||||
{
|
||||
assert(!mysql_protocol_context || session_context.lock().get());
|
||||
return mysql_protocol_context;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -119,6 +119,8 @@ using ThrottlerPtr = std::shared_ptr<Throttler>;
|
||||
class ZooKeeperMetadataTransaction;
|
||||
using ZooKeeperMetadataTransactionPtr = std::shared_ptr<ZooKeeperMetadataTransaction>;
|
||||
|
||||
struct MySQLWireContext;
|
||||
|
||||
/// Callback for external tables initializer
|
||||
using ExternalTablesInitializer = std::function<void(ContextPtr)>;
|
||||
|
||||
@ -298,6 +300,8 @@ private:
|
||||
/// thousands of signatures.
|
||||
/// And I hope it will be replaced with more common Transaction sometime.
|
||||
|
||||
MySQLWireContext * mysql_protocol_context = nullptr;
|
||||
|
||||
Context();
|
||||
Context(const Context &);
|
||||
Context & operator=(const Context &);
|
||||
@ -538,7 +542,6 @@ public:
|
||||
BlockOutputStreamPtr getOutputStream(const String & name, WriteBuffer & buf, const Block & sample) const;
|
||||
|
||||
OutputFormatPtr getOutputFormatParallelIfPossible(const String & name, WriteBuffer & buf, const Block & sample) const;
|
||||
OutputFormatPtr getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const;
|
||||
|
||||
InterserverIOHandler & getInterserverIOHandler();
|
||||
|
||||
@ -794,14 +797,10 @@ public:
|
||||
/// Returns context of current distributed DDL query or nullptr.
|
||||
ZooKeeperMetadataTransactionPtr getZooKeeperMetadataTransaction() const;
|
||||
|
||||
struct MySQLWireContext
|
||||
{
|
||||
uint8_t sequence_id = 0;
|
||||
uint32_t client_capabilities = 0;
|
||||
size_t max_packet_size = 0;
|
||||
};
|
||||
|
||||
MySQLWireContext mysql;
|
||||
/// Caller is responsible for lifetime of mysql_context.
|
||||
/// Used in MySQLHandler for session context.
|
||||
void setMySQLProtocolContext(MySQLWireContext * mysql_context);
|
||||
MySQLWireContext * getMySQLProtocolContext() const;
|
||||
|
||||
PartUUIDsPtr getPartUUIDs() const;
|
||||
PartUUIDsPtr getIgnoredPartUUIDs() const;
|
||||
|
@ -17,6 +17,22 @@ MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header_,
|
||||
{
|
||||
}
|
||||
|
||||
void MySQLOutputFormat::setContext(ContextPtr context_)
|
||||
{
|
||||
context = context_;
|
||||
/// MySQlWire is a special format that is usually used as output format for MySQL protocol connections.
|
||||
/// In this case we have to use the corresponding session context to set correct sequence_id.
|
||||
mysql_context = getContext()->getMySQLProtocolContext();
|
||||
if (!mysql_context)
|
||||
{
|
||||
/// But it's also possible to specify MySQLWire as output format for clickhouse-client or clickhouse-local.
|
||||
/// There is no MySQL protocol context in this case, so we create dummy one.
|
||||
own_mysql_context.emplace();
|
||||
mysql_context = &own_mysql_context.value();
|
||||
}
|
||||
packet_endpoint = mysql_context->makeEndpoint(out);
|
||||
}
|
||||
|
||||
void MySQLOutputFormat::initialize()
|
||||
{
|
||||
if (initialized)
|
||||
@ -40,7 +56,7 @@ void MySQLOutputFormat::initialize()
|
||||
packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
|
||||
}
|
||||
|
||||
if (!(getContext()->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
|
||||
if (!(mysql_context->client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
|
||||
{
|
||||
packet_endpoint->sendPacket(EOFPacket(0, 0));
|
||||
}
|
||||
@ -79,10 +95,10 @@ void MySQLOutputFormat::finalize()
|
||||
const auto & header = getPort(PortKind::Main).getHeader();
|
||||
if (header.columns() == 0)
|
||||
packet_endpoint->sendPacket(
|
||||
OKPacket(0x0, getContext()->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
|
||||
else if (getContext()->mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
|
||||
OKPacket(0x0, mysql_context->client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
|
||||
else if (mysql_context->client_capabilities & CLIENT_DEPRECATE_EOF)
|
||||
packet_endpoint->sendPacket(
|
||||
OKPacket(0xfe, getContext()->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
|
||||
OKPacket(0xfe, mysql_context->client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
|
||||
else
|
||||
packet_endpoint->sendPacket(EOFPacket(0, 0), true);
|
||||
}
|
||||
|
@ -25,11 +25,7 @@ public:
|
||||
|
||||
String getName() const override { return "MySQLOutputFormat"; }
|
||||
|
||||
void setContext(ContextPtr context_)
|
||||
{
|
||||
context = context_;
|
||||
packet_endpoint = std::make_unique<MySQLProtocol::PacketEndpoint>(out, const_cast<uint8_t &>(getContext()->mysql.sequence_id)); /// TODO: fix it
|
||||
}
|
||||
void setContext(ContextPtr context_);
|
||||
|
||||
void consume(Chunk) override;
|
||||
void finalize() override;
|
||||
@ -41,7 +37,9 @@ public:
|
||||
private:
|
||||
bool initialized = false;
|
||||
|
||||
std::unique_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint;
|
||||
std::optional<MySQLWireContext> own_mysql_context;
|
||||
MySQLWireContext * mysql_context = nullptr;
|
||||
MySQLProtocol::PacketEndpointPtr packet_endpoint;
|
||||
FormatSettings format_settings;
|
||||
DataTypes data_types;
|
||||
Serializations serializations;
|
||||
|
@ -95,10 +95,11 @@ void MySQLHandler::run()
|
||||
connection_context->getClientInfo().interface = ClientInfo::Interface::MYSQL;
|
||||
connection_context->setDefaultFormat("MySQLWire");
|
||||
connection_context->getClientInfo().connection_id = connection_id;
|
||||
connection_context->setMySQLProtocolContext(&connection_context_mysql);
|
||||
|
||||
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
|
||||
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
|
||||
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context->mysql.sequence_id);
|
||||
packet_endpoint = connection_context_mysql.makeEndpoint(*in, *out);
|
||||
|
||||
try
|
||||
{
|
||||
@ -110,11 +111,11 @@ void MySQLHandler::run()
|
||||
|
||||
HandshakeResponse handshake_response;
|
||||
finishHandshake(handshake_response);
|
||||
connection_context->mysql.client_capabilities = handshake_response.capability_flags;
|
||||
connection_context_mysql.client_capabilities = handshake_response.capability_flags;
|
||||
if (handshake_response.max_packet_size)
|
||||
connection_context->mysql.max_packet_size = handshake_response.max_packet_size;
|
||||
if (!connection_context->mysql.max_packet_size)
|
||||
connection_context->mysql.max_packet_size = MAX_PACKET_LENGTH;
|
||||
connection_context_mysql.max_packet_size = handshake_response.max_packet_size;
|
||||
if (!connection_context_mysql.max_packet_size)
|
||||
connection_context_mysql.max_packet_size = MAX_PACKET_LENGTH;
|
||||
|
||||
LOG_TRACE(log,
|
||||
"Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}",
|
||||
@ -395,14 +396,14 @@ void MySQLHandlerSSL::finishHandshakeSSL(
|
||||
ReadBufferFromMemory payload(buf, pos);
|
||||
payload.ignore(PACKET_HEADER_SIZE);
|
||||
ssl_request.readPayloadWithUnpacked(payload);
|
||||
connection_context->mysql.client_capabilities = ssl_request.capability_flags;
|
||||
connection_context->mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
|
||||
connection_context_mysql.client_capabilities = ssl_request.capability_flags;
|
||||
connection_context_mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
|
||||
secure_connection = true;
|
||||
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
|
||||
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
|
||||
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
|
||||
connection_context->mysql.sequence_id = 2;
|
||||
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, connection_context->mysql.sequence_id);
|
||||
connection_context_mysql.sequence_id = 2;
|
||||
packet_endpoint = connection_context_mysql.makeEndpoint(*in, *out);
|
||||
packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
|
||||
}
|
||||
|
||||
|
@ -56,9 +56,10 @@ private:
|
||||
protected:
|
||||
Poco::Logger * log;
|
||||
|
||||
MySQLWireContext connection_context_mysql;
|
||||
ContextMutablePtr connection_context;
|
||||
|
||||
std::shared_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint;
|
||||
MySQLProtocol::PacketEndpointPtr packet_endpoint;
|
||||
|
||||
private:
|
||||
UInt64 connection_id = 0;
|
||||
|
@ -22,5 +22,27 @@ expect "| dummy |"
|
||||
expect "| 0 |"
|
||||
expect "1 row in set"
|
||||
|
||||
# exception before start
|
||||
send -- "select * from table_that_does_not_exist;\r"
|
||||
expect "ERROR 60 (00000): Code: 60"
|
||||
|
||||
# exception after start
|
||||
send -- "select throwIf(number) from numbers(2) settings max_block_size=1;\r"
|
||||
expect "ERROR 395 (00000): Code: 395"
|
||||
|
||||
# other formats
|
||||
send -- "select * from system.one format TSV;\r"
|
||||
expect "ERROR 1 (00000): Code: 1"
|
||||
|
||||
send -- "select count(number), sum(number) from numbers(10);\r"
|
||||
expect "+---------------+-------------+"
|
||||
expect "| count(number) | sum(number) |"
|
||||
expect "+---------------+-------------+"
|
||||
expect "| 10 | 45 |"
|
||||
expect "+---------------+-------------+"
|
||||
expect "1 row in set"
|
||||
expect "Read 10 rows, 80.00 B"
|
||||
expect "mysql> "
|
||||
|
||||
send -- "quit;\r"
|
||||
expect eof
|
||||
|
Loading…
Reference in New Issue
Block a user