Merge pull request #51772 from vitlibar/no-keep-context-lock-while-calculating-access

Avoid keeping lock Context::getLock() while calculating access rights
This commit is contained in:
Nikita Mikhaylov 2023-07-20 14:43:18 +02:00 committed by GitHub
commit 0d4af0d5e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 786 additions and 351 deletions

View File

@ -1173,12 +1173,12 @@ void Client::processOptions(const OptionsDescription & options_description,
{
String traceparent = options["opentelemetry-traceparent"].as<std::string>();
String error;
if (!global_context->getClientInfo().client_trace_context.parseTraceparentHeader(traceparent, error))
if (!global_context->getClientTraceContext().parseTraceparentHeader(traceparent, error))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Cannot parse OpenTelemetry traceparent '{}': {}", traceparent, error);
}
if (options.count("opentelemetry-tracestate"))
global_context->getClientInfo().client_trace_context.tracestate = options["opentelemetry-tracestate"].as<std::string>();
global_context->getClientTraceContext().tracestate = options["opentelemetry-tracestate"].as<std::string>();
}
@ -1238,10 +1238,9 @@ void Client::processConfig()
global_context->getSettingsRef().max_insert_block_size);
}
ClientInfo & client_info = global_context->getClientInfo();
client_info.setInitialQuery();
client_info.quota_key = config().getString("quota_key", "");
client_info.query_kind = query_kind;
global_context->setQueryKindInitial();
global_context->setQuotaClientKey(config().getString("quota_key", ""));
global_context->setQueryKind(query_kind);
}

View File

@ -737,9 +737,8 @@ void LocalServer::processConfig()
for (const auto & [key, value] : prompt_substitutions)
boost::replace_all(prompt_by_server_display_name, "{" + key + "}", value);
ClientInfo & client_info = global_context->getClientInfo();
client_info.setInitialQuery();
client_info.query_kind = query_kind;
global_context->setQueryKindInitial();
global_context->setQueryKind(query_kind);
}

View File

