diff --git a/programs/benchmark/Benchmark.cpp b/programs/benchmark/Benchmark.cpp index c5acd10f791..a5564f47784 100644 --- a/programs/benchmark/Benchmark.cpp +++ b/programs/benchmark/Benchmark.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,7 @@ namespace DB { using Ports = std::vector; +static constexpr std::string_view DEFAULT_CLIENT_NAME = "benchmark"; namespace ErrorCodes { @@ -122,7 +124,7 @@ public: default_database_, user_, password_, quota_key_, /* cluster_= */ "", /* cluster_secret_= */ "", - /* client_name_= */ "benchmark", + /* client_name_= */ std::string(DEFAULT_CLIENT_NAME), Protocol::Compression::Enable, secure)); @@ -135,6 +137,8 @@ public: global_context->makeGlobalContext(); global_context->setSettings(settings); + global_context->setClientName(std::string(DEFAULT_CLIENT_NAME)); + global_context->setQueryKindInitial(); std::cerr << std::fixed << std::setprecision(3); diff --git a/programs/client/Client.cpp b/programs/client/Client.cpp index e73f77819ad..929e59ed852 100644 --- a/programs/client/Client.cpp +++ b/programs/client/Client.cpp @@ -1243,6 +1243,7 @@ void Client::processConfig() global_context->getSettingsRef().max_insert_block_size); } + global_context->setClientName(std::string(DEFAULT_CLIENT_NAME)); global_context->setQueryKindInitial(); global_context->setQuotaClientKey(config().getString("quota_key", "")); global_context->setQueryKind(query_kind); diff --git a/src/Client/ClientBase.h b/src/Client/ClientBase.h index bd17318d1df..d877905302d 100644 --- a/src/Client/ClientBase.h +++ b/src/Client/ClientBase.h @@ -1,5 +1,6 @@ #pragma once +#include #include "Common/NamePrompter.h" #include #include @@ -24,6 +25,7 @@ namespace po = boost::program_options; namespace DB { +static constexpr std::string_view DEFAULT_CLIENT_NAME = "client"; static const NameSet exit_strings { diff --git a/src/Client/Connection.cpp b/src/Client/Connection.cpp index 3e12e60be08..859afb5ea44 100644 --- a/src/Client/Connection.cpp +++ b/src/Client/Connection.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -1204,7 +1205,7 @@ ServerConnectionPtr Connection::createConnection(const ConnectionParameters & pa parameters.quota_key, "", /* cluster */ "", /* cluster_secret */ - "client", + std::string(DEFAULT_CLIENT_NAME), parameters.compression, parameters.security); } diff --git a/src/Common/ErrorCodes.cpp b/src/Common/ErrorCodes.cpp index a3277821111..f23685c37d1 100644 --- a/src/Common/ErrorCodes.cpp +++ b/src/Common/ErrorCodes.cpp @@ -584,6 +584,7 @@ M(699, INVALID_REDIS_TABLE_STRUCTURE) \ M(700, USER_SESSION_LIMIT_EXCEEDED) \ M(701, CLUSTER_DOESNT_EXIST) \ + M(702, CLIENT_INFO_DOES_NOT_MATCH) \ \ M(999, KEEPER_EXCEPTION) \ M(1000, POCO_EXCEPTION) \ diff --git a/src/Interpreters/ClientInfo.cpp b/src/Interpreters/ClientInfo.cpp index 6c09b327ca1..d007341a1ac 100644 --- a/src/Interpreters/ClientInfo.cpp +++ b/src/Interpreters/ClientInfo.cpp @@ -9,6 +9,7 @@ #include "config_version.h" +#include namespace DB { @@ -18,7 +19,6 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } - void ClientInfo::write(WriteBuffer & out, UInt64 server_protocol_revision) const { if (server_protocol_revision < DBMS_MIN_REVISION_WITH_CLIENT_INFO) @@ -199,6 +199,20 @@ void ClientInfo::setInitialQuery() client_name = (VERSION_NAME " ") + client_name; } +bool ClientInfo::clientVersionEquals(const ClientInfo & other, bool compare_patch) const +{ + bool patch_equals = compare_patch ? client_version_patch == other.client_version_patch : true; + return client_version_major == other.client_version_major && + client_version_minor == other.client_version_minor && + patch_equals && + client_tcp_protocol_version == other.client_tcp_protocol_version; +} + +String ClientInfo::getVersionStr() const +{ + return std::format("{}.{}.{} ({})", client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version); +} + void ClientInfo::fillOSUserHostNameAndVersionInfo() { @@ -216,5 +230,27 @@ void ClientInfo::fillOSUserHostNameAndVersionInfo() client_tcp_protocol_version = DBMS_TCP_PROTOCOL_VERSION; } +String toString(ClientInfo::Interface interface) +{ + switch (interface) + { + case ClientInfo::Interface::TCP: + return "TCP"; + case ClientInfo::Interface::HTTP: + return "HTTP"; + case ClientInfo::Interface::GRPC: + return "GRPC"; + case ClientInfo::Interface::MYSQL: + return "MYSQL"; + case ClientInfo::Interface::POSTGRESQL: + return "POSTGRESQL"; + case ClientInfo::Interface::LOCAL: + return "LOCAL"; + case ClientInfo::Interface::TCP_INTERSERVER: + return "TCP_INTERSERVER"; + } + + return std::format("Unknown {}!\n", static_cast(interface)); +} } diff --git a/src/Interpreters/ClientInfo.h b/src/Interpreters/ClientInfo.h index 5c5a284d63b..798fc95954c 100644 --- a/src/Interpreters/ClientInfo.h +++ b/src/Interpreters/ClientInfo.h @@ -48,7 +48,6 @@ public: SECONDARY_QUERY = 2, /// Query that was initiated by another query for distributed or ON CLUSTER query execution. }; - QueryKind query_kind = QueryKind::NO_QUERY; /// Current values are not serialized, because it is passed separately. @@ -135,8 +134,14 @@ public: /// Initialize parameters on client initiating query. void setInitialQuery(); + bool clientVersionEquals(const ClientInfo & other, bool compare_patch) const; + + String getVersionStr() const; + private: void fillOSUserHostNameAndVersionInfo(); }; +String toString(ClientInfo::Interface interface); + } diff --git a/src/Interpreters/Session.cpp b/src/Interpreters/Session.cpp index e0b5db44593..439bf6056ba 100644 --- a/src/Interpreters/Session.cpp +++ b/src/Interpreters/Session.cpp @@ -302,7 +302,6 @@ Session::~Session() LOG_DEBUG(log, "{} Logout, user_id: {}", toString(auth_id), toString(*user_id)); if (auto session_log = getSessionLog()) { - /// TODO: We have to ensure that the same info is added to the session log on a LoginSuccess event and on the corresponding Logout event. session_log->addLogOut(auth_id, user, getClientInfo()); } } diff --git a/src/Server/TCPHandler.cpp b/src/Server/TCPHandler.cpp index ac3928b4abe..983d88b13fc 100644 --- a/src/Server/TCPHandler.cpp +++ b/src/Server/TCPHandler.cpp @@ -83,6 +83,22 @@ namespace ProfileEvents extern const Event MergeTreeAllRangesAnnouncementsSentElapsedMicroseconds; } +namespace DB::ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int ATTEMPT_TO_READ_AFTER_EOF; + extern const int CLIENT_HAS_CONNECTED_TO_WRONG_PORT; + extern const int UNKNOWN_EXCEPTION; + extern const int UNKNOWN_PACKET_FROM_CLIENT; + extern const int POCO_EXCEPTION; + extern const int SOCKET_TIMEOUT; + extern const int UNEXPECTED_PACKET_FROM_CLIENT; + extern const int UNKNOWN_PROTOCOL; + extern const int AUTHENTICATION_FAILED; + extern const int QUERY_WAS_CANCELLED; + extern const int CLIENT_INFO_DOES_NOT_MATCH; +} + namespace { NameToNameMap convertToQueryParameters(const Settings & passed_params) @@ -98,26 +114,56 @@ NameToNameMap convertToQueryParameters(const Settings & passed_params) return query_parameters; } +void validateClientInfo(const ClientInfo & session_client_info, const ClientInfo & client_info) +{ + // Secondary query may contain different client_info. + // In the case of select from distributed table or 'select * from remote' from non-tcp handler. Server sends the initial client_info data. + // + // Example 1: curl -q -s --max-time 60 -sS "http://127.0.0.1:8123/?" -d "SELECT 1 FROM remote('127.0.0.1', system.one)" + // HTTP handler initiates TCP connection with remote 127.0.0.1 (session on remote 127.0.0.1 use TCP interface) + // HTTP handler sends client_info with HTTP interface and HTTP data by TCP protocol in Protocol::Client::Query message. + // + // Example 2: select * from --host shard_1 // distributed table has 2 shards: shard_1, shard_2 + // shard_1 receives a message with 'ClickHouse client' client_name + // shard_1 initiates TCP connection with shard_2 with 'ClickHouse server' client_name. + // shard_1 sends 'ClickHouse client' client_name in Protocol::Client::Query message to shard_2. + if (client_info.query_kind == ClientInfo::QueryKind::SECONDARY_QUERY) + return; + + if (session_client_info.interface != client_info.interface) + { + throw Exception( + DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH, + "Client info's interface does not match: {} not equal to {}", + toString(session_client_info.interface), + toString(client_info.interface)); + } + + if (session_client_info.interface == ClientInfo::Interface::TCP) + { + if (session_client_info.client_name != client_info.client_name) + throw Exception( + DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH, + "Client info's client_name does not match: {} not equal to {}", + session_client_info.client_name, + client_info.client_name); + + // TCP handler got patch version 0 always for backward compatibility. + if (!session_client_info.clientVersionEquals(client_info, false)) + throw Exception( + DB::ErrorCodes::CLIENT_INFO_DOES_NOT_MATCH, + "Client info's version does not match: {} not equal to {}", + session_client_info.getVersionStr(), + client_info.getVersionStr()); + + // os_user, quota_key, client_trace_context can be different. + } +} } namespace DB { -namespace ErrorCodes -{ - extern const int LOGICAL_ERROR; - extern const int ATTEMPT_TO_READ_AFTER_EOF; - extern const int CLIENT_HAS_CONNECTED_TO_WRONG_PORT; - extern const int UNKNOWN_EXCEPTION; - extern const int UNKNOWN_PACKET_FROM_CLIENT; - extern const int POCO_EXCEPTION; - extern const int SOCKET_TIMEOUT; - extern const int UNEXPECTED_PACKET_FROM_CLIENT; - extern const int UNKNOWN_PROTOCOL; - extern const int AUTHENTICATION_FAILED; - extern const int QUERY_WAS_CANCELLED; -} - TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool parse_proxy_protocol_, std::string server_display_name_) : Poco::Net::TCPServerConnection(socket_) , server(server_) @@ -1484,7 +1530,10 @@ void TCPHandler::receiveQuery() /// Read client info. ClientInfo client_info = session->getClientInfo(); if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) + { client_info.read(*in, client_tcp_protocol_version); + validateClientInfo(session->getClientInfo(), client_info); + } /// Per query settings are also passed via TCP. /// We need to check them before applying due to they can violate the settings constraints. diff --git a/tests/queries/0_stateless/01601_proxy_protocol.sh b/tests/queries/0_stateless/01601_proxy_protocol.sh index 5f4ec6cc597..c8ee3ad1f7b 100755 --- a/tests/queries/0_stateless/01601_proxy_protocol.sh +++ b/tests/queries/0_stateless/01601_proxy_protocol.sh @@ -6,4 +6,4 @@ CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) # shellcheck source=../shell_config.sh . "$CURDIR"/../shell_config.sh -printf "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n\0\21ClickHouse client\24\r\253\251\3\0\7default\0\4\1\0\1\0\0\t0.0.0.0:0\1\tmilovidov\21milovidov-desktop\vClickHouse \24\r\253\251\3\0\1\0\0\0\2\1\25SELECT 'Hello, world'\2\0\247\203\254l\325\\z|\265\254F\275\333\206\342\24\202\24\0\0\0\n\0\0\0\240\1\0\2\377\377\377\377\0\0\0" | nc "${CLICKHOUSE_HOST}" "${CLICKHOUSE_PORT_TCP_WITH_PROXY}" | head -c150 | grep --text -o -F 'Hello, world' +printf "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n\0\21ClickHouse client\24\r\253\251\3\0\7default\0\4\1\0\1\0\0\t0.0.0.0:0\1\tmilovidov\21milovidov-desktop\21ClickHouse client\24\r\253\251\3\0\1\0\0\0\2\1\25SELECT 'Hello, world'\2\0\247\203\254l\325\\z|\265\254F\275\333\206\342\24\202\24\0\0\0\n\0\0\0\240\1\0\2\377\377\377\377\0\0\0" | nc "${CLICKHOUSE_HOST}" "${CLICKHOUSE_PORT_TCP_WITH_PROXY}" | head -c150 | grep --text -o -F 'Hello, world' diff --git a/tests/queries/0_stateless/02010_lc_native.python b/tests/queries/0_stateless/02010_lc_native.python index a197d32a3b9..6c4220855c8 100755 --- a/tests/queries/0_stateless/02010_lc_native.python +++ b/tests/queries/0_stateless/02010_lc_native.python @@ -8,6 +8,7 @@ import uuid CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1") CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT_TCP", "900000")) CLICKHOUSE_DATABASE = os.environ.get("CLICKHOUSE_DATABASE", "default") +CLIENT_NAME = "simple native protocol" def writeVarUInt(x, ba): @@ -86,7 +87,7 @@ def readStringBinary(s): def sendHello(s): ba = bytearray() writeVarUInt(0, ba) # Hello - writeStringBinary("simple native protocol", ba) + writeStringBinary(CLIENT_NAME, ba) writeVarUInt(21, ba) writeVarUInt(9, ba) writeVarUInt(54449, ba) @@ -123,7 +124,7 @@ def serializeClientInfo(ba, query_id): ba.append(1) # TCP writeStringBinary("os_user", ba) # os_user writeStringBinary("client_hostname", ba) # client_hostname - writeStringBinary("client_name", ba) # client_name + writeStringBinary(CLIENT_NAME, ba) # client_name writeVarUInt(21, ba) writeVarUInt(9, ba) writeVarUInt(54449, ba) diff --git a/tests/queries/0_stateless/02270_client_name.reference b/tests/queries/0_stateless/02270_client_name.reference index fbb2921010e..8d1f2863fad 100644 --- a/tests/queries/0_stateless/02270_client_name.reference +++ b/tests/queries/0_stateless/02270_client_name.reference @@ -1 +1 @@ -"ClickHouse" +"ClickHouse client" diff --git a/tests/queries/0_stateless/02458_insert_select_progress_tcp.python b/tests/queries/0_stateless/02458_insert_select_progress_tcp.python index 696eb01ff7e..92240e109c1 100644 --- a/tests/queries/0_stateless/02458_insert_select_progress_tcp.python +++ b/tests/queries/0_stateless/02458_insert_select_progress_tcp.python @@ -8,6 +8,7 @@ import json CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1") CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT_TCP", "900000")) CLICKHOUSE_DATABASE = os.environ.get("CLICKHOUSE_DATABASE", "default") +CLIENT_NAME = "simple native protocol" def writeVarUInt(x, ba): @@ -86,7 +87,7 @@ def readStringBinary(s): def sendHello(s): ba = bytearray() writeVarUInt(0, ba) # Hello - writeStringBinary("simple native protocol", ba) + writeStringBinary(CLIENT_NAME, ba) writeVarUInt(21, ba) writeVarUInt(9, ba) writeVarUInt(54449, ba) @@ -123,7 +124,7 @@ def serializeClientInfo(ba, query_id): ba.append(1) # TCP writeStringBinary("os_user", ba) # os_user writeStringBinary("client_hostname", ba) # client_hostname - writeStringBinary("client_name", ba) # client_name + writeStringBinary(CLIENT_NAME, ba) # client_name writeVarUInt(21, ba) writeVarUInt(9, ba) writeVarUInt(54449, ba) diff --git a/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.python b/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.python index 768fb2144e3..48b27d434ec 100644 --- a/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.python +++ b/tests/queries/0_stateless/02750_settings_alias_tcp_protocol.python @@ -8,6 +8,7 @@ import json CLICKHOUSE_HOST = os.environ.get("CLICKHOUSE_HOST", "127.0.0.1") CLICKHOUSE_PORT = int(os.environ.get("CLICKHOUSE_PORT_TCP", "900000")) CLICKHOUSE_DATABASE = os.environ.get("CLICKHOUSE_DATABASE", "default") +CLIENT_NAME = "simple native protocol" def writeVarUInt(x, ba): @@ -86,7 +87,7 @@ def readStringBinary(s): def sendHello(s): ba = bytearray() writeVarUInt(0, ba) # Hello - writeStringBinary("simple native protocol", ba) + writeStringBinary(CLIENT_NAME, ba) writeVarUInt(21, ba) writeVarUInt(9, ba) writeVarUInt(54449, ba) @@ -116,7 +117,7 @@ def serializeClientInfo(ba, query_id): ba.append(1) # TCP writeStringBinary("os_user", ba) # os_user writeStringBinary("client_hostname", ba) # client_hostname - writeStringBinary("client_name", ba) # client_name + writeStringBinary(CLIENT_NAME, ba) # client_name writeVarUInt(21, ba) writeVarUInt(9, ba) writeVarUInt(54449, ba) diff --git a/tests/queries/0_stateless/02865_tcp_proxy_query_packet_validation.reference b/tests/queries/0_stateless/02865_tcp_proxy_query_packet_validation.reference new file mode 100644 index 00000000000..1f966c6731b --- /dev/null +++ b/tests/queries/0_stateless/02865_tcp_proxy_query_packet_validation.reference @@ -0,0 +1,2 @@ +client_name does not match +version does not match diff --git a/tests/queries/0_stateless/02865_tcp_proxy_query_packet_validation.sh b/tests/queries/0_stateless/02865_tcp_proxy_query_packet_validation.sh new file mode 100755 index 00000000000..fbbb7d11ec0 --- /dev/null +++ b/tests/queries/0_stateless/02865_tcp_proxy_query_packet_validation.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# Tags: no-fasttest +# Tag no-fasttest: nc - command not found + +CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CURDIR"/../shell_config.sh + +printf "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n\0\21ClickHouse client\24\r\253\251\3\0\7default\0\4\1\0\1\0\0\t0.0.0.0:0\1\6hacker\16hacker-desktop\15Hacker client\24\r\253\251\3\0\1\0\0\0\2\1\25SELECT 'Hello, world'\2\0\247\203\254l\325\\z|\265\254F\275\333\206\342\24\202\24\0\0\0\n\0\0\0\240\1\0\2\377\377\377\377\0\0\0" | nc "${CLICKHOUSE_HOST}" "${CLICKHOUSE_PORT_TCP_WITH_PROXY}" | head -c250 | grep --text -o -F 'client_name does not match' +printf "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n\0\21ClickHouse client\24\r\253\251\3\0\7default\0\4\1\0\1\0\0\t0.0.0.0:0\1\6hacker\16hacker-desktop\21ClickHouse client\20\r\253\251\3\0\1\0\0\0\2\1\25SELECT 'Hello, world'\2\0\247\203\254l\325\\z|\265\254F\275\333\206\342\24\202\24\0\0\0\n\0\0\0\240\1\0\2\377\377\377\377\0\0\0" | nc "${CLICKHOUSE_HOST}" "${CLICKHOUSE_PORT_TCP_WITH_PROXY}" | head -c250 | grep --text -o -F 'version does not match'