From 018546a57d4553c44613c11aa3b0eb616461e60c Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Sat, 1 Jul 2023 17:39:50 +0200 Subject: [PATCH 1/5] Avoid keeping lock Context::getLock() while recalculating access rights of a connected user. --- src/Access/AccessControl.cpp | 53 +++------ src/Access/AccessControl.h | 10 +- src/Access/ContextAccess.cpp | 74 +++++++------ src/Access/ContextAccess.h | 51 +-------- src/Access/ContextAccessParams.cpp | 172 +++++++++++++++++++++++++++++ src/Access/ContextAccessParams.h | 64 +++++++++++ src/Interpreters/Context.cpp | 102 +++++++++-------- src/Interpreters/Context.h | 7 +- 8 files changed, 362 insertions(+), 171 deletions(-) create mode 100644 src/Access/ContextAccessParams.cpp create mode 100644 src/Access/ContextAccessParams.h diff --git a/src/Access/AccessControl.cpp b/src/Access/AccessControl.cpp index 6179c823b56..41ac3f42ee2 100644 --- a/src/Access/AccessControl.cpp +++ b/src/Access/AccessControl.cpp @@ -72,18 +72,26 @@ public: std::shared_ptr getContextAccess(const ContextAccessParams & params) { - std::lock_guard lock{mutex}; - auto x = cache.get(params); - if (x) { - if ((*x)->tryGetUser()) - return *x; - /// No user, probably the user has been dropped while it was in the cache. - cache.remove(params); + std::lock_guard lock{mutex}; + auto x = cache.get(params); + if (x) + { + if ((*x)->getUserID() && !(*x)->tryGetUser()) + cache.remove(params); /// The user has been dropped while it was in the cache. + else + return *x; + } } + auto res = std::make_shared(access_control, params); res->initialize(); - cache.add(params, res); + + { + std::lock_guard lock{mutex}; + cache.add(params, res); + } + return res; } @@ -713,35 +721,6 @@ int AccessControl::getBcryptWorkfactor() const } -std::shared_ptr AccessControl::getContextAccess( - const UUID & user_id, - const std::vector & current_roles, - bool use_default_roles, - const Settings & settings, - const String & current_database, - const ClientInfo & client_info) const -{ - ContextAccessParams params; - params.user_id = user_id; - params.current_roles.insert(current_roles.begin(), current_roles.end()); - params.use_default_roles = use_default_roles; - params.current_database = current_database; - params.readonly = settings.readonly; - params.allow_ddl = settings.allow_ddl; - params.allow_introspection = settings.allow_introspection_functions; - params.interface = client_info.interface; - params.http_method = client_info.http_method; - params.address = client_info.current_address.host(); - params.quota_key = client_info.quota_key; - - /// Extract the last entry from comma separated list of X-Forwarded-For addresses. - /// Only the last proxy can be trusted (if any). - params.forwarded_address = client_info.getLastForwardedFor(); - - return getContextAccess(params); -} - - std::shared_ptr AccessControl::getContextAccess(const ContextAccessParams & params) const { return context_access_cache->getContextAccess(params); diff --git a/src/Access/AccessControl.h b/src/Access/AccessControl.h index 2a8293a49e7..74816090f88 100644 --- a/src/Access/AccessControl.h +++ b/src/Access/AccessControl.h @@ -25,7 +25,7 @@ namespace Poco namespace DB { class ContextAccess; -struct ContextAccessParams; +class ContextAccessParams; struct User; using UserPtr = std::shared_ptr; class EnabledRoles; @@ -181,14 +181,6 @@ public: void setSettingsConstraintsReplacePrevious(bool enable) { settings_constraints_replace_previous = enable; } bool doesSettingsConstraintsReplacePrevious() const { return settings_constraints_replace_previous; } - std::shared_ptr getContextAccess( - const UUID & user_id, - const std::vector & current_roles, - bool use_default_roles, - const Settings & settings, - const String & current_database, - const ClientInfo & client_info) const; - std::shared_ptr getContextAccess(const ContextAccessParams & params) const; std::shared_ptr getEnabledRoles( diff --git a/src/Access/ContextAccess.cpp b/src/Access/ContextAccess.cpp index 9c57853679f..cb8f1a5a48e 100644 --- a/src/Access/ContextAccess.cpp +++ b/src/Access/ContextAccess.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -221,15 +222,15 @@ namespace } -ContextAccess::ContextAccess(const AccessControl & access_control_, const Params & params_) - : access_control(&access_control_) - , params(params_) +std::shared_ptr ContextAccess::fromContext(const ContextPtr & context) { + return context->getAccess(); } -ContextAccess::ContextAccess(FullAccess) - : is_full_access(true), access(std::make_shared(AccessRights::getFullAccess())), access_with_implicit(access) +ContextAccess::ContextAccess(const AccessControl & access_control_, const Params & params_) + : access_control(&access_control_) + , params(params_) { } @@ -251,18 +252,31 @@ ContextAccess::~ContextAccess() void ContextAccess::initialize() { - std::lock_guard lock{mutex}; - subscription_for_user_change = access_control->subscribeForChanges( - *params.user_id, [weak_ptr = weak_from_this()](const UUID &, const AccessEntityPtr & entity) - { - auto ptr = weak_ptr.lock(); - if (!ptr) - return; - UserPtr changed_user = entity ? typeid_cast(entity) : nullptr; - std::lock_guard lock2{ptr->mutex}; - ptr->setUser(changed_user); - }); - setUser(access_control->read(*params.user_id)); + std::lock_guard lock{mutex}; + + if (params.full_access) + { + access = std::make_shared(AccessRights::getFullAccess()); + access_with_implicit = access; + return; + } + + if (!params.user_id) + throw Exception(ErrorCodes::LOGICAL_ERROR, "No user in current context, it's a bug"); + + subscription_for_user_change = access_control->subscribeForChanges( + *params.user_id, + [weak_ptr = weak_from_this()](const UUID &, const AccessEntityPtr & entity) + { + auto ptr = weak_ptr.lock(); + if (!ptr) + return; + UserPtr changed_user = entity ? typeid_cast(entity) : nullptr; + std::lock_guard lock2{ptr->mutex}; + ptr->setUser(changed_user); + }); + + setUser(access_control->read(*params.user_id)); } @@ -294,10 +308,10 @@ void ContextAccess::setUser(const UserPtr & user_) const current_roles = user->granted_roles.findGranted(user->default_roles); current_roles_with_admin_option = user->granted_roles.findGrantedWithAdminOption(user->default_roles); } - else + else if (params.current_roles) { - current_roles = user->granted_roles.findGranted(params.current_roles); - current_roles_with_admin_option = user->granted_roles.findGrantedWithAdminOption(params.current_roles); + current_roles = user->granted_roles.findGranted(*params.current_roles); + current_roles_with_admin_option = user->granted_roles.findGrantedWithAdminOption(*params.current_roles); } subscription_for_roles_changes.reset(); @@ -316,12 +330,16 @@ void ContextAccess::setRolesInfo(const std::shared_ptr & { assert(roles_info_); roles_info = roles_info_; + enabled_row_policies = access_control->getEnabledRowPolicies( *params.user_id, roles_info->enabled_roles); + enabled_quota = access_control->getEnabledQuota( *params.user_id, user_name, roles_info->enabled_roles, params.address, params.forwarded_address, params.quota_key); + enabled_settings = access_control->getEnabledSettings( *params.user_id, user->settings, roles_info->enabled_roles, roles_info->settings_from_enabled_roles); + calculateAccessRights(); } @@ -417,14 +435,6 @@ std::optional ContextAccess::getQuotaUsage() const } -std::shared_ptr ContextAccess::getFullAccess() -{ - static const std::shared_ptr res = - [] { return std::shared_ptr(new ContextAccess{kFullAccess}); }(); - return res; -} - - SettingsChanges ContextAccess::getDefaultSettings() const { std::lock_guard lock{mutex}; @@ -478,7 +488,7 @@ bool ContextAccess::checkAccessImplHelper(AccessFlags flags, const Args &... arg throw Exception(ErrorCodes::UNKNOWN_USER, "{}: User has been dropped", getUserName()); } - if (is_full_access) + if (params.full_access) return true; auto access_granted = [&] @@ -706,7 +716,7 @@ bool ContextAccess::checkAdminOptionImplHelper(const Container & role_ids, const return false; }; - if (is_full_access) + if (params.full_access) return true; if (user_was_dropped) @@ -806,7 +816,7 @@ void ContextAccess::checkAdminOption(const std::vector & role_ids, const s void ContextAccess::checkGranteeIsAllowed(const UUID & grantee_id, const IAccessEntity & grantee) const { - if (is_full_access) + if (params.full_access) return; auto current_user = getUser(); @@ -816,7 +826,7 @@ void ContextAccess::checkGranteeIsAllowed(const UUID & grantee_id, const IAccess void ContextAccess::checkGranteesAreAllowed(const std::vector & grantee_ids) const { - if (is_full_access) + if (params.full_access) return; auto current_user = getUser(); diff --git a/src/Access/ContextAccess.h b/src/Access/ContextAccess.h index 60bad0118fc..4c96ef5c11f 100644 --- a/src/Access/ContextAccess.h +++ b/src/Access/ContextAccess.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -30,47 +31,18 @@ class AccessControl; class IAST; struct IAccessEntity; using ASTPtr = std::shared_ptr; - - -struct ContextAccessParams -{ - std::optional user_id; - boost::container::flat_set current_roles; - bool use_default_roles = false; - UInt64 readonly = 0; - bool allow_ddl = false; - bool allow_introspection = false; - String current_database; - ClientInfo::Interface interface = ClientInfo::Interface::TCP; - ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; - Poco::Net::IPAddress address; - String forwarded_address; - String quota_key; - - auto toTuple() const - { - return std::tie( - user_id, current_roles, use_default_roles, readonly, allow_ddl, allow_introspection, - current_database, interface, http_method, address, forwarded_address, quota_key); - } - - friend bool operator ==(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return lhs.toTuple() == rhs.toTuple(); } - friend bool operator !=(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return !(lhs == rhs); } - friend bool operator <(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return lhs.toTuple() < rhs.toTuple(); } - friend bool operator >(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return rhs < lhs; } - friend bool operator <=(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return !(rhs < lhs); } - friend bool operator >=(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return !(lhs < rhs); } -}; +class Context; +using ContextPtr = std::shared_ptr; class ContextAccess : public std::enable_shared_from_this { public: + static std::shared_ptr fromContext(const ContextPtr & context); + using Params = ContextAccessParams; const Params & getParams() const { return params; } - ContextAccess(const AccessControl & access_control_, const Params & params_); - /// Returns the current user. Throws if user is nullptr. UserPtr getUser() const; /// Same as above, but can return nullptr. @@ -161,22 +133,12 @@ public: /// Checks if grantees are allowed for the current user, throws an exception if not. void checkGranteesAreAllowed(const std::vector & grantee_ids) const; - /// Makes an instance of ContextAccess which provides full access to everything - /// without any limitations. This is used for the global context. - static std::shared_ptr getFullAccess(); - + ContextAccess(const AccessControl & access_control_, const Params & params_); ~ContextAccess(); private: friend class AccessControl; - struct FullAccess {}; - static const FullAccess kFullAccess; - - /// Makes an instance of ContextAccess which provides full access to everything - /// without any limitations. This is used for the global context. - explicit ContextAccess(FullAccess); - void initialize(); void setUser(const UserPtr & user_) const TSA_REQUIRES(mutex); void setRolesInfo(const std::shared_ptr & roles_info_) const TSA_REQUIRES(mutex); @@ -223,7 +185,6 @@ private: const AccessControl * access_control = nullptr; const Params params; - const bool is_full_access = false; mutable std::atomic user_was_dropped = false; mutable std::atomic trace_log = nullptr; diff --git a/src/Access/ContextAccessParams.cpp b/src/Access/ContextAccessParams.cpp new file mode 100644 index 00000000000..7963e83dddf --- /dev/null +++ b/src/Access/ContextAccessParams.cpp @@ -0,0 +1,172 @@ +#include +#include +#include + + +namespace DB +{ + +ContextAccessParams::ContextAccessParams( + const std::optional user_id_, + bool full_access_, + bool use_default_roles_, + const std::shared_ptr> & current_roles_, + const Settings & settings_, + const String & current_database_, + const ClientInfo & client_info_) + : user_id(user_id_) + , full_access(full_access_) + , use_default_roles(use_default_roles_) + , current_roles(current_roles_) + , readonly(settings_.readonly) + , allow_ddl(settings_.allow_ddl) + , allow_introspection(settings_.allow_introspection_functions) + , current_database(current_database_) + , interface(client_info_.interface) + , http_method(client_info_.http_method) + , address(client_info_.current_address.host()) + , forwarded_address(client_info_.getLastForwardedFor()) + , quota_key(client_info_.quota_key) +{ +} + +String ContextAccessParams::toString() const +{ + WriteBufferFromOwnString out; + auto separator = [&] { return out.stringView().empty() ? "" : ", "; }; + if (user_id) + out << separator() << "user_id = " << *user_id; + if (full_access) + out << separator() << "full_access = " << full_access; + if (use_default_roles) + out << separator() << "use_default_roles = " << use_default_roles; + if (current_roles && !current_roles->empty()) + { + out << separator() << "current_roles = ["; + for (size_t i = 0; i != current_roles->size(); ++i) + { + if (i) + out << ", "; + out << (*current_roles)[i]; + } + out << "]"; + } + if (readonly) + out << separator() << "readonly = " << readonly; + if (allow_ddl) + out << separator() << "allow_ddl = " << allow_ddl; + if (allow_introspection) + out << separator() << "allow_introspection = " << allow_introspection; + if (!current_database.empty()) + out << separator() << "current_database = " << current_database; + out << separator() << "interface = " << magic_enum::enum_name(interface); + if (http_method != ClientInfo::HTTPMethod::UNKNOWN) + out << separator() << "http_method = " << magic_enum::enum_name(http_method); + if (!address.isWildcard()) + out << separator() << "address = " << address.toString(); + if (!forwarded_address.empty()) + out << separator() << "forwarded_address = " << forwarded_address; + if (!quota_key.empty()) + out << separator() << "quota_key = " << quota_key; + return out.str(); +} + +bool operator ==(const ContextAccessParams & left, const ContextAccessParams & right) +{ + auto check_equals = [](const auto & x, const auto & y) + { + if constexpr (::detail::is_shared_ptr_v>) + { + if (!x) + return !y; + else if (!y) + return false; + else + return *x == *y; + } + else + { + return x == y; + } + }; + + #define CONTEXT_ACCESS_PARAMS_EQUALS(name) \ + if (!check_equals(left.name, right.name)) \ + return false; + + CONTEXT_ACCESS_PARAMS_EQUALS(user_id) + CONTEXT_ACCESS_PARAMS_EQUALS(full_access) + CONTEXT_ACCESS_PARAMS_EQUALS(use_default_roles) + CONTEXT_ACCESS_PARAMS_EQUALS(current_roles) + CONTEXT_ACCESS_PARAMS_EQUALS(readonly) + CONTEXT_ACCESS_PARAMS_EQUALS(allow_ddl) + CONTEXT_ACCESS_PARAMS_EQUALS(allow_introspection) + CONTEXT_ACCESS_PARAMS_EQUALS(current_database) + CONTEXT_ACCESS_PARAMS_EQUALS(interface) + CONTEXT_ACCESS_PARAMS_EQUALS(http_method) + CONTEXT_ACCESS_PARAMS_EQUALS(address) + CONTEXT_ACCESS_PARAMS_EQUALS(forwarded_address) + CONTEXT_ACCESS_PARAMS_EQUALS(quota_key) + + #undef CONTEXT_ACCESS_PARAMS_EQUALS + + return true; /// All fields are equal, operator == must return true. +} + +bool operator <(const ContextAccessParams & left, const ContextAccessParams & right) +{ + auto check_less = [](const auto & x, const auto & y) + { + if constexpr (::detail::is_shared_ptr_v>) + { + if (!x) + return y ? -1 : 0; + else if (!y) + return 1; + else if (*x == *y) + return 0; + else if (*x < *y) + return -1; + else + return 1; + } + else + { + if (x == y) + return 0; + else if (x < y) + return -1; + else + return 1; + } + }; + + #define CONTEXT_ACCESS_PARAMS_LESS(name) \ + if (auto cmp = check_less(left.name, right.name); cmp != 0) \ + return cmp < 0; + + CONTEXT_ACCESS_PARAMS_LESS(user_id) + CONTEXT_ACCESS_PARAMS_LESS(full_access) + CONTEXT_ACCESS_PARAMS_LESS(use_default_roles) + CONTEXT_ACCESS_PARAMS_LESS(current_roles) + CONTEXT_ACCESS_PARAMS_LESS(readonly) + CONTEXT_ACCESS_PARAMS_LESS(allow_ddl) + CONTEXT_ACCESS_PARAMS_LESS(allow_introspection) + CONTEXT_ACCESS_PARAMS_LESS(current_database) + CONTEXT_ACCESS_PARAMS_LESS(interface) + CONTEXT_ACCESS_PARAMS_LESS(http_method) + CONTEXT_ACCESS_PARAMS_LESS(address) + CONTEXT_ACCESS_PARAMS_LESS(forwarded_address) + CONTEXT_ACCESS_PARAMS_LESS(quota_key) + + #undef CONTEXT_ACCESS_PARAMS_LESS + + return false; /// All fields are equal, operator < must return false. +} + +bool ContextAccessParams::dependsOnSettingName(std::string_view setting_name) +{ + return (setting_name == "readonly") || (setting_name == "allow_ddl") || (setting_name == "allow_introspection_functions"); +} + +} diff --git a/src/Access/ContextAccessParams.h b/src/Access/ContextAccessParams.h new file mode 100644 index 00000000000..740ec997964 --- /dev/null +++ b/src/Access/ContextAccessParams.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ +struct Settings; + +/// Parameters which are used to calculate access rights and some related stuff like roles or constraints. +class ContextAccessParams +{ +public: + ContextAccessParams( + const std::optional user_id_, + bool full_access_, + bool use_default_roles_, + const std::shared_ptr> & current_roles_, + const Settings & settings_, + const String & current_database_, + const ClientInfo & client_info_); + + const std::optional user_id; + + /// Full access to everything without any limitations. + /// This is used for the global context. + const bool full_access; + + const bool use_default_roles; + const std::shared_ptr> current_roles; + + const UInt64 readonly; + const bool allow_ddl; + const bool allow_introspection; + + const String current_database; + + const ClientInfo::Interface interface; + const ClientInfo::HTTPMethod http_method; + const Poco::Net::IPAddress address; + + /// The last entry from comma separated list of X-Forwarded-For addresses. + /// Only the last proxy can be trusted (if any). + const String forwarded_address; + + const String quota_key; + + /// Outputs `ContextAccessParams` to string for logging. + String toString() const; + + friend bool operator <(const ContextAccessParams & left, const ContextAccessParams & right); + friend bool operator ==(const ContextAccessParams & left, const ContextAccessParams & right); + friend bool operator !=(const ContextAccessParams & left, const ContextAccessParams & right) { return !(left == right); } + friend bool operator >(const ContextAccessParams & left, const ContextAccessParams & right) { return right < left; } + friend bool operator <=(const ContextAccessParams & left, const ContextAccessParams & right) { return !(right < left); } + friend bool operator >=(const ContextAccessParams & left, const ContextAccessParams & right) { return !(left < right); } + + static bool dependsOnSettingName(std::string_view setting_name); +}; + +} diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 5019933c2af..abc33c1b8d4 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -1063,8 +1063,16 @@ void Context::setUser(const UUID & user_id_) user_id = user_id_; - access = getAccessControl().getContextAccess( - user_id_, /* current_roles = */ {}, /* use_default_roles = */ true, settings, current_database, client_info); + ContextAccessParams params{ + user_id, + /* full_access= */ false, + /* use_default_roles = */ true, + /* current_roles = */ nullptr, + settings, + current_database, + client_info}; + + access = getAccessControl().getContextAccess(params); auto user = access->getUser(); @@ -1108,7 +1116,7 @@ void Context::setCurrentRoles(const std::vector & current_roles_) if (current_roles ? (*current_roles == current_roles_) : current_roles_.empty()) return; current_roles = std::make_shared>(current_roles_); - calculateAccessRights(); + need_recalculate_access = true; } void Context::setCurrentRolesDefault() @@ -1133,20 +1141,6 @@ std::shared_ptr Context::getRolesInfo() const } -void Context::calculateAccessRights() -{ - auto lock = getLock(); - if (user_id) - access = getAccessControl().getContextAccess( - *user_id, - current_roles ? *current_roles : std::vector{}, - /* use_default_roles = */ false, - settings, - current_database, - client_info); -} - - template void Context::checkAccessImpl(const Args &... args) const { @@ -1166,11 +1160,50 @@ void Context::checkAccess(const AccessFlags & flags, const StorageID & table_id, void Context::checkAccess(const AccessRightsElement & element) const { return checkAccessImpl(element); } void Context::checkAccess(const AccessRightsElements & elements) const { return checkAccessImpl(elements); } - std::shared_ptr Context::getAccess() const { - auto lock = getLock(); - return access ? access : ContextAccess::getFullAccess(); + /// A helper function to collect parameters for calculating access rights, called with Context::getLock() acquired. + auto get_params = [this]() + { + /// If setUserID() was never called then this must be the global context with the full access. + bool full_access = !user_id; + + return ContextAccessParams{user_id, full_access, /* use_default_roles= */ false, current_roles, settings, current_database, client_info}; + }; + + /// Check if the current access rights are still valid, otherwise get parameters for recalculating access rights. + std::optional params; + + { + auto lock = getLock(); + if (access && !need_recalculate_access) + return access; /// No need to recalculate access rights. + + params.emplace(get_params()); + + if (access && (access->getParams() == *params)) + { + need_recalculate_access = false; + return access; /// No need to recalculate access rights. + } + } + + /// Calculate new access rights according to the collected parameters. + /// NOTE: AccessControl::getContextAccess() may require some IO work, so Context::getLock() must be unlocked while we're doing this. + auto res = getAccessControl().getContextAccess(*params); + + { + /// If the parameters of access rights were not changed while we were calculated them + /// then we store the new access rights in the Context to allow reusing it later. + auto lock = getLock(); + if (get_params() == *params) + { + access = res; + need_recalculate_access = false; + } + } + + return res; } RowPolicyFilterPtr Context::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const @@ -1700,27 +1733,8 @@ Settings Context::getSettings() const void Context::setSettings(const Settings & settings_) { auto lock = getLock(); - const auto old_readonly = settings.readonly; - const auto old_allow_ddl = settings.allow_ddl; - const auto old_allow_introspection_functions = settings.allow_introspection_functions; - const auto old_display_secrets = settings.format_display_secrets_in_show_and_select; - settings = settings_; - - if ((settings.readonly != old_readonly) - || (settings.allow_ddl != old_allow_ddl) - || (settings.allow_introspection_functions != old_allow_introspection_functions) - || (settings.format_display_secrets_in_show_and_select != old_display_secrets)) - calculateAccessRights(); -} - -void Context::recalculateAccessRightsIfNeeded(std::string_view name) -{ - if (name == "readonly" - || name == "allow_ddl" - || name == "allow_introspection_functions" - || name == "format_display_secrets_in_show_and_select") - calculateAccessRights(); + need_recalculate_access = true; } void Context::setSetting(std::string_view name, const String & value) @@ -1732,7 +1746,8 @@ void Context::setSetting(std::string_view name, const String & value) return; } settings.set(name, value); - recalculateAccessRightsIfNeeded(name); + if (ContextAccessParams::dependsOnSettingName(name)) + need_recalculate_access = true; } void Context::setSetting(std::string_view name, const Field & value) @@ -1744,7 +1759,8 @@ void Context::setSetting(std::string_view name, const Field & value) return; } settings.set(name, value); - recalculateAccessRightsIfNeeded(name); + if (ContextAccessParams::dependsOnSettingName(name)) + need_recalculate_access = true; } void Context::applySettingChange(const SettingChange & change) @@ -1853,7 +1869,7 @@ void Context::setCurrentDatabase(const String & name) DatabaseCatalog::instance().assertDatabaseExists(name); auto lock = getLock(); current_database = name; - calculateAccessRights(); + need_recalculate_access = true; } void Context::setCurrentQueryId(const String & query_id) diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 6cbb0e58911..2c32ad28d01 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -248,7 +248,8 @@ private: std::optional user_id; std::shared_ptr> current_roles; std::shared_ptr settings_constraints_and_current_profiles; - std::shared_ptr access; + mutable std::shared_ptr access; + mutable bool need_recalculate_access = true; std::shared_ptr row_policies_of_initial_user; String current_database; Settings settings; /// Setting for query execution. @@ -1149,10 +1150,6 @@ private: void initGlobal(); - /// Compute and set actual user settings, client_info.current_user should be set - void calculateAccessRights(); - void recalculateAccessRightsIfNeeded(std::string_view setting_name); - template void checkAccessImpl(const Args &... args) const; From 0e4b75a282f14d812763a43cb3519c94beeb138c Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Sat, 1 Jul 2023 19:58:38 +0200 Subject: [PATCH 2/5] Avoid keeping lock Context::getLock() while calculating access rights when a user logs in. --- src/Access/ContextAccess.cpp | 35 +++++---- src/Access/ContextAccess.h | 6 +- src/Access/ContextAccessParams.cpp | 5 ++ src/Access/ContextAccessParams.h | 3 + src/Interpreters/Context.cpp | 115 +++++++++++++++++------------ src/Interpreters/Context.h | 23 +++--- src/Interpreters/Session.cpp | 3 - 7 files changed, 107 insertions(+), 83 deletions(-) diff --git a/src/Access/ContextAccess.cpp b/src/Access/ContextAccess.cpp index cb8f1a5a48e..51bb7794735 100644 --- a/src/Access/ContextAccess.cpp +++ b/src/Access/ContextAccess.cpp @@ -240,6 +240,7 @@ ContextAccess::~ContextAccess() enabled_settings.reset(); enabled_quota.reset(); enabled_row_policies.reset(); + row_policies_of_initial_user.reset(); access_with_implicit.reset(); access.reset(); roles_info.reset(); @@ -264,6 +265,12 @@ void ContextAccess::initialize() if (!params.user_id) throw Exception(ErrorCodes::LOGICAL_ERROR, "No user in current context, it's a bug"); + if (!params.initial_user.empty()) + { + if (auto initial_user_id = access_control->find(params.initial_user)) + row_policies_of_initial_user = access_control->tryGetDefaultRowPolicies(*initial_user_id); + } + subscription_for_user_change = access_control->subscribeForChanges( *params.user_id, [weak_ptr = weak_from_this()](const UUID &, const AccessEntityPtr & entity) @@ -331,8 +338,7 @@ void ContextAccess::setRolesInfo(const std::shared_ptr & assert(roles_info_); roles_info = roles_info_; - enabled_row_policies = access_control->getEnabledRowPolicies( - *params.user_id, roles_info->enabled_roles); + enabled_row_policies = access_control->getEnabledRowPolicies(*params.user_id, roles_info->enabled_roles); enabled_quota = access_control->getEnabledQuota( *params.user_id, user_name, roles_info->enabled_roles, params.address, params.forwarded_address, params.quota_key); @@ -399,21 +405,24 @@ std::shared_ptr ContextAccess::getRolesInfo() const return no_roles; } -std::shared_ptr ContextAccess::getEnabledRowPolicies() const +RowPolicyFilterPtr ContextAccess::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const { std::lock_guard lock{mutex}; - if (enabled_row_policies) - return enabled_row_policies; - static const auto no_row_policies = std::make_shared(); - return no_row_policies; -} -RowPolicyFilterPtr ContextAccess::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, RowPolicyFilterPtr combine_with_filter) const -{ - std::lock_guard lock{mutex}; + RowPolicyFilterPtr filter; if (enabled_row_policies) - return enabled_row_policies->getFilter(database, table_name, filter_type, combine_with_filter); - return combine_with_filter; + filter = enabled_row_policies->getFilter(database, table_name, filter_type); + + if (row_policies_of_initial_user) + { + /// Find and set extra row policies to be used based on `client_info.initial_user`, if the initial user exists. + /// TODO: we need a better solution here. It seems we should pass the initial row policy + /// because a shard is allowed to not have the initial user or it might be another user + /// with the same name. + filter = row_policies_of_initial_user->getFilter(database, table_name, filter_type, filter); + } + + return filter; } std::shared_ptr ContextAccess::getQuota() const diff --git a/src/Access/ContextAccess.h b/src/Access/ContextAccess.h index 4c96ef5c11f..4bd67f8881b 100644 --- a/src/Access/ContextAccess.h +++ b/src/Access/ContextAccess.h @@ -53,12 +53,9 @@ public: /// Returns information about current and enabled roles. std::shared_ptr getRolesInfo() const; - /// Returns information about enabled row policies. - std::shared_ptr getEnabledRowPolicies() const; - /// Returns the row policy filter for a specified table. /// The function returns nullptr if there is no filter to apply. - RowPolicyFilterPtr getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, RowPolicyFilterPtr combine_with_filter = {}) const; + RowPolicyFilterPtr getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const; /// Returns the quota to track resource consumption. std::shared_ptr getQuota() const; @@ -198,6 +195,7 @@ private: mutable std::shared_ptr access TSA_GUARDED_BY(mutex); mutable std::shared_ptr access_with_implicit TSA_GUARDED_BY(mutex); mutable std::shared_ptr enabled_row_policies TSA_GUARDED_BY(mutex); + mutable std::shared_ptr row_policies_of_initial_user TSA_GUARDED_BY(mutex); mutable std::shared_ptr enabled_quota TSA_GUARDED_BY(mutex); mutable std::shared_ptr enabled_settings TSA_GUARDED_BY(mutex); diff --git a/src/Access/ContextAccessParams.cpp b/src/Access/ContextAccessParams.cpp index 7963e83dddf..ec839a37b1a 100644 --- a/src/Access/ContextAccessParams.cpp +++ b/src/Access/ContextAccessParams.cpp @@ -27,6 +27,7 @@ ContextAccessParams::ContextAccessParams( , address(client_info_.current_address.host()) , forwarded_address(client_info_.getLastForwardedFor()) , quota_key(client_info_.quota_key) + , initial_user((client_info_.initial_user != client_info_.current_user) ? client_info_.initial_user : "") { } @@ -68,6 +69,8 @@ String ContextAccessParams::toString() const out << separator() << "forwarded_address = " << forwarded_address; if (!quota_key.empty()) out << separator() << "quota_key = " << quota_key; + if (!initial_user.empty()) + out << separator() << "initial_user = " << initial_user; return out.str(); } @@ -107,6 +110,7 @@ bool operator ==(const ContextAccessParams & left, const ContextAccessParams & r CONTEXT_ACCESS_PARAMS_EQUALS(address) CONTEXT_ACCESS_PARAMS_EQUALS(forwarded_address) CONTEXT_ACCESS_PARAMS_EQUALS(quota_key) + CONTEXT_ACCESS_PARAMS_EQUALS(initial_user) #undef CONTEXT_ACCESS_PARAMS_EQUALS @@ -158,6 +162,7 @@ bool operator <(const ContextAccessParams & left, const ContextAccessParams & ri CONTEXT_ACCESS_PARAMS_LESS(address) CONTEXT_ACCESS_PARAMS_LESS(forwarded_address) CONTEXT_ACCESS_PARAMS_LESS(quota_key) + CONTEXT_ACCESS_PARAMS_LESS(initial_user) #undef CONTEXT_ACCESS_PARAMS_LESS diff --git a/src/Access/ContextAccessParams.h b/src/Access/ContextAccessParams.h index 740ec997964..8b68fa44ed4 100644 --- a/src/Access/ContextAccessParams.h +++ b/src/Access/ContextAccessParams.h @@ -48,6 +48,9 @@ public: const String quota_key; + /// Initial user is used to combine row policies with. + const String initial_user; + /// Outputs `ContextAccessParams` to string for logging. String toString() const; diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index abc33c1b8d4..5fae9374705 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -1057,33 +1057,54 @@ ConfigurationPtr Context::getUsersConfig() return shared->users_config; } -void Context::setUser(const UUID & user_id_) +void Context::setUser(const UUID & user_id_, bool set_current_profiles_, bool set_current_roles_, bool set_current_database_) { + /// Prepare lists of user's profiles, constraints, settings, roles. + + std::shared_ptr user; + std::shared_ptr temp_access; + if (set_current_profiles_ || set_current_roles_ || set_current_database_) + { + std::optional params; + { + auto lock = getLock(); + params.emplace(ContextAccessParams{user_id_, /* full_access= */ false, /* use_default_roles = */ true, {}, settings, current_database, client_info}); + } + /// `temp_access` is used here only to extract information about the user, not to actually check access. + /// NOTE: AccessControl::getContextAccess() may require some IO work, so Context::getLock() must be unlocked while we're doing this. + temp_access = getAccessControl().getContextAccess(*params); + user = temp_access->getUser(); + } + + std::shared_ptr profiles; + if (set_current_profiles_) + profiles = temp_access->getDefaultProfileInfo(); + + std::optional> roles; + if (set_current_roles_) + roles = user->granted_roles.findGranted(user->default_roles); + + String database; + if (set_current_database_) + database = user->default_database; + + /// Apply user's profiles, constraints, settings, roles. auto lock = getLock(); - user_id = user_id_; + setUserID(user_id_); - ContextAccessParams params{ - user_id, - /* full_access= */ false, - /* use_default_roles = */ true, - /* current_roles = */ nullptr, - settings, - current_database, - client_info}; + if (profiles) + { + /// A profile can specify a value and a readonly constraint for same setting at the same time, + /// so we shouldn't check constraints here. + setCurrentProfiles(*profiles, /* check_constraints= */ false); + } - access = getAccessControl().getContextAccess(params); + if (roles) + setCurrentRoles(*roles); - auto user = access->getUser(); - - current_roles = std::make_shared>(user->granted_roles.findGranted(user->default_roles)); - - auto default_profile_info = access->getDefaultProfileInfo(); - settings_constraints_and_current_profiles = default_profile_info->getConstraintsAndProfileIDs(); - applySettingsChanges(default_profile_info->settings); - - if (!user->default_database.empty()) - setCurrentDatabase(user->default_database); + if (!database.empty()) + setCurrentDatabase(database); } std::shared_ptr Context::getUser() const @@ -1096,6 +1117,13 @@ String Context::getUserName() const return getAccess()->getUserName(); } +void Context::setUserID(const UUID & user_id_) +{ + auto lock = getLock(); + user_id = user_id_; + need_recalculate_access = true; +} + std::optional Context::getUserID() const { auto lock = getLock(); @@ -1113,9 +1141,10 @@ void Context::setQuotaKey(String quota_key_) void Context::setCurrentRoles(const std::vector & current_roles_) { auto lock = getLock(); - if (current_roles ? (*current_roles == current_roles_) : current_roles_.empty()) - return; - current_roles = std::make_shared>(current_roles_); + if (current_roles_.empty()) + current_roles = nullptr; + else + current_roles = std::make_shared>(current_roles_); need_recalculate_access = true; } @@ -1208,23 +1237,7 @@ std::shared_ptr Context::getAccess() const RowPolicyFilterPtr Context::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const { - auto lock = getLock(); - RowPolicyFilterPtr row_filter_of_initial_user; - if (row_policies_of_initial_user) - row_filter_of_initial_user = row_policies_of_initial_user->getFilter(database, table_name, filter_type); - return getAccess()->getRowPolicyFilter(database, table_name, filter_type, row_filter_of_initial_user); -} - -void Context::enableRowPoliciesOfInitialUser() -{ - auto lock = getLock(); - row_policies_of_initial_user = nullptr; - if (client_info.initial_user == client_info.current_user) - return; - auto initial_user_id = getAccessControl().find(client_info.initial_user); - if (!initial_user_id) - return; - row_policies_of_initial_user = getAccessControl().tryGetDefaultRowPolicies(*initial_user_id); + return getAccess()->getRowPolicyFilter(database, table_name, filter_type); } @@ -1240,13 +1253,12 @@ std::optional Context::getQuotaUsage() const } -void Context::setCurrentProfile(const String & profile_name) +void Context::setCurrentProfile(const String & profile_name, bool check_constraints) { - auto lock = getLock(); try { UUID profile_id = getAccessControl().getID(profile_name); - setCurrentProfile(profile_id); + setCurrentProfile(profile_id, check_constraints); } catch (Exception & e) { @@ -1255,15 +1267,20 @@ void Context::setCurrentProfile(const String & profile_name) } } -void Context::setCurrentProfile(const UUID & profile_id) +void Context::setCurrentProfile(const UUID & profile_id, bool check_constraints) { - auto lock = getLock(); auto profile_info = getAccessControl().getSettingsProfileInfo(profile_id); - checkSettingsConstraints(profile_info->settings); - applySettingsChanges(profile_info->settings); - settings_constraints_and_current_profiles = profile_info->getConstraintsAndProfileIDs(settings_constraints_and_current_profiles); + setCurrentProfiles(*profile_info, check_constraints); } +void Context::setCurrentProfiles(const SettingsProfilesInfo & profiles_info, bool check_constraints) +{ + auto lock = getLock(); + if (check_constraints) + checkSettingsConstraints(profiles_info.settings); + applySettingsChanges(profiles_info.settings); + settings_constraints_and_current_profiles = profiles_info.getConstraintsAndProfileIDs(settings_constraints_and_current_profiles); +} std::vector Context::getCurrentProfiles() const { diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 2c32ad28d01..172f3818dfd 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -50,8 +50,8 @@ struct ContextSharedPart; class ContextAccess; struct User; using UserPtr = std::shared_ptr; +struct SettingsProfilesInfo; struct EnabledRolesInfo; -class EnabledRowPolicies; struct RowPolicyFilter; using RowPolicyFilterPtr = std::shared_ptr; class EnabledQuota; @@ -250,7 +250,6 @@ private: std::shared_ptr settings_constraints_and_current_profiles; mutable std::shared_ptr access; mutable bool need_recalculate_access = true; - std::shared_ptr row_policies_of_initial_user; String current_database; Settings settings; /// Setting for query execution. @@ -530,12 +529,14 @@ public: /// Sets the current user assuming that he/she is already authenticated. /// WARNING: This function doesn't check password! - void setUser(const UUID & user_id_); - + void setUser(const UUID & user_id_, bool set_current_profiles_ = true, bool set_current_roles_ = true, bool set_current_database_ = true); UserPtr getUser() const; - String getUserName() const; + + void setUserID(const UUID & user_id_); std::optional getUserID() const; + String getUserName() const; + void setQuotaKey(String quota_key_); void setCurrentRoles(const std::vector & current_roles_); @@ -544,8 +545,9 @@ public: boost::container::flat_set getEnabledRoles() const; std::shared_ptr getRolesInfo() const; - void setCurrentProfile(const String & profile_name); - void setCurrentProfile(const UUID & profile_id); + void setCurrentProfile(const String & profile_name, bool check_constraints = true); + void setCurrentProfile(const UUID & profile_id, bool check_constraints = true); + void setCurrentProfiles(const SettingsProfilesInfo & profiles_info, bool check_constraints = true); std::vector getCurrentProfiles() const; std::vector getEnabledProfiles() const; @@ -568,13 +570,6 @@ public: RowPolicyFilterPtr getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const; - /// Finds and sets extra row policies to be used based on `client_info.initial_user`, - /// if the initial user exists. - /// TODO: we need a better solution here. It seems we should pass the initial row policy - /// because a shard is allowed to not have the initial user or it might be another user - /// with the same name. - void enableRowPoliciesOfInitialUser(); - std::shared_ptr getQuota() const; std::optional getQuotaUsage() const; diff --git a/src/Interpreters/Session.cpp b/src/Interpreters/Session.cpp index 64f7b4fc934..8571f20b91e 100644 --- a/src/Interpreters/Session.cpp +++ b/src/Interpreters/Session.cpp @@ -515,9 +515,6 @@ ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_t res_client_info.initial_address = res_client_info.current_address; } - /// Sets that row policies of the initial user should be used too. - query_context->enableRowPoliciesOfInitialUser(); - /// Set user information for the new context: current profiles, roles, access rights. if (user_id && !query_context->getAccess()->tryGetUser()) query_context->setUser(*user_id); From 815a3857de74b92d3071128a6b5fbc5cb0a53c93 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Fri, 7 Jul 2023 12:49:50 +0200 Subject: [PATCH 3/5] Remove non-const function Context::getClientInfo(). --- programs/client/Client.cpp | 11 +- programs/local/LocalServer.cpp | 5 +- src/Databases/DatabaseReplicated.cpp | 4 +- .../MySQL/MaterializedMySQLSyncThread.cpp | 2 +- src/Interpreters/AsynchronousInsertQueue.cpp | 10 +- .../ClusterProxy/executeQuery.cpp | 2 +- src/Interpreters/Context.cpp | 123 +++++++++++++++ src/Interpreters/Context.h | 26 +++- src/Interpreters/DDLTask.cpp | 6 +- src/Interpreters/DDLWorker.cpp | 2 +- src/Interpreters/InterpreterDropQuery.cpp | 4 +- src/Interpreters/InterpreterSelectQuery.cpp | 2 +- src/Interpreters/Session.cpp | 141 +++++++++++++++--- src/Interpreters/Session.h | 17 ++- src/Interpreters/executeQuery.cpp | 5 +- .../QueryPlan/DistributedCreateLocalPlan.cpp | 12 +- src/Server/GRPCServer.cpp | 2 +- src/Server/HTTPHandler.cpp | 19 +-- src/Server/MySQLHandler.cpp | 2 +- src/Server/PostgreSQLHandler.cpp | 2 +- src/Server/TCPHandler.cpp | 25 +--- src/Storages/Distributed/DistributedSink.cpp | 2 +- src/Storages/StorageDistributed.cpp | 4 +- src/Storages/StorageReplicatedMergeTree.cpp | 2 +- src/Storages/WindowView/StorageWindowView.cpp | 2 +- 25 files changed, 335 insertions(+), 97 deletions(-) diff --git a/programs/client/Client.cpp b/programs/client/Client.cpp index f791c39bad1..e1a33231592 100644 --- a/programs/client/Client.cpp +++ b/programs/client/Client.cpp @@ -1173,12 +1173,12 @@ void Client::processOptions(const OptionsDescription & options_description, { String traceparent = options["opentelemetry-traceparent"].as(); String error; - if (!global_context->getClientInfo().client_trace_context.parseTraceparentHeader(traceparent, error)) + if (!global_context->getClientTraceContext().parseTraceparentHeader(traceparent, error)) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Cannot parse OpenTelemetry traceparent '{}': {}", traceparent, error); } if (options.count("opentelemetry-tracestate")) - global_context->getClientInfo().client_trace_context.tracestate = options["opentelemetry-tracestate"].as(); + global_context->getClientTraceContext().tracestate = options["opentelemetry-tracestate"].as(); } @@ -1238,10 +1238,9 @@ void Client::processConfig() global_context->getSettingsRef().max_insert_block_size); } - ClientInfo & client_info = global_context->getClientInfo(); - client_info.setInitialQuery(); - client_info.quota_key = config().getString("quota_key", ""); - client_info.query_kind = query_kind; + global_context->setQueryKindInitial(); + global_context->setQuotaClientKey(config().getString("quota_key", "")); + global_context->setQueryKind(query_kind); } diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index 96924e3c8d9..3c2a8ae3152 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -737,9 +737,8 @@ void LocalServer::processConfig() for (const auto & [key, value] : prompt_substitutions) boost::replace_all(prompt_by_server_display_name, "{" + key + "}", value); - ClientInfo & client_info = global_context->getClientInfo(); - client_info.setInitialQuery(); - client_info.query_kind = query_kind; + global_context->setQueryKindInitial(); + global_context->setQueryKind(query_kind); } diff --git a/src/Databases/DatabaseReplicated.cpp b/src/Databases/DatabaseReplicated.cpp index 661afc6bf1f..25c23e2be17 100644 --- a/src/Databases/DatabaseReplicated.cpp +++ b/src/Databases/DatabaseReplicated.cpp @@ -814,8 +814,8 @@ void DatabaseReplicated::recoverLostReplica(const ZooKeeperPtr & current_zookeep { auto query_context = Context::createCopy(getContext()); query_context->makeQueryContext(); - query_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY; - query_context->getClientInfo().is_replicated_database_internal = true; + query_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY); + query_context->setQueryKindReplicatedDatabaseInternal(); query_context->setCurrentDatabase(getDatabaseName()); query_context->setCurrentQueryId(""); auto txn = std::make_shared(current_zookeeper, zookeeper_path, false, ""); diff --git a/src/Databases/MySQL/MaterializedMySQLSyncThread.cpp b/src/Databases/MySQL/MaterializedMySQLSyncThread.cpp index a01ab2a15a8..379e6ef5097 100644 --- a/src/Databases/MySQL/MaterializedMySQLSyncThread.cpp +++ b/src/Databases/MySQL/MaterializedMySQLSyncThread.cpp @@ -59,7 +59,7 @@ static ContextMutablePtr createQueryContext(ContextPtr context) query_context->setSettings(new_query_settings); query_context->setInternalQuery(true); - query_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY; + query_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY); query_context->setCurrentQueryId(""); // generate random query_id return query_context; } diff --git a/src/Interpreters/AsynchronousInsertQueue.cpp b/src/Interpreters/AsynchronousInsertQueue.cpp index 0da762699d2..6081919a120 100644 --- a/src/Interpreters/AsynchronousInsertQueue.cpp +++ b/src/Interpreters/AsynchronousInsertQueue.cpp @@ -421,12 +421,10 @@ try auto insert_query_id = insert_context->getCurrentQueryId(); auto query_start_time = std::chrono::system_clock::now(); Stopwatch start_watch{CLOCK_MONOTONIC}; - ClientInfo & client_info = insert_context->getClientInfo(); - client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; - client_info.initial_query_start_time = timeInSeconds(query_start_time); - client_info.initial_query_start_time_microseconds = timeInMicroseconds(query_start_time); - client_info.current_query_id = insert_query_id; - client_info.initial_query_id = insert_query_id; + insert_context->setQueryKind(ClientInfo::QueryKind::INITIAL_QUERY); + insert_context->setInitialQueryStartTime(query_start_time); + insert_context->setCurrentQueryId(insert_query_id); + insert_context->setInitialQueryId(insert_query_id); size_t log_queries_cut_to_length = insert_context->getSettingsRef().log_queries_cut_to_length; String query_for_logging = insert_query.hasSecretParts() ? insert_query.formatForLogging(log_queries_cut_to_length) diff --git a/src/Interpreters/ClusterProxy/executeQuery.cpp b/src/Interpreters/ClusterProxy/executeQuery.cpp index e2f1dfe8ba7..3dea52faf46 100644 --- a/src/Interpreters/ClusterProxy/executeQuery.cpp +++ b/src/Interpreters/ClusterProxy/executeQuery.cpp @@ -171,7 +171,7 @@ void executeQuery( SelectStreamFactory::Shards remote_shards; auto new_context = updateSettingsForCluster(*query_info.getCluster(), context, settings, main_table, &query_info, log); - new_context->getClientInfo().distributed_depth += 1; + new_context->increaseDistributedDepth(); size_t shards = query_info.getCluster()->getShardCount(); for (const auto & shard_info : query_info.getCluster()->getShardsInfo()) diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 5fae9374705..c097eeb87f1 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -3850,6 +3850,129 @@ void Context::resetInputCallbacks() } +void Context::setClientInfo(const ClientInfo & client_info_) +{ + client_info = client_info_; + need_recalculate_access = true; +} + +void Context::setClientName(const String & client_name) +{ + client_info.client_name = client_name; +} + +void Context::setClientInterface(ClientInfo::Interface interface) +{ + client_info.interface = interface; + need_recalculate_access = true; +} + +void Context::setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version) +{ + client_info.client_version_major = client_version_major; + client_info.client_version_minor = client_version_minor; + client_info.client_version_patch = client_version_patch; + client_info.client_tcp_protocol_version = client_tcp_protocol_version; +} + +void Context::setClientConnectionId(uint32_t connection_id_) +{ + client_info.connection_id = connection_id_; +} + +void Context::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer) +{ + client_info.http_method = http_method; + client_info.http_user_agent = http_user_agent; + client_info.http_referer = http_referer; + need_recalculate_access = true; +} + +void Context::setForwardedFor(const String & forwarded_for) +{ + client_info.forwarded_for = forwarded_for; + need_recalculate_access = true; +} + +void Context::setQueryKind(ClientInfo::QueryKind query_kind) +{ + client_info.query_kind = query_kind; +} + +void Context::setQueryKindInitial() +{ + /// TODO: Try to combine this function with setQueryKind(). + client_info.setInitialQuery(); +} + +void Context::setQueryKindReplicatedDatabaseInternal() +{ + /// TODO: Try to combine this function with setQueryKind(). + client_info.is_replicated_database_internal = true; +} + +void Context::setCurrentUserName(const String & current_user_name) +{ + /// TODO: Try to combine this function with setUser(). + client_info.current_user = current_user_name; + need_recalculate_access = true; +} + +void Context::setCurrentAddress(const Poco::Net::SocketAddress & current_address) +{ + client_info.current_address = current_address; + need_recalculate_access = true; +} + +void Context::setInitialUserName(const String & initial_user_name) +{ + client_info.initial_user = initial_user_name; + need_recalculate_access = true; +} + +void Context::setInitialAddress(const Poco::Net::SocketAddress & initial_address) +{ + client_info.initial_address = initial_address; +} + +void Context::setInitialQueryId(const String & initial_query_id) +{ + client_info.initial_query_id = initial_query_id; +} + +void Context::setInitialQueryStartTime(std::chrono::time_point initial_query_start_time) +{ + client_info.initial_query_start_time = timeInSeconds(initial_query_start_time); + client_info.initial_query_start_time_microseconds = timeInMicroseconds(initial_query_start_time); +} + +void Context::setQuotaClientKey(const String & quota_key_) +{ + client_info.quota_key = quota_key_; + need_recalculate_access = true; +} + +void Context::setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version) +{ + client_info.connection_client_version_major = client_version_major; + client_info.connection_client_version_minor = client_version_minor; + client_info.connection_client_version_patch = client_version_patch; + client_info.connection_tcp_protocol_version = client_tcp_protocol_version; +} + +void Context::setReplicaInfo(bool collaborate_with_initiator, size_t all_replicas_count, size_t number_of_current_replica) +{ + client_info.collaborate_with_initiator = collaborate_with_initiator; + client_info.count_participating_replicas = all_replicas_count; + client_info.number_of_current_replica = number_of_current_replica; +} + +void Context::increaseDistributedDepth() +{ + ++client_info.distributed_depth; +} + + StorageID Context::resolveStorageID(StorageID storage_id, StorageNamespace where) const { if (storage_id.uuid != UUIDHelpers::Nil) diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 172f3818dfd..afc4bfde6a8 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -593,9 +593,33 @@ public: InputBlocksReader getInputBlocksReaderCallback() const; void resetInputCallbacks(); - ClientInfo & getClientInfo() { return client_info; } + /// Returns information about the client executing a query. const ClientInfo & getClientInfo() const { return client_info; } + /// Modify stored in the context information about the client executing a query. + void setClientInfo(const ClientInfo & client_info_); + void setClientName(const String & client_name); + void setClientInterface(ClientInfo::Interface interface); + void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version); + void setClientConnectionId(uint32_t connection_id); + void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer); + void setForwardedFor(const String & forwarded_for); + void setQueryKind(ClientInfo::QueryKind query_kind); + void setQueryKindInitial(); + void setQueryKindReplicatedDatabaseInternal(); + void setCurrentUserName(const String & current_user_name); + void setCurrentAddress(const Poco::Net::SocketAddress & current_address); + void setInitialUserName(const String & initial_user_name); + void setInitialAddress(const Poco::Net::SocketAddress & initial_address); + void setInitialQueryId(const String & initial_query_id); + void setInitialQueryStartTime(std::chrono::time_point initial_query_start_time); + void setQuotaClientKey(const String & quota_key); + void setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version); + void setReplicaInfo(bool collaborate_with_initiator, size_t all_replicas_count, size_t number_of_current_replica); + void increaseDistributedDepth(); + const OpenTelemetry::TracingContext & getClientTraceContext() const { return client_info.client_trace_context; } + OpenTelemetry::TracingContext & getClientTraceContext() { return client_info.client_trace_context; } + enum StorageNamespace { ResolveGlobal = 1u, /// Database name must be specified diff --git a/src/Interpreters/DDLTask.cpp b/src/Interpreters/DDLTask.cpp index b24856a6146..4e684f5899f 100644 --- a/src/Interpreters/DDLTask.cpp +++ b/src/Interpreters/DDLTask.cpp @@ -199,7 +199,7 @@ ContextMutablePtr DDLTaskBase::makeQueryContext(ContextPtr from_context, const Z auto query_context = Context::createCopy(from_context); query_context->makeQueryContext(); query_context->setCurrentQueryId(""); // generate random query_id - query_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY; + query_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY); if (entry.settings) query_context->applySettingsChanges(*entry.settings); return query_context; @@ -439,8 +439,8 @@ void DatabaseReplicatedTask::parseQueryFromEntry(ContextPtr context) ContextMutablePtr DatabaseReplicatedTask::makeQueryContext(ContextPtr from_context, const ZooKeeperPtr & zookeeper) { auto query_context = DDLTaskBase::makeQueryContext(from_context, zookeeper); - query_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY; - query_context->getClientInfo().is_replicated_database_internal = true; + query_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY); + query_context->setQueryKindReplicatedDatabaseInternal(); query_context->setCurrentDatabase(database->getDatabaseName()); auto txn = std::make_shared(zookeeper, database->zookeeper_path, is_initial_query, entry_path); diff --git a/src/Interpreters/DDLWorker.cpp b/src/Interpreters/DDLWorker.cpp index 81c78000ac3..193bb5b6ab0 100644 --- a/src/Interpreters/DDLWorker.cpp +++ b/src/Interpreters/DDLWorker.cpp @@ -476,7 +476,7 @@ bool DDLWorker::tryExecuteQuery(DDLTaskBase & task, const ZooKeeperPtr & zookeep query_context->setSetting("implicit_transaction", Field{0}); } - query_context->getClientInfo().initial_query_id = task.entry.initial_query_id; + query_context->setInitialQueryId(task.entry.initial_query_id); if (!task.is_initial_query) query_scope.emplace(query_context); diff --git a/src/Interpreters/InterpreterDropQuery.cpp b/src/Interpreters/InterpreterDropQuery.cpp index 0beb4492aef..616cf80a446 100644 --- a/src/Interpreters/InterpreterDropQuery.cpp +++ b/src/Interpreters/InterpreterDropQuery.cpp @@ -451,11 +451,11 @@ void InterpreterDropQuery::executeDropQuery(ASTDropQuery::Kind kind, ContextPtr auto drop_context = Context::createCopy(global_context); if (ignore_sync_setting) drop_context->setSetting("database_atomic_wait_for_drop_and_detach_synchronously", false); - drop_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY; + drop_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY); if (auto txn = current_context->getZooKeeperMetadataTransaction()) { /// For Replicated database - drop_context->getClientInfo().is_replicated_database_internal = true; + drop_context->setQueryKindReplicatedDatabaseInternal(); drop_context->setQueryContext(std::const_pointer_cast(current_context)); drop_context->initZooKeeperMetadataTransaction(txn, true); } diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 32812151b59..d07a6521544 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -3183,7 +3183,7 @@ void InterpreterSelectQuery::initSettings() if (query.settings()) InterpreterSetQuery(query.settings(), context).executeForCurrentContext(options.ignore_setting_constraints); - auto & client_info = context->getClientInfo(); + const auto & client_info = context->getClientInfo(); auto min_major = DBMS_MIN_MAJOR_VERSION_WITH_CURRENT_AGGREGATION_VARIANT_SELECTION_METHOD; auto min_minor = DBMS_MIN_MINOR_VERSION_WITH_CURRENT_AGGREGATION_VARIANT_SELECTION_METHOD; diff --git a/src/Interpreters/Session.cpp b/src/Interpreters/Session.cpp index 8571f20b91e..97b056cfc32 100644 --- a/src/Interpreters/Session.cpp +++ b/src/Interpreters/Session.cpp @@ -299,7 +299,10 @@ Session::~Session() if (notified_session_log_about_login) { if (auto session_log = getSessionLog()) + { + /// TODO: We have to ensure that the same info is added to the session log on a LoginSuccess event and on the corresponding Logout event. session_log->addLogOut(auth_id, user, getClientInfo()); + } } } @@ -368,17 +371,117 @@ void Session::onAuthenticationFailure(const std::optional & user_name, c } } -ClientInfo & Session::getClientInfo() -{ - /// FIXME it may produce different info for LoginSuccess and the corresponding Logout entries in the session log - return session_context ? session_context->getClientInfo() : *prepared_client_info; -} - const ClientInfo & Session::getClientInfo() const { return session_context ? session_context->getClientInfo() : *prepared_client_info; } +void Session::setClientInfo(const ClientInfo & client_info) +{ + if (session_context) + session_context->setClientInfo(client_info); + else + prepared_client_info = client_info; +} + +void Session::setClientName(const String & client_name) +{ + if (session_context) + session_context->setClientName(client_name); + else + prepared_client_info->client_name = client_name; +} + +void Session::setClientInterface(ClientInfo::Interface interface) +{ + if (session_context) + session_context->setClientInterface(interface); + else + prepared_client_info->interface = interface; +} + +void Session::setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version) +{ + if (session_context) + { + session_context->setClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version); + } + else + { + prepared_client_info->client_version_major = client_version_major; + prepared_client_info->client_version_minor = client_version_minor; + prepared_client_info->client_version_patch = client_version_patch; + prepared_client_info->client_tcp_protocol_version = client_tcp_protocol_version; + } +} + +void Session::setClientConnectionId(uint32_t connection_id) +{ + if (session_context) + session_context->setClientConnectionId(connection_id); + else + prepared_client_info->connection_id = connection_id; +} + +void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer) +{ + if (session_context) + { + session_context->setHttpClientInfo(http_method, http_user_agent, http_referer); + } + else + { + prepared_client_info->http_method = http_method; + prepared_client_info->http_user_agent = http_user_agent; + prepared_client_info->http_referer = http_referer; + } +} + +void Session::setForwardedFor(const String & forwarded_for) +{ + if (session_context) + session_context->setForwardedFor(forwarded_for); + else + prepared_client_info->forwarded_for = forwarded_for; +} + +void Session::setQuotaClientKey(const String & quota_key) +{ + if (session_context) + session_context->setQuotaClientKey(quota_key); + else + prepared_client_info->quota_key = quota_key; +} + +void Session::setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version) +{ + if (session_context) + { + session_context->setConnectionClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version); + } + else + { + prepared_client_info->connection_client_version_major = client_version_major; + prepared_client_info->connection_client_version_minor = client_version_minor; + prepared_client_info->connection_client_version_patch = client_version_patch; + prepared_client_info->connection_tcp_protocol_version = client_tcp_protocol_version; + } +} + +const OpenTelemetry::TracingContext & Session::getClientTraceContext() const +{ + if (session_context) + return session_context->getClientTraceContext(); + return prepared_client_info->client_trace_context; +} + +OpenTelemetry::TracingContext & Session::getClientTraceContext() +{ + if (session_context) + return session_context->getClientTraceContext(); + return prepared_client_info->client_trace_context; +} + ContextMutablePtr Session::makeSessionContext() { if (session_context) @@ -396,8 +499,7 @@ ContextMutablePtr Session::makeSessionContext() new_session_context->makeSessionContext(); /// Copy prepared client info to the new session context. - auto & res_client_info = new_session_context->getClientInfo(); - res_client_info = std::move(prepared_client_info).value(); + new_session_context->setClientInfo(*prepared_client_info); prepared_client_info.reset(); /// Set user information for the new context: current profiles, roles, access rights. @@ -436,8 +538,7 @@ ContextMutablePtr Session::makeSessionContext(const String & session_name_, std: /// Copy prepared client info to the session context, no matter it's been just created or not. /// If we continue using a previously created session context found by session ID /// it's necessary to replace the client info in it anyway, because it contains actual connection information (client address, etc.) - auto & res_client_info = new_session_context->getClientInfo(); - res_client_info = std::move(prepared_client_info).value(); + new_session_context->setClientInfo(*prepared_client_info); prepared_client_info.reset(); /// Set user information for the new context: current profiles, roles, access rights. @@ -492,27 +593,26 @@ ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_t } /// Copy the specified client info to the new query context. - auto & res_client_info = query_context->getClientInfo(); if (client_info_to_move) - res_client_info = std::move(*client_info_to_move); + query_context->setClientInfo(*client_info_to_move); else if (client_info_to_copy && (client_info_to_copy != &getClientInfo())) - res_client_info = *client_info_to_copy; + query_context->setClientInfo(*client_info_to_copy); /// Copy current user's name and address if it was authenticated after query_client_info was initialized. if (prepared_client_info && !prepared_client_info->current_user.empty()) { - res_client_info.current_user = prepared_client_info->current_user; - res_client_info.current_address = prepared_client_info->current_address; + query_context->setCurrentUserName(prepared_client_info->current_user); + query_context->setCurrentAddress(prepared_client_info->current_address); } /// Set parameters of initial query. - if (res_client_info.query_kind == ClientInfo::QueryKind::NO_QUERY) - res_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY; + if (query_context->getClientInfo().query_kind == ClientInfo::QueryKind::NO_QUERY) + query_context->setQueryKind(ClientInfo::QueryKind::INITIAL_QUERY); - if (res_client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY) + if (query_context->getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY) { - res_client_info.initial_user = res_client_info.current_user; - res_client_info.initial_address = res_client_info.current_address; + query_context->setInitialUserName(query_context->getClientInfo().current_user); + query_context->setInitialAddress(query_context->getClientInfo().current_address); } /// Set user information for the new context: current profiles, roles, access rights. @@ -563,4 +663,3 @@ void Session::closeSession(const String & session_id) } } - diff --git a/src/Interpreters/Session.h b/src/Interpreters/Session.h index d7c06a60464..36f811ccd24 100644 --- a/src/Interpreters/Session.h +++ b/src/Interpreters/Session.h @@ -54,10 +54,23 @@ public: /// Writes a row about login failure into session log (if enabled) void onAuthenticationFailure(const std::optional & user_name, const Poco::Net::SocketAddress & address_, const Exception & e); - /// Returns a reference to session ClientInfo. - ClientInfo & getClientInfo(); + /// Returns a reference to the session's ClientInfo. const ClientInfo & getClientInfo() const; + /// Modify the session's ClientInfo. + void setClientInfo(const ClientInfo & client_info); + void setClientName(const String & client_name); + void setClientInterface(ClientInfo::Interface interface); + void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version); + void setClientConnectionId(uint32_t connection_id); + void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer); + void setForwardedFor(const String & forwarded_for); + void setQuotaClientKey(const String & quota_key); + void setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version); + + const OpenTelemetry::TracingContext & getClientTraceContext() const; + OpenTelemetry::TracingContext & getClientTraceContext(); + /// Makes a session context, can be used one or zero times. /// The function also assigns an user to this context. ContextMutablePtr makeSessionContext(); diff --git a/src/Interpreters/executeQuery.cpp b/src/Interpreters/executeQuery.cpp index 4b76d20f31d..2c74039463e 100644 --- a/src/Interpreters/executeQuery.cpp +++ b/src/Interpreters/executeQuery.cpp @@ -655,7 +655,7 @@ static std::tuple executeQueryImpl( /// the value passed by the client Stopwatch start_watch{CLOCK_MONOTONIC}; - auto & client_info = context->getClientInfo(); + const auto & client_info = context->getClientInfo(); if (!internal) { @@ -667,8 +667,7 @@ static std::tuple executeQueryImpl( // On the other hand, if it's initialized then take it as the start of the query if (client_info.initial_query_start_time == 0) { - client_info.initial_query_start_time = timeInSeconds(query_start_time); - client_info.initial_query_start_time_microseconds = timeInMicroseconds(query_start_time); + context->setInitialQueryStartTime(query_start_time); } else { diff --git a/src/Processors/QueryPlan/DistributedCreateLocalPlan.cpp b/src/Processors/QueryPlan/DistributedCreateLocalPlan.cpp index 9b9cc221ca8..b251eec2d28 100644 --- a/src/Processors/QueryPlan/DistributedCreateLocalPlan.cpp +++ b/src/Processors/QueryPlan/DistributedCreateLocalPlan.cpp @@ -72,14 +72,10 @@ std::unique_ptr createLocalPlan( if (coordinator) { new_context->parallel_reading_coordinator = coordinator; - new_context->getClientInfo().interface = ClientInfo::Interface::LOCAL; - new_context->getClientInfo().collaborate_with_initiator = true; - new_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY; - new_context->getClientInfo().count_participating_replicas = replica_count; - new_context->getClientInfo().number_of_current_replica = replica_num; - new_context->getClientInfo().connection_client_version_major = DBMS_VERSION_MAJOR; - new_context->getClientInfo().connection_client_version_minor = DBMS_VERSION_MINOR; - new_context->getClientInfo().connection_tcp_protocol_version = DBMS_TCP_PROTOCOL_VERSION; + new_context->setClientInterface(ClientInfo::Interface::LOCAL); + new_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY); + new_context->setReplicaInfo(true, replica_count, replica_num); + new_context->setConnectionClientVersion(DBMS_VERSION_MAJOR, DBMS_VERSION_MINOR, DBMS_VERSION_PATCH, DBMS_TCP_PROTOCOL_VERSION); new_context->setParallelReplicasGroupUUID(group_uuid); new_context->setMergeTreeAllRangesCallback([coordinator](InitialAllRangesAnnouncement announcement) { diff --git a/src/Server/GRPCServer.cpp b/src/Server/GRPCServer.cpp index bf9ba20a5cf..67d30012b0e 100644 --- a/src/Server/GRPCServer.cpp +++ b/src/Server/GRPCServer.cpp @@ -798,7 +798,7 @@ namespace /// Authentication. session.emplace(iserver.context(), ClientInfo::Interface::GRPC); session->authenticate(user, password, user_address); - session->getClientInfo().quota_key = quota_key; + session->setQuotaClientKey(quota_key); ClientInfo client_info = session->getClientInfo(); diff --git a/src/Server/HTTPHandler.cpp b/src/Server/HTTPHandler.cpp index f7cdb905710..069670c84a5 100644 --- a/src/Server/HTTPHandler.cpp +++ b/src/Server/HTTPHandler.cpp @@ -474,7 +474,6 @@ bool HTTPHandler::authenticateUser( } /// Set client info. It will be used for quota accounting parameters in 'setUser' method. - ClientInfo & client_info = session->getClientInfo(); ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; if (request.getMethod() == HTTPServerRequest::HTTP_GET) @@ -482,15 +481,13 @@ bool HTTPHandler::authenticateUser( else if (request.getMethod() == HTTPServerRequest::HTTP_POST) http_method = ClientInfo::HTTPMethod::POST; - client_info.http_method = http_method; - client_info.http_user_agent = request.get("User-Agent", ""); - client_info.http_referer = request.get("Referer", ""); - client_info.forwarded_for = request.get("X-Forwarded-For", ""); - client_info.quota_key = quota_key; + session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", "")); + session->setForwardedFor(request.get("X-Forwarded-For", "")); + session->setQuotaClientKey(quota_key); /// Extract the last entry from comma separated list of forwarded_for addresses. /// Only the last proxy can be trusted (if any). - String forwarded_address = client_info.getLastForwardedFor(); + String forwarded_address = session->getClientInfo().getLastForwardedFor(); try { if (!forwarded_address.empty() && server.config().getBool("auth_use_forwarded_address", false)) @@ -988,22 +985,22 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse } // Parse the OpenTelemetry traceparent header. - ClientInfo& client_info = session->getClientInfo(); + auto & client_trace_context = session->getClientTraceContext(); if (request.has("traceparent")) { std::string opentelemetry_traceparent = request.get("traceparent"); std::string error; - if (!client_info.client_trace_context.parseTraceparentHeader(opentelemetry_traceparent, error)) + if (!client_trace_context.parseTraceparentHeader(opentelemetry_traceparent, error)) { LOG_DEBUG(log, "Failed to parse OpenTelemetry traceparent header '{}': {}", opentelemetry_traceparent, error); } - client_info.client_trace_context.tracestate = request.get("tracestate", ""); + client_trace_context.tracestate = request.get("tracestate", ""); } // Setup tracing context for this thread auto context = session->sessionOrGlobalContext(); thread_trace_context = std::make_unique("HTTPHandler", - client_info.client_trace_context, + client_trace_context, context->getSettingsRef(), context->getOpenTelemetrySpanLog()); thread_trace_context->root_span.kind = OpenTelemetry::SERVER; diff --git a/src/Server/MySQLHandler.cpp b/src/Server/MySQLHandler.cpp index 7318b0ad89b..f98b86e6cf8 100644 --- a/src/Server/MySQLHandler.cpp +++ b/src/Server/MySQLHandler.cpp @@ -94,7 +94,7 @@ void MySQLHandler::run() session = std::make_unique(server.context(), ClientInfo::Interface::MYSQL); SCOPE_EXIT({ session.reset(); }); - session->getClientInfo().connection_id = connection_id; + session->setClientConnectionId(connection_id); in = std::make_shared(socket()); out = std::make_shared(socket()); diff --git a/src/Server/PostgreSQLHandler.cpp b/src/Server/PostgreSQLHandler.cpp index 36b05932979..7b078154252 100644 --- a/src/Server/PostgreSQLHandler.cpp +++ b/src/Server/PostgreSQLHandler.cpp @@ -58,7 +58,7 @@ void PostgreSQLHandler::run() session = std::make_unique(server.context(), ClientInfo::Interface::POSTGRESQL); SCOPE_EXIT({ session.reset(); }); - session->getClientInfo().connection_id = connection_id; + session->setClientConnectionId(connection_id); try { diff --git a/src/Server/TCPHandler.cpp b/src/Server/TCPHandler.cpp index 36566832ebc..a747f06f1ce 100644 --- a/src/Server/TCPHandler.cpp +++ b/src/Server/TCPHandler.cpp @@ -1177,21 +1177,12 @@ std::unique_ptr TCPHandler::makeSession() auto res = std::make_unique(server.context(), interface, socket().secure(), certificate); - auto & client_info = res->getClientInfo(); - client_info.forwarded_for = forwarded_for; - client_info.client_name = client_name; - client_info.client_version_major = client_version_major; - client_info.client_version_minor = client_version_minor; - client_info.client_version_patch = client_version_patch; - client_info.client_tcp_protocol_version = client_tcp_protocol_version; - - client_info.connection_client_version_major = client_version_major; - client_info.connection_client_version_minor = client_version_minor; - client_info.connection_client_version_patch = client_version_patch; - client_info.connection_tcp_protocol_version = client_tcp_protocol_version; - - client_info.quota_key = quota_key; - client_info.interface = interface; + res->setForwardedFor(forwarded_for); + res->setClientName(client_name); + res->setClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version); + res->setConnectionClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version); + res->setQuotaClientKey(quota_key); + res->setClientInterface(interface); return res; } @@ -1253,7 +1244,7 @@ void TCPHandler::receiveHello() } session = makeSession(); - auto & client_info = session->getClientInfo(); + const auto & client_info = session->getClientInfo(); #if USE_SSL /// Authentication with SSL user certificate @@ -1286,7 +1277,7 @@ void TCPHandler::receiveAddendum() { readStringBinary(quota_key, *in); if (!is_interserver_mode) - session->getClientInfo().quota_key = quota_key; + session->setQuotaClientKey(quota_key); } } diff --git a/src/Storages/Distributed/DistributedSink.cpp b/src/Storages/Distributed/DistributedSink.cpp index 875764f7633..0dcdae01ba9 100644 --- a/src/Storages/Distributed/DistributedSink.cpp +++ b/src/Storages/Distributed/DistributedSink.cpp @@ -132,7 +132,7 @@ DistributedSink::DistributedSink( const auto & settings = context->getSettingsRef(); if (settings.max_distributed_depth && context->getClientInfo().distributed_depth >= settings.max_distributed_depth) throw Exception(ErrorCodes::TOO_LARGE_DISTRIBUTED_DEPTH, "Maximum distributed depth exceeded"); - context->getClientInfo().distributed_depth += 1; + context->increaseDistributedDepth(); random_shard_insert = settings.insert_distributed_one_random_shard && !storage.has_sharding_key; } diff --git a/src/Storages/StorageDistributed.cpp b/src/Storages/StorageDistributed.cpp index c46192ab43b..c028cf5ec77 100644 --- a/src/Storages/StorageDistributed.cpp +++ b/src/Storages/StorageDistributed.cpp @@ -914,7 +914,7 @@ std::optional StorageDistributed::distributedWriteBetweenDistribu QueryPipeline pipeline; ContextMutablePtr query_context = Context::createCopy(local_context); - ++query_context->getClientInfo().distributed_depth; + query_context->increaseDistributedDepth(); for (size_t shard_index : collections::range(0, shards_info.size())) { @@ -976,7 +976,7 @@ std::optional StorageDistributed::distributedWriteFromClusterStor QueryPipeline pipeline; ContextMutablePtr query_context = Context::createCopy(local_context); - ++query_context->getClientInfo().distributed_depth; + query_context->increaseDistributedDepth(); /// Here we take addresses from destination cluster and assume source table exists on these nodes for (const auto & replicas : getCluster()->getShardsAddresses()) diff --git a/src/Storages/StorageReplicatedMergeTree.cpp b/src/Storages/StorageReplicatedMergeTree.cpp index 6894368841f..066f5a42f46 100644 --- a/src/Storages/StorageReplicatedMergeTree.cpp +++ b/src/Storages/StorageReplicatedMergeTree.cpp @@ -5079,7 +5079,7 @@ std::optional StorageReplicatedMergeTree::distributedWriteFromClu QueryPipeline pipeline; ContextMutablePtr query_context = Context::createCopy(local_context); - ++query_context->getClientInfo().distributed_depth; + query_context->increaseDistributedDepth(); for (const auto & replicas : src_cluster->getShardsAddresses()) { diff --git a/src/Storages/WindowView/StorageWindowView.cpp b/src/Storages/WindowView/StorageWindowView.cpp index 242e8e5d570..0f506040cd9 100644 --- a/src/Storages/WindowView/StorageWindowView.cpp +++ b/src/Storages/WindowView/StorageWindowView.cpp @@ -992,7 +992,7 @@ void StorageWindowView::cleanup() auto cleanup_context = Context::createCopy(getContext()); cleanup_context->makeQueryContext(); cleanup_context->setCurrentQueryId(""); - cleanup_context->getClientInfo().is_replicated_database_internal = true; + cleanup_context->setQueryKindReplicatedDatabaseInternal(); InterpreterAlterQuery interpreter_alter(alter_query, cleanup_context); interpreter_alter.execute(); From 2379d8c9d5f73be50c78978c73aea1c2c87044e0 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Mon, 17 Jul 2023 14:52:17 +0200 Subject: [PATCH 4/5] Revert unnecessary improving in ContextAccessCache for now. --- src/Access/AccessControl.cpp | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/Access/AccessControl.cpp b/src/Access/AccessControl.cpp index 41ac3f42ee2..bf0a2a0fbba 100644 --- a/src/Access/AccessControl.cpp +++ b/src/Access/AccessControl.cpp @@ -72,26 +72,20 @@ public: std::shared_ptr getContextAccess(const ContextAccessParams & params) { + std::lock_guard lock{mutex}; + auto x = cache.get(params); + if (x) { - std::lock_guard lock{mutex}; - auto x = cache.get(params); - if (x) - { - if ((*x)->getUserID() && !(*x)->tryGetUser()) - cache.remove(params); /// The user has been dropped while it was in the cache. - else - return *x; - } + if ((*x)->getUserID() && !(*x)->tryGetUser()) + cache.remove(params); /// The user has been dropped while it was in the cache. + else + return *x; } + /// TODO: There is no need to keep the `ContextAccessCache::mutex` locked while we're calculating access rights. auto res = std::make_shared(access_control, params); res->initialize(); - - { - std::lock_guard lock{mutex}; - cache.add(params, res); - } - + cache.add(params, res); return res; } From fff1ae73691bc3bbd409b1743a2c85d18412f868 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Mon, 17 Jul 2023 17:08:36 +0200 Subject: [PATCH 5/5] Use default destructor for ContextAccess. --- src/Access/ContextAccess.cpp | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/src/Access/ContextAccess.cpp b/src/Access/ContextAccess.cpp index 51bb7794735..9e9d8644539 100644 --- a/src/Access/ContextAccess.cpp +++ b/src/Access/ContextAccess.cpp @@ -235,20 +235,7 @@ ContextAccess::ContextAccess(const AccessControl & access_control_, const Params } -ContextAccess::~ContextAccess() -{ - enabled_settings.reset(); - enabled_quota.reset(); - enabled_row_policies.reset(); - row_policies_of_initial_user.reset(); - access_with_implicit.reset(); - access.reset(); - roles_info.reset(); - subscription_for_roles_changes.reset(); - enabled_roles.reset(); - subscription_for_user_change.reset(); - user.reset(); -} +ContextAccess::~ContextAccess() = default; void ContextAccess::initialize() @@ -265,12 +252,6 @@ void ContextAccess::initialize() if (!params.user_id) throw Exception(ErrorCodes::LOGICAL_ERROR, "No user in current context, it's a bug"); - if (!params.initial_user.empty()) - { - if (auto initial_user_id = access_control->find(params.initial_user)) - row_policies_of_initial_user = access_control->tryGetDefaultRowPolicies(*initial_user_id); - } - subscription_for_user_change = access_control->subscribeForChanges( *params.user_id, [weak_ptr = weak_from_this()](const UUID &, const AccessEntityPtr & entity) @@ -290,7 +271,8 @@ void ContextAccess::initialize() void ContextAccess::setUser(const UserPtr & user_) const { user = user_; - if (!user) + + if (!user_) { /// User has been dropped. user_was_dropped = true; @@ -301,6 +283,7 @@ void ContextAccess::setUser(const UserPtr & user_) const enabled_roles = nullptr; roles_info = nullptr; enabled_row_policies = nullptr; + row_policies_of_initial_user = nullptr; enabled_quota = nullptr; enabled_settings = nullptr; return; @@ -330,6 +313,11 @@ void ContextAccess::setUser(const UserPtr & user_) const }); setRolesInfo(enabled_roles->getRolesInfo()); + + std::optional initial_user_id; + if (!params.initial_user.empty()) + initial_user_id = access_control->find(params.initial_user); + row_policies_of_initial_user = initial_user_id ? access_control->tryGetDefaultRowPolicies(*initial_user_id) : nullptr; }