@ -76,11 +76,13 @@ public:
auto x = cache.get(params);
if (x)
{
if ((*x)->tryGetUser())
if ((*x)->getUserID() && !(*x)->tryGetUser())
cache.remove(params); /// The user has been dropped while it was in the cache.
else
return *x;
/// No user, probably the user has been dropped while it was in the cache.
cache.remove(params);
}
/// TODO: There is no need to keep the `ContextAccessCache::mutex` locked while we're calculating access rights.
auto res = std::make_shared<ContextAccess>(access_control, params);
res->initialize();
cache.add(params, res);
@ -713,35 +715,6 @@ int AccessControl::getBcryptWorkfactor() const
}
std::shared_ptr<const ContextAccess> AccessControl::getContextAccess(
const UUID & user_id,
const std::vector<UUID> & current_roles,
bool use_default_roles,
const Settings & settings,
const String & current_database,
const ClientInfo & client_info) const
{
ContextAccessParams params;
params.user_id = user_id;
params.current_roles.insert(current_roles.begin(), current_roles.end());
params.use_default_roles = use_default_roles;
params.current_database = current_database;
params.readonly = settings.readonly;
params.allow_ddl = settings.allow_ddl;
params.allow_introspection = settings.allow_introspection_functions;
params.interface = client_info.interface;
params.http_method = client_info.http_method;
params.address = client_info.current_address.host();
params.quota_key = client_info.quota_key;
/// Extract the last entry from comma separated list of X-Forwarded-For addresses.
/// Only the last proxy can be trusted (if any).
params.forwarded_address = client_info.getLastForwardedFor();
return getContextAccess(params);
}
std::shared_ptr<const ContextAccess> AccessControl::getContextAccess(const ContextAccessParams & params) const
{
return context_access_cache->getContextAccess(params);

View File

@ -25,7 +25,7 @@ namespace Poco
namespace DB
{
class ContextAccess;
struct ContextAccessParams;
class ContextAccessParams;
struct User;
using UserPtr = std::shared_ptr<const User>;
class EnabledRoles;
@ -181,14 +181,6 @@ public:
void setSettingsConstraintsReplacePrevious(bool enable) { settings_constraints_replace_previous = enable; }
bool doesSettingsConstraintsReplacePrevious() const { return settings_constraints_replace_previous; }
std::shared_ptr<const ContextAccess> getContextAccess(
const UUID & user_id,
const std::vector<UUID> & current_roles,
bool use_default_roles,
const Settings & settings,
const String & current_database,
const ClientInfo & client_info) const;
std::shared_ptr<const ContextAccess> getContextAccess(const ContextAccessParams & params) const;
std::shared_ptr<const EnabledRoles> getEnabledRoles(

View File

@ -10,6 +10,7 @@
#include <Access/EnabledSettings.h>
#include <Access/SettingsProfilesInfo.h>
#include <Interpreters/DatabaseCatalog.h>
#include <Interpreters/Context.h>
#include <Common/Exception.h>
#include <Common/quoteString.h>
#include <Core/Settings.h>
@ -221,6 +222,12 @@ namespace
}
std::shared_ptr<const ContextAccess> ContextAccess::fromContext(const ContextPtr & context)
{
return context->getAccess();
}
ContextAccess::ContextAccess(const AccessControl & access_control_, const Params & params_)
: access_control(&access_control_)
, params(params_)
@ -228,48 +235,44 @@ ContextAccess::ContextAccess(const AccessControl & access_control_, const Params
}
ContextAccess::ContextAccess(FullAccess)
: is_full_access(true), access(std::make_shared<AccessRights>(AccessRights::getFullAccess())), access_with_implicit(access)
{
}
ContextAccess::~ContextAccess()
{
enabled_settings.reset();
enabled_quota.reset();
enabled_row_policies.reset();
access_with_implicit.reset();
access.reset();
roles_info.reset();
subscription_for_roles_changes.reset();
enabled_roles.reset();
subscription_for_user_change.reset();
user.reset();
}
ContextAccess::~ContextAccess() = default;
void ContextAccess::initialize()
{
std::lock_guard lock{mutex};
subscription_for_user_change = access_control->subscribeForChanges(
*params.user_id, [weak_ptr = weak_from_this()](const UUID &, const AccessEntityPtr & entity)
{
auto ptr = weak_ptr.lock();
if (!ptr)
return;
UserPtr changed_user = entity ? typeid_cast<UserPtr>(entity) : nullptr;
std::lock_guard lock2{ptr->mutex};
ptr->setUser(changed_user);
});
setUser(access_control->read<User>(*params.user_id));
std::lock_guard lock{mutex};
if (params.full_access)
{
access = std::make_shared<AccessRights>(AccessRights::getFullAccess());
access_with_implicit = access;
return;
}
if (!params.user_id)
throw Exception(ErrorCodes::LOGICAL_ERROR, "No user in current context, it's a bug");
subscription_for_user_change = access_control->subscribeForChanges(
*params.user_id,
[weak_ptr = weak_from_this()](const UUID &, const AccessEntityPtr & entity)
{
auto ptr = weak_ptr.lock();
if (!ptr)
return;
UserPtr changed_user = entity ? typeid_cast<UserPtr>(entity) : nullptr;
std::lock_guard lock2{ptr->mutex};
ptr->setUser(changed_user);
});
setUser(access_control->read<User>(*params.user_id));
}
void ContextAccess::setUser(const UserPtr & user_) const
{
user = user_;
if (!user)
if (!user_)
{
/// User has been dropped.
user_was_dropped = true;
@ -280,6 +283,7 @@ void ContextAccess::setUser(const UserPtr & user_) const
enabled_roles = nullptr;
roles_info = nullptr;
enabled_row_policies = nullptr;
row_policies_of_initial_user = nullptr;
enabled_quota = nullptr;
enabled_settings = nullptr;
return;
@ -294,10 +298,10 @@ void ContextAccess::setUser(const UserPtr & user_) const
current_roles = user->granted_roles.findGranted(user->default_roles);
current_roles_with_admin_option = user->granted_roles.findGrantedWithAdminOption(user->default_roles);
}
else
else if (params.current_roles)
{
current_roles = user->granted_roles.findGranted(params.current_roles);
current_roles_with_admin_option = user->granted_roles.findGrantedWithAdminOption(params.current_roles);
current_roles = user->granted_roles.findGranted(*params.current_roles);
current_roles_with_admin_option = user->granted_roles.findGrantedWithAdminOption(*params.current_roles);
}
subscription_for_roles_changes.reset();
@ -309,6 +313,11 @@ void ContextAccess::setUser(const UserPtr & user_) const
});
setRolesInfo(enabled_roles->getRolesInfo());
std::optional<UUID> initial_user_id;
if (!params.initial_user.empty())
initial_user_id = access_control->find<User>(params.initial_user);
row_policies_of_initial_user = initial_user_id ? access_control->tryGetDefaultRowPolicies(*initial_user_id) : nullptr;
}
@ -316,12 +325,15 @@ void ContextAccess::setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> &
{
assert(roles_info_);
roles_info = roles_info_;
enabled_row_policies = access_control->getEnabledRowPolicies(
*params.user_id, roles_info->enabled_roles);
enabled_row_policies = access_control->getEnabledRowPolicies(*params.user_id, roles_info->enabled_roles);
enabled_quota = access_control->getEnabledQuota(
*params.user_id, user_name, roles_info->enabled_roles, params.address, params.forwarded_address, params.quota_key);
enabled_settings = access_control->getEnabledSettings(
*params.user_id, user->settings, roles_info->enabled_roles, roles_info->settings_from_enabled_roles);
calculateAccessRights();
}
@ -381,21 +393,24 @@ std::shared_ptr<const EnabledRolesInfo> ContextAccess::getRolesInfo() const
return no_roles;
}
std::shared_ptr<const EnabledRowPolicies> ContextAccess::getEnabledRowPolicies() const
RowPolicyFilterPtr ContextAccess::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const
{
std::lock_guard lock{mutex};
if (enabled_row_policies)
return enabled_row_policies;
static const auto no_row_policies = std::make_shared<EnabledRowPolicies>();
return no_row_policies;
}
RowPolicyFilterPtr ContextAccess::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, RowPolicyFilterPtr combine_with_filter) const
{
std::lock_guard lock{mutex};
RowPolicyFilterPtr filter;
if (enabled_row_policies)
return enabled_row_policies->getFilter(database, table_name, filter_type, combine_with_filter);
return combine_with_filter;
filter = enabled_row_policies->getFilter(database, table_name, filter_type);
if (row_policies_of_initial_user)
{
/// Find and set extra row policies to be used based on `client_info.initial_user`, if the initial user exists.
/// TODO: we need a better solution here. It seems we should pass the initial row policy
/// because a shard is allowed to not have the initial user or it might be another user
/// with the same name.
filter = row_policies_of_initial_user->getFilter(database, table_name, filter_type, filter);
}
return filter;
}
std::shared_ptr<const EnabledQuota> ContextAccess::getQuota() const
@ -417,14 +432,6 @@ std::optional<QuotaUsage> ContextAccess::getQuotaUsage() const
}
std::shared_ptr<const ContextAccess> ContextAccess::getFullAccess()
{
static const std::shared_ptr<const ContextAccess> res =
[] { return std::shared_ptr<ContextAccess>(new ContextAccess{kFullAccess}); }();
return res;
}
SettingsChanges ContextAccess::getDefaultSettings() const
{
std::lock_guard lock{mutex};
@ -478,7 +485,7 @@ bool ContextAccess::checkAccessImplHelper(AccessFlags flags, const Args &... arg
throw Exception(ErrorCodes::UNKNOWN_USER, "{}: User has been dropped", getUserName());
}
if (is_full_access)
if (params.full_access)
return true;
auto access_granted = [&]
@ -706,7 +713,7 @@ bool ContextAccess::checkAdminOptionImplHelper(const Container & role_ids, const
return false;
};
if (is_full_access)
if (params.full_access)
return true;
if (user_was_dropped)
@ -806,7 +813,7 @@ void ContextAccess::checkAdminOption(const std::vector<UUID> & role_ids, const s
void ContextAccess::checkGranteeIsAllowed(const UUID & grantee_id, const IAccessEntity & grantee) const
{
if (is_full_access)
if (params.full_access)
return;
auto current_user = getUser();
@ -816,7 +823,7 @@ void ContextAccess::checkGranteeIsAllowed(const UUID & grantee_id, const IAccess
void ContextAccess::checkGranteesAreAllowed(const std::vector<UUID> & grantee_ids) const
{
if (is_full_access)
if (params.full_access)
return;
auto current_user = getUser();

View File

@ -1,6 +1,7 @@
#pragma once
#include <Access/AccessRights.h>
#include <Access/ContextAccessParams.h>
#include <Access/EnabledRowPolicies.h>
#include <Interpreters/ClientInfo.h>
#include <Core/UUID.h>
@ -30,47 +31,18 @@ class AccessControl;
class IAST;
struct IAccessEntity;
using ASTPtr = std::shared_ptr<IAST>;
struct ContextAccessParams
{
std::optional<UUID> user_id;
boost::container::flat_set<UUID> current_roles;
bool use_default_roles = false;
UInt64 readonly = 0;
bool allow_ddl = false;
bool allow_introspection = false;
String current_database;
ClientInfo::Interface interface = ClientInfo::Interface::TCP;
ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
Poco::Net::IPAddress address;
String forwarded_address;
String quota_key;
auto toTuple() const
{
return std::tie(
user_id, current_roles, use_default_roles, readonly, allow_ddl, allow_introspection,
current_database, interface, http_method, address, forwarded_address, quota_key);
}
friend bool operator ==(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return lhs.toTuple() == rhs.toTuple(); }
friend bool operator !=(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return !(lhs == rhs); }
friend bool operator <(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return lhs.toTuple() < rhs.toTuple(); }
friend bool operator >(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return rhs < lhs; }
friend bool operator <=(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return !(rhs < lhs); }
friend bool operator >=(const ContextAccessParams & lhs, const ContextAccessParams & rhs) { return !(lhs < rhs); }
};
class Context;
using ContextPtr = std::shared_ptr<const Context>;
class ContextAccess : public std::enable_shared_from_this<ContextAccess>
{
public:
static std::shared_ptr<const ContextAccess> fromContext(const ContextPtr & context);
using Params = ContextAccessParams;
const Params & getParams() const { return params; }
ContextAccess(const AccessControl & access_control_, const Params & params_);
/// Returns the current user. Throws if user is nullptr.
UserPtr getUser() const;
/// Same as above, but can return nullptr.
@ -81,12 +53,9 @@ public:
/// Returns information about current and enabled roles.
std::shared_ptr<const EnabledRolesInfo> getRolesInfo() const;
/// Returns information about enabled row policies.
std::shared_ptr<const EnabledRowPolicies> getEnabledRowPolicies() const;
/// Returns the row policy filter for a specified table.
/// The function returns nullptr if there is no filter to apply.
RowPolicyFilterPtr getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type, RowPolicyFilterPtr combine_with_filter = {}) const;
RowPolicyFilterPtr getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const;
/// Returns the quota to track resource consumption.
std::shared_ptr<const EnabledQuota> getQuota() const;
@ -161,22 +130,12 @@ public:
/// Checks if grantees are allowed for the current user, throws an exception if not.
void checkGranteesAreAllowed(const std::vector<UUID> & grantee_ids) const;
/// Makes an instance of ContextAccess which provides full access to everything
/// without any limitations. This is used for the global context.
static std::shared_ptr<const ContextAccess> getFullAccess();
ContextAccess(const AccessControl & access_control_, const Params & params_);
~ContextAccess();
private:
friend class AccessControl;
struct FullAccess {};
static const FullAccess kFullAccess;
/// Makes an instance of ContextAccess which provides full access to everything
/// without any limitations. This is used for the global context.
explicit ContextAccess(FullAccess);
void initialize();
void setUser(const UserPtr & user_) const TSA_REQUIRES(mutex);
void setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & roles_info_) const TSA_REQUIRES(mutex);
@ -223,7 +182,6 @@ private:
const AccessControl * access_control = nullptr;
const Params params;
const bool is_full_access = false;
mutable std::atomic<bool> user_was_dropped = false;
mutable std::atomic<Poco::Logger *> trace_log = nullptr;
@ -237,6 +195,7 @@ private:
mutable std::shared_ptr<const AccessRights> access TSA_GUARDED_BY(mutex);
mutable std::shared_ptr<const AccessRights> access_with_implicit TSA_GUARDED_BY(mutex);
mutable std::shared_ptr<const EnabledRowPolicies> enabled_row_policies TSA_GUARDED_BY(mutex);
mutable std::shared_ptr<const EnabledRowPolicies> row_policies_of_initial_user TSA_GUARDED_BY(mutex);
mutable std::shared_ptr<const EnabledQuota> enabled_quota TSA_GUARDED_BY(mutex);
mutable std::shared_ptr<const EnabledSettings> enabled_settings TSA_GUARDED_BY(mutex);

View File

@ -0,0 +1,177 @@
#include <Access/ContextAccessParams.h>
#include <Core/Settings.h>
#include <Common/typeid_cast.h>
namespace DB
{
ContextAccessParams::ContextAccessParams(
const std::optional<UUID> user_id_,
bool full_access_,
bool use_default_roles_,
const std::shared_ptr<const std::vector<UUID>> & current_roles_,
const Settings & settings_,
const String & current_database_,
const ClientInfo & client_info_)
: user_id(user_id_)
, full_access(full_access_)
, use_default_roles(use_default_roles_)
, current_roles(current_roles_)
, readonly(settings_.readonly)
, allow_ddl(settings_.allow_ddl)
, allow_introspection(settings_.allow_introspection_functions)
, current_database(current_database_)
, interface(client_info_.interface)
, http_method(client_info_.http_method)
, address(client_info_.current_address.host())
, forwarded_address(client_info_.getLastForwardedFor())
, quota_key(client_info_.quota_key)
, initial_user((client_info_.initial_user != client_info_.current_user) ? client_info_.initial_user : "")
{
}
String ContextAccessParams::toString() const
{
WriteBufferFromOwnString out;
auto separator = [&] { return out.stringView().empty() ? "" : ", "; };
if (user_id)
out << separator() << "user_id = " << *user_id;
if (full_access)
out << separator() << "full_access = " << full_access;
if (use_default_roles)
out << separator() << "use_default_roles = " << use_default_roles;
if (current_roles && !current_roles->empty())
{
out << separator() << "current_roles = [";
for (size_t i = 0; i != current_roles->size(); ++i)
{
if (i)
out << ", ";
out << (*current_roles)[i];
}
out << "]";
}
if (readonly)
out << separator() << "readonly = " << readonly;
if (allow_ddl)
out << separator() << "allow_ddl = " << allow_ddl;
if (allow_introspection)
out << separator() << "allow_introspection = " << allow_introspection;
if (!current_database.empty())
out << separator() << "current_database = " << current_database;
out << separator() << "interface = " << magic_enum::enum_name(interface);
if (http_method != ClientInfo::HTTPMethod::UNKNOWN)
out << separator() << "http_method = " << magic_enum::enum_name(http_method);
if (!address.isWildcard())
out << separator() << "address = " << address.toString();
if (!forwarded_address.empty())
out << separator() << "forwarded_address = " << forwarded_address;
if (!quota_key.empty())
out << separator() << "quota_key = " << quota_key;
if (!initial_user.empty())
out << separator() << "initial_user = " << initial_user;
return out.str();
}
bool operator ==(const ContextAccessParams & left, const ContextAccessParams & right)
{
auto check_equals = [](const auto & x, const auto & y)
{
if constexpr (::detail::is_shared_ptr_v<std::remove_cvref_t<decltype(x)>>)
{
if (!x)
return !y;
else if (!y)
return false;
else
return *x == *y;
}
else
{
return x == y;
}
};
#define CONTEXT_ACCESS_PARAMS_EQUALS(name) \
if (!check_equals(left.name, right.name)) \
return false;
CONTEXT_ACCESS_PARAMS_EQUALS(user_id)
CONTEXT_ACCESS_PARAMS_EQUALS(full_access)
CONTEXT_ACCESS_PARAMS_EQUALS(use_default_roles)
CONTEXT_ACCESS_PARAMS_EQUALS(current_roles)
CONTEXT_ACCESS_PARAMS_EQUALS(readonly)
CONTEXT_ACCESS_PARAMS_EQUALS(allow_ddl)
CONTEXT_ACCESS_PARAMS_EQUALS(allow_introspection)
CONTEXT_ACCESS_PARAMS_EQUALS(current_database)
CONTEXT_ACCESS_PARAMS_EQUALS(interface)
CONTEXT_ACCESS_PARAMS_EQUALS(http_method)
CONTEXT_ACCESS_PARAMS_EQUALS(address)
CONTEXT_ACCESS_PARAMS_EQUALS(forwarded_address)
CONTEXT_ACCESS_PARAMS_EQUALS(quota_key)
CONTEXT_ACCESS_PARAMS_EQUALS(initial_user)
#undef CONTEXT_ACCESS_PARAMS_EQUALS
return true; /// All fields are equal, operator == must return true.
}
bool operator <(const ContextAccessParams & left, const ContextAccessParams & right)
{
auto check_less = [](const auto & x, const auto & y)
{
if constexpr (::detail::is_shared_ptr_v<std::remove_cvref_t<decltype(x)>>)
{
if (!x)
return y ? -1 : 0;
else if (!y)
return 1;
else if (*x == *y)
return 0;
else if (*x < *y)
return -1;
else
return 1;
}
else
{
if (x == y)
return 0;
else if (x < y)
return -1;
else
return 1;
}
};
#define CONTEXT_ACCESS_PARAMS_LESS(name) \
if (auto cmp = check_less(left.name, right.name); cmp != 0) \
return cmp < 0;
CONTEXT_ACCESS_PARAMS_LESS(user_id)
CONTEXT_ACCESS_PARAMS_LESS(full_access)
CONTEXT_ACCESS_PARAMS_LESS(use_default_roles)
CONTEXT_ACCESS_PARAMS_LESS(current_roles)
CONTEXT_ACCESS_PARAMS_LESS(readonly)
CONTEXT_ACCESS_PARAMS_LESS(allow_ddl)
CONTEXT_ACCESS_PARAMS_LESS(allow_introspection)
CONTEXT_ACCESS_PARAMS_LESS(current_database)
CONTEXT_ACCESS_PARAMS_LESS(interface)
CONTEXT_ACCESS_PARAMS_LESS(http_method)
CONTEXT_ACCESS_PARAMS_LESS(address)
CONTEXT_ACCESS_PARAMS_LESS(forwarded_address)
CONTEXT_ACCESS_PARAMS_LESS(quota_key)
CONTEXT_ACCESS_PARAMS_LESS(initial_user)
#undef CONTEXT_ACCESS_PARAMS_LESS
return false; /// All fields are equal, operator < must return false.
}
bool ContextAccessParams::dependsOnSettingName(std::string_view setting_name)
{
return (setting_name == "readonly") || (setting_name == "allow_ddl") || (setting_name == "allow_introspection_functions");
}
}

View File

@ -0,0 +1,67 @@
#pragma once
#include <Interpreters/ClientInfo.h>
#include <Core/UUID.h>
#include <optional>
#include <vector>
namespace DB
{
struct Settings;
/// Parameters which are used to calculate access rights and some related stuff like roles or constraints.
class ContextAccessParams
{
public:
ContextAccessParams(
const std::optional<UUID> user_id_,
bool full_access_,
bool use_default_roles_,
const std::shared_ptr<const std::vector<UUID>> & current_roles_,
const Settings & settings_,
const String & current_database_,
const ClientInfo & client_info_);
const std::optional<UUID> user_id;
/// Full access to everything without any limitations.
/// This is used for the global context.
const bool full_access;
const bool use_default_roles;
const std::shared_ptr<const std::vector<UUID>> current_roles;
const UInt64 readonly;
const bool allow_ddl;
const bool allow_introspection;
const String current_database;
const ClientInfo::Interface interface;
const ClientInfo::HTTPMethod http_method;
const Poco::Net::IPAddress address;
/// The last entry from comma separated list of X-Forwarded-For addresses.
/// Only the last proxy can be trusted (if any).
const String forwarded_address;
const String quota_key;
/// Initial user is used to combine row policies with.
const String initial_user;
/// Outputs `ContextAccessParams` to string for logging.
String toString() const;
friend bool operator <(const ContextAccessParams & left, const ContextAccessParams & right);
friend bool operator ==(const ContextAccessParams & left, const ContextAccessParams & right);
friend bool operator !=(const ContextAccessParams & left, const ContextAccessParams & right) { return !(left == right); }
friend bool operator >(const ContextAccessParams & left, const ContextAccessParams & right) { return right < left; }
friend bool operator <=(const ContextAccessParams & left, const ContextAccessParams & right) { return !(right < left); }
friend bool operator >=(const ContextAccessParams & left, const ContextAccessParams & right) { return !(left < right); }
static bool dependsOnSettingName(std::string_view setting_name);
};
}

View File

@ -814,8 +814,8 @@ void DatabaseReplicated::recoverLostReplica(const ZooKeeperPtr & current_zookeep
{
auto query_context = Context::createCopy(getContext());
query_context->makeQueryContext();
query_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY;
query_context->getClientInfo().is_replicated_database_internal = true;
query_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY);
query_context->setQueryKindReplicatedDatabaseInternal();
query_context->setCurrentDatabase(getDatabaseName());
query_context->setCurrentQueryId("");
auto txn = std::make_shared<ZooKeeperMetadataTransaction>(current_zookeeper, zookeeper_path, false, "");

View File

@ -60,7 +60,7 @@ static ContextMutablePtr createQueryContext(ContextPtr context)
query_context->setSettings(new_query_settings);
query_context->setInternalQuery(true);
query_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY;
query_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY);
query_context->setCurrentQueryId(""); // generate random query_id
return query_context;
}

View File

@ -426,12 +426,10 @@ try
auto insert_query_id = insert_context->getCurrentQueryId();
auto query_start_time = std::chrono::system_clock::now();
Stopwatch start_watch{CLOCK_MONOTONIC};
ClientInfo & client_info = insert_context->getClientInfo();
client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
client_info.initial_query_start_time = timeInSeconds(query_start_time);
client_info.initial_query_start_time_microseconds = timeInMicroseconds(query_start_time);
client_info.current_query_id = insert_query_id;
client_info.initial_query_id = insert_query_id;
insert_context->setQueryKind(ClientInfo::QueryKind::INITIAL_QUERY);
insert_context->setInitialQueryStartTime(query_start_time);
insert_context->setCurrentQueryId(insert_query_id);
insert_context->setInitialQueryId(insert_query_id);
size_t log_queries_cut_to_length = insert_context->getSettingsRef().log_queries_cut_to_length;
String query_for_logging = insert_query.hasSecretParts()
? insert_query.formatForLogging(log_queries_cut_to_length)

View File

@ -171,7 +171,7 @@ void executeQuery(
SelectStreamFactory::Shards remote_shards;
auto new_context = updateSettingsForCluster(*query_info.getCluster(), context, settings, main_table, &query_info, log);
new_context->getClientInfo().distributed_depth += 1;
new_context->increaseDistributedDepth();
size_t shards = query_info.getCluster()->getShardCount();
for (const auto & shard_info : query_info.getCluster()->getShardsInfo())

View File

@ -1059,25 +1059,54 @@ ConfigurationPtr Context::getUsersConfig()
return shared->users_config;
}
void Context::setUser(const UUID & user_id_)
void Context::setUser(const UUID & user_id_, bool set_current_profiles_, bool set_current_roles_, bool set_current_database_)
{
/// Prepare lists of user's profiles, constraints, settings, roles.
std::shared_ptr<const User> user;
std::shared_ptr<const ContextAccess> temp_access;
if (set_current_profiles_ || set_current_roles_ || set_current_database_)
{
std::optional<ContextAccessParams> params;
{
auto lock = getLock();
params.emplace(ContextAccessParams{user_id_, /* full_access= */ false, /* use_default_roles = */ true, {}, settings, current_database, client_info});
}
/// `temp_access` is used here only to extract information about the user, not to actually check access.
/// NOTE: AccessControl::getContextAccess() may require some IO work, so Context::getLock() must be unlocked while we're doing this.
temp_access = getAccessControl().getContextAccess(*params);
user = temp_access->getUser();
}
std::shared_ptr<const SettingsProfilesInfo> profiles;
if (set_current_profiles_)
profiles = temp_access->getDefaultProfileInfo();
std::optional<std::vector<UUID>> roles;
if (set_current_roles_)
roles = user->granted_roles.findGranted(user->default_roles);
String database;
if (set_current_database_)
database = user->default_database;
/// Apply user's profiles, constraints, settings, roles.
auto lock = getLock();
user_id = user_id_;
setUserID(user_id_);
access = getAccessControl().getContextAccess(
user_id_, /* current_roles = */ {}, /* use_default_roles = */ true, settings, current_database, client_info);
if (profiles)
{
/// A profile can specify a value and a readonly constraint for same setting at the same time,
/// so we shouldn't check constraints here.
setCurrentProfiles(*profiles, /* check_constraints= */ false);
}
auto user = access->getUser();
if (roles)
setCurrentRoles(*roles);
current_roles = std::make_shared<std::vector<UUID>>(user->granted_roles.findGranted(user->default_roles));
auto default_profile_info = access->getDefaultProfileInfo();
settings_constraints_and_current_profiles = default_profile_info->getConstraintsAndProfileIDs();
applySettingsChanges(default_profile_info->settings);
if (!user->default_database.empty())
setCurrentDatabase(user->default_database);
if (!database.empty())
setCurrentDatabase(database);
}
std::shared_ptr<const User> Context::getUser() const
@ -1090,6 +1119,13 @@ String Context::getUserName() const
return getAccess()->getUserName();
}
void Context::setUserID(const UUID & user_id_)
{
auto lock = getLock();
user_id = user_id_;
need_recalculate_access = true;
}
std::optional<UUID> Context::getUserID() const
{
auto lock = getLock();
@ -1107,10 +1143,11 @@ void Context::setQuotaKey(String quota_key_)
void Context::setCurrentRoles(const std::vector<UUID> & current_roles_)
{
auto lock = getLock();
if (current_roles ? (*current_roles == current_roles_) : current_roles_.empty())
return;
current_roles = std::make_shared<std::vector<UUID>>(current_roles_);
calculateAccessRights();
if (current_roles_.empty())
current_roles = nullptr;
else
current_roles = std::make_shared<std::vector<UUID>>(current_roles_);
need_recalculate_access = true;
}
void Context::setCurrentRolesDefault()
@ -1135,20 +1172,6 @@ std::shared_ptr<const EnabledRolesInfo> Context::getRolesInfo() const
}
void Context::calculateAccessRights()
{
auto lock = getLock();
if (user_id)
access = getAccessControl().getContextAccess(
*user_id,
current_roles ? *current_roles : std::vector<UUID>{},
/* use_default_roles = */ false,
settings,
current_database,
client_info);
}
template <typename... Args>
void Context::checkAccessImpl(const Args &... args) const
{
@ -1168,32 +1191,55 @@ void Context::checkAccess(const AccessFlags & flags, const StorageID & table_id,
void Context::checkAccess(const AccessRightsElement & element) const { return checkAccessImpl(element); }
void Context::checkAccess(const AccessRightsElements & elements) const { return checkAccessImpl(elements); }
std::shared_ptr<const ContextAccess> Context::getAccess() const
{
auto lock = getLock();
return access ? access : ContextAccess::getFullAccess();
/// A helper function to collect parameters for calculating access rights, called with Context::getLock() acquired.
auto get_params = [this]()
{
/// If setUserID() was never called then this must be the global context with the full access.
bool full_access = !user_id;
return ContextAccessParams{user_id, full_access, /* use_default_roles= */ false, current_roles, settings, current_database, client_info};
};
/// Check if the current access rights are still valid, otherwise get parameters for recalculating access rights.
std::optional<ContextAccessParams> params;
{
auto lock = getLock();
if (access && !need_recalculate_access)
return access; /// No need to recalculate access rights.
params.emplace(get_params());
if (access && (access->getParams() == *params))
{
need_recalculate_access = false;
return access; /// No need to recalculate access rights.
}
}
/// Calculate new access rights according to the collected parameters.
/// NOTE: AccessControl::getContextAccess() may require some IO work, so Context::getLock() must be unlocked while we're doing this.
auto res = getAccessControl().getContextAccess(*params);
{
/// If the parameters of access rights were not changed while we were calculated them
/// then we store the new access rights in the Context to allow reusing it later.
auto lock = getLock();
if (get_params() == *params)
{
access = res;
need_recalculate_access = false;
}
}
return res;
}
RowPolicyFilterPtr Context::getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const
{
auto lock = getLock();
RowPolicyFilterPtr row_filter_of_initial_user;
if (row_policies_of_initial_user)
row_filter_of_initial_user = row_policies_of_initial_user->getFilter(database, table_name, filter_type);
return getAccess()->getRowPolicyFilter(database, table_name, filter_type, row_filter_of_initial_user);
}
void Context::enableRowPoliciesOfInitialUser()
{
auto lock = getLock();
row_policies_of_initial_user = nullptr;
if (client_info.initial_user == client_info.current_user)
return;
auto initial_user_id = getAccessControl().find<User>(client_info.initial_user);
if (!initial_user_id)
return;
row_policies_of_initial_user = getAccessControl().tryGetDefaultRowPolicies(*initial_user_id);
return getAccess()->getRowPolicyFilter(database, table_name, filter_type);
}
@ -1209,13 +1255,12 @@ std::optional<QuotaUsage> Context::getQuotaUsage() const
}
void Context::setCurrentProfile(const String & profile_name)
void Context::setCurrentProfile(const String & profile_name, bool check_constraints)
{
auto lock = getLock();
try
{
UUID profile_id = getAccessControl().getID<SettingsProfile>(profile_name);
setCurrentProfile(profile_id);
setCurrentProfile(profile_id, check_constraints);
}
catch (Exception & e)
{
@ -1224,15 +1269,20 @@ void Context::setCurrentProfile(const String & profile_name)
}
}
void Context::setCurrentProfile(const UUID & profile_id)
void Context::setCurrentProfile(const UUID & profile_id, bool check_constraints)
{
auto lock = getLock();
auto profile_info = getAccessControl().getSettingsProfileInfo(profile_id);
checkSettingsConstraints(profile_info->settings);
applySettingsChanges(profile_info->settings);
settings_constraints_and_current_profiles = profile_info->getConstraintsAndProfileIDs(settings_constraints_and_current_profiles);
setCurrentProfiles(*profile_info, check_constraints);
}
void Context::setCurrentProfiles(const SettingsProfilesInfo & profiles_info, bool check_constraints)
{
auto lock = getLock();
if (check_constraints)
checkSettingsConstraints(profiles_info.settings);
applySettingsChanges(profiles_info.settings);
settings_constraints_and_current_profiles = profiles_info.getConstraintsAndProfileIDs(settings_constraints_and_current_profiles);
}
std::vector<UUID> Context::getCurrentProfiles() const
{
@ -1706,27 +1756,8 @@ Settings Context::getSettings() const
void Context::setSettings(const Settings & settings_)
{
auto lock = getLock();
const auto old_readonly = settings.readonly;
const auto old_allow_ddl = settings.allow_ddl;
const auto old_allow_introspection_functions = settings.allow_introspection_functions;
const auto old_display_secrets = settings.format_display_secrets_in_show_and_select;
settings = settings_;
if ((settings.readonly != old_readonly)
|| (settings.allow_ddl != old_allow_ddl)
|| (settings.allow_introspection_functions != old_allow_introspection_functions)
|| (settings.format_display_secrets_in_show_and_select != old_display_secrets))
calculateAccessRights();
}
void Context::recalculateAccessRightsIfNeeded(std::string_view name)
{
if (name == "readonly"
|| name == "allow_ddl"
|| name == "allow_introspection_functions"
|| name == "format_display_secrets_in_show_and_select")
calculateAccessRights();
need_recalculate_access = true;
}
void Context::setSetting(std::string_view name, const String & value)
@ -1738,7 +1769,8 @@ void Context::setSetting(std::string_view name, const String & value)
return;
}
settings.set(name, value);
recalculateAccessRightsIfNeeded(name);
if (ContextAccessParams::dependsOnSettingName(name))
need_recalculate_access = true;
}
void Context::setSetting(std::string_view name, const Field & value)
@ -1750,7 +1782,8 @@ void Context::setSetting(std::string_view name, const Field & value)
return;
}
settings.set(name, value);
recalculateAccessRightsIfNeeded(name);
if (ContextAccessParams::dependsOnSettingName(name))
need_recalculate_access = true;
}
void Context::applySettingChange(const SettingChange & change)
@ -1859,7 +1892,7 @@ void Context::setCurrentDatabase(const String & name)
DatabaseCatalog::instance().assertDatabaseExists(name);
auto lock = getLock();
current_database = name;
calculateAccessRights();
need_recalculate_access = true;
}
void Context::setCurrentQueryId(const String & query_id)
@ -3833,6 +3866,129 @@ void Context::resetInputCallbacks()
}
void Context::setClientInfo(const ClientInfo & client_info_)
{
client_info = client_info_;
need_recalculate_access = true;
}
void Context::setClientName(const String & client_name)
{
client_info.client_name = client_name;
}
void Context::setClientInterface(ClientInfo::Interface interface)
{
client_info.interface = interface;
need_recalculate_access = true;
}
void Context::setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version)
{
client_info.client_version_major = client_version_major;
client_info.client_version_minor = client_version_minor;
client_info.client_version_patch = client_version_patch;
client_info.client_tcp_protocol_version = client_tcp_protocol_version;
}
void Context::setClientConnectionId(uint32_t connection_id_)
{
client_info.connection_id = connection_id_;
}
void Context::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer)
{
client_info.http_method = http_method;
client_info.http_user_agent = http_user_agent;
client_info.http_referer = http_referer;
need_recalculate_access = true;
}
void Context::setForwardedFor(const String & forwarded_for)
{
client_info.forwarded_for = forwarded_for;
need_recalculate_access = true;
}
void Context::setQueryKind(ClientInfo::QueryKind query_kind)
{
client_info.query_kind = query_kind;
}
void Context::setQueryKindInitial()
{
/// TODO: Try to combine this function with setQueryKind().
client_info.setInitialQuery();
}
void Context::setQueryKindReplicatedDatabaseInternal()
{
/// TODO: Try to combine this function with setQueryKind().
client_info.is_replicated_database_internal = true;
}
void Context::setCurrentUserName(const String & current_user_name)
{
/// TODO: Try to combine this function with setUser().
client_info.current_user = current_user_name;
need_recalculate_access = true;
}
void Context::setCurrentAddress(const Poco::Net::SocketAddress & current_address)
{
client_info.current_address = current_address;
need_recalculate_access = true;
}
void Context::setInitialUserName(const String & initial_user_name)
{
client_info.initial_user = initial_user_name;
need_recalculate_access = true;
}
void Context::setInitialAddress(const Poco::Net::SocketAddress & initial_address)
{
client_info.initial_address = initial_address;
}
void Context::setInitialQueryId(const String & initial_query_id)
{
client_info.initial_query_id = initial_query_id;
}
void Context::setInitialQueryStartTime(std::chrono::time_point<std::chrono::system_clock> initial_query_start_time)
{
client_info.initial_query_start_time = timeInSeconds(initial_query_start_time);
client_info.initial_query_start_time_microseconds = timeInMicroseconds(initial_query_start_time);
}
void Context::setQuotaClientKey(const String & quota_key_)
{
client_info.quota_key = quota_key_;
need_recalculate_access = true;
}
void Context::setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version)
{
client_info.connection_client_version_major = client_version_major;
client_info.connection_client_version_minor = client_version_minor;
client_info.connection_client_version_patch = client_version_patch;
client_info.connection_tcp_protocol_version = client_tcp_protocol_version;
}
void Context::setReplicaInfo(bool collaborate_with_initiator, size_t all_replicas_count, size_t number_of_current_replica)
{
client_info.collaborate_with_initiator = collaborate_with_initiator;
client_info.count_participating_replicas = all_replicas_count;
client_info.number_of_current_replica = number_of_current_replica;
}
void Context::increaseDistributedDepth()
{
++client_info.distributed_depth;
}
StorageID Context::resolveStorageID(StorageID storage_id, StorageNamespace where) const
{
if (storage_id.uuid != UUIDHelpers::Nil)

View File

@ -51,8 +51,8 @@ struct ContextSharedPart;
class ContextAccess;
struct User;
using UserPtr = std::shared_ptr<const User>;
struct SettingsProfilesInfo;
struct EnabledRolesInfo;
class EnabledRowPolicies;
struct RowPolicyFilter;
using RowPolicyFilterPtr = std::shared_ptr<const RowPolicyFilter>;
class EnabledQuota;
@ -249,8 +249,8 @@ private:
std::optional<UUID> user_id;
std::shared_ptr<std::vector<UUID>> current_roles;
std::shared_ptr<const SettingsConstraintsAndProfileIDs> settings_constraints_and_current_profiles;
std::shared_ptr<const ContextAccess> access;
std::shared_ptr<const EnabledRowPolicies> row_policies_of_initial_user;
mutable std::shared_ptr<const ContextAccess> access;
mutable bool need_recalculate_access = true;
String current_database;
Settings settings; /// Setting for query execution.
@ -530,12 +530,14 @@ public:
/// Sets the current user assuming that he/she is already authenticated.
/// WARNING: This function doesn't check password!
void setUser(const UUID & user_id_);
void setUser(const UUID & user_id_, bool set_current_profiles_ = true, bool set_current_roles_ = true, bool set_current_database_ = true);
UserPtr getUser() const;
String getUserName() const;
void setUserID(const UUID & user_id_);
std::optional<UUID> getUserID() const;
String getUserName() const;
void setQuotaKey(String quota_key_);
void setCurrentRoles(const std::vector<UUID> & current_roles_);
@ -544,8 +546,9 @@ public:
boost::container::flat_set<UUID> getEnabledRoles() const;
std::shared_ptr<const EnabledRolesInfo> getRolesInfo() const;
void setCurrentProfile(const String & profile_name);
void setCurrentProfile(const UUID & profile_id);
void setCurrentProfile(const String & profile_name, bool check_constraints = true);
void setCurrentProfile(const UUID & profile_id, bool check_constraints = true);
void setCurrentProfiles(const SettingsProfilesInfo & profiles_info, bool check_constraints = true);
std::vector<UUID> getCurrentProfiles() const;
std::vector<UUID> getEnabledProfiles() const;
@ -568,13 +571,6 @@ public:
RowPolicyFilterPtr getRowPolicyFilter(const String & database, const String & table_name, RowPolicyFilterType filter_type) const;
/// Finds and sets extra row policies to be used based on `client_info.initial_user`,
/// if the initial user exists.
/// TODO: we need a better solution here. It seems we should pass the initial row policy
/// because a shard is allowed to not have the initial user or it might be another user
/// with the same name.
void enableRowPoliciesOfInitialUser();
std::shared_ptr<const EnabledQuota> getQuota() const;
std::optional<QuotaUsage> getQuotaUsage() const;
@ -598,9 +594,33 @@ public:
InputBlocksReader getInputBlocksReaderCallback() const;
void resetInputCallbacks();
ClientInfo & getClientInfo() { return client_info; }
/// Returns information about the client executing a query.
const ClientInfo & getClientInfo() const { return client_info; }
/// Modify stored in the context information about the client executing a query.
void setClientInfo(const ClientInfo & client_info_);
void setClientName(const String & client_name);
void setClientInterface(ClientInfo::Interface interface);
void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
void setClientConnectionId(uint32_t connection_id);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer);
void setForwardedFor(const String & forwarded_for);
void setQueryKind(ClientInfo::QueryKind query_kind);
void setQueryKindInitial();
void setQueryKindReplicatedDatabaseInternal();
void setCurrentUserName(const String & current_user_name);
void setCurrentAddress(const Poco::Net::SocketAddress & current_address);
void setInitialUserName(const String & initial_user_name);
void setInitialAddress(const Poco::Net::SocketAddress & initial_address);
void setInitialQueryId(const String & initial_query_id);
void setInitialQueryStartTime(std::chrono::time_point<std::chrono::system_clock> initial_query_start_time);
void setQuotaClientKey(const String & quota_key);
void setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
void setReplicaInfo(bool collaborate_with_initiator, size_t all_replicas_count, size_t number_of_current_replica);
void increaseDistributedDepth();
const OpenTelemetry::TracingContext & getClientTraceContext() const { return client_info.client_trace_context; }
OpenTelemetry::TracingContext & getClientTraceContext() { return client_info.client_trace_context; }
enum StorageNamespace
{
ResolveGlobal = 1u, /// Database name must be specified
@ -1154,10 +1174,6 @@ private:
void initGlobal();
/// Compute and set actual user settings, client_info.current_user should be set
void calculateAccessRights();
void recalculateAccessRightsIfNeeded(std::string_view setting_name);
template <typename... Args>
void checkAccessImpl(const Args &... args) const;

View File

@ -199,7 +199,7 @@ ContextMutablePtr DDLTaskBase::makeQueryContext(ContextPtr from_context, const Z
auto query_context = Context::createCopy(from_context);
query_context->makeQueryContext();
query_context->setCurrentQueryId(""); // generate random query_id
query_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY;
query_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY);
if (entry.settings)
query_context->applySettingsChanges(*entry.settings);
return query_context;
@ -439,8 +439,8 @@ void DatabaseReplicatedTask::parseQueryFromEntry(ContextPtr context)
ContextMutablePtr DatabaseReplicatedTask::makeQueryContext(ContextPtr from_context, const ZooKeeperPtr & zookeeper)
{
auto query_context = DDLTaskBase::makeQueryContext(from_context, zookeeper);
query_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY;
query_context->getClientInfo().is_replicated_database_internal = true;
query_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY);
query_context->setQueryKindReplicatedDatabaseInternal();
query_context->setCurrentDatabase(database->getDatabaseName());
auto txn = std::make_shared<ZooKeeperMetadataTransaction>(zookeeper, database->zookeeper_path, is_initial_query, entry_path);

View File

@ -476,7 +476,7 @@ bool DDLWorker::tryExecuteQuery(DDLTaskBase & task, const ZooKeeperPtr & zookeep
query_context->setSetting("implicit_transaction", Field{0});
}
query_context->getClientInfo().initial_query_id = task.entry.initial_query_id;
query_context->setInitialQueryId(task.entry.initial_query_id);
if (!task.is_initial_query)
query_scope.emplace(query_context);

View File

@ -451,11 +451,11 @@ void InterpreterDropQuery::executeDropQuery(ASTDropQuery::Kind kind, ContextPtr
auto drop_context = Context::createCopy(global_context);
if (ignore_sync_setting)
drop_context->setSetting("database_atomic_wait_for_drop_and_detach_synchronously", false);
drop_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY;
drop_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY);
if (auto txn = current_context->getZooKeeperMetadataTransaction())
{
/// For Replicated database
drop_context->getClientInfo().is_replicated_database_internal = true;
drop_context->setQueryKindReplicatedDatabaseInternal();
drop_context->setQueryContext(std::const_pointer_cast<Context>(current_context));
drop_context->initZooKeeperMetadataTransaction(txn, true);
}

View File

@ -3183,7 +3183,7 @@ void InterpreterSelectQuery::initSettings()
if (query.settings())
InterpreterSetQuery(query.settings(), context).executeForCurrentContext(options.ignore_setting_constraints);
auto & client_info = context->getClientInfo();
const auto & client_info = context->getClientInfo();
auto min_major = DBMS_MIN_MAJOR_VERSION_WITH_CURRENT_AGGREGATION_VARIANT_SELECTION_METHOD;
auto min_minor = DBMS_MIN_MINOR_VERSION_WITH_CURRENT_AGGREGATION_VARIANT_SELECTION_METHOD;

View File

@ -299,7 +299,10 @@ Session::~Session()
if (notified_session_log_about_login)
{
if (auto session_log = getSessionLog())
{
/// TODO: We have to ensure that the same info is added to the session log on a LoginSuccess event and on the corresponding Logout event.
session_log->addLogOut(auth_id, user, getClientInfo());
}
}
}
@ -368,17 +371,117 @@ void Session::onAuthenticationFailure(const std::optional<String> & user_name, c
}
}
ClientInfo & Session::getClientInfo()
{
/// FIXME it may produce different info for LoginSuccess and the corresponding Logout entries in the session log
return session_context ? session_context->getClientInfo() : *prepared_client_info;
}
const ClientInfo & Session::getClientInfo() const
{
return session_context ? session_context->getClientInfo() : *prepared_client_info;
}
void Session::setClientInfo(const ClientInfo & client_info)
{
if (session_context)
session_context->setClientInfo(client_info);
else
prepared_client_info = client_info;
}
void Session::setClientName(const String & client_name)
{
if (session_context)
session_context->setClientName(client_name);
else
prepared_client_info->client_name = client_name;
}
void Session::setClientInterface(ClientInfo::Interface interface)
{
if (session_context)
session_context->setClientInterface(interface);
else
prepared_client_info->interface = interface;
}
void Session::setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version)
{
if (session_context)
{
session_context->setClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version);
}
else
{
prepared_client_info->client_version_major = client_version_major;
prepared_client_info->client_version_minor = client_version_minor;
prepared_client_info->client_version_patch = client_version_patch;
prepared_client_info->client_tcp_protocol_version = client_tcp_protocol_version;
}
}
void Session::setClientConnectionId(uint32_t connection_id)
{
if (session_context)
session_context->setClientConnectionId(connection_id);
else
prepared_client_info->connection_id = connection_id;
}
void Session::setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer)
{
if (session_context)
{
session_context->setHttpClientInfo(http_method, http_user_agent, http_referer);
}
else
{
prepared_client_info->http_method = http_method;
prepared_client_info->http_user_agent = http_user_agent;
prepared_client_info->http_referer = http_referer;
}
}
void Session::setForwardedFor(const String & forwarded_for)
{
if (session_context)
session_context->setForwardedFor(forwarded_for);
else
prepared_client_info->forwarded_for = forwarded_for;
}
void Session::setQuotaClientKey(const String & quota_key)
{
if (session_context)
session_context->setQuotaClientKey(quota_key);
else
prepared_client_info->quota_key = quota_key;
}
void Session::setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version)
{
if (session_context)
{
session_context->setConnectionClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version);
}
else
{
prepared_client_info->connection_client_version_major = client_version_major;
prepared_client_info->connection_client_version_minor = client_version_minor;
prepared_client_info->connection_client_version_patch = client_version_patch;
prepared_client_info->connection_tcp_protocol_version = client_tcp_protocol_version;
}
}
const OpenTelemetry::TracingContext & Session::getClientTraceContext() const
{
if (session_context)
return session_context->getClientTraceContext();
return prepared_client_info->client_trace_context;
}
OpenTelemetry::TracingContext & Session::getClientTraceContext()
{
if (session_context)
return session_context->getClientTraceContext();
return prepared_client_info->client_trace_context;
}
ContextMutablePtr Session::makeSessionContext()
{
if (session_context)
@ -396,8 +499,7 @@ ContextMutablePtr Session::makeSessionContext()
new_session_context->makeSessionContext();
/// Copy prepared client info to the new session context.
auto & res_client_info = new_session_context->getClientInfo();
res_client_info = std::move(prepared_client_info).value();
new_session_context->setClientInfo(*prepared_client_info);
prepared_client_info.reset();
/// Set user information for the new context: current profiles, roles, access rights.
@ -436,8 +538,7 @@ ContextMutablePtr Session::makeSessionContext(const String & session_name_, std:
/// Copy prepared client info to the session context, no matter it's been just created or not.
/// If we continue using a previously created session context found by session ID
/// it's necessary to replace the client info in it anyway, because it contains actual connection information (client address, etc.)
auto & res_client_info = new_session_context->getClientInfo();
res_client_info = std::move(prepared_client_info).value();
new_session_context->setClientInfo(*prepared_client_info);
prepared_client_info.reset();
/// Set user information for the new context: current profiles, roles, access rights.
@ -492,32 +593,28 @@ ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_t
}
/// Copy the specified client info to the new query context.
auto & res_client_info = query_context->getClientInfo();
if (client_info_to_move)
res_client_info = std::move(*client_info_to_move);
query_context->setClientInfo(*client_info_to_move);
else if (client_info_to_copy && (client_info_to_copy != &getClientInfo()))
res_client_info = *client_info_to_copy;
query_context->setClientInfo(*client_info_to_copy);
/// Copy current user's name and address if it was authenticated after query_client_info was initialized.
if (prepared_client_info && !prepared_client_info->current_user.empty())
{
res_client_info.current_user = prepared_client_info->current_user;
res_client_info.current_address = prepared_client_info->current_address;
query_context->setCurrentUserName(prepared_client_info->current_user);
query_context->setCurrentAddress(prepared_client_info->current_address);
}
/// Set parameters of initial query.
if (res_client_info.query_kind == ClientInfo::QueryKind::NO_QUERY)
res_client_info.query_kind = ClientInfo::QueryKind::INITIAL_QUERY;
if (query_context->getClientInfo().query_kind == ClientInfo::QueryKind::NO_QUERY)
query_context->setQueryKind(ClientInfo::QueryKind::INITIAL_QUERY);
if (res_client_info.query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
if (query_context->getClientInfo().query_kind == ClientInfo::QueryKind::INITIAL_QUERY)
{
res_client_info.initial_user = res_client_info.current_user;
res_client_info.initial_address = res_client_info.current_address;
query_context->setInitialUserName(query_context->getClientInfo().current_user);
query_context->setInitialAddress(query_context->getClientInfo().current_address);
}
/// Sets that row policies of the initial user should be used too.
query_context->enableRowPoliciesOfInitialUser();
/// Set user information for the new context: current profiles, roles, access rights.
if (user_id && !query_context->getAccess()->tryGetUser())
query_context->setUser(*user_id);
@ -566,4 +663,3 @@ void Session::closeSession(const String & session_id)
}
}

