From 1a470fb777d1658d24b042f85e216d78c2bd78a8 Mon Sep 17 00:00:00 2001 From: Alexander Tokmakov Date: Wed, 7 Jul 2021 18:46:56 +0300 Subject: [PATCH 1/3] fix sequence_id in MySQL protocol --- src/Core/MySQL/MySQLClient.cpp | 17 +++++----- src/Core/MySQL/MySQLClient.h | 4 +-- src/Core/MySQL/PacketEndpoint.cpp | 11 ++++++ src/Core/MySQL/PacketEndpoint.h | 34 ++++++++++++++----- src/Formats/FormatFactory.cpp | 11 ++++-- src/Interpreters/Context.cpp | 19 ++++++++--- src/Interpreters/Context.h | 17 +++++----- .../Formats/Impl/MySQLOutputFormat.cpp | 24 ++++++++++--- .../Formats/Impl/MySQLOutputFormat.h | 10 +++--- src/Server/MySQLHandler.cpp | 19 ++++++----- src/Server/MySQLHandler.h | 3 +- .../01176_mysql_client_interactive.expect | 22 ++++++++++++ 12 files changed, 135 insertions(+), 56 deletions(-) diff --git a/src/Core/MySQL/MySQLClient.cpp b/src/Core/MySQL/MySQLClient.cpp index 3650818c543..d103ea873e5 100644 --- a/src/Core/MySQL/MySQLClient.cpp +++ b/src/Core/MySQL/MySQLClient.cpp @@ -26,13 +26,14 @@ namespace ErrorCodes MySQLClient::MySQLClient(const String & host_, UInt16 port_, const String & user_, const String & password_) : host(host_), port(port_), user(user_), password(std::move(password_)) { - client_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION; + mysql_context.client_capabilities = CLIENT_PROTOCOL_41 | CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION; } MySQLClient::MySQLClient(MySQLClient && other) : host(std::move(other.host)), port(other.port), user(std::move(other.user)), password(std::move(other.password)) - , client_capability_flags(other.client_capability_flags) + , mysql_context(other.mysql_context) { + mysql_context.sequence_id = 0; } void MySQLClient::connect() @@ -56,7 +57,7 @@ void MySQLClient::connect() in = std::make_shared(*socket); out = std::make_shared(*socket); - packet_endpoint = std::make_shared(*in, *out, seq); + packet_endpoint = mysql_context.makeEndpoint(*in, *out); handshake(); } @@ -68,7 +69,7 @@ void MySQLClient::disconnect() socket->close(); socket = nullptr; connected = false; - seq = 0; + mysql_context.sequence_id = 0; } /// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html @@ -87,10 +88,10 @@ void MySQLClient::handshake() String auth_plugin_data = native41.getAuthPluginData(); HandshakeResponse handshake_response( - client_capability_flags, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password); + mysql_context.client_capabilities, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password); packet_endpoint->sendPacket(handshake_response, true); - ResponsePacket packet_response(client_capability_flags, true); + ResponsePacket packet_response(mysql_context.client_capabilities, true); packet_endpoint->receivePacket(packet_response); packet_endpoint->resetSequenceId(); @@ -105,7 +106,7 @@ void MySQLClient::writeCommand(char command, String query) WriteCommand write_command(command, query); packet_endpoint->sendPacket(write_command, true); - ResponsePacket packet_response(client_capability_flags); + ResponsePacket packet_response(mysql_context.client_capabilities); packet_endpoint->receivePacket(packet_response); switch (packet_response.getType()) { @@ -124,7 +125,7 @@ void MySQLClient::registerSlaveOnMaster(UInt32 slave_id) RegisterSlave register_slave(slave_id); packet_endpoint->sendPacket(register_slave, true); - ResponsePacket packet_response(client_capability_flags); + ResponsePacket packet_response(mysql_context.client_capabilities); packet_endpoint->receivePacket(packet_response); packet_endpoint->resetSequenceId(); if (packet_response.getType() == PACKET_ERR) diff --git a/src/Core/MySQL/MySQLClient.h b/src/Core/MySQL/MySQLClient.h index e503c985584..6144b14690d 100644 --- a/src/Core/MySQL/MySQLClient.h +++ b/src/Core/MySQL/MySQLClient.h @@ -45,9 +45,7 @@ private: String password; bool connected = false; - UInt32 client_capability_flags = 0; - - uint8_t seq = 0; + MySQLWireContext mysql_context; const UInt8 charset_utf8 = 33; const String mysql_native_password = "mysql_native_password"; diff --git a/src/Core/MySQL/PacketEndpoint.cpp b/src/Core/MySQL/PacketEndpoint.cpp index 0bc5c585516..fa1d60034d2 100644 --- a/src/Core/MySQL/PacketEndpoint.cpp +++ b/src/Core/MySQL/PacketEndpoint.cpp @@ -68,4 +68,15 @@ String PacketEndpoint::packetToText(const String & payload) } + +MySQLProtocol::PacketEndpointPtr MySQLWireContext::makeEndpoint(WriteBuffer & out) +{ + return MySQLProtocol::PacketEndpoint::create(out, sequence_id); +} + +MySQLProtocol::PacketEndpointPtr MySQLWireContext::makeEndpoint(ReadBuffer & in, WriteBuffer & out) +{ + return MySQLProtocol::PacketEndpoint::create(in, out, sequence_id); +} + } diff --git a/src/Core/MySQL/PacketEndpoint.h b/src/Core/MySQL/PacketEndpoint.h index d027934eafb..3aa76ac93de 100644 --- a/src/Core/MySQL/PacketEndpoint.h +++ b/src/Core/MySQL/PacketEndpoint.h @@ -5,6 +5,7 @@ #include "IMySQLReadPacket.h" #include "IMySQLWritePacket.h" #include "IO/MySQLPacketPayloadReadBuffer.h" +#include namespace DB { @@ -15,19 +16,13 @@ namespace MySQLProtocol /* Writes and reads packets, keeping sequence-id. * Throws ProtocolError, if packet with incorrect sequence-id was received. */ -class PacketEndpoint +class PacketEndpoint : public shared_ptr_helper { public: uint8_t & sequence_id; ReadBuffer * in; WriteBuffer * out; - /// For writing. - PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_); - - /// For reading and writing. - PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_); - MySQLPacketPayloadReadBuffer getPayload(); void receivePacket(IMySQLReadPacket & packet); @@ -48,8 +43,29 @@ public: /// Converts packet to text. Is used for debug output. static String packetToText(const String & payload); + +protected: + /// For writing. + PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_); + + /// For reading and writing. + PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_); + + friend struct shared_ptr_helper; +}; + +using PacketEndpointPtr = std::shared_ptr; + +} + +struct MySQLWireContext +{ + uint8_t sequence_id = 0; + uint32_t client_capabilities = 0; + size_t max_packet_size = 0; + + MySQLProtocol::PacketEndpointPtr makeEndpoint(WriteBuffer & out); + MySQLProtocol::PacketEndpointPtr makeEndpoint(ReadBuffer & in, WriteBuffer & out); }; } - -} diff --git a/src/Formats/FormatFactory.cpp b/src/Formats/FormatFactory.cpp index 8b7cf9635b4..a00839fc5f5 100644 --- a/src/Formats/FormatFactory.cpp +++ b/src/Formats/FormatFactory.cpp @@ -33,6 +33,7 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; extern const int FORMAT_IS_NOT_SUITABLE_FOR_INPUT; extern const int FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT; + extern const int UNSUPPORTED_METHOD; } const FormatFactory::Creators & FormatFactory::getCreators(const String & name) const @@ -207,6 +208,9 @@ BlockOutputStreamPtr FormatFactory::getOutputStreamParallelIfPossible( WriteCallback callback, const std::optional & _format_settings) const { + if (context->getMySQLProtocolContext() && name != "MySQLWire") + throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats"); + const auto & output_getter = getCreators(name).output_processor_creator; const Settings & settings = context->getSettingsRef(); @@ -309,7 +313,10 @@ OutputFormatPtr FormatFactory::getOutputFormatParallelIfPossible( { const auto & output_getter = getCreators(name).output_processor_creator; if (!output_getter) - throw Exception("Format " + name + " is not suitable for output (with processors)", ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT); + throw Exception(ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT, "Format {} is not suitable for output (with processors)", name); + + if (context->getMySQLProtocolContext() && name != "MySQLWire") + throw Exception(ErrorCodes::UNSUPPORTED_METHOD, "MySQL protocol does not support custom output formats"); auto format_settings = _format_settings ? *_format_settings : getFormatSettings(context); @@ -344,7 +351,7 @@ OutputFormatPtr FormatFactory::getOutputFormat( { const auto & output_getter = getCreators(name).output_processor_creator; if (!output_getter) - throw Exception("Format " + name + " is not suitable for output (with processors)", ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT); + throw Exception(ErrorCodes::FORMAT_IS_NOT_SUITABLE_FOR_OUTPUT, "Format {} is not suitable for output (with processors)", name); if (context->hasQueryContext() && context->getSettingsRef().log_queries) context->getQueryContext()->addQueryFactoriesInfo(Context::QueryLogFactories::Format, name); diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 9b204f12ab2..c597bc49c91 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -2335,11 +2335,6 @@ OutputFormatPtr Context::getOutputFormatParallelIfPossible(const String & name, return FormatFactory::instance().getOutputFormatParallelIfPossible(name, buf, sample, shared_from_this()); } -OutputFormatPtr Context::getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const -{ - return FormatFactory::instance().getOutputFormat(name, buf, sample, shared_from_this()); -} - time_t Context::getUptimeSeconds() const { @@ -2712,4 +2707,18 @@ PartUUIDsPtr Context::getIgnoredPartUUIDs() const return ignored_part_uuids; } +void Context::setMySQLProtocolContext(MySQLWireContext * mysql_context) +{ + assert(session_context.lock().get() == this); + assert(!mysql_protocol_context); + assert(mysql_context); + mysql_protocol_context = mysql_context; +} + +MySQLWireContext * Context::getMySQLProtocolContext() const +{ + assert((mysql_protocol_context == nullptr) == (session_context.lock().get() == nullptr)); + return mysql_protocol_context; +} + } diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 2b53c737915..9c14a50e0e1 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -119,6 +119,8 @@ using ThrottlerPtr = std::shared_ptr; class ZooKeeperMetadataTransaction; using ZooKeeperMetadataTransactionPtr = std::shared_ptr; +struct MySQLWireContext; + /// Callback for external tables initializer using ExternalTablesInitializer = std::function; @@ -298,6 +300,8 @@ private: /// thousands of signatures. /// And I hope it will be replaced with more common Transaction sometime. + MySQLWireContext * mysql_protocol_context = nullptr; + Context(); Context(const Context &); Context & operator=(const Context &); @@ -533,7 +537,6 @@ public: BlockOutputStreamPtr getOutputStream(const String & name, WriteBuffer & buf, const Block & sample) const; OutputFormatPtr getOutputFormatParallelIfPossible(const String & name, WriteBuffer & buf, const Block & sample) const; - OutputFormatPtr getOutputFormat(const String & name, WriteBuffer & buf, const Block & sample) const; InterserverIOHandler & getInterserverIOHandler(); @@ -789,14 +792,10 @@ public: /// Returns context of current distributed DDL query or nullptr. ZooKeeperMetadataTransactionPtr getZooKeeperMetadataTransaction() const; - struct MySQLWireContext - { - uint8_t sequence_id = 0; - uint32_t client_capabilities = 0; - size_t max_packet_size = 0; - }; - - MySQLWireContext mysql; + /// Caller is responsible for lifetime of mysql_context. + /// Used in MySQLHandler for session context. + void setMySQLProtocolContext(MySQLWireContext * mysql_context); + MySQLWireContext * getMySQLProtocolContext() const; PartUUIDsPtr getPartUUIDs() const; PartUUIDsPtr getIgnoredPartUUIDs() const; diff --git a/src/Processors/Formats/Impl/MySQLOutputFormat.cpp b/src/Processors/Formats/Impl/MySQLOutputFormat.cpp index 0f73349c271..5f991dd0a3f 100644 --- a/src/Processors/Formats/Impl/MySQLOutputFormat.cpp +++ b/src/Processors/Formats/Impl/MySQLOutputFormat.cpp @@ -17,6 +17,22 @@ MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header_, { } +void MySQLOutputFormat::setContext(ContextPtr context_) +{ + context = context_; + /// MySQlWire is a special format that is usually used as output format for MySQL protocol connections. + /// In this case we have to use the corresponding session context to set correct sequence_id. + mysql_context = getContext()->getMySQLProtocolContext(); + if (!mysql_context) + { + /// But it's also possible to specify MySQLWire as output format for clickhouse-client ot clickhouse-local. + /// There is no MySQL protocol context in this case, so we create dummy one. + own_mysql_context.emplace(); + mysql_context = &own_mysql_context.value(); + } + packet_endpoint = mysql_context->makeEndpoint(out); +} + void MySQLOutputFormat::initialize() { if (initialized) @@ -40,7 +56,7 @@ void MySQLOutputFormat::initialize() packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId())); } - if (!(getContext()->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF)) + if (!(mysql_context->client_capabilities & Capability::CLIENT_DEPRECATE_EOF)) { packet_endpoint->sendPacket(EOFPacket(0, 0)); } @@ -79,10 +95,10 @@ void MySQLOutputFormat::finalize() const auto & header = getPort(PortKind::Main).getHeader(); if (header.columns() == 0) packet_endpoint->sendPacket( - OKPacket(0x0, getContext()->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); - else if (getContext()->mysql.client_capabilities & CLIENT_DEPRECATE_EOF) + OKPacket(0x0, mysql_context->client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); + else if (mysql_context->client_capabilities & CLIENT_DEPRECATE_EOF) packet_endpoint->sendPacket( - OKPacket(0xfe, getContext()->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); + OKPacket(0xfe, mysql_context->client_capabilities, affected_rows, 0, 0, "", human_readable_info), true); else packet_endpoint->sendPacket(EOFPacket(0, 0), true); } diff --git a/src/Processors/Formats/Impl/MySQLOutputFormat.h b/src/Processors/Formats/Impl/MySQLOutputFormat.h index 7d67df3015e..fed2a431860 100644 --- a/src/Processors/Formats/Impl/MySQLOutputFormat.h +++ b/src/Processors/Formats/Impl/MySQLOutputFormat.h @@ -25,11 +25,7 @@ public: String getName() const override { return "MySQLOutputFormat"; } - void setContext(ContextPtr context_) - { - context = context_; - packet_endpoint = std::make_unique(out, const_cast(getContext()->mysql.sequence_id)); /// TODO: fix it - } + void setContext(ContextPtr context_); void consume(Chunk) override; void finalize() override; @@ -41,7 +37,9 @@ public: private: bool initialized = false; - std::unique_ptr packet_endpoint; + std::optional own_mysql_context; + MySQLWireContext * mysql_context = nullptr; + MySQLProtocol::PacketEndpointPtr packet_endpoint; FormatSettings format_settings; DataTypes data_types; Serializations serializations; diff --git a/src/Server/MySQLHandler.cpp b/src/Server/MySQLHandler.cpp index beace5dd576..b8913f5e64f 100644 --- a/src/Server/MySQLHandler.cpp +++ b/src/Server/MySQLHandler.cpp @@ -95,10 +95,11 @@ void MySQLHandler::run() connection_context->getClientInfo().interface = ClientInfo::Interface::MYSQL; connection_context->setDefaultFormat("MySQLWire"); connection_context->getClientInfo().connection_id = connection_id; + connection_context->setMySQLProtocolContext(&connection_context_mysql); in = std::make_shared(socket()); out = std::make_shared(socket()); - packet_endpoint = std::make_shared(*in, *out, connection_context->mysql.sequence_id); + packet_endpoint = connection_context_mysql.makeEndpoint(*in, *out); try { @@ -110,11 +111,11 @@ void MySQLHandler::run() HandshakeResponse handshake_response; finishHandshake(handshake_response); - connection_context->mysql.client_capabilities = handshake_response.capability_flags; + connection_context_mysql.client_capabilities = handshake_response.capability_flags; if (handshake_response.max_packet_size) - connection_context->mysql.max_packet_size = handshake_response.max_packet_size; - if (!connection_context->mysql.max_packet_size) - connection_context->mysql.max_packet_size = MAX_PACKET_LENGTH; + connection_context_mysql.max_packet_size = handshake_response.max_packet_size; + if (!connection_context_mysql.max_packet_size) + connection_context_mysql.max_packet_size = MAX_PACKET_LENGTH; LOG_TRACE(log, "Capabilities: {}, max_packet_size: {}, character_set: {}, user: {}, auth_response length: {}, database: {}, auth_plugin_name: {}", @@ -395,14 +396,14 @@ void MySQLHandlerSSL::finishHandshakeSSL( ReadBufferFromMemory payload(buf, pos); payload.ignore(PACKET_HEADER_SIZE); ssl_request.readPayloadWithUnpacked(payload); - connection_context->mysql.client_capabilities = ssl_request.capability_flags; - connection_context->mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH; + connection_context_mysql.client_capabilities = ssl_request.capability_flags; + connection_context_mysql.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH; secure_connection = true; ss = std::make_shared(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext())); in = std::make_shared(*ss); out = std::make_shared(*ss); - connection_context->mysql.sequence_id = 2; - packet_endpoint = std::make_shared(*in, *out, connection_context->mysql.sequence_id); + connection_context_mysql.sequence_id = 2; + packet_endpoint = connection_context_mysql.makeEndpoint(*in, *out); packet_endpoint->receivePacket(packet); /// Reading HandshakeResponse from secure socket. } diff --git a/src/Server/MySQLHandler.h b/src/Server/MySQLHandler.h index e681ad2e6f6..2ea5695a0a6 100644 --- a/src/Server/MySQLHandler.h +++ b/src/Server/MySQLHandler.h @@ -56,9 +56,10 @@ private: protected: Poco::Logger * log; + MySQLWireContext connection_context_mysql; ContextMutablePtr connection_context; - std::shared_ptr packet_endpoint; + MySQLProtocol::PacketEndpointPtr packet_endpoint; private: UInt64 connection_id = 0; diff --git a/tests/queries/0_stateless/01176_mysql_client_interactive.expect b/tests/queries/0_stateless/01176_mysql_client_interactive.expect index b2dc88a7795..2337b7d01fe 100755 --- a/tests/queries/0_stateless/01176_mysql_client_interactive.expect +++ b/tests/queries/0_stateless/01176_mysql_client_interactive.expect @@ -22,5 +22,27 @@ expect "| dummy |" expect "| 0 |" expect "1 row in set" +# exception before start +send -- "select * from table_that_does_not_exist;\r" +expect "ERROR 60 (00000): Code: 60" + +# exception after start +send -- "select throwIf(number) from numbers(2) settings max_block_size=1;\r" +expect "ERROR 395 (00000): Code: 395" + +# other formats +send -- "select * from system.one format TSV;\r" +expect "ERROR 1 (00000): Code: 1" + +send -- "select count(number), sum(number) from numbers(10);\r" +expect "+---------------+-------------+" +expect "| count(number) | sum(number) |" +expect "+---------------+-------------+" +expect "| 10 | 45 |" +expect "+---------------+-------------+" +expect "1 row in set" +expect "Read 10 rows, 80.00 B" +expect "mysql> " + send -- "quit;\r" expect eof From 8333a09258de2d6f74b5830ed53a184f74f8e4eb Mon Sep 17 00:00:00 2001 From: tavplubix Date: Thu, 8 Jul 2021 13:18:54 +0300 Subject: [PATCH 2/3] Update Context.cpp --- src/Interpreters/Context.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index c597bc49c91..901c38791b9 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -2717,7 +2717,7 @@ void Context::setMySQLProtocolContext(MySQLWireContext * mysql_context) MySQLWireContext * Context::getMySQLProtocolContext() const { - assert((mysql_protocol_context == nullptr) == (session_context.lock().get() == nullptr)); + assert(!mysql_protocol_context || session_context.lock().get()); return mysql_protocol_context; } From dd0ad58deed709c63d54e3c7fda6d15b85c98bc2 Mon Sep 17 00:00:00 2001 From: tavplubix Date: Thu, 8 Jul 2021 13:20:06 +0300 Subject: [PATCH 3/3] Update MySQLOutputFormat.cpp --- src/Processors/Formats/Impl/MySQLOutputFormat.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Processors/Formats/Impl/MySQLOutputFormat.cpp b/src/Processors/Formats/Impl/MySQLOutputFormat.cpp index 5f991dd0a3f..0f6d90b720e 100644 --- a/src/Processors/Formats/Impl/MySQLOutputFormat.cpp +++ b/src/Processors/Formats/Impl/MySQLOutputFormat.cpp @@ -25,7 +25,7 @@ void MySQLOutputFormat::setContext(ContextPtr context_) mysql_context = getContext()->getMySQLProtocolContext(); if (!mysql_context) { - /// But it's also possible to specify MySQLWire as output format for clickhouse-client ot clickhouse-local. + /// But it's also possible to specify MySQLWire as output format for clickhouse-client or clickhouse-local. /// There is no MySQL protocol context in this case, so we create dummy one. own_mysql_context.emplace(); mysql_context = &own_mysql_context.value();