From fabd7193bd687ee4b10ca826303399ff35e3d3dd Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Sun, 1 Aug 2021 17:12:34 +0300 Subject: [PATCH] Code cleanups and improvements. --- programs/local/LocalServer.cpp | 9 +- programs/server/Server.cpp | 4 +- src/Access/ContextAccess.h | 1 + src/Access/Credentials.h | 2 + src/Bridge/IBridgeHelper.cpp | 1 + src/Core/MySQL/Authentication.cpp | 16 +- src/Core/MySQL/Authentication.h | 7 +- src/Core/PostgreSQLProtocol.h | 25 +- .../ClickHouseDictionarySource.cpp | 2 +- src/IO/ReadBufferFromFileDescriptor.cpp | 1 + src/Interpreters/Context.cpp | 66 +-- src/Interpreters/Context.h | 29 +- src/Interpreters/Session.cpp | 316 ++++++------ src/Interpreters/Session.h | 89 ++-- src/Server/GRPCServer.cpp | 44 +- src/Server/HTTPHandler.cpp | 68 +-- src/Server/HTTPHandler.h | 16 +- src/Server/MySQLHandler.cpp | 52 +- src/Server/MySQLHandler.h | 2 +- src/Server/PostgreSQLHandler.cpp | 38 +- src/Server/PostgreSQLHandler.h | 7 +- src/Server/TCPHandler.cpp | 479 +++++++++--------- src/Server/TCPHandler.h | 39 +- src/TableFunctions/TableFunctionMySQL.cpp | 3 +- .../test.py | 1 + .../01455_opentelemetry_distributed.reference | 12 + .../01455_opentelemetry_distributed.sh | 30 +- 27 files changed, 677 insertions(+), 682 deletions(-) diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index 7f1bbe77d9c..44e9880fabb 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include #include @@ -377,11 +376,13 @@ void LocalServer::processQueries() /// we can't mutate global global_context (can lead to races, as it was already passed to some background threads) /// so we can't reuse it safely as a query context and need a copy here - Session session(global_context, ClientInfo::Interface::TCP); - session.setUser("default", "", Poco::Net::SocketAddress{}); + auto context = Context::createCopy(global_context); - auto context = session.makeQueryContext(""); + context->makeSessionContext(); + context->makeQueryContext(); + context->authenticate("default", "", Poco::Net::SocketAddress{}); + context->setCurrentQueryId(""); applyCmdSettings(context); /// Use the same query_id (and thread group) for all queries diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index 98c63f9896a..c30ef52f46a 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -47,13 +47,13 @@ #include #include #include -#include #include #include #include #include #include #include +#include #include #include #include @@ -1429,7 +1429,7 @@ if (ThreadFuzzer::instance().isEffective()) /// Must be done after initialization of `servers`, because async_metrics will access `servers` variable from its thread. async_metrics.start(); - Session::enableNamedSessions(); + Session::startupNamedSessions(); { String level_str = config().getString("text_log.level", ""); diff --git a/src/Access/ContextAccess.h b/src/Access/ContextAccess.h index 70145b0a3ef..cde69471800 100644 --- a/src/Access/ContextAccess.h +++ b/src/Access/ContextAccess.h @@ -70,6 +70,7 @@ public: /// Returns the current user. The function can return nullptr. UserPtr getUser() const; String getUserName() const; + std::optional getUserID() const { return getParams().user_id; } /// Returns information about current and enabled roles. std::shared_ptr getRolesInfo() const; diff --git a/src/Access/Credentials.h b/src/Access/Credentials.h index 5e9fd1589e0..256ed3853ab 100644 --- a/src/Access/Credentials.h +++ b/src/Access/Credentials.h @@ -26,6 +26,8 @@ protected: String user_name; }; +/// Does not check the password/credentials and that the specified host is allowed. +/// (Used only internally in cluster, if the secret matches) class AlwaysAllowCredentials : public Credentials { diff --git a/src/Bridge/IBridgeHelper.cpp b/src/Bridge/IBridgeHelper.cpp index 5c884a2ca3d..984641be3d2 100644 --- a/src/Bridge/IBridgeHelper.cpp +++ b/src/Bridge/IBridgeHelper.cpp @@ -5,6 +5,7 @@ #include #include #include +#include namespace fs = std::filesystem; diff --git a/src/Core/MySQL/Authentication.cpp b/src/Core/MySQL/Authentication.cpp index bc34b5637d6..aeb9a411082 100644 --- a/src/Core/MySQL/Authentication.cpp +++ b/src/Core/MySQL/Authentication.cpp @@ -2,8 +2,6 @@ #include #include #include -#include -#include #include #include @@ -74,7 +72,7 @@ Native41::Native41(const String & password, const String & auth_plugin_data) } void Native41::authenticate( - const String & user_name, std::optional auth_response, Session & session, + const String & user_name, Session & session, std::optional auth_response, std::shared_ptr packet_endpoint, bool, const Poco::Net::SocketAddress & address) { if (!auth_response) @@ -87,7 +85,7 @@ void Native41::authenticate( if (auth_response->empty()) { - session.setUser(user_name, "", address); + session.authenticate(user_name, "", address); return; } @@ -97,9 +95,7 @@ void Native41::authenticate( + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.", ErrorCodes::UNKNOWN_EXCEPTION); - const auto user_authentication = session.getUserAuthentication(user_name); - - Poco::SHA1Engine::Digest double_sha1_value = user_authentication.getPasswordDoubleSHA1(); + Poco::SHA1Engine::Digest double_sha1_value = session.getPasswordDoubleSHA1(user_name); assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE); Poco::SHA1Engine engine; @@ -112,7 +108,7 @@ void Native41::authenticate( { password_sha1[i] = digest[i] ^ static_cast((*auth_response)[i]); } - session.setUser(user_name, password_sha1, address); + session.authenticate(user_name, password_sha1, address); } #if USE_SSL @@ -137,7 +133,7 @@ Sha256Password::Sha256Password(RSA & public_key_, RSA & private_key_, Poco::Logg } void Sha256Password::authenticate( - const String & user_name, std::optional auth_response, Session & session, + const String & user_name, Session & session, std::optional auth_response, std::shared_ptr packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) { if (!auth_response) @@ -232,7 +228,7 @@ void Sha256Password::authenticate( password.pop_back(); } - session.setUser(user_name, password, address); + session.authenticate(user_name, password, address); } #endif diff --git a/src/Core/MySQL/Authentication.h b/src/Core/MySQL/Authentication.h index 0dde8d10c0e..a60e769434e 100644 --- a/src/Core/MySQL/Authentication.h +++ b/src/Core/MySQL/Authentication.h @@ -15,6 +15,7 @@ namespace DB { +class Session; namespace MySQLProtocol { @@ -32,7 +33,7 @@ public: virtual String getAuthPluginData() = 0; virtual void authenticate( - const String & user_name, std::optional auth_response, Session & session, + const String & user_name, Session & session, std::optional auth_response, std::shared_ptr packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0; }; @@ -49,7 +50,7 @@ public: String getAuthPluginData() override { return scramble; } void authenticate( - const String & user_name, std::optional auth_response, Session & session, + const String & user_name, Session & session, std::optional auth_response, std::shared_ptr packet_endpoint, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override; private: @@ -69,7 +70,7 @@ public: String getAuthPluginData() override { return scramble; } void authenticate( - const String & user_name, std::optional auth_response, Session & session, + const String & user_name, Session & session, std::optional auth_response, std::shared_ptr packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) override; private: diff --git a/src/Core/PostgreSQLProtocol.h b/src/Core/PostgreSQLProtocol.h index 19bcc727105..aef0ed6ab25 100644 --- a/src/Core/PostgreSQLProtocol.h +++ b/src/Core/PostgreSQLProtocol.h @@ -1,14 +1,11 @@ #pragma once -#include -#include #include -#include -#include #include #include #include #include +#include #include #include #include @@ -808,8 +805,9 @@ protected: Messaging::MessageTransport & mt, const Poco::Net::SocketAddress & address) { - try { - session.setUser(user_name, password, address); + try + { + session.authenticate(user_name, password, address); } catch (const Exception &) { @@ -841,7 +839,7 @@ public: Messaging::MessageTransport & mt, const Poco::Net::SocketAddress & address) override { - setPassword(user_name, "", session, mt, address); + return setPassword(user_name, "", session, mt, address); } Authentication::Type getType() const override @@ -865,7 +863,7 @@ public: if (type == Messaging::FrontMessageType::PASSWORD_MESSAGE) { std::unique_ptr password = mt.receive(); - setPassword(user_name, password->password, session, mt, address); + return setPassword(user_name, password->password, session, mt, address); } else throw Exception( @@ -902,16 +900,7 @@ public: Messaging::MessageTransport & mt, const Poco::Net::SocketAddress & address) { - Authentication::Type user_auth_type; - try - { - user_auth_type = session.getUserAuthentication(user_name).getType(); - } - catch (const std::exception & e) - { - session.onLogInFailure(user_name, e); - throw; - } + Authentication::Type user_auth_type = session.getAuthenticationType(user_name); if (type_to_method.find(user_auth_type) != type_to_method.end()) { diff --git a/src/Dictionaries/ClickHouseDictionarySource.cpp b/src/Dictionaries/ClickHouseDictionarySource.cpp index 0f085a7c1a2..d4f01dee8b2 100644 --- a/src/Dictionaries/ClickHouseDictionarySource.cpp +++ b/src/Dictionaries/ClickHouseDictionarySource.cpp @@ -255,7 +255,7 @@ void registerDictionarySourceClickHouse(DictionarySourceFactory & factory) /// We should set user info even for the case when the dictionary is loaded in-process (without TCP communication). if (configuration.is_local) { - context_copy->setUser(configuration.user, configuration.password, Poco::Net::SocketAddress("127.0.0.1", 0)); + context_copy->authenticate(configuration.user, configuration.password, Poco::Net::SocketAddress("127.0.0.1", 0)); context_copy = copyContextAndApplySettings(config_prefix, context_copy, config); } diff --git a/src/IO/ReadBufferFromFileDescriptor.cpp b/src/IO/ReadBufferFromFileDescriptor.cpp index fdb538d4a49..e60ec335ca1 100644 --- a/src/IO/ReadBufferFromFileDescriptor.cpp +++ b/src/IO/ReadBufferFromFileDescriptor.cpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace ProfileEvents diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 601127c99b5..4d918d0fbb6 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -588,48 +588,45 @@ ConfigurationPtr Context::getUsersConfig() } -void Context::setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address) +void Context::authenticate(const String & name, const String & password, const Poco::Net::SocketAddress & address) { - auto lock = getLock(); + authenticate(BasicCredentials(name, password), address); +} + +void Context::authenticate(const Credentials & credentials, const Poco::Net::SocketAddress & address) +{ + auto authenticated_user_id = getAccessControlManager().login(credentials, address.host()); client_info.current_user = credentials.getUserName(); client_info.current_address = address; #if defined(ARCADIA_BUILD) /// This is harmful field that is used only in foreign "Arcadia" build. - client_info.current_password.clear(); if (const auto * basic_credentials = dynamic_cast(&credentials)) client_info.current_password = basic_credentials->getPassword(); #endif - /// Find a user with such name and check the credentials. - auto new_user_id = getAccessControlManager().login(credentials, address.host()); - auto new_access = getAccessControlManager().getContextAccess( - new_user_id, /* current_roles = */ {}, /* use_default_roles = */ true, - settings, current_database, client_info); + setUser(authenticated_user_id); +} - user_id = new_user_id; - access = std::move(new_access); +void Context::setUser(const UUID & user_id_) +{ + auto lock = getLock(); + + user_id = user_id_; + + access = getAccessControlManager().getContextAccess( + user_id_, /* current_roles = */ {}, /* use_default_roles = */ true, settings, current_database, client_info); auto user = access->getUser(); current_roles = std::make_shared>(user->granted_roles.findGranted(user->default_roles)); - if (!user->default_database.empty()) - setCurrentDatabase(user->default_database); - auto default_profile_info = access->getDefaultProfileInfo(); settings_constraints_and_current_profiles = default_profile_info->getConstraintsAndProfileIDs(); applySettingsChanges(default_profile_info->settings); -} -void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address) -{ - setUser(BasicCredentials(name, password), address); -} - -void Context::setUserWithoutCheckingPassword(const String & name, const Poco::Net::SocketAddress & address) -{ - setUser(AlwaysAllowCredentials(name), address); + if (!user->default_database.empty()) + setCurrentDatabase(user->default_database); } std::shared_ptr Context::getUser() const @@ -637,12 +634,6 @@ std::shared_ptr Context::getUser() const return getAccess()->getUser(); } -void Context::setQuotaKey(String quota_key_) -{ - auto lock = getLock(); - client_info.quota_key = std::move(quota_key_); -} - String Context::getUserName() const { return getAccess()->getUserName(); @@ -655,6 +646,13 @@ std::optional Context::getUserID() const } +void Context::setQuotaKey(String quota_key_) +{ + auto lock = getLock(); + client_info.quota_key = std::move(quota_key_); +} + + void Context::setCurrentRoles(const std::vector & current_roles_) { auto lock = getLock(); @@ -736,10 +734,13 @@ ASTPtr Context::getRowPolicyCondition(const String & database, const String & ta void Context::setInitialRowPolicy() { auto lock = getLock(); - auto initial_user_id = getAccessControlManager().find(client_info.initial_user); initial_row_policy = nullptr; - if (initial_user_id) - initial_row_policy = getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {}); + if (client_info.initial_user == client_info.current_user) + return; + auto initial_user_id = getAccessControlManager().find(client_info.initial_user); + if (!initial_user_id) + return; + initial_row_policy = getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {}); } @@ -1180,6 +1181,9 @@ void Context::setCurrentQueryId(const String & query_id) } client_info.current_query_id = query_id_to_set; + + if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY) + client_info.initial_query_id = client_info.current_query_id; } void Context::killCurrentQuery() diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 0bb32bb7b43..4e378dacf01 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -14,21 +14,16 @@ #include #include #include -#include #include #if !defined(ARCADIA_BUILD) # include "config_core.h" #endif -#include -#include -#include #include #include #include #include -#include namespace Poco::Net { class IPAddress; } @@ -67,6 +62,7 @@ class ProcessList; class QueryStatus; class Macros; struct Progress; +struct FileProgress; class Clusters; class QueryLog; class QueryThreadLog; @@ -366,23 +362,21 @@ public: void setUsersConfig(const ConfigurationPtr & config); ConfigurationPtr getUsersConfig(); - /// Sets the current user, checks the credentials and that the specified host is allowed. - /// Must be called before getClientInfo() can be called. - void setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address); - void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address); + /// Sets the current user, checks the credentials and that the specified address is allowed to connect from. + /// The function throws an exception if there is no such user or password is wrong. + void authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address); + void authenticate(const Credentials & credentials, const Poco::Net::SocketAddress & address); - /// Sets the current user, *does not check the password/credentials and that the specified host is allowed*. - /// Must be called before getClientInfo. - /// - /// (Used only internally in cluster, if the secret matches) - void setUserWithoutCheckingPassword(const String & name, const Poco::Net::SocketAddress & address); - - void setQuotaKey(String quota_key_); + /// Sets the current user assuming that he/she is already authenticated. + /// WARNING: This function doesn't check password! Don't use until it's necessary! + void setUser(const UUID & user_id_); UserPtr getUser() const; String getUserName() const; std::optional getUserID() const; + void setQuotaKey(String quota_key_); + void setCurrentRoles(const std::vector & current_roles_); void setCurrentRolesDefault(); boost::container::flat_set getCurrentRoles() const; @@ -590,8 +584,6 @@ public: std::optional getTCPPortSecure() const; - std::shared_ptr acquireNamedSession(const String & session_id, std::chrono::steady_clock::duration timeout, bool session_check); - /// For methods below you may need to acquire the context lock by yourself. ContextMutablePtr getQueryContext() const; @@ -602,7 +594,6 @@ public: bool hasSessionContext() const { return !session_context.expired(); } ContextMutablePtr getGlobalContext() const; - bool hasGlobalContext() const { return !global_context.expired(); } bool isGlobalContext() const { diff --git a/src/Interpreters/Session.cpp b/src/Interpreters/Session.cpp index acebc182a64..7334f2e7640 100644 --- a/src/Interpreters/Session.cpp +++ b/src/Interpreters/Session.cpp @@ -1,24 +1,22 @@ #include #include -#include #include +#include #include #include #include #include #include -#include - -#include -#include -#include #include -#include +#include #include +#include +#include #include + namespace DB { @@ -27,13 +25,13 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; extern const int SESSION_NOT_FOUND; extern const int SESSION_IS_LOCKED; - extern const int NOT_IMPLEMENTED; } + class NamedSessionsStorage; -/// User name and session identifier. Named sessions are local to users. -using NamedSessionKey = std::pair; +/// User ID and session identifier. Named sessions are local to users. +using NamedSessionKey = std::pair; /// Named sessions. The user could specify session identifier to reuse settings and temporary tables in subsequent requests. struct NamedSessionData @@ -75,21 +73,16 @@ public: } /// Find existing session or create a new. - std::shared_ptr acquireSession( + std::pair, bool> acquireSession( + const ContextPtr & global_context, + const UUID & user_id, const String & session_id, - ContextMutablePtr context, std::chrono::steady_clock::duration timeout, bool throw_if_not_found) { std::unique_lock lock(mutex); - const auto & client_info = context->getClientInfo(); - const auto & user_name = client_info.current_user; - - if (user_name.empty()) - throw Exception("Empty user name.", ErrorCodes::LOGICAL_ERROR); - - Key key(user_name, session_id); + Key key{user_id, session_id}; auto it = sessions.find(key); if (it == sessions.end()) @@ -98,22 +91,20 @@ public: throw Exception("Session not found.", ErrorCodes::SESSION_NOT_FOUND); /// Create a new session from current context. + auto context = Context::createCopy(global_context); it = sessions.insert(std::make_pair(key, std::make_shared(key, context, timeout, *this))).first; + const auto & session = it->second; + return {session, true}; } - else if (it->second->key.first != client_info.current_user) + else { - throw Exception("Session belongs to a different user", ErrorCodes::SESSION_IS_LOCKED); + /// Use existing session. + const auto & session = it->second; + + if (!session.unique()) + throw Exception("Session is locked by a concurrent client.", ErrorCodes::SESSION_IS_LOCKED); + return {session, false}; } - - /// Use existing session. - const auto & session = it->second; - - if (!session.unique()) - throw Exception("Session is locked by a concurrent client.", ErrorCodes::SESSION_IS_LOCKED); - - session->context->getClientInfo() = client_info; - - return session; } void releaseSession(NamedSessionData & session) @@ -229,164 +220,195 @@ void NamedSessionData::release() std::optional Session::named_sessions = std::nullopt; -void Session::enableNamedSessions() +void Session::startupNamedSessions() { named_sessions.emplace(); } -Session::Session(const ContextPtr & context_to_copy, ClientInfo::Interface interface, std::optional default_format) - : session_context(Context::createCopy(context_to_copy)), - initial_session_context(session_context) +Session::Session(const ContextPtr & global_context_, ClientInfo::Interface interface_) + : global_context(global_context_) { - session_context->makeSessionContext(); - session_context->getClientInfo().interface = interface; - - if (default_format) - session_context->setDefaultFormat(*default_format); + prepared_client_info.emplace(); + prepared_client_info->interface = interface_; } Session::Session(Session &&) = default; Session::~Session() { - releaseNamedSession(); - - if (access) - { - auto user = access->getUser(); - if (user) - onLogOut(); - } -} - -Authentication Session::getUserAuthentication(const String & user_name) const -{ - return session_context->getAccessControlManager().read(user_name)->authentication; -} - -void Session::setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address) -{ - try - { - session_context->setUser(credentials, address); - - // Caching access just in case if context is going to be replaced later (e.g. with context of NamedSessionData) - access = session_context->getAccess(); - - // Check if this is a not an intercluster session, but the real one. - if (access && access->getUser() && dynamic_cast(&credentials)) - { - onLogInSuccess(); - } - } - catch (const std::exception & e) - { - onLogInFailure(credentials.getUserName(), e); - throw; - } -} - -void Session::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address) -{ - setUser(BasicCredentials(name, password), address); -} - -void Session::onLogInSuccess() -{ -} - -void Session::onLogInFailure(const String & /* user_name */, const std::exception & /* failure_reason */) -{ -} - -void Session::onLogOut() -{ -} - -void Session::promoteToNamedSession(const String & session_id, std::chrono::steady_clock::duration timeout, bool session_check) -{ - if (!named_sessions) - throw Exception("Support for named sessions is not enabled", ErrorCodes::NOT_IMPLEMENTED); - - auto new_named_session = named_sessions->acquireSession(session_id, session_context, timeout, session_check); - - // Must retain previous client info cause otherwise source client address and port, - // and other stuff are reused from previous user of the said session. - const ClientInfo prev_client_info = session_context->getClientInfo(); - - session_context = new_named_session->context; - session_context->getClientInfo() = prev_client_info; - session_context->makeSessionContext(); - - named_session.swap(new_named_session); -} - -/// Early release a NamedSessionData. -void Session::releaseNamedSession() -{ + /// Early release a NamedSessionData. if (named_session) - { named_session->release(); - named_session.reset(); - } - - session_context = initial_session_context; } -ContextMutablePtr Session::makeQueryContext(const String & query_id) const +Authentication::Type Session::getAuthenticationType(const String & user_name) const { - ContextMutablePtr new_query_context = Context::createCopy(session_context); - - new_query_context->setCurrentQueryId(query_id); - new_query_context->setSessionContext(session_context); - new_query_context->makeQueryContext(); - - ClientInfo & client_info = new_query_context->getClientInfo(); - client_info.initial_user = client_info.current_user; - client_info.initial_query_id = client_info.current_query_id; - client_info.initial_address = client_info.current_address; - - return new_query_context; + return global_context->getAccessControlManager().read(user_name)->authentication.getType(); } -ContextPtr Session::sessionContext() const +Authentication::Digest Session::getPasswordDoubleSHA1(const String & user_name) const { - return session_context; + return global_context->getAccessControlManager().read(user_name)->authentication.getPasswordDoubleSHA1(); } -ContextMutablePtr Session::mutableSessionContext() +void Session::authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address) { - return session_context; + authenticate(BasicCredentials{user_name, password}, address); +} + +void Session::authenticate(const Credentials & credentials_, const Poco::Net::SocketAddress & address_) +{ + if (session_context) + throw Exception("If there is a session context it must be created after authentication", ErrorCodes::LOGICAL_ERROR); + + user_id = global_context->getAccessControlManager().login(credentials_, address_.host()); + + prepared_client_info->current_user = credentials_.getUserName(); + prepared_client_info->current_address = address_; + +#if defined(ARCADIA_BUILD) + /// This is harmful field that is used only in foreign "Arcadia" build. + if (const auto * basic_credentials = dynamic_cast(&credentials_)) + session_client_info->current_password = basic_credentials->getPassword(); +#endif } ClientInfo & Session::getClientInfo() { - return session_context->getClientInfo(); + return session_context ? session_context->getClientInfo() : *prepared_client_info; } const ClientInfo & Session::getClientInfo() const { - return session_context->getClientInfo(); + return session_context ? session_context->getClientInfo() : *prepared_client_info; } -const Settings & Session::getSettings() const +ContextMutablePtr Session::makeSessionContext() { - return session_context->getSettingsRef(); + if (session_context) + throw Exception("Session context already exists", ErrorCodes::LOGICAL_ERROR); + if (query_context_created) + throw Exception("Session context must be created before any query context", ErrorCodes::LOGICAL_ERROR); + + /// Make a new session context. + ContextMutablePtr new_session_context; + new_session_context = Context::createCopy(global_context); + new_session_context->makeSessionContext(); + + /// Copy prepared client info to the new session context. + auto & res_client_info = new_session_context->getClientInfo(); + res_client_info = std::move(prepared_client_info).value(); + prepared_client_info.reset(); + + /// Set user information for the new context: current profiles, roles, access rights. + if (user_id) + new_session_context->setUser(*user_id); + + /// Session context is ready. + session_context = new_session_context; + user = session_context->getUser(); + + return session_context; } -void Session::setQuotaKey(const String & quota_key) +ContextMutablePtr Session::makeSessionContext(const String & session_id_, std::chrono::steady_clock::duration timeout_, bool session_check_) { - session_context->setQuotaKey(quota_key); + if (session_context) + throw Exception("Session context already exists", ErrorCodes::LOGICAL_ERROR); + if (query_context_created) + throw Exception("Session context must be created before any query context", ErrorCodes::LOGICAL_ERROR); + if (!named_sessions) + throw Exception("Support for named sessions is not enabled", ErrorCodes::LOGICAL_ERROR); + + /// Make a new session context OR + /// if the `session_id` and `user_id` were used before then just get a previously created session context. + std::shared_ptr new_named_session; + bool new_named_session_created = false; + std::tie(new_named_session, new_named_session_created) + = named_sessions->acquireSession(global_context, user_id.value_or(UUID{}), session_id_, timeout_, session_check_); + + auto new_session_context = new_named_session->context; + new_session_context->makeSessionContext(); + + /// Copy prepared client info to the session context, no matter it's been just created or not. + /// If we continue using a previously created session context found by session ID + /// it's necessary to replace the client info in it anyway, because it contains actual connection information (client address, etc.) + auto & res_client_info = new_session_context->getClientInfo(); + res_client_info = std::move(prepared_client_info).value(); + prepared_client_info.reset(); + + /// Set user information for the new context: current profiles, roles, access rights. + if (user_id && !new_session_context->getUser()) + new_session_context->setUser(*user_id); + + /// Session context is ready. + session_context = new_session_context; + session_id = session_id_; + named_session = new_named_session; + named_session_created = new_named_session_created; + user = session_context->getUser(); + + return session_context; } -String Session::getCurrentDatabase() const +ContextMutablePtr Session::makeQueryContext(const ClientInfo & query_client_info) const { - return session_context->getCurrentDatabase(); + return makeQueryContextImpl(&query_client_info, nullptr); } -void Session::setCurrentDatabase(const String & name) +ContextMutablePtr Session::makeQueryContext(ClientInfo && query_client_info) const { - session_context->setCurrentDatabase(name); + return makeQueryContextImpl(nullptr, &query_client_info); +} + +ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_to_copy, ClientInfo * client_info_to_move) const +{ + /// We can create a query context either from a session context or from a global context. + bool from_session_context = static_cast(session_context); + + /// Create a new query context. + ContextMutablePtr query_context = Context::createCopy(from_session_context ? session_context : global_context); + query_context->makeQueryContext(); + + /// Copy the specified client info to the new query context. + auto & res_client_info = query_context->getClientInfo(); + if (client_info_to_move) + res_client_info = std::move(*client_info_to_move); + else if (client_info_to_copy && (client_info_to_copy != &getClientInfo())) + res_client_info = *client_info_to_copy; + + /// Copy current user's name and address if it was authenticated after query_client_info was initialized. + if (prepared_client_info && !prepared_client_info->current_user.empty()) + { + res_client_info.current_user = prepared_client_info->current_user; + res_client_info.current_address = prepared_client_info->current_address; +#if defined(ARCADIA_BUILD) + res_client_info.current_password = prepared_client_info->current_password; +#endif + } + + /// Set parameters of initial query. + if (res_client_info.query_kind == ClientInfo::QueryKind::NO_QUERY) + res_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; + + if (res_client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY) + { + res_client_info.initial_user = res_client_info.current_user; + res_client_info.initial_address = res_client_info.current_address; + } + + /// Sets that row policies from the initial user should be used too. + query_context->setInitialRowPolicy(); + + /// Set user information for the new context: current profiles, roles, access rights. + if (user_id && !query_context->getUser()) + query_context->setUser(*user_id); + + /// Query context is ready. + query_context_created = true; + user = query_context->getUser(); + + return query_context; } } diff --git a/src/Interpreters/Session.h b/src/Interpreters/Session.h index 300ed779c49..58370aad2d0 100644 --- a/src/Interpreters/Session.h +++ b/src/Interpreters/Session.h @@ -1,8 +1,9 @@ #pragma once -#include -#include +#include +#include #include +#include #include #include @@ -13,77 +14,77 @@ namespace Poco::Net { class SocketAddress; } namespace DB { class Credentials; -class ContextAccess; -struct Settings; class Authentication; struct NamedSessionData; class NamedSessionsStorage; +struct User; +using UserPtr = std::shared_ptr; /** Represents user-session from the server perspective, * basically it is just a smaller subset of Context API, simplifies Context management. * * Holds session context, facilitates acquisition of NamedSession and proper creation of query contexts. - * Adds log in, log out and login failure events to the SessionLog. */ class Session { - static std::optional named_sessions; - public: /// Allow to use named sessions. The thread will be run to cleanup sessions after timeout has expired. /// The method must be called at the server startup. - static void enableNamedSessions(); + static void startupNamedSessions(); -// static Session makeSessionFromCopyOfContext(const ContextPtr & _context_to_copy); - Session(const ContextPtr & context_to_copy, ClientInfo::Interface interface, std::optional default_format = std::nullopt); - virtual ~Session(); + Session(const ContextPtr & global_context_, ClientInfo::Interface interface_); + Session(Session &&); + ~Session(); Session(const Session &) = delete; Session& operator=(const Session &) = delete; - Session(Session &&); -// Session& operator=(Session &&); + /// Provides information about the authentication type of a specified user. + Authentication::Type getAuthenticationType(const String & user_name) const; + Authentication::Digest getPasswordDoubleSHA1(const String & user_name) const; - Authentication getUserAuthentication(const String & user_name) const; - void setUser(const Credentials & credentials, const Poco::Net::SocketAddress & address); - void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address); - - /// Handle login and logout events. - void onLogInSuccess(); - void onLogInFailure(const String & user_name, const std::exception & /* failure_reason */); - void onLogOut(); - - /** Propmotes current session to a named session. - * - * that is: re-uses or creates NamedSession and then piggybacks on it's context, - * retaining ClientInfo of current session_context. - * Acquired named_session is then released in the destructor. - */ - void promoteToNamedSession(const String & session_id, std::chrono::steady_clock::duration timeout, bool session_check); - /// Early release a NamedSession. - void releaseNamedSession(); - - ContextMutablePtr makeQueryContext(const String & query_id) const; - - ContextPtr sessionContext() const; - ContextMutablePtr mutableSessionContext(); + /// Sets the current user, checks the credentials and that the specified address is allowed to connect from. + /// The function throws an exception if there is no such user or password is wrong. + void authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address); + void authenticate(const Credentials & credentials_, const Poco::Net::SocketAddress & address_); + /// Returns a reference to session ClientInfo. ClientInfo & getClientInfo(); const ClientInfo & getClientInfo() const; - const Settings & getSettings() const; + /// Makes a session context, can be used one or zero times. + /// The function also assigns an user to this context. + ContextMutablePtr makeSessionContext(); + ContextMutablePtr makeSessionContext(const String & session_id_, std::chrono::steady_clock::duration timeout_, bool session_check_); + ContextMutablePtr sessionContext() { return session_context; } + ContextPtr sessionContext() const { return session_context; } - void setQuotaKey(const String & quota_key); - - String getCurrentDatabase() const; - void setCurrentDatabase(const String & name); + /// Makes a query context, can be used multiple times, with or without makeSession() called earlier. + /// The query context will be created from a copy of a session context if it exists, or from a copy of + /// a global context otherwise. In the latter case the function also assigns an user to this context. + ContextMutablePtr makeQueryContext() const { return makeQueryContext(getClientInfo()); } + ContextMutablePtr makeQueryContext(const ClientInfo & query_client_info) const; + ContextMutablePtr makeQueryContext(ClientInfo && query_client_info) const; private: + ContextMutablePtr makeQueryContextImpl(const ClientInfo * client_info_to_copy, ClientInfo * client_info_to_move) const; + + const ContextPtr global_context; + + /// ClientInfo that will be copied to a session context when it's created. + std::optional prepared_client_info; + + mutable UserPtr user; + std::optional user_id; + ContextMutablePtr session_context; - // So that Session can be used after forced release of named_session. - const ContextMutablePtr initial_session_context; - std::shared_ptr access; + mutable bool query_context_created = false; + + String session_id; std::shared_ptr named_session; + bool named_session_created = false; + + static std::optional named_sessions; }; } diff --git a/src/Server/GRPCServer.cpp b/src/Server/GRPCServer.cpp index f03d0ae5f9f..f0c6e208323 100644 --- a/src/Server/GRPCServer.cpp +++ b/src/Server/GRPCServer.cpp @@ -11,9 +11,9 @@ #include #include #include -#include #include #include +#include #include #include #include @@ -55,7 +55,6 @@ namespace ErrorCodes extern const int NETWORK_ERROR; extern const int NO_DATA_TO_INSERT; extern const int SUPPORT_IS_DISABLED; - extern const int UNKNOWN_DATABASE; } namespace @@ -561,7 +560,7 @@ namespace IServer & iserver; Poco::Logger * log = nullptr; - std::shared_ptr session; + std::optional session; ContextMutablePtr query_context; std::optional query_scope; String query_text; @@ -690,32 +689,20 @@ namespace password = ""; } - /// Create context. - session = std::make_shared(iserver.context(), ClientInfo::Interface::GRPC); /// Authentication. - session->setUser(user, password, user_address); - if (!quota_key.empty()) - session->setQuotaKey(quota_key); + session.emplace(iserver.context(), ClientInfo::Interface::GRPC); + session->authenticate(user, password, user_address); + session->getClientInfo().quota_key = quota_key; /// The user could specify session identifier and session timeout. /// It allows to modify settings, create temporary tables and reuse them in subsequent requests. if (!query_info.session_id().empty()) { - session->promoteToNamedSession( - query_info.session_id(), - getSessionTimeout(query_info, iserver.config()), - query_info.session_check()); + session->makeSessionContext( + query_info.session_id(), getSessionTimeout(query_info, iserver.config()), query_info.session_check()); } - query_context = session->makeQueryContext(query_info.query_id()); - query_scope.emplace(query_context); - - /// Set client info. - ClientInfo & client_info = query_context->getClientInfo(); - client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; - client_info.initial_user = client_info.current_user; - client_info.initial_query_id = client_info.current_query_id; - client_info.initial_address = client_info.current_address; + query_context = session->makeQueryContext(); /// Prepare settings. SettingsChanges settings_changes; @@ -725,11 +712,14 @@ namespace } query_context->checkSettingsConstraints(settings_changes); query_context->applySettingsChanges(settings_changes); - const Settings & settings = query_context->getSettingsRef(); + + query_context->setCurrentQueryId(query_info.query_id()); + query_scope.emplace(query_context); /// Prepare for sending exceptions and logs. - send_exception_with_stacktrace = query_context->getSettingsRef().calculate_text_stack_trace; - const auto client_logs_level = query_context->getSettingsRef().send_logs_level; + const Settings & settings = query_context->getSettingsRef(); + send_exception_with_stacktrace = settings.calculate_text_stack_trace; + const auto client_logs_level = settings.send_logs_level; if (client_logs_level != LogsLevel::none) { logs_queue = std::make_shared(); @@ -740,14 +730,10 @@ namespace /// Set the current database if specified. if (!query_info.database().empty()) - { - if (!DatabaseCatalog::instance().isDatabaseExist(query_info.database())) - throw Exception("Database " + query_info.database() + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE); query_context->setCurrentDatabase(query_info.database()); - } /// The interactive delay will be used to show progress. - interactive_delay = query_context->getSettingsRef().interactive_delay; + interactive_delay = settings.interactive_delay; query_context->setProgressCallback([this](const Progress & value) { return progress.incrementPiecewiseAtomically(value); }); /// Parse the query. diff --git a/src/Server/HTTPHandler.cpp b/src/Server/HTTPHandler.cpp index 0e6b7d57b7c..0492b58dc88 100644 --- a/src/Server/HTTPHandler.cpp +++ b/src/Server/HTTPHandler.cpp @@ -19,9 +19,9 @@ #include #include #include -#include #include #include +#include #include #include #include @@ -262,6 +262,7 @@ void HTTPHandler::pushDelayedResults(Output & used_output) HTTPHandler::HTTPHandler(IServer & server_, const std::string & name) : server(server_) , log(&Poco::Logger::get(name)) + , default_settings(server.context()->getSettingsRef()) { server_display_name = server.config().getString("display_name", getFQDNOrHostName()); } @@ -269,10 +270,7 @@ HTTPHandler::HTTPHandler(IServer & server_, const std::string & name) /// We need d-tor to be present in this translation unit to make it play well with some /// forward decls in the header. Other than that, the default d-tor would be OK. -HTTPHandler::~HTTPHandler() -{ - (void)this; -} +HTTPHandler::~HTTPHandler() = default; bool HTTPHandler::authenticateUser( @@ -352,7 +350,7 @@ bool HTTPHandler::authenticateUser( else { if (!request_credentials) - request_credentials = request_session->sessionContext()->makeGSSAcceptorContext(); + request_credentials = server.context()->makeGSSAcceptorContext(); auto * gss_acceptor_context = dynamic_cast(request_credentials.get()); if (!gss_acceptor_context) @@ -378,9 +376,7 @@ bool HTTPHandler::authenticateUser( } /// Set client info. It will be used for quota accounting parameters in 'setUser' method. - - ClientInfo & client_info = request_session->getClientInfo(); - client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; + ClientInfo & client_info = session->getClientInfo(); ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; if (request.getMethod() == HTTPServerRequest::HTTP_GET) @@ -392,10 +388,11 @@ bool HTTPHandler::authenticateUser( client_info.http_user_agent = request.get("User-Agent", ""); client_info.http_referer = request.get("Referer", ""); client_info.forwarded_for = request.get("X-Forwarded-For", ""); + client_info.quota_key = quota_key; try { - request_session->setUser(*request_credentials, request.clientAddress()); + session->authenticate(*request_credentials, request.clientAddress()); } catch (const Authentication::Require & required_credentials) { @@ -412,7 +409,7 @@ bool HTTPHandler::authenticateUser( } catch (const Authentication::Require & required_credentials) { - request_credentials = request_session->sessionContext()->makeGSSAcceptorContext(); + request_credentials = server.context()->makeGSSAcceptorContext(); if (required_credentials.getRealm().empty()) response.set("WWW-Authenticate", "Negotiate"); @@ -425,14 +422,6 @@ bool HTTPHandler::authenticateUser( } request_credentials.reset(); - - if (!quota_key.empty()) - request_session->setQuotaKey(quota_key); - - /// Query sent through HTTP interface is initial. - client_info.initial_user = client_info.current_user; - client_info.initial_address = client_info.current_address; - return true; } @@ -463,20 +452,16 @@ void HTTPHandler::processQuery( session_id = params.get("session_id"); session_timeout = parseSessionTimeout(config, params); std::string session_check = params.get("session_check", ""); - request_session->promoteToNamedSession(session_id, session_timeout, session_check == "1"); + session->makeSessionContext(session_id, session_timeout, session_check == "1"); } - SCOPE_EXIT({ - request_session->releaseNamedSession(); - }); - // Parse the OpenTelemetry traceparent header. // Disable in Arcadia -- it interferes with the // test_clickhouse.TestTracing.test_tracing_via_http_proxy[traceparent] test. + ClientInfo client_info = session->getClientInfo(); #if !defined(ARCADIA_BUILD) if (request.has("traceparent")) { - ClientInfo & client_info = request_session->getClientInfo(); std::string opentelemetry_traceparent = request.get("traceparent"); std::string error; if (!client_info.client_trace_context.parseTraceparentHeader( @@ -486,16 +471,11 @@ void HTTPHandler::processQuery( "Failed to parse OpenTelemetry traceparent header '{}': {}", opentelemetry_traceparent, error); } - client_info.client_trace_context.tracestate = request.get("tracestate", ""); } #endif - // Set the query id supplied by the user, if any, and also update the OpenTelemetry fields. - auto context = request_session->makeQueryContext(params.get("query_id", request.get("X-ClickHouse-Query-Id", ""))); - - ClientInfo & client_info = context->getClientInfo(); - client_info.initial_query_id = client_info.current_query_id; + auto context = session->makeQueryContext(std::move(client_info)); /// The client can pass a HTTP header indicating supported compression method (gzip or deflate). String http_response_compression_methods = request.get("Accept-Encoding", ""); @@ -560,7 +540,7 @@ void HTTPHandler::processQuery( if (buffer_until_eof) { - const std::string tmp_path(context->getTemporaryVolume()->getDisk()->getPath()); + const std::string tmp_path(server.context()->getTemporaryVolume()->getDisk()->getPath()); const std::string tmp_path_template(tmp_path + "http_buffers/"); auto create_tmp_disk_buffer = [tmp_path_template] (const WriteBufferPtr &) @@ -706,6 +686,9 @@ void HTTPHandler::processQuery( context->checkSettingsConstraints(settings_changes); context->applySettingsChanges(settings_changes); + // Set the query id supplied by the user, if any, and also update the OpenTelemetry fields. + context->setCurrentQueryId(params.get("query_id", request.get("X-ClickHouse-Query-Id", ""))); + const auto & query = getQuery(request, params, context); std::unique_ptr in_param = std::make_unique(query); in = has_external_data ? std::move(in_param) : std::make_unique(*in_param, *in_post_maybe_compressed); @@ -856,23 +839,10 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse setThreadName("HTTPHandler"); ThreadStatus thread_status; - SCOPE_EXIT({ - // If there is no request_credentials instance waiting for the next round, then the request is processed, - // so no need to preserve request_session either. - // Needs to be performed with respect to the other destructors in the scope though. - if (!request_credentials) - request_session.reset(); - }); - - if (!request_session) - { - // Context should be initialized before anything, for correct memory accounting. - request_session = std::make_shared(server.context(), ClientInfo::Interface::HTTP); - request_credentials.reset(); - } - - /// Cannot be set here, since query_id is unknown. + session = std::make_unique(server.context(), ClientInfo::Interface::HTTP); + SCOPE_EXIT({ session.reset(); }); std::optional query_scope; + Output used_output; /// In case of exception, send stack trace to client. @@ -886,7 +856,7 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse if (request.getVersion() == HTTPServerRequest::HTTP_1_1) response.setChunkedTransferEncoding(true); - HTMLForm params(request_session->getSettings(), request); + HTMLForm params(default_settings, request); with_stacktrace = params.getParsed("stacktrace", false); /// FIXME: maybe this check is already unnecessary. diff --git a/src/Server/HTTPHandler.h b/src/Server/HTTPHandler.h index bca73ca7cb8..98f573f8cef 100644 --- a/src/Server/HTTPHandler.h +++ b/src/Server/HTTPHandler.h @@ -21,6 +21,7 @@ namespace DB class Session; class Credentials; class IServer; +struct Settings; class WriteBufferFromHTTPServerResponse; using CompiledRegexPtr = std::shared_ptr; @@ -72,15 +73,22 @@ private: CurrentMetrics::Increment metric_increment{CurrentMetrics::HTTPConnection}; - // The request_session and the request_credentials instances may outlive a single request/response loop. + /// Reference to the immutable settings in the global context. + /// Those settings are used only to extract a http request's parameters. + /// See settings http_max_fields, http_max_field_name_size, http_max_field_value_size in HTMLForm. + const Settings & default_settings; + + // session is reset at the end of each request/response. + std::unique_ptr session; + + // The request_credential instance may outlive a single request/response loop. // This happens only when the authentication mechanism requires more than a single request/response exchange (e.g., SPNEGO). - std::shared_ptr request_session; std::unique_ptr request_credentials; // Returns true when the user successfully authenticated, - // the request_session instance will be configured accordingly, and the request_credentials instance will be dropped. + // the session instance will be configured accordingly, and the request_credentials instance will be dropped. // Returns false when the user is not authenticated yet, and the 'Negotiate' response is sent, - // the request_session and request_credentials instances are preserved. + // the session and request_credentials instances are preserved. // Throws an exception if authentication failed. bool authenticateUser( HTTPServerRequest & request, diff --git a/src/Server/MySQLHandler.cpp b/src/Server/MySQLHandler.cpp index f2ac1184640..93f4bff46c2 100644 --- a/src/Server/MySQLHandler.cpp +++ b/src/Server/MySQLHandler.cpp @@ -3,11 +3,11 @@ #include #include #include -#include #include #include #include #include +#include #include #include #include @@ -19,9 +19,8 @@ #include #include #include -#include -#include #include +#include #if !defined(ARCADIA_BUILD) # include @@ -88,12 +87,10 @@ void MySQLHandler::run() setThreadName("MySQLHandler"); ThreadStatus thread_status; - session = std::make_shared(server.context(), ClientInfo::Interface::MYSQL, "MySQLWire"); - auto & session_client_info = session->getClientInfo(); + session = std::make_unique(server.context(), ClientInfo::Interface::MYSQL); + SCOPE_EXIT({ session.reset(); }); - session_client_info.current_address = socket().peerAddress(); - session_client_info.connection_id = connection_id; - session_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; + session->getClientInfo().connection_id = connection_id; in = std::make_shared(socket()); out = std::make_shared(socket()); @@ -127,12 +124,12 @@ void MySQLHandler::run() authenticate(handshake_response.username, handshake_response.auth_plugin_name, handshake_response.auth_response); - session_client_info.initial_user = handshake_response.username; - try { + session->makeSessionContext(); + session->sessionContext()->setDefaultFormat("MySQLWire"); if (!handshake_response.database.empty()) - session->setCurrentDatabase(handshake_response.database); + session->sessionContext()->setCurrentDatabase(handshake_response.database); } catch (const Exception & exc) { @@ -246,26 +243,23 @@ void MySQLHandler::finishHandshake(MySQLProtocol::ConnectionPhase::HandshakeResp void MySQLHandler::authenticate(const String & user_name, const String & auth_plugin_name, const String & initial_auth_response) { - // For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used. - DB::Authentication::Type user_auth_type; try { - user_auth_type = session->getUserAuthentication(user_name).getType(); + // For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used. + if (session->getAuthenticationType(user_name) == DB::Authentication::SHA256_PASSWORD) + { + authPluginSSL(); + } + + std::optional auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional(initial_auth_response) : std::nullopt; + auth_plugin->authenticate(user_name, *session, auth_response, packet_endpoint, secure_connection, socket().peerAddress()); } - catch (const std::exception & e) + catch (const Exception & exc) { - session->onLogInFailure(user_name, e); + LOG_ERROR(log, "Authentication for user {} failed.", user_name); + packet_endpoint->sendPacket(ERRPacket(exc.code(), "00000", exc.message()), true); throw; } - - if (user_auth_type == DB::Authentication::SHA256_PASSWORD) - { - authPluginSSL(); - } - - std::optional auth_response = auth_plugin_name == auth_plugin->getName() ? std::make_optional(initial_auth_response) : std::nullopt; - auth_plugin->authenticate(user_name, auth_response, *session, packet_endpoint, secure_connection, socket().peerAddress()); - LOG_DEBUG(log, "Authentication for user {} succeeded.", user_name); } @@ -274,7 +268,7 @@ void MySQLHandler::comInitDB(ReadBuffer & payload) String database; readStringUntilEOF(database, payload); LOG_DEBUG(log, "Setting current database to {}", database); - session->setCurrentDatabase(database); + session->sessionContext()->setCurrentDatabase(database); packet_endpoint->sendPacket(OKPacket(0, client_capabilities, 0, 0, 1), true); } @@ -331,7 +325,9 @@ void MySQLHandler::comQuery(ReadBuffer & payload) ReadBufferFromString replacement(replacement_query); - auto query_context = session->makeQueryContext(Poco::format("mysql:%lu", connection_id)); + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(Poco::format("mysql:%lu", connection_id)); + CurrentThread::QueryScope query_scope{query_context}; std::atomic affected_rows {0}; auto prev = query_context->getProgressCallback(); @@ -343,8 +339,6 @@ void MySQLHandler::comQuery(ReadBuffer & payload) affected_rows += progress.written_rows; }); - CurrentThread::QueryScope query_scope{query_context}; - FormatSettings format_settings; format_settings.mysql_wire.client_capabilities = client_capabilities; format_settings.mysql_wire.max_packet_size = max_packet_size; diff --git a/src/Server/MySQLHandler.h b/src/Server/MySQLHandler.h index c57cb7d8f65..5258862cf23 100644 --- a/src/Server/MySQLHandler.h +++ b/src/Server/MySQLHandler.h @@ -63,7 +63,7 @@ protected: uint8_t sequence_id = 0; MySQLProtocol::PacketEndpointPtr packet_endpoint; - std::shared_ptr session; + std::unique_ptr session; using ReplacementFn = std::function; using Replacements = std::unordered_map; diff --git a/src/Server/PostgreSQLHandler.cpp b/src/Server/PostgreSQLHandler.cpp index ae21d387e73..0716d828520 100644 --- a/src/Server/PostgreSQLHandler.cpp +++ b/src/Server/PostgreSQLHandler.cpp @@ -2,8 +2,8 @@ #include #include #include +#include #include -#include #include "PostgreSQLHandler.h" #include #include @@ -53,14 +53,12 @@ void PostgreSQLHandler::run() setThreadName("PostgresHandler"); ThreadStatus thread_status; - Session session(server.context(), ClientInfo::Interface::POSTGRESQL, "PostgreSQLWire"); - auto & session_client_info = session.getClientInfo(); - - session_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; + session = std::make_unique(server.context(), ClientInfo::Interface::POSTGRESQL); + SCOPE_EXIT({ session.reset(); }); try { - if (!startup(session)) + if (!startup()) return; while (true) @@ -71,7 +69,7 @@ void PostgreSQLHandler::run() switch (message_type) { case PostgreSQLProtocol::Messaging::FrontMessageType::QUERY: - processQuery(session); + processQuery(); break; case PostgreSQLProtocol::Messaging::FrontMessageType::TERMINATE: LOG_DEBUG(log, "Client closed the connection"); @@ -110,7 +108,7 @@ void PostgreSQLHandler::run() } -bool PostgreSQLHandler::startup(Session & session) +bool PostgreSQLHandler::startup() { Int32 payload_size; Int32 info; @@ -119,17 +117,20 @@ bool PostgreSQLHandler::startup(Session & session) if (static_cast(info) == PostgreSQLProtocol::Messaging::FrontMessageType::CANCEL_REQUEST) { LOG_DEBUG(log, "Client issued request canceling"); - cancelRequest(session); + cancelRequest(); return false; } std::unique_ptr start_up_msg = receiveStartupMessage(payload_size); - authentication_manager.authenticate(start_up_msg->user, session, *message_transport, socket().peerAddress()); + const auto & user_name = start_up_msg->user; + authentication_manager.authenticate(user_name, *session, *message_transport, socket().peerAddress()); try { + session->makeSessionContext(); + session->sessionContext()->setDefaultFormat("PostgreSQLWire"); if (!start_up_msg->database.empty()) - session.setCurrentDatabase(start_up_msg->database); + session->sessionContext()->setCurrentDatabase(start_up_msg->database); } catch (const Exception & exc) { @@ -207,18 +208,16 @@ void PostgreSQLHandler::sendParameterStatusData(PostgreSQLProtocol::Messaging::S message_transport->flush(); } -void PostgreSQLHandler::cancelRequest(Session & session) +void PostgreSQLHandler::cancelRequest() { - // TODO (nemkov): maybe run cancellation query with session context? - auto query_context = session.makeQueryContext(std::string{}); - query_context->setDefaultFormat("Null"); - std::unique_ptr msg = message_transport->receiveWithPayloadSize(8); String query = Poco::format("KILL QUERY WHERE query_id = 'postgres:%d:%d'", msg->process_id, msg->secret_key); ReadBufferFromString replacement(query); + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(""); executeQuery(replacement, *out, true, query_context, {}); } @@ -242,7 +241,7 @@ inline std::unique_ptr PostgreSQL return message; } -void PostgreSQLHandler::processQuery(Session & session) +void PostgreSQLHandler::processQuery() { try { @@ -265,7 +264,7 @@ void PostgreSQLHandler::processQuery(Session & session) return; } - const auto & settings = session.getSettings(); + const auto & settings = session->sessionContext()->getSettingsRef(); std::vector queries; auto parse_res = splitMultipartQuery(query->query, queries, settings.max_query_size, settings.max_parser_depth); if (!parse_res.second) @@ -278,7 +277,8 @@ void PostgreSQLHandler::processQuery(Session & session) for (const auto & spl_query : queries) { secret_key = dis(gen); - auto query_context = session.makeQueryContext(Poco::format("postgres:%d:%d", connection_id, secret_key)); + auto query_context = session->makeQueryContext(); + query_context->setCurrentQueryId(Poco::format("postgres:%d:%d", connection_id, secret_key)); CurrentThread::QueryScope query_scope{query_context}; ReadBufferFromString read_buf(spl_query); diff --git a/src/Server/PostgreSQLHandler.h b/src/Server/PostgreSQLHandler.h index cf4a6620063..36dd62d3dec 100644 --- a/src/Server/PostgreSQLHandler.h +++ b/src/Server/PostgreSQLHandler.h @@ -39,6 +39,7 @@ private: Poco::Logger * log = &Poco::Logger::get("PostgreSQLHandler"); IServer & server; + std::unique_ptr session; bool ssl_enabled = false; Int32 connection_id = 0; Int32 secret_key = 0; @@ -57,7 +58,7 @@ private: void changeIO(Poco::Net::StreamSocket & socket); - bool startup(Session & session); + bool startup(); void establishSecureConnection(Int32 & payload_size, Int32 & info); @@ -65,11 +66,11 @@ private: void sendParameterStatusData(PostgreSQLProtocol::Messaging::StartupMessage & start_up_message); - void cancelRequest(Session & session); + void cancelRequest(); std::unique_ptr receiveStartupMessage(int payload_size); - void processQuery(DB::Session & session); + void processQuery(); static bool isEmptyQuery(const String & query); }; diff --git a/src/Server/TCPHandler.cpp b/src/Server/TCPHandler.cpp index de14f117981..b2db65e22bc 100644 --- a/src/Server/TCPHandler.cpp +++ b/src/Server/TCPHandler.cpp @@ -20,16 +20,16 @@ #include #include #include -#include #include #include #include -#include #include +#include #include #include #include #include +#include #include #include #include @@ -75,7 +75,6 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; extern const int ATTEMPT_TO_READ_AFTER_EOF; extern const int CLIENT_HAS_CONNECTED_TO_WRONG_PORT; - extern const int UNKNOWN_DATABASE; extern const int UNKNOWN_EXCEPTION; extern const int UNKNOWN_PACKET_FROM_CLIENT; extern const int POCO_EXCEPTION; @@ -90,7 +89,6 @@ TCPHandler::TCPHandler(IServer & server_, const Poco::Net::StreamSocket & socket , server(server_) , parse_proxy_protocol(parse_proxy_protocol_) , log(&Poco::Logger::get("TCPHandler")) - , query_context(Context::createCopy(server.context())) , server_display_name(std::move(server_display_name_)) { } @@ -115,16 +113,10 @@ void TCPHandler::runImpl() ThreadStatus thread_status; session = std::make_unique(server.context(), ClientInfo::Interface::TCP); - const auto session_context = session->sessionContext(); + extractConnectionSettingsFromContext(server.context()); - /// These timeouts can be changed after receiving query. - const auto & settings = session->getSettings(); - - auto global_receive_timeout = settings.receive_timeout; - auto global_send_timeout = settings.send_timeout; - - socket().setReceiveTimeout(global_receive_timeout); - socket().setSendTimeout(global_send_timeout); + socket().setReceiveTimeout(receive_timeout); + socket().setSendTimeout(send_timeout); socket().setNoDelay(true); in = std::make_shared(socket()); @@ -162,33 +154,27 @@ void TCPHandler::runImpl() try { /// We try to send error information to the client. - sendException(e, session->getSettings().calculate_text_stack_trace); + sendException(e, send_exception_with_stack_trace); } catch (...) {} throw; } - /// When connecting, the default database can be specified. - if (!default_database.empty()) - { - if (!DatabaseCatalog::instance().isDatabaseExist(default_database)) - { - Exception e("Database " + backQuote(default_database) + " doesn't exist", ErrorCodes::UNKNOWN_DATABASE); - LOG_ERROR(log, getExceptionMessage(e, true)); - sendException(e, settings.calculate_text_stack_trace); - return; - } - - session->setCurrentDatabase(default_database); - } - - UInt64 idle_connection_timeout = settings.idle_connection_timeout; - UInt64 poll_interval = settings.poll_interval; - sendHello(); - session->mutableSessionContext()->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); }); + if (!is_interserver_mode) /// In interserver mode queries are executed without a session context. + { + session->makeSessionContext(); + + /// If session created, then settings in session context has been updated. + /// So it's better to update the connection settings for flexibility. + extractConnectionSettingsFromContext(session->sessionContext()); + + /// When connecting, the default database could be specified. + if (!default_database.empty()) + session->sessionContext()->setCurrentDatabase(default_database); + } while (true) { @@ -210,10 +196,6 @@ void TCPHandler::runImpl() if (server.isCancelled() || in->eof()) break; - /// Set context of request. - /// TODO (nemkov): create query later in receiveQuery - query_context = session->makeQueryContext(std::string{}); // proper query_id is set later in receiveQuery - Stopwatch watch; state.reset(); @@ -226,8 +208,6 @@ void TCPHandler::runImpl() std::optional exception; bool network_error = false; - bool send_exception_with_stack_trace = true; - try { /// If a user passed query-local timeouts, reset socket to initial state at the end of the query @@ -240,23 +220,22 @@ void TCPHandler::runImpl() if (!receivePacket()) continue; - /** If Query received, then settings in query_context has been updated - * So, update some other connection settings, for flexibility. - */ - { - const Settings & query_settings = query_context->getSettingsRef(); - idle_connection_timeout = query_settings.idle_connection_timeout; - poll_interval = query_settings.poll_interval; - } - /** If part_uuids got received in previous packet, trying to read again. */ - if (state.empty() && state.part_uuids && !receivePacket()) + if (state.empty() && state.part_uuids_to_ignore && !receivePacket()) continue; query_scope.emplace(query_context); - send_exception_with_stack_trace = query_context->getSettingsRef().calculate_text_stack_trace; + /// If query received, then settings in query_context has been updated. + /// So it's better to update the connection settings for flexibility. + extractConnectionSettingsFromContext(query_context); + + /// Sync timeouts on client and server during current query to avoid dangling queries on server + /// NOTE: We use send_timeout for the receive timeout and vice versa (change arguments ordering in TimeoutSetter), + /// because send_timeout is client-side setting which has opposite meaning on the server side. + /// NOTE: these settings are applied only for current connection (not for distributed tables' connections) + state.timeout_setter = std::make_unique(socket(), receive_timeout, send_timeout); /// Should we send internal logs to client? const auto client_logs_level = query_context->getSettingsRef().send_logs_level; @@ -269,20 +248,18 @@ void TCPHandler::runImpl() CurrentThread::setFatalErrorCallback([this]{ sendLogs(); }); } - query_context->setExternalTablesInitializer([&settings, this] (ContextPtr context) + query_context->setExternalTablesInitializer([this] (ContextPtr context) { if (context != query_context) throw Exception("Unexpected context in external tables initializer", ErrorCodes::LOGICAL_ERROR); /// Get blocks of temporary tables - readData(settings); + readData(); /// Reset the input stream, as we received an empty block while receiving external table data. /// So, the stream has been marked as cancelled and we can't read from it anymore. state.block_in.reset(); state.maybe_compressed_in.reset(); /// For more accurate accounting by MemoryTracker. - - state.temporary_tables_read = true; }); /// Send structure of columns to client for function input() @@ -306,15 +283,12 @@ void TCPHandler::runImpl() sendData(state.input_header); }); - query_context->setInputBlocksReaderCallback([&settings, this] (ContextPtr context) -> Block + query_context->setInputBlocksReaderCallback([this] (ContextPtr context) -> Block { if (context != query_context) throw Exception("Unexpected context in InputBlocksReader", ErrorCodes::LOGICAL_ERROR); - size_t poll_interval_ms; - int receive_timeout; - std::tie(poll_interval_ms, receive_timeout) = getReadTimeouts(settings); - if (!readDataNext(poll_interval_ms, receive_timeout)) + if (!readDataNext()) { state.block_in.reset(); state.maybe_compressed_in.reset(); @@ -337,15 +311,13 @@ void TCPHandler::runImpl() /// Processing Query state.io = executeQuery(state.query, query_context, false, state.stage, may_have_embedded_data); - unknown_packet_in_send_data = query_context->getSettingsRef().unknown_packet_in_send_data; - after_check_cancelled.restart(); after_send_progress.restart(); if (state.io.out) { state.need_receive_data_for_insert = true; - processInsertQuery(settings); + processInsertQuery(); } else if (state.need_receive_data_for_input) // It implies pipeline execution { @@ -461,16 +433,17 @@ void TCPHandler::runImpl() try { - if (exception && !state.temporary_tables_read) - query_context->initializeExternalTablesIfSet(); + /// A query packet is always followed by one or more data packets. + /// If some of those data packets are left, try to skip them. + if (exception && !state.empty() && !state.read_all_data) + skipData(); } catch (...) { network_error = true; - LOG_WARNING(log, "Can't read external tables after query failure."); + LOG_WARNING(log, "Can't skip data packets after query failure."); } - try { /// QueryState should be cleared before QueryScope, since otherwise @@ -501,75 +474,94 @@ void TCPHandler::runImpl() } -bool TCPHandler::readDataNext(size_t poll_interval, time_t receive_timeout) +void TCPHandler::extractConnectionSettingsFromContext(const ContextPtr & context) +{ + const auto & settings = context->getSettingsRef(); + send_exception_with_stack_trace = settings.calculate_text_stack_trace; + send_timeout = settings.send_timeout; + receive_timeout = settings.receive_timeout; + poll_interval = settings.poll_interval; + idle_connection_timeout = settings.idle_connection_timeout; + interactive_delay = settings.interactive_delay; + sleep_in_send_tables_status = settings.sleep_in_send_tables_status_ms; + unknown_packet_in_send_data = settings.unknown_packet_in_send_data; + sleep_in_receive_cancel = settings.sleep_in_receive_cancel_ms; +} + + +bool TCPHandler::readDataNext() { Stopwatch watch(CLOCK_MONOTONIC_COARSE); + /// Poll interval should not be greater than receive_timeout + constexpr UInt64 min_timeout_ms = 5000; // 5 ms + UInt64 timeout_ms = std::max(min_timeout_ms, std::min(poll_interval * 1000000, static_cast(receive_timeout.totalMicroseconds()))); + bool read_ok = false; + /// We are waiting for a packet from the client. Thus, every `POLL_INTERVAL` seconds check whether we need to shut down. while (true) { - if (static_cast(*in).poll(poll_interval)) + if (static_cast(*in).poll(timeout_ms)) + { + /// If client disconnected. + if (in->eof()) + { + LOG_INFO(log, "Client has dropped the connection, cancel the query."); + state.is_connection_closed = true; + break; + } + + /// We accept and process data. + read_ok = receivePacket(); break; + } /// Do we need to shut down? if (server.isCancelled()) - return false; + break; /** Have we waited for data for too long? * If we periodically poll, the receive_timeout of the socket itself does not work. * Therefore, an additional check is added. */ Float64 elapsed = watch.elapsedSeconds(); - if (elapsed > static_cast(receive_timeout)) + if (elapsed > static_cast(receive_timeout.totalSeconds())) { throw Exception(ErrorCodes::SOCKET_TIMEOUT, "Timeout exceeded while receiving data from client. Waited for {} seconds, timeout is {} seconds.", - static_cast(elapsed), receive_timeout); + static_cast(elapsed), receive_timeout.totalSeconds()); } } - /// If client disconnected. - if (in->eof()) - { - LOG_INFO(log, "Client has dropped the connection, cancel the query."); - state.is_connection_closed = true; - return false; - } + if (read_ok) + sendLogs(); + else + state.read_all_data = true; - /// We accept and process data. And if they are over, then we leave. - if (!receivePacket()) - return false; - - sendLogs(); - return true; + return read_ok; } -std::tuple TCPHandler::getReadTimeouts(const Settings & connection_settings) +void TCPHandler::readData() { - const auto receive_timeout = query_context->getSettingsRef().receive_timeout.value; - - /// Poll interval should not be greater than receive_timeout - const size_t default_poll_interval = connection_settings.poll_interval * 1000000; - size_t current_poll_interval = static_cast(receive_timeout.totalMicroseconds()); - constexpr size_t min_poll_interval = 5000; // 5 ms - size_t poll_interval = std::max(min_poll_interval, std::min(default_poll_interval, current_poll_interval)); - - return std::make_tuple(poll_interval, receive_timeout.totalSeconds()); -} - - -void TCPHandler::readData(const Settings & connection_settings) -{ - auto [poll_interval, receive_timeout] = getReadTimeouts(connection_settings); sendLogs(); - while (readDataNext(poll_interval, receive_timeout)) + while (readDataNext()) ; } -void TCPHandler::processInsertQuery(const Settings & connection_settings) +void TCPHandler::skipData() +{ + state.skipping_data = true; + SCOPE_EXIT({ state.skipping_data = false; }); + + while (readDataNext()) + ; +} + + +void TCPHandler::processInsertQuery() { /** Made above the rest of the lines, so that in case of `writePrefix` function throws an exception, * client receive exception before sending data. @@ -595,7 +587,7 @@ void TCPHandler::processInsertQuery(const Settings & connection_settings) try { - readData(connection_settings); + readData(); } catch (...) { @@ -634,7 +626,7 @@ void TCPHandler::processOrdinaryQuery() break; } - if (after_send_progress.elapsed() / 1000 >= query_context->getSettingsRef().interactive_delay) + if (after_send_progress.elapsed() / 1000 >= interactive_delay) { /// Some time passed. after_send_progress.restart(); @@ -643,7 +635,7 @@ void TCPHandler::processOrdinaryQuery() sendLogs(); - if (async_in.poll(query_context->getSettingsRef().interactive_delay / 1000)) + if (async_in.poll(interactive_delay / 1000)) { const auto block = async_in.read(); if (!block) @@ -698,7 +690,7 @@ void TCPHandler::processOrdinaryQueryWithProcessors() CurrentMetrics::Increment query_thread_metric_increment{CurrentMetrics::QueryThread}; Block block; - while (executor.pull(block, query_context->getSettingsRef().interactive_delay / 1000)) + while (executor.pull(block, interactive_delay / 1000)) { std::lock_guard lock(task_callback_mutex); @@ -709,7 +701,7 @@ void TCPHandler::processOrdinaryQueryWithProcessors() break; } - if (after_send_progress.elapsed() / 1000 >= query_context->getSettingsRef().interactive_delay) + if (after_send_progress.elapsed() / 1000 >= interactive_delay) { /// Some time passed and there is a progress. after_send_progress.restart(); @@ -755,13 +747,14 @@ void TCPHandler::processTablesStatusRequest() { TablesStatusRequest request; request.read(*in, client_tcp_protocol_version); - const auto session_context = session->sessionContext(); + + ContextPtr context_to_resolve_table_names = session->sessionContext() ? session->sessionContext() : server.context(); TablesStatusResponse response; for (const QualifiedTableName & table_name: request.tables) { - auto resolved_id = session_context->tryResolveStorageID({table_name.database, table_name.table}); - StoragePtr table = DatabaseCatalog::instance().tryGetTable(resolved_id, session_context); + auto resolved_id = context_to_resolve_table_names->tryResolveStorageID({table_name.database, table_name.table}); + StoragePtr table = DatabaseCatalog::instance().tryGetTable(resolved_id, context_to_resolve_table_names); if (!table) continue; @@ -781,11 +774,10 @@ void TCPHandler::processTablesStatusRequest() writeVarUInt(Protocol::Server::TablesStatusResponse, *out); /// For testing hedged requests - const Settings & settings = query_context->getSettingsRef(); - if (settings.sleep_in_send_tables_status_ms.totalMilliseconds()) + if (sleep_in_send_tables_status.totalMilliseconds()) { out->next(); - std::chrono::milliseconds ms(settings.sleep_in_send_tables_status_ms.totalMilliseconds()); + std::chrono::milliseconds ms(sleep_in_send_tables_status.totalMilliseconds()); std::this_thread::sleep_for(ms); } @@ -977,22 +969,21 @@ void TCPHandler::receiveHello() (!user.empty() ? ", user: " + user : "") ); - if (user != USER_INTERSERVER_MARKER) - { - auto & client_info = session->getClientInfo(); - client_info.interface = ClientInfo::Interface::TCP; - client_info.client_name = client_name; - client_info.client_version_major = client_version_major; - client_info.client_version_minor = client_version_minor; - client_info.client_version_patch = client_version_patch; - client_info.client_tcp_protocol_version = client_tcp_protocol_version; + auto & client_info = session->getClientInfo(); + client_info.client_name = client_name; + client_info.client_version_major = client_version_major; + client_info.client_version_minor = client_version_minor; + client_info.client_version_patch = client_version_patch; + client_info.client_tcp_protocol_version = client_tcp_protocol_version; - session->setUser(user, password, socket().peerAddress()); - } - else + is_interserver_mode = (user == USER_INTERSERVER_MARKER); + if (is_interserver_mode) { receiveClusterNameAndSalt(); + return; } + + session->authenticate(user, password, socket().peerAddress()); } @@ -1039,8 +1030,11 @@ bool TCPHandler::receivePacket() { case Protocol::Client::IgnoredPartUUIDs: /// Part uuids packet if any comes before query. + if (!state.empty() || state.part_uuids_to_ignore) + receiveUnexpectedIgnoredPartUUIDs(); receiveIgnoredPartUUIDs(); return true; + case Protocol::Client::Query: if (!state.empty()) receiveUnexpectedQuery(); @@ -1049,8 +1043,10 @@ bool TCPHandler::receivePacket() case Protocol::Client::Data: case Protocol::Client::Scalar: + if (state.skipping_data) + return receiveUnexpectedData(false); if (state.empty()) - receiveUnexpectedData(); + receiveUnexpectedData(true); return receiveData(packet_type == Protocol::Client::Scalar); case Protocol::Client::Ping: @@ -1061,10 +1057,9 @@ bool TCPHandler::receivePacket() case Protocol::Client::Cancel: { /// For testing connection collector. - const Settings & settings = query_context->getSettingsRef(); - if (settings.sleep_in_receive_cancel_ms.totalMilliseconds()) + if (sleep_in_receive_cancel.totalMilliseconds()) { - std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds()); + std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds()); std::this_thread::sleep_for(ms); } @@ -1086,14 +1081,18 @@ bool TCPHandler::receivePacket() } } + void TCPHandler::receiveIgnoredPartUUIDs() { - state.part_uuids = true; - std::vector uuids; - readVectorBinary(uuids, *in); + readVectorBinary(state.part_uuids_to_ignore.emplace(), *in); +} - if (!uuids.empty()) - query_context->getIgnoredPartUUIDs()->add(uuids); + +void TCPHandler::receiveUnexpectedIgnoredPartUUIDs() +{ + std::vector skip_part_uuids; + readVectorBinary(skip_part_uuids, *in); + throw NetException("Unexpected packet IgnoredPartUUIDs received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); } @@ -1107,10 +1106,9 @@ String TCPHandler::receiveReadTaskResponseAssumeLocked() { state.is_cancelled = true; /// For testing connection collector. - const Settings & settings = query_context->getSettingsRef(); - if (settings.sleep_in_receive_cancel_ms.totalMilliseconds()) + if (sleep_in_receive_cancel.totalMilliseconds()) { - std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds()); + std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds()); std::this_thread::sleep_for(ms); } return {}; @@ -1141,14 +1139,14 @@ void TCPHandler::receiveClusterNameAndSalt() if (salt.empty()) throw NetException("Empty salt is not allowed", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); - cluster_secret = query_context->getCluster(cluster)->getSecret(); + cluster_secret = server.context()->getCluster(cluster)->getSecret(); } catch (const Exception & e) { try { /// We try to send error information to the client. - sendException(e, session->getSettings().calculate_text_stack_trace); + sendException(e, send_exception_with_stack_trace); } catch (...) {} @@ -1163,27 +1161,12 @@ void TCPHandler::receiveQuery() state.is_empty = false; readStringBinary(state.query_id, *in); -// query_context = session->makeQueryContext(state.query_id); - /// Client info - ClientInfo & client_info = query_context->getClientInfo(); + /// Read client info. + ClientInfo client_info = session->getClientInfo(); if (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) client_info.read(*in, client_tcp_protocol_version); - /// For better support of old clients, that does not send ClientInfo. - if (client_info.query_kind == ClientInfo::QueryKind::NO_QUERY) - { - client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; - client_info.client_name = client_name; - client_info.client_version_major = client_version_major; - client_info.client_version_minor = client_version_minor; - client_info.client_version_patch = client_version_patch; - client_info.client_tcp_protocol_version = client_tcp_protocol_version; - } - - /// Set fields, that are known apriori. - client_info.interface = ClientInfo::Interface::TCP; - /// Per query settings are also passed via TCP. /// We need to check them before applying due to they can violate the settings constraints. auto settings_format = (client_tcp_protocol_version >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS) @@ -1204,12 +1187,11 @@ void TCPHandler::receiveQuery() readVarUInt(compression, *in); state.compression = static_cast(compression); + last_block_in.compression = state.compression; readStringBinary(state.query, *in); - /// It is OK to check only when query != INITIAL_QUERY, - /// since only in that case the actions will be done. - if (!cluster.empty() && client_info.query_kind != ClientInfo::QueryKind::INITIAL_QUERY) + if (is_interserver_mode) { #if USE_SSL std::string data(salt); @@ -1231,26 +1213,33 @@ void TCPHandler::receiveQuery() /// i.e. when the INSERT is done with the global context (w/o user). if (!client_info.initial_user.empty()) { - query_context->setUserWithoutCheckingPassword(client_info.initial_user, client_info.initial_address); - LOG_DEBUG(log, "User (initial): {}", query_context->getUserName()); + LOG_DEBUG(log, "User (initial): {}", client_info.initial_user); + session->authenticate(AlwaysAllowCredentials{client_info.initial_user}, client_info.initial_address); } - /// No need to update connection_context, since it does not requires user (it will not be used for query execution) #else throw Exception( "Inter-server secret support is disabled, because ClickHouse was built without SSL library", ErrorCodes::SUPPORT_IS_DISABLED); #endif } - else - { - query_context->setInitialRowPolicy(); - } + + query_context = session->makeQueryContext(std::move(client_info)); + + /// Sets the default database if it wasn't set earlier for the session context. + if (!default_database.empty() && !session->sessionContext()) + query_context->setCurrentDatabase(default_database); + + if (state.part_uuids_to_ignore) + query_context->getIgnoredPartUUIDs()->add(*state.part_uuids_to_ignore); + + query_context->setProgressCallback([this] (const Progress & value) { return this->updateProgress(value); }); /// /// Settings /// auto settings_changes = passed_settings.changes(); - if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY) + auto query_kind = query_context->getClientInfo().query_kind; + if (query_kind == ClientInfo::QueryKind::INITIAL_QUERY) { /// Throw an exception if the passed settings violate the constraints. query_context->checkSettingsConstraints(settings_changes); @@ -1262,40 +1251,24 @@ void TCPHandler::receiveQuery() } query_context->applySettingsChanges(settings_changes); + /// Use the received query id, or generate a random default. It is convenient + /// to also generate the default OpenTelemetry trace id at the same time, and + /// set the trace parent. + /// Notes: + /// 1) ClientInfo might contain upstream trace id, so we decide whether to use + /// the default ids after we have received the ClientInfo. + /// 2) There is the opentelemetry_start_trace_probability setting that + /// controls when we start a new trace. It can be changed via Native protocol, + /// so we have to apply the changes first. + query_context->setCurrentQueryId(state.query_id); + /// Disable function name normalization when it's a secondary query, because queries are either /// already normalized on initiator node, or not normalized and should remain unnormalized for /// compatibility. - if (client_info.query_kind == ClientInfo::QueryKind::SECONDARY_QUERY) + if (query_kind == ClientInfo::QueryKind::SECONDARY_QUERY) { query_context->setSetting("normalize_function_names", Field(0)); } - - // Use the received query id, or generate a random default. It is convenient - // to also generate the default OpenTelemetry trace id at the same time, and - // set the trace parent. - // Why is this done here and not earlier: - // 1) ClientInfo might contain upstream trace id, so we decide whether to use - // the default ids after we have received the ClientInfo. - // 2) There is the opentelemetry_start_trace_probability setting that - // controls when we start a new trace. It can be changed via Native protocol, - // so we have to apply the changes first. - query_context->setCurrentQueryId(state.query_id); - - // Set parameters of initial query. - if (client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY) - { - /// 'Current' fields was set at receiveHello. - client_info.initial_user = client_info.current_user; - client_info.initial_query_id = client_info.current_query_id; - client_info.initial_address = client_info.current_address; - } - - /// Sync timeouts on client and server during current query to avoid dangling queries on server - /// NOTE: We use settings.send_timeout for the receive timeout and vice versa (change arguments ordering in TimeoutSetter), - /// because settings.send_timeout is client-side setting which has opposite meaning on the server side. - /// NOTE: these settings are applied only for current connection (not for distributed tables' connections) - const Settings & settings = query_context->getSettingsRef(); - state.timeout_setter = std::make_unique(socket(), settings.receive_timeout, settings.send_timeout); } void TCPHandler::receiveUnexpectedQuery() @@ -1320,7 +1293,10 @@ void TCPHandler::receiveUnexpectedQuery() readStringBinary(skip_hash, *in, 32); readVarUInt(skip_uint_64, *in); + readVarUInt(skip_uint_64, *in); + last_block_in.compression = static_cast(skip_uint_64); + readStringBinary(skip_string, *in); throw NetException("Unexpected packet Query received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); @@ -1337,73 +1313,77 @@ bool TCPHandler::receiveData(bool scalar) /// Read one block from the network and write it down Block block = state.block_in->read(); - if (block) + if (!block) { - if (scalar) - { - /// Scalar value - query_context->addScalar(temporary_id.table_name, block); - } - else if (!state.need_receive_data_for_insert && !state.need_receive_data_for_input) - { - /// Data for external tables + state.read_all_data = true; + return false; + } - auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal); - StoragePtr storage; - /// If such a table does not exist, create it. - if (resolved) - { - storage = DatabaseCatalog::instance().getTable(resolved, query_context); - } - else - { - NamesAndTypesList columns = block.getNamesAndTypesList(); - auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {}); - storage = temporary_table.getTable(); - query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table)); - } - auto metadata_snapshot = storage->getInMemoryMetadataPtr(); - /// The data will be written directly to the table. - auto temporary_table_out = std::make_shared(storage->write(ASTPtr(), metadata_snapshot, query_context)); - temporary_table_out->write(block); - temporary_table_out->writeSuffix(); + if (scalar) + { + /// Scalar value + query_context->addScalar(temporary_id.table_name, block); + } + else if (!state.need_receive_data_for_insert && !state.need_receive_data_for_input) + { + /// Data for external tables - } - else if (state.need_receive_data_for_input) + auto resolved = query_context->tryResolveStorageID(temporary_id, Context::ResolveExternal); + StoragePtr storage; + /// If such a table does not exist, create it. + if (resolved) { - /// 'input' table function. - state.block_for_input = block; + storage = DatabaseCatalog::instance().getTable(resolved, query_context); } else { - /// INSERT query. - state.io.out->write(block); + NamesAndTypesList columns = block.getNamesAndTypesList(); + auto temporary_table = TemporaryTableHolder(query_context, ColumnsDescription{columns}, {}); + storage = temporary_table.getTable(); + query_context->addExternalTable(temporary_id.table_name, std::move(temporary_table)); } - return true; + auto metadata_snapshot = storage->getInMemoryMetadataPtr(); + /// The data will be written directly to the table. + auto temporary_table_out = std::make_shared(storage->write(ASTPtr(), metadata_snapshot, query_context)); + temporary_table_out->write(block); + temporary_table_out->writeSuffix(); + + } + else if (state.need_receive_data_for_input) + { + /// 'input' table function. + state.block_for_input = block; } else - return false; + { + /// INSERT query. + state.io.out->write(block); + } + return true; } -void TCPHandler::receiveUnexpectedData() + +bool TCPHandler::receiveUnexpectedData(bool throw_exception) { String skip_external_table_name; readStringBinary(skip_external_table_name, *in); std::shared_ptr maybe_compressed_in; - if (last_block_in.compression == Protocol::Compression::Enable) maybe_compressed_in = std::make_shared(*in, /* allow_different_codecs */ true); else maybe_compressed_in = in; - auto skip_block_in = std::make_shared( - *maybe_compressed_in, - last_block_in.header, - client_tcp_protocol_version); + auto skip_block_in = std::make_shared(*maybe_compressed_in, client_tcp_protocol_version); + bool read_ok = skip_block_in->read(); - skip_block_in->read(); - throw NetException("Unexpected packet Data received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); + if (!read_ok) + state.read_all_data = true; + + if (throw_exception) + throw NetException("Unexpected packet Data received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); + + return read_ok; } void TCPHandler::initBlockInput() @@ -1424,9 +1404,6 @@ void TCPHandler::initBlockInput() else if (state.need_receive_data_for_input) header = state.input_header; - last_block_in.header = header; - last_block_in.compression = state.compression; - state.block_in = std::make_shared( *state.maybe_compressed_in, header, @@ -1439,10 +1416,9 @@ void TCPHandler::initBlockOutput(const Block & block) { if (!state.block_out) { + const Settings & query_settings = query_context->getSettingsRef(); if (!state.maybe_compressed_out) { - const Settings & query_settings = query_context->getSettingsRef(); - std::string method = Poco::toUpper(query_settings.network_compression_method.toString()); std::optional level; if (method == "ZSTD") @@ -1463,7 +1439,7 @@ void TCPHandler::initBlockOutput(const Block & block) *state.maybe_compressed_out, client_tcp_protocol_version, block.cloneEmpty(), - !session->getSettings().low_cardinality_allow_in_native_format); + !query_settings.low_cardinality_allow_in_native_format); } } @@ -1472,11 +1448,12 @@ void TCPHandler::initLogsBlockOutput(const Block & block) if (!state.logs_block_out) { /// Use uncompressed stream since log blocks usually contain only one row + const Settings & query_settings = query_context->getSettingsRef(); state.logs_block_out = std::make_shared( *out, client_tcp_protocol_version, block.cloneEmpty(), - !session->getSettings().low_cardinality_allow_in_native_format); + !query_settings.low_cardinality_allow_in_native_format); } } @@ -1486,7 +1463,7 @@ bool TCPHandler::isQueryCancelled() if (state.is_cancelled || state.sent_all_data) return true; - if (after_check_cancelled.elapsed() / 1000 < query_context->getSettingsRef().interactive_delay) + if (after_check_cancelled.elapsed() / 1000 < interactive_delay) return false; after_check_cancelled.restart(); @@ -1514,10 +1491,9 @@ bool TCPHandler::isQueryCancelled() state.is_cancelled = true; /// For testing connection collector. { - const Settings & settings = query_context->getSettingsRef(); - if (settings.sleep_in_receive_cancel_ms.totalMilliseconds()) + if (sleep_in_receive_cancel.totalMilliseconds()) { - std::chrono::milliseconds ms(settings.sleep_in_receive_cancel_ms.totalMilliseconds()); + std::chrono::milliseconds ms(sleep_in_receive_cancel.totalMilliseconds()); std::this_thread::sleep_for(ms); } } @@ -1555,11 +1531,10 @@ void TCPHandler::sendData(const Block & block) writeStringBinary("", *out); /// For testing hedged requests - const Settings & settings = query_context->getSettingsRef(); - if (block.rows() > 0 && settings.sleep_in_send_data_ms.totalMilliseconds()) + if (block.rows() > 0 && query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds()) { out->next(); - std::chrono::milliseconds ms(settings.sleep_in_send_data_ms.totalMilliseconds()); + std::chrono::milliseconds ms(query_context->getSettingsRef().sleep_in_send_data_ms.totalMilliseconds()); std::this_thread::sleep_for(ms); } diff --git a/src/Server/TCPHandler.h b/src/Server/TCPHandler.h index d8e156ee7be..7f75d0ac04b 100644 --- a/src/Server/TCPHandler.h +++ b/src/Server/TCPHandler.h @@ -27,7 +27,9 @@ namespace DB { class Session; +struct Settings; class ColumnsDescription; +struct BlockStreamProfileInfo; /// State of query processing. struct QueryState @@ -66,11 +68,11 @@ struct QueryState bool sent_all_data = false; /// Request requires data from the client (INSERT, but not INSERT SELECT). bool need_receive_data_for_insert = false; - /// Temporary tables read - bool temporary_tables_read = false; + /// Data was read. + bool read_all_data = false; /// A state got uuids to exclude from a query - bool part_uuids = false; + std::optional> part_uuids_to_ignore; /// Request requires data from client for function input() bool need_receive_data_for_input = false; @@ -79,6 +81,9 @@ struct QueryState /// sample block from StorageInput Block input_header; + /// If true, the data packets will be skipped instead of reading. Used to recover after errors. + bool skipping_data = false; + /// To output progress, the difference after the previous sending of progress. Progress progress; @@ -100,7 +105,6 @@ struct QueryState struct LastBlockInputParameters { Protocol::Compression compression = Protocol::Compression::Disable; - Block header; }; class TCPHandler : public Poco::Net::TCPServerConnection @@ -133,11 +137,20 @@ private: UInt64 client_version_patch = 0; UInt64 client_tcp_protocol_version = 0; + /// Connection settings, which are extracted from a context. + bool send_exception_with_stack_trace = true; + Poco::Timespan send_timeout = DBMS_DEFAULT_SEND_TIMEOUT_SEC; + Poco::Timespan receive_timeout = DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC; + UInt64 poll_interval = DBMS_DEFAULT_POLL_INTERVAL; + UInt64 idle_connection_timeout = 3600; + UInt64 interactive_delay = 100000; + Poco::Timespan sleep_in_send_tables_status; + UInt64 unknown_packet_in_send_data = 0; + Poco::Timespan sleep_in_receive_cancel; + std::unique_ptr session; ContextMutablePtr query_context; - size_t unknown_packet_in_send_data = 0; - /// Streams for reading/writing from/to client connection socket. std::shared_ptr in; std::shared_ptr out; @@ -149,6 +162,7 @@ private: String default_database; /// For inter-server secret (remote_server.*.secret) + bool is_interserver_mode = false; String salt; String cluster; String cluster_secret; @@ -168,6 +182,8 @@ private: void runImpl(); + void extractConnectionSettingsFromContext(const ContextPtr & context); + bool receiveProxyHeader(); void receiveHello(); bool receivePacket(); @@ -175,18 +191,19 @@ private: void receiveIgnoredPartUUIDs(); String receiveReadTaskResponseAssumeLocked(); bool receiveData(bool scalar); - bool readDataNext(size_t poll_interval, time_t receive_timeout); - void readData(const Settings & connection_settings); + bool readDataNext(); + void readData(); + void skipData(); void receiveClusterNameAndSalt(); - std::tuple getReadTimeouts(const Settings & connection_settings); - [[noreturn]] void receiveUnexpectedData(); + bool receiveUnexpectedData(bool throw_exception = true); [[noreturn]] void receiveUnexpectedQuery(); + [[noreturn]] void receiveUnexpectedIgnoredPartUUIDs(); [[noreturn]] void receiveUnexpectedHello(); [[noreturn]] void receiveUnexpectedTablesStatusRequest(); /// Process INSERT query - void processInsertQuery(const Settings & connection_settings); + void processInsertQuery(); /// Process a request that does not require the receiving of data blocks from the client void processOrdinaryQuery(); diff --git a/src/TableFunctions/TableFunctionMySQL.cpp b/src/TableFunctions/TableFunctionMySQL.cpp index 92387b13d55..09f9cf8b1f5 100644 --- a/src/TableFunctions/TableFunctionMySQL.cpp +++ b/src/TableFunctions/TableFunctionMySQL.cpp @@ -61,9 +61,8 @@ void TableFunctionMySQL::parseArguments(const ASTPtr & ast_function, ContextPtr user_name = args[3]->as().value.safeGet(); password = args[4]->as().value.safeGet(); - const auto & settings = context->getSettingsRef(); /// Split into replicas if needed. 3306 is the default MySQL port number - const size_t max_addresses = settings.glob_expansion_max_elements; + size_t max_addresses = context->getSettingsRef().glob_expansion_max_elements; auto addresses = parseRemoteDescriptionForExternalDatabase(host_port, max_addresses, 3306); pool.emplace(remote_database_name, addresses, user_name, password); diff --git a/tests/integration/test_read_temporary_tables_on_failure/test.py b/tests/integration/test_read_temporary_tables_on_failure/test.py index e62c7c9eaec..ae59fb31641 100644 --- a/tests/integration/test_read_temporary_tables_on_failure/test.py +++ b/tests/integration/test_read_temporary_tables_on_failure/test.py @@ -24,3 +24,4 @@ def test_different_versions(start_cluster): node.query("SELECT 1", settings={'max_concurrent_queries_for_user': 1}) assert node.contains_in_log('Too many simultaneous queries for user') assert not node.contains_in_log('Unknown packet') + assert not node.contains_in_log('Unexpected packet') diff --git a/tests/queries/0_stateless/01455_opentelemetry_distributed.reference b/tests/queries/0_stateless/01455_opentelemetry_distributed.reference index b40e4f87c13..f45f1ab6104 100644 --- a/tests/queries/0_stateless/01455_opentelemetry_distributed.reference +++ b/tests/queries/0_stateless/01455_opentelemetry_distributed.reference @@ -1,8 +1,20 @@ ===http=== +{"query":"select 1 from remote('127.0.0.2', system, one) format Null\n","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1} +{"query":"DESC TABLE system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1} +{"query":"DESC TABLE system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1} +{"query":"SELECT 1 FROM system.one","status":"QueryFinish","tracestate":"some custom state","sorted_by_start_time":1} +{"query":"DESC TABLE system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1} +{"query":"DESC TABLE system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1} +{"query":"SELECT 1 FROM system.one","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1} +{"query":"select 1 from remote('127.0.0.2', system, one) format Null\n","query_status":"QueryFinish","tracestate":"some custom state","sorted_by_finish_time":1} {"total spans":"4","unique spans":"4","unique non-zero parent spans":"3"} {"initial query spans with proper parent":"1"} {"unique non-empty tracestate values":"1"} ===native=== +{"query":"select * from url('http:\/\/127.0.0.2:8123\/?query=select%201%20format%20Null', CSV, 'a int')","status":"QueryFinish","tracestate":"another custom state","sorted_by_start_time":1} +{"query":"select 1 format Null\n","status":"QueryFinish","tracestate":"another custom state","sorted_by_start_time":1} +{"query":"select 1 format Null\n","query_status":"QueryFinish","tracestate":"another custom state","sorted_by_finish_time":1} +{"query":"select * from url('http:\/\/127.0.0.2:8123\/?query=select%201%20format%20Null', CSV, 'a int')","query_status":"QueryFinish","tracestate":"another custom state","sorted_by_finish_time":1} {"total spans":"2","unique spans":"2","unique non-zero parent spans":"2"} {"initial query spans with proper parent":"1"} {"unique non-empty tracestate values":"1"} diff --git a/tests/queries/0_stateless/01455_opentelemetry_distributed.sh b/tests/queries/0_stateless/01455_opentelemetry_distributed.sh index 8f034b0bf61..59cd1b57d1e 100755 --- a/tests/queries/0_stateless/01455_opentelemetry_distributed.sh +++ b/tests/queries/0_stateless/01455_opentelemetry_distributed.sh @@ -12,6 +12,28 @@ function check_log ${CLICKHOUSE_CLIENT} --format=JSONEachRow -nq " system flush logs; +-- Show queries sorted by start time. +select attribute['db.statement'] as query, + attribute['clickhouse.query_status'] as status, + attribute['clickhouse.tracestate'] as tracestate, + 1 as sorted_by_start_time + from system.opentelemetry_span_log + where trace_id = reinterpretAsUUID(reverse(unhex('$trace_id'))) + and operation_name = 'query' + order by start_time_us + ; + +-- Show queries sorted by finish time. +select attribute['db.statement'] as query, + attribute['clickhouse.query_status'] as query_status, + attribute['clickhouse.tracestate'] as tracestate, + 1 as sorted_by_finish_time + from system.opentelemetry_span_log + where trace_id = reinterpretAsUUID(reverse(unhex('$trace_id'))) + and operation_name = 'query' + order by finish_time_us + ; + -- Check the number of query spans with given trace id, to verify it was -- propagated. select count(*) "'"'"total spans"'"'", @@ -89,10 +111,10 @@ check_log echo "===sampled===" query_id=$(${CLICKHOUSE_CLIENT} -q "select lower(hex(reverse(reinterpretAsString(generateUUIDv4()))))") -for i in {1..200} +for i in {1..20} do ${CLICKHOUSE_CLIENT} \ - --opentelemetry_start_trace_probability=0.1 \ + --opentelemetry_start_trace_probability=0.5 \ --query_id "$query_id-$i" \ --query "select 1 from remote('127.0.0.2', system, one) format Null" \ & @@ -108,8 +130,8 @@ wait ${CLICKHOUSE_CLIENT} -q "system flush logs" ${CLICKHOUSE_CLIENT} -q " - -- expect 200 * 0.1 = 20 sampled events on average - select if(count() > 1 and count() < 50, 'OK', 'Fail') + -- expect 20 * 0.5 = 10 sampled events on average + select if(2 <= count() and count() <= 18, 'OK', 'Fail') from system.opentelemetry_span_log where operation_name = 'query' and parent_span_id = 0 -- only account for the initial queries