View File

@ -54,10 +54,23 @@ public:
/// Writes a row about login failure into session log (if enabled)
void onAuthenticationFailure(const std::optional<String> & user_name, const Poco::Net::SocketAddress & address_, const Exception & e);
/// Returns a reference to session ClientInfo.
ClientInfo & getClientInfo();
/// Returns a reference to the session's ClientInfo.
const ClientInfo & getClientInfo() const;
/// Modify the session's ClientInfo.
void setClientInfo(const ClientInfo & client_info);
void setClientName(const String & client_name);
void setClientInterface(ClientInfo::Interface interface);
void setClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
void setClientConnectionId(uint32_t connection_id);
void setHttpClientInfo(ClientInfo::HTTPMethod http_method, const String & http_user_agent, const String & http_referer);
void setForwardedFor(const String & forwarded_for);
void setQuotaClientKey(const String & quota_key);
void setConnectionClientVersion(UInt64 client_version_major, UInt64 client_version_minor, UInt64 client_version_patch, unsigned client_tcp_protocol_version);
const OpenTelemetry::TracingContext & getClientTraceContext() const;
OpenTelemetry::TracingContext & getClientTraceContext();
/// Makes a session context, can be used one or zero times.
/// The function also assigns an user to this context.
ContextMutablePtr makeSessionContext();

View File

@ -656,7 +656,7 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
/// the value passed by the client
Stopwatch start_watch{CLOCK_MONOTONIC};
auto & client_info = context->getClientInfo();
const auto & client_info = context->getClientInfo();
if (!internal)
{
@ -668,8 +668,7 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
// On the other hand, if it's initialized then take it as the start of the query
if (client_info.initial_query_start_time == 0)
{
client_info.initial_query_start_time = timeInSeconds(query_start_time);
client_info.initial_query_start_time_microseconds = timeInMicroseconds(query_start_time);
context->setInitialQueryStartTime(query_start_time);
}
else
{

View File

@ -72,14 +72,10 @@ std::unique_ptr<QueryPlan> createLocalPlan(
if (coordinator)
{
new_context->parallel_reading_coordinator = coordinator;
new_context->getClientInfo().interface = ClientInfo::Interface::LOCAL;
new_context->getClientInfo().collaborate_with_initiator = true;
new_context->getClientInfo().query_kind = ClientInfo::QueryKind::SECONDARY_QUERY;
new_context->getClientInfo().count_participating_replicas = replica_count;
new_context->getClientInfo().number_of_current_replica = replica_num;
new_context->getClientInfo().connection_client_version_major = DBMS_VERSION_MAJOR;
new_context->getClientInfo().connection_client_version_minor = DBMS_VERSION_MINOR;
new_context->getClientInfo().connection_tcp_protocol_version = DBMS_TCP_PROTOCOL_VERSION;
new_context->setClientInterface(ClientInfo::Interface::LOCAL);
new_context->setQueryKind(ClientInfo::QueryKind::SECONDARY_QUERY);
new_context->setReplicaInfo(true, replica_count, replica_num);
new_context->setConnectionClientVersion(DBMS_VERSION_MAJOR, DBMS_VERSION_MINOR, DBMS_VERSION_PATCH, DBMS_TCP_PROTOCOL_VERSION);
new_context->setParallelReplicasGroupUUID(group_uuid);
new_context->setMergeTreeAllRangesCallback([coordinator](InitialAllRangesAnnouncement announcement)
{

View File

@ -798,7 +798,7 @@ namespace
/// Authentication.
session.emplace(iserver.context(), ClientInfo::Interface::GRPC);
session->authenticate(user, password, user_address);
session->getClientInfo().quota_key = quota_key;
session->setQuotaClientKey(quota_key);
ClientInfo client_info = session->getClientInfo();

View File

@ -474,7 +474,6 @@ bool HTTPHandler::authenticateUser(
}
/// Set client info. It will be used for quota accounting parameters in 'setUser' method.
ClientInfo & client_info = session->getClientInfo();
ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
if (request.getMethod() == HTTPServerRequest::HTTP_GET)
@ -482,15 +481,13 @@ bool HTTPHandler::authenticateUser(
else if (request.getMethod() == HTTPServerRequest::HTTP_POST)
http_method = ClientInfo::HTTPMethod::POST;
client_info.http_method = http_method;
client_info.http_user_agent = request.get("User-Agent", "");
client_info.http_referer = request.get("Referer", "");
client_info.forwarded_for = request.get("X-Forwarded-For", "");
client_info.quota_key = quota_key;
session->setHttpClientInfo(http_method, request.get("User-Agent", ""), request.get("Referer", ""));
session->setForwardedFor(request.get("X-Forwarded-For", ""));
session->setQuotaClientKey(quota_key);
/// Extract the last entry from comma separated list of forwarded_for addresses.
/// Only the last proxy can be trusted (if any).
String forwarded_address = client_info.getLastForwardedFor();
String forwarded_address = session->getClientInfo().getLastForwardedFor();
try
{
if (!forwarded_address.empty() && server.config().getBool("auth_use_forwarded_address", false))
@ -988,22 +985,22 @@ void HTTPHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
}
// Parse the OpenTelemetry traceparent header.
ClientInfo& client_info = session->getClientInfo();
auto & client_trace_context = session->getClientTraceContext();
if (request.has("traceparent"))
{
std::string opentelemetry_traceparent = request.get("traceparent");
std::string error;
if (!client_info.client_trace_context.parseTraceparentHeader(opentelemetry_traceparent, error))
if (!client_trace_context.parseTraceparentHeader(opentelemetry_traceparent, error))
{
LOG_DEBUG(log, "Failed to parse OpenTelemetry traceparent header '{}': {}", opentelemetry_traceparent, error);
}
client_info.client_trace_context.tracestate = request.get("tracestate", "");
client_trace_context.tracestate = request.get("tracestate", "");
}
// Setup tracing context for this thread
auto context = session->sessionOrGlobalContext();
thread_trace_context = std::make_unique<OpenTelemetry::TracingContextHolder>("HTTPHandler",
client_info.client_trace_context,
client_trace_context,
context->getSettingsRef(),
context->getOpenTelemetrySpanLog());
thread_trace_context->root_span.kind = OpenTelemetry::SERVER;

View File

@ -94,7 +94,7 @@ void MySQLHandler::run()
session = std::make_unique<Session>(server.context(), ClientInfo::Interface::MYSQL);
SCOPE_EXIT({ session.reset(); });
session->getClientInfo().connection_id = connection_id;
session->setClientConnectionId(connection_id);
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket());

View File

@ -58,7 +58,7 @@ void PostgreSQLHandler::run()
session = std::make_unique<Session>(server.context(), ClientInfo::Interface::POSTGRESQL);
SCOPE_EXIT({ session.reset(); });
session->getClientInfo().connection_id = connection_id;
session->setClientConnectionId(connection_id);
try
{

View File

@ -1177,21 +1177,12 @@ std::unique_ptr<Session> TCPHandler::makeSession()
auto res = std::make_unique<Session>(server.context(), interface, socket().secure(), certificate);
auto & client_info = res->getClientInfo();
client_info.forwarded_for = forwarded_for;
client_info.client_name = client_name;
client_info.client_version_major = client_version_major;
client_info.client_version_minor = client_version_minor;
client_info.client_version_patch = client_version_patch;
client_info.client_tcp_protocol_version = client_tcp_protocol_version;
client_info.connection_client_version_major = client_version_major;
client_info.connection_client_version_minor = client_version_minor;
client_info.connection_client_version_patch = client_version_patch;
client_info.connection_tcp_protocol_version = client_tcp_protocol_version;
client_info.quota_key = quota_key;
client_info.interface = interface;
res->setForwardedFor(forwarded_for);
res->setClientName(client_name);
res->setClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version);
res->setConnectionClientVersion(client_version_major, client_version_minor, client_version_patch, client_tcp_protocol_version);
res->setQuotaClientKey(quota_key);
res->setClientInterface(interface);
return res;
}
@ -1253,7 +1244,7 @@ void TCPHandler::receiveHello()
}
session = makeSession();
auto & client_info = session->getClientInfo();
const auto & client_info = session->getClientInfo();
#if USE_SSL
/// Authentication with SSL user certificate
@ -1286,7 +1277,7 @@ void TCPHandler::receiveAddendum()
{
readStringBinary(quota_key, *in);
if (!is_interserver_mode)
session->getClientInfo().quota_key = quota_key;
session->setQuotaClientKey(quota_key);
}
}

View File

@ -132,7 +132,7 @@ DistributedSink::DistributedSink(
const auto & settings = context->getSettingsRef();
if (settings.max_distributed_depth && context->getClientInfo().distributed_depth >= settings.max_distributed_depth)
throw Exception(ErrorCodes::TOO_LARGE_DISTRIBUTED_DEPTH, "Maximum distributed depth exceeded");
context->getClientInfo().distributed_depth += 1;
context->increaseDistributedDepth();
random_shard_insert = settings.insert_distributed_one_random_shard && !storage.has_sharding_key;
}

View File

@ -914,7 +914,7 @@ std::optional<QueryPipeline> StorageDistributed::distributedWriteBetweenDistribu
QueryPipeline pipeline;
ContextMutablePtr query_context = Context::createCopy(local_context);
++query_context->getClientInfo().distributed_depth;
query_context->increaseDistributedDepth();
for (size_t shard_index : collections::range(0, shards_info.size()))
{
@ -976,7 +976,7 @@ std::optional<QueryPipeline> StorageDistributed::distributedWriteFromClusterStor
QueryPipeline pipeline;
ContextMutablePtr query_context = Context::createCopy(local_context);
++query_context->getClientInfo().distributed_depth;
query_context->increaseDistributedDepth();
/// Here we take addresses from destination cluster and assume source table exists on these nodes
for (const auto & replicas : getCluster()->getShardsAddresses())

View File

@ -5082,7 +5082,7 @@ std::optional<QueryPipeline> StorageReplicatedMergeTree::distributedWriteFromClu
QueryPipeline pipeline;
ContextMutablePtr query_context = Context::createCopy(local_context);
++query_context->getClientInfo().distributed_depth;
query_context->increaseDistributedDepth();
for (const auto & replicas : src_cluster->getShardsAddresses())
{

View File

@ -992,7 +992,7 @@ void StorageWindowView::cleanup()
auto cleanup_context = Context::createCopy(getContext());
cleanup_context->makeQueryContext();
cleanup_context->setCurrentQueryId("");
cleanup_context->getClientInfo().is_replicated_database_internal = true;
cleanup_context->setQueryKindReplicatedDatabaseInternal();
InterpreterAlterQuery interpreter_alter(alter_query, cleanup_context);
interpreter_alter.execute();