diff --git a/base/common/StringRef.h b/base/common/StringRef.h index 961aab58980..076b8982b1d 100644 --- a/base/common/StringRef.h +++ b/base/common/StringRef.h @@ -27,17 +27,17 @@ struct StringRef size_t size = 0; template > - StringRef(const CharT * data_, size_t size_) : data(reinterpret_cast(data_)), size(size_) {} + constexpr StringRef(const CharT * data_, size_t size_) : data(reinterpret_cast(data_)), size(size_) {} StringRef(const std::string & s) : data(s.data()), size(s.size()) {} - StringRef(const std::string_view & s) : data(s.data()), size(s.size()) {} - explicit StringRef(const char * data_) : data(data_), size(strlen(data_)) {} - StringRef() = default; + constexpr StringRef(const std::string_view & s) : data(s.data()), size(s.size()) {} + constexpr StringRef(const char * data_) : StringRef(std::string_view{data_}) {} + constexpr StringRef() = default; std::string toString() const { return std::string(data, size); } explicit operator std::string() const { return toString(); } - explicit operator std::string_view() const { return {data, size}; } + constexpr explicit operator std::string_view() const { return {data, size}; } }; using StringRefs = std::vector; diff --git a/base/ext/scope_guard.h b/base/ext/scope_guard.h index f5b986e7ab6..79bad56f360 100644 --- a/base/ext/scope_guard.h +++ b/base/ext/scope_guard.h @@ -12,20 +12,20 @@ class [[nodiscard]] basic_scope_guard { public: constexpr basic_scope_guard() = default; - constexpr basic_scope_guard(basic_scope_guard && src) : function{std::exchange(src.function, {})} {} + constexpr basic_scope_guard(basic_scope_guard && src) : function{src.release()} {} constexpr basic_scope_guard & operator=(basic_scope_guard && src) { if (this != &src) { invoke(); - function = std::exchange(src.function, {}); + function = src.release(); } return *this; } template , void>> - constexpr basic_scope_guard(basic_scope_guard && src) : function{std::exchange(src.function, {})} {} + constexpr basic_scope_guard(basic_scope_guard && src) : function{src.release()} {} template , void>> constexpr basic_scope_guard & operator=(basic_scope_guard && src) @@ -33,7 +33,7 @@ public: if (this != &src) { invoke(); - function = std::exchange(src.function, {}); + function = src.release(); } return *this; } @@ -46,14 +46,26 @@ public: ~basic_scope_guard() { invoke(); } + static constexpr bool is_nullable = std::is_constructible_v; + explicit operator bool() const { - if constexpr (std::is_constructible_v) + if constexpr (is_nullable) return static_cast(function); return true; } - void reset() { function = {}; } + void reset() + { + invoke(); + release(); + } + + F release() + { + static_assert(is_nullable); + return std::exchange(function, {}); + } template , void>> basic_scope_guard & join(basic_scope_guard && other) @@ -62,14 +74,14 @@ public: { if (function) { - function = [x = std::make_shared>(std::move(function), std::exchange(other.function, {}))]() + function = [x = std::make_shared>(std::move(function), other.release())]() { std::move(x->first)(); std::move(x->second)(); }; } else - function = std::exchange(other.function, {}); + function = other.release(); } return *this; } @@ -77,7 +89,7 @@ public: private: void invoke() { - if constexpr (std::is_constructible_v) + if constexpr (is_nullable) { if (!function) return; diff --git a/dbms/programs/client/Client.cpp b/dbms/programs/client/Client.cpp index 142f5edc4da..21768911beb 100644 --- a/dbms/programs/client/Client.cpp +++ b/dbms/programs/client/Client.cpp @@ -225,11 +225,11 @@ private: context.setQueryParameters(query_parameters); /// settings and limits could be specified in config file, but passed settings has higher priority - for (auto && setting : context.getSettingsRef()) + for (const auto & setting : context.getSettingsRef()) { const String & name = setting.getName().toString(); if (config().has(name) && !setting.isChanged()) - setting.setValue(config().getString(name)); + context.setSetting(name, config().getString(name)); } /// Set path for format schema files @@ -1736,8 +1736,8 @@ public: ("server_logs_file", po::value(), "put server logs into specified file") ; - context.makeGlobalContext(); - context.getSettingsRef().addProgramOptions(main_description); + Settings cmd_settings; + cmd_settings.addProgramOptions(main_description); /// Commandline options related to external tables. po::options_description external_description = createOptionsDescription("External tables options", terminal_width); @@ -1805,6 +1805,9 @@ public: } } + context.makeGlobalContext(); + context.setSettings(cmd_settings); + /// Copy settings-related program options to config. /// TODO: Is this code necessary? for (const auto & setting : context.getSettingsRef()) diff --git a/dbms/programs/copier/ClusterCopier.cpp b/dbms/programs/copier/ClusterCopier.cpp index 4431362913d..fcb2a69d2a4 100644 --- a/dbms/programs/copier/ClusterCopier.cpp +++ b/dbms/programs/copier/ClusterCopier.cpp @@ -216,7 +216,7 @@ void ClusterCopier::reloadTaskDescription() /// Setup settings task_cluster->reloadSettings(*config); - context.getSettingsRef() = task_cluster->settings_common; + context.setSettings(task_cluster->settings_common); task_cluster_current_config = config; task_descprtion_current_stat = stat; @@ -964,8 +964,8 @@ PartitionTaskStatus ClusterCopier::processPartitionTaskImpl(const ConnectionTime { Context local_context = context; // Use pull (i.e. readonly) settings, but fetch data from destination servers - local_context.getSettingsRef() = task_cluster->settings_pull; - local_context.getSettingsRef().skip_unavailable_shards = true; + local_context.setSettings(task_cluster->settings_pull); + local_context.setSetting("skip_unavailable_shards", true); Block block = getBlockWithAllStreamData(InterpreterFactory::get(query_select_ast, local_context)->execute().in); count = (block) ? block.safeGetByPosition(0).column->getUInt(0) : 0; @@ -1053,10 +1053,10 @@ PartitionTaskStatus ClusterCopier::processPartitionTaskImpl(const ConnectionTime { /// Custom INSERT SELECT implementation Context context_select = context; - context_select.getSettingsRef() = task_cluster->settings_pull; + context_select.setSettings(task_cluster->settings_pull); Context context_insert = context; - context_insert.getSettingsRef() = task_cluster->settings_push; + context_insert.setSettings(task_cluster->settings_push); BlockInputStreamPtr input; BlockOutputStreamPtr output; diff --git a/dbms/programs/local/LocalServer.cpp b/dbms/programs/local/LocalServer.cpp index 2d93c792350..ec2c01924f6 100644 --- a/dbms/programs/local/LocalServer.cpp +++ b/dbms/programs/local/LocalServer.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -92,7 +93,7 @@ void LocalServer::initialize(Poco::Util::Application & self) void LocalServer::applyCmdSettings() { - context->getSettingsRef().copyChangesFrom(cmd_settings); + context->applySettingsChanges(cmd_settings.changes()); } /// If path is specified and not empty, will try to setup server environment and load existing metadata diff --git a/dbms/programs/server/HTTPHandler.cpp b/dbms/programs/server/HTTPHandler.cpp index 0d447a56740..65b605c993d 100644 --- a/dbms/programs/server/HTTPHandler.cpp +++ b/dbms/programs/server/HTTPHandler.cpp @@ -439,13 +439,13 @@ void HTTPHandler::processQuery( /// In theory if initially readonly = 0, the client can change any setting and then set readonly /// to some other value. - auto & settings = context.getSettingsRef(); + const auto & settings = context.getSettingsRef(); /// Only readonly queries are allowed for HTTP GET requests. if (request.getMethod() == Poco::Net::HTTPServerRequest::HTTP_GET) { if (settings.readonly == 0) - settings.readonly = 2; + context.setSetting("readonly", 2); } bool has_external_data = startsWith(request.getContentType(), "multipart/form-data"); diff --git a/dbms/programs/server/Server.cpp b/dbms/programs/server/Server.cpp index 86f65fb09f1..96ba2883480 100644 --- a/dbms/programs/server/Server.cpp +++ b/dbms/programs/server/Server.cpp @@ -527,7 +527,7 @@ int Server::main(const std::vector & /*args*/) /// Load global settings from default_profile and system_profile. global_context->setDefaultProfiles(config()); - Settings & settings = global_context->getSettingsRef(); + const Settings & settings = global_context->getSettingsRef(); /// Size of cache for marks (index of MergeTree family of tables). It is mandatory. size_t mark_cache_size = config().getUInt64("mark_cache_size"); diff --git a/dbms/programs/server/TCPHandler.cpp b/dbms/programs/server/TCPHandler.cpp index a5ecf2963ea..f9df1e4cf9a 100644 --- a/dbms/programs/server/TCPHandler.cpp +++ b/dbms/programs/server/TCPHandler.cpp @@ -950,11 +950,11 @@ void TCPHandler::receiveUnexpectedQuery() readStringBinary(skip_string, *in); - ClientInfo & skip_client_info = query_context->getClientInfo(); + ClientInfo skip_client_info; if (client_revision >= DBMS_MIN_REVISION_WITH_CLIENT_INFO) skip_client_info.read(*in, client_revision); - Settings & skip_settings = query_context->getSettingsRef(); + Settings skip_settings; auto settings_format = (client_revision >= DBMS_MIN_REVISION_WITH_SETTINGS_SERIALIZED_AS_STRINGS) ? SettingsBinaryFormat::STRINGS : SettingsBinaryFormat::OLD; skip_settings.deserialize(*in, settings_format); diff --git a/dbms/src/Access/AccessControlManager.cpp b/dbms/src/Access/AccessControlManager.cpp index b3854e69eec..b5e06549c28 100644 --- a/dbms/src/Access/AccessControlManager.cpp +++ b/dbms/src/Access/AccessControlManager.cpp @@ -3,10 +3,15 @@ #include #include #include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include namespace DB @@ -27,12 +32,55 @@ namespace } +class AccessControlManager::ContextAccessCache +{ +public: + explicit ContextAccessCache(const AccessControlManager & manager_) : manager(manager_) {} + + std::shared_ptr getContextAccess( + const UUID & user_id, + const std::vector & current_roles, + bool use_default_roles, + const Settings & settings, + const String & current_database, + const ClientInfo & client_info) + { + ContextAccess::Params params; + params.user_id = user_id; + params.current_roles = current_roles; + params.use_default_roles = use_default_roles; + params.current_database = current_database; + params.readonly = settings.readonly; + params.allow_ddl = settings.allow_ddl; + params.allow_introspection = settings.allow_introspection_functions; + params.interface = client_info.interface; + params.http_method = client_info.http_method; + params.address = client_info.current_address.host(); + params.quota_key = client_info.quota_key; + + std::lock_guard lock{mutex}; + auto x = cache.get(params); + if (x) + return *x; + auto res = std::shared_ptr(new ContextAccess(manager, params)); + cache.add(params, res); + return res; + } + +private: + const AccessControlManager & manager; + Poco::ExpireCache> cache; + std::mutex mutex; +}; + + AccessControlManager::AccessControlManager() : MultipleAccessStorage(createStorages()), - access_rights_context_factory(std::make_unique(*this)), - role_context_factory(std::make_unique(*this)), - row_policy_context_factory(std::make_unique(*this)), - quota_context_factory(std::make_unique(*this)) + context_access_cache(std::make_unique(*this)), + role_cache(std::make_unique(*this)), + row_policy_cache(std::make_unique(*this)), + quota_cache(std::make_unique(*this)), + settings_profiles_cache(std::make_unique(*this)) { } @@ -54,7 +102,13 @@ void AccessControlManager::setUsersConfig(const Poco::Util::AbstractConfiguratio } -AccessRightsContextPtr AccessControlManager::getAccessRightsContext( +void AccessControlManager::setDefaultProfileName(const String & default_profile_name) +{ + settings_profiles_cache->setDefaultProfileName(default_profile_name); +} + + +std::shared_ptr AccessControlManager::getContextAccess( const UUID & user_id, const std::vector & current_roles, bool use_default_roles, @@ -62,34 +116,49 @@ AccessRightsContextPtr AccessControlManager::getAccessRightsContext( const String & current_database, const ClientInfo & client_info) const { - return access_rights_context_factory->createContext(user_id, current_roles, use_default_roles, settings, current_database, client_info); + return context_access_cache->getContextAccess(user_id, current_roles, use_default_roles, settings, current_database, client_info); } -RoleContextPtr AccessControlManager::getRoleContext( +std::shared_ptr AccessControlManager::getEnabledRoles( const std::vector & current_roles, const std::vector & current_roles_with_admin_option) const { - return role_context_factory->createContext(current_roles, current_roles_with_admin_option); + return role_cache->getEnabledRoles(current_roles, current_roles_with_admin_option); } -RowPolicyContextPtr AccessControlManager::getRowPolicyContext(const UUID & user_id, const std::vector & enabled_roles) const +std::shared_ptr AccessControlManager::getEnabledRowPolicies(const UUID & user_id, const std::vector & enabled_roles) const { - return row_policy_context_factory->createContext(user_id, enabled_roles); + return row_policy_cache->getEnabledRowPolicies(user_id, enabled_roles); } -QuotaContextPtr AccessControlManager::getQuotaContext( - const String & user_name, const UUID & user_id, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & custom_quota_key) const +std::shared_ptr AccessControlManager::getEnabledQuota( + const UUID & user_id, const String & user_name, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & custom_quota_key) const { - return quota_context_factory->createContext(user_name, user_id, enabled_roles, address, custom_quota_key); + return quota_cache->getEnabledQuota(user_id, user_name, enabled_roles, address, custom_quota_key); } std::vector AccessControlManager::getQuotaUsageInfo() const { - return quota_context_factory->getUsageInfo(); + return quota_cache->getUsageInfo(); +} + + +std::shared_ptr AccessControlManager::getEnabledSettings( + const UUID & user_id, + const SettingsProfileElements & settings_from_user, + const std::vector & enabled_roles, + const SettingsProfileElements & settings_from_enabled_roles) const +{ + return settings_profiles_cache->getEnabledSettings(user_id, settings_from_user, enabled_roles, settings_from_enabled_roles); +} + +std::shared_ptr AccessControlManager::getProfileSettings(const String & profile_name) const +{ + return settings_profiles_cache->getProfileSettings(profile_name); } } diff --git a/dbms/src/Access/AccessControlManager.h b/dbms/src/Access/AccessControlManager.h index bd5720bb0f4..810970a8379 100644 --- a/dbms/src/Access/AccessControlManager.h +++ b/dbms/src/Access/AccessControlManager.h @@ -19,21 +19,21 @@ namespace Poco namespace DB { -class AccessRightsContext; -using AccessRightsContextPtr = std::shared_ptr; -class AccessRightsContextFactory; +class ContextAccess; struct User; using UserPtr = std::shared_ptr; -class RoleContext; -using RoleContextPtr = std::shared_ptr; -class RoleContextFactory; -class RowPolicyContext; -using RowPolicyContextPtr = std::shared_ptr; -class RowPolicyContextFactory; -class QuotaContext; -using QuotaContextPtr = std::shared_ptr; -class QuotaContextFactory; +class EnabledRoles; +class RoleCache; +class EnabledRowPolicies; +class RowPolicyCache; +class EnabledQuota; +class QuotaCache; struct QuotaUsageInfo; +struct SettingsProfile; +using SettingsProfilePtr = std::shared_ptr; +class EnabledSettings; +class SettingsProfilesCache; +class SettingsProfileElements; class ClientInfo; struct Settings; @@ -47,8 +47,9 @@ public: void setLocalDirectory(const String & directory); void setUsersConfig(const Poco::Util::AbstractConfiguration & users_config); + void setDefaultProfileName(const String & default_profile_name); - AccessRightsContextPtr getAccessRightsContext( + std::shared_ptr getContextAccess( const UUID & user_id, const std::vector & current_roles, bool use_default_roles, @@ -56,28 +57,37 @@ public: const String & current_database, const ClientInfo & client_info) const; - RoleContextPtr getRoleContext( + std::shared_ptr getEnabledRoles( const std::vector & current_roles, const std::vector & current_roles_with_admin_option) const; - RowPolicyContextPtr getRowPolicyContext( + std::shared_ptr getEnabledRowPolicies( const UUID & user_id, const std::vector & enabled_roles) const; - QuotaContextPtr getQuotaContext( - const String & user_name, + std::shared_ptr getEnabledQuota( const UUID & user_id, + const String & user_name, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & custom_quota_key) const; std::vector getQuotaUsageInfo() const; + std::shared_ptr getEnabledSettings(const UUID & user_id, + const SettingsProfileElements & settings_from_user, + const std::vector & enabled_roles, + const SettingsProfileElements & settings_from_enabled_roles) const; + + std::shared_ptr getProfileSettings(const String & profile_name) const; + private: - std::unique_ptr access_rights_context_factory; - std::unique_ptr role_context_factory; - std::unique_ptr row_policy_context_factory; - std::unique_ptr quota_context_factory; + class ContextAccessCache; + std::unique_ptr context_access_cache; + std::unique_ptr role_cache; + std::unique_ptr row_policy_cache; + std::unique_ptr quota_cache; + std::unique_ptr settings_profiles_cache; }; } diff --git a/dbms/src/Access/AccessFlags.h b/dbms/src/Access/AccessFlags.h index 5af804ddc48..f15e7d1e274 100644 --- a/dbms/src/Access/AccessFlags.h +++ b/dbms/src/Access/AccessFlags.h @@ -63,15 +63,24 @@ public: /// Returns a list of keywords. std::vector toKeywords() const; - /// Returns the access types which could be granted on the database level. - /// For example, SELECT can be granted on the database level, but CREATE_USER cannot. - static AccessFlags databaseLevel(); + /// Returns all the flags. + /// These are the same as (allGlobalFlags() | allDatabaseFlags() | allTableFlags() | allColumnsFlags() | allDictionaryFlags()). + static AccessFlags allFlags(); - /// Returns the access types which could be granted on the table/dictionary level. - static AccessFlags tableLevel(); + /// Returns all the global flags. + static AccessFlags allGlobalFlags(); - /// Returns the access types which could be granted on the column/attribute level. - static AccessFlags columnLevel(); + /// Returns all the flags related to a database. + static AccessFlags allDatabaseFlags(); + + /// Returns all the flags related to a table. + static AccessFlags allTableFlags(); + + /// Returns all the flags related to a column. + static AccessFlags allColumnFlags(); + + /// Returns all the flags related to a dictionary. + static AccessFlags allDictionaryFlags(); private: static constexpr size_t NUM_FLAGS = 128; @@ -158,22 +167,27 @@ public: return str; } - const Flags & getDatabaseLevelFlags() const { return all_grantable_on_level[DATABASE_LEVEL]; } - const Flags & getTableLevelFlags() const { return all_grantable_on_level[TABLE_LEVEL]; } - const Flags & getColumnLevelFlags() const { return all_grantable_on_level[COLUMN_LEVEL]; } + const Flags & getAllFlags() const { return all_flags; } + const Flags & getGlobalFlags() const { return all_flags_for_target[GLOBAL]; } + const Flags & getDatabaseFlags() const { return all_flags_for_target[DATABASE]; } + const Flags & getTableFlags() const { return all_flags_for_target[TABLE]; } + const Flags & getColumnFlags() const { return all_flags_for_target[COLUMN]; } + const Flags & getDictionaryFlags() const { return all_flags_for_target[DICTIONARY]; } private: - enum Level + enum Target { - UNKNOWN_LEVEL = -1, - GLOBAL_LEVEL = 0, - DATABASE_LEVEL = 1, - TABLE_LEVEL = 2, - VIEW_LEVEL = 2, - DICTIONARY_LEVEL = 2, - COLUMN_LEVEL = 3, + UNKNOWN_TARGET, + GLOBAL, + DATABASE, + TABLE, + VIEW = TABLE, + COLUMN, + DICTIONARY, }; + static constexpr size_t NUM_TARGETS = static_cast(DICTIONARY) + 1; + struct Node; using NodePtr = std::unique_ptr; using Nodes = std::vector; @@ -191,11 +205,11 @@ private: std::string_view keyword; std::vector aliases; Flags flags; - Level level = UNKNOWN_LEVEL; + Target target = UNKNOWN_TARGET; Nodes children; - Node(std::string_view keyword_, size_t flag_, Level level_) - : keyword(keyword_), level(level_) + Node(std::string_view keyword_, size_t flag_, Target target_) + : keyword(keyword_), target(target_) { flags.set(flag_); } @@ -229,216 +243,225 @@ private: } } - static void makeFlagsToKeywordTree(NodePtr & flags_to_keyword_tree_) + static NodePtr makeFlagsToKeywordTree() { size_t next_flag = 0; Nodes all; - auto show = std::make_unique("SHOW", next_flag++, COLUMN_LEVEL); - auto exists = std::make_unique("EXISTS", next_flag++, COLUMN_LEVEL); - ext::push_back(all, std::move(show), std::move(exists)); + auto show_databases = std::make_unique("SHOW DATABASES", next_flag++, DATABASE); + auto show_tables = std::make_unique("SHOW TABLES", next_flag++, TABLE); + auto show_columns = std::make_unique("SHOW COLUMNS", next_flag++, COLUMN); + auto show_dictionaries = std::make_unique("SHOW DICTIONARIES", next_flag++, DICTIONARY); + auto show = std::make_unique("SHOW", std::move(show_databases), std::move(show_tables), std::move(show_columns), std::move(show_dictionaries)); + ext::push_back(all, std::move(show)); - auto select = std::make_unique("SELECT", next_flag++, COLUMN_LEVEL); - auto insert = std::make_unique("INSERT", next_flag++, COLUMN_LEVEL); + auto select = std::make_unique("SELECT", next_flag++, COLUMN); + auto insert = std::make_unique("INSERT", next_flag++, COLUMN); ext::push_back(all, std::move(select), std::move(insert)); - auto update = std::make_unique("UPDATE", next_flag++, COLUMN_LEVEL); + auto update = std::make_unique("UPDATE", next_flag++, COLUMN); ext::push_back(update->aliases, "ALTER UPDATE"); - auto delet = std::make_unique("DELETE", next_flag++, TABLE_LEVEL); + auto delet = std::make_unique("DELETE", next_flag++, TABLE); ext::push_back(delet->aliases, "ALTER DELETE"); - auto add_column = std::make_unique("ADD COLUMN", next_flag++, COLUMN_LEVEL); + auto add_column = std::make_unique("ADD COLUMN", next_flag++, COLUMN); add_column->aliases.push_back("ALTER ADD COLUMN"); - auto modify_column = std::make_unique("MODIFY COLUMN", next_flag++, COLUMN_LEVEL); + auto modify_column = std::make_unique("MODIFY COLUMN", next_flag++, COLUMN); modify_column->aliases.push_back("ALTER MODIFY COLUMN"); - auto drop_column = std::make_unique("DROP COLUMN", next_flag++, COLUMN_LEVEL); + auto drop_column = std::make_unique("DROP COLUMN", next_flag++, COLUMN); drop_column->aliases.push_back("ALTER DROP COLUMN"); - auto comment_column = std::make_unique("COMMENT COLUMN", next_flag++, COLUMN_LEVEL); + auto comment_column = std::make_unique("COMMENT COLUMN", next_flag++, COLUMN); comment_column->aliases.push_back("ALTER COMMENT COLUMN"); - auto clear_column = std::make_unique("CLEAR COLUMN", next_flag++, COLUMN_LEVEL); + auto clear_column = std::make_unique("CLEAR COLUMN", next_flag++, COLUMN); clear_column->aliases.push_back("ALTER CLEAR COLUMN"); auto alter_column = std::make_unique("ALTER COLUMN", std::move(add_column), std::move(modify_column), std::move(drop_column), std::move(comment_column), std::move(clear_column)); - auto alter_order_by = std::make_unique("ALTER ORDER BY", next_flag++, TABLE_LEVEL); + auto alter_order_by = std::make_unique("ALTER ORDER BY", next_flag++, TABLE); alter_order_by->aliases.push_back("MODIFY ORDER BY"); alter_order_by->aliases.push_back("ALTER MODIFY ORDER BY"); - auto add_index = std::make_unique("ADD INDEX", next_flag++, TABLE_LEVEL); + auto add_index = std::make_unique("ADD INDEX", next_flag++, TABLE); add_index->aliases.push_back("ALTER ADD INDEX"); - auto drop_index = std::make_unique("DROP INDEX", next_flag++, TABLE_LEVEL); + auto drop_index = std::make_unique("DROP INDEX", next_flag++, TABLE); drop_index->aliases.push_back("ALTER DROP INDEX"); - auto materialize_index = std::make_unique("MATERIALIZE INDEX", next_flag++, TABLE_LEVEL); + auto materialize_index = std::make_unique("MATERIALIZE INDEX", next_flag++, TABLE); materialize_index->aliases.push_back("ALTER MATERIALIZE INDEX"); - auto clear_index = std::make_unique("CLEAR INDEX", next_flag++, TABLE_LEVEL); + auto clear_index = std::make_unique("CLEAR INDEX", next_flag++, TABLE); clear_index->aliases.push_back("ALTER CLEAR INDEX"); auto index = std::make_unique("INDEX", std::move(alter_order_by), std::move(add_index), std::move(drop_index), std::move(materialize_index), std::move(clear_index)); index->aliases.push_back("ALTER INDEX"); - auto add_constraint = std::make_unique("ADD CONSTRAINT", next_flag++, TABLE_LEVEL); + auto add_constraint = std::make_unique("ADD CONSTRAINT", next_flag++, TABLE); add_constraint->aliases.push_back("ALTER ADD CONSTRAINT"); - auto drop_constraint = std::make_unique("DROP CONSTRAINT", next_flag++, TABLE_LEVEL); + auto drop_constraint = std::make_unique("DROP CONSTRAINT", next_flag++, TABLE); drop_constraint->aliases.push_back("ALTER DROP CONSTRAINT"); auto alter_constraint = std::make_unique("CONSTRAINT", std::move(add_constraint), std::move(drop_constraint)); alter_constraint->aliases.push_back("ALTER CONSTRAINT"); - auto modify_ttl = std::make_unique("MODIFY TTL", next_flag++, TABLE_LEVEL); + auto modify_ttl = std::make_unique("MODIFY TTL", next_flag++, TABLE); modify_ttl->aliases.push_back("ALTER MODIFY TTL"); - auto materialize_ttl = std::make_unique("MATERIALIZE TTL", next_flag++, TABLE_LEVEL); + auto materialize_ttl = std::make_unique("MATERIALIZE TTL", next_flag++, TABLE); materialize_ttl->aliases.push_back("ALTER MATERIALIZE TTL"); - auto modify_setting = std::make_unique("MODIFY SETTING", next_flag++, TABLE_LEVEL); + auto modify_setting = std::make_unique("MODIFY SETTING", next_flag++, TABLE); modify_setting->aliases.push_back("ALTER MODIFY SETTING"); - auto move_partition = std::make_unique("MOVE PARTITION", next_flag++, TABLE_LEVEL); + auto move_partition = std::make_unique("MOVE PARTITION", next_flag++, TABLE); ext::push_back(move_partition->aliases, "ALTER MOVE PARTITION", "MOVE PART", "ALTER MOVE PART"); - auto fetch_partition = std::make_unique("FETCH PARTITION", next_flag++, TABLE_LEVEL); + auto fetch_partition = std::make_unique("FETCH PARTITION", next_flag++, TABLE); ext::push_back(fetch_partition->aliases, "ALTER FETCH PARTITION"); - auto freeze_partition = std::make_unique("FREEZE PARTITION", next_flag++, TABLE_LEVEL); + auto freeze_partition = std::make_unique("FREEZE PARTITION", next_flag++, TABLE); ext::push_back(freeze_partition->aliases, "ALTER FREEZE PARTITION"); auto alter_table = std::make_unique("ALTER TABLE", std::move(update), std::move(delet), std::move(alter_column), std::move(index), std::move(alter_constraint), std::move(modify_ttl), std::move(materialize_ttl), std::move(modify_setting), std::move(move_partition), std::move(fetch_partition), std::move(freeze_partition)); - auto refresh_view = std::make_unique("REFRESH VIEW", next_flag++, VIEW_LEVEL); + auto refresh_view = std::make_unique("REFRESH VIEW", next_flag++, VIEW); ext::push_back(refresh_view->aliases, "ALTER LIVE VIEW REFRESH"); - auto modify_view_query = std::make_unique("MODIFY VIEW QUERY", next_flag++, VIEW_LEVEL); + auto modify_view_query = std::make_unique("MODIFY VIEW QUERY", next_flag++, VIEW); auto alter_view = std::make_unique("ALTER VIEW", std::move(refresh_view), std::move(modify_view_query)); auto alter = std::make_unique("ALTER", std::move(alter_table), std::move(alter_view)); ext::push_back(all, std::move(alter)); - auto create_database = std::make_unique("CREATE DATABASE", next_flag++, DATABASE_LEVEL); - auto create_table = std::make_unique("CREATE TABLE", next_flag++, TABLE_LEVEL); - auto create_view = std::make_unique("CREATE VIEW", next_flag++, VIEW_LEVEL); - auto create_dictionary = std::make_unique("CREATE DICTIONARY", next_flag++, DICTIONARY_LEVEL); + auto create_database = std::make_unique("CREATE DATABASE", next_flag++, DATABASE); + auto create_table = std::make_unique("CREATE TABLE", next_flag++, TABLE); + auto create_view = std::make_unique("CREATE VIEW", next_flag++, VIEW); + auto create_dictionary = std::make_unique("CREATE DICTIONARY", next_flag++, DICTIONARY); auto create = std::make_unique("CREATE", std::move(create_database), std::move(create_table), std::move(create_view), std::move(create_dictionary)); ext::push_back(all, std::move(create)); - auto create_temporary_table = std::make_unique("CREATE TEMPORARY TABLE", next_flag++, GLOBAL_LEVEL); + auto create_temporary_table = std::make_unique("CREATE TEMPORARY TABLE", next_flag++, GLOBAL); ext::push_back(all, std::move(create_temporary_table)); - auto drop_database = std::make_unique("DROP DATABASE", next_flag++, DATABASE_LEVEL); - auto drop_table = std::make_unique("DROP TABLE", next_flag++, TABLE_LEVEL); - auto drop_view = std::make_unique("DROP VIEW", next_flag++, VIEW_LEVEL); - auto drop_dictionary = std::make_unique("DROP DICTIONARY", next_flag++, DICTIONARY_LEVEL); + auto drop_database = std::make_unique("DROP DATABASE", next_flag++, DATABASE); + auto drop_table = std::make_unique("DROP TABLE", next_flag++, TABLE); + auto drop_view = std::make_unique("DROP VIEW", next_flag++, VIEW); + auto drop_dictionary = std::make_unique("DROP DICTIONARY", next_flag++, DICTIONARY); auto drop = std::make_unique("DROP", std::move(drop_database), std::move(drop_table), std::move(drop_view), std::move(drop_dictionary)); ext::push_back(all, std::move(drop)); - auto truncate_table = std::make_unique("TRUNCATE TABLE", next_flag++, TABLE_LEVEL); - auto truncate_view = std::make_unique("TRUNCATE VIEW", next_flag++, VIEW_LEVEL); + auto truncate_table = std::make_unique("TRUNCATE TABLE", next_flag++, TABLE); + auto truncate_view = std::make_unique("TRUNCATE VIEW", next_flag++, VIEW); auto truncate = std::make_unique("TRUNCATE", std::move(truncate_table), std::move(truncate_view)); ext::push_back(all, std::move(truncate)); - auto optimize = std::make_unique("OPTIMIZE", next_flag++, TABLE_LEVEL); + auto optimize = std::make_unique("OPTIMIZE", next_flag++, TABLE); optimize->aliases.push_back("OPTIMIZE TABLE"); ext::push_back(all, std::move(optimize)); - auto kill_query = std::make_unique("KILL QUERY", next_flag++, GLOBAL_LEVEL); - auto kill_mutation = std::make_unique("KILL MUTATION", next_flag++, TABLE_LEVEL); - auto kill = std::make_unique("KILL", std::move(kill_query), std::move(kill_mutation)); - ext::push_back(all, std::move(kill)); + auto kill_query = std::make_unique("KILL QUERY", next_flag++, GLOBAL); + ext::push_back(all, std::move(kill_query)); - auto create_user = std::make_unique("CREATE USER", next_flag++, GLOBAL_LEVEL); - auto alter_user = std::make_unique("ALTER USER", next_flag++, GLOBAL_LEVEL); - auto drop_user = std::make_unique("DROP USER", next_flag++, GLOBAL_LEVEL); - auto create_role = std::make_unique("CREATE ROLE", next_flag++, GLOBAL_LEVEL); - auto drop_role = std::make_unique("DROP ROLE", next_flag++, GLOBAL_LEVEL); - auto create_policy = std::make_unique("CREATE POLICY", next_flag++, GLOBAL_LEVEL); - auto alter_policy = std::make_unique("ALTER POLICY", next_flag++, GLOBAL_LEVEL); - auto drop_policy = std::make_unique("DROP POLICY", next_flag++, GLOBAL_LEVEL); - auto create_quota = std::make_unique("CREATE QUOTA", next_flag++, GLOBAL_LEVEL); - auto alter_quota = std::make_unique("ALTER QUOTA", next_flag++, GLOBAL_LEVEL); - auto drop_quota = std::make_unique("DROP QUOTA", next_flag++, GLOBAL_LEVEL); - auto role_admin = std::make_unique("ROLE ADMIN", next_flag++, GLOBAL_LEVEL); - ext::push_back(all, std::move(create_user), std::move(alter_user), std::move(drop_user), std::move(create_role), std::move(drop_role), std::move(create_policy), std::move(alter_policy), std::move(drop_policy), std::move(create_quota), std::move(alter_quota), std::move(drop_quota), std::move(role_admin)); + auto create_user = std::make_unique("CREATE USER", next_flag++, GLOBAL); + auto alter_user = std::make_unique("ALTER USER", next_flag++, GLOBAL); + auto drop_user = std::make_unique("DROP USER", next_flag++, GLOBAL); + auto create_role = std::make_unique("CREATE ROLE", next_flag++, GLOBAL); + auto alter_role = std::make_unique("ALTER ROLE", next_flag++, GLOBAL); + auto drop_role = std::make_unique("DROP ROLE", next_flag++, GLOBAL); + auto create_policy = std::make_unique("CREATE POLICY", next_flag++, GLOBAL); + auto alter_policy = std::make_unique("ALTER POLICY", next_flag++, GLOBAL); + auto drop_policy = std::make_unique("DROP POLICY", next_flag++, GLOBAL); + auto create_quota = std::make_unique("CREATE QUOTA", next_flag++, GLOBAL); + auto alter_quota = std::make_unique("ALTER QUOTA", next_flag++, GLOBAL); + auto drop_quota = std::make_unique("DROP QUOTA", next_flag++, GLOBAL); + auto create_profile = std::make_unique("CREATE SETTINGS PROFILE", next_flag++, GLOBAL); + ext::push_back(create_profile->aliases, "CREATE PROFILE"); + auto alter_profile = std::make_unique("ALTER SETTINGS PROFILE", next_flag++, GLOBAL); + ext::push_back(alter_profile->aliases, "ALTER PROFILE"); + auto drop_profile = std::make_unique("DROP SETTINGS PROFILE", next_flag++, GLOBAL); + ext::push_back(drop_profile->aliases, "DROP PROFILE"); + auto role_admin = std::make_unique("ROLE ADMIN", next_flag++, GLOBAL); + ext::push_back(all, std::move(create_user), std::move(alter_user), std::move(drop_user), std::move(create_role), std::move(alter_role), std::move(drop_role), std::move(create_policy), std::move(alter_policy), std::move(drop_policy), std::move(create_quota), std::move(alter_quota), std::move(drop_quota), std::move(create_profile), std::move(alter_profile), std::move(drop_profile), std::move(role_admin)); - auto shutdown = std::make_unique("SHUTDOWN", next_flag++, GLOBAL_LEVEL); + auto shutdown = std::make_unique("SHUTDOWN", next_flag++, GLOBAL); ext::push_back(shutdown->aliases, "SYSTEM SHUTDOWN", "SYSTEM KILL"); - auto drop_cache = std::make_unique("DROP CACHE", next_flag++, GLOBAL_LEVEL); + auto drop_cache = std::make_unique("DROP CACHE", next_flag++, GLOBAL); ext::push_back(drop_cache->aliases, "SYSTEM DROP CACHE", "DROP DNS CACHE", "SYSTEM DROP DNS CACHE", "DROP MARK CACHE", "SYSTEM DROP MARK CACHE", "DROP UNCOMPRESSED CACHE", "SYSTEM DROP UNCOMPRESSED CACHE", "DROP COMPILED EXPRESSION CACHE", "SYSTEM DROP COMPILED EXPRESSION CACHE"); - auto reload_config = std::make_unique("RELOAD CONFIG", next_flag++, GLOBAL_LEVEL); + auto reload_config = std::make_unique("RELOAD CONFIG", next_flag++, GLOBAL); ext::push_back(reload_config->aliases, "SYSTEM RELOAD CONFIG"); - auto reload_dictionary = std::make_unique("RELOAD DICTIONARY", next_flag++, GLOBAL_LEVEL); + auto reload_dictionary = std::make_unique("RELOAD DICTIONARY", next_flag++, GLOBAL); ext::push_back(reload_dictionary->aliases, "SYSTEM RELOAD DICTIONARY", "RELOAD DICTIONARIES", "SYSTEM RELOAD DICTIONARIES", "RELOAD EMBEDDED DICTIONARIES", "SYSTEM RELOAD EMBEDDED DICTIONARIES"); - auto stop_merges = std::make_unique("STOP MERGES", next_flag++, TABLE_LEVEL); + auto stop_merges = std::make_unique("STOP MERGES", next_flag++, TABLE); ext::push_back(stop_merges->aliases, "SYSTEM STOP MERGES", "START MERGES", "SYSTEM START MERGES"); - auto stop_ttl_merges = std::make_unique("STOP TTL MERGES", next_flag++, TABLE_LEVEL); + auto stop_ttl_merges = std::make_unique("STOP TTL MERGES", next_flag++, TABLE); ext::push_back(stop_ttl_merges->aliases, "SYSTEM STOP TTL MERGES", "START TTL MERGES", "SYSTEM START TTL MERGES"); - auto stop_fetches = std::make_unique("STOP FETCHES", next_flag++, TABLE_LEVEL); + auto stop_fetches = std::make_unique("STOP FETCHES", next_flag++, TABLE); ext::push_back(stop_fetches->aliases, "SYSTEM STOP FETCHES", "START FETCHES", "SYSTEM START FETCHES"); - auto stop_moves = std::make_unique("STOP MOVES", next_flag++, TABLE_LEVEL); + auto stop_moves = std::make_unique("STOP MOVES", next_flag++, TABLE); ext::push_back(stop_moves->aliases, "SYSTEM STOP MOVES", "START MOVES", "SYSTEM START MOVES"); - auto stop_distributed_sends = std::make_unique("STOP DISTRIBUTED SENDS", next_flag++, TABLE_LEVEL); + auto stop_distributed_sends = std::make_unique("STOP DISTRIBUTED SENDS", next_flag++, TABLE); ext::push_back(stop_distributed_sends->aliases, "SYSTEM STOP DISTRIBUTED SENDS", "START DISTRIBUTED SENDS", "SYSTEM START DISTRIBUTED SENDS"); - auto stop_replicated_sends = std::make_unique("STOP REPLICATED SENDS", next_flag++, TABLE_LEVEL); + auto stop_replicated_sends = std::make_unique("STOP REPLICATED SENDS", next_flag++, TABLE); ext::push_back(stop_replicated_sends->aliases, "SYSTEM STOP REPLICATED SENDS", "START REPLICATED SENDS", "SYSTEM START REPLICATED SENDS"); - auto stop_replication_queues = std::make_unique("STOP REPLICATION QUEUES", next_flag++, TABLE_LEVEL); + auto stop_replication_queues = std::make_unique("STOP REPLICATION QUEUES", next_flag++, TABLE); ext::push_back(stop_replication_queues->aliases, "SYSTEM STOP REPLICATION QUEUES", "START REPLICATION QUEUES", "SYSTEM START REPLICATION QUEUES"); - auto sync_replica = std::make_unique("SYNC REPLICA", next_flag++, TABLE_LEVEL); + auto sync_replica = std::make_unique("SYNC REPLICA", next_flag++, TABLE); ext::push_back(sync_replica->aliases, "SYSTEM SYNC REPLICA"); - auto restart_replica = std::make_unique("RESTART REPLICA", next_flag++, TABLE_LEVEL); + auto restart_replica = std::make_unique("RESTART REPLICA", next_flag++, TABLE); ext::push_back(restart_replica->aliases, "SYSTEM RESTART REPLICA"); - auto flush_distributed = std::make_unique("FLUSH DISTRIBUTED", next_flag++, TABLE_LEVEL); + auto flush_distributed = std::make_unique("FLUSH DISTRIBUTED", next_flag++, TABLE); ext::push_back(flush_distributed->aliases, "SYSTEM FLUSH DISTRIBUTED"); - auto flush_logs = std::make_unique("FLUSH LOGS", next_flag++, GLOBAL_LEVEL); + auto flush_logs = std::make_unique("FLUSH LOGS", next_flag++, GLOBAL); ext::push_back(flush_logs->aliases, "SYSTEM FLUSH LOGS"); auto system = std::make_unique("SYSTEM", std::move(shutdown), std::move(drop_cache), std::move(reload_config), std::move(reload_dictionary), std::move(stop_merges), std::move(stop_ttl_merges), std::move(stop_fetches), std::move(stop_moves), std::move(stop_distributed_sends), std::move(stop_replicated_sends), std::move(stop_replication_queues), std::move(sync_replica), std::move(restart_replica), std::move(flush_distributed), std::move(flush_logs)); ext::push_back(all, std::move(system)); - auto dict_get = std::make_unique("dictGet()", next_flag++, DICTIONARY_LEVEL); + auto dict_get = std::make_unique("dictGet()", next_flag++, DICTIONARY); dict_get->aliases.push_back("dictHas()"); dict_get->aliases.push_back("dictGetHierarchy()"); dict_get->aliases.push_back("dictIsIn()"); ext::push_back(all, std::move(dict_get)); - auto address_to_line = std::make_unique("addressToLine()", next_flag++, GLOBAL_LEVEL); - auto address_to_symbol = std::make_unique("addressToSymbol()", next_flag++, GLOBAL_LEVEL); - auto demangle = std::make_unique("demangle()", next_flag++, GLOBAL_LEVEL); + auto address_to_line = std::make_unique("addressToLine()", next_flag++, GLOBAL); + auto address_to_symbol = std::make_unique("addressToSymbol()", next_flag++, GLOBAL); + auto demangle = std::make_unique("demangle()", next_flag++, GLOBAL); auto introspection = std::make_unique("INTROSPECTION", std::move(address_to_line), std::move(address_to_symbol), std::move(demangle)); ext::push_back(introspection->aliases, "INTROSPECTION FUNCTIONS"); ext::push_back(all, std::move(introspection)); - auto file = std::make_unique("file()", next_flag++, GLOBAL_LEVEL); - auto url = std::make_unique("url()", next_flag++, GLOBAL_LEVEL); - auto input = std::make_unique("input()", next_flag++, GLOBAL_LEVEL); - auto values = std::make_unique("values()", next_flag++, GLOBAL_LEVEL); - auto numbers = std::make_unique("numbers()", next_flag++, GLOBAL_LEVEL); - auto zeros = std::make_unique("zeros()", next_flag++, GLOBAL_LEVEL); - auto merge = std::make_unique("merge()", next_flag++, DATABASE_LEVEL); - auto remote = std::make_unique("remote()", next_flag++, GLOBAL_LEVEL); + auto file = std::make_unique("file()", next_flag++, GLOBAL); + auto url = std::make_unique("url()", next_flag++, GLOBAL); + auto input = std::make_unique("input()", next_flag++, GLOBAL); + auto values = std::make_unique("values()", next_flag++, GLOBAL); + auto numbers = std::make_unique("numbers()", next_flag++, GLOBAL); + auto zeros = std::make_unique("zeros()", next_flag++, GLOBAL); + auto merge = std::make_unique("merge()", next_flag++, DATABASE); + auto remote = std::make_unique("remote()", next_flag++, GLOBAL); ext::push_back(remote->aliases, "remoteSecure()", "cluster()"); - auto mysql = std::make_unique("mysql()", next_flag++, GLOBAL_LEVEL); - auto odbc = std::make_unique("odbc()", next_flag++, GLOBAL_LEVEL); - auto jdbc = std::make_unique("jdbc()", next_flag++, GLOBAL_LEVEL); - auto hdfs = std::make_unique("hdfs()", next_flag++, GLOBAL_LEVEL); - auto s3 = std::make_unique("s3()", next_flag++, GLOBAL_LEVEL); + auto mysql = std::make_unique("mysql()", next_flag++, GLOBAL); + auto odbc = std::make_unique("odbc()", next_flag++, GLOBAL); + auto jdbc = std::make_unique("jdbc()", next_flag++, GLOBAL); + auto hdfs = std::make_unique("hdfs()", next_flag++, GLOBAL); + auto s3 = std::make_unique("s3()", next_flag++, GLOBAL); auto table_functions = std::make_unique("TABLE FUNCTIONS", std::move(file), std::move(url), std::move(input), std::move(values), std::move(numbers), std::move(zeros), std::move(merge), std::move(remote), std::move(mysql), std::move(odbc), std::move(jdbc), std::move(hdfs), std::move(s3)); ext::push_back(all, std::move(table_functions)); - flags_to_keyword_tree_ = std::make_unique("ALL", std::move(all)); - flags_to_keyword_tree_->aliases.push_back("ALL PRIVILEGES"); + auto node_all = std::make_unique("ALL", std::move(all)); + node_all->aliases.push_back("ALL PRIVILEGES"); + return node_all; } - void makeKeywordToFlagsMap(std::unordered_map & keyword_to_flags_map_, Node * start_node = nullptr) + void makeKeywordToFlagsMap(Node * start_node = nullptr) { if (!start_node) { start_node = flags_to_keyword_tree.get(); - keyword_to_flags_map_["USAGE"] = {}; - keyword_to_flags_map_["NONE"] = {}; - keyword_to_flags_map_["NO PRIVILEGES"] = {}; + keyword_to_flags_map["USAGE"] = {}; + keyword_to_flags_map["NONE"] = {}; + keyword_to_flags_map["NO PRIVILEGES"] = {}; } start_node->aliases.emplace_back(start_node->keyword); for (auto & alias : start_node->aliases) { boost::to_upper(alias); - keyword_to_flags_map_[alias] = start_node->flags; + keyword_to_flags_map[alias] = start_node->flags; } for (auto & child : start_node->children) - makeKeywordToFlagsMap(keyword_to_flags_map_, child.get()); + makeKeywordToFlagsMap(child.get()); } - void makeAccessTypeToFlagsMapping(std::vector & access_type_to_flags_mapping_) + void makeAccessTypeToFlagsMapping() { - access_type_to_flags_mapping_.resize(MAX_ACCESS_TYPE); + access_type_to_flags_mapping.resize(MAX_ACCESS_TYPE); for (auto access_type : ext::range_with_static_cast(0, MAX_ACCESS_TYPE)) { auto str = toKeyword(access_type); @@ -449,35 +472,36 @@ private: boost::to_upper(uppercased); it = keyword_to_flags_map.find(uppercased); } - access_type_to_flags_mapping_[static_cast(access_type)] = it->second; + access_type_to_flags_mapping[static_cast(access_type)] = it->second; } } - void collectAllGrantableOnLevel(std::vector & all_grantable_on_level_, const Node * start_node = nullptr) + void collectAllFlags(const Node * start_node = nullptr) { if (!start_node) { start_node = flags_to_keyword_tree.get(); - all_grantable_on_level.resize(COLUMN_LEVEL + 1); + all_flags = start_node->flags; } - for (int i = 0; i <= start_node->level; ++i) - all_grantable_on_level_[i] |= start_node->flags; + if (start_node->target != UNKNOWN_TARGET) + all_flags_for_target[start_node->target] |= start_node->flags; for (const auto & child : start_node->children) - collectAllGrantableOnLevel(all_grantable_on_level_, child.get()); + collectAllFlags(child.get()); } Impl() { - makeFlagsToKeywordTree(flags_to_keyword_tree); - makeKeywordToFlagsMap(keyword_to_flags_map); - makeAccessTypeToFlagsMapping(access_type_to_flags_mapping); - collectAllGrantableOnLevel(all_grantable_on_level); + flags_to_keyword_tree = makeFlagsToKeywordTree(); + makeKeywordToFlagsMap(); + makeAccessTypeToFlagsMapping(); + collectAllFlags(); } std::unique_ptr flags_to_keyword_tree; std::unordered_map keyword_to_flags_map; std::vector access_type_to_flags_mapping; - std::vector all_grantable_on_level; + Flags all_flags; + Flags all_flags_for_target[NUM_TARGETS]; }; @@ -487,9 +511,12 @@ inline AccessFlags::AccessFlags(const std::vector & keywords) inline AccessFlags::AccessFlags(const Strings & keywords) : flags(Impl<>::instance().keywordsToFlags(keywords)) {} inline String AccessFlags::toString() const { return Impl<>::instance().flagsToString(flags); } inline std::vector AccessFlags::toKeywords() const { return Impl<>::instance().flagsToKeywords(flags); } -inline AccessFlags AccessFlags::databaseLevel() { return Impl<>::instance().getDatabaseLevelFlags(); } -inline AccessFlags AccessFlags::tableLevel() { return Impl<>::instance().getTableLevelFlags(); } -inline AccessFlags AccessFlags::columnLevel() { return Impl<>::instance().getColumnLevelFlags(); } +inline AccessFlags AccessFlags::allFlags() { return Impl<>::instance().getAllFlags(); } +inline AccessFlags AccessFlags::allGlobalFlags() { return Impl<>::instance().getGlobalFlags(); } +inline AccessFlags AccessFlags::allDatabaseFlags() { return Impl<>::instance().getDatabaseFlags(); } +inline AccessFlags AccessFlags::allTableFlags() { return Impl<>::instance().getTableFlags(); } +inline AccessFlags AccessFlags::allColumnFlags() { return Impl<>::instance().getColumnFlags(); } +inline AccessFlags AccessFlags::allDictionaryFlags() { return Impl<>::instance().getDictionaryFlags(); } inline AccessFlags operator |(AccessType left, AccessType right) { return AccessFlags(left) | right; } inline AccessFlags operator &(AccessType left, AccessType right) { return AccessFlags(left) & right; } diff --git a/dbms/src/Access/AccessRights.cpp b/dbms/src/Access/AccessRights.cpp index 80de185ed8f..6f94cfac286 100644 --- a/dbms/src/Access/AccessRights.cpp +++ b/dbms/src/Access/AccessRights.cpp @@ -23,13 +23,6 @@ namespace COLUMN_LEVEL, }; - enum RevokeMode - { - NORMAL_REVOKE_MODE, /// for AccessRights::revoke() - PARTIAL_REVOKE_MODE, /// for AccessRights::partialRevoke() - FULL_REVOKE_MODE, /// for AccessRights::fullRevoke() - }; - struct Helper { static const Helper & instance() @@ -38,13 +31,28 @@ namespace return res; } - const AccessFlags database_level_flags = AccessFlags::databaseLevel(); - const AccessFlags table_level_flags = AccessFlags::tableLevel(); - const AccessFlags column_level_flags = AccessFlags::columnLevel(); - const AccessFlags show_flag = AccessType::SHOW; - const AccessFlags exists_flag = AccessType::EXISTS; + const AccessFlags all_flags = AccessFlags::allFlags(); + const AccessFlags database_flags = AccessFlags::allDatabaseFlags(); + const AccessFlags table_flags = AccessFlags::allTableFlags(); + const AccessFlags column_flags = AccessFlags::allColumnFlags(); + const AccessFlags dictionary_flags = AccessFlags::allDictionaryFlags(); + const AccessFlags column_level_flags = column_flags; + const AccessFlags table_level_flags = table_flags | dictionary_flags | column_level_flags; + const AccessFlags database_level_flags = database_flags | table_level_flags; + + const AccessFlags show_databases_flag = AccessType::SHOW_DATABASES; + const AccessFlags show_tables_flag = AccessType::SHOW_TABLES; + const AccessFlags show_columns_flag = AccessType::SHOW_COLUMNS; + const AccessFlags show_dictionaries_flag = AccessType::SHOW_DICTIONARIES; const AccessFlags create_table_flag = AccessType::CREATE_TABLE; + const AccessFlags create_view_flag = AccessType::CREATE_VIEW; const AccessFlags create_temporary_table_flag = AccessType::CREATE_TEMPORARY_TABLE; + const AccessFlags alter_table_flag = AccessType::ALTER_TABLE; + const AccessFlags alter_view_flag = AccessType::ALTER_VIEW; + const AccessFlags truncate_table_flag = AccessType::TRUNCATE_TABLE; + const AccessFlags truncate_view_flag = AccessType::TRUNCATE_VIEW; + const AccessFlags drop_table_flag = AccessType::DROP_TABLE; + const AccessFlags drop_view_flag = AccessType::DROP_VIEW; }; std::string_view checkCurrentDatabase(const std::string_view & current_database) @@ -61,13 +69,10 @@ struct AccessRights::Node public: std::shared_ptr node_name; Level level = GLOBAL_LEVEL; - AccessFlags explicit_grants; - AccessFlags partial_revokes; - AccessFlags inherited_access; /// the access inherited from the parent node - AccessFlags raw_access; /// raw_access = (inherited_access - partial_revokes) | explicit_grants - AccessFlags access; /// access = raw_access | implicit_access - AccessFlags min_access; /// min_access = access & child[0].access & ... | child[N-1].access - AccessFlags max_access; /// max_access = access | child[0].access | ... | child[N-1].access + AccessFlags access; /// access = (inherited_access - partial_revokes) | explicit_grants + AccessFlags final_access; /// final_access = access | implicit_access + AccessFlags min_access; /// min_access = final_access & child[0].final_access & ... & child[N-1].final_access + AccessFlags max_access; /// max_access = final_access | child[0].final_access | ... | child[N-1].final_access std::unique_ptr> children; Node() = default; @@ -80,11 +85,8 @@ public: node_name = src.node_name; level = src.level; - inherited_access = src.inherited_access; - explicit_grants = src.explicit_grants; - partial_revokes = src.partial_revokes; - raw_access = src.raw_access; access = src.access; + final_access = src.final_access; min_access = src.min_access; max_access = src.max_access; if (src.children) @@ -94,9 +96,9 @@ public: return *this; } - void grant(AccessFlags access_to_grant, const Helper & helper) + void grant(AccessFlags flags, const Helper & helper) { - if (!access_to_grant) + if (!flags) return; if (level == GLOBAL_LEVEL) @@ -105,126 +107,77 @@ public: } else if (level == DATABASE_LEVEL) { - AccessFlags grantable = access_to_grant & helper.database_level_flags; + AccessFlags grantable = flags & helper.database_level_flags; if (!grantable) - throw Exception(access_to_grant.toString() + " cannot be granted on the database level", ErrorCodes::INVALID_GRANT); - access_to_grant = grantable; + throw Exception(flags.toString() + " cannot be granted on the database level", ErrorCodes::INVALID_GRANT); + flags = grantable; } else if (level == TABLE_LEVEL) { - AccessFlags grantable = access_to_grant & helper.table_level_flags; + AccessFlags grantable = flags & helper.table_level_flags; if (!grantable) - throw Exception(access_to_grant.toString() + " cannot be granted on the table level", ErrorCodes::INVALID_GRANT); - access_to_grant = grantable; + throw Exception(flags.toString() + " cannot be granted on the table level", ErrorCodes::INVALID_GRANT); + flags = grantable; } else if (level == COLUMN_LEVEL) { - AccessFlags grantable = access_to_grant & helper.column_level_flags; + AccessFlags grantable = flags & helper.column_level_flags; if (!grantable) - throw Exception(access_to_grant.toString() + " cannot be granted on the column level", ErrorCodes::INVALID_GRANT); - access_to_grant = grantable; + throw Exception(flags.toString() + " cannot be granted on the column level", ErrorCodes::INVALID_GRANT); + flags = grantable; } - AccessFlags new_explicit_grants = access_to_grant - partial_revokes; - if (level == TABLE_LEVEL) - removeExplicitGrantsRec(new_explicit_grants); - removePartialRevokesRec(access_to_grant); - explicit_grants |= new_explicit_grants; - - calculateAllAccessRec(helper); + addGrantsRec(flags); + calculateFinalAccessRec(helper); } template - void grant(const AccessFlags & access_to_grant, const Helper & helper, const std::string_view & name, const Args &... subnames) + void grant(const AccessFlags & flags, const Helper & helper, const std::string_view & name, const Args &... subnames) { auto & child = getChild(name); - child.grant(access_to_grant, helper, subnames...); - eraseChildIfEmpty(child); - calculateImplicitAccess(helper); - calculateMinAndMaxAccess(); + child.grant(flags, helper, subnames...); + eraseChildIfPossible(child); + calculateFinalAccess(helper); } template - void grant(const AccessFlags & access_to_grant, const Helper & helper, const std::vector & names) + void grant(const AccessFlags & flags, const Helper & helper, const std::vector & names) { for (const auto & name : names) { auto & child = getChild(name); - child.grant(access_to_grant, helper); - eraseChildIfEmpty(child); + child.grant(flags, helper); + eraseChildIfPossible(child); } - calculateImplicitAccess(helper); - calculateMinAndMaxAccess(); + calculateFinalAccess(helper); } - template - void revoke(const AccessFlags & access_to_revoke, const Helper & helper) + void revoke(const AccessFlags & flags, const Helper & helper) { - if constexpr (mode == NORMAL_REVOKE_MODE) - { // NOLINT - if (level == TABLE_LEVEL) - removeExplicitGrantsRec(access_to_revoke); - else - removeExplicitGrants(access_to_revoke); - } - else if constexpr (mode == PARTIAL_REVOKE_MODE) - { - if (level == TABLE_LEVEL) - removeExplicitGrantsRec(access_to_revoke); - else - removeExplicitGrants(access_to_revoke); - - AccessFlags new_partial_revokes = access_to_revoke - explicit_grants; - removePartialRevokesRec(new_partial_revokes); - partial_revokes |= new_partial_revokes; - } - else /// mode == FULL_REVOKE_MODE - { - AccessFlags new_partial_revokes = access_to_revoke - explicit_grants; - removeExplicitGrantsRec(access_to_revoke); - removePartialRevokesRec(new_partial_revokes); - partial_revokes |= new_partial_revokes; - } - calculateAllAccessRec(helper); + removeGrantsRec(flags); + calculateFinalAccessRec(helper); } - template - void revoke(const AccessFlags & access_to_revoke, const Helper & helper, const std::string_view & name, const Args &... subnames) + template + void revoke(const AccessFlags & flags, const Helper & helper, const std::string_view & name, const Args &... subnames) { - Node * child; - if (mode == NORMAL_REVOKE_MODE) - { - if (!(child = tryGetChild(name))) - return; - } - else - child = &getChild(name); + auto & child = getChild(name); - child->revoke(access_to_revoke, helper, subnames...); - eraseChildIfEmpty(*child); - calculateImplicitAccess(helper); - calculateMinAndMaxAccess(); + child.revoke(flags, helper, subnames...); + eraseChildIfPossible(child); + calculateFinalAccess(helper); } - template - void revoke(const AccessFlags & access_to_revoke, const Helper & helper, const std::vector & names) + template + void revoke(const AccessFlags & flags, const Helper & helper, const std::vector & names) { - Node * child; for (const auto & name : names) { - if (mode == NORMAL_REVOKE_MODE) - { - if (!(child = tryGetChild(name))) - continue; - } - else - child = &getChild(name); - - child->revoke(access_to_revoke, helper); - eraseChildIfEmpty(*child); + auto & child = getChild(name); + child.revoke(flags, helper); + eraseChildIfPossible(child); } - calculateImplicitAccess(helper); - calculateMinAndMaxAccess(); + calculateFinalAccess(helper); } bool isGranted(const AccessFlags & flags) const @@ -244,7 +197,7 @@ public: if (child) return child->isGranted(flags, subnames...); else - return access.contains(flags); + return final_access.contains(flags); } template @@ -265,7 +218,7 @@ public: } else { - if (!access.contains(flags)) + if (!final_access.contains(flags)) return false; } } @@ -274,7 +227,7 @@ public: friend bool operator ==(const Node & left, const Node & right) { - if ((left.explicit_grants != right.explicit_grants) || (left.partial_revokes != right.partial_revokes)) + if (left.access != right.access) return false; if (!left.children) @@ -287,33 +240,24 @@ public: friend bool operator!=(const Node & left, const Node & right) { return !(left == right); } - bool isEmpty() const - { - return !explicit_grants && !partial_revokes && !children; - } - void merge(const Node & other, const Helper & helper) { - mergeRawAccessRec(other); - calculateGrantsAndPartialRevokesRec(); - calculateAllAccessRec(helper); + mergeAccessRec(other); + calculateFinalAccessRec(helper); } - void traceTree(Poco::Logger * log) const + void logTree(Poco::Logger * log) const { LOG_TRACE(log, "Tree(" << level << "): name=" << (node_name ? *node_name : "NULL") - << ", explicit_grants=" << explicit_grants.toString() - << ", partial_revokes=" << partial_revokes.toString() - << ", inherited_access=" << inherited_access.toString() - << ", raw_access=" << raw_access.toString() << ", access=" << access.toString() + << ", final_access=" << final_access.toString() << ", min_access=" << min_access.toString() << ", max_access=" << max_access.toString() << ", num_children=" << (children ? children->size() : 0)); if (children) { for (auto & child : *children | boost::adaptors::map_values) - child.traceTree(log); + child.logTree(log); } } @@ -349,14 +293,13 @@ private: Node & new_child = (*children)[*new_child_name]; new_child.node_name = std::move(new_child_name); new_child.level = static_cast(level + 1); - new_child.inherited_access = raw_access; - new_child.raw_access = raw_access; + new_child.access = access; return new_child; } - void eraseChildIfEmpty(Node & child) + void eraseChildIfPossible(Node & child) { - if (!child.isEmpty()) + if (!canEraseChild(child)) return; auto it = children->find(*child.node_name); children->erase(it); @@ -364,46 +307,59 @@ private: children = nullptr; } - void calculateImplicitAccess(const Helper & helper) + bool canEraseChild(const Node & child) const { - access = raw_access; - if (access & helper.database_level_flags) - access |= helper.show_flag | helper.exists_flag; - else if ((level >= DATABASE_LEVEL) && children) - access |= helper.exists_flag; - - if ((level == GLOBAL_LEVEL) && (access & helper.create_table_flag)) - access |= helper.create_temporary_table_flag; + return (access == child.access) && !child.children; } - void calculateMinAndMaxAccess() + void addGrantsRec(const AccessFlags & flags) { - min_access = access; - max_access = access; + access |= flags; if (children) { - for (const auto & child : *children | boost::adaptors::map_values) + for (auto it = children->begin(); it != children->end();) { - min_access &= child.min_access; - max_access |= child.max_access; + auto & child = it->second; + child.addGrantsRec(flags); + if (canEraseChild(child)) + it = children->erase(it); + else + ++it; } + if (children->empty()) + children = nullptr; } } - void calculateAllAccessRec(const Helper & helper) + void removeGrantsRec(const AccessFlags & flags) { - partial_revokes &= inherited_access; - raw_access = (inherited_access - partial_revokes) | explicit_grants; + access &= ~flags; + if (children) + { + for (auto it = children->begin(); it != children->end();) + { + auto & child = it->second; + child.removeGrantsRec(flags); + if (canEraseChild(child)) + it = children->erase(it); + else + ++it; + } + if (children->empty()) + children = nullptr; + } + } + void calculateFinalAccessRec(const Helper & helper) + { /// Traverse tree. if (children) { for (auto it = children->begin(); it != children->end();) { auto & child = it->second; - child.inherited_access = raw_access; - child.calculateAllAccessRec(helper); - if (child.isEmpty()) + child.calculateFinalAccessRec(helper); + if (canEraseChild(child)) it = children->erase(it); else ++it; @@ -412,64 +368,95 @@ private: children = nullptr; } - calculateImplicitAccess(helper); - calculateMinAndMaxAccess(); + calculateFinalAccess(helper); } - void removeExplicitGrants(const AccessFlags & change) + void calculateFinalAccess(const Helper & helper) { - explicit_grants -= change; - } - - void removeExplicitGrantsRec(const AccessFlags & change) - { - removeExplicitGrants(change); + /// Calculate min and max access among children. + AccessFlags min_access_among_children = helper.all_flags; + AccessFlags max_access_among_children; if (children) { - for (auto & child : *children | boost::adaptors::map_values) - child.removeExplicitGrantsRec(change); + for (const auto & child : *children | boost::adaptors::map_values) + { + min_access_among_children &= child.min_access; + max_access_among_children |= child.max_access; + } } - } - void removePartialRevokesRec(const AccessFlags & change) - { - partial_revokes -= change; - if (children) + /// Calculate implicit access: + AccessFlags implicit_access; + + if (level <= DATABASE_LEVEL) { - for (auto & child : *children | boost::adaptors::map_values) - child.removePartialRevokesRec(change); + if (access & helper.database_flags) + implicit_access |= helper.show_databases_flag; } + if (level <= TABLE_LEVEL) + { + if (access & helper.table_flags) + implicit_access |= helper.show_tables_flag; + if (access & helper.dictionary_flags) + implicit_access |= helper.show_dictionaries_flag; + } + if (level <= COLUMN_LEVEL) + { + if (access & helper.column_flags) + implicit_access |= helper.show_columns_flag; + } + if (children && max_access_among_children) + { + if (level == DATABASE_LEVEL) + implicit_access |= helper.show_databases_flag; + else if (level == TABLE_LEVEL) + implicit_access |= helper.show_tables_flag; + } + + if ((level == GLOBAL_LEVEL) && ((access | max_access_among_children) & helper.create_table_flag)) + implicit_access |= helper.create_temporary_table_flag; + + if (level <= TABLE_LEVEL) + { + if (access & helper.create_table_flag) + implicit_access |= helper.create_view_flag; + + if (access & helper.drop_table_flag) + implicit_access |= helper.drop_view_flag; + + if (access & helper.alter_table_flag) + implicit_access |= helper.alter_view_flag; + + if (access & helper.truncate_table_flag) + implicit_access |= helper.truncate_view_flag; + } + + final_access = access | implicit_access; + + /// Calculate min and max access: + /// min_access = final_access & child[0].final_access & ... & child[N-1].final_access + /// max_access = final_access | child[0].final_access | ... | child[N-1].final_access + min_access = final_access & min_access_among_children; + max_access = final_access | max_access_among_children; } - void mergeRawAccessRec(const Node & rhs) + void mergeAccessRec(const Node & rhs) { if (rhs.children) { for (const auto & [rhs_childname, rhs_child] : *rhs.children) - getChild(rhs_childname).mergeRawAccessRec(rhs_child); + getChild(rhs_childname).mergeAccessRec(rhs_child); } - raw_access |= rhs.raw_access; + access |= rhs.access; if (children) { for (auto & [lhs_childname, lhs_child] : *children) { - lhs_child.inherited_access = raw_access; if (!rhs.tryGetChild(lhs_childname)) - lhs_child.raw_access |= rhs.raw_access; + lhs_child.access |= rhs.access; } } } - - void calculateGrantsAndPartialRevokesRec() - { - explicit_grants = raw_access - inherited_access; - partial_revokes = inherited_access - raw_access; - if (children) - { - for (auto & child : *children | boost::adaptors::map_values) - child.calculateGrantsAndPartialRevokesRec(); - } - } }; @@ -514,165 +501,150 @@ void AccessRights::clear() template -void AccessRights::grantImpl(const AccessFlags & access, const Args &... args) +void AccessRights::grantImpl(const AccessFlags & flags, const Args &... args) { if (!root) root = std::make_unique(); - root->grant(access, Helper::instance(), args...); - if (root->isEmpty()) + root->grant(flags, Helper::instance(), args...); + if (!root->access && !root->children) root = nullptr; } -void AccessRights::grantImpl(const AccessRightsElement & element, std::string_view current_database) +void AccessRights::grant(const AccessFlags & flags) { grantImpl(flags); } +void AccessRights::grant(const AccessFlags & flags, const std::string_view & database) { grantImpl(flags, database); } +void AccessRights::grant(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) { grantImpl(flags, database, table); } +void AccessRights::grant(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) { grantImpl(flags, database, table, column); } +void AccessRights::grant(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) { grantImpl(flags, database, table, columns); } +void AccessRights::grant(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) { grantImpl(flags, database, table, columns); } + +void AccessRights::grant(const AccessRightsElement & element, std::string_view current_database) { if (element.any_database) { - grantImpl(element.access_flags); + grant(element.access_flags); } else if (element.any_table) { if (element.database.empty()) - grantImpl(element.access_flags, checkCurrentDatabase(current_database)); + grant(element.access_flags, checkCurrentDatabase(current_database)); else - grantImpl(element.access_flags, element.database); + grant(element.access_flags, element.database); } else if (element.any_column) { if (element.database.empty()) - grantImpl(element.access_flags, checkCurrentDatabase(current_database), element.table); + grant(element.access_flags, checkCurrentDatabase(current_database), element.table); else - grantImpl(element.access_flags, element.database, element.table); + grant(element.access_flags, element.database, element.table); } else { if (element.database.empty()) - grantImpl(element.access_flags, checkCurrentDatabase(current_database), element.table, element.columns); + grant(element.access_flags, checkCurrentDatabase(current_database), element.table, element.columns); else - grantImpl(element.access_flags, element.database, element.table, element.columns); + grant(element.access_flags, element.database, element.table, element.columns); } } -void AccessRights::grantImpl(const AccessRightsElements & elements, std::string_view current_database) +void AccessRights::grant(const AccessRightsElements & elements, std::string_view current_database) { for (const auto & element : elements) - grantImpl(element, current_database); + grant(element, current_database); } -void AccessRights::grant(const AccessFlags & access) { grantImpl(access); } -void AccessRights::grant(const AccessFlags & access, const std::string_view & database) { grantImpl(access, database); } -void AccessRights::grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table) { grantImpl(access, database, table); } -void AccessRights::grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) { grantImpl(access, database, table, column); } -void AccessRights::grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) { grantImpl(access, database, table, columns); } -void AccessRights::grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) { grantImpl(access, database, table, columns); } -void AccessRights::grant(const AccessRightsElement & element, std::string_view current_database) { grantImpl(element, current_database); } -void AccessRights::grant(const AccessRightsElements & elements, std::string_view current_database) { grantImpl(elements, current_database); } -template -void AccessRights::revokeImpl(const AccessFlags & access, const Args &... args) +template +void AccessRights::revokeImpl(const AccessFlags & flags, const Args &... args) { if (!root) return; - root->revoke(access, Helper::instance(), args...); - if (root->isEmpty()) + root->revoke(flags, Helper::instance(), args...); + if (!root->access && !root->children) root = nullptr; } -template -void AccessRights::revokeImpl(const AccessRightsElement & element, std::string_view current_database) +void AccessRights::revoke(const AccessFlags & flags) { revokeImpl(flags); } +void AccessRights::revoke(const AccessFlags & flags, const std::string_view & database) { revokeImpl(flags, database); } +void AccessRights::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) { revokeImpl(flags, database, table); } +void AccessRights::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(flags, database, table, column); } +void AccessRights::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(flags, database, table, columns); } +void AccessRights::revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(flags, database, table, columns); } + + +void AccessRights::revoke(const AccessRightsElement & element, std::string_view current_database) { if (element.any_database) { - revokeImpl(element.access_flags); + revoke(element.access_flags); } else if (element.any_table) { if (element.database.empty()) - revokeImpl(element.access_flags, checkCurrentDatabase(current_database)); + revoke(element.access_flags, checkCurrentDatabase(current_database)); else - revokeImpl(element.access_flags, element.database); + revoke(element.access_flags, element.database); } else if (element.any_column) { if (element.database.empty()) - revokeImpl(element.access_flags, checkCurrentDatabase(current_database), element.table); + revoke(element.access_flags, checkCurrentDatabase(current_database), element.table); else - revokeImpl(element.access_flags, element.database, element.table); + revoke(element.access_flags, element.database, element.table); } else { if (element.database.empty()) - revokeImpl(element.access_flags, checkCurrentDatabase(current_database), element.table, element.columns); + revoke(element.access_flags, checkCurrentDatabase(current_database), element.table, element.columns); else - revokeImpl(element.access_flags, element.database, element.table, element.columns); + revoke(element.access_flags, element.database, element.table, element.columns); } } -template -void AccessRights::revokeImpl(const AccessRightsElements & elements, std::string_view current_database) +void AccessRights::revoke(const AccessRightsElements & elements, std::string_view current_database) { for (const auto & element : elements) - revokeImpl(element, current_database); + revoke(element, current_database); } -void AccessRights::revoke(const AccessFlags & access) { revokeImpl(access); } -void AccessRights::revoke(const AccessFlags & access, const std::string_view & database) { revokeImpl(access, database); } -void AccessRights::revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table) { revokeImpl(access, database, table); } -void AccessRights::revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(access, database, table, column); } -void AccessRights::revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(access, database, table, columns); } -void AccessRights::revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(access, database, table, columns); } -void AccessRights::revoke(const AccessRightsElement & element, std::string_view current_database) { revokeImpl(element, current_database); } -void AccessRights::revoke(const AccessRightsElements & elements, std::string_view current_database) { revokeImpl(elements, current_database); } - -void AccessRights::partialRevoke(const AccessFlags & access) { revokeImpl(access); } -void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database) { revokeImpl(access, database); } -void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table) { revokeImpl(access, database, table); } -void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(access, database, table, column); } -void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(access, database, table, columns); } -void AccessRights::partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(access, database, table, columns); } -void AccessRights::partialRevoke(const AccessRightsElement & element, std::string_view current_database) { revokeImpl(element, current_database); } -void AccessRights::partialRevoke(const AccessRightsElements & elements, std::string_view current_database) { revokeImpl(elements, current_database); } - -void AccessRights::fullRevoke(const AccessFlags & access) { revokeImpl(access); } -void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database) { revokeImpl(access, database); } -void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table) { revokeImpl(access, database, table); } -void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) { revokeImpl(access, database, table, column); } -void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) { revokeImpl(access, database, table, columns); } -void AccessRights::fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) { revokeImpl(access, database, table, columns); } -void AccessRights::fullRevoke(const AccessRightsElement & element, std::string_view current_database) { revokeImpl(element, current_database); } -void AccessRights::fullRevoke(const AccessRightsElements & elements, std::string_view current_database) { revokeImpl(elements, current_database); } - AccessRights::Elements AccessRights::getElements() const { if (!root) return {}; Elements res; - if (root->explicit_grants) - res.grants.push_back({root->explicit_grants}); + auto global_access = root->access; + if (global_access) + res.grants.push_back({global_access}); if (root->children) { for (const auto & [db_name, db_node] : *root->children) { - if (db_node.partial_revokes) - res.partial_revokes.push_back({db_node.partial_revokes, db_name}); - if (db_node.explicit_grants) - res.grants.push_back({db_node.explicit_grants, db_name}); + auto db_grants = db_node.access - global_access; + auto db_partial_revokes = global_access - db_node.access; + if (db_partial_revokes) + res.partial_revokes.push_back({db_partial_revokes, db_name}); + if (db_grants) + res.grants.push_back({db_grants, db_name}); if (db_node.children) { for (const auto & [table_name, table_node] : *db_node.children) { - if (table_node.partial_revokes) - res.partial_revokes.push_back({table_node.partial_revokes, db_name, table_name}); - if (table_node.explicit_grants) - res.grants.push_back({table_node.explicit_grants, db_name, table_name}); + auto table_grants = table_node.access - db_node.access; + auto table_partial_revokes = db_node.access - table_node.access; + if (table_partial_revokes) + res.partial_revokes.push_back({table_partial_revokes, db_name, table_name}); + if (table_grants) + res.grants.push_back({table_grants, db_name, table_name}); if (table_node.children) { for (const auto & [column_name, column_node] : *table_node.children) { - if (column_node.partial_revokes) - res.partial_revokes.push_back({column_node.partial_revokes, db_name, table_name, column_name}); - if (column_node.explicit_grants) - res.grants.push_back({column_node.explicit_grants, db_name, table_name, column_name}); + auto column_grants = column_node.access - table_node.access; + auto column_partial_revokes = table_node.access - column_node.access; + if (column_partial_revokes) + res.partial_revokes.push_back({column_partial_revokes, db_name, table_name, column_name}); + if (column_grants) + res.grants.push_back({column_grants, db_name, table_name, column_name}); } } } @@ -706,59 +678,57 @@ String AccessRights::toString() const template -bool AccessRights::isGrantedImpl(const AccessFlags & access, const Args &... args) const +bool AccessRights::isGrantedImpl(const AccessFlags & flags, const Args &... args) const { if (!root) - return access.isEmpty(); - return root->isGranted(access, args...); + return flags.isEmpty(); + return root->isGranted(flags, args...); } -bool AccessRights::isGrantedImpl(const AccessRightsElement & element, std::string_view current_database) const +bool AccessRights::isGranted(const AccessFlags & flags) const { return isGrantedImpl(flags); } +bool AccessRights::isGranted(const AccessFlags & flags, const std::string_view & database) const { return isGrantedImpl(flags, database); } +bool AccessRights::isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const { return isGrantedImpl(flags, database, table); } +bool AccessRights::isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return isGrantedImpl(flags, database, table, column); } +bool AccessRights::isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return isGrantedImpl(flags, database, table, columns); } +bool AccessRights::isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return isGrantedImpl(flags, database, table, columns); } + +bool AccessRights::isGranted(const AccessRightsElement & element, std::string_view current_database) const { if (element.any_database) { - return isGrantedImpl(element.access_flags); + return isGranted(element.access_flags); } else if (element.any_table) { if (element.database.empty()) - return isGrantedImpl(element.access_flags, checkCurrentDatabase(current_database)); + return isGranted(element.access_flags, checkCurrentDatabase(current_database)); else - return isGrantedImpl(element.access_flags, element.database); + return isGranted(element.access_flags, element.database); } else if (element.any_column) { if (element.database.empty()) - return isGrantedImpl(element.access_flags, checkCurrentDatabase(current_database), element.table); + return isGranted(element.access_flags, checkCurrentDatabase(current_database), element.table); else - return isGrantedImpl(element.access_flags, element.database, element.table); + return isGranted(element.access_flags, element.database, element.table); } else { if (element.database.empty()) - return isGrantedImpl(element.access_flags, checkCurrentDatabase(current_database), element.table, element.columns); + return isGranted(element.access_flags, checkCurrentDatabase(current_database), element.table, element.columns); else - return isGrantedImpl(element.access_flags, element.database, element.table, element.columns); + return isGranted(element.access_flags, element.database, element.table, element.columns); } } -bool AccessRights::isGrantedImpl(const AccessRightsElements & elements, std::string_view current_database) const +bool AccessRights::isGranted(const AccessRightsElements & elements, std::string_view current_database) const { for (const auto & element : elements) - if (!isGrantedImpl(element, current_database)) + if (!isGranted(element, current_database)) return false; return true; } -bool AccessRights::isGranted(const AccessFlags & access) const { return isGrantedImpl(access); } -bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database) const { return isGrantedImpl(access, database); } -bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return isGrantedImpl(access, database, table); } -bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return isGrantedImpl(access, database, table, column); } -bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return isGrantedImpl(access, database, table, columns); } -bool AccessRights::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return isGrantedImpl(access, database, table, columns); } -bool AccessRights::isGranted(const AccessRightsElement & element, std::string_view current_database) const { return isGrantedImpl(element, current_database); } -bool AccessRights::isGranted(const AccessRightsElements & elements, std::string_view current_database) const { return isGrantedImpl(elements, current_database); } - bool operator ==(const AccessRights & left, const AccessRights & right) { @@ -780,17 +750,17 @@ void AccessRights::merge(const AccessRights & other) if (other.root) { root->merge(*other.root, Helper::instance()); - if (root->isEmpty()) + if (!root->access && !root->children) root = nullptr; } } -void AccessRights::traceTree() const +void AccessRights::logTree() const { auto * log = &Poco::Logger::get("AccessRights"); if (root) - root->traceTree(log); + root->logTree(log); else LOG_TRACE(log, "Tree: NULL"); } diff --git a/dbms/src/Access/AccessRights.h b/dbms/src/Access/AccessRights.h index 67d205ec6dc..133038f2d44 100644 --- a/dbms/src/Access/AccessRights.h +++ b/dbms/src/Access/AccessRights.h @@ -23,60 +23,31 @@ public: bool isEmpty() const; - /// Revokes everything. It's the same as fullRevoke(AccessType::ALL). + /// Revokes everything. It's the same as revoke(AccessType::ALL). void clear(); /// Grants access on a specified database/table/column. /// Does nothing if the specified access has been already granted. - void grant(const AccessFlags & access); - void grant(const AccessFlags & access, const std::string_view & database); - void grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table); - void grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column); - void grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns); - void grant(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns); + void grant(const AccessFlags & flags); + void grant(const AccessFlags & flags, const std::string_view & database); + void grant(const AccessFlags & flags, const std::string_view & database, const std::string_view & table); + void grant(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column); + void grant(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns); + void grant(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns); void grant(const AccessRightsElement & element, std::string_view current_database = {}); void grant(const AccessRightsElements & elements, std::string_view current_database = {}); /// Revokes a specified access granted earlier on a specified database/table/column. - /// Does nothing if the specified access is not granted. - /// If the specified access is granted but on upper level (e.g. database for table, table for columns) - /// or lower level, the function also does nothing. - /// This function implements the standard SQL REVOKE behaviour. - void revoke(const AccessFlags & access); - void revoke(const AccessFlags & access, const std::string_view & database); - void revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table); - void revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column); - void revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns); - void revoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns); + /// For example, revoke(AccessType::ALL) revokes all grants at all, just like clear(); + void revoke(const AccessFlags & flags); + void revoke(const AccessFlags & flags, const std::string_view & database); + void revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table); + void revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column); + void revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns); + void revoke(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns); void revoke(const AccessRightsElement & element, std::string_view current_database = {}); void revoke(const AccessRightsElements & elements, std::string_view current_database = {}); - /// Revokes a specified access granted earlier on a specified database/table/column or on lower levels. - /// The function also restricts access if it's granted on upper level. - /// For example, an access could be granted on a database and then revoked on a table in this database. - /// This function implements the MySQL REVOKE behaviour with partial_revokes is ON. - void partialRevoke(const AccessFlags & access); - void partialRevoke(const AccessFlags & access, const std::string_view & database); - void partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table); - void partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column); - void partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns); - void partialRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns); - void partialRevoke(const AccessRightsElement & element, std::string_view current_database = {}); - void partialRevoke(const AccessRightsElements & elements, std::string_view current_database = {}); - - /// Revokes a specified access granted earlier on a specified database/table/column or on lower levels. - /// The function also restricts access if it's granted on upper level. - /// For example, fullRevoke(AccessType::ALL) revokes all grants at all, just like clear(); - /// fullRevoke(AccessType::SELECT, db) means it's not allowed to execute SELECT in that database anymore (from any table). - void fullRevoke(const AccessFlags & access); - void fullRevoke(const AccessFlags & access, const std::string_view & database); - void fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table); - void fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column); - void fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns); - void fullRevoke(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns); - void fullRevoke(const AccessRightsElement & element, std::string_view current_database = {}); - void fullRevoke(const AccessRightsElements & elements, std::string_view current_database = {}); - /// Returns the information about all the access granted. struct Elements { @@ -89,12 +60,12 @@ public: String toString() const; /// Whether a specified access granted. - bool isGranted(const AccessFlags & access) const; - bool isGranted(const AccessFlags & access, const std::string_view & database) const; - bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; - bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + bool isGranted(const AccessFlags & flags) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const; bool isGranted(const AccessRightsElement & element, std::string_view current_database = {}) const; bool isGranted(const AccessRightsElements & elements, std::string_view current_database = {}) const; @@ -107,22 +78,13 @@ public: private: template - void grantImpl(const AccessFlags & access, const Args &... args); - - void grantImpl(const AccessRightsElement & element, std::string_view current_database); - void grantImpl(const AccessRightsElements & elements, std::string_view current_database); - - template - void revokeImpl(const AccessFlags & access, const Args &... args); - - template - void revokeImpl(const AccessRightsElement & element, std::string_view current_database); - - template - void revokeImpl(const AccessRightsElements & elements, std::string_view current_database); + void grantImpl(const AccessFlags & flags, const Args &... args); template - bool isGrantedImpl(const AccessFlags & access, const Args &... args) const; + void revokeImpl(const AccessFlags & flags, const Args &... args); + + template + bool isGrantedImpl(const AccessFlags & flags, const Args &... args) const; bool isGrantedImpl(const AccessRightsElement & element, std::string_view current_database) const; bool isGrantedImpl(const AccessRightsElements & elements, std::string_view current_database) const; @@ -130,7 +92,7 @@ private: template AccessFlags getAccessImpl(const Args &... args) const; - void traceTree() const; + void logTree() const; struct Node; std::unique_ptr root; diff --git a/dbms/src/Access/AccessRightsContext.cpp b/dbms/src/Access/AccessRightsContext.cpp deleted file mode 100644 index 9e781cbe280..00000000000 --- a/dbms/src/Access/AccessRightsContext.cpp +++ /dev/null @@ -1,586 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace DB -{ -namespace ErrorCodes -{ - extern const int ACCESS_DENIED; - extern const int READONLY; - extern const int QUERY_IS_PROHIBITED; - extern const int FUNCTION_NOT_ALLOWED; - extern const int UNKNOWN_USER; -} - - -namespace -{ - enum CheckAccessRightsMode - { - RETURN_FALSE_IF_ACCESS_DENIED, - LOG_WARNING_IF_ACCESS_DENIED, - THROW_IF_ACCESS_DENIED, - }; - - - String formatSkippedMessage() - { - return ""; - } - - String formatSkippedMessage(const std::string_view & database) - { - return ". Skipped database " + backQuoteIfNeed(database); - } - - String formatSkippedMessage(const std::string_view & database, const std::string_view & table) - { - String str = ". Skipped table "; - if (!database.empty()) - str += backQuoteIfNeed(database) + "."; - str += backQuoteIfNeed(table); - return str; - } - - String formatSkippedMessage(const std::string_view & database, const std::string_view & table, const std::string_view & column) - { - String str = ". Skipped column " + backQuoteIfNeed(column) + " ON "; - if (!database.empty()) - str += backQuoteIfNeed(database) + "."; - str += backQuoteIfNeed(table); - return str; - } - - template - String formatSkippedMessage(const std::string_view & database, const std::string_view & table, const std::vector & columns) - { - if (columns.size() == 1) - return formatSkippedMessage(database, table, columns[0]); - - String str = ". Skipped columns "; - bool need_comma = false; - for (const auto & column : columns) - { - if (std::exchange(need_comma, true)) - str += ", "; - str += backQuoteIfNeed(column); - } - str += " ON "; - if (!database.empty()) - str += backQuoteIfNeed(database) + "."; - str += backQuoteIfNeed(table); - return str; - } -} - - -AccessRightsContext::AccessRightsContext() -{ - auto everything_granted = boost::make_shared(); - everything_granted->grant(AccessType::ALL); - boost::range::fill(result_access_cache, everything_granted); - - enabled_roles_with_admin_option = boost::make_shared>(); - - row_policy_context = std::make_shared(); - quota_context = std::make_shared(); -} - - -AccessRightsContext::AccessRightsContext(const AccessControlManager & manager_, const Params & params_) - : manager(&manager_) - , params(params_) -{ - subscription_for_user_change = manager->subscribeForChanges( - *params.user_id, [this](const UUID &, const AccessEntityPtr & entity) - { - UserPtr changed_user = entity ? typeid_cast(entity) : nullptr; - std::lock_guard lock{mutex}; - setUser(changed_user); - }); - - setUser(manager->read(*params.user_id)); -} - - -void AccessRightsContext::setUser(const UserPtr & user_) const -{ - user = user_; - if (!user) - { - /// User has been dropped. - auto nothing_granted = boost::make_shared(); - boost::range::fill(result_access_cache, nothing_granted); - subscription_for_user_change = {}; - subscription_for_roles_info_change = {}; - role_context = nullptr; - enabled_roles_with_admin_option = boost::make_shared>(); - row_policy_context = std::make_shared(); - quota_context = std::make_shared(); - return; - } - - user_name = user->getName(); - trace_log = &Poco::Logger::get("AccessRightsContext (" + user_name + ")"); - - std::vector current_roles, current_roles_with_admin_option; - if (params.use_default_roles) - { - for (const UUID & id : user->granted_roles) - { - if (user->default_roles.match(id)) - current_roles.push_back(id); - } - boost::range::set_intersection(current_roles, user->granted_roles_with_admin_option, - std::back_inserter(current_roles_with_admin_option)); - } - else - { - current_roles.reserve(params.current_roles.size()); - for (const auto & id : params.current_roles) - { - if (user->granted_roles.contains(id)) - current_roles.push_back(id); - if (user->granted_roles_with_admin_option.contains(id)) - current_roles_with_admin_option.push_back(id); - } - } - - subscription_for_roles_info_change = {}; - role_context = manager->getRoleContext(current_roles, current_roles_with_admin_option); - subscription_for_roles_info_change = role_context->subscribeForChanges([this](const CurrentRolesInfoPtr & roles_info_) - { - std::lock_guard lock{mutex}; - setRolesInfo(roles_info_); - }); - - setRolesInfo(role_context->getInfo()); -} - - -void AccessRightsContext::setRolesInfo(const CurrentRolesInfoPtr & roles_info_) const -{ - assert(roles_info_); - roles_info = roles_info_; - enabled_roles_with_admin_option.store(nullptr /* need to recalculate */); - boost::range::fill(result_access_cache, nullptr /* need recalculate */); - row_policy_context = manager->getRowPolicyContext(*params.user_id, roles_info->enabled_roles); - quota_context = manager->getQuotaContext(user_name, *params.user_id, roles_info->enabled_roles, params.address, params.quota_key); -} - - -bool AccessRightsContext::isCorrectPassword(const String & password) const -{ - std::lock_guard lock{mutex}; - if (!user) - return false; - return user->authentication.isCorrectPassword(password); -} - -bool AccessRightsContext::isClientHostAllowed() const -{ - std::lock_guard lock{mutex}; - if (!user) - return false; - return user->allowed_client_hosts.contains(params.address); -} - - -template -bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const -{ - auto result_access = calculateResultAccess(grant_option); - bool is_granted = result_access->isGranted(access, args...); - - if (trace_log) - LOG_TRACE(trace_log, "Access " << (is_granted ? "granted" : "denied") << ": " << (AccessRightsElement{access, args...}.toString())); - - if (is_granted) - return true; - - if constexpr (mode == RETURN_FALSE_IF_ACCESS_DENIED) - return false; - - if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED) - { - if (!log_) - return false; - } - - auto show_error = [&](const String & msg, [[maybe_unused]] int error_code) - { - if constexpr (mode == THROW_IF_ACCESS_DENIED) - throw Exception(user_name + ": " + msg, error_code); - else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED) - LOG_WARNING(log_, user_name + ": " + msg + formatSkippedMessage(args...)); - }; - - if (!user) - { - show_error("User has been dropped", ErrorCodes::UNKNOWN_USER); - } - else if (grant_option && calculateResultAccess(false, params.readonly, params.allow_ddl, params.allow_introspection)->isGranted(access, args...)) - { - show_error( - "Not enough privileges. " - "The required privileges have been granted, but without grant option. " - "To execute this query it's necessary to have the grant " - + AccessRightsElement{access, args...}.toString() + " WITH GRANT OPTION", - ErrorCodes::ACCESS_DENIED); - } - else if (params.readonly && calculateResultAccess(false, false, params.allow_ddl, params.allow_introspection)->isGranted(access, args...)) - { - if (params.interface == ClientInfo::Interface::HTTP && params.http_method == ClientInfo::HTTPMethod::GET) - show_error( - "Cannot execute query in readonly mode. " - "For queries over HTTP, method GET implies readonly. You should use method POST for modifying queries", - ErrorCodes::READONLY); - else - show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY); - } - else if (!params.allow_ddl && calculateResultAccess(false, params.readonly, true, params.allow_introspection)->isGranted(access, args...)) - { - show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED); - } - else if (!params.allow_introspection && calculateResultAccess(false, params.readonly, params.allow_ddl, true)->isGranted(access, args...)) - { - show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED); - } - else - { - show_error( - "Not enough privileges. To execute this query it's necessary to have the grant " - + AccessRightsElement{access, args...}.toString() + (grant_option ? " WITH GRANT OPTION" : ""), - ErrorCodes::ACCESS_DENIED); - } - - return false; -} - - -template -bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessRightsElement & element) const -{ - if (element.any_database) - { - return checkAccessImpl(log_, element.access_flags); - } - else if (element.any_table) - { - if (element.database.empty()) - return checkAccessImpl(log_, element.access_flags, params.current_database); - else - return checkAccessImpl(log_, element.access_flags, element.database); - } - else if (element.any_column) - { - if (element.database.empty()) - return checkAccessImpl(log_, element.access_flags, params.current_database, element.table); - else - return checkAccessImpl(log_, element.access_flags, element.database, element.table); - } - else - { - if (element.database.empty()) - return checkAccessImpl(log_, element.access_flags, params.current_database, element.table, element.columns); - else - return checkAccessImpl(log_, element.access_flags, element.database, element.table, element.columns); - } -} - - -template -bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessRightsElements & elements) const -{ - for (const auto & element : elements) - if (!checkAccessImpl(log_, element)) - return false; - return true; -} - - -void AccessRightsContext::checkAccess(const AccessFlags & access) const { checkAccessImpl(nullptr, access); } -void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database) const { checkAccessImpl(nullptr, access, database); } -void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkAccessImpl(nullptr, access, database, table); } -void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkAccessImpl(nullptr, access, database, table, column); } -void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkAccessImpl(nullptr, access, database, table, columns); } -void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkAccessImpl(nullptr, access, database, table, columns); } -void AccessRightsContext::checkAccess(const AccessRightsElement & access) const { checkAccessImpl(nullptr, access); } -void AccessRightsContext::checkAccess(const AccessRightsElements & access) const { checkAccessImpl(nullptr, access); } - -bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkAccessImpl(nullptr, access); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(nullptr, access, database); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(nullptr, access, database, table); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(nullptr, access, database, table, column); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(nullptr, access, database, table, columns); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(nullptr, access, database, table, columns); } -bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkAccessImpl(nullptr, access); } -bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkAccessImpl(nullptr, access); } - -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkAccessImpl(log_, access); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(log_, access, database); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(log_, access, database, table); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(log_, access, database, table, column); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(log_, access, database, table, columns); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(log_, access, database, table, columns); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkAccessImpl(log_, access); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkAccessImpl(log_, access); } - -void AccessRightsContext::checkGrantOption(const AccessFlags & access) const { checkAccessImpl(nullptr, access); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database) const { checkAccessImpl(nullptr, access, database); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkAccessImpl(nullptr, access, database, table); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkAccessImpl(nullptr, access, database, table, column); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkAccessImpl(nullptr, access, database, table, columns); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkAccessImpl(nullptr, access, database, table, columns); } -void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) const { checkAccessImpl(nullptr, access); } -void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkAccessImpl(nullptr, access); } - - -void AccessRightsContext::checkAdminOption(const UUID & role_id) const -{ - if (isGranted(AccessType::ROLE_ADMIN)) - return; - - boost::shared_ptr> enabled_roles = enabled_roles_with_admin_option.load(); - if (!enabled_roles) - { - std::lock_guard lock{mutex}; - enabled_roles = enabled_roles_with_admin_option.load(); - if (!enabled_roles) - { - if (roles_info) - enabled_roles = boost::make_shared>(roles_info->enabled_roles_with_admin_option.begin(), roles_info->enabled_roles_with_admin_option.end()); - else - enabled_roles = boost::make_shared>(); - enabled_roles_with_admin_option.store(enabled_roles); - } - } - - if (enabled_roles->contains(role_id)) - return; - - std::optional role_name = manager->readName(role_id); - if (!role_name) - role_name = "ID {" + toString(role_id) + "}"; - throw Exception( - getUserName() + ": Not enough privileges. To execute this query it's necessary to have the grant " + backQuoteIfNeed(*role_name) - + " WITH ADMIN OPTION ", - ErrorCodes::ACCESS_DENIED); -} - - -boost::shared_ptr AccessRightsContext::calculateResultAccess(bool grant_option) const -{ - return calculateResultAccess(grant_option, params.readonly, params.allow_ddl, params.allow_introspection); -} - - -boost::shared_ptr AccessRightsContext::calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const -{ - size_t cache_index = static_cast(readonly_ != params.readonly) - + static_cast(allow_ddl_ != params.allow_ddl) * 2 + - + static_cast(allow_introspection_ != params.allow_introspection) * 3 - + static_cast(grant_option) * 4; - assert(cache_index < std::size(result_access_cache)); - auto cached = result_access_cache[cache_index].load(); - if (cached) - return cached; - - std::lock_guard lock{mutex}; - cached = result_access_cache[cache_index].load(); - if (cached) - return cached; - - auto result_ptr = boost::make_shared(); - auto & result = *result_ptr; - - if (grant_option) - { - result = user->access_with_grant_option; - if (roles_info) - result.merge(roles_info->access_with_grant_option); - } - else - { - result = user->access; - if (roles_info) - result.merge(roles_info->access); - } - - static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW - | AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW - | AccessType::TRUNCATE; - static const AccessFlags dictionary_ddl = AccessType::CREATE_DICTIONARY | AccessType::DROP_DICTIONARY; - static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl; - static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE; - static const AccessFlags all_dcl = AccessType::CREATE_USER | AccessType::CREATE_ROLE | AccessType::CREATE_POLICY - | AccessType::CREATE_QUOTA | AccessType::ALTER_USER | AccessType::ALTER_POLICY | AccessType::ALTER_QUOTA | AccessType::DROP_USER - | AccessType::DROP_ROLE | AccessType::DROP_POLICY | AccessType::DROP_QUOTA | AccessType::ROLE_ADMIN; - - /// Anyone has access to the "system" database. - if (!result.isGranted(AccessType::SELECT, DatabaseCatalog::SYSTEM_DATABASE)) - result.grant(AccessType::SELECT, DatabaseCatalog::SYSTEM_DATABASE); - - /// User has access to temporary or external table if such table was resolved in session or query context - if (!result.isGranted(AccessType::SELECT, DatabaseCatalog::TEMPORARY_DATABASE)) - result.grant(AccessType::SELECT, DatabaseCatalog::TEMPORARY_DATABASE); - - if (readonly_) - result.fullRevoke(write_table_access | all_dcl | AccessType::SYSTEM | AccessType::KILL); - - if (readonly_ || !allow_ddl_) - result.fullRevoke(table_and_dictionary_ddl); - - if (readonly_ && grant_option) - result.fullRevoke(AccessType::ALL); - - if (readonly_ == 1) - { - /// Table functions are forbidden in readonly mode. - /// For example, for readonly = 2 - allowed. - result.fullRevoke(AccessType::CREATE_TEMPORARY_TABLE | AccessType::TABLE_FUNCTIONS); - } - else if (readonly_ == 2) - { - /// Allow INSERT into temporary tables - result.grant(AccessType::INSERT, DatabaseCatalog::TEMPORARY_DATABASE); - } - - if (!allow_introspection_) - result.fullRevoke(AccessType::INTROSPECTION); - - result_access_cache[cache_index].store(result_ptr); - - if (trace_log && (params.readonly == readonly_) && (params.allow_ddl == allow_ddl_) && (params.allow_introspection == allow_introspection_)) - { - LOG_TRACE(trace_log, "List of all grants: " << result_ptr->toString() << (grant_option ? " WITH GRANT OPTION" : "")); - if (roles_info && !roles_info->getCurrentRolesNames().empty()) - { - LOG_TRACE( - trace_log, - "Current_roles: " << boost::algorithm::join(roles_info->getCurrentRolesNames(), ", ") - << ", enabled_roles: " << boost::algorithm::join(roles_info->getEnabledRolesNames(), ", ")); - } - } - - return result_ptr; -} - - -UserPtr AccessRightsContext::getUser() const -{ - std::lock_guard lock{mutex}; - return user; -} - -String AccessRightsContext::getUserName() const -{ - std::lock_guard lock{mutex}; - return user_name; -} - -CurrentRolesInfoPtr AccessRightsContext::getRolesInfo() const -{ - std::lock_guard lock{mutex}; - return roles_info; -} - -std::vector AccessRightsContext::getCurrentRoles() const -{ - std::lock_guard lock{mutex}; - return roles_info ? roles_info->current_roles : std::vector{}; -} - -Strings AccessRightsContext::getCurrentRolesNames() const -{ - std::lock_guard lock{mutex}; - return roles_info ? roles_info->getCurrentRolesNames() : Strings{}; -} - -std::vector AccessRightsContext::getEnabledRoles() const -{ - std::lock_guard lock{mutex}; - return roles_info ? roles_info->enabled_roles : std::vector{}; -} - -Strings AccessRightsContext::getEnabledRolesNames() const -{ - std::lock_guard lock{mutex}; - return roles_info ? roles_info->getEnabledRolesNames() : Strings{}; -} - -RowPolicyContextPtr AccessRightsContext::getRowPolicy() const -{ - std::lock_guard lock{mutex}; - return row_policy_context; -} - -QuotaContextPtr AccessRightsContext::getQuota() const -{ - std::lock_guard lock{mutex}; - return quota_context; -} - - -bool operator <(const AccessRightsContext::Params & lhs, const AccessRightsContext::Params & rhs) -{ -#define ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(field) \ - if (lhs.field < rhs.field) \ - return true; \ - if (lhs.field > rhs.field) \ - return false - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(user_id); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_roles); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(use_default_roles); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(address); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(quota_key); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_database); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(readonly); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_ddl); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_introspection); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(interface); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(http_method); - return false; -#undef ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER -} - - -bool operator ==(const AccessRightsContext::Params & lhs, const AccessRightsContext::Params & rhs) -{ -#define ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(field) \ - if (lhs.field != rhs.field) \ - return false - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(user_id); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_roles); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(use_default_roles); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(address); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(quota_key); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_database); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(readonly); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_ddl); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_introspection); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(interface); - ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(http_method); - return true; -#undef ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER -} - -} diff --git a/dbms/src/Access/AccessRightsContext.h b/dbms/src/Access/AccessRightsContext.h deleted file mode 100644 index 8fc5066cfe4..00000000000 --- a/dbms/src/Access/AccessRightsContext.h +++ /dev/null @@ -1,157 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace Poco { class Logger; } - -namespace DB -{ -struct User; -using UserPtr = std::shared_ptr; -struct CurrentRolesInfo; -using CurrentRolesInfoPtr = std::shared_ptr; -class RoleContext; -using RoleContextPtr = std::shared_ptr; -class RowPolicyContext; -using RowPolicyContextPtr = std::shared_ptr; -class QuotaContext; -using QuotaContextPtr = std::shared_ptr; -struct Settings; -class AccessControlManager; - - -class AccessRightsContext -{ -public: - struct Params - { - std::optional user_id; - std::vector current_roles; - bool use_default_roles = false; - UInt64 readonly = 0; - bool allow_ddl = false; - bool allow_introspection = false; - String current_database; - ClientInfo::Interface interface = ClientInfo::Interface::TCP; - ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; - Poco::Net::IPAddress address; - String quota_key; - - friend bool operator ==(const Params & lhs, const Params & rhs); - friend bool operator !=(const Params & lhs, const Params & rhs) { return !(lhs == rhs); } - friend bool operator <(const Params & lhs, const Params & rhs); - friend bool operator >(const Params & lhs, const Params & rhs) { return rhs < lhs; } - friend bool operator <=(const Params & lhs, const Params & rhs) { return !(rhs < lhs); } - friend bool operator >=(const Params & lhs, const Params & rhs) { return !(lhs < rhs); } - }; - - /// Default constructor creates access rights' context which allows everything. - AccessRightsContext(); - - const Params & getParams() const { return params; } - UserPtr getUser() const; - String getUserName() const; - - bool isCorrectPassword(const String & password) const; - bool isClientHostAllowed() const; - - CurrentRolesInfoPtr getRolesInfo() const; - std::vector getCurrentRoles() const; - Strings getCurrentRolesNames() const; - std::vector getEnabledRoles() const; - Strings getEnabledRolesNames() const; - - RowPolicyContextPtr getRowPolicy() const; - QuotaContextPtr getQuota() const; - - /// Checks if a specified access is granted, and throws an exception if not. - /// Empty database means the current database. - void checkAccess(const AccessFlags & access) const; - void checkAccess(const AccessFlags & access, const std::string_view & database) const; - void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; - void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; - void checkAccess(const AccessRightsElement & access) const; - void checkAccess(const AccessRightsElements & access) const; - - /// Checks if a specified access is granted. - bool isGranted(const AccessFlags & access) const; - bool isGranted(const AccessFlags & access, const std::string_view & database) const; - bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; - bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; - bool isGranted(const AccessRightsElement & access) const; - bool isGranted(const AccessRightsElements & access) const; - - /// Checks if a specified access is granted, and logs a warning if not. - bool isGranted(Poco::Logger * log_, const AccessFlags & access) const; - bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const; - bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; - bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; - bool isGranted(Poco::Logger * log_, const AccessRightsElement & access) const; - bool isGranted(Poco::Logger * log_, const AccessRightsElements & access) const; - - /// Checks if a specified access is granted with grant option, and throws an exception if not. - void checkGrantOption(const AccessFlags & access) const; - void checkGrantOption(const AccessFlags & access, const std::string_view & database) const; - void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; - void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; - void checkGrantOption(const AccessRightsElement & access) const; - void checkGrantOption(const AccessRightsElements & access) const; - - /// Checks if a specified role is granted with admin option, and throws an exception if not. - void checkAdminOption(const UUID & role_id) const; - -private: - friend class AccessRightsContextFactory; - friend struct ext::shared_ptr_helper; - AccessRightsContext(const AccessControlManager & manager_, const Params & params_); /// AccessRightsContext should be created by AccessRightsContextFactory. - - void setUser(const UserPtr & user_) const; - void setRolesInfo(const CurrentRolesInfoPtr & roles_info_) const; - - template - bool checkAccessImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const; - - template - bool checkAccessImpl(Poco::Logger * log_, const AccessRightsElement & element) const; - - template - bool checkAccessImpl(Poco::Logger * log_, const AccessRightsElements & elements) const; - - boost::shared_ptr calculateResultAccess(bool grant_option) const; - boost::shared_ptr calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const; - - const AccessControlManager * manager = nullptr; - const Params params; - mutable Poco::Logger * trace_log = nullptr; - mutable UserPtr user; - mutable String user_name; - mutable ext::scope_guard subscription_for_user_change; - mutable RoleContextPtr role_context; - mutable ext::scope_guard subscription_for_roles_info_change; - mutable CurrentRolesInfoPtr roles_info; - mutable boost::atomic_shared_ptr> enabled_roles_with_admin_option; - mutable boost::atomic_shared_ptr result_access_cache[7]; - mutable RowPolicyContextPtr row_policy_context; - mutable QuotaContextPtr quota_context; - mutable std::mutex mutex; -}; - -using AccessRightsContextPtr = std::shared_ptr; - -} diff --git a/dbms/src/Access/AccessRightsContextFactory.cpp b/dbms/src/Access/AccessRightsContextFactory.cpp deleted file mode 100644 index 8d542a5f439..00000000000 --- a/dbms/src/Access/AccessRightsContextFactory.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include -#include -#include - - -namespace DB -{ -AccessRightsContextFactory::AccessRightsContextFactory(const AccessControlManager & manager_) - : manager(manager_), cache(600000 /* 10 minutes */) {} - -AccessRightsContextFactory::~AccessRightsContextFactory() = default; - - -AccessRightsContextPtr AccessRightsContextFactory::createContext(const Params & params) -{ - std::lock_guard lock{mutex}; - auto x = cache.get(params); - if (x) - return *x; - auto res = ext::shared_ptr_helper::create(manager, params); - cache.add(params, res); - return res; -} - -AccessRightsContextPtr AccessRightsContextFactory::createContext( - const UUID & user_id, - const std::vector & current_roles, - bool use_default_roles, - const Settings & settings, - const String & current_database, - const ClientInfo & client_info) -{ - Params params; - params.user_id = user_id; - params.current_roles = current_roles; - params.use_default_roles = use_default_roles; - params.current_database = current_database; - params.readonly = settings.readonly; - params.allow_ddl = settings.allow_ddl; - params.allow_introspection = settings.allow_introspection_functions; - params.interface = client_info.interface; - params.http_method = client_info.http_method; - params.address = client_info.current_address.host(); - params.quota_key = client_info.quota_key; - return createContext(params); -} - -} diff --git a/dbms/src/Access/AccessRightsContextFactory.h b/dbms/src/Access/AccessRightsContextFactory.h deleted file mode 100644 index c480307757a..00000000000 --- a/dbms/src/Access/AccessRightsContextFactory.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include -#include - - -namespace DB -{ -class AccessControlManager; - - -class AccessRightsContextFactory -{ -public: - AccessRightsContextFactory(const AccessControlManager & manager_); - ~AccessRightsContextFactory(); - - using Params = AccessRightsContext::Params; - AccessRightsContextPtr createContext(const Params & params); - AccessRightsContextPtr createContext(const UUID & user_id, const std::vector & current_roles, bool use_default_roles, const Settings & settings, const String & current_database, const ClientInfo & client_info); - -private: - const AccessControlManager & manager; - Poco::ExpireCache cache; - std::mutex mutex; -}; - -} diff --git a/dbms/src/Access/AccessType.h b/dbms/src/Access/AccessType.h index 205840eecdf..27892076d59 100644 --- a/dbms/src/Access/AccessType.h +++ b/dbms/src/Access/AccessType.h @@ -14,12 +14,11 @@ enum class AccessType NONE, /// no access ALL, /// full access - SHOW, /// allows to execute SHOW TABLES, SHOW CREATE TABLE, SHOW DATABASES and so on - /// (granted implicitly with any other grant) - - EXISTS, /// allows to execute EXISTS, USE, i.e. to check existence - /// (granted implicitly on the database level with any other grant on the database and lower levels, - /// e.g. "GRANT SELECT(x) ON db.table" also grants EXISTS on db.*) + SHOW_DATABASES, /// allows to execute SHOW DATABASES, SHOW CREATE DATABASE, USE + SHOW_TABLES, /// allows to execute SHOW TABLES, EXISTS , CHECK
+ SHOW_COLUMNS, /// allows to execute SHOW CREATE TABLE, DESCRIBE + SHOW_DICTIONARIES, /// allows to execute SHOW DICTIONARIES, SHOW CREATE DICTIONARY, EXISTS + SHOW, /// allows to execute SHOW, USE, EXISTS, CHECK, DESCRIBE SELECT, INSERT, @@ -80,13 +79,12 @@ enum class AccessType OPTIMIZE, /// allows to execute OPTIMIZE TABLE KILL_QUERY, /// allows to kill a query started by another user (anyone can kill his own queries) - KILL_MUTATION, /// allows to kill a mutation - KILL, /// allows to execute KILL {MUTATION|QUERY} CREATE_USER, ALTER_USER, DROP_USER, CREATE_ROLE, + ALTER_ROLE, DROP_ROLE, CREATE_POLICY, ALTER_POLICY, @@ -94,6 +92,9 @@ enum class AccessType CREATE_QUOTA, ALTER_QUOTA, DROP_QUOTA, + CREATE_SETTINGS_PROFILE, + ALTER_SETTINGS_PROFILE, + DROP_SETTINGS_PROFILE, ROLE_ADMIN, /// allows to grant and revoke any roles. @@ -179,8 +180,12 @@ namespace impl ACCESS_TYPE_TO_KEYWORD_CASE(NONE); ACCESS_TYPE_TO_KEYWORD_CASE(ALL); + + ACCESS_TYPE_TO_KEYWORD_CASE(SHOW_DATABASES); + ACCESS_TYPE_TO_KEYWORD_CASE(SHOW_TABLES); + ACCESS_TYPE_TO_KEYWORD_CASE(SHOW_COLUMNS); + ACCESS_TYPE_TO_KEYWORD_CASE(SHOW_DICTIONARIES); ACCESS_TYPE_TO_KEYWORD_CASE(SHOW); - ACCESS_TYPE_TO_KEYWORD_CASE(EXISTS); ACCESS_TYPE_TO_KEYWORD_CASE(SELECT); ACCESS_TYPE_TO_KEYWORD_CASE(INSERT); @@ -241,13 +246,12 @@ namespace impl ACCESS_TYPE_TO_KEYWORD_CASE(OPTIMIZE); ACCESS_TYPE_TO_KEYWORD_CASE(KILL_QUERY); - ACCESS_TYPE_TO_KEYWORD_CASE(KILL_MUTATION); - ACCESS_TYPE_TO_KEYWORD_CASE(KILL); ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_USER); ACCESS_TYPE_TO_KEYWORD_CASE(ALTER_USER); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_USER); ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_ROLE); + ACCESS_TYPE_TO_KEYWORD_CASE(ALTER_ROLE); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_ROLE); ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_POLICY); ACCESS_TYPE_TO_KEYWORD_CASE(ALTER_POLICY); @@ -255,6 +259,9 @@ namespace impl ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_QUOTA); ACCESS_TYPE_TO_KEYWORD_CASE(ALTER_QUOTA); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_QUOTA); + ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_SETTINGS_PROFILE); + ACCESS_TYPE_TO_KEYWORD_CASE(ALTER_SETTINGS_PROFILE); + ACCESS_TYPE_TO_KEYWORD_CASE(DROP_SETTINGS_PROFILE); ACCESS_TYPE_TO_KEYWORD_CASE(ROLE_ADMIN); ACCESS_TYPE_TO_KEYWORD_CASE(SHUTDOWN); diff --git a/dbms/src/Access/ContextAccess.cpp b/dbms/src/Access/ContextAccess.cpp new file mode 100644 index 00000000000..f5f4ccfe6ac --- /dev/null +++ b/dbms/src/Access/ContextAccess.cpp @@ -0,0 +1,552 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ACCESS_DENIED; + extern const int READONLY; + extern const int QUERY_IS_PROHIBITED; + extern const int FUNCTION_NOT_ALLOWED; + extern const int UNKNOWN_USER; +} + + +namespace +{ + enum CheckAccessRightsMode + { + RETURN_FALSE_IF_ACCESS_DENIED, + LOG_WARNING_IF_ACCESS_DENIED, + THROW_IF_ACCESS_DENIED, + }; + + + String formatSkippedMessage() + { + return ""; + } + + String formatSkippedMessage(const std::string_view & database) + { + return ". Skipped database " + backQuoteIfNeed(database); + } + + String formatSkippedMessage(const std::string_view & database, const std::string_view & table) + { + String str = ". Skipped table "; + if (!database.empty()) + str += backQuoteIfNeed(database) + "."; + str += backQuoteIfNeed(table); + return str; + } + + String formatSkippedMessage(const std::string_view & database, const std::string_view & table, const std::string_view & column) + { + String str = ". Skipped column " + backQuoteIfNeed(column) + " ON "; + if (!database.empty()) + str += backQuoteIfNeed(database) + "."; + str += backQuoteIfNeed(table); + return str; + } + + template + String formatSkippedMessage(const std::string_view & database, const std::string_view & table, const std::vector & columns) + { + if (columns.size() == 1) + return formatSkippedMessage(database, table, columns[0]); + + String str = ". Skipped columns "; + bool need_comma = false; + for (const auto & column : columns) + { + if (std::exchange(need_comma, true)) + str += ", "; + str += backQuoteIfNeed(column); + } + str += " ON "; + if (!database.empty()) + str += backQuoteIfNeed(database) + "."; + str += backQuoteIfNeed(table); + return str; + } +} + + +ContextAccess::ContextAccess(const AccessControlManager & manager_, const Params & params_) + : manager(&manager_) + , params(params_) +{ + subscription_for_user_change = manager->subscribeForChanges( + *params.user_id, [this](const UUID &, const AccessEntityPtr & entity) + { + UserPtr changed_user = entity ? typeid_cast(entity) : nullptr; + std::lock_guard lock{mutex}; + setUser(changed_user); + }); + + setUser(manager->read(*params.user_id)); +} + + +void ContextAccess::setUser(const UserPtr & user_) const +{ + user = user_; + if (!user) + { + /// User has been dropped. + auto nothing_granted = boost::make_shared(); + boost::range::fill(result_access, nothing_granted); + subscription_for_user_change = {}; + subscription_for_roles_changes = {}; + enabled_roles = nullptr; + roles_info = nullptr; + roles_with_admin_option = nullptr; + enabled_row_policies = nullptr; + enabled_quota = nullptr; + enabled_settings = nullptr; + return; + } + + user_name = user->getName(); + trace_log = &Poco::Logger::get("ContextAccess (" + user_name + ")"); + + std::vector current_roles, current_roles_with_admin_option; + if (params.use_default_roles) + { + for (const UUID & id : user->granted_roles) + { + if (user->default_roles.match(id)) + current_roles.push_back(id); + } + boost::range::set_intersection(current_roles, user->granted_roles_with_admin_option, + std::back_inserter(current_roles_with_admin_option)); + } + else + { + current_roles.reserve(params.current_roles.size()); + for (const auto & id : params.current_roles) + { + if (user->granted_roles.contains(id)) + current_roles.push_back(id); + if (user->granted_roles_with_admin_option.contains(id)) + current_roles_with_admin_option.push_back(id); + } + } + + subscription_for_roles_changes = {}; + enabled_roles = manager->getEnabledRoles(current_roles, current_roles_with_admin_option); + subscription_for_roles_changes = enabled_roles->subscribeForChanges([this](const std::shared_ptr & roles_info_) + { + std::lock_guard lock{mutex}; + setRolesInfo(roles_info_); + }); + + setRolesInfo(enabled_roles->getRolesInfo()); +} + + +void ContextAccess::setRolesInfo(const std::shared_ptr & roles_info_) const +{ + assert(roles_info_); + roles_info = roles_info_; + roles_with_admin_option.store(boost::make_shared>(roles_info->enabled_roles_with_admin_option.begin(), roles_info->enabled_roles_with_admin_option.end())); + boost::range::fill(result_access, nullptr /* need recalculate */); + enabled_row_policies = manager->getEnabledRowPolicies(*params.user_id, roles_info->enabled_roles); + enabled_quota = manager->getEnabledQuota(*params.user_id, user_name, roles_info->enabled_roles, params.address, params.quota_key); + enabled_settings = manager->getEnabledSettings(*params.user_id, user->settings, roles_info->enabled_roles, roles_info->settings_from_enabled_roles); +} + + +bool ContextAccess::isCorrectPassword(const String & password) const +{ + std::lock_guard lock{mutex}; + if (!user) + return false; + return user->authentication.isCorrectPassword(password); +} + +bool ContextAccess::isClientHostAllowed() const +{ + std::lock_guard lock{mutex}; + if (!user) + return false; + return user->allowed_client_hosts.contains(params.address); +} + + +template +bool ContextAccess::checkAccessImpl(Poco::Logger * log_, const AccessFlags & flags, const Args &... args) const +{ + auto access = calculateResultAccess(grant_option); + bool is_granted = access->isGranted(flags, args...); + + if (trace_log) + LOG_TRACE(trace_log, "Access " << (is_granted ? "granted" : "denied") << ": " << (AccessRightsElement{flags, args...}.toString())); + + if (is_granted) + return true; + + if constexpr (mode == RETURN_FALSE_IF_ACCESS_DENIED) + return false; + + if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED) + { + if (!log_) + return false; + } + + auto show_error = [&](const String & msg, [[maybe_unused]] int error_code) + { + if constexpr (mode == THROW_IF_ACCESS_DENIED) + throw Exception(user_name + ": " + msg, error_code); + else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED) + LOG_WARNING(log_, user_name + ": " + msg + formatSkippedMessage(args...)); + }; + + if (!user) + { + show_error("User has been dropped", ErrorCodes::UNKNOWN_USER); + } + else if (grant_option && calculateResultAccess(false, params.readonly, params.allow_ddl, params.allow_introspection)->isGranted(flags, args...)) + { + show_error( + "Not enough privileges. " + "The required privileges have been granted, but without grant option. " + "To execute this query it's necessary to have the grant " + + AccessRightsElement{flags, args...}.toString() + " WITH GRANT OPTION", + ErrorCodes::ACCESS_DENIED); + } + else if (params.readonly && calculateResultAccess(false, false, params.allow_ddl, params.allow_introspection)->isGranted(flags, args...)) + { + if (params.interface == ClientInfo::Interface::HTTP && params.http_method == ClientInfo::HTTPMethod::GET) + show_error( + "Cannot execute query in readonly mode. " + "For queries over HTTP, method GET implies readonly. You should use method POST for modifying queries", + ErrorCodes::READONLY); + else + show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY); + } + else if (!params.allow_ddl && calculateResultAccess(false, params.readonly, true, params.allow_introspection)->isGranted(flags, args...)) + { + show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED); + } + else if (!params.allow_introspection && calculateResultAccess(false, params.readonly, params.allow_ddl, true)->isGranted(flags, args...)) + { + show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED); + } + else + { + show_error( + "Not enough privileges. To execute this query it's necessary to have the grant " + + AccessRightsElement{flags, args...}.toString() + (grant_option ? " WITH GRANT OPTION" : ""), + ErrorCodes::ACCESS_DENIED); + } + + return false; +} + + +template +bool ContextAccess::checkAccessImpl(Poco::Logger * log_, const AccessRightsElement & element) const +{ + if (element.any_database) + { + return checkAccessImpl(log_, element.access_flags); + } + else if (element.any_table) + { + if (element.database.empty()) + return checkAccessImpl(log_, element.access_flags, params.current_database); + else + return checkAccessImpl(log_, element.access_flags, element.database); + } + else if (element.any_column) + { + if (element.database.empty()) + return checkAccessImpl(log_, element.access_flags, params.current_database, element.table); + else + return checkAccessImpl(log_, element.access_flags, element.database, element.table); + } + else + { + if (element.database.empty()) + return checkAccessImpl(log_, element.access_flags, params.current_database, element.table, element.columns); + else + return checkAccessImpl(log_, element.access_flags, element.database, element.table, element.columns); + } +} + + +template +bool ContextAccess::checkAccessImpl(Poco::Logger * log_, const AccessRightsElements & elements) const +{ + for (const auto & element : elements) + if (!checkAccessImpl(log_, element)) + return false; + return true; +} + + +void ContextAccess::checkAccess(const AccessFlags & flags) const { checkAccessImpl(nullptr, flags); } +void ContextAccess::checkAccess(const AccessFlags & flags, const std::string_view & database) const { checkAccessImpl(nullptr, flags, database); } +void ContextAccess::checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const { checkAccessImpl(nullptr, flags, database, table); } +void ContextAccess::checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkAccessImpl(nullptr, flags, database, table, column); } +void ContextAccess::checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkAccessImpl(nullptr, flags, database, table, columns); } +void ContextAccess::checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkAccessImpl(nullptr, flags, database, table, columns); } +void ContextAccess::checkAccess(const AccessRightsElement & element) const { checkAccessImpl(nullptr, element); } +void ContextAccess::checkAccess(const AccessRightsElements & elements) const { checkAccessImpl(nullptr, elements); } + +bool ContextAccess::isGranted(const AccessFlags & flags) const { return checkAccessImpl(nullptr, flags); } +bool ContextAccess::isGranted(const AccessFlags & flags, const std::string_view & database) const { return checkAccessImpl(nullptr, flags, database); } +bool ContextAccess::isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(nullptr, flags, database, table); } +bool ContextAccess::isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(nullptr, flags, database, table, column); } +bool ContextAccess::isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(nullptr, flags, database, table, columns); } +bool ContextAccess::isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(nullptr, flags, database, table, columns); } +bool ContextAccess::isGranted(const AccessRightsElement & element) const { return checkAccessImpl(nullptr, element); } +bool ContextAccess::isGranted(const AccessRightsElements & elements) const { return checkAccessImpl(nullptr, elements); } + +bool ContextAccess::isGranted(Poco::Logger * log_, const AccessFlags & flags) const { return checkAccessImpl(log_, flags); } +bool ContextAccess::isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database) const { return checkAccessImpl(log_, flags, database); } +bool ContextAccess::isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(log_, flags, database, table); } +bool ContextAccess::isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(log_, flags, database, table, column); } +bool ContextAccess::isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(log_, flags, database, table, columns); } +bool ContextAccess::isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(log_, flags, database, table, columns); } +bool ContextAccess::isGranted(Poco::Logger * log_, const AccessRightsElement & element) const { return checkAccessImpl(log_, element); } +bool ContextAccess::isGranted(Poco::Logger * log_, const AccessRightsElements & elements) const { return checkAccessImpl(log_, elements); } + +void ContextAccess::checkGrantOption(const AccessFlags & flags) const { checkAccessImpl(nullptr, flags); } +void ContextAccess::checkGrantOption(const AccessFlags & flags, const std::string_view & database) const { checkAccessImpl(nullptr, flags, database); } +void ContextAccess::checkGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const { checkAccessImpl(nullptr, flags, database, table); } +void ContextAccess::checkGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkAccessImpl(nullptr, flags, database, table, column); } +void ContextAccess::checkGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkAccessImpl(nullptr, flags, database, table, columns); } +void ContextAccess::checkGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkAccessImpl(nullptr, flags, database, table, columns); } +void ContextAccess::checkGrantOption(const AccessRightsElement & element) const { checkAccessImpl(nullptr, element); } +void ContextAccess::checkGrantOption(const AccessRightsElements & elements) const { checkAccessImpl(nullptr, elements); } + + +void ContextAccess::checkAdminOption(const UUID & role_id) const +{ + if (isGranted(AccessType::ROLE_ADMIN)) + return; + + auto roles_with_admin_option_loaded = roles_with_admin_option.load(); + if (roles_with_admin_option_loaded && roles_with_admin_option_loaded->contains(role_id)) + return; + + std::optional role_name = manager->readName(role_id); + if (!role_name) + role_name = "ID {" + toString(role_id) + "}"; + throw Exception( + getUserName() + ": Not enough privileges. To execute this query it's necessary to have the grant " + backQuoteIfNeed(*role_name) + + " WITH ADMIN OPTION ", + ErrorCodes::ACCESS_DENIED); +} + + +boost::shared_ptr ContextAccess::calculateResultAccess(bool grant_option) const +{ + return calculateResultAccess(grant_option, params.readonly, params.allow_ddl, params.allow_introspection); +} + + +boost::shared_ptr ContextAccess::calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const +{ + size_t cache_index = static_cast(readonly_ != params.readonly) + + static_cast(allow_ddl_ != params.allow_ddl) * 2 + + + static_cast(allow_introspection_ != params.allow_introspection) * 3 + + static_cast(grant_option) * 4; + assert(cache_index < std::size(result_access)); + auto res = result_access[cache_index].load(); + if (res) + return res; + + std::lock_guard lock{mutex}; + res = result_access[cache_index].load(); + if (res) + return res; + + auto merged_access = boost::make_shared(); + + if (grant_option) + { + *merged_access = user->access_with_grant_option; + if (roles_info) + merged_access->merge(roles_info->access_with_grant_option); + } + else + { + *merged_access = user->access; + if (roles_info) + merged_access->merge(roles_info->access); + } + + static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW + | AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW + | AccessType::TRUNCATE; + static const AccessFlags dictionary_ddl = AccessType::CREATE_DICTIONARY | AccessType::DROP_DICTIONARY; + static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl; + static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE; + static const AccessFlags all_dcl = AccessType::CREATE_USER | AccessType::CREATE_ROLE | AccessType::CREATE_POLICY + | AccessType::CREATE_QUOTA | AccessType::ALTER_USER | AccessType::ALTER_POLICY | AccessType::ALTER_QUOTA | AccessType::DROP_USER + | AccessType::DROP_ROLE | AccessType::DROP_POLICY | AccessType::DROP_QUOTA | AccessType::ROLE_ADMIN; + + if (readonly_) + merged_access->revoke(write_table_access | all_dcl | AccessType::SYSTEM | AccessType::KILL_QUERY); + + if (readonly_ || !allow_ddl_) + merged_access->revoke(table_and_dictionary_ddl); + + if (readonly_ == 1) + { + /// Table functions are forbidden in readonly mode. + /// For example, for readonly = 2 - allowed. + merged_access->revoke(AccessType::CREATE_TEMPORARY_TABLE | AccessType::TABLE_FUNCTIONS); + } + + if (!allow_introspection_) + merged_access->revoke(AccessType::INTROSPECTION); + + /// Anyone has access to the "system" database. + merged_access->grant(AccessType::SELECT, DatabaseCatalog::SYSTEM_DATABASE); + + if (readonly_ != 1) + { + /// User has access to temporary or external table if such table was resolved in session or query context + merged_access->grant(AccessFlags::allTableFlags() | AccessFlags::allColumnFlags(), DatabaseCatalog::TEMPORARY_DATABASE); + } + + if (readonly_ && grant_option) + { + /// No grant option in readonly mode. + merged_access->revoke(AccessType::ALL); + } + + if (trace_log && (params.readonly == readonly_) && (params.allow_ddl == allow_ddl_) && (params.allow_introspection == allow_introspection_)) + { + LOG_TRACE(trace_log, "List of all grants: " << merged_access->toString() << (grant_option ? " WITH GRANT OPTION" : "")); + if (roles_info && !roles_info->getCurrentRolesNames().empty()) + { + LOG_TRACE( + trace_log, + "Current_roles: " << boost::algorithm::join(roles_info->getCurrentRolesNames(), ", ") + << ", enabled_roles: " << boost::algorithm::join(roles_info->getEnabledRolesNames(), ", ")); + } + } + + res = std::move(merged_access); + result_access[cache_index].store(res); + return res; +} + + +UserPtr ContextAccess::getUser() const +{ + std::lock_guard lock{mutex}; + return user; +} + +String ContextAccess::getUserName() const +{ + std::lock_guard lock{mutex}; + return user_name; +} + +std::shared_ptr ContextAccess::getRolesInfo() const +{ + std::lock_guard lock{mutex}; + return roles_info; +} + +std::vector ContextAccess::getCurrentRoles() const +{ + std::lock_guard lock{mutex}; + return roles_info ? roles_info->current_roles : std::vector{}; +} + +Strings ContextAccess::getCurrentRolesNames() const +{ + std::lock_guard lock{mutex}; + return roles_info ? roles_info->getCurrentRolesNames() : Strings{}; +} + +std::vector ContextAccess::getEnabledRoles() const +{ + std::lock_guard lock{mutex}; + return roles_info ? roles_info->enabled_roles : std::vector{}; +} + +Strings ContextAccess::getEnabledRolesNames() const +{ + std::lock_guard lock{mutex}; + return roles_info ? roles_info->getEnabledRolesNames() : Strings{}; +} + +std::shared_ptr ContextAccess::getRowPolicies() const +{ + std::lock_guard lock{mutex}; + return enabled_row_policies; +} + +ASTPtr ContextAccess::getRowPolicyCondition(const String & database, const String & table_name, RowPolicy::ConditionType index, const ASTPtr & extra_condition) const +{ + std::lock_guard lock{mutex}; + return enabled_row_policies ? enabled_row_policies->getCondition(database, table_name, index, extra_condition) : nullptr; +} + +std::shared_ptr ContextAccess::getQuota() const +{ + std::lock_guard lock{mutex}; + return enabled_quota; +} + + +std::shared_ptr ContextAccess::getFullAccess() +{ + static const std::shared_ptr res = [] + { + auto full_access = std::shared_ptr(new ContextAccess); + auto everything_granted = boost::make_shared(); + everything_granted->grant(AccessType::ALL); + boost::range::fill(full_access->result_access, everything_granted); + full_access->enabled_quota = EnabledQuota::getUnlimitedQuota(); + return full_access; + }(); + return res; +} + + +std::shared_ptr ContextAccess::getDefaultSettings() const +{ + std::lock_guard lock{mutex}; + return enabled_settings->getSettings(); +} + + +std::shared_ptr ContextAccess::getSettingsConstraints() const +{ + std::lock_guard lock{mutex}; + return enabled_settings->getConstraints(); +} + +} diff --git a/dbms/src/Access/ContextAccess.h b/dbms/src/Access/ContextAccess.h new file mode 100644 index 00000000000..bee63103793 --- /dev/null +++ b/dbms/src/Access/ContextAccess.h @@ -0,0 +1,162 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace Poco { class Logger; } + +namespace DB +{ +struct User; +using UserPtr = std::shared_ptr; +struct EnabledRolesInfo; +class EnabledRoles; +class EnabledRowPolicies; +class EnabledQuota; +class EnabledSettings; +struct Settings; +class SettingsConstraints; +class AccessControlManager; +class IAST; +using ASTPtr = std::shared_ptr; + + +class ContextAccess +{ +public: + struct Params + { + std::optional user_id; + std::vector current_roles; + bool use_default_roles = false; + UInt64 readonly = 0; + bool allow_ddl = false; + bool allow_introspection = false; + String current_database; + ClientInfo::Interface interface = ClientInfo::Interface::TCP; + ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; + Poco::Net::IPAddress address; + String quota_key; + + auto toTuple() const { return std::tie(user_id, current_roles, use_default_roles, readonly, allow_ddl, allow_introspection, current_database, interface, http_method, address, quota_key); } + friend bool operator ==(const Params & lhs, const Params & rhs) { return lhs.toTuple() == rhs.toTuple(); } + friend bool operator !=(const Params & lhs, const Params & rhs) { return !(lhs == rhs); } + friend bool operator <(const Params & lhs, const Params & rhs) { return lhs.toTuple() < rhs.toTuple(); } + friend bool operator >(const Params & lhs, const Params & rhs) { return rhs < lhs; } + friend bool operator <=(const Params & lhs, const Params & rhs) { return !(rhs < lhs); } + friend bool operator >=(const Params & lhs, const Params & rhs) { return !(lhs < rhs); } + }; + + const Params & getParams() const { return params; } + UserPtr getUser() const; + String getUserName() const; + + bool isCorrectPassword(const String & password) const; + bool isClientHostAllowed() const; + + std::shared_ptr getRolesInfo() const; + std::vector getCurrentRoles() const; + Strings getCurrentRolesNames() const; + std::vector getEnabledRoles() const; + Strings getEnabledRolesNames() const; + + std::shared_ptr getRowPolicies() const; + ASTPtr getRowPolicyCondition(const String & database, const String & table_name, RowPolicy::ConditionType index, const ASTPtr & extra_condition = nullptr) const; + std::shared_ptr getQuota() const; + std::shared_ptr getDefaultSettings() const; + std::shared_ptr getSettingsConstraints() const; + + /// Checks if a specified access is granted, and throws an exception if not. + /// Empty database means the current database. + void checkAccess(const AccessFlags & flags) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + void checkAccess(const AccessRightsElement & element) const; + void checkAccess(const AccessRightsElements & elements) const; + + /// Checks if a specified access is granted. + bool isGranted(const AccessFlags & flags) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + bool isGranted(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + bool isGranted(const AccessRightsElement & element) const; + bool isGranted(const AccessRightsElements & elements) const; + + /// Checks if a specified access is granted, and logs a warning if not. + bool isGranted(Poco::Logger * log_, const AccessFlags & flags) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + bool isGranted(Poco::Logger * log_, const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + bool isGranted(Poco::Logger * log_, const AccessRightsElement & element) const; + bool isGranted(Poco::Logger * log_, const AccessRightsElements & elements) const; + + /// Checks if a specified access is granted with grant option, and throws an exception if not. + void checkGrantOption(const AccessFlags & flags) const; + void checkGrantOption(const AccessFlags & flags, const std::string_view & database) const; + void checkGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const; + void checkGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + void checkGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + void checkGrantOption(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + void checkGrantOption(const AccessRightsElement & element) const; + void checkGrantOption(const AccessRightsElements & elements) const; + + /// Checks if a specified role is granted with admin option, and throws an exception if not. + void checkAdminOption(const UUID & role_id) const; + + /// Returns an instance of ContextAccess which has full access to everything. + static std::shared_ptr getFullAccess(); + +private: + friend class AccessControlManager; + ContextAccess() {} + ContextAccess(const AccessControlManager & manager_, const Params & params_); + + void setUser(const UserPtr & user_) const; + void setRolesInfo(const std::shared_ptr & roles_info_) const; + void setSettingsAndConstraints() const; + + template + bool checkAccessImpl(Poco::Logger * log_, const AccessFlags & flags, const Args &... args) const; + + template + bool checkAccessImpl(Poco::Logger * log_, const AccessRightsElement & element) const; + + template + bool checkAccessImpl(Poco::Logger * log_, const AccessRightsElements & elements) const; + + boost::shared_ptr calculateResultAccess(bool grant_option) const; + boost::shared_ptr calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const; + + const AccessControlManager * manager = nullptr; + const Params params; + mutable Poco::Logger * trace_log = nullptr; + mutable UserPtr user; + mutable String user_name; + mutable ext::scope_guard subscription_for_user_change; + mutable std::shared_ptr enabled_roles; + mutable ext::scope_guard subscription_for_roles_changes; + mutable std::shared_ptr roles_info; + mutable boost::atomic_shared_ptr> roles_with_admin_option; + mutable boost::atomic_shared_ptr result_access[7]; + mutable std::shared_ptr enabled_row_policies; + mutable std::shared_ptr enabled_quota; + mutable std::shared_ptr enabled_settings; + mutable std::mutex mutex; +}; + +} diff --git a/dbms/src/Access/DiskAccessStorage.cpp b/dbms/src/Access/DiskAccessStorage.cpp index f5f42e1ff80..12c65e7df1e 100644 --- a/dbms/src/Access/DiskAccessStorage.cpp +++ b/dbms/src/Access/DiskAccessStorage.cpp @@ -8,15 +8,18 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -24,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -64,6 +68,8 @@ namespace return true; if (ParserCreateQuotaQuery{}.enableAttachMode(true).parse(pos, node, expected)) return true; + if (ParserCreateSettingsProfileQuery{}.enableAttachMode(true).parse(pos, node, expected)) + return true; if (ParserGrantQuery{}.enableAttachMode(true).parse(pos, node, expected)) return true; return false; @@ -97,6 +103,7 @@ namespace std::shared_ptr role; std::shared_ptr policy; std::shared_ptr quota; + std::shared_ptr profile; AccessEntityPtr res; for (const auto & query : queries) @@ -129,6 +136,13 @@ namespace res = quota = std::make_unique(); InterpreterCreateQuotaQuery::updateQuotaFromQuery(*quota, *create_quota_query); } + else if (auto create_profile_query = query->as()) + { + if (res) + throw Exception("Two access entities are attached in the same file " + file_path.string(), ErrorCodes::INCORRECT_ACCESS_ENTITY_DEFINITION); + res = profile = std::make_unique(); + InterpreterCreateSettingsProfileQuery::updateSettingsProfileFromQuery(*profile, *create_profile_query); + } else if (auto grant_query = query->as()) { if (!user && !role) @@ -139,7 +153,7 @@ namespace InterpreterGrantQuery::updateRoleFromQuery(*role, *grant_query); } else - throw Exception("Two access entities are attached in the same file " + file_path.string(), ErrorCodes::INCORRECT_ACCESS_ENTITY_DEFINITION); + throw Exception("No interpreter found for query " + query->getID(), ErrorCodes::INCORRECT_ACCESS_ENTITY_DEFINITION); } if (!res) @@ -149,6 +163,20 @@ namespace } + AccessEntityPtr tryReadAccessEntityFile(const std::filesystem::path & file_path, Poco::Logger & log) + { + try + { + return readAccessEntityFile(file_path); + } + catch (...) + { + tryLogCurrentException(&log, "Could not parse " + file_path.string()); + return nullptr; + } + } + + /// Writes ATTACH queries for building a specified access entity to a file. void writeAccessEntityFile(const std::filesystem::path & file_path, const IAccessEntity & entity) { @@ -238,6 +266,8 @@ namespace file_name = "quotas"; else if (type == typeid(RowPolicy)) file_name = "row_policies"; + else if (type == typeid(SettingsProfile)) + file_name = "settings_profiles"; else throw Exception("Unexpected type of access entity: " + IAccessEntity::getTypeName(type), ErrorCodes::LOGICAL_ERROR); @@ -254,13 +284,6 @@ namespace } - const std::vector & getAllAccessEntityTypes() - { - static const std::vector res = {typeid(User), typeid(Role), typeid(RowPolicy), typeid(Quota)}; - return res; - } - - bool tryParseUUID(const String & str, UUID & id) { try @@ -273,13 +296,20 @@ namespace return false; } } + + + const std::vector & getAllAccessEntityTypes() + { + static const std::vector res = {typeid(User), typeid(Role), typeid(RowPolicy), typeid(Quota), typeid(SettingsProfile)}; + return res; + } } DiskAccessStorage::DiskAccessStorage() : IAccessStorage("disk") { - for (const auto & type : getAllAccessEntityTypes()) + for (auto type : getAllAccessEntityTypes()) name_to_id_maps[type]; } @@ -340,10 +370,10 @@ void DiskAccessStorage::initialize(const String & directory_path_, Notifications bool DiskAccessStorage::readLists() { assert(id_to_entry_map.empty()); - assert(name_to_id_maps.size() == getAllAccessEntityTypes().size()); bool ok = true; - for (auto & [type, name_to_id_map] : name_to_id_maps) + for (auto type : getAllAccessEntityTypes()) { + auto & name_to_id_map = name_to_id_maps.at(type); auto file_path = getListFilePath(directory_path, type); if (!std::filesystem::exists(file_path)) { @@ -362,6 +392,7 @@ bool DiskAccessStorage::readLists() ok = false; break; } + for (const auto & [name, id] : name_to_id_map) id_to_entry_map.emplace(id, Entry{name, type}); } @@ -376,11 +407,14 @@ bool DiskAccessStorage::readLists() } -void DiskAccessStorage::writeLists() +bool DiskAccessStorage::writeLists() { - if (failed_to_write_lists || types_of_lists_to_write.empty()) - return; /// We don't try to write list files after the first fail. - /// The next restart of the server will invoke rebuilding of the list files. + if (failed_to_write_lists) + return false; /// We don't try to write list files after the first fail. + /// The next restart of the server will invoke rebuilding of the list files. + + if (types_of_lists_to_write.empty()) + return true; for (const auto & type : types_of_lists_to_write) { @@ -395,13 +429,14 @@ void DiskAccessStorage::writeLists() tryLogCurrentException(getLogger(), "Could not write " + file_path.string()); failed_to_write_lists = true; types_of_lists_to_write.clear(); - return; + return false; } } /// The list files was successfully written, we don't need the 'need_rebuild_lists.mark' file any longer. std::filesystem::remove(getNeedRebuildListsMarkFilePath(directory_path)); types_of_lists_to_write.clear(); + return true; } @@ -465,10 +500,11 @@ void DiskAccessStorage::listsWritingThreadFunc() /// Reads and parses all the ".sql" files from a specified directory /// and then saves the files "users.list", "roles.list", etc. to the same directory. -void DiskAccessStorage::rebuildLists() +bool DiskAccessStorage::rebuildLists() { LOG_WARNING(getLogger(), "Recovering lists in directory " + directory_path); assert(id_to_entry_map.empty()); + for (const auto & directory_entry : std::filesystem::directory_iterator(directory_path)) { if (!directory_entry.is_regular_file()) @@ -481,14 +517,21 @@ void DiskAccessStorage::rebuildLists() if (!tryParseUUID(path.stem(), id)) continue; - auto entity = readAccessEntityFile(getAccessEntityFilePath(directory_path, id)); + const auto access_entity_file_path = getAccessEntityFilePath(directory_path, id); + auto entity = tryReadAccessEntityFile(access_entity_file_path, *getLogger()); + if (!entity) + continue; + auto type = entity->getType(); - auto & name_to_id_map = name_to_id_maps[type]; + auto & name_to_id_map = name_to_id_maps.at(type); auto it_by_name = name_to_id_map.emplace(entity->getFullName(), id).first; id_to_entry_map.emplace(id, Entry{it_by_name->first, type}); } - boost::range::copy(getAllAccessEntityTypes(), std::inserter(types_of_lists_to_write, types_of_lists_to_write.end())); + for (auto type : getAllAccessEntityTypes()) + types_of_lists_to_write.insert(type); + + return true; } @@ -499,6 +542,7 @@ std::optional DiskAccessStorage::findImpl(std::type_index type, const Stri auto it = name_to_id_map.find(name); if (it == name_to_id_map.end()) return {}; + return it->second; } diff --git a/dbms/src/Access/DiskAccessStorage.h b/dbms/src/Access/DiskAccessStorage.h index 935cebfece9..104c0f1fa38 100644 --- a/dbms/src/Access/DiskAccessStorage.h +++ b/dbms/src/Access/DiskAccessStorage.h @@ -33,9 +33,9 @@ private: void initialize(const String & directory_path_, Notifications & notifications); bool readLists(); - void writeLists(); + bool writeLists(); void scheduleWriteLists(std::type_index type); - void rebuildLists(); + bool rebuildLists(); void startListsWritingThread(); void stopListsWritingThread(); diff --git a/dbms/src/Access/QuotaContext.cpp b/dbms/src/Access/EnabledQuota.cpp similarity index 77% rename from dbms/src/Access/QuotaContext.cpp rename to dbms/src/Access/EnabledQuota.cpp index a48c41dc419..92257ce0002 100644 --- a/dbms/src/Access/QuotaContext.cpp +++ b/dbms/src/Access/EnabledQuota.cpp @@ -1,4 +1,5 @@ -#include +#include +#include #include #include #include @@ -14,7 +15,7 @@ namespace ErrorCodes extern const int QUOTA_EXPIRED; } -struct QuotaContext::Impl +struct EnabledQuota::Impl { [[noreturn]] static void throwQuotaExceed( const String & user_name, @@ -133,7 +134,7 @@ struct QuotaContext::Impl }; -QuotaContext::Interval & QuotaContext::Interval::operator =(const Interval & src) +EnabledQuota::Interval & EnabledQuota::Interval::operator =(const Interval & src) { if (this == &src) return *this; @@ -150,7 +151,7 @@ QuotaContext::Interval & QuotaContext::Interval::operator =(const Interval & src } -QuotaUsageInfo QuotaContext::Intervals::getUsageInfo(std::chrono::system_clock::time_point current_time) const +QuotaUsageInfo EnabledQuota::Intervals::getUsageInfo(std::chrono::system_clock::time_point current_time) const { QuotaUsageInfo info; info.quota_id = quota_id; @@ -174,97 +175,85 @@ QuotaUsageInfo QuotaContext::Intervals::getUsageInfo(std::chrono::system_clock:: } -QuotaContext::QuotaContext() - : intervals(boost::make_shared()) /// Unlimited quota. +EnabledQuota::EnabledQuota(const Params & params_) : params(params_) { } - -QuotaContext::QuotaContext( - const String & user_name_, - const UUID & user_id_, - const std::vector & enabled_roles_, - const Poco::Net::IPAddress & address_, - const String & client_key_) - : user_name(user_name_), user_id(user_id_), enabled_roles(enabled_roles_), address(address_), client_key(client_key_) -{ -} +EnabledQuota::~EnabledQuota() = default; -QuotaContext::~QuotaContext() = default; - - -void QuotaContext::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded) const +void EnabledQuota::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded) const { used({resource_type, amount}, check_exceeded); } -void QuotaContext::used(const std::pair & resource, bool check_exceeded) const +void EnabledQuota::used(const std::pair & resource, bool check_exceeded) const { auto loaded = intervals.load(); auto current_time = std::chrono::system_clock::now(); - Impl::used(user_name, *loaded, resource.first, resource.second, current_time, check_exceeded); + Impl::used(getUserName(), *loaded, resource.first, resource.second, current_time, check_exceeded); } -void QuotaContext::used(const std::pair & resource1, const std::pair & resource2, bool check_exceeded) const +void EnabledQuota::used(const std::pair & resource1, const std::pair & resource2, bool check_exceeded) const { auto loaded = intervals.load(); auto current_time = std::chrono::system_clock::now(); - Impl::used(user_name, *loaded, resource1.first, resource1.second, current_time, check_exceeded); - Impl::used(user_name, *loaded, resource2.first, resource2.second, current_time, check_exceeded); + Impl::used(getUserName(), *loaded, resource1.first, resource1.second, current_time, check_exceeded); + Impl::used(getUserName(), *loaded, resource2.first, resource2.second, current_time, check_exceeded); } -void QuotaContext::used(const std::pair & resource1, const std::pair & resource2, const std::pair & resource3, bool check_exceeded) const +void EnabledQuota::used(const std::pair & resource1, const std::pair & resource2, const std::pair & resource3, bool check_exceeded) const { auto loaded = intervals.load(); auto current_time = std::chrono::system_clock::now(); - Impl::used(user_name, *loaded, resource1.first, resource1.second, current_time, check_exceeded); - Impl::used(user_name, *loaded, resource2.first, resource2.second, current_time, check_exceeded); - Impl::used(user_name, *loaded, resource3.first, resource3.second, current_time, check_exceeded); + Impl::used(getUserName(), *loaded, resource1.first, resource1.second, current_time, check_exceeded); + Impl::used(getUserName(), *loaded, resource2.first, resource2.second, current_time, check_exceeded); + Impl::used(getUserName(), *loaded, resource3.first, resource3.second, current_time, check_exceeded); } -void QuotaContext::used(const std::vector> & resources, bool check_exceeded) const +void EnabledQuota::used(const std::vector> & resources, bool check_exceeded) const { auto loaded = intervals.load(); auto current_time = std::chrono::system_clock::now(); for (const auto & resource : resources) - Impl::used(user_name, *loaded, resource.first, resource.second, current_time, check_exceeded); + Impl::used(getUserName(), *loaded, resource.first, resource.second, current_time, check_exceeded); } -void QuotaContext::checkExceeded() const +void EnabledQuota::checkExceeded() const { auto loaded = intervals.load(); - Impl::checkExceeded(user_name, *loaded, std::chrono::system_clock::now()); + Impl::checkExceeded(getUserName(), *loaded, std::chrono::system_clock::now()); } -void QuotaContext::checkExceeded(ResourceType resource_type) const +void EnabledQuota::checkExceeded(ResourceType resource_type) const { auto loaded = intervals.load(); - Impl::checkExceeded(user_name, *loaded, resource_type, std::chrono::system_clock::now()); + Impl::checkExceeded(getUserName(), *loaded, resource_type, std::chrono::system_clock::now()); } -QuotaUsageInfo QuotaContext::getUsageInfo() const +QuotaUsageInfo EnabledQuota::getUsageInfo() const { auto loaded = intervals.load(); return loaded->getUsageInfo(std::chrono::system_clock::now()); } -QuotaUsageInfo::QuotaUsageInfo() : quota_id(UUID(UInt128(0))) +std::shared_ptr EnabledQuota::getUnlimitedQuota() { + static const std::shared_ptr res = [] + { + auto unlimited_quota = std::shared_ptr(new EnabledQuota); + unlimited_quota->intervals = boost::make_shared(); + return unlimited_quota; + }(); + return res; } - -QuotaUsageInfo::Interval::Interval() -{ - boost::range::fill(used, 0); - boost::range::fill(max, 0); -} } diff --git a/dbms/src/Access/QuotaContext.h b/dbms/src/Access/EnabledQuota.h similarity index 60% rename from dbms/src/Access/QuotaContext.h rename to dbms/src/Access/EnabledQuota.h index d788a08ea17..5a624c651af 100644 --- a/dbms/src/Access/QuotaContext.h +++ b/dbms/src/Access/EnabledQuota.h @@ -3,7 +3,6 @@ #include #include #include -#include #include #include #include @@ -16,17 +15,31 @@ namespace DB struct QuotaUsageInfo; -/// Instances of `QuotaContext` are used to track resource consumption. -class QuotaContext : public boost::noncopyable +/// Instances of `EnabledQuota` are used to track resource consumption. +class EnabledQuota : public boost::noncopyable { public: + struct Params + { + UUID user_id; + String user_name; + std::vector enabled_roles; + Poco::Net::IPAddress client_address; + String client_key; + + auto toTuple() const { return std::tie(user_id, enabled_roles, user_name, client_address, client_key); } + friend bool operator ==(const Params & lhs, const Params & rhs) { return lhs.toTuple() == rhs.toTuple(); } + friend bool operator !=(const Params & lhs, const Params & rhs) { return !(lhs == rhs); } + friend bool operator <(const Params & lhs, const Params & rhs) { return lhs.toTuple() < rhs.toTuple(); } + friend bool operator >(const Params & lhs, const Params & rhs) { return rhs < lhs; } + friend bool operator <=(const Params & lhs, const Params & rhs) { return !(rhs < lhs); } + friend bool operator >=(const Params & lhs, const Params & rhs) { return !(lhs < rhs); } + }; + using ResourceType = Quota::ResourceType; using ResourceAmount = Quota::ResourceAmount; - /// Default constructors makes an unlimited quota. - QuotaContext(); - - ~QuotaContext(); + ~EnabledQuota(); /// Tracks resource consumption. If the quota exceeded and `check_exceeded == true`, throws an exception. void used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded = true) const; @@ -39,15 +52,18 @@ public: void checkExceeded() const; void checkExceeded(ResourceType resource_type) const; - /// Returns the information about this quota context. + /// Returns the information about quota consumption. QuotaUsageInfo getUsageInfo() const; -private: - friend class QuotaContextFactory; - friend struct ext::shared_ptr_helper; + /// Returns an instance of EnabledQuota which is never exceeded. + static std::shared_ptr getUnlimitedQuota(); - /// Instances of this class are created by QuotaContextFactory. - QuotaContext(const String & user_name_, const UUID & user_id_, const std::vector & enabled_roles_, const Poco::Net::IPAddress & address_, const String & client_key_); +private: + friend class QuotaCache; + EnabledQuota(const Params & params_); + EnabledQuota() {} + + const String & getUserName() const { return params.user_name; } static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE; @@ -76,38 +92,8 @@ private: struct Impl; - const String user_name; - const UUID user_id; - const std::vector enabled_roles; - const Poco::Net::IPAddress address; - const String client_key; + const Params params; boost::atomic_shared_ptr intervals; /// atomically changed by QuotaUsageManager }; -using QuotaContextPtr = std::shared_ptr; - - -/// The information about a quota context. -struct QuotaUsageInfo -{ - using ResourceType = Quota::ResourceType; - using ResourceAmount = Quota::ResourceAmount; - static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE; - - struct Interval - { - ResourceAmount used[MAX_RESOURCE_TYPE]; - ResourceAmount max[MAX_RESOURCE_TYPE]; - std::chrono::seconds duration = std::chrono::seconds::zero(); - bool randomize_interval = false; - std::chrono::system_clock::time_point end_of_interval; - Interval(); - }; - - std::vector intervals; - UUID quota_id; - String quota_name; - String quota_key; - QuotaUsageInfo(); -}; } diff --git a/dbms/src/Access/EnabledRoles.cpp b/dbms/src/Access/EnabledRoles.cpp new file mode 100644 index 00000000000..fd48eb6830a --- /dev/null +++ b/dbms/src/Access/EnabledRoles.cpp @@ -0,0 +1,51 @@ +#include +#include +#include +#include + + +namespace DB +{ +EnabledRoles::EnabledRoles(const Params & params_) : params(params_) +{ +} + +EnabledRoles::~EnabledRoles() = default; + + +std::shared_ptr EnabledRoles::getRolesInfo() const +{ + std::lock_guard lock{mutex}; + return info; +} + + +ext::scope_guard EnabledRoles::subscribeForChanges(const OnChangeHandler & handler) const +{ + std::lock_guard lock{mutex}; + handlers.push_back(handler); + auto it = std::prev(handlers.end()); + + return [this, it] + { + std::lock_guard lock2{mutex}; + handlers.erase(it); + }; +} + + +void EnabledRoles::setRolesInfo(const std::shared_ptr & info_) +{ + std::vector handlers_to_notify; + SCOPE_EXIT({ for (const auto & handler : handlers_to_notify) handler(info_); }); + + std::lock_guard lock{mutex}; + + if (info && info_ && *info == *info_) + return; + + info = info_; + boost::range::copy(handlers, std::back_inserter(handlers_to_notify)); +} + +} diff --git a/dbms/src/Access/EnabledRoles.h b/dbms/src/Access/EnabledRoles.h new file mode 100644 index 00000000000..122b1a16fe3 --- /dev/null +++ b/dbms/src/Access/EnabledRoles.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include +#include +#include + + +namespace DB +{ +struct EnabledRolesInfo; + +class EnabledRoles +{ +public: + struct Params + { + std::vector current_roles; + std::vector current_roles_with_admin_option; + + auto toTuple() const { return std::tie(current_roles, current_roles_with_admin_option); } + friend bool operator ==(const Params & lhs, const Params & rhs) { return lhs.toTuple() == rhs.toTuple(); } + friend bool operator !=(const Params & lhs, const Params & rhs) { return !(lhs == rhs); } + friend bool operator <(const Params & lhs, const Params & rhs) { return lhs.toTuple() < rhs.toTuple(); } + friend bool operator >(const Params & lhs, const Params & rhs) { return rhs < lhs; } + friend bool operator <=(const Params & lhs, const Params & rhs) { return !(rhs < lhs); } + friend bool operator >=(const Params & lhs, const Params & rhs) { return !(lhs < rhs); } + }; + + ~EnabledRoles(); + + /// Returns all the roles specified in the constructor. + std::shared_ptr getRolesInfo() const; + + using OnChangeHandler = std::function & info)>; + + /// Called when either the specified roles or the roles granted to the specified roles are changed. + ext::scope_guard subscribeForChanges(const OnChangeHandler & handler) const; + +private: + friend class RoleCache; + EnabledRoles(const Params & params_); + + void setRolesInfo(const std::shared_ptr & info_); + + const Params params; + mutable std::shared_ptr info; + mutable std::list handlers; + mutable std::mutex mutex; +}; + +} diff --git a/dbms/src/Access/CurrentRolesInfo.cpp b/dbms/src/Access/EnabledRolesInfo.cpp similarity index 67% rename from dbms/src/Access/CurrentRolesInfo.cpp rename to dbms/src/Access/EnabledRolesInfo.cpp index f4cbd739021..01b90d6fa1e 100644 --- a/dbms/src/Access/CurrentRolesInfo.cpp +++ b/dbms/src/Access/EnabledRolesInfo.cpp @@ -1,10 +1,10 @@ -#include +#include namespace DB { -Strings CurrentRolesInfo::getCurrentRolesNames() const +Strings EnabledRolesInfo::getCurrentRolesNames() const { Strings result; result.reserve(current_roles.size()); @@ -14,7 +14,7 @@ Strings CurrentRolesInfo::getCurrentRolesNames() const } -Strings CurrentRolesInfo::getEnabledRolesNames() const +Strings EnabledRolesInfo::getEnabledRolesNames() const { Strings result; result.reserve(enabled_roles.size()); @@ -24,11 +24,12 @@ Strings CurrentRolesInfo::getEnabledRolesNames() const } -bool operator==(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs) +bool operator==(const EnabledRolesInfo & lhs, const EnabledRolesInfo & rhs) { return (lhs.current_roles == rhs.current_roles) && (lhs.enabled_roles == rhs.enabled_roles) && (lhs.enabled_roles_with_admin_option == rhs.enabled_roles_with_admin_option) && (lhs.names_of_roles == rhs.names_of_roles) - && (lhs.access == rhs.access) && (lhs.access_with_grant_option == rhs.access_with_grant_option); + && (lhs.access == rhs.access) && (lhs.access_with_grant_option == rhs.access_with_grant_option) + && (lhs.settings_from_enabled_roles == rhs.settings_from_enabled_roles); } } diff --git a/dbms/src/Access/CurrentRolesInfo.h b/dbms/src/Access/EnabledRolesInfo.h similarity index 60% rename from dbms/src/Access/CurrentRolesInfo.h rename to dbms/src/Access/EnabledRolesInfo.h index a4dd26be0f7..837d4b74ad5 100644 --- a/dbms/src/Access/CurrentRolesInfo.h +++ b/dbms/src/Access/EnabledRolesInfo.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -10,7 +11,7 @@ namespace DB { /// Information about a role. -struct CurrentRolesInfo +struct EnabledRolesInfo { std::vector current_roles; std::vector enabled_roles; @@ -18,14 +19,13 @@ struct CurrentRolesInfo std::unordered_map names_of_roles; AccessRights access; AccessRights access_with_grant_option; + SettingsProfileElements settings_from_enabled_roles; Strings getCurrentRolesNames() const; Strings getEnabledRolesNames() const; - friend bool operator ==(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs); - friend bool operator !=(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs) { return !(lhs == rhs); } + friend bool operator ==(const EnabledRolesInfo & lhs, const EnabledRolesInfo & rhs); + friend bool operator !=(const EnabledRolesInfo & lhs, const EnabledRolesInfo & rhs) { return !(lhs == rhs); } }; -using CurrentRolesInfoPtr = std::shared_ptr; - } diff --git a/dbms/src/Access/RowPolicyContext.cpp b/dbms/src/Access/EnabledRowPolicies.cpp similarity index 55% rename from dbms/src/Access/RowPolicyContext.cpp rename to dbms/src/Access/EnabledRowPolicies.cpp index 661a6cb4b5f..a525fb65606 100644 --- a/dbms/src/Access/RowPolicyContext.cpp +++ b/dbms/src/Access/EnabledRowPolicies.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include #include @@ -8,55 +8,50 @@ namespace DB { -size_t RowPolicyContext::Hash::operator()(const DatabaseAndTableNameRef & database_and_table_name) const +size_t EnabledRowPolicies::Hash::operator()(const DatabaseAndTableNameRef & database_and_table_name) const { return std::hash{}(database_and_table_name.first) - std::hash{}(database_and_table_name.second); } -RowPolicyContext::RowPolicyContext() - : map_of_mixed_conditions(boost::make_shared()) +EnabledRowPolicies::EnabledRowPolicies(const Params & params_) + : params(params_) { } - -RowPolicyContext::~RowPolicyContext() = default; +EnabledRowPolicies::~EnabledRowPolicies() = default; -RowPolicyContext::RowPolicyContext(const UUID & user_id_, const std::vector & enabled_roles_) - : user_id(user_id_), enabled_roles(enabled_roles_) -{} - - -ASTPtr RowPolicyContext::getCondition(const String & database, const String & table_name, ConditionIndex index) const +ASTPtr EnabledRowPolicies::getCondition(const String & database, const String & table_name, ConditionType type) const { /// We don't lock `mutex` here. auto loaded = map_of_mixed_conditions.load(); auto it = loaded->find({database, table_name}); if (it == loaded->end()) return {}; - return it->second.mixed_conditions[index]; + return it->second.mixed_conditions[type]; } -ASTPtr RowPolicyContext::combineConditionsUsingAnd(const ASTPtr & lhs, const ASTPtr & rhs) +ASTPtr EnabledRowPolicies::getCondition(const String & database, const String & table_name, ConditionType type, const ASTPtr & extra_condition) const { - if (!lhs) - return rhs; - if (!rhs) - return lhs; + ASTPtr main_condition = getCondition(database, table_name, type); + if (!main_condition) + return extra_condition; + if (!extra_condition) + return main_condition; auto function = std::make_shared(); auto exp_list = std::make_shared(); function->name = "and"; function->arguments = exp_list; function->children.push_back(exp_list); - exp_list->children.push_back(lhs); - exp_list->children.push_back(rhs); + exp_list->children.push_back(main_condition); + exp_list->children.push_back(extra_condition); return function; } -std::vector RowPolicyContext::getCurrentPolicyIDs() const +std::vector EnabledRowPolicies::getCurrentPolicyIDs() const { /// We don't lock `mutex` here. auto loaded = map_of_mixed_conditions.load(); @@ -67,7 +62,7 @@ std::vector RowPolicyContext::getCurrentPolicyIDs() const } -std::vector RowPolicyContext::getCurrentPolicyIDs(const String & database, const String & table_name) const +std::vector EnabledRowPolicies::getCurrentPolicyIDs(const String & database, const String & table_name) const { /// We don't lock `mutex` here. auto loaded = map_of_mixed_conditions.load(); @@ -76,4 +71,5 @@ std::vector RowPolicyContext::getCurrentPolicyIDs(const String & database, return {}; return it->second.policy_ids; } + } diff --git a/dbms/src/Access/RowPolicyContext.h b/dbms/src/Access/EnabledRowPolicies.h similarity index 55% rename from dbms/src/Access/RowPolicyContext.h rename to dbms/src/Access/EnabledRowPolicies.h index 2042b85bf7a..9befb65ff0b 100644 --- a/dbms/src/Access/RowPolicyContext.h +++ b/dbms/src/Access/EnabledRowPolicies.h @@ -15,23 +15,32 @@ using ASTPtr = std::shared_ptr; /// Provides fast access to row policies' conditions for a specific user and tables. -class RowPolicyContext +class EnabledRowPolicies { public: - /// Default constructor makes a row policy usage context which restricts nothing. - RowPolicyContext(); + struct Params + { + UUID user_id; + std::vector enabled_roles; - ~RowPolicyContext(); + auto toTuple() const { return std::tie(user_id, enabled_roles); } + friend bool operator ==(const Params & lhs, const Params & rhs) { return lhs.toTuple() == rhs.toTuple(); } + friend bool operator !=(const Params & lhs, const Params & rhs) { return !(lhs == rhs); } + friend bool operator <(const Params & lhs, const Params & rhs) { return lhs.toTuple() < rhs.toTuple(); } + friend bool operator >(const Params & lhs, const Params & rhs) { return rhs < lhs; } + friend bool operator <=(const Params & lhs, const Params & rhs) { return !(rhs < lhs); } + friend bool operator >=(const Params & lhs, const Params & rhs) { return !(lhs < rhs); } + }; - using ConditionIndex = RowPolicy::ConditionIndex; + ~EnabledRowPolicies(); + + using ConditionType = RowPolicy::ConditionType; /// Returns prepared filter for a specific table and operations. /// The function can return nullptr, that means there is no filters applied. /// The returned filter can be a combination of the filters defined by multiple row policies. - ASTPtr getCondition(const String & database, const String & table_name, ConditionIndex index) const; - - /// Combines two conditions into one by using the logical AND operator. - static ASTPtr combineConditionsUsingAnd(const ASTPtr & lhs, const ASTPtr & rhs); + ASTPtr getCondition(const String & database, const String & table_name, ConditionType type) const; + ASTPtr getCondition(const String & database, const String & table_name, ConditionType type, const ASTPtr & extra_condition) const; /// Returns IDs of all the policies used by the current user. std::vector getCurrentPolicyIDs() const; @@ -40,9 +49,8 @@ public: std::vector getCurrentPolicyIDs(const String & database, const String & table_name) const; private: - friend class RowPolicyContextFactory; - friend struct ext::shared_ptr_helper; - RowPolicyContext(const UUID & user_id_, const std::vector & enabled_roles_); /// RowPolicyContext should be created by RowPolicyContextFactory. + friend class RowPolicyCache; + EnabledRowPolicies(const Params & params_); using DatabaseAndTableName = std::pair; using DatabaseAndTableNameRef = std::pair; @@ -50,8 +58,8 @@ private: { size_t operator()(const DatabaseAndTableNameRef & database_and_table_name) const; }; - static constexpr size_t MAX_CONDITION_INDEX = RowPolicy::MAX_CONDITION_INDEX; - using ParsedConditions = std::array; + static constexpr size_t MAX_CONDITION_TYPE = RowPolicy::MAX_CONDITION_TYPE; + using ParsedConditions = std::array; struct MixedConditions { std::unique_ptr database_and_table_name_keeper; @@ -60,11 +68,8 @@ private: }; using MapOfMixedConditions = std::unordered_map; - const UUID user_id; - const std::vector enabled_roles; + const Params params; mutable boost::atomic_shared_ptr map_of_mixed_conditions; }; - -using RowPolicyContextPtr = std::shared_ptr; } diff --git a/dbms/src/Access/EnabledSettings.cpp b/dbms/src/Access/EnabledSettings.cpp new file mode 100644 index 00000000000..65e38e4827f --- /dev/null +++ b/dbms/src/Access/EnabledSettings.cpp @@ -0,0 +1,36 @@ +#include + + +namespace DB +{ + +EnabledSettings::EnabledSettings(const Params & params_) : params(params_) +{ +} + +EnabledSettings::~EnabledSettings() = default; + + +std::shared_ptr EnabledSettings::getSettings() const +{ + std::lock_guard lock{mutex}; + return settings; +} + + +std::shared_ptr EnabledSettings::getConstraints() const +{ + std::lock_guard lock{mutex}; + return constraints; +} + + +void EnabledSettings::setSettingsAndConstraints( + const std::shared_ptr & settings_, const std::shared_ptr & constraints_) +{ + std::lock_guard lock{mutex}; + settings = settings_; + constraints = constraints_; +} + +} diff --git a/dbms/src/Access/EnabledSettings.h b/dbms/src/Access/EnabledSettings.h new file mode 100644 index 00000000000..d8e969d685d --- /dev/null +++ b/dbms/src/Access/EnabledSettings.h @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +/// Watches settings profiles for a specific user and roles. +class EnabledSettings +{ +public: + struct Params + { + UUID user_id; + std::vector enabled_roles; + SettingsProfileElements settings_from_enabled_roles; + SettingsProfileElements settings_from_user; + + auto toTuple() const { return std::tie(user_id, enabled_roles, settings_from_enabled_roles, settings_from_user); } + friend bool operator ==(const Params & lhs, const Params & rhs) { return lhs.toTuple() == rhs.toTuple(); } + friend bool operator !=(const Params & lhs, const Params & rhs) { return !(lhs == rhs); } + friend bool operator <(const Params & lhs, const Params & rhs) { return lhs.toTuple() < rhs.toTuple(); } + friend bool operator >(const Params & lhs, const Params & rhs) { return rhs < lhs; } + friend bool operator <=(const Params & lhs, const Params & rhs) { return !(rhs < lhs); } + friend bool operator >=(const Params & lhs, const Params & rhs) { return !(lhs < rhs); } + }; + + ~EnabledSettings(); + + /// Returns the default settings come from settings profiles defined for the user + /// and the roles passed in the constructor. + std::shared_ptr getSettings() const; + + /// Returns the constraints come from settings profiles defined for the user + /// and the roles passed in the constructor. + std::shared_ptr getConstraints() const; + +private: + friend class SettingsProfilesCache; + EnabledSettings(const Params & params_); + + void setSettingsAndConstraints( + const std::shared_ptr & settings_, const std::shared_ptr & constraints_); + + const Params params; + SettingsProfileElements settings_from_enabled; + std::shared_ptr settings; + std::shared_ptr constraints; + mutable std::mutex mutex; +}; +} diff --git a/dbms/src/Access/GenericRoleSet.cpp b/dbms/src/Access/ExtendedRoleSet.cpp similarity index 65% rename from dbms/src/Access/GenericRoleSet.cpp rename to dbms/src/Access/ExtendedRoleSet.cpp index 1e751f995c1..b59dc7ac232 100644 --- a/dbms/src/Access/GenericRoleSet.cpp +++ b/dbms/src/Access/ExtendedRoleSet.cpp @@ -1,8 +1,8 @@ -#include +#include #include #include #include -#include +#include #include #include #include @@ -17,57 +17,59 @@ namespace ErrorCodes { extern const int LOGICAL_ERROR; } -GenericRoleSet::GenericRoleSet() = default; -GenericRoleSet::GenericRoleSet(const GenericRoleSet & src) = default; -GenericRoleSet & GenericRoleSet::operator =(const GenericRoleSet & src) = default; -GenericRoleSet::GenericRoleSet(GenericRoleSet && src) = default; -GenericRoleSet & GenericRoleSet::operator =(GenericRoleSet && src) = default; -GenericRoleSet::GenericRoleSet(AllTag) +ExtendedRoleSet::ExtendedRoleSet() = default; +ExtendedRoleSet::ExtendedRoleSet(const ExtendedRoleSet & src) = default; +ExtendedRoleSet & ExtendedRoleSet::operator =(const ExtendedRoleSet & src) = default; +ExtendedRoleSet::ExtendedRoleSet(ExtendedRoleSet && src) = default; +ExtendedRoleSet & ExtendedRoleSet::operator =(ExtendedRoleSet && src) = default; + + +ExtendedRoleSet::ExtendedRoleSet(AllTag) { all = true; } -GenericRoleSet::GenericRoleSet(const UUID & id) +ExtendedRoleSet::ExtendedRoleSet(const UUID & id) { add(id); } -GenericRoleSet::GenericRoleSet(const std::vector & ids_) +ExtendedRoleSet::ExtendedRoleSet(const std::vector & ids_) { add(ids_); } -GenericRoleSet::GenericRoleSet(const boost::container::flat_set & ids_) +ExtendedRoleSet::ExtendedRoleSet(const boost::container::flat_set & ids_) { add(ids_); } -GenericRoleSet::GenericRoleSet(const ASTGenericRoleSet & ast) +ExtendedRoleSet::ExtendedRoleSet(const ASTExtendedRoleSet & ast) { init(ast, nullptr, nullptr); } -GenericRoleSet::GenericRoleSet(const ASTGenericRoleSet & ast, const UUID & current_user_id) +ExtendedRoleSet::ExtendedRoleSet(const ASTExtendedRoleSet & ast, const UUID & current_user_id) { init(ast, nullptr, ¤t_user_id); } -GenericRoleSet::GenericRoleSet(const ASTGenericRoleSet & ast, const AccessControlManager & manager) +ExtendedRoleSet::ExtendedRoleSet(const ASTExtendedRoleSet & ast, const AccessControlManager & manager) { init(ast, &manager, nullptr); } -GenericRoleSet::GenericRoleSet(const ASTGenericRoleSet & ast, const AccessControlManager & manager, const UUID & current_user_id) +ExtendedRoleSet::ExtendedRoleSet(const ASTExtendedRoleSet & ast, const AccessControlManager & manager, const UUID & current_user_id) { init(ast, &manager, ¤t_user_id); } -void GenericRoleSet::init(const ASTGenericRoleSet & ast, const AccessControlManager * manager, const UUID * current_user_id) +void ExtendedRoleSet::init(const ASTExtendedRoleSet & ast, const AccessControlManager * manager, const UUID * current_user_id) { all = ast.all; @@ -113,9 +115,9 @@ void GenericRoleSet::init(const ASTGenericRoleSet & ast, const AccessControlMana } -std::shared_ptr GenericRoleSet::toAST() const +std::shared_ptr ExtendedRoleSet::toAST() const { - auto ast = std::make_shared(); + auto ast = std::make_shared(); ast->id_mode = true; ast->all = all; @@ -137,14 +139,14 @@ std::shared_ptr GenericRoleSet::toAST() const } -String GenericRoleSet::toString() const +String ExtendedRoleSet::toString() const { auto ast = toAST(); return serializeAST(*ast); } -Strings GenericRoleSet::toStrings() const +Strings ExtendedRoleSet::toStrings() const { if (all || !except_ids.empty()) return {toString()}; @@ -157,9 +159,9 @@ Strings GenericRoleSet::toStrings() const } -std::shared_ptr GenericRoleSet::toASTWithNames(const AccessControlManager & manager) const +std::shared_ptr ExtendedRoleSet::toASTWithNames(const AccessControlManager & manager) const { - auto ast = std::make_shared(); + auto ast = std::make_shared(); ast->all = all; if (!ids.empty()) @@ -190,14 +192,14 @@ std::shared_ptr GenericRoleSet::toASTWithNames(const AccessCo } -String GenericRoleSet::toStringWithNames(const AccessControlManager & manager) const +String ExtendedRoleSet::toStringWithNames(const AccessControlManager & manager) const { auto ast = toASTWithNames(manager); return serializeAST(*ast); } -Strings GenericRoleSet::toStringsWithNames(const AccessControlManager & manager) const +Strings ExtendedRoleSet::toStringsWithNames(const AccessControlManager & manager) const { if (all || !except_ids.empty()) return {toStringWithNames(manager)}; @@ -215,13 +217,13 @@ Strings GenericRoleSet::toStringsWithNames(const AccessControlManager & manager) } -bool GenericRoleSet::empty() const +bool ExtendedRoleSet::empty() const { return ids.empty() && !all; } -void GenericRoleSet::clear() +void ExtendedRoleSet::clear() { ids.clear(); all = false; @@ -229,33 +231,33 @@ void GenericRoleSet::clear() } -void GenericRoleSet::add(const UUID & id) +void ExtendedRoleSet::add(const UUID & id) { ids.insert(id); } -void GenericRoleSet::add(const std::vector & ids_) +void ExtendedRoleSet::add(const std::vector & ids_) { for (const auto & id : ids_) add(id); } -void GenericRoleSet::add(const boost::container::flat_set & ids_) +void ExtendedRoleSet::add(const boost::container::flat_set & ids_) { for (const auto & id : ids_) add(id); } -bool GenericRoleSet::match(const UUID & id) const +bool ExtendedRoleSet::match(const UUID & id) const { return (all || ids.contains(id)) && !except_ids.contains(id); } -bool GenericRoleSet::match(const UUID & user_id, const std::vector & enabled_roles) const +bool ExtendedRoleSet::match(const UUID & user_id, const std::vector & enabled_roles) const { if (!all && !ids.contains(user_id)) { @@ -274,7 +276,7 @@ bool GenericRoleSet::match(const UUID & user_id, const std::vector & enabl } -bool GenericRoleSet::match(const UUID & user_id, const boost::container::flat_set & enabled_roles) const +bool ExtendedRoleSet::match(const UUID & user_id, const boost::container::flat_set & enabled_roles) const { if (!all && !ids.contains(user_id)) { @@ -293,17 +295,17 @@ bool GenericRoleSet::match(const UUID & user_id, const boost::container::flat_se } -std::vector GenericRoleSet::getMatchingIDs() const +std::vector ExtendedRoleSet::getMatchingIDs() const { if (all) - throw Exception("getAllMatchingIDs() can't get ALL ids", ErrorCodes::LOGICAL_ERROR); + throw Exception("getAllMatchingIDs() can't get ALL ids without manager", ErrorCodes::LOGICAL_ERROR); std::vector res; boost::range::set_difference(ids, except_ids, std::back_inserter(res)); return res; } -std::vector GenericRoleSet::getMatchingUsers(const AccessControlManager & manager) const +std::vector ExtendedRoleSet::getMatchingIDs(const AccessControlManager & manager) const { if (!all) return getMatchingIDs(); @@ -314,37 +316,17 @@ std::vector GenericRoleSet::getMatchingUsers(const AccessControlManager & if (match(id)) res.push_back(id); } - return res; -} - - -std::vector GenericRoleSet::getMatchingRoles(const AccessControlManager & manager) const -{ - if (!all) - return getMatchingIDs(); - - std::vector res; for (const UUID & id : manager.findAll()) { if (match(id)) res.push_back(id); } + return res; } -std::vector GenericRoleSet::getMatchingUsersAndRoles(const AccessControlManager & manager) const -{ - if (!all) - return getMatchingIDs(); - - std::vector vec = getMatchingUsers(manager); - boost::range::push_back(vec, getMatchingRoles(manager)); - return vec; -} - - -bool operator ==(const GenericRoleSet & lhs, const GenericRoleSet & rhs) +bool operator ==(const ExtendedRoleSet & lhs, const ExtendedRoleSet & rhs) { return (lhs.all == rhs.all) && (lhs.ids == rhs.ids) && (lhs.except_ids == rhs.except_ids); } diff --git a/dbms/src/Access/ExtendedRoleSet.h b/dbms/src/Access/ExtendedRoleSet.h new file mode 100644 index 00000000000..61a4db6e0ae --- /dev/null +++ b/dbms/src/Access/ExtendedRoleSet.h @@ -0,0 +1,75 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ +class ASTExtendedRoleSet; +class AccessControlManager; + + +/// Represents a set of users/roles like +/// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...] +/// Similar to ASTExtendedRoleSet, but with IDs instead of names. +struct ExtendedRoleSet +{ + ExtendedRoleSet(); + ExtendedRoleSet(const ExtendedRoleSet & src); + ExtendedRoleSet & operator =(const ExtendedRoleSet & src); + ExtendedRoleSet(ExtendedRoleSet && src); + ExtendedRoleSet & operator =(ExtendedRoleSet && src); + + struct AllTag {}; + ExtendedRoleSet(AllTag); + + ExtendedRoleSet(const UUID & id); + ExtendedRoleSet(const std::vector & ids_); + ExtendedRoleSet(const boost::container::flat_set & ids_); + + /// The constructor from AST requires the AccessControlManager if `ast.id_mode == false`. + ExtendedRoleSet(const ASTExtendedRoleSet & ast); + ExtendedRoleSet(const ASTExtendedRoleSet & ast, const UUID & current_user_id); + ExtendedRoleSet(const ASTExtendedRoleSet & ast, const AccessControlManager & manager); + ExtendedRoleSet(const ASTExtendedRoleSet & ast, const AccessControlManager & manager, const UUID & current_user_id); + + std::shared_ptr toAST() const; + String toString() const; + Strings toStrings() const; + + std::shared_ptr toASTWithNames(const AccessControlManager & manager) const; + String toStringWithNames(const AccessControlManager & manager) const; + Strings toStringsWithNames(const AccessControlManager & manager) const; + + bool empty() const; + void clear(); + void add(const UUID & id); + void add(const std::vector & ids_); + void add(const boost::container::flat_set & ids_); + + /// Checks if a specified ID matches this ExtendedRoleSet. + bool match(const UUID & id) const; + bool match(const UUID & user_id, const std::vector & enabled_roles) const; + bool match(const UUID & user_id, const boost::container::flat_set & enabled_roles) const; + + /// Returns a list of matching IDs. The function must not be called if `all` == `true`. + std::vector getMatchingIDs() const; + + /// Returns a list of matching users and roles. + std::vector getMatchingIDs(const AccessControlManager & manager) const; + + friend bool operator ==(const ExtendedRoleSet & lhs, const ExtendedRoleSet & rhs); + friend bool operator !=(const ExtendedRoleSet & lhs, const ExtendedRoleSet & rhs) { return !(lhs == rhs); } + + boost::container::flat_set ids; + bool all = false; + boost::container::flat_set except_ids; + +private: + void init(const ASTExtendedRoleSet & ast, const AccessControlManager * manager = nullptr, const UUID * current_user_id = nullptr); +}; + +} diff --git a/dbms/src/Access/GenericRoleSet.h b/dbms/src/Access/GenericRoleSet.h deleted file mode 100644 index e276eb4066a..00000000000 --- a/dbms/src/Access/GenericRoleSet.h +++ /dev/null @@ -1,77 +0,0 @@ -#pragma once - -#include -#include -#include -#include - - -namespace DB -{ -class ASTGenericRoleSet; -class AccessControlManager; - - -/// Represents a set of users/roles like -/// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...] -/// Similar to ASTGenericRoleSet, but with IDs instead of names. -struct GenericRoleSet -{ - GenericRoleSet(); - GenericRoleSet(const GenericRoleSet & src); - GenericRoleSet & operator =(const GenericRoleSet & src); - GenericRoleSet(GenericRoleSet && src); - GenericRoleSet & operator =(GenericRoleSet && src); - - struct AllTag {}; - GenericRoleSet(AllTag); - - GenericRoleSet(const UUID & id); - GenericRoleSet(const std::vector & ids_); - GenericRoleSet(const boost::container::flat_set & ids_); - - /// The constructor from AST requires the AccessControlManager if `ast.id_mode == false`. - GenericRoleSet(const ASTGenericRoleSet & ast); - GenericRoleSet(const ASTGenericRoleSet & ast, const UUID & current_user_id); - GenericRoleSet(const ASTGenericRoleSet & ast, const AccessControlManager & manager); - GenericRoleSet(const ASTGenericRoleSet & ast, const AccessControlManager & manager, const UUID & current_user_id); - - std::shared_ptr toAST() const; - String toString() const; - Strings toStrings() const; - - std::shared_ptr toASTWithNames(const AccessControlManager & manager) const; - String toStringWithNames(const AccessControlManager & manager) const; - Strings toStringsWithNames(const AccessControlManager & manager) const; - - bool empty() const; - void clear(); - void add(const UUID & id); - void add(const std::vector & ids_); - void add(const boost::container::flat_set & ids_); - - /// Checks if a specified ID matches this GenericRoleSet. - bool match(const UUID & id) const; - bool match(const UUID & user_id, const std::vector & enabled_roles) const; - bool match(const UUID & user_id, const boost::container::flat_set & enabled_roles) const; - - /// Returns a list of matching IDs. The function must not be called if `all` == `true`. - std::vector getMatchingIDs() const; - - /// Returns a list of matching users. - std::vector getMatchingUsers(const AccessControlManager & manager) const; - std::vector getMatchingRoles(const AccessControlManager & manager) const; - std::vector getMatchingUsersAndRoles(const AccessControlManager & manager) const; - - friend bool operator ==(const GenericRoleSet & lhs, const GenericRoleSet & rhs); - friend bool operator !=(const GenericRoleSet & lhs, const GenericRoleSet & rhs) { return !(lhs == rhs); } - - boost::container::flat_set ids; - bool all = false; - boost::container::flat_set except_ids; - -private: - void init(const ASTGenericRoleSet & ast, const AccessControlManager * manager = nullptr, const UUID * current_user_id = nullptr); -}; - -} diff --git a/dbms/src/Access/IAccessEntity.cpp b/dbms/src/Access/IAccessEntity.cpp index 361946863b2..5dbc056b71c 100644 --- a/dbms/src/Access/IAccessEntity.cpp +++ b/dbms/src/Access/IAccessEntity.cpp @@ -2,6 +2,8 @@ #include #include #include +#include +#include #include @@ -15,9 +17,30 @@ String IAccessEntity::getTypeName(std::type_index type) return "Quota"; if (type == typeid(RowPolicy)) return "Row policy"; + if (type == typeid(Role)) + return "Role"; + if (type == typeid(SettingsProfile)) + return "Settings profile"; return demangle(type.name()); } + +const char * IAccessEntity::getKeyword(std::type_index type) +{ + if (type == typeid(User)) + return "USER"; + if (type == typeid(Quota)) + return "QUOTA"; + if (type == typeid(RowPolicy)) + return "ROW POLICY"; + if (type == typeid(Role)) + return "ROLE"; + if (type == typeid(SettingsProfile)) + return "SETTINGS PROFILE"; + __builtin_unreachable(); +} + + bool IAccessEntity::equal(const IAccessEntity & other) const { return (full_name == other.full_name) && (getType() == other.getType()); diff --git a/dbms/src/Access/IAccessEntity.h b/dbms/src/Access/IAccessEntity.h index 272fde006ac..9214d64aa8c 100644 --- a/dbms/src/Access/IAccessEntity.h +++ b/dbms/src/Access/IAccessEntity.h @@ -20,6 +20,8 @@ struct IAccessEntity std::type_index getType() const { return typeid(*this); } static String getTypeName(std::type_index type); const String getTypeName() const { return getTypeName(getType()); } + static const char * getKeyword(std::type_index type); + const char * getKeyword() const { return getKeyword(getType()); } template bool isTypeOf() const { return isTypeOf(typeid(EntityType)); } diff --git a/dbms/src/Access/Quota.cpp b/dbms/src/Access/Quota.cpp index d9e9e0b35fc..e3a9e11eb10 100644 --- a/dbms/src/Access/Quota.cpp +++ b/dbms/src/Access/Quota.cpp @@ -23,7 +23,7 @@ bool Quota::equal(const IAccessEntity & other) const if (!IAccessEntity::equal(other)) return false; const auto & other_quota = typeid_cast(other); - return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (roles == other_quota.roles); + return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (to_roles == other_quota.to_roles); } diff --git a/dbms/src/Access/Quota.h b/dbms/src/Access/Quota.h index 4968e5d92c9..714d582e98f 100644 --- a/dbms/src/Access/Quota.h +++ b/dbms/src/Access/Quota.h @@ -1,7 +1,7 @@ -#pragma once +#pragma once #include -#include +#include #include @@ -63,7 +63,7 @@ struct Quota : public IAccessEntity KeyType key_type = KeyType::NONE; /// Which roles or users should use this quota. - GenericRoleSet roles; + ExtendedRoleSet to_roles; bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } diff --git a/dbms/src/Access/QuotaContextFactory.cpp b/dbms/src/Access/QuotaCache.cpp similarity index 67% rename from dbms/src/Access/QuotaContextFactory.cpp rename to dbms/src/Access/QuotaCache.cpp index 7c585bdddee..cdf298d0e57 100644 --- a/dbms/src/Access/QuotaContextFactory.cpp +++ b/dbms/src/Access/QuotaCache.cpp @@ -1,5 +1,6 @@ -#include -#include +#include +#include +#include #include #include #include @@ -8,7 +9,6 @@ #include #include #include -#include #include @@ -31,58 +31,53 @@ namespace } -void QuotaContextFactory::QuotaInfo::setQuota(const QuotaPtr & quota_, const UUID & quota_id_) +void QuotaCache::QuotaInfo::setQuota(const QuotaPtr & quota_, const UUID & quota_id_) { quota = quota_; quota_id = quota_id_; - roles = "a->roles; + roles = "a->to_roles; rebuildAllIntervals(); } -bool QuotaContextFactory::QuotaInfo::canUseWithContext(const QuotaContext & context) const -{ - return roles->match(context.user_id, context.enabled_roles); -} - - -String QuotaContextFactory::QuotaInfo::calculateKey(const QuotaContext & context) const +String QuotaCache::QuotaInfo::calculateKey(const EnabledQuota & enabled) const { + const auto & params = enabled.params; using KeyType = Quota::KeyType; switch (quota->key_type) { case KeyType::NONE: return ""; case KeyType::USER_NAME: - return context.user_name; + return params.user_name; case KeyType::IP_ADDRESS: - return context.address.toString(); + return params.client_address.toString(); case KeyType::CLIENT_KEY: { - if (!context.client_key.empty()) - return context.client_key; + if (!params.client_key.empty()) + return params.client_key; throw Exception( - "Quota " + quota->getName() + " (for user " + context.user_name + ") requires a client supplied key.", + "Quota " + quota->getName() + " (for user " + params.user_name + ") requires a client supplied key.", ErrorCodes::QUOTA_REQUIRES_CLIENT_KEY); } case KeyType::CLIENT_KEY_OR_USER_NAME: { - if (!context.client_key.empty()) - return context.client_key; - return context.user_name; + if (!params.client_key.empty()) + return params.client_key; + return params.user_name; } case KeyType::CLIENT_KEY_OR_IP_ADDRESS: { - if (!context.client_key.empty()) - return context.client_key; - return context.address.toString(); + if (!params.client_key.empty()) + return params.client_key; + return params.client_address.toString(); } } __builtin_unreachable(); } -boost::shared_ptr QuotaContextFactory::QuotaInfo::getOrBuildIntervals(const String & key) +boost::shared_ptr QuotaCache::QuotaInfo::getOrBuildIntervals(const String & key) { auto it = key_to_intervals.find(key); if (it != key_to_intervals.end()) @@ -91,14 +86,14 @@ boost::shared_ptr QuotaContextFactory::QuotaInfo: } -void QuotaContextFactory::QuotaInfo::rebuildAllIntervals() +void QuotaCache::QuotaInfo::rebuildAllIntervals() { for (const String & key : key_to_intervals | boost::adaptors::map_keys) rebuildIntervals(key); } -boost::shared_ptr QuotaContextFactory::QuotaInfo::rebuildIntervals(const String & key) +boost::shared_ptr QuotaCache::QuotaInfo::rebuildIntervals(const String & key) { auto new_intervals = boost::make_shared(); new_intervals->quota_name = quota->getName(); @@ -164,27 +159,42 @@ boost::shared_ptr QuotaContextFactory::QuotaInfo: } -QuotaContextFactory::QuotaContextFactory(const AccessControlManager & access_control_manager_) +QuotaCache::QuotaCache(const AccessControlManager & access_control_manager_) : access_control_manager(access_control_manager_) { } - -QuotaContextFactory::~QuotaContextFactory() = default; +QuotaCache::~QuotaCache() = default; -QuotaContextPtr QuotaContextFactory::createContext(const String & user_name, const UUID & user_id, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & client_key) +std::shared_ptr QuotaCache::getEnabledQuota(const UUID & user_id, const String & user_name, const std::vector & enabled_roles, const Poco::Net::IPAddress & client_address, const String & client_key) { std::lock_guard lock{mutex}; ensureAllQuotasRead(); - auto context = ext::shared_ptr_helper::create(user_name, user_id, enabled_roles, address, client_key); - contexts.push_back(context); - chooseQuotaForContext(context); - return context; + + EnabledQuota::Params params; + params.user_id = user_id; + params.user_name = user_name; + params.enabled_roles = enabled_roles; + params.client_address = client_address; + params.client_key = client_key; + auto it = enabled_quotas.find(params); + if (it != enabled_quotas.end()) + { + auto from_cache = it->second.lock(); + if (from_cache) + return from_cache; + enabled_quotas.erase(it); + } + + auto res = std::shared_ptr(new EnabledQuota(params)); + enabled_quotas.emplace(std::move(params), res); + chooseQuotaToConsumeFor(*res); + return res; } -void QuotaContextFactory::ensureAllQuotasRead() +void QuotaCache::ensureAllQuotasRead() { /// `mutex` is already locked. if (all_quotas_read) @@ -209,7 +219,7 @@ void QuotaContextFactory::ensureAllQuotasRead() } -void QuotaContextFactory::quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr & new_quota) +void QuotaCache::quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr & new_quota) { std::lock_guard lock{mutex}; auto it = all_quotas.find(quota_id); @@ -225,42 +235,42 @@ void QuotaContextFactory::quotaAddedOrChanged(const UUID & quota_id, const std:: auto & info = it->second; info.setQuota(new_quota, quota_id); - chooseQuotaForAllContexts(); + chooseQuotaToConsume(); } -void QuotaContextFactory::quotaRemoved(const UUID & quota_id) +void QuotaCache::quotaRemoved(const UUID & quota_id) { std::lock_guard lock{mutex}; all_quotas.erase(quota_id); - chooseQuotaForAllContexts(); + chooseQuotaToConsume(); } -void QuotaContextFactory::chooseQuotaForAllContexts() +void QuotaCache::chooseQuotaToConsume() { /// `mutex` is already locked. - boost::range::remove_erase_if( - contexts, - [&](const std::weak_ptr & weak) + std::erase_if( + enabled_quotas, + [&](const std::pair> & pr) { - auto context = weak.lock(); - if (!context) - return true; // remove from the `contexts` list. - chooseQuotaForContext(context); - return false; // keep in the `contexts` list. + auto elem = pr.second.lock(); + if (!elem) + return true; // remove from the `enabled_quotas` list. + chooseQuotaToConsumeFor(*elem); + return false; // keep in the `enabled_quotas` list. }); } -void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr & context) +void QuotaCache::chooseQuotaToConsumeFor(EnabledQuota & enabled) { /// `mutex` is already locked. boost::shared_ptr intervals; for (auto & info : all_quotas | boost::adaptors::map_values) { - if (info.canUseWithContext(*context)) + if (info.roles->match(enabled.params.user_id, enabled.params.enabled_roles)) { - String key = info.calculateKey(*context); + String key = info.calculateKey(enabled); intervals = info.getOrBuildIntervals(key); break; } @@ -269,11 +279,11 @@ void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr(); /// No quota == no limits. - context->intervals.store(intervals); + enabled.intervals.store(intervals); } -std::vector QuotaContextFactory::getUsageInfo() const +std::vector QuotaCache::getUsageInfo() const { std::lock_guard lock{mutex}; std::vector all_infos; diff --git a/dbms/src/Access/QuotaContextFactory.h b/dbms/src/Access/QuotaCache.h similarity index 60% rename from dbms/src/Access/QuotaContextFactory.h rename to dbms/src/Access/QuotaCache.h index c130da4f2cd..81734f385c1 100644 --- a/dbms/src/Access/QuotaContextFactory.h +++ b/dbms/src/Access/QuotaCache.h @@ -1,11 +1,11 @@ #pragma once -#include +#include #include #include #include +#include #include -#include namespace DB @@ -14,47 +14,46 @@ class AccessControlManager; /// Stores information how much amount of resources have been consumed and how much are left. -class QuotaContextFactory +class QuotaCache { public: - QuotaContextFactory(const AccessControlManager & access_control_manager_); - ~QuotaContextFactory(); + QuotaCache(const AccessControlManager & access_control_manager_); + ~QuotaCache(); - QuotaContextPtr createContext(const String & user_name, const UUID & user_id, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & client_key); + std::shared_ptr getEnabledQuota(const UUID & user_id, const String & user_name, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & client_key); std::vector getUsageInfo() const; private: - using Interval = QuotaContext::Interval; - using Intervals = QuotaContext::Intervals; + using Interval = EnabledQuota::Interval; + using Intervals = EnabledQuota::Intervals; struct QuotaInfo { QuotaInfo(const QuotaPtr & quota_, const UUID & quota_id_) { setQuota(quota_, quota_id_); } void setQuota(const QuotaPtr & quota_, const UUID & quota_id_); - bool canUseWithContext(const QuotaContext & context) const; - String calculateKey(const QuotaContext & context) const; + String calculateKey(const EnabledQuota & enabled_quota) const; boost::shared_ptr getOrBuildIntervals(const String & key); boost::shared_ptr rebuildIntervals(const String & key); void rebuildAllIntervals(); QuotaPtr quota; UUID quota_id; - const GenericRoleSet * roles = nullptr; + const ExtendedRoleSet * roles = nullptr; std::unordered_map> key_to_intervals; }; void ensureAllQuotasRead(); void quotaAddedOrChanged(const UUID & quota_id, const std::shared_ptr & new_quota); void quotaRemoved(const UUID & quota_id); - void chooseQuotaForAllContexts(); - void chooseQuotaForContext(const std::shared_ptr & context); + void chooseQuotaToConsume(); + void chooseQuotaToConsumeFor(EnabledQuota & enabled_quota); const AccessControlManager & access_control_manager; mutable std::mutex mutex; std::unordered_map all_quotas; bool all_quotas_read = false; ext::scope_guard subscription; - std::vector> contexts; + std::map> enabled_quotas; }; } diff --git a/dbms/src/Access/QuotaUsageInfo.cpp b/dbms/src/Access/QuotaUsageInfo.cpp new file mode 100644 index 00000000000..bcdf2b50062 --- /dev/null +++ b/dbms/src/Access/QuotaUsageInfo.cpp @@ -0,0 +1,17 @@ +#include +#include + + +namespace DB +{ +QuotaUsageInfo::QuotaUsageInfo() : quota_id(UUID(UInt128(0))) +{ +} + + +QuotaUsageInfo::Interval::Interval() +{ + boost::range::fill(used, 0); + boost::range::fill(max, 0); +} +} diff --git a/dbms/src/Access/QuotaUsageInfo.h b/dbms/src/Access/QuotaUsageInfo.h new file mode 100644 index 00000000000..94e16fb9f69 --- /dev/null +++ b/dbms/src/Access/QuotaUsageInfo.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + + +namespace DB +{ +/// The information about a quota consumption. +struct QuotaUsageInfo +{ + using ResourceType = Quota::ResourceType; + using ResourceAmount = Quota::ResourceAmount; + static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE; + + struct Interval + { + ResourceAmount used[MAX_RESOURCE_TYPE]; + ResourceAmount max[MAX_RESOURCE_TYPE]; + std::chrono::seconds duration = std::chrono::seconds::zero(); + bool randomize_interval = false; + std::chrono::system_clock::time_point end_of_interval; + Interval(); + }; + + std::vector intervals; + UUID quota_id; + String quota_name; + String quota_key; + QuotaUsageInfo(); +}; +} diff --git a/dbms/src/Access/Role.cpp b/dbms/src/Access/Role.cpp index 7b1a395feec..f20ef9b9bfa 100644 --- a/dbms/src/Access/Role.cpp +++ b/dbms/src/Access/Role.cpp @@ -10,7 +10,8 @@ bool Role::equal(const IAccessEntity & other) const return false; const auto & other_role = typeid_cast(other); return (access == other_role.access) && (access_with_grant_option == other_role.access_with_grant_option) - && (granted_roles == other_role.granted_roles) && (granted_roles_with_admin_option == other_role.granted_roles_with_admin_option); + && (granted_roles == other_role.granted_roles) && (granted_roles_with_admin_option == other_role.granted_roles_with_admin_option) + && (settings == other_role.settings); } } diff --git a/dbms/src/Access/Role.h b/dbms/src/Access/Role.h index eaeb8debd3a..04330ba85f5 100644 --- a/dbms/src/Access/Role.h +++ b/dbms/src/Access/Role.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -15,6 +16,7 @@ struct Role : public IAccessEntity AccessRights access_with_grant_option; boost::container::flat_set granted_roles; boost::container::flat_set granted_roles_with_admin_option; + SettingsProfileElements settings; bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } diff --git a/dbms/src/Access/RoleCache.cpp b/dbms/src/Access/RoleCache.cpp new file mode 100644 index 00000000000..63e19a3cb40 --- /dev/null +++ b/dbms/src/Access/RoleCache.cpp @@ -0,0 +1,187 @@ +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace +{ + struct CollectedRoleInfo + { + RolePtr role; + bool is_current_role = false; + bool with_admin_option = false; + }; + + + void collectRoles(boost::container::flat_map & collected_roles, + const std::function & get_role_function, + const UUID & role_id, + bool is_current_role, + bool with_admin_option) + { + auto it = collected_roles.find(role_id); + if (it != collected_roles.end()) + { + it->second.is_current_role |= is_current_role; + it->second.with_admin_option |= with_admin_option; + return; + } + + auto role = get_role_function(role_id); + collected_roles[role_id] = CollectedRoleInfo{role, is_current_role, with_admin_option}; + + if (!role) + return; + + for (const auto & granted_role : role->granted_roles) + collectRoles(collected_roles, get_role_function, granted_role, false, false); + + for (const auto & granted_role : role->granted_roles_with_admin_option) + collectRoles(collected_roles, get_role_function, granted_role, false, true); + } + + + std::shared_ptr collectInfoForRoles(const boost::container::flat_map & roles) + { + auto new_info = std::make_shared(); + for (const auto & [role_id, collect_info] : roles) + { + const auto & role = collect_info.role; + if (!role) + continue; + if (collect_info.is_current_role) + new_info->current_roles.emplace_back(role_id); + new_info->enabled_roles.emplace_back(role_id); + if (collect_info.with_admin_option) + new_info->enabled_roles_with_admin_option.emplace_back(role_id); + new_info->names_of_roles[role_id] = role->getName(); + new_info->access.merge(role->access); + new_info->access_with_grant_option.merge(role->access_with_grant_option); + new_info->settings_from_enabled_roles.merge(role->settings); + } + return new_info; + } +} + + +RoleCache::RoleCache(const AccessControlManager & manager_) + : manager(manager_), cache(600000 /* 10 minutes */) {} + + +RoleCache::~RoleCache() = default; + + +std::shared_ptr RoleCache::getEnabledRoles( + const std::vector & roles, const std::vector & roles_with_admin_option) +{ + std::lock_guard lock{mutex}; + + EnabledRoles::Params params; + params.current_roles = roles; + params.current_roles_with_admin_option = roles_with_admin_option; + auto it = enabled_roles.find(params); + if (it != enabled_roles.end()) + { + auto from_cache = it->second.lock(); + if (from_cache) + return from_cache; + enabled_roles.erase(it); + } + + auto res = std::shared_ptr(new EnabledRoles(params)); + collectRolesInfoFor(*res); + enabled_roles.emplace(std::move(params), res); + return res; +} + + +void RoleCache::collectRolesInfo() +{ + /// `mutex` is already locked. + + std::erase_if( + enabled_roles, + [&](const std::pair> & pr) + { + auto elem = pr.second.lock(); + if (!elem) + return true; // remove from the `enabled_roles` map. + collectRolesInfoFor(*elem); + return false; // keep in the `enabled_roles` map. + }); +} + + +void RoleCache::collectRolesInfoFor(EnabledRoles & enabled) +{ + /// `mutex` is already locked. + + /// Collect roles in use. That includes the current roles, the roles granted to the current roles, and so on. + boost::container::flat_map collected_roles; + auto get_role_function = [this](const UUID & id) { return getRole(id); }; + for (const auto & current_role : enabled.params.current_roles) + collectRoles(collected_roles, get_role_function, current_role, true, false); + + for (const auto & current_role : enabled.params.current_roles_with_admin_option) + collectRoles(collected_roles, get_role_function, current_role, true, true); + + /// Collect data from the collected roles. + enabled.setRolesInfo(collectInfoForRoles(collected_roles)); +} + + +RolePtr RoleCache::getRole(const UUID & role_id) +{ + /// `mutex` is already locked. + + auto role_from_cache = cache.get(role_id); + if (role_from_cache) + return role_from_cache->first; + + auto subscription = manager.subscribeForChanges(role_id, + [this, role_id](const UUID &, const AccessEntityPtr & entity) + { + auto changed_role = entity ? typeid_cast(entity) : nullptr; + if (changed_role) + roleChanged(role_id, changed_role); + else + roleRemoved(role_id); + }); + + auto role = manager.tryRead(role_id); + if (role) + { + auto cache_value = Poco::SharedPtr>( + new std::pair{role, std::move(subscription)}); + cache.add(role_id, cache_value); + return role; + } + + return nullptr; +} + + +void RoleCache::roleChanged(const UUID & role_id, const RolePtr & changed_role) +{ + std::lock_guard lock{mutex}; + auto role_from_cache = cache.get(role_id); + if (!role_from_cache) + return; + role_from_cache->first = changed_role; + cache.update(role_id, role_from_cache); + collectRolesInfo(); +} + + +void RoleCache::roleRemoved(const UUID & role_id) +{ + std::lock_guard lock{mutex}; + cache.remove(role_id); + collectRolesInfo(); +} + +} diff --git a/dbms/src/Access/RoleCache.h b/dbms/src/Access/RoleCache.h new file mode 100644 index 00000000000..69f4cb2ebe8 --- /dev/null +++ b/dbms/src/Access/RoleCache.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ +class AccessControlManager; +struct Role; +using RolePtr = std::shared_ptr; + +class RoleCache +{ +public: + RoleCache(const AccessControlManager & manager_); + ~RoleCache(); + + std::shared_ptr getEnabledRoles(const std::vector & current_roles, const std::vector & current_roles_with_admin_option); + +private: + void collectRolesInfo(); + void collectRolesInfoFor(EnabledRoles & enabled); + RolePtr getRole(const UUID & role_id); + void roleChanged(const UUID & role_id, const RolePtr & changed_role); + void roleRemoved(const UUID & role_id); + + const AccessControlManager & manager; + Poco::ExpireCache> cache; + std::map> enabled_roles; + mutable std::mutex mutex; +}; + +} diff --git a/dbms/src/Access/RoleContext.cpp b/dbms/src/Access/RoleContext.cpp deleted file mode 100644 index 291b44027d4..00000000000 --- a/dbms/src/Access/RoleContext.cpp +++ /dev/null @@ -1,200 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include - - -namespace DB -{ -namespace -{ - void makeUnique(std::vector & vec) - { - boost::range::sort(vec); - vec.erase(std::unique(vec.begin(), vec.end()), vec.end()); - } -} - - -RoleContext::RoleContext(const AccessControlManager & manager_, const UUID & current_role_, bool with_admin_option_) - : manager(&manager_), current_role(current_role_), with_admin_option(with_admin_option_) -{ - update(); -} - - -RoleContext::RoleContext(std::vector && children_) - : children(std::move(children_)) -{ - update(); -} - - -RoleContext::~RoleContext() = default; - - -void RoleContext::update() -{ - std::vector handlers_to_notify; - CurrentRolesInfoPtr info_to_notify; - - { - std::lock_guard lock{mutex}; - auto old_info = info; - - updateImpl(); - - if (!handlers.empty() && (!old_info || (*old_info != *info))) - { - boost::range::copy(handlers, std::back_inserter(handlers_to_notify)); - info_to_notify = info; - } - } - - for (const auto & handler : handlers_to_notify) - handler(info_to_notify); -} - - -void RoleContext::updateImpl() -{ - if (!current_role && children.empty()) - { - info = std::make_shared(); - return; - } - - if (!children.empty()) - { - if (subscriptions_for_change_children.empty()) - { - for (const auto & child : children) - subscriptions_for_change_children.emplace_back( - child->subscribeForChanges([this](const CurrentRolesInfoPtr &) { update(); })); - } - - auto new_info = std::make_shared(); - auto & new_info_ref = *new_info; - - for (const auto & child : children) - { - auto child_info = child->getInfo(); - new_info_ref.access.merge(child_info->access); - new_info_ref.access_with_grant_option.merge(child_info->access_with_grant_option); - boost::range::copy(child_info->current_roles, std::back_inserter(new_info_ref.current_roles)); - boost::range::copy(child_info->enabled_roles, std::back_inserter(new_info_ref.enabled_roles)); - boost::range::copy(child_info->enabled_roles_with_admin_option, std::back_inserter(new_info_ref.enabled_roles_with_admin_option)); - boost::range::copy(child_info->names_of_roles, std::inserter(new_info_ref.names_of_roles, new_info_ref.names_of_roles.end())); - } - makeUnique(new_info_ref.current_roles); - makeUnique(new_info_ref.enabled_roles); - makeUnique(new_info_ref.enabled_roles_with_admin_option); - info = new_info; - return; - } - - assert(current_role); - traverseRoles(*current_role, with_admin_option); - - auto new_info = std::make_shared(); - auto & new_info_ref = *new_info; - - for (auto it = roles_map.begin(); it != roles_map.end();) - { - const auto & id = it->first; - auto & entry = it->second; - if (!entry.in_use) - { - it = roles_map.erase(it); - continue; - } - - if (id == *current_role) - new_info_ref.current_roles.push_back(id); - - new_info_ref.enabled_roles.push_back(id); - - if (entry.with_admin_option) - new_info_ref.enabled_roles_with_admin_option.push_back(id); - - new_info_ref.access.merge(entry.role->access); - new_info_ref.access_with_grant_option.merge(entry.role->access_with_grant_option); - new_info_ref.names_of_roles[id] = entry.role->getName(); - - entry.in_use = false; - entry.with_admin_option = false; - ++it; - } - - info = new_info; -} - - -void RoleContext::traverseRoles(const UUID & id_, bool with_admin_option_) -{ - auto it = roles_map.find(id_); - if (it == roles_map.end()) - { - assert(manager); - auto subscription = manager->subscribeForChanges(id_, [this, id_](const UUID &, const AccessEntityPtr & entity) - { - { - std::lock_guard lock{mutex}; - auto it2 = roles_map.find(id_); - if (it2 == roles_map.end()) - return; - if (entity) - it2->second.role = typeid_cast(entity); - else - roles_map.erase(it2); - } - update(); - }); - - auto role = manager->tryRead(id_); - if (!role) - return; - - RoleEntry new_entry; - new_entry.role = role; - new_entry.subscription_for_change_role = std::move(subscription); - it = roles_map.emplace(id_, std::move(new_entry)).first; - } - - RoleEntry & entry = it->second; - entry.with_admin_option |= with_admin_option_; - if (entry.in_use) - return; - - entry.in_use = true; - for (const auto & granted_role : entry.role->granted_roles) - traverseRoles(granted_role, false); - - for (const auto & granted_role : entry.role->granted_roles_with_admin_option) - traverseRoles(granted_role, true); -} - - -CurrentRolesInfoPtr RoleContext::getInfo() const -{ - std::lock_guard lock{mutex}; - return info; -} - - -ext::scope_guard RoleContext::subscribeForChanges(const OnChangeHandler & handler) const -{ - std::lock_guard lock{mutex}; - handlers.push_back(handler); - auto it = std::prev(handlers.end()); - - return [this, it] - { - std::lock_guard lock2{mutex}; - handlers.erase(it); - }; -} -} diff --git a/dbms/src/Access/RoleContext.h b/dbms/src/Access/RoleContext.h deleted file mode 100644 index 5f19adc56de..00000000000 --- a/dbms/src/Access/RoleContext.h +++ /dev/null @@ -1,64 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - - -namespace DB -{ -struct Role; -using RolePtr = std::shared_ptr; -struct CurrentRolesInfo; -using CurrentRolesInfoPtr = std::shared_ptr; -class AccessControlManager; - - -class RoleContext -{ -public: - ~RoleContext(); - - /// Returns all the roles specified in the constructor. - CurrentRolesInfoPtr getInfo() const; - - using OnChangeHandler = std::function; - - /// Called when either the specified roles or the roles granted to the specified roles are changed. - ext::scope_guard subscribeForChanges(const OnChangeHandler & handler) const; - -private: - friend struct ext::shared_ptr_helper; - RoleContext(const AccessControlManager & manager_, const UUID & current_role_, bool with_admin_option_); - RoleContext(std::vector> && children_); - - void update(); - void updateImpl(); - - void traverseRoles(const UUID & id_, bool with_admin_option_); - - const AccessControlManager * manager = nullptr; - std::optional current_role; - bool with_admin_option = false; - std::vector> children; - std::vector subscriptions_for_change_children; - - struct RoleEntry - { - RolePtr role; - ext::scope_guard subscription_for_change_role; - bool with_admin_option = false; - bool in_use = false; - }; - mutable std::unordered_map roles_map; - mutable CurrentRolesInfoPtr info; - mutable std::list handlers; - mutable std::mutex mutex; -}; - -using RoleContextPtr = std::shared_ptr; -} diff --git a/dbms/src/Access/RoleContextFactory.cpp b/dbms/src/Access/RoleContextFactory.cpp deleted file mode 100644 index 3356bc238db..00000000000 --- a/dbms/src/Access/RoleContextFactory.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include -#include - - -namespace DB -{ - -RoleContextFactory::RoleContextFactory(const AccessControlManager & manager_) - : manager(manager_), cache(600000 /* 10 minutes */) {} - - -RoleContextFactory::~RoleContextFactory() = default; - - -RoleContextPtr RoleContextFactory::createContext( - const std::vector & roles, const std::vector & roles_with_admin_option) -{ - if (roles.size() == 1 && roles_with_admin_option.empty()) - return createContextImpl(roles[0], false); - - if (roles.size() == 1 && roles_with_admin_option == roles) - return createContextImpl(roles[0], true); - - std::vector children; - children.reserve(roles.size()); - for (const auto & role : roles_with_admin_option) - children.push_back(createContextImpl(role, true)); - - boost::container::flat_set roles_with_admin_option_set{roles_with_admin_option.begin(), roles_with_admin_option.end()}; - for (const auto & role : roles) - { - if (!roles_with_admin_option_set.contains(role)) - children.push_back(createContextImpl(role, false)); - } - - return ext::shared_ptr_helper::create(std::move(children)); -} - - -RoleContextPtr RoleContextFactory::createContextImpl(const UUID & id, bool with_admin_option) -{ - std::lock_guard lock{mutex}; - auto key = std::make_pair(id, with_admin_option); - auto x = cache.get(key); - if (x) - return *x; - auto res = ext::shared_ptr_helper::create(manager, id, with_admin_option); - cache.add(key, res); - return res; -} - -} diff --git a/dbms/src/Access/RoleContextFactory.h b/dbms/src/Access/RoleContextFactory.h deleted file mode 100644 index 659c9a218a1..00000000000 --- a/dbms/src/Access/RoleContextFactory.h +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include -#include - - -namespace DB -{ -class AccessControlManager; - - -class RoleContextFactory -{ -public: - RoleContextFactory(const AccessControlManager & manager_); - ~RoleContextFactory(); - - RoleContextPtr createContext(const std::vector & roles, const std::vector & roles_with_admin_option); - -private: - RoleContextPtr createContextImpl(const UUID & id, bool with_admin_option); - - const AccessControlManager & manager; - Poco::ExpireCache, RoleContextPtr> cache; - std::mutex mutex; -}; - -} diff --git a/dbms/src/Access/RowPolicy.cpp b/dbms/src/Access/RowPolicy.cpp index d5a28d14bb8..65b9451a453 100644 --- a/dbms/src/Access/RowPolicy.cpp +++ b/dbms/src/Access/RowPolicy.cpp @@ -77,11 +77,11 @@ bool RowPolicy::equal(const IAccessEntity & other) const const auto & other_policy = typeid_cast(other); return (database == other_policy.database) && (table_name == other_policy.table_name) && (policy_name == other_policy.policy_name) && boost::range::equal(conditions, other_policy.conditions) && restrictive == other_policy.restrictive - && (roles == other_policy.roles); + && (to_roles == other_policy.to_roles); } -const char * RowPolicy::conditionIndexToString(ConditionIndex index) +const char * RowPolicy::conditionTypeToString(ConditionType index) { switch (index) { @@ -95,7 +95,7 @@ const char * RowPolicy::conditionIndexToString(ConditionIndex index) } -const char * RowPolicy::conditionIndexToColumnName(ConditionIndex index) +const char * RowPolicy::conditionTypeToColumnName(ConditionType index) { switch (index) { diff --git a/dbms/src/Access/RowPolicy.h b/dbms/src/Access/RowPolicy.h index 6bc51a2481c..08219edb46b 100644 --- a/dbms/src/Access/RowPolicy.h +++ b/dbms/src/Access/RowPolicy.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include namespace DB @@ -37,7 +37,7 @@ struct RowPolicy : public IAccessEntity /// Check is a SQL condition expression used to check whether a row can be written into /// the table. If the expression returns NULL or false an exception is thrown. /// If a conditional expression here is empty it means no filtering is applied. - enum ConditionIndex + enum ConditionType { SELECT_FILTER, INSERT_CHECK, @@ -45,11 +45,11 @@ struct RowPolicy : public IAccessEntity UPDATE_CHECK, DELETE_FILTER, }; - static constexpr size_t MAX_CONDITION_INDEX = 5; - static const char * conditionIndexToString(ConditionIndex index); - static const char * conditionIndexToColumnName(ConditionIndex index); + static constexpr size_t MAX_CONDITION_TYPE = 5; + static const char * conditionTypeToString(ConditionType index); + static const char * conditionTypeToColumnName(ConditionType index); - String conditions[MAX_CONDITION_INDEX]; + String conditions[MAX_CONDITION_TYPE]; /// Sets that the policy is permissive. /// A row is only accessible if at least one of the permissive policies passes, @@ -67,7 +67,7 @@ struct RowPolicy : public IAccessEntity std::shared_ptr clone() const override { return cloneImpl(); } /// Which roles or users should use this row policy. - GenericRoleSet roles; + ExtendedRoleSet to_roles; private: String database; diff --git a/dbms/src/Access/RowPolicyContextFactory.cpp b/dbms/src/Access/RowPolicyCache.cpp similarity index 68% rename from dbms/src/Access/RowPolicyContextFactory.cpp rename to dbms/src/Access/RowPolicyCache.cpp index e891f43b5eb..9509923adbf 100644 --- a/dbms/src/Access/RowPolicyContextFactory.cpp +++ b/dbms/src/Access/RowPolicyCache.cpp @@ -1,5 +1,5 @@ -#include -#include +#include +#include #include #include #include @@ -92,8 +92,8 @@ namespace } - using ConditionIndex = RowPolicy::ConditionIndex; - constexpr size_t MAX_CONDITION_INDEX = RowPolicy::MAX_CONDITION_INDEX; + using ConditionType = RowPolicy::ConditionType; + constexpr size_t MAX_CONDITION_TYPE = RowPolicy::MAX_CONDITION_TYPE; /// Accumulates conditions from multiple row policies and joins them using the AND logical operation. @@ -124,24 +124,24 @@ namespace } -void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_) +void RowPolicyCache::PolicyInfo::setPolicy(const RowPolicyPtr & policy_) { policy = policy_; - roles = &policy->roles; + roles = &policy->to_roles; - for (auto index : ext::range_with_static_cast(0, MAX_CONDITION_INDEX)) + for (auto type : ext::range_with_static_cast(0, MAX_CONDITION_TYPE)) { - parsed_conditions[index] = nullptr; - const String & condition = policy->conditions[index]; + parsed_conditions[type] = nullptr; + const String & condition = policy->conditions[type]; if (condition.empty()) continue; - auto previous_range = std::pair(std::begin(policy->conditions), std::begin(policy->conditions) + index); + auto previous_range = std::pair(std::begin(policy->conditions), std::begin(policy->conditions) + type); auto previous_it = std::find(previous_range.first, previous_range.second, condition); if (previous_it != previous_range.second) { /// The condition is already parsed before. - parsed_conditions[index] = parsed_conditions[previous_it - previous_range.first]; + parsed_conditions[type] = parsed_conditions[previous_it - previous_range.first]; continue; } @@ -149,45 +149,52 @@ void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_ try { ParserExpression parser; - parsed_conditions[index] = parseQuery(parser, condition, 0); + parsed_conditions[type] = parseQuery(parser, condition, 0); } catch (...) { tryLogCurrentException( &Poco::Logger::get("RowPolicy"), - String("Could not parse the condition ") + RowPolicy::conditionIndexToString(index) + " of row policy " + String("Could not parse the condition ") + RowPolicy::conditionTypeToString(type) + " of row policy " + backQuote(policy->getFullName())); } } } -bool RowPolicyContextFactory::PolicyInfo::canUseWithContext(const RowPolicyContext & context) const -{ - return roles->match(context.user_id, context.enabled_roles); -} - - -RowPolicyContextFactory::RowPolicyContextFactory(const AccessControlManager & access_control_manager_) +RowPolicyCache::RowPolicyCache(const AccessControlManager & access_control_manager_) : access_control_manager(access_control_manager_) { } -RowPolicyContextFactory::~RowPolicyContextFactory() = default; +RowPolicyCache::~RowPolicyCache() = default; -RowPolicyContextPtr RowPolicyContextFactory::createContext(const UUID & user_id, const std::vector & enabled_roles) +std::shared_ptr RowPolicyCache::getEnabledRowPolicies(const UUID & user_id, const std::vector & enabled_roles) { std::lock_guard lock{mutex}; ensureAllRowPoliciesRead(); - auto context = ext::shared_ptr_helper::create(user_id, enabled_roles); - contexts.push_back(context); - mixConditionsForContext(*context); - return context; + + EnabledRowPolicies::Params params; + params.user_id = user_id; + params.enabled_roles = enabled_roles; + auto it = enabled_row_policies.find(params); + if (it != enabled_row_policies.end()) + { + auto from_cache = it->second.lock(); + if (from_cache) + return from_cache; + enabled_row_policies.erase(it); + } + + auto res = std::shared_ptr(new EnabledRowPolicies(params)); + enabled_row_policies.emplace(std::move(params), res); + mixConditionsFor(*res); + return res; } -void RowPolicyContextFactory::ensureAllRowPoliciesRead() +void RowPolicyCache::ensureAllRowPoliciesRead() { /// `mutex` is already locked. if (all_policies_read) @@ -212,7 +219,7 @@ void RowPolicyContextFactory::ensureAllRowPoliciesRead() } -void RowPolicyContextFactory::rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy) +void RowPolicyCache::rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy) { std::lock_guard lock{mutex}; auto it = all_policies.find(policy_id); @@ -228,46 +235,46 @@ void RowPolicyContextFactory::rowPolicyAddedOrChanged(const UUID & policy_id, co auto & info = it->second; info.setPolicy(new_policy); - mixConditionsForAllContexts(); + mixConditions(); } -void RowPolicyContextFactory::rowPolicyRemoved(const UUID & policy_id) +void RowPolicyCache::rowPolicyRemoved(const UUID & policy_id) { std::lock_guard lock{mutex}; all_policies.erase(policy_id); - mixConditionsForAllContexts(); + mixConditions(); } -void RowPolicyContextFactory::mixConditionsForAllContexts() +void RowPolicyCache::mixConditions() { /// `mutex` is already locked. - boost::range::remove_erase_if( - contexts, - [&](const std::weak_ptr & weak) + std::erase_if( + enabled_row_policies, + [&](const std::pair> & pr) { - auto context = weak.lock(); - if (!context) - return true; // remove from the `contexts` list. - mixConditionsForContext(*context); - return false; // keep in the `contexts` list. + auto elem = pr.second.lock(); + if (!elem) + return true; // remove from the `enabled_row_policies` map. + mixConditionsFor(*elem); + return false; // keep in the `enabled_row_policies` map. }); } -void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context) +void RowPolicyCache::mixConditionsFor(EnabledRowPolicies & enabled) { /// `mutex` is already locked. struct Mixers { - ConditionsMixer mixers[MAX_CONDITION_INDEX]; + ConditionsMixer mixers[MAX_CONDITION_TYPE]; std::vector policy_ids; }; - using MapOfMixedConditions = RowPolicyContext::MapOfMixedConditions; - using DatabaseAndTableName = RowPolicyContext::DatabaseAndTableName; - using DatabaseAndTableNameRef = RowPolicyContext::DatabaseAndTableNameRef; - using Hash = RowPolicyContext::Hash; + using MapOfMixedConditions = EnabledRowPolicies::MapOfMixedConditions; + using DatabaseAndTableName = EnabledRowPolicies::DatabaseAndTableName; + using DatabaseAndTableNameRef = EnabledRowPolicies::DatabaseAndTableNameRef; + using Hash = EnabledRowPolicies::Hash; std::unordered_map map_of_mixers; @@ -275,12 +282,12 @@ void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context { const auto & policy = *info.policy; auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}]; - if (info.canUseWithContext(context)) + if (info.roles->match(enabled.params.user_id, enabled.params.enabled_roles)) { mixers.policy_ids.push_back(policy_id); - for (auto index : ext::range(0, MAX_CONDITION_INDEX)) - if (info.parsed_conditions[index]) - mixers.mixers[index].add(info.parsed_conditions[index], policy.isRestrictive()); + for (auto type : ext::range(0, MAX_CONDITION_TYPE)) + if (info.parsed_conditions[type]) + mixers.mixers[type].add(info.parsed_conditions[type], policy.isRestrictive()); } } @@ -294,11 +301,11 @@ void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context database_and_table_name_keeper->second}]; mixed_conditions.database_and_table_name_keeper = std::move(database_and_table_name_keeper); mixed_conditions.policy_ids = std::move(mixers.policy_ids); - for (auto index : ext::range(0, MAX_CONDITION_INDEX)) - mixed_conditions.mixed_conditions[index] = std::move(mixers.mixers[index]).getResult(); + for (auto type : ext::range(0, MAX_CONDITION_TYPE)) + mixed_conditions.mixed_conditions[type] = std::move(mixers.mixers[type]).getResult(); } - context.map_of_mixed_conditions.store(map_of_mixed_conditions); + enabled.map_of_mixed_conditions.store(map_of_mixed_conditions); } } diff --git a/dbms/src/Access/RowPolicyContextFactory.h b/dbms/src/Access/RowPolicyCache.h similarity index 56% rename from dbms/src/Access/RowPolicyContextFactory.h rename to dbms/src/Access/RowPolicyCache.h index d93d1626b24..d0ec74b9ab8 100644 --- a/dbms/src/Access/RowPolicyContextFactory.h +++ b/dbms/src/Access/RowPolicyCache.h @@ -1,8 +1,9 @@ #pragma once -#include +#include #include #include +#include #include @@ -11,39 +12,38 @@ namespace DB class AccessControlManager; /// Stores read and parsed row policies. -class RowPolicyContextFactory +class RowPolicyCache { public: - RowPolicyContextFactory(const AccessControlManager & access_control_manager_); - ~RowPolicyContextFactory(); + RowPolicyCache(const AccessControlManager & access_control_manager_); + ~RowPolicyCache(); - RowPolicyContextPtr createContext(const UUID & user_id, const std::vector & enabled_roles); + std::shared_ptr getEnabledRowPolicies(const UUID & user_id, const std::vector & enabled_roles); private: - using ParsedConditions = RowPolicyContext::ParsedConditions; + using ParsedConditions = EnabledRowPolicies::ParsedConditions; struct PolicyInfo { PolicyInfo(const RowPolicyPtr & policy_) { setPolicy(policy_); } void setPolicy(const RowPolicyPtr & policy_); - bool canUseWithContext(const RowPolicyContext & context) const; RowPolicyPtr policy; - const GenericRoleSet * roles = nullptr; + const ExtendedRoleSet * roles = nullptr; ParsedConditions parsed_conditions; }; void ensureAllRowPoliciesRead(); void rowPolicyAddedOrChanged(const UUID & policy_id, const RowPolicyPtr & new_policy); void rowPolicyRemoved(const UUID & policy_id); - void mixConditionsForAllContexts(); - void mixConditionsForContext(RowPolicyContext & context); + void mixConditions(); + void mixConditionsFor(EnabledRowPolicies & enabled); const AccessControlManager & access_control_manager; std::unordered_map all_policies; bool all_policies_read = false; ext::scope_guard subscription; - std::vector> contexts; + std::map> enabled_row_policies; std::mutex mutex; }; diff --git a/dbms/src/Access/SettingsProfile.cpp b/dbms/src/Access/SettingsProfile.cpp new file mode 100644 index 00000000000..c2f868502c0 --- /dev/null +++ b/dbms/src/Access/SettingsProfile.cpp @@ -0,0 +1,13 @@ +#include + + +namespace DB +{ +bool SettingsProfile::equal(const IAccessEntity & other) const +{ + if (!IAccessEntity::equal(other)) + return false; + const auto & other_profile = typeid_cast(other); + return (elements == other_profile.elements) && (to_roles == other_profile.to_roles); +} +} diff --git a/dbms/src/Access/SettingsProfile.h b/dbms/src/Access/SettingsProfile.h new file mode 100644 index 00000000000..b73b45d57cf --- /dev/null +++ b/dbms/src/Access/SettingsProfile.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include + + +namespace DB +{ +/// Represents a settings profile created by command +/// CREATE SETTINGS PROFILE name SETTINGS x=value MIN=min MAX=max READONLY... TO roles +struct SettingsProfile : public IAccessEntity +{ + SettingsProfileElements elements; + + /// Which roles or users should use this settings profile. + ExtendedRoleSet to_roles; + + bool equal(const IAccessEntity & other) const override; + std::shared_ptr clone() const override { return cloneImpl(); } +}; + +using SettingsProfilePtr = std::shared_ptr; +} diff --git a/dbms/src/Access/SettingsProfileElement.cpp b/dbms/src/Access/SettingsProfileElement.cpp new file mode 100644 index 00000000000..b052f8b5e75 --- /dev/null +++ b/dbms/src/Access/SettingsProfileElement.cpp @@ -0,0 +1,170 @@ +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +SettingsProfileElement::SettingsProfileElement(const ASTSettingsProfileElement & ast) +{ + init(ast, nullptr); +} + +SettingsProfileElement::SettingsProfileElement(const ASTSettingsProfileElement & ast, const AccessControlManager & manager) +{ + init(ast, &manager); +} + +void SettingsProfileElement::init(const ASTSettingsProfileElement & ast, const AccessControlManager * manager) +{ + auto name_to_id = [id_mode{ast.id_mode}, manager](const String & name_) -> UUID + { + if (id_mode) + return parse(name_); + assert(manager); + return manager->getID(name_); + }; + + if (!ast.parent_profile.empty()) + parent_profile = name_to_id(ast.parent_profile); + + if (!ast.name.empty()) + { + name = ast.name; + value = ast.value; + min_value = ast.min_value; + max_value = ast.max_value; + readonly = ast.readonly; + + size_t index = Settings::findIndexStrict(name); + if (!value.isNull()) + value = Settings::valueToCorrespondingType(index, value); + if (!min_value.isNull()) + min_value = Settings::valueToCorrespondingType(index, min_value); + if (!max_value.isNull()) + max_value = Settings::valueToCorrespondingType(index, max_value); + } +} + + +std::shared_ptr SettingsProfileElement::toAST() const +{ + auto ast = std::make_shared(); + ast->id_mode = true; + + if (parent_profile) + ast->parent_profile = ::DB::toString(*parent_profile); + + ast->name = name; + ast->value = value; + ast->min_value = min_value; + ast->max_value = max_value; + ast->readonly = readonly; + + return ast; +} + + +std::shared_ptr SettingsProfileElement::toASTWithNames(const AccessControlManager & manager) const +{ + auto ast = std::make_shared(); + + if (parent_profile) + { + auto parent_profile_name = manager.tryReadName(*parent_profile); + if (parent_profile_name) + ast->parent_profile = *parent_profile_name; + } + + ast->name = name; + ast->value = value; + ast->min_value = min_value; + ast->max_value = max_value; + ast->readonly = readonly; + + return ast; +} + + +SettingsProfileElements::SettingsProfileElements(const ASTSettingsProfileElements & ast) +{ + for (const auto & ast_element : ast.elements) + emplace_back(*ast_element); +} + +SettingsProfileElements::SettingsProfileElements(const ASTSettingsProfileElements & ast, const AccessControlManager & manager) +{ + for (const auto & ast_element : ast.elements) + emplace_back(*ast_element, manager); +} + + +std::shared_ptr SettingsProfileElements::toAST() const +{ + auto res = std::make_shared(); + for (const auto & element : *this) + res->elements.push_back(element.toAST()); + return res; +} + +std::shared_ptr SettingsProfileElements::toASTWithNames(const AccessControlManager & manager) const +{ + auto res = std::make_shared(); + for (const auto & element : *this) + res->elements.push_back(element.toASTWithNames(manager)); + return res; +} + + +void SettingsProfileElements::merge(const SettingsProfileElements & other) +{ + insert(end(), other.begin(), other.end()); +} + + +Settings SettingsProfileElements::toSettings() const +{ + Settings res; + for (const auto & elem : *this) + { + if (!elem.name.empty() && !elem.value.isNull()) + res.set(elem.name, elem.value); + } + return res; +} + +SettingsChanges SettingsProfileElements::toSettingsChanges() const +{ + SettingsChanges res; + for (const auto & elem : *this) + { + if (!elem.name.empty() && !elem.value.isNull()) + res.push_back({elem.name, elem.value}); + } + return res; +} + +SettingsConstraints SettingsProfileElements::toSettingsConstraints() const +{ + SettingsConstraints res; + for (const auto & elem : *this) + { + if (!elem.name.empty()) + { + if (!elem.min_value.isNull()) + res.setMinValue(elem.name, elem.min_value); + if (!elem.max_value.isNull()) + res.setMaxValue(elem.name, elem.max_value); + if (elem.readonly) + res.setReadOnly(elem.name, *elem.readonly); + } + } + return res; +} + +} diff --git a/dbms/src/Access/SettingsProfileElement.h b/dbms/src/Access/SettingsProfileElement.h new file mode 100644 index 00000000000..abcac2567c8 --- /dev/null +++ b/dbms/src/Access/SettingsProfileElement.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ +struct Settings; +struct SettingChange; +using SettingsChanges = std::vector; +class SettingsConstraints; +class ASTSettingsProfileElement; +class ASTSettingsProfileElements; +class AccessControlManager; + + +struct SettingsProfileElement +{ + std::optional parent_profile; + String name; + Field value; + Field min_value; + Field max_value; + std::optional readonly; + + auto toTuple() const { return std::tie(parent_profile, name, value, min_value, max_value, readonly); } + friend bool operator==(const SettingsProfileElement & lhs, const SettingsProfileElement & rhs) { return lhs.toTuple() == rhs.toTuple(); } + friend bool operator!=(const SettingsProfileElement & lhs, const SettingsProfileElement & rhs) { return !(lhs == rhs); } + friend bool operator <(const SettingsProfileElement & lhs, const SettingsProfileElement & rhs) { return lhs.toTuple() < rhs.toTuple(); } + friend bool operator >(const SettingsProfileElement & lhs, const SettingsProfileElement & rhs) { return rhs < lhs; } + friend bool operator <=(const SettingsProfileElement & lhs, const SettingsProfileElement & rhs) { return !(rhs < lhs); } + friend bool operator >=(const SettingsProfileElement & lhs, const SettingsProfileElement & rhs) { return !(lhs < rhs); } + + SettingsProfileElement() {} + + /// The constructor from AST requires the AccessControlManager if `ast.id_mode == false`. + SettingsProfileElement(const ASTSettingsProfileElement & ast); + SettingsProfileElement(const ASTSettingsProfileElement & ast, const AccessControlManager & manager); + std::shared_ptr toAST() const; + std::shared_ptr toASTWithNames(const AccessControlManager & manager) const; + +private: + void init(const ASTSettingsProfileElement & ast, const AccessControlManager * manager); +}; + + +class SettingsProfileElements : public std::vector +{ +public: + SettingsProfileElements() {} + + /// The constructor from AST requires the AccessControlManager if `ast.id_mode == false`. + SettingsProfileElements(const ASTSettingsProfileElements & ast); + SettingsProfileElements(const ASTSettingsProfileElements & ast, const AccessControlManager & manager); + std::shared_ptr toAST() const; + std::shared_ptr toASTWithNames(const AccessControlManager & manager) const; + + void merge(const SettingsProfileElements & other); + + Settings toSettings() const; + SettingsChanges toSettingsChanges() const; + SettingsConstraints toSettingsConstraints() const; +}; + +} diff --git a/dbms/src/Access/SettingsProfilesCache.cpp b/dbms/src/Access/SettingsProfilesCache.cpp new file mode 100644 index 00000000000..552ed324635 --- /dev/null +++ b/dbms/src/Access/SettingsProfilesCache.cpp @@ -0,0 +1,234 @@ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int THERE_IS_NO_PROFILE; +} + + +SettingsProfilesCache::SettingsProfilesCache(const AccessControlManager & manager_) + : manager(manager_) {} + +SettingsProfilesCache::~SettingsProfilesCache() = default; + + +void SettingsProfilesCache::ensureAllProfilesRead() +{ + /// `mutex` is already locked. + if (all_profiles_read) + return; + all_profiles_read = true; + + subscription = manager.subscribeForChanges( + [&](const UUID & id, const AccessEntityPtr & entity) + { + if (entity) + profileAddedOrChanged(id, typeid_cast(entity)); + else + profileRemoved(id); + }); + + for (const UUID & id : manager.findAll()) + { + auto profile = manager.tryRead(id); + if (profile) + { + all_profiles.emplace(id, profile); + profiles_by_name[profile->getName()] = id; + } + } +} + + +void SettingsProfilesCache::profileAddedOrChanged(const UUID & profile_id, const SettingsProfilePtr & new_profile) +{ + std::lock_guard lock{mutex}; + auto it = all_profiles.find(profile_id); + if (it == all_profiles.end()) + { + all_profiles.emplace(profile_id, new_profile); + profiles_by_name[new_profile->getName()] = profile_id; + } + else + { + auto old_profile = it->second; + it->second = new_profile; + if (old_profile->getName() != new_profile->getName()) + profiles_by_name.erase(old_profile->getName()); + profiles_by_name[new_profile->getName()] = profile_id; + } + settings_for_profiles.clear(); + mergeSettingsAndConstraints(); +} + + +void SettingsProfilesCache::profileRemoved(const UUID & profile_id) +{ + std::lock_guard lock{mutex}; + auto it = all_profiles.find(profile_id); + if (it == all_profiles.end()) + return; + profiles_by_name.erase(it->second->getName()); + all_profiles.erase(it); + settings_for_profiles.clear(); + mergeSettingsAndConstraints(); +} + + +void SettingsProfilesCache::setDefaultProfileName(const String & default_profile_name) +{ + std::lock_guard lock{mutex}; + ensureAllProfilesRead(); + + if (default_profile_name.empty()) + { + default_profile_id = {}; + return; + } + + auto it = profiles_by_name.find(default_profile_name); + if (it == profiles_by_name.end()) + throw Exception("Settings profile " + backQuote(default_profile_name) + " not found", ErrorCodes::THERE_IS_NO_PROFILE); + + default_profile_id = it->second; +} + +void SettingsProfilesCache::mergeSettingsAndConstraints() +{ + /// `mutex` is already locked. + std::erase_if( + enabled_settings, + [&](const std::pair> & pr) + { + auto enabled = pr.second.lock(); + if (!enabled) + return true; // remove from the `enabled_settings` list. + mergeSettingsAndConstraintsFor(*enabled); + return false; // keep in the `enabled_settings` list. + }); +} + + +void SettingsProfilesCache::mergeSettingsAndConstraintsFor(EnabledSettings & enabled) const +{ + SettingsProfileElements merged_settings; + if (default_profile_id) + { + SettingsProfileElement new_element; + new_element.parent_profile = *default_profile_id; + merged_settings.emplace_back(new_element); + } + + for (const auto & [profile_id, profile] : all_profiles) + if (profile->to_roles.match(enabled.params.user_id, enabled.params.enabled_roles)) + { + SettingsProfileElement new_element; + new_element.parent_profile = profile_id; + merged_settings.emplace_back(new_element); + } + + merged_settings.merge(enabled.params.settings_from_enabled_roles); + merged_settings.merge(enabled.params.settings_from_user); + + substituteProfiles(merged_settings); + + enabled.setSettingsAndConstraints( + std::make_shared(merged_settings.toSettings()), + std::make_shared(merged_settings.toSettingsConstraints())); +} + + +void SettingsProfilesCache::substituteProfiles(SettingsProfileElements & elements) const +{ + bool stop_substituting = false; + boost::container::flat_set already_substituted; + while (!stop_substituting) + { + stop_substituting = true; + for (size_t i = 0; i != elements.size(); ++i) + { + auto & element = elements[i]; + if (!element.parent_profile) + continue; + + auto parent_profile_id = *element.parent_profile; + element.parent_profile.reset(); + if (already_substituted.contains(parent_profile_id)) + continue; + + already_substituted.insert(parent_profile_id); + auto parent_profile = all_profiles.find(parent_profile_id); + if (parent_profile == all_profiles.end()) + continue; + + const auto & parent_profile_elements = parent_profile->second->elements; + elements.insert(elements.begin() + i + 1, parent_profile_elements.begin(), parent_profile_elements.end()); + i += parent_profile_elements.size(); + stop_substituting = false; + } + } +} + + +std::shared_ptr SettingsProfilesCache::getEnabledSettings( + const UUID & user_id, + const SettingsProfileElements & settings_from_user, + const std::vector & enabled_roles, + const SettingsProfileElements & settings_from_enabled_roles) +{ + std::lock_guard lock{mutex}; + ensureAllProfilesRead(); + + EnabledSettings::Params params; + params.user_id = user_id; + params.settings_from_user = settings_from_user; + params.enabled_roles = enabled_roles; + params.settings_from_enabled_roles = settings_from_enabled_roles; + + auto it = enabled_settings.find(params); + if (it != enabled_settings.end()) + { + auto from_cache = it->second.lock(); + if (from_cache) + return from_cache; + enabled_settings.erase(it); + } + + std::shared_ptr res(new EnabledSettings(params)); + enabled_settings.emplace(std::move(params), res); + mergeSettingsAndConstraintsFor(*res); + return res; +} + + +std::shared_ptr SettingsProfilesCache::getProfileSettings(const String & profile_name) +{ + std::lock_guard lock{mutex}; + ensureAllProfilesRead(); + + auto it = profiles_by_name.find(profile_name); + if (it == profiles_by_name.end()) + throw Exception("Settings profile " + backQuote(profile_name) + " not found", ErrorCodes::THERE_IS_NO_PROFILE); + const UUID profile_id = it->second; + + auto it2 = settings_for_profiles.find(profile_id); + if (it2 != settings_for_profiles.end()) + return it2->second; + + SettingsProfileElements elements = all_profiles[profile_id]->elements; + substituteProfiles(elements); + auto res = std::make_shared(elements.toSettingsChanges()); + settings_for_profiles.emplace(profile_id, res); + return res; +} + + +} diff --git a/dbms/src/Access/SettingsProfilesCache.h b/dbms/src/Access/SettingsProfilesCache.h new file mode 100644 index 00000000000..656ffc6fce6 --- /dev/null +++ b/dbms/src/Access/SettingsProfilesCache.h @@ -0,0 +1,55 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +class AccessControlManager; +struct SettingsProfile; +using SettingsProfilePtr = std::shared_ptr; +class SettingsProfileElements; +class EnabledSettings; + + +/// Reads and caches all the settings profiles. +class SettingsProfilesCache +{ +public: + SettingsProfilesCache(const AccessControlManager & manager_); + ~SettingsProfilesCache(); + + void setDefaultProfileName(const String & default_profile_name); + + std::shared_ptr getEnabledSettings( + const UUID & user_id, + const SettingsProfileElements & settings_from_user_, + const std::vector & enabled_roles, + const SettingsProfileElements & settings_from_enabled_roles_); + + std::shared_ptr getProfileSettings(const String & profile_name); + +private: + void ensureAllProfilesRead(); + void profileAddedOrChanged(const UUID & profile_id, const SettingsProfilePtr & new_profile); + void profileRemoved(const UUID & profile_id); + void mergeSettingsAndConstraints(); + void mergeSettingsAndConstraintsFor(EnabledSettings & enabled) const; + void substituteProfiles(SettingsProfileElements & elements) const; + + const AccessControlManager & manager; + std::unordered_map all_profiles; + std::unordered_map profiles_by_name; + bool all_profiles_read = false; + ext::scope_guard subscription; + std::map> enabled_settings; + std::optional default_profile_id; + std::unordered_map> settings_for_profiles; + mutable std::mutex mutex; +}; +} diff --git a/dbms/src/Access/User.cpp b/dbms/src/Access/User.cpp index bc5b062db6a..4a751c31e25 100644 --- a/dbms/src/Access/User.cpp +++ b/dbms/src/Access/User.cpp @@ -12,7 +12,7 @@ bool User::equal(const IAccessEntity & other) const return (authentication == other_user.authentication) && (allowed_client_hosts == other_user.allowed_client_hosts) && (access == other_user.access) && (access_with_grant_option == other_user.access_with_grant_option) && (granted_roles == other_user.granted_roles) && (granted_roles_with_admin_option == other_user.granted_roles_with_admin_option) - && (default_roles == other_user.default_roles) && (profile == other_user.profile); + && (default_roles == other_user.default_roles) && (settings == other_user.settings); } } diff --git a/dbms/src/Access/User.h b/dbms/src/Access/User.h index 3a9b3cd7014..6df3b3e4d3c 100644 --- a/dbms/src/Access/User.h +++ b/dbms/src/Access/User.h @@ -4,7 +4,8 @@ #include #include #include -#include +#include +#include #include #include @@ -21,8 +22,8 @@ struct User : public IAccessEntity AccessRights access_with_grant_option; boost::container::flat_set granted_roles; boost::container::flat_set granted_roles_with_admin_option; - GenericRoleSet default_roles = GenericRoleSet::AllTag{}; - String profile; + ExtendedRoleSet default_roles = ExtendedRoleSet::AllTag{}; + SettingsProfileElements settings; bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } diff --git a/dbms/src/Access/UsersConfigAccessStorage.cpp b/dbms/src/Access/UsersConfigAccessStorage.cpp index 20ee2a628a6..13102528108 100644 --- a/dbms/src/Access/UsersConfigAccessStorage.cpp +++ b/dbms/src/Access/UsersConfigAccessStorage.cpp @@ -2,11 +2,15 @@ #include #include #include +#include #include #include #include #include #include +#include +#include +#include #include @@ -16,6 +20,7 @@ namespace ErrorCodes { extern const int BAD_ARGUMENTS; extern const int UNKNOWN_ADDRESS_PATTERN_TYPE; + extern const int NOT_IMPLEMENTED; } @@ -29,6 +34,8 @@ namespace return 'Q'; if (type == typeid(RowPolicy)) return 'P'; + if (type == typeid(SettingsProfile)) + return 'S'; return 0; } @@ -82,7 +89,14 @@ namespace user->authentication.setPasswordHashHex(config.getString(user_config + ".password_double_sha1_hex")); } - user->profile = config.getString(user_config + ".profile"); + const auto profile_name_config = user_config + ".profile"; + if (config.has(profile_name_config)) + { + auto profile_name = config.getString(profile_name_config); + SettingsProfileElement profile_element; + profile_element.parent_profile = generateID(typeid(SettingsProfile), profile_name); + user->settings.push_back(std::move(profile_element)); + } /// Fill list of allowed hosts. const auto networks_config = user_config + ".networks"; @@ -141,19 +155,18 @@ namespace if (databases) { - user->access.fullRevoke(AccessFlags::databaseLevel()); + user->access.revoke(AccessFlags::allFlags() - AccessFlags::allGlobalFlags()); + user->access.grant(AccessFlags::allDictionaryFlags(), IDictionary::NO_DATABASE_TAG); for (const String & database : *databases) - user->access.grant(AccessFlags::databaseLevel(), database); + user->access.grant(AccessFlags::allFlags(), database); } if (dictionaries) { - user->access.fullRevoke(AccessType::dictGet, IDictionary::NO_DATABASE_TAG); + user->access.revoke(AccessFlags::allDictionaryFlags(), IDictionary::NO_DATABASE_TAG); for (const String & dictionary : *dictionaries) - user->access.grant(AccessType::dictGet, IDictionary::NO_DATABASE_TAG, dictionary); + user->access.grant(AccessFlags::allDictionaryFlags(), IDictionary::NO_DATABASE_TAG, dictionary); } - else if (databases) - user->access.grant(AccessType::dictGet, IDictionary::NO_DATABASE_TAG); user->access_with_grant_option = user->access; @@ -225,7 +238,7 @@ namespace limits.max[ResourceType::EXECUTION_TIME] = Quota::secondsToExecutionTime(config.getUInt64(interval_config + ".execution_time", Quota::UNLIMITED)); } - quota->roles.add(user_ids); + quota->to_roles.add(user_ids); return quota; } @@ -325,12 +338,99 @@ namespace auto policy = std::make_shared(); policy->setFullName(database, table_name, user_name); policy->conditions[RowPolicy::SELECT_FILTER] = filter; - policy->roles.add(generateID(typeid(User), user_name)); + policy->to_roles.add(generateID(typeid(User), user_name)); policies.push_back(policy); } } return policies; } + + + SettingsProfileElements parseSettingsConstraints(const Poco::Util::AbstractConfiguration & config, + const String & path_to_constraints) + { + SettingsProfileElements profile_elements; + Poco::Util::AbstractConfiguration::Keys names; + config.keys(path_to_constraints, names); + for (const String & name : names) + { + SettingsProfileElement profile_element; + profile_element.name = name; + Poco::Util::AbstractConfiguration::Keys constraint_types; + String path_to_name = path_to_constraints + "." + name; + config.keys(path_to_name, constraint_types); + for (const String & constraint_type : constraint_types) + { + if (constraint_type == "min") + profile_element.min_value = config.getString(path_to_name + "." + constraint_type); + else if (constraint_type == "max") + profile_element.max_value = config.getString(path_to_name + "." + constraint_type); + else if (constraint_type == "readonly") + profile_element.readonly = true; + else + throw Exception("Setting " + constraint_type + " value for " + name + " isn't supported", ErrorCodes::NOT_IMPLEMENTED); + } + profile_elements.push_back(std::move(profile_element)); + } + return profile_elements; + } + + std::shared_ptr parseSettingsProfile( + const Poco::Util::AbstractConfiguration & config, + const String & profile_name) + { + auto profile = std::make_shared(); + profile->setName(profile_name); + String profile_config = "profiles." + profile_name; + + Poco::Util::AbstractConfiguration::Keys keys; + config.keys(profile_config, keys); + + for (const std::string & key : keys) + { + if (key == "profile" || key.starts_with("profile[")) + { + String parent_profile_name = config.getString(profile_config + "." + key); + SettingsProfileElement profile_element; + profile_element.parent_profile = generateID(typeid(SettingsProfile), parent_profile_name); + profile->elements.emplace_back(std::move(profile_element)); + continue; + } + + if (key == "constraints" || key.starts_with("constraints[")) + { + profile->elements.merge(parseSettingsConstraints(config, profile_config + "." + key)); + continue; + } + + SettingsProfileElement profile_element; + profile_element.name = key; + profile_element.value = config.getString(profile_config + "." + key); + profile->elements.emplace_back(std::move(profile_element)); + } + + return profile; + } + + + std::vector parseSettingsProfiles(const Poco::Util::AbstractConfiguration & config, Poco::Logger * log) + { + std::vector profiles; + Poco::Util::AbstractConfiguration::Keys profile_names; + config.keys("profiles", profile_names); + for (const auto & profile_name : profile_names) + { + try + { + profiles.push_back(parseSettingsProfile(config, profile_name)); + } + catch (...) + { + tryLogCurrentException(log, "Could not parse profile " + backQuote(profile_name)); + } + } + return profiles; + } } @@ -348,6 +448,8 @@ void UsersConfigAccessStorage::setConfiguration(const Poco::Util::AbstractConfig all_entities.emplace_back(generateID(*entity), entity); for (const auto & entity : parseRowPolicies(config, getLogger())) all_entities.emplace_back(generateID(*entity), entity); + for (const auto & entity : parseSettingsProfiles(config, getLogger())) + all_entities.emplace_back(generateID(*entity), entity); memory_storage.setAll(all_entities); } diff --git a/dbms/src/Core/Settings.h b/dbms/src/Core/Settings.h index 445641b0e29..72d74abd95c 100644 --- a/dbms/src/Core/Settings.h +++ b/dbms/src/Core/Settings.h @@ -395,7 +395,6 @@ struct Settings : public SettingsCollection M(SettingBool, allow_experimental_alter_materialized_view_structure, false, "Allow atomic alter on Materialized views. Work in progress.", 0) \ M(SettingBool, enable_early_constant_folding, true, "Enable query optimization where we analyze function and subqueries results and rewrite query if there're constants there", 0) \ \ - M(SettingBool, partial_revokes, false, "Makes it possible to revoke privileges partially.", 0) \ M(SettingBool, deduplicate_blocks_in_dependent_materialized_views, false, "Should deduplicate blocks for materialized views if the block is not a duplicate for the table. Use true to always deduplicate in dependent tables.", 0) \ M(SettingBool, use_compact_format_in_distributed_parts_names, false, "Changes format of directories names for distributed table insert parts.", 0) \ M(SettingUInt64, multiple_joins_rewriter_version, 1, "1 or 2. Second rewriter version knows about table columns and keep not clashed names as is.", 0) \ diff --git a/dbms/src/Core/SettingsCollection.cpp b/dbms/src/Core/SettingsCollection.cpp index d45c082eb0b..6d879b27181 100644 --- a/dbms/src/Core/SettingsCollection.cpp +++ b/dbms/src/Core/SettingsCollection.cpp @@ -165,7 +165,7 @@ void SettingMaxThreads::set(const Field & x) if (x.getType() == Field::Types::String) set(get(x)); else - set(safeGet(x)); + set(applyVisitor(FieldVisitorConvertToNumber(), x)); } void SettingMaxThreads::set(const String & x) @@ -246,7 +246,7 @@ void SettingTimespan::set(const Field & x) if (x.getType() == Field::Types::String) set(get(x)); else - set(safeGet(x)); + set(applyVisitor(FieldVisitorConvertToNumber(), x)); } template diff --git a/dbms/src/DataStreams/IBlockInputStream.cpp b/dbms/src/DataStreams/IBlockInputStream.cpp index 4bccdff6848..733bafdcf71 100644 --- a/dbms/src/DataStreams/IBlockInputStream.cpp +++ b/dbms/src/DataStreams/IBlockInputStream.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include diff --git a/dbms/src/DataStreams/IBlockInputStream.h b/dbms/src/DataStreams/IBlockInputStream.h index 6fe8be079d8..aacd12bacd9 100644 --- a/dbms/src/DataStreams/IBlockInputStream.h +++ b/dbms/src/DataStreams/IBlockInputStream.h @@ -21,8 +21,7 @@ namespace ErrorCodes } class ProcessListElement; -class QuotaContext; -using QuotaContextPtr = std::shared_ptr; +class EnabledQuota; class QueryStatus; struct SortColumnDescription; using SortDescription = std::vector; @@ -219,7 +218,7 @@ public: /** Set the quota. If you set a quota on the amount of raw data, * then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits. */ - virtual void setQuota(const QuotaContextPtr & quota_) + virtual void setQuota(const std::shared_ptr & quota_) { quota = quota_; } @@ -277,7 +276,7 @@ private: LocalLimits limits; - QuotaContextPtr quota; /// If nullptr - the quota is not used. + std::shared_ptr quota; /// If nullptr - the quota is not used. UInt64 prev_elapsed = 0; /// The approximate total number of rows to read. For progress bar. diff --git a/dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp b/dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp index f6dbf0b6c0b..991d206777a 100644 --- a/dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp +++ b/dbms/src/DataStreams/PushingToViewsBlockOutputStream.cpp @@ -42,7 +42,7 @@ PushingToViewsBlockOutputStream::PushingToViewsBlockOutputStream( views_context = std::make_unique(context); // Do not deduplicate insertions into MV if the main insertion is Ok if (disable_deduplication_for_children) - views_context->getSettingsRef().insert_deduplicate = false; + views_context->setSetting("insert_deduplicate", false); } for (const auto & database_table : dependencies) diff --git a/dbms/src/Dictionaries/ClickHouseDictionarySource.cpp b/dbms/src/Dictionaries/ClickHouseDictionarySource.cpp index 97ae125abbb..e7f38173d8a 100644 --- a/dbms/src/Dictionaries/ClickHouseDictionarySource.cpp +++ b/dbms/src/Dictionaries/ClickHouseDictionarySource.cpp @@ -74,7 +74,7 @@ ClickHouseDictionarySource::ClickHouseDictionarySource( /// We should set user info even for the case when the dictionary is loaded in-process (without TCP communication). context.setUser(user, password, Poco::Net::SocketAddress("127.0.0.1", 0), {}); /// Processors are not supported here yet. - context.getSettingsRef().experimental_use_processors = false; + context.setSetting("experimental_use_processors", false); /// Query context is needed because some code in executeQuery function may assume it exists. /// Current example is Context::getSampleBlockCache from InterpreterSelectWithUnionQuery::getSampleBlock. context.makeQueryContext(); diff --git a/dbms/src/Functions/currentQuota.cpp b/dbms/src/Functions/currentQuota.cpp index d292627d1ca..b16a8a7c1ec 100644 --- a/dbms/src/Functions/currentQuota.cpp +++ b/dbms/src/Functions/currentQuota.cpp @@ -3,7 +3,8 @@ #include #include #include -#include +#include +#include #include diff --git a/dbms/src/Functions/currentRowPolicies.cpp b/dbms/src/Functions/currentRowPolicies.cpp index dfebf1552bc..0248f77c9b5 100644 --- a/dbms/src/Functions/currentRowPolicies.cpp +++ b/dbms/src/Functions/currentRowPolicies.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -65,17 +65,20 @@ public: auto database_column = ColumnString::create(); auto table_name_column = ColumnString::create(); auto policy_name_column = ColumnString::create(); - for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs()) + if (auto policies = context.getRowPolicies()) { - const auto policy = context.getAccessControlManager().tryRead(policy_id); - if (policy) + for (const auto & policy_id : policies->getCurrentPolicyIDs()) { - const String database = policy->getDatabase(); - const String table_name = policy->getTableName(); - const String policy_name = policy->getName(); - database_column->insertData(database.data(), database.length()); - table_name_column->insertData(table_name.data(), table_name.length()); - policy_name_column->insertData(policy_name.data(), policy_name.length()); + const auto policy = context.getAccessControlManager().tryRead(policy_id); + if (policy) + { + const String database = policy->getDatabase(); + const String table_name = policy->getTableName(); + const String policy_name = policy->getName(); + database_column->insertData(database.data(), database.length()); + table_name_column->insertData(table_name.data(), table_name.length()); + policy_name_column->insertData(policy_name.data(), policy_name.length()); + } } } auto offset_column = ColumnArray::ColumnOffsets::create(); @@ -113,13 +116,16 @@ public: { String database = database_column ? database_column->getDataAt(i).toString() : context.getCurrentDatabase(); String table_name = table_name_column->getDataAt(i).toString(); - for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs(database, table_name)) + if (auto policies = context.getRowPolicies()) { - const auto policy = context.getAccessControlManager().tryRead(policy_id); - if (policy) + for (const auto & policy_id : policies->getCurrentPolicyIDs(database, table_name)) { - const String policy_name = policy->getName(); - policy_name_column->insertData(policy_name.data(), policy_name.length()); + const auto policy = context.getAccessControlManager().tryRead(policy_id); + if (policy) + { + const String policy_name = policy->getName(); + policy_name_column->insertData(policy_name.data(), policy_name.length()); + } } } offset_column->insertValue(policy_name_column->size()); @@ -169,8 +175,11 @@ public: if (arguments.empty()) { auto policy_id_column = ColumnVector::create(); - for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs()) - policy_id_column->insertValue(policy_id); + if (auto policies = context.getRowPolicies()) + { + for (const auto & policy_id : policies->getCurrentPolicyIDs()) + policy_id_column->insertValue(policy_id); + } auto offset_column = ColumnArray::ColumnOffsets::create(); offset_column->insertValue(policy_id_column->size()); block.getByPosition(result_pos).column @@ -203,8 +212,11 @@ public: { String database = database_column ? database_column->getDataAt(i).toString() : context.getCurrentDatabase(); String table_name = table_name_column->getDataAt(i).toString(); - for (const auto & policy_id : context.getRowPolicy()->getCurrentPolicyIDs(database, table_name)) - policy_id_column->insertValue(policy_id); + if (auto policies = context.getRowPolicies()) + { + for (const auto & policy_id : policies->getCurrentPolicyIDs(database, table_name)) + policy_id_column->insertValue(policy_id); + } offset_column->insertValue(policy_id_column->size()); } diff --git a/dbms/src/IO/WriteHelpers.h b/dbms/src/IO/WriteHelpers.h index 44b0322ee83..10918fb7b61 100644 --- a/dbms/src/IO/WriteHelpers.h +++ b/dbms/src/IO/WriteHelpers.h @@ -239,11 +239,6 @@ inline void writeFloatText(T x, WriteBuffer & buf) } -inline void writeString(const String & s, WriteBuffer & buf) -{ - buf.write(s.data(), s.size()); -} - inline void writeString(const char * data, size_t size, WriteBuffer & buf) { buf.write(data, size); diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index ab9b4a2c31b..6ca4b4a0a2e 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -27,9 +27,10 @@ #include #include #include -#include -#include +#include +#include #include +#include #include #include #include @@ -444,8 +445,6 @@ Context & Context::operator=(const Context &) = default; Context Context::createGlobal() { Context res; - res.access_rights = std::make_shared(); - res.initial_row_policy = std::make_shared(); res.shared = std::make_shared(); return res; } @@ -632,38 +631,38 @@ void Context::setUser(const String & name, const String & password, const Poco:: client_info.quota_key = quota_key; auto new_user_id = getAccessControlManager().find(name); - AccessRightsContextPtr new_access_rights; + std::shared_ptr new_access; if (new_user_id) { - new_access_rights = getAccessControlManager().getAccessRightsContext(*new_user_id, {}, true, settings, current_database, client_info); - if (!new_access_rights->isClientHostAllowed() || !new_access_rights->isCorrectPassword(password)) + new_access = getAccessControlManager().getContextAccess(*new_user_id, {}, true, {}, current_database, client_info); + if (!new_access->isClientHostAllowed() || !new_access->isCorrectPassword(password)) { new_user_id = {}; - new_access_rights = nullptr; + new_access = nullptr; } } - if (!new_user_id || !new_access_rights) + if (!new_user_id || !new_access) throw Exception(name + ": Authentication failed: password is incorrect or there is no user with such name", ErrorCodes::AUTHENTICATION_FAILED); user_id = new_user_id; - access_rights = std::move(new_access_rights); + access = std::move(new_access); current_roles.clear(); use_default_roles = true; - calculateUserSettings(); + setSettings(*access->getDefaultSettings()); } std::shared_ptr Context::getUser() const { auto lock = getLock(); - return access_rights->getUser(); + return access->getUser(); } String Context::getUserName() const { auto lock = getLock(); - return access_rights->getUserName(); + return access->getUserName(); } UUID Context::getUserID() const @@ -697,22 +696,22 @@ void Context::setCurrentRolesDefault() std::vector Context::getCurrentRoles() const { - return getAccessRights()->getCurrentRoles(); + return getAccess()->getCurrentRoles(); } Strings Context::getCurrentRolesNames() const { - return getAccessRights()->getCurrentRolesNames(); + return getAccess()->getCurrentRolesNames(); } std::vector Context::getEnabledRoles() const { - return getAccessRights()->getEnabledRoles(); + return getAccess()->getEnabledRoles(); } Strings Context::getEnabledRolesNames() const { - return getAccessRights()->getEnabledRolesNames(); + return getAccess()->getEnabledRolesNames(); } @@ -720,98 +719,67 @@ void Context::calculateAccessRights() { auto lock = getLock(); if (user_id) - access_rights = getAccessControlManager().getAccessRightsContext(*user_id, current_roles, use_default_roles, settings, current_database, client_info); + access = getAccessControlManager().getContextAccess(*user_id, current_roles, use_default_roles, settings, current_database, client_info); } template void Context::checkAccessImpl(const Args &... args) const { - getAccessRights()->checkAccess(args...); + return getAccess()->checkAccess(args...); } -void Context::checkAccess(const AccessFlags & access) const { return checkAccessImpl(access); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(access, database); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(access, database, table); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(access, database, table, column); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(access, database, table, columns); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(access, database, table, columns); } -void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); } -void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); } +void Context::checkAccess(const AccessFlags & flags) const { return checkAccessImpl(flags); } +void Context::checkAccess(const AccessFlags & flags, const std::string_view & database) const { return checkAccessImpl(flags, database); } +void Context::checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(flags, database, table); } +void Context::checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(flags, database, table, column); } +void Context::checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(flags, database, table, columns); } +void Context::checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(flags, database, table, columns); } +void Context::checkAccess(const AccessFlags & flags, const StorageID & table_id) const { checkAccessImpl(flags, table_id.getDatabaseName(), table_id.getTableName()); } +void Context::checkAccess(const AccessFlags & flags, const StorageID & table_id, const std::string_view & column) const { checkAccessImpl(flags, table_id.getDatabaseName(), table_id.getTableName(), column); } +void Context::checkAccess(const AccessFlags & flags, const StorageID & table_id, const std::vector & columns) const { checkAccessImpl(flags, table_id.getDatabaseName(), table_id.getTableName(), columns); } +void Context::checkAccess(const AccessFlags & flags, const StorageID & table_id, const Strings & columns) const { checkAccessImpl(flags, table_id.getDatabaseName(), table_id.getTableName(), columns); } +void Context::checkAccess(const AccessRightsElement & element) const { return checkAccessImpl(element); } +void Context::checkAccess(const AccessRightsElements & elements) const { return checkAccessImpl(elements); } -void Context::checkAccess(const AccessFlags & access, const StorageID & table_id) const { checkAccessImpl(access, table_id.getDatabaseName(), table_id.getTableName()); } -void Context::checkAccess(const AccessFlags & access, const StorageID & table_id, const std::string_view & column) const { checkAccessImpl(access, table_id.getDatabaseName(), table_id.getTableName(), column); } -void Context::checkAccess(const AccessFlags & access, const StorageID & table_id, const std::vector & columns) const { checkAccessImpl(access, table_id.getDatabaseName(), table_id.getTableName(), columns); } -void Context::checkAccess(const AccessFlags & access, const StorageID & table_id, const Strings & columns) const { checkAccessImpl(access, table_id.getDatabaseName(), table_id.getTableName(), columns); } -AccessRightsContextPtr Context::getAccessRights() const +std::shared_ptr Context::getAccess() const { auto lock = getLock(); - return access_rights; + return access ? access : ContextAccess::getFullAccess(); } -RowPolicyContextPtr Context::getRowPolicy() const +ASTPtr Context::getRowPolicyCondition(const String & database, const String & table_name, RowPolicy::ConditionType type) const { - return getAccessRights()->getRowPolicy(); + auto lock = getLock(); + auto initial_condition = initial_row_policy ? initial_row_policy->getCondition(database, table_name, type) : nullptr; + return getAccess()->getRowPolicyCondition(database, table_name, type, initial_condition); +} + +std::shared_ptr Context::getRowPolicies() const +{ + return getAccess()->getRowPolicies(); } void Context::setInitialRowPolicy() { auto lock = getLock(); auto initial_user_id = getAccessControlManager().find(client_info.initial_user); + initial_row_policy = nullptr; if (initial_user_id) - initial_row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id, {}); -} - -RowPolicyContextPtr Context::getInitialRowPolicy() const -{ - auto lock = getLock(); - return initial_row_policy; + initial_row_policy = getAccessControlManager().getEnabledRowPolicies(*initial_user_id, {}); } -QuotaContextPtr Context::getQuota() const +std::shared_ptr Context::getQuota() const { - return getAccessRights()->getQuota(); + return getAccess()->getQuota(); } -void Context::calculateUserSettings() +void Context::setProfile(const String & profile_name) { - auto lock = getLock(); - String profile = getUser()->profile; - - bool old_readonly = settings.readonly; - bool old_allow_ddl = settings.allow_ddl; - bool old_allow_introspection_functions = settings.allow_introspection_functions; - - /// 1) Set default settings (hardcoded values) - /// NOTE: we ignore global_context settings (from which it is usually copied) - /// NOTE: global_context settings are immutable and not auto updated - settings = Settings(); - settings_constraints = nullptr; - - /// 2) Apply settings from default profile - auto default_profile_name = getDefaultProfileName(); - if (profile != default_profile_name) - setProfile(default_profile_name); - - /// 3) Apply settings from current user - setProfile(profile); - - /// 4) Recalculate access rights if it's necessary. - if ((settings.readonly != old_readonly) || (settings.allow_ddl != old_allow_ddl) || (settings.allow_introspection_functions != old_allow_introspection_functions)) - calculateAccessRights(); -} - -void Context::setProfile(const String & profile) -{ - settings.setProfile(profile, *shared->users_config); - - auto new_constraints - = settings_constraints ? std::make_shared(*settings_constraints) : std::make_shared(); - new_constraints->setProfile(profile, *shared->users_config); - settings_constraints = std::move(new_constraints); + applySettingsChanges(*getAccessControlManager().getProfileSettings(profile_name)); } @@ -936,9 +904,9 @@ Settings Context::getSettings() const void Context::setSettings(const Settings & settings_) { auto lock = getLock(); - bool old_readonly = settings.readonly; - bool old_allow_ddl = settings.allow_ddl; - bool old_allow_introspection_functions = settings.allow_introspection_functions; + auto old_readonly = settings.readonly; + auto old_allow_ddl = settings.allow_ddl; + auto old_allow_introspection_functions = settings.allow_introspection_functions; settings = settings_; @@ -947,7 +915,7 @@ void Context::setSettings(const Settings & settings_) } -void Context::setSetting(const String & name, const String & value) +void Context::setSetting(const StringRef & name, const String & value) { auto lock = getLock(); if (name == "profile") @@ -962,7 +930,7 @@ void Context::setSetting(const String & name, const String & value) } -void Context::setSetting(const String & name, const Field & value) +void Context::setSetting(const StringRef & name, const Field & value) { auto lock = getLock(); if (name == "profile") @@ -993,30 +961,37 @@ void Context::applySettingsChanges(const SettingsChanges & changes) void Context::checkSettingsConstraints(const SettingChange & change) const { - if (settings_constraints) + if (auto settings_constraints = getSettingsConstraints()) settings_constraints->check(settings, change); } void Context::checkSettingsConstraints(const SettingsChanges & changes) const { - if (settings_constraints) + if (auto settings_constraints = getSettingsConstraints()) settings_constraints->check(settings, changes); } void Context::clampToSettingsConstraints(SettingChange & change) const { - if (settings_constraints) + if (auto settings_constraints = getSettingsConstraints()) settings_constraints->clamp(settings, change); } void Context::clampToSettingsConstraints(SettingsChanges & changes) const { - if (settings_constraints) + if (auto settings_constraints = getSettingsConstraints()) settings_constraints->clamp(settings, changes); } +std::shared_ptr Context::getSettingsConstraints() const +{ + auto lock = getLock(); + return access->getSettingsConstraints(); +} + + String Context::getCurrentDatabase() const { auto lock = getLock(); @@ -1877,8 +1852,10 @@ void Context::setApplicationType(ApplicationType type) void Context::setDefaultProfiles(const Poco::Util::AbstractConfiguration & config) { shared->default_profile_name = config.getString("default_profile", "default"); + getAccessControlManager().setDefaultProfileName(shared->default_profile_name); + shared->system_profile_name = config.getString("system_profile", shared->default_profile_name); - setSetting("profile", shared->system_profile_name); + setProfile(shared->system_profile_name); } String Context::getDefaultProfileName() const diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 5d8351ed598..331c89294d0 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -44,14 +45,11 @@ namespace DB struct ContextShared; class Context; -class AccessRightsContext; -using AccessRightsContextPtr = std::shared_ptr; +class ContextAccess; struct User; using UserPtr = std::shared_ptr; -class RowPolicyContext; -using RowPolicyContextPtr = std::shared_ptr; -class QuotaContext; -using QuotaContextPtr = std::shared_ptr; +class EnabledRowPolicies; +class EnabledQuota; class AccessFlags; struct AccessRightsElement; class AccessRightsElements; @@ -151,11 +149,10 @@ private: std::optional user_id; std::vector current_roles; bool use_default_roles = false; - AccessRightsContextPtr access_rights; - RowPolicyContextPtr initial_row_policy; + std::shared_ptr access; + std::shared_ptr initial_row_policy; String current_database; Settings settings; /// Setting for query execution. - std::shared_ptr settings_constraints; using ProgressCallback = std::function; ProgressCallback progress_callback; /// Callback for tracking progress of query execution. QueryStatus * process_list_elem = nullptr; /// For tracking total resource usage for query. @@ -246,31 +243,30 @@ public: /// Checks access rights. /// Empty database means the current database. - void checkAccess(const AccessFlags & access) const; - void checkAccess(const AccessFlags & access, const std::string_view & database) const; - void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; - void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; - void checkAccess(const AccessRightsElement & access) const; - void checkAccess(const AccessRightsElements & access) const; + void checkAccess(const AccessFlags & flags) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + void checkAccess(const AccessFlags & flags, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + void checkAccess(const AccessFlags & flags, const StorageID & table_id) const; + void checkAccess(const AccessFlags & flags, const StorageID & table_id, const std::string_view & column) const; + void checkAccess(const AccessFlags & flags, const StorageID & table_id, const std::vector & columns) const; + void checkAccess(const AccessFlags & flags, const StorageID & table_id, const Strings & columns) const; + void checkAccess(const AccessRightsElement & element) const; + void checkAccess(const AccessRightsElements & elements) const; - void checkAccess(const AccessFlags & access, const StorageID & table_id) const; - void checkAccess(const AccessFlags & access, const StorageID & table_id, const std::string_view & column) const; - void checkAccess(const AccessFlags & access, const StorageID & table_id, const std::vector & columns) const; - void checkAccess(const AccessFlags & access, const StorageID & table_id, const Strings & columns) const; + std::shared_ptr getAccess() const; - AccessRightsContextPtr getAccessRights() const; - - RowPolicyContextPtr getRowPolicy() const; + std::shared_ptr getRowPolicies() const; + ASTPtr getRowPolicyCondition(const String & database, const String & table_name, RowPolicy::ConditionType type) const; /// Sets an extra row policy based on `client_info.initial_user`, if it exists. /// TODO: we need a better solution here. It seems we should pass the initial row policy /// because a shard is allowed to don't have the initial user or it may be another user with the same name. void setInitialRowPolicy(); - RowPolicyContextPtr getInitialRowPolicy() const; - QuotaContextPtr getQuota() const; + std::shared_ptr getQuota() const; /// We have to copy external tables inside executeQuery() to track limits. Therefore, set callback for it. Must set once. void setExternalTablesInitializer(ExternalTablesInitializer && initializer); @@ -344,8 +340,8 @@ public: void setSettings(const Settings & settings_); /// Set settings by name. - void setSetting(const String & name, const String & value); - void setSetting(const String & name, const Field & value); + void setSetting(const StringRef & name, const String & value); + void setSetting(const StringRef & name, const Field & value); void applySettingChange(const SettingChange & change); void applySettingsChanges(const SettingsChanges & changes); @@ -356,7 +352,7 @@ public: void clampToSettingsConstraints(SettingsChanges & changes) const; /// Returns the current constraints (can return null). - std::shared_ptr getSettingsConstraints() const { return settings_constraints; } + std::shared_ptr getSettingsConstraints() const; const EmbeddedDictionaries & getEmbeddedDictionaries() const; const ExternalDictionariesLoader & getExternalDictionariesLoader() const; @@ -427,7 +423,6 @@ public: } const Settings & getSettingsRef() const { return settings; } - Settings & getSettingsRef() { return settings; } void setProgressCallback(ProgressCallback callback); /// Used in InterpreterSelectQuery to pass it to the IBlockInputStream. @@ -597,7 +592,6 @@ private: std::unique_lock getLock() const; /// Compute and set actual user settings, client_info.current_user should be set - void calculateUserSettings(); void calculateAccessRights(); template diff --git a/dbms/src/Interpreters/InterpreterAlterQuery.cpp b/dbms/src/Interpreters/InterpreterAlterQuery.cpp index 5462fc16a81..315527765ef 100644 --- a/dbms/src/Interpreters/InterpreterAlterQuery.cpp +++ b/dbms/src/Interpreters/InterpreterAlterQuery.cpp @@ -13,7 +13,7 @@ #include #include #include - +#include #include @@ -125,155 +125,162 @@ AccessRightsElements InterpreterAlterQuery::getRequiredAccess() const { AccessRightsElements required_access; const auto & alter = query_ptr->as(); - for (ASTAlterCommand * command_ast : alter.command_list->commands) - { - auto column_name = [&]() -> String { return getIdentifierName(command_ast->column); }; - auto column_name_from_col_decl = [&]() -> std::string_view { return command_ast->col_decl->as().name; }; - auto column_names_from_update_assignments = [&]() -> std::vector - { - std::vector column_names; - for (const ASTPtr & assignment_ast : command_ast->update_assignments->children) - column_names.emplace_back(assignment_ast->as().column_name); - return column_names; - }; + for (ASTAlterCommand * command : alter.command_list->commands) + boost::range::push_back(required_access, getRequiredAccessForCommand(*command, alter.database, alter.table)); + return required_access; +} - switch (command_ast->type) + +AccessRightsElements InterpreterAlterQuery::getRequiredAccessForCommand(const ASTAlterCommand & command, const String & database, const String & table) +{ + AccessRightsElements required_access; + + auto column_name = [&]() -> String { return getIdentifierName(command.column); }; + auto column_name_from_col_decl = [&]() -> std::string_view { return command.col_decl->as().name; }; + auto column_names_from_update_assignments = [&]() -> std::vector + { + std::vector column_names; + for (const ASTPtr & assignment_ast : command.update_assignments->children) + column_names.emplace_back(assignment_ast->as().column_name); + return column_names; + }; + + switch (command.type) + { + case ASTAlterCommand::UPDATE: { - case ASTAlterCommand::UPDATE: - { - required_access.emplace_back(AccessType::UPDATE, alter.database, alter.table, column_names_from_update_assignments()); - break; - } - case ASTAlterCommand::DELETE: - { - required_access.emplace_back(AccessType::DELETE, alter.database, alter.table); - break; - } - case ASTAlterCommand::ADD_COLUMN: - { - required_access.emplace_back(AccessType::ADD_COLUMN, alter.database, alter.table, column_name_from_col_decl()); - break; - } - case ASTAlterCommand::DROP_COLUMN: - { - if (command_ast->clear_column) - required_access.emplace_back(AccessType::CLEAR_COLUMN, alter.database, alter.table, column_name()); - else - required_access.emplace_back(AccessType::DROP_COLUMN, alter.database, alter.table, column_name()); - break; - } - case ASTAlterCommand::MODIFY_COLUMN: - { - required_access.emplace_back(AccessType::MODIFY_COLUMN, alter.database, alter.table, column_name_from_col_decl()); - break; - } - case ASTAlterCommand::COMMENT_COLUMN: - { - required_access.emplace_back(AccessType::COMMENT_COLUMN, alter.database, alter.table, column_name()); - break; - } - case ASTAlterCommand::MODIFY_ORDER_BY: - { - required_access.emplace_back(AccessType::ALTER_ORDER_BY, alter.database, alter.table); - break; - } - case ASTAlterCommand::ADD_INDEX: - { - required_access.emplace_back(AccessType::ADD_INDEX, alter.database, alter.table); - break; - } - case ASTAlterCommand::DROP_INDEX: - { - if (command_ast->clear_index) - required_access.emplace_back(AccessType::CLEAR_INDEX, alter.database, alter.table); - else - required_access.emplace_back(AccessType::DROP_INDEX, alter.database, alter.table); - break; - } - case ASTAlterCommand::MATERIALIZE_INDEX: - { - required_access.emplace_back(AccessType::MATERIALIZE_INDEX, alter.database, alter.table); - break; - } - case ASTAlterCommand::ADD_CONSTRAINT: - { - required_access.emplace_back(AccessType::ADD_CONSTRAINT, alter.database, alter.table); - break; - } - case ASTAlterCommand::DROP_CONSTRAINT: - { - required_access.emplace_back(AccessType::DROP_CONSTRAINT, alter.database, alter.table); - break; - } - case ASTAlterCommand::MODIFY_TTL: - { - required_access.emplace_back(AccessType::MODIFY_TTL, alter.database, alter.table); - break; - } - case ASTAlterCommand::MATERIALIZE_TTL: - { - required_access.emplace_back(AccessType::MATERIALIZE_TTL, alter.database, alter.table); - break; - } - case ASTAlterCommand::MODIFY_SETTING: - { - required_access.emplace_back(AccessType::MODIFY_SETTING, alter.database, alter.table); - break; - } - case ASTAlterCommand::ATTACH_PARTITION: - { - required_access.emplace_back(AccessType::INSERT, alter.database, alter.table); - break; - } - case ASTAlterCommand::DROP_PARTITION: [[fallthrough]]; - case ASTAlterCommand::DROP_DETACHED_PARTITION: - { - required_access.emplace_back(AccessType::DELETE, alter.database, alter.table); - break; - } - case ASTAlterCommand::MOVE_PARTITION: - { - if ((command_ast->move_destination_type == PartDestinationType::DISK) - || (command_ast->move_destination_type == PartDestinationType::VOLUME)) - { - required_access.emplace_back(AccessType::MOVE_PARTITION, alter.database, alter.table); - } - else if (command_ast->move_destination_type == PartDestinationType::TABLE) - { - required_access.emplace_back(AccessType::SELECT | AccessType::DELETE, alter.database, alter.table); - required_access.emplace_back(AccessType::INSERT, command_ast->to_database, command_ast->to_table); - } - break; - } - case ASTAlterCommand::REPLACE_PARTITION: - { - required_access.emplace_back(AccessType::SELECT, command_ast->from_database, command_ast->from_table); - required_access.emplace_back(AccessType::DELETE | AccessType::INSERT, alter.database, alter.table); - break; - } - case ASTAlterCommand::FETCH_PARTITION: - { - required_access.emplace_back(AccessType::FETCH_PARTITION, alter.database, alter.table); - break; - } - case ASTAlterCommand::FREEZE_PARTITION: [[fallthrough]]; - case ASTAlterCommand::FREEZE_ALL: - { - required_access.emplace_back(AccessType::FREEZE_PARTITION, alter.database, alter.table); - break; - } - case ASTAlterCommand::MODIFY_QUERY: - { - required_access.emplace_back(AccessType::MODIFY_VIEW_QUERY, alter.database, alter.table); - break; - } - case ASTAlterCommand::LIVE_VIEW_REFRESH: - { - required_access.emplace_back(AccessType::REFRESH_VIEW, alter.database, alter.table); - break; - } - case ASTAlterCommand::NO_TYPE: break; + required_access.emplace_back(AccessType::UPDATE, database, table, column_names_from_update_assignments()); + break; } + case ASTAlterCommand::DELETE: + { + required_access.emplace_back(AccessType::DELETE, database, table); + break; + } + case ASTAlterCommand::ADD_COLUMN: + { + required_access.emplace_back(AccessType::ADD_COLUMN, database, table, column_name_from_col_decl()); + break; + } + case ASTAlterCommand::DROP_COLUMN: + { + if (command.clear_column) + required_access.emplace_back(AccessType::CLEAR_COLUMN, database, table, column_name()); + else + required_access.emplace_back(AccessType::DROP_COLUMN, database, table, column_name()); + break; + } + case ASTAlterCommand::MODIFY_COLUMN: + { + required_access.emplace_back(AccessType::MODIFY_COLUMN, database, table, column_name_from_col_decl()); + break; + } + case ASTAlterCommand::COMMENT_COLUMN: + { + required_access.emplace_back(AccessType::COMMENT_COLUMN, database, table, column_name()); + break; + } + case ASTAlterCommand::MODIFY_ORDER_BY: + { + required_access.emplace_back(AccessType::ALTER_ORDER_BY, database, table); + break; + } + case ASTAlterCommand::ADD_INDEX: + { + required_access.emplace_back(AccessType::ADD_INDEX, database, table); + break; + } + case ASTAlterCommand::DROP_INDEX: + { + if (command.clear_index) + required_access.emplace_back(AccessType::CLEAR_INDEX, database, table); + else + required_access.emplace_back(AccessType::DROP_INDEX, database, table); + break; + } + case ASTAlterCommand::MATERIALIZE_INDEX: + { + required_access.emplace_back(AccessType::MATERIALIZE_INDEX, database, table); + break; + } + case ASTAlterCommand::ADD_CONSTRAINT: + { + required_access.emplace_back(AccessType::ADD_CONSTRAINT, database, table); + break; + } + case ASTAlterCommand::DROP_CONSTRAINT: + { + required_access.emplace_back(AccessType::DROP_CONSTRAINT, database, table); + break; + } + case ASTAlterCommand::MODIFY_TTL: + { + required_access.emplace_back(AccessType::MODIFY_TTL, database, table); + break; + } + case ASTAlterCommand::MATERIALIZE_TTL: + { + required_access.emplace_back(AccessType::MATERIALIZE_TTL, database, table); + break; + } + case ASTAlterCommand::MODIFY_SETTING: + { + required_access.emplace_back(AccessType::MODIFY_SETTING, database, table); + break; + } + case ASTAlterCommand::ATTACH_PARTITION: + { + required_access.emplace_back(AccessType::INSERT, database, table); + break; + } + case ASTAlterCommand::DROP_PARTITION: [[fallthrough]]; + case ASTAlterCommand::DROP_DETACHED_PARTITION: + { + required_access.emplace_back(AccessType::DELETE, database, table); + break; + } + case ASTAlterCommand::MOVE_PARTITION: + { + if ((command.move_destination_type == PartDestinationType::DISK) + || (command.move_destination_type == PartDestinationType::VOLUME)) + { + required_access.emplace_back(AccessType::MOVE_PARTITION, database, table); + } + else if (command.move_destination_type == PartDestinationType::TABLE) + { + required_access.emplace_back(AccessType::SELECT | AccessType::DELETE, database, table); + required_access.emplace_back(AccessType::INSERT, command.to_database, command.to_table); + } + break; + } + case ASTAlterCommand::REPLACE_PARTITION: + { + required_access.emplace_back(AccessType::SELECT, command.from_database, command.from_table); + required_access.emplace_back(AccessType::DELETE | AccessType::INSERT, database, table); + break; + } + case ASTAlterCommand::FETCH_PARTITION: + { + required_access.emplace_back(AccessType::FETCH_PARTITION, database, table); + break; + } + case ASTAlterCommand::FREEZE_PARTITION: [[fallthrough]]; + case ASTAlterCommand::FREEZE_ALL: + { + required_access.emplace_back(AccessType::FREEZE_PARTITION, database, table); + break; + } + case ASTAlterCommand::MODIFY_QUERY: + { + required_access.emplace_back(AccessType::MODIFY_VIEW_QUERY, database, table); + break; + } + case ASTAlterCommand::LIVE_VIEW_REFRESH: + { + required_access.emplace_back(AccessType::REFRESH_VIEW, database, table); + break; + } + case ASTAlterCommand::NO_TYPE: break; } return required_access; diff --git a/dbms/src/Interpreters/InterpreterAlterQuery.h b/dbms/src/Interpreters/InterpreterAlterQuery.h index fd395a0de52..a7609eb81f1 100644 --- a/dbms/src/Interpreters/InterpreterAlterQuery.h +++ b/dbms/src/Interpreters/InterpreterAlterQuery.h @@ -8,6 +8,7 @@ namespace DB { class Context; class AccessRightsElements; +class ASTAlterCommand; /** Allows you add or remove a column in the table. @@ -20,6 +21,8 @@ public: BlockIO execute() override; + static AccessRightsElements getRequiredAccessForCommand(const ASTAlterCommand & command, const String & database, const String & table); + private: AccessRightsElements getRequiredAccess() const; diff --git a/dbms/src/Interpreters/InterpreterCheckQuery.cpp b/dbms/src/Interpreters/InterpreterCheckQuery.cpp index 25fac1d4982..b8f7203e607 100644 --- a/dbms/src/Interpreters/InterpreterCheckQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCheckQuery.cpp @@ -40,7 +40,7 @@ BlockIO InterpreterCheckQuery::execute() const auto & check = query_ptr->as(); auto table_id = context.resolveStorageID(check, Context::ResolveOrdinary); - context.checkAccess(AccessType::SHOW, table_id); + context.checkAccess(AccessType::SHOW_TABLES, table_id); StoragePtr table = DatabaseCatalog::instance().getTable(table_id); auto check_results = table->checkData(query_ptr, context); diff --git a/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp b/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp index 10c52a5b4fb..4b64615dd36 100644 --- a/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include @@ -14,7 +14,7 @@ namespace DB { namespace { -void updateQuotaFromQueryImpl(Quota & quota, const ASTCreateQuotaQuery & query, const std::optional & roles_from_query = {}) +void updateQuotaFromQueryImpl(Quota & quota, const ASTCreateQuotaQuery & query, const std::optional & roles_from_query = {}) { if (query.alter) { @@ -61,15 +61,15 @@ void updateQuotaFromQueryImpl(Quota & quota, const ASTCreateQuotaQuery & query, } } - const GenericRoleSet * roles = nullptr; - std::optional temp_role_set; + const ExtendedRoleSet * roles = nullptr; + std::optional temp_role_set; if (roles_from_query) roles = &*roles_from_query; else if (query.roles) roles = &temp_role_set.emplace(*query.roles); if (roles) - quota.roles = *roles; + quota.to_roles = *roles; } } @@ -80,9 +80,9 @@ BlockIO InterpreterCreateQuotaQuery::execute() auto & access_control = context.getAccessControlManager(); context.checkAccess(query.alter ? AccessType::ALTER_QUOTA : AccessType::CREATE_QUOTA); - std::optional roles_from_query; + std::optional roles_from_query; if (query.roles) - roles_from_query = GenericRoleSet{*query.roles, access_control, context.getUserID()}; + roles_from_query = ExtendedRoleSet{*query.roles, access_control, context.getUserID()}; if (query.alter) { diff --git a/dbms/src/Interpreters/InterpreterCreateRoleQuery.cpp b/dbms/src/Interpreters/InterpreterCreateRoleQuery.cpp index f1c58f9d9bd..f64462d443b 100644 --- a/dbms/src/Interpreters/InterpreterCreateRoleQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateRoleQuery.cpp @@ -7,21 +7,53 @@ namespace DB { +namespace +{ + void updateRoleFromQueryImpl( + Role & role, + const ASTCreateRoleQuery & query, + const std::optional & settings_from_query = {}) + { + if (query.alter) + { + if (!query.new_name.empty()) + role.setName(query.new_name); + } + else + role.setName(query.name); + + const SettingsProfileElements * settings = nullptr; + std::optional temp_settings; + if (settings_from_query) + settings = &*settings_from_query; + else if (query.settings) + settings = &temp_settings.emplace(*query.settings); + + if (settings) + role.settings = *settings; + } +} + + BlockIO InterpreterCreateRoleQuery::execute() { const auto & query = query_ptr->as(); auto & access_control = context.getAccessControlManager(); if (query.alter) - context.checkAccess(AccessType::CREATE_ROLE | AccessType::DROP_ROLE); + context.checkAccess(AccessType::ALTER_ROLE); else context.checkAccess(AccessType::CREATE_ROLE); + std::optional settings_from_query; + if (query.settings) + settings_from_query = SettingsProfileElements{*query.settings, access_control}; + if (query.alter) { auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { auto updated_role = typeid_cast>(entity->clone()); - updateRoleFromQuery(*updated_role, query); + updateRoleFromQueryImpl(*updated_role, query, settings_from_query); return updated_role; }; if (query.if_exists) @@ -35,7 +67,7 @@ BlockIO InterpreterCreateRoleQuery::execute() else { auto new_role = std::make_shared(); - updateRoleFromQuery(*new_role, query); + updateRoleFromQueryImpl(*new_role, query, settings_from_query); if (query.if_not_exists) access_control.tryInsert(new_role); @@ -51,12 +83,6 @@ BlockIO InterpreterCreateRoleQuery::execute() void InterpreterCreateRoleQuery::updateRoleFromQuery(Role & role, const ASTCreateRoleQuery & query) { - if (query.alter) - { - if (!query.new_name.empty()) - role.setName(query.new_name); - } - else - role.setName(query.name); + updateRoleFromQueryImpl(role, query); } } diff --git a/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp b/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp index 2d0e23d284e..9ea47aba7bb 100644 --- a/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include @@ -27,7 +27,7 @@ namespace void updateRowPolicyFromQueryImpl( RowPolicy & policy, const ASTCreateRowPolicyQuery & query, - const std::optional & roles_from_query = {}, + const std::optional & roles_from_query = {}, const String & current_database = {}) { if (query.alter) @@ -48,15 +48,15 @@ namespace for (const auto & [index, condition] : query.conditions) policy.conditions[index] = condition ? serializeAST(*condition) : String{}; - const GenericRoleSet * roles = nullptr; - std::optional temp_role_set; + const ExtendedRoleSet * roles = nullptr; + std::optional temp_role_set; if (roles_from_query) roles = &*roles_from_query; else if (query.roles) roles = &temp_role_set.emplace(*query.roles); if (roles) - policy.roles = *roles; + policy.to_roles = *roles; } } @@ -67,9 +67,9 @@ BlockIO InterpreterCreateRowPolicyQuery::execute() auto & access_control = context.getAccessControlManager(); context.checkAccess(query.alter ? AccessType::ALTER_POLICY : AccessType::CREATE_POLICY); - std::optional roles_from_query; + std::optional roles_from_query; if (query.roles) - roles_from_query = GenericRoleSet{*query.roles, access_control, context.getUserID()}; + roles_from_query = ExtendedRoleSet{*query.roles, access_control, context.getUserID()}; const String current_database = context.getCurrentDatabase(); diff --git a/dbms/src/Interpreters/InterpreterCreateSettingsProfileQuery.cpp b/dbms/src/Interpreters/InterpreterCreateSettingsProfileQuery.cpp new file mode 100644 index 00000000000..9d110a69516 --- /dev/null +++ b/dbms/src/Interpreters/InterpreterCreateSettingsProfileQuery.cpp @@ -0,0 +1,104 @@ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace +{ + void updateSettingsProfileFromQueryImpl( + SettingsProfile & profile, + const ASTCreateSettingsProfileQuery & query, + const std::optional & settings_from_query = {}, + const std::optional & roles_from_query = {}) + { + if (query.alter) + { + if (!query.new_name.empty()) + profile.setName(query.new_name); + } + else + profile.setName(query.name); + + const SettingsProfileElements * settings = nullptr; + std::optional temp_settings; + if (settings_from_query) + settings = &*settings_from_query; + else if (query.settings) + settings = &temp_settings.emplace(*query.settings); + + if (settings) + profile.elements = *settings; + + const ExtendedRoleSet * roles = nullptr; + std::optional temp_role_set; + if (roles_from_query) + roles = &*roles_from_query; + else if (query.to_roles) + roles = &temp_role_set.emplace(*query.to_roles); + + if (roles) + profile.to_roles = *roles; + } +} + + +BlockIO InterpreterCreateSettingsProfileQuery::execute() +{ + const auto & query = query_ptr->as(); + auto & access_control = context.getAccessControlManager(); + if (query.alter) + context.checkAccess(AccessType::ALTER_SETTINGS_PROFILE); + else + context.checkAccess(AccessType::CREATE_SETTINGS_PROFILE); + + std::optional settings_from_query; + if (query.settings) + settings_from_query = SettingsProfileElements{*query.settings, access_control}; + + std::optional roles_from_query; + if (query.to_roles) + roles_from_query = ExtendedRoleSet{*query.to_roles, access_control, context.getUserID()}; + + if (query.alter) + { + auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr + { + auto updated_profile = typeid_cast>(entity->clone()); + updateSettingsProfileFromQueryImpl(*updated_profile, query, settings_from_query, roles_from_query); + return updated_profile; + }; + if (query.if_exists) + { + if (auto id = access_control.find(query.name)) + access_control.tryUpdate(*id, update_func); + } + else + access_control.update(access_control.getID(query.name), update_func); + } + else + { + auto new_profile = std::make_shared(); + updateSettingsProfileFromQueryImpl(*new_profile, query, settings_from_query, roles_from_query); + + if (query.if_not_exists) + access_control.tryInsert(new_profile); + else if (query.or_replace) + access_control.insertOrReplace(new_profile); + else + access_control.insert(new_profile); + } + + return {}; +} + + +void InterpreterCreateSettingsProfileQuery::updateSettingsProfileFromQuery(SettingsProfile & SettingsProfile, const ASTCreateSettingsProfileQuery & query) +{ + updateSettingsProfileFromQueryImpl(SettingsProfile, query); +} +} diff --git a/dbms/src/Interpreters/InterpreterCreateSettingsProfileQuery.h b/dbms/src/Interpreters/InterpreterCreateSettingsProfileQuery.h new file mode 100644 index 00000000000..fd420779cf4 --- /dev/null +++ b/dbms/src/Interpreters/InterpreterCreateSettingsProfileQuery.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + + +namespace DB +{ +class ASTCreateSettingsProfileQuery; +struct SettingsProfile; + + +class InterpreterCreateSettingsProfileQuery : public IInterpreter +{ +public: + InterpreterCreateSettingsProfileQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {} + + BlockIO execute() override; + + static void updateSettingsProfileFromQuery(SettingsProfile & profile, const ASTCreateSettingsProfileQuery & query); + +private: + ASTPtr query_ptr; + Context & context; +}; +} diff --git a/dbms/src/Interpreters/InterpreterCreateUserQuery.cpp b/dbms/src/Interpreters/InterpreterCreateUserQuery.cpp index 6219a493b27..5dba1fefc9c 100644 --- a/dbms/src/Interpreters/InterpreterCreateUserQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateUserQuery.cpp @@ -4,8 +4,8 @@ #include #include #include -#include -#include +#include +#include #include @@ -13,7 +13,11 @@ namespace DB { namespace { - void updateUserFromQueryImpl(User & user, const ASTCreateUserQuery & query, const std::optional & default_roles_from_query = {}) + void updateUserFromQueryImpl( + User & user, + const ASTCreateUserQuery & query, + const std::optional & default_roles_from_query = {}, + const std::optional & settings_from_query = {}) { if (query.alter) { @@ -33,8 +37,8 @@ namespace if (query.add_hosts) user.allowed_client_hosts.add(*query.add_hosts); - const GenericRoleSet * default_roles = nullptr; - std::optional temp_role_set; + const ExtendedRoleSet * default_roles = nullptr; + std::optional temp_role_set; if (default_roles_from_query) default_roles = &*default_roles_from_query; else if (query.default_roles) @@ -48,8 +52,15 @@ namespace InterpreterSetRoleQuery::updateUserSetDefaultRoles(user, *default_roles); } - if (query.profile) - user.profile = *query.profile; + const SettingsProfileElements * settings = nullptr; + std::optional temp_settings; + if (settings_from_query) + settings = &*settings_from_query; + else if (query.settings) + settings = &temp_settings.emplace(*query.settings); + + if (settings) + user.settings = *settings; } } @@ -58,25 +69,30 @@ BlockIO InterpreterCreateUserQuery::execute() { const auto & query = query_ptr->as(); auto & access_control = context.getAccessControlManager(); - context.checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER); + auto access = context.getAccess(); + access->checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER); - std::optional default_roles_from_query; + std::optional default_roles_from_query; if (query.default_roles) { - default_roles_from_query = GenericRoleSet{*query.default_roles, access_control}; + default_roles_from_query = ExtendedRoleSet{*query.default_roles, access_control}; if (!query.alter && !default_roles_from_query->all) { for (const UUID & role : default_roles_from_query->getMatchingIDs()) - context.getAccessRights()->checkAdminOption(role); + access->checkAdminOption(role); } } + std::optional settings_from_query; + if (query.settings) + settings_from_query = SettingsProfileElements{*query.settings, access_control}; + if (query.alter) { auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { auto updated_user = typeid_cast>(entity->clone()); - updateUserFromQueryImpl(*updated_user, query, default_roles_from_query); + updateUserFromQueryImpl(*updated_user, query, default_roles_from_query, settings_from_query); return updated_user; }; if (query.if_exists) @@ -90,7 +106,7 @@ BlockIO InterpreterCreateUserQuery::execute() else { auto new_user = std::make_shared(); - updateUserFromQueryImpl(*new_user, query, default_roles_from_query); + updateUserFromQueryImpl(*new_user, query, default_roles_from_query, settings_from_query); if (query.if_not_exists) access_control.tryInsert(new_user); diff --git a/dbms/src/Interpreters/InterpreterDescribeQuery.cpp b/dbms/src/Interpreters/InterpreterDescribeQuery.cpp index c2660f63169..1353c01ebf6 100644 --- a/dbms/src/Interpreters/InterpreterDescribeQuery.cpp +++ b/dbms/src/Interpreters/InterpreterDescribeQuery.cpp @@ -85,8 +85,7 @@ BlockInputStreamPtr InterpreterDescribeQuery::executeImpl() else { auto table_id = context.resolveStorageID(table_expression.database_and_table_name); - context.checkAccess(AccessType::SHOW, table_id); - + context.checkAccess(AccessType::SHOW_COLUMNS, table_id); table = DatabaseCatalog::instance().getTable(table_id); } diff --git a/dbms/src/Interpreters/InterpreterDropAccessEntityQuery.cpp b/dbms/src/Interpreters/InterpreterDropAccessEntityQuery.cpp index c69ce3ade45..12f33250188 100644 --- a/dbms/src/Interpreters/InterpreterDropAccessEntityQuery.cpp +++ b/dbms/src/Interpreters/InterpreterDropAccessEntityQuery.cpp @@ -7,64 +7,69 @@ #include #include #include +#include #include namespace DB { +namespace +{ + using Kind = ASTDropAccessEntityQuery::Kind; + + std::type_index getType(Kind kind) + { + switch (kind) + { + case Kind::USER: return typeid(User); + case Kind::ROLE: return typeid(Role); + case Kind::QUOTA: return typeid(Quota); + case Kind::ROW_POLICY: return typeid(RowPolicy); + case Kind::SETTINGS_PROFILE: return typeid(SettingsProfile); + } + __builtin_unreachable(); + } + + AccessType getRequiredAccessType(Kind kind) + { + switch (kind) + { + case Kind::USER: return AccessType::DROP_USER; + case Kind::ROLE: return AccessType::DROP_ROLE; + case Kind::QUOTA: return AccessType::DROP_QUOTA; + case Kind::ROW_POLICY: return AccessType::DROP_POLICY; + case Kind::SETTINGS_PROFILE: return AccessType::DROP_SETTINGS_PROFILE; + } + __builtin_unreachable(); + } +} + BlockIO InterpreterDropAccessEntityQuery::execute() { const auto & query = query_ptr->as(); auto & access_control = context.getAccessControlManager(); - using Kind = ASTDropAccessEntityQuery::Kind; - switch (query.kind) + std::type_index type = getType(query.kind); + context.checkAccess(getRequiredAccessType(query.kind)); + + if (query.kind == Kind::ROW_POLICY) { - case Kind::USER: - { - context.checkAccess(AccessType::DROP_USER); - if (query.if_exists) - access_control.tryRemove(access_control.find(query.names)); - else - access_control.remove(access_control.getIDs(query.names)); - return {}; - } - - case Kind::ROLE: - { - context.checkAccess(AccessType::DROP_ROLE); - if (query.if_exists) - access_control.tryRemove(access_control.find(query.names)); - else - access_control.remove(access_control.getIDs(query.names)); - return {}; - } - - case Kind::QUOTA: - { - context.checkAccess(AccessType::DROP_QUOTA); - if (query.if_exists) - access_control.tryRemove(access_control.find(query.names)); - else - access_control.remove(access_control.getIDs(query.names)); - return {}; - } - - case Kind::ROW_POLICY: - { - context.checkAccess(AccessType::DROP_POLICY); - Strings full_names; - boost::range::transform( - query.row_policies_names, std::back_inserter(full_names), - [this](const RowPolicy::FullNameParts & row_policy_name) { return row_policy_name.getFullName(context); }); - if (query.if_exists) - access_control.tryRemove(access_control.find(full_names)); - else - access_control.remove(access_control.getIDs(full_names)); - return {}; - } + Strings full_names; + boost::range::transform( + query.row_policies_names, std::back_inserter(full_names), + [this](const RowPolicy::FullNameParts & row_policy_name) { return row_policy_name.getFullName(context); }); + if (query.if_exists) + access_control.tryRemove(access_control.find(full_names)); + else + access_control.remove(access_control.getIDs(full_names)); + return {}; } - __builtin_unreachable(); + if (query.if_exists) + access_control.tryRemove(access_control.find(type, query.names)); + else + access_control.remove(access_control.getIDs(type, query.names)); + return {}; } + } diff --git a/dbms/src/Interpreters/InterpreterExistsQuery.cpp b/dbms/src/Interpreters/InterpreterExistsQuery.cpp index 7cd864fddb7..993b3631e06 100644 --- a/dbms/src/Interpreters/InterpreterExistsQuery.cpp +++ b/dbms/src/Interpreters/InterpreterExistsQuery.cpp @@ -44,13 +44,12 @@ BlockInputStreamPtr InterpreterExistsQuery::executeImpl() { if (exists_query->temporary) { - context.checkAccess(AccessType::EXISTS, "", exists_query->table); result = context.tryResolveStorageID({"", exists_query->table}, Context::ResolveExternal); } else { String database = context.resolveDatabase(exists_query->database); - context.checkAccess(AccessType::EXISTS, database, exists_query->table); + context.checkAccess(AccessType::SHOW_TABLES, database, exists_query->table); result = DatabaseCatalog::instance().isTableExist({database, exists_query->table}); } } @@ -59,7 +58,7 @@ BlockInputStreamPtr InterpreterExistsQuery::executeImpl() if (exists_query->temporary) throw Exception("Temporary dictionaries are not possible.", ErrorCodes::SYNTAX_ERROR); String database = context.resolveDatabase(exists_query->database); - context.checkAccess(AccessType::EXISTS, database, exists_query->table); + context.checkAccess(AccessType::SHOW_DICTIONARIES, database, exists_query->table); result = DatabaseCatalog::instance().isDictionaryExist({database, exists_query->table}); } diff --git a/dbms/src/Interpreters/InterpreterFactory.cpp b/dbms/src/Interpreters/InterpreterFactory.cpp index b3b1fd498db..0c34d6ed79f 100644 --- a/dbms/src/Interpreters/InterpreterFactory.cpp +++ b/dbms/src/Interpreters/InterpreterFactory.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -204,6 +206,10 @@ std::unique_ptr InterpreterFactory::get(ASTPtr & query, Context & { return std::make_unique(query, context); } + else if (query->as()) + { + return std::make_unique(query, context); + } else if (query->as()) { return std::make_unique(query, context); diff --git a/dbms/src/Interpreters/InterpreterGrantQuery.cpp b/dbms/src/Interpreters/InterpreterGrantQuery.cpp index 6d1b2262637..5d215ff3a93 100644 --- a/dbms/src/Interpreters/InterpreterGrantQuery.cpp +++ b/dbms/src/Interpreters/InterpreterGrantQuery.cpp @@ -2,8 +2,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include @@ -14,7 +14,7 @@ namespace DB namespace { template - void updateFromQueryImpl(T & grantee, const ASTGrantQuery & query, const std::vector & roles_from_query, const String & current_database, bool partial_revokes) + void updateFromQueryImpl(T & grantee, const ASTGrantQuery & query, const std::vector & roles_from_query, const String & current_database) { using Kind = ASTGrantQuery::Kind; if (!query.access_rights_elements.empty()) @@ -25,12 +25,6 @@ namespace if (query.grant_option) grantee.access_with_grant_option.grant(query.access_rights_elements, current_database); } - else if (partial_revokes) - { - grantee.access_with_grant_option.partialRevoke(query.access_rights_elements, current_database); - if (!query.grant_option) - grantee.access.partialRevoke(query.access_rights_elements, current_database); - } else { grantee.access_with_grant_option.revoke(query.access_rights_elements, current_database); @@ -67,31 +61,31 @@ BlockIO InterpreterGrantQuery::execute() { const auto & query = query_ptr->as(); auto & access_control = context.getAccessControlManager(); - context.getAccessRights()->checkGrantOption(query.access_rights_elements); + auto access = context.getAccess(); + access->checkGrantOption(query.access_rights_elements); std::vector roles_from_query; if (query.roles) { - roles_from_query = GenericRoleSet{*query.roles, access_control}.getMatchingRoles(access_control); + roles_from_query = ExtendedRoleSet{*query.roles, access_control}.getMatchingIDs(access_control); for (const UUID & role_from_query : roles_from_query) - context.getAccessRights()->checkAdminOption(role_from_query); + access->checkAdminOption(role_from_query); } - std::vector to_roles = GenericRoleSet{*query.to_roles, access_control, context.getUserID()}.getMatchingUsersAndRoles(access_control); + std::vector to_roles = ExtendedRoleSet{*query.to_roles, access_control, context.getUserID()}.getMatchingIDs(access_control); String current_database = context.getCurrentDatabase(); - bool partial_revokes = context.getSettingsRef().partial_revokes; auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { auto clone = entity->clone(); if (auto user = typeid_cast>(clone)) { - updateFromQueryImpl(*user, query, roles_from_query, current_database, partial_revokes); + updateFromQueryImpl(*user, query, roles_from_query, current_database); return user; } else if (auto role = typeid_cast>(clone)) { - updateFromQueryImpl(*role, query, roles_from_query, current_database, partial_revokes); + updateFromQueryImpl(*role, query, roles_from_query, current_database); return role; } else @@ -108,8 +102,8 @@ void InterpreterGrantQuery::updateUserFromQuery(User & user, const ASTGrantQuery { std::vector roles_from_query; if (query.roles) - roles_from_query = GenericRoleSet{*query.roles}.getMatchingIDs(); - updateFromQueryImpl(user, query, roles_from_query, {}, true); + roles_from_query = ExtendedRoleSet{*query.roles}.getMatchingIDs(); + updateFromQueryImpl(user, query, roles_from_query, {}); } @@ -117,8 +111,8 @@ void InterpreterGrantQuery::updateRoleFromQuery(Role & role, const ASTGrantQuery { std::vector roles_from_query; if (query.roles) - roles_from_query = GenericRoleSet{*query.roles}.getMatchingIDs(); - updateFromQueryImpl(role, query, roles_from_query, {}, true); + roles_from_query = ExtendedRoleSet{*query.roles}.getMatchingIDs(); + updateFromQueryImpl(role, query, roles_from_query, {}); } } diff --git a/dbms/src/Interpreters/InterpreterKillQueryQuery.cpp b/dbms/src/Interpreters/InterpreterKillQueryQuery.cpp index 81a093f4eae..196b2b4eef1 100644 --- a/dbms/src/Interpreters/InterpreterKillQueryQuery.cpp +++ b/dbms/src/Interpreters/InterpreterKillQueryQuery.cpp @@ -6,7 +6,11 @@ #include #include #include -#include +#include +#include +#include +#include +#include #include #include #include @@ -55,7 +59,7 @@ struct QueryDescriptor size_t source_num; bool processed = false; - QueryDescriptor(String && query_id_, String && user_, size_t source_num_, bool processed_ = false) + QueryDescriptor(String query_id_, String user_, size_t source_num_, bool processed_ = false) : query_id(std::move(query_id_)), user(std::move(user_)), source_num(source_num_), processed(processed_) {} }; @@ -79,8 +83,20 @@ static QueryDescriptors extractQueriesExceptMeAndCheckAccess(const Block & proce const ColumnString & query_id_col = typeid_cast(*processes_block.getByName("query_id").column); const ColumnString & user_col = typeid_cast(*processes_block.getByName("user").column); const ClientInfo & my_client = context.getProcessListElement()->getClientInfo(); - std::optional can_kill_query_started_by_another_user; + + std::optional can_kill_query_started_by_another_user_cached; + auto can_kill_query_started_by_another_user = [&]() -> bool + { + if (!can_kill_query_started_by_another_user_cached) + { + can_kill_query_started_by_another_user_cached + = context.getAccess()->isGranted(&Poco::Logger::get("InterpreterKillQueryQuery"), AccessType::KILL_QUERY); + } + return *can_kill_query_started_by_another_user_cached; + }; + String query_user; + bool access_denied = false; for (size_t i = 0; i < num_processes; ++i) { @@ -91,18 +107,16 @@ static QueryDescriptors extractQueriesExceptMeAndCheckAccess(const Block & proce auto query_id = query_id_col.getDataAt(i).toString(); query_user = user_col.getDataAt(i).toString(); - if (my_client.current_user != query_user) + if ((my_client.current_user != query_user) && !can_kill_query_started_by_another_user()) { - if (!can_kill_query_started_by_another_user) - can_kill_query_started_by_another_user = context.getAccessRights()->isGranted(&Poco::Logger::get("InterpreterKillQueryQuery"), AccessType::KILL_QUERY); - if (!can_kill_query_started_by_another_user.value()) - continue; + access_denied = true; + continue; } - res.emplace_back(std::move(query_id), std::move(query_user), i, false); + res.emplace_back(std::move(query_id), query_user, i, false); } - if (res.empty() && !query_user.empty()) // NOLINT + if (res.empty() && access_denied) throw Exception("User " + my_client.current_user + " attempts to kill query created by " + query_user, ErrorCodes::ACCESS_DENIED); return res; @@ -221,19 +235,23 @@ BlockIO InterpreterKillQueryQuery::execute() } case ASTKillQueryQuery::Type::Mutation: { - Block mutations_block = getSelectResult("database, table, mutation_id", "system.mutations"); + Block mutations_block = getSelectResult("database, table, mutation_id, command", "system.mutations"); if (!mutations_block) return res_io; const ColumnString & database_col = typeid_cast(*mutations_block.getByName("database").column); const ColumnString & table_col = typeid_cast(*mutations_block.getByName("table").column); const ColumnString & mutation_id_col = typeid_cast(*mutations_block.getByName("mutation_id").column); + const ColumnString & command_col = typeid_cast(*mutations_block.getByName("command").column); auto header = mutations_block.cloneEmpty(); header.insert(0, {ColumnString::create(), std::make_shared(), "kill_status"}); MutableColumns res_columns = header.cloneEmptyColumns(); auto table_id = StorageID::createEmpty(); + AccessRightsElements required_access_rights; + auto access = context.getAccess(); + bool access_denied = false; for (size_t i = 0; i < mutations_block.rows(); ++i) { @@ -248,8 +266,14 @@ BlockIO InterpreterKillQueryQuery::execute() code = CancellationCode::NotFound; else { - if (!context.getAccessRights()->isGranted(&Poco::Logger::get("InterpreterKillQueryQuery"), AccessType::KILL_MUTATION, table_id.database_name, table_id.table_name)) + ParserAlterCommand parser; + auto command_ast = parseQuery(parser, command_col.getDataAt(i).toString(), 0); + required_access_rights = InterpreterAlterQuery::getRequiredAccessForCommand(command_ast->as(), table_id.database_name, table_id.table_name); + if (!access->isGranted(&Poco::Logger::get("InterpreterKillQueryQuery"), required_access_rights)) + { + access_denied = true; continue; + } code = storage->killMutation(mutation_id); } } @@ -257,9 +281,9 @@ BlockIO InterpreterKillQueryQuery::execute() insertResultRow(i, code, mutations_block, header, res_columns); } - if (res_columns[0]->empty() && table_id) + if (res_columns[0]->empty() && access_denied) throw Exception( - "Not allowed to kill mutation on " + table_id.getNameForLogs(), + "Not allowed to kill mutation. To execute this query it's necessary to have the grant " + required_access_rights.toString(), ErrorCodes::ACCESS_DENIED); res_io.in = std::make_shared(header.cloneWithColumns(std::move(res_columns))); @@ -295,7 +319,7 @@ AccessRightsElements InterpreterKillQueryQuery::getRequiredAccessForDDLOnCluster if (query.type == ASTKillQueryQuery::Type::Query) required_access.emplace_back(AccessType::KILL_QUERY); else if (query.type == ASTKillQueryQuery::Type::Mutation) - required_access.emplace_back(AccessType::KILL_MUTATION); + required_access.emplace_back(AccessType::UPDATE | AccessType::DELETE | AccessType::MATERIALIZE_INDEX | AccessType::MATERIALIZE_TTL); return required_access; } diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 07043301325..8a5f3c1bcb7 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -38,7 +38,6 @@ #include #include -#include #include #include @@ -361,8 +360,7 @@ InterpreterSelectQuery::InterpreterSelectQuery( source_header = storage->getSampleBlockForColumns(required_columns); /// Fix source_header for filter actions. - auto row_policy_filter = context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER); - row_policy_filter = RowPolicyContext::combineConditionsUsingAnd(row_policy_filter, context->getInitialRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER)); + auto row_policy_filter = context->getRowPolicyCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER); if (row_policy_filter) { filter_info = std::make_shared(); @@ -490,8 +488,7 @@ Block InterpreterSelectQuery::getSampleBlockImpl(bool try_move_to_prewhere) /// PREWHERE optimization. /// Turn off, if the table filter (row-level security) is applied. - if (!context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER) - && !context->getInitialRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER)) + if (!context->getRowPolicyCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER)) { auto optimize_prewhere = [&](auto & merge_tree) { @@ -1128,7 +1125,7 @@ void InterpreterSelectQuery::executeFetchColumns( if (storage) { /// Append columns from the table filter to required - auto row_policy_filter = context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER); + auto row_policy_filter = context->getRowPolicyCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER); if (row_policy_filter) { auto initial_required_columns = required_columns; diff --git a/dbms/src/Interpreters/InterpreterSetRoleQuery.cpp b/dbms/src/Interpreters/InterpreterSetRoleQuery.cpp index 567c626cb90..2a6f2317a9c 100644 --- a/dbms/src/Interpreters/InterpreterSetRoleQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSetRoleQuery.cpp @@ -1,8 +1,8 @@ #include #include -#include +#include #include -#include +#include #include #include @@ -38,7 +38,7 @@ void InterpreterSetRoleQuery::setRole(const ASTSetRoleQuery & query) } else { - GenericRoleSet roles_from_query{*query.roles, access_control}; + ExtendedRoleSet roles_from_query{*query.roles, access_control}; std::vector new_current_roles; if (roles_from_query.all) { @@ -65,8 +65,8 @@ void InterpreterSetRoleQuery::setDefaultRole(const ASTSetRoleQuery & query) context.checkAccess(AccessType::CREATE_USER | AccessType::DROP_USER); auto & access_control = context.getAccessControlManager(); - std::vector to_users = GenericRoleSet{*query.to_users, access_control, context.getUserID()}.getMatchingUsers(access_control); - GenericRoleSet roles_from_query{*query.roles, access_control}; + std::vector to_users = ExtendedRoleSet{*query.to_users, access_control, context.getUserID()}.getMatchingIDs(access_control); + ExtendedRoleSet roles_from_query{*query.roles, access_control}; auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { @@ -79,7 +79,7 @@ void InterpreterSetRoleQuery::setDefaultRole(const ASTSetRoleQuery & query) } -void InterpreterSetRoleQuery::updateUserSetDefaultRoles(User & user, const GenericRoleSet & roles_from_query) +void InterpreterSetRoleQuery::updateUserSetDefaultRoles(User & user, const ExtendedRoleSet & roles_from_query) { if (!roles_from_query.all) { diff --git a/dbms/src/Interpreters/InterpreterSetRoleQuery.h b/dbms/src/Interpreters/InterpreterSetRoleQuery.h index cace6b22c24..afb53014c87 100644 --- a/dbms/src/Interpreters/InterpreterSetRoleQuery.h +++ b/dbms/src/Interpreters/InterpreterSetRoleQuery.h @@ -7,7 +7,7 @@ namespace DB { class ASTSetRoleQuery; -struct GenericRoleSet; +struct ExtendedRoleSet; struct User; @@ -18,7 +18,7 @@ public: BlockIO execute() override; - static void updateUserSetDefaultRoles(User & user, const GenericRoleSet & roles_from_query); + static void updateUserSetDefaultRoles(User & user, const ExtendedRoleSet & roles_from_query); private: void setRole(const ASTSetRoleQuery & query); diff --git a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp index 8c8658d820c..52126b0507e 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp @@ -4,15 +4,18 @@ #include #include #include +#include #include -#include +#include #include #include #include #include -#include +#include +#include #include #include +#include #include #include #include @@ -42,15 +45,12 @@ namespace if (user.allowed_client_hosts != AllowedClientHosts::AnyHostTag{}) query->hosts = user.allowed_client_hosts; - if (!user.profile.empty()) - query->profile = user.profile; - - if (user.default_roles != GenericRoleSet::AllTag{}) + if (user.default_roles != ExtendedRoleSet::AllTag{}) { if (attach_mode) - query->default_roles = GenericRoleSet{user.default_roles}.toAST(); + query->default_roles = user.default_roles.toAST(); else - query->default_roles = GenericRoleSet{user.default_roles}.toASTWithNames(*manager); + query->default_roles = user.default_roles.toASTWithNames(*manager); } if (attach_mode && (user.authentication.getType() != Authentication::NO_PASSWORD)) @@ -58,15 +58,59 @@ namespace /// We don't show password unless it's an ATTACH statement. query->authentication = user.authentication; } + + if (!user.settings.empty()) + { + if (attach_mode) + query->settings = user.settings.toAST(); + else + query->settings = user.settings.toASTWithNames(*manager); + } + return query; } - ASTPtr getCreateQueryImpl(const Role & role, const AccessControlManager *, bool attach_mode = false) + ASTPtr getCreateQueryImpl(const Role & role, const AccessControlManager * manager, bool attach_mode = false) { auto query = std::make_shared(); query->name = role.getName(); query->attach = attach_mode; + + if (!role.settings.empty()) + { + if (attach_mode) + query->settings = role.settings.toAST(); + else + query->settings = role.settings.toASTWithNames(*manager); + } + + return query; + } + + + ASTPtr getCreateQueryImpl(const SettingsProfile & profile, const AccessControlManager * manager, bool attach_mode = false) + { + auto query = std::make_shared(); + query->name = profile.getName(); + query->attach = attach_mode; + + if (!profile.elements.empty()) + { + if (attach_mode) + query->settings = profile.elements.toAST(); + else + query->settings = profile.elements.toASTWithNames(*manager); + } + + if (!profile.to_roles.empty()) + { + if (attach_mode) + query->to_roles = profile.to_roles.toAST(); + else + query->to_roles = profile.to_roles.toASTWithNames(*manager); + } + return query; } @@ -94,12 +138,12 @@ namespace query->all_limits.push_back(create_query_limits); } - if (!quota.roles.empty()) + if (!quota.to_roles.empty()) { if (attach_mode) - query->roles = quota.roles.toAST(); + query->roles = quota.to_roles.toAST(); else - query->roles = quota.roles.toASTWithNames(*manager); + query->roles = quota.to_roles.toASTWithNames(*manager); } return query; @@ -118,7 +162,7 @@ namespace if (policy.isRestrictive()) query->is_restrictive = policy.isRestrictive(); - for (auto index : ext::range_with_static_cast(RowPolicy::MAX_CONDITION_INDEX)) + for (auto index : ext::range_with_static_cast(RowPolicy::MAX_CONDITION_TYPE)) { const auto & condition = policy.conditions[index]; if (!condition.empty()) @@ -129,12 +173,12 @@ namespace } } - if (!policy.roles.empty()) + if (!policy.to_roles.empty()) { if (attach_mode) - query->roles = policy.roles.toAST(); + query->roles = policy.to_roles.toAST(); else - query->roles = policy.roles.toASTWithNames(*manager); + query->roles = policy.to_roles.toASTWithNames(*manager); } return query; @@ -153,8 +197,25 @@ namespace return getCreateQueryImpl(*policy, manager, attach_mode); if (const Quota * quota = typeid_cast(&entity)) return getCreateQueryImpl(*quota, manager, attach_mode); + if (const SettingsProfile * profile = typeid_cast(&entity)) + return getCreateQueryImpl(*profile, manager, attach_mode); throw Exception("Unexpected type of access entity: " + entity.getTypeName(), ErrorCodes::LOGICAL_ERROR); } + + using Kind = ASTShowCreateAccessEntityQuery::Kind; + + std::type_index getType(Kind kind) + { + switch (kind) + { + case Kind::USER: return typeid(User); + case Kind::ROLE: return typeid(Role); + case Kind::QUOTA: return typeid(Quota); + case Kind::ROW_POLICY: return typeid(RowPolicy); + case Kind::SETTINGS_PROFILE: return typeid(SettingsProfile); + } + __builtin_unreachable(); + } } @@ -195,36 +256,28 @@ BlockInputStreamPtr InterpreterShowCreateAccessEntityQuery::executeImpl() ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const ASTShowCreateAccessEntityQuery & show_query) const { const auto & access_control = context.getAccessControlManager(); - using Kind = ASTShowCreateAccessEntityQuery::Kind; - switch (show_query.kind) + + if (show_query.current_user) { - case Kind::USER: - { - UserPtr user; - if (show_query.current_user) - user = context.getUser(); - else - user = access_control.read(show_query.name); - return getCreateQueryImpl(*user, &access_control); - } - - case Kind::QUOTA: - { - QuotaPtr quota; - if (show_query.current_quota) - quota = access_control.read(context.getQuota()->getUsageInfo().quota_id); - else - quota = access_control.read(show_query.name); - return getCreateQueryImpl(*quota, &access_control); - } - - case Kind::ROW_POLICY: - { - RowPolicyPtr policy = access_control.read(show_query.row_policy_name.getFullName(context)); - return getCreateQueryImpl(*policy, &access_control); - } + auto user = context.getUser(); + return getCreateQueryImpl(*user, &access_control); } - __builtin_unreachable(); + + if (show_query.current_quota) + { + auto quota = access_control.read(context.getQuota()->getUsageInfo().quota_id); + return getCreateQueryImpl(*quota, &access_control); + } + + auto type = getType(show_query.kind); + if (show_query.kind == Kind::ROW_POLICY) + { + RowPolicyPtr policy = access_control.read(show_query.row_policy_name.getFullName(context)); + return getCreateQueryImpl(*policy, &access_control); + } + + auto entity = access_control.read(access_control.getID(type, show_query.name)); + return getCreateQueryImpl(*entity, &access_control); } diff --git a/dbms/src/Interpreters/InterpreterShowCreateQuery.cpp b/dbms/src/Interpreters/InterpreterShowCreateQuery.cpp index 71f08f12f5d..8bee0b88fe8 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowCreateQuery.cpp @@ -49,7 +49,7 @@ BlockInputStreamPtr InterpreterShowCreateQuery::executeImpl() { auto resolve_table_type = show_query->temporary ? Context::ResolveExternal : Context::ResolveOrdinary; auto table_id = context.resolveStorageID(*show_query, resolve_table_type); - context.checkAccess(AccessType::SHOW, table_id); + context.checkAccess(AccessType::SHOW_COLUMNS, table_id); create_query = DatabaseCatalog::instance().getDatabase(table_id.database_name)->getCreateTableQuery(context, table_id.table_name); } else if ((show_query = query_ptr->as())) @@ -57,7 +57,7 @@ BlockInputStreamPtr InterpreterShowCreateQuery::executeImpl() if (show_query->temporary) throw Exception("Temporary databases are not possible.", ErrorCodes::SYNTAX_ERROR); show_query->database = context.resolveDatabase(show_query->database); - context.checkAccess(AccessType::SHOW, show_query->database); + context.checkAccess(AccessType::SHOW_DATABASES, show_query->database); create_query = DatabaseCatalog::instance().getDatabase(show_query->database)->getCreateDatabaseQuery(context); } else if ((show_query = query_ptr->as())) @@ -65,7 +65,7 @@ BlockInputStreamPtr InterpreterShowCreateQuery::executeImpl() if (show_query->temporary) throw Exception("Temporary dictionaries are not possible.", ErrorCodes::SYNTAX_ERROR); show_query->database = context.resolveDatabase(show_query->database); - context.checkAccess(AccessType::SHOW, show_query->database, show_query->table); + context.checkAccess(AccessType::SHOW_DICTIONARIES, show_query->database, show_query->table); create_query = DatabaseCatalog::instance().getDatabase(show_query->database)->getCreateDictionaryQuery(context, show_query->table); } diff --git a/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp b/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp index cbd4b3636ac..da1d46f0cab 100644 --- a/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include #include @@ -62,7 +62,7 @@ namespace { ASTs res; - std::shared_ptr to_roles = std::make_shared(); + std::shared_ptr to_roles = std::make_shared(); to_roles->names.push_back(grantee.getName()); for (bool grant_option : {true, false}) @@ -104,9 +104,9 @@ namespace grant_query->admin_option = admin_option; grant_query->to_roles = to_roles; if (attach_mode) - grant_query->roles = GenericRoleSet{roles}.toAST(); + grant_query->roles = ExtendedRoleSet{roles}.toAST(); else - grant_query->roles = GenericRoleSet{roles}.toASTWithNames(*manager); + grant_query->roles = ExtendedRoleSet{roles}.toASTWithNames(*manager); res.push_back(std::move(grant_query)); } diff --git a/dbms/src/Interpreters/InterpreterSystemQuery.cpp b/dbms/src/Interpreters/InterpreterSystemQuery.cpp index e479a53cb03..87ed4a1f749 100644 --- a/dbms/src/Interpreters/InterpreterSystemQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSystemQuery.cpp @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include #include @@ -137,17 +137,17 @@ void InterpreterSystemQuery::startStopAction(StorageActionBlockType action_type, } else { + auto access = context.getAccess(); for (auto & elem : DatabaseCatalog::instance().getDatabases()) { for (auto iterator = elem.second->getTablesIterator(context); iterator->isValid(); iterator->next()) { - if (context.getAccessRights()->isGranted(log, getRequiredAccessType(action_type), elem.first, iterator->name())) - { - if (start) - manager->remove(iterator->table(), action_type); - else - manager->add(iterator->table(), action_type); - } + if (!access->isGranted(log, getRequiredAccessType(action_type), elem.first, iterator->name())) + continue; + if (start) + manager->remove(iterator->table(), action_type); + else + manager->add(iterator->table(), action_type); } } } diff --git a/dbms/src/Interpreters/InterpreterUseQuery.cpp b/dbms/src/Interpreters/InterpreterUseQuery.cpp index 0cddaf26c11..58f5b6c9a32 100644 --- a/dbms/src/Interpreters/InterpreterUseQuery.cpp +++ b/dbms/src/Interpreters/InterpreterUseQuery.cpp @@ -11,7 +11,7 @@ namespace DB BlockIO InterpreterUseQuery::execute() { const String & new_database = query_ptr->as().database; - context.checkAccess(AccessType::EXISTS, new_database); + context.checkAccess(AccessType::SHOW_DATABASES, new_database); context.getSessionContext().setCurrentDatabase(new_database); return {}; } diff --git a/dbms/src/Interpreters/MutationsInterpreter.cpp b/dbms/src/Interpreters/MutationsInterpreter.cpp index 3acd04d99f6..056fd5b597a 100644 --- a/dbms/src/Interpreters/MutationsInterpreter.cpp +++ b/dbms/src/Interpreters/MutationsInterpreter.cpp @@ -164,8 +164,8 @@ bool isStorageTouchedByMutations( return true; } - context_copy.getSettingsRef().max_streams_to_max_threads_ratio = 1; - context_copy.getSettingsRef().max_threads = 1; + context_copy.setSetting("max_streams_to_max_threads_ratio", 1); + context_copy.setSetting("max_threads", 1); ASTPtr select_query = prepareQueryAffectedAST(commands); diff --git a/dbms/src/Interpreters/executeQuery.cpp b/dbms/src/Interpreters/executeQuery.cpp index fefca6b580f..5c8c587fcc8 100644 --- a/dbms/src/Interpreters/executeQuery.cpp +++ b/dbms/src/Interpreters/executeQuery.cpp @@ -24,7 +24,7 @@ #include -#include +#include #include #include #include @@ -148,7 +148,8 @@ static void logException(Context & context, QueryLogElement & elem) static void onExceptionBeforeStart(const String & query_for_logging, Context & context, time_t current_time) { /// Exception before the query execution. - context.getQuota()->used(Quota::ERRORS, 1, /* check_exceeded = */ false); + if (auto quota = context.getQuota()) + quota->used(Quota::ERRORS, 1, /* check_exceeded = */ false); const Settings & settings = context.getSettingsRef(); @@ -307,12 +308,15 @@ static std::tuple executeQueryImpl( auto interpreter = InterpreterFactory::get(ast, context, stage); bool use_processors = settings.experimental_use_processors && allow_processors && interpreter->canExecuteWithProcessors(); - QuotaContextPtr quota; + std::shared_ptr quota; if (!interpreter->ignoreQuota()) { quota = context.getQuota(); - quota->used(Quota::QUERIES, 1); - quota->checkExceeded(Quota::ERRORS); + if (quota) + { + quota->used(Quota::QUERIES, 1); + quota->checkExceeded(Quota::ERRORS); + } } IBlockInputStream::LocalLimits limits; @@ -486,9 +490,10 @@ static std::tuple executeQueryImpl( } }; - auto exception_callback = [elem, &context, log_queries] () mutable + auto exception_callback = [elem, &context, log_queries, quota(quota)] () mutable { - context.getQuota()->used(Quota::ERRORS, 1, /* check_exceeded = */ false); + if (quota) + quota->used(Quota::ERRORS, 1, /* check_exceeded = */ false); elem.type = QueryLogElement::EXCEPTION_WHILE_PROCESSING; diff --git a/dbms/src/Interpreters/tests/in_join_subqueries_preprocessor.cpp b/dbms/src/Interpreters/tests/in_join_subqueries_preprocessor.cpp index 47dcb406114..9a6d7ca4162 100644 --- a/dbms/src/Interpreters/tests/in_join_subqueries_preprocessor.cpp +++ b/dbms/src/Interpreters/tests/in_join_subqueries_preprocessor.cpp @@ -1170,9 +1170,7 @@ TestResult check(const TestEntry & entry) database->attachTable("visits_all", storage_distributed_visits); database->attachTable("hits_all", storage_distributed_hits); context.setCurrentDatabase("test"); - - auto & settings = context.getSettingsRef(); - settings.distributed_product_mode = entry.mode; + context.setSetting("distributed_product_mode", entry.mode); /// Parse and process the incoming query. DB::ASTPtr ast_input; diff --git a/dbms/src/Parsers/ASTCreateQuotaQuery.cpp b/dbms/src/Parsers/ASTCreateQuotaQuery.cpp index 7e1017ae0c3..7613fce6167 100644 --- a/dbms/src/Parsers/ASTCreateQuotaQuery.cpp +++ b/dbms/src/Parsers/ASTCreateQuotaQuery.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -94,7 +94,7 @@ namespace } } - void formatToRoles(const ASTGenericRoleSet & roles, const IAST::FormatSettings & settings) + void formatToRoles(const ASTExtendedRoleSet & roles, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); roles.format(settings); @@ -143,7 +143,7 @@ void ASTCreateQuotaQuery::formatImpl(const FormatSettings & settings, FormatStat formatAllLimits(all_limits, settings); - if (roles) + if (roles && (!roles->empty() || alter)) formatToRoles(*roles, settings); } } diff --git a/dbms/src/Parsers/ASTCreateQuotaQuery.h b/dbms/src/Parsers/ASTCreateQuotaQuery.h index 71b1b95d894..2968c2cc607 100644 --- a/dbms/src/Parsers/ASTCreateQuotaQuery.h +++ b/dbms/src/Parsers/ASTCreateQuotaQuery.h @@ -6,7 +6,7 @@ namespace DB { -class ASTGenericRoleSet; +class ASTExtendedRoleSet; /** CREATE QUOTA [IF NOT EXISTS | OR REPLACE] name @@ -53,7 +53,7 @@ public: }; std::vector all_limits; - std::shared_ptr roles; + std::shared_ptr roles; String getID(char) const override; ASTPtr clone() const override; diff --git a/dbms/src/Parsers/ASTCreateRoleQuery.cpp b/dbms/src/Parsers/ASTCreateRoleQuery.cpp index b511a466d2f..3d69e4dac59 100644 --- a/dbms/src/Parsers/ASTCreateRoleQuery.cpp +++ b/dbms/src/Parsers/ASTCreateRoleQuery.cpp @@ -1,4 +1,5 @@ #include +#include #include @@ -11,6 +12,12 @@ namespace settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " RENAME TO " << (settings.hilite ? IAST::hilite_none : "") << quoteString(new_name); } + + void formatSettings(const ASTSettingsProfileElements & settings, const IAST::FormatSettings & format) + { + format.ostr << (format.hilite ? IAST::hilite_keyword : "") << " SETTINGS " << (format.hilite ? IAST::hilite_none : ""); + settings.format(format); + } } @@ -26,28 +33,32 @@ ASTPtr ASTCreateRoleQuery::clone() const } -void ASTCreateRoleQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const +void ASTCreateRoleQuery::formatImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const { if (attach) { - settings.ostr << (settings.hilite ? hilite_keyword : "") << "ATTACH ROLE" << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << "ATTACH ROLE" << (format.hilite ? hilite_none : ""); } else { - settings.ostr << (settings.hilite ? hilite_keyword : "") << (alter ? "ALTER ROLE" : "CREATE ROLE") - << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << (alter ? "ALTER ROLE" : "CREATE ROLE") + << (format.hilite ? hilite_none : ""); } if (if_exists) - settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF EXISTS" << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << " IF EXISTS" << (format.hilite ? hilite_none : ""); else if (if_not_exists) - settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (format.hilite ? hilite_none : ""); else if (or_replace) - settings.ostr << (settings.hilite ? hilite_keyword : "") << " OR REPLACE" << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << " OR REPLACE" << (format.hilite ? hilite_none : ""); - settings.ostr << " " << backQuoteIfNeed(name); + format.ostr << " " << backQuoteIfNeed(name); if (!new_name.empty()) - formatRenameTo(new_name, settings); + formatRenameTo(new_name, format); + + if (settings && (!settings->empty() || alter)) + formatSettings(*settings, format); } + } diff --git a/dbms/src/Parsers/ASTCreateRoleQuery.h b/dbms/src/Parsers/ASTCreateRoleQuery.h index 5109492fc9e..69bb9896fa3 100644 --- a/dbms/src/Parsers/ASTCreateRoleQuery.h +++ b/dbms/src/Parsers/ASTCreateRoleQuery.h @@ -5,10 +5,15 @@ namespace DB { +class ASTSettingsProfileElements; + + /** CREATE ROLE [IF NOT EXISTS | OR REPLACE] name + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] * * ALTER ROLE [IF EXISTS] name - * [RENAME TO new_name] + * [RENAME TO new_name] + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] */ class ASTCreateRoleQuery : public IAST { @@ -23,8 +28,10 @@ public: String name; String new_name; + std::shared_ptr settings; + String getID(char) const override; ASTPtr clone() const override; - void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; + void formatImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const override; }; } diff --git a/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp b/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp index 0e3002d385f..ac3d859e66f 100644 --- a/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp +++ b/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -10,7 +10,7 @@ namespace DB { namespace { - using ConditionIndex = RowPolicy::ConditionIndex; + using ConditionType = RowPolicy::ConditionType; void formatRenameTo(const String & new_policy_name, const IAST::FormatSettings & settings) { @@ -37,13 +37,13 @@ namespace } - std::vector> - conditionalExpressionsToStrings(const std::vector> & exprs, const IAST::FormatSettings & settings) + std::vector> + conditionalExpressionsToStrings(const std::vector> & exprs, const IAST::FormatSettings & settings) { - std::vector> result; + std::vector> result; std::stringstream ss; IAST::FormatSettings temp_settings(ss, settings); - boost::range::transform(exprs, std::back_inserter(result), [&](const std::pair & in) + boost::range::transform(exprs, std::back_inserter(result), [&](const std::pair & in) { formatConditionalExpression(in.second, temp_settings); auto out = std::pair{in.first, ss.str()}; @@ -70,9 +70,9 @@ namespace } - void formatMultipleConditions(const std::vector> & conditions, bool alter, const IAST::FormatSettings & settings) + void formatMultipleConditions(const std::vector> & conditions, bool alter, const IAST::FormatSettings & settings) { - std::optional scond[RowPolicy::MAX_CONDITION_INDEX]; + std::optional scond[RowPolicy::MAX_CONDITION_TYPE]; for (const auto & [index, scondition] : conditionalExpressionsToStrings(conditions, settings)) scond[index] = scondition; @@ -112,7 +112,7 @@ namespace } } - void formatToRoles(const ASTGenericRoleSet & roles, const IAST::FormatSettings & settings) + void formatToRoles(const ASTExtendedRoleSet & roles, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); roles.format(settings); @@ -165,7 +165,7 @@ void ASTCreateRowPolicyQuery::formatImpl(const FormatSettings & settings, Format formatMultipleConditions(conditions, alter, settings); - if (roles) + if (roles && (!roles->empty() || alter)) formatToRoles(*roles, settings); } } diff --git a/dbms/src/Parsers/ASTCreateRowPolicyQuery.h b/dbms/src/Parsers/ASTCreateRowPolicyQuery.h index 9c233799639..e58ed0ec46c 100644 --- a/dbms/src/Parsers/ASTCreateRowPolicyQuery.h +++ b/dbms/src/Parsers/ASTCreateRowPolicyQuery.h @@ -8,7 +8,7 @@ namespace DB { -class ASTGenericRoleSet; +class ASTExtendedRoleSet; /** CREATE [ROW] POLICY [IF NOT EXISTS | OR REPLACE] name ON [database.]table * [AS {PERMISSIVE | RESTRICTIVE}] @@ -39,10 +39,10 @@ public: String new_policy_name; std::optional is_restrictive; - using ConditionIndex = RowPolicy::ConditionIndex; - std::vector> conditions; + using ConditionType = RowPolicy::ConditionType; + std::vector> conditions; - std::shared_ptr roles; + std::shared_ptr roles; String getID(char) const override; ASTPtr clone() const override; diff --git a/dbms/src/Parsers/ASTCreateSettingsProfileQuery.cpp b/dbms/src/Parsers/ASTCreateSettingsProfileQuery.cpp new file mode 100644 index 00000000000..a5a5556baf3 --- /dev/null +++ b/dbms/src/Parsers/ASTCreateSettingsProfileQuery.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include + + +namespace DB +{ +namespace +{ + void formatRenameTo(const String & new_name, const IAST::FormatSettings & settings) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " RENAME TO " << (settings.hilite ? IAST::hilite_none : "") + << quoteString(new_name); + } + + void formatSettings(const ASTSettingsProfileElements & settings, const IAST::FormatSettings & format) + { + format.ostr << (format.hilite ? IAST::hilite_keyword : "") << " SETTINGS " << (format.hilite ? IAST::hilite_none : ""); + settings.format(format); + } + + void formatToRoles(const ASTExtendedRoleSet & roles, const IAST::FormatSettings & settings) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); + roles.format(settings); + } +} + + +String ASTCreateSettingsProfileQuery::getID(char) const +{ + return "CreateSettingsProfileQuery"; +} + + +ASTPtr ASTCreateSettingsProfileQuery::clone() const +{ + return std::make_shared(*this); +} + + +void ASTCreateSettingsProfileQuery::formatImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const +{ + if (attach) + { + format.ostr << (format.hilite ? hilite_keyword : "") << "ATTACH SETTINGS PROFILE" << (format.hilite ? hilite_none : ""); + } + else + { + format.ostr << (format.hilite ? hilite_keyword : "") << (alter ? "ALTER SETTINGS PROFILE" : "CREATE SETTINGS PROFILE") + << (format.hilite ? hilite_none : ""); + } + + if (if_exists) + format.ostr << (format.hilite ? hilite_keyword : "") << " IF EXISTS" << (format.hilite ? hilite_none : ""); + else if (if_not_exists) + format.ostr << (format.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (format.hilite ? hilite_none : ""); + else if (or_replace) + format.ostr << (format.hilite ? hilite_keyword : "") << " OR REPLACE" << (format.hilite ? hilite_none : ""); + + format.ostr << " " << backQuoteIfNeed(name); + + if (!new_name.empty()) + formatRenameTo(new_name, format); + + if (settings && (!settings->empty() || alter)) + formatSettings(*settings, format); + + if (to_roles && (!to_roles->empty() || alter)) + formatToRoles(*to_roles, format); +} + +} diff --git a/dbms/src/Parsers/ASTCreateSettingsProfileQuery.h b/dbms/src/Parsers/ASTCreateSettingsProfileQuery.h new file mode 100644 index 00000000000..b3a60853e57 --- /dev/null +++ b/dbms/src/Parsers/ASTCreateSettingsProfileQuery.h @@ -0,0 +1,40 @@ +#pragma once + +#include + + +namespace DB +{ +class ASTSettingsProfileElements; +class ASTExtendedRoleSet; + + +/** CREATE SETTINGS PROFILE [IF NOT EXISTS | OR REPLACE] name + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] + * + * ALTER SETTINGS PROFILE [IF EXISTS] name + * [RENAME TO new_name] + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] + */ +class ASTCreateSettingsProfileQuery : public IAST +{ +public: + bool alter = false; + bool attach = false; + + bool if_exists = false; + bool if_not_exists = false; + bool or_replace = false; + + String name; + String new_name; + + std::shared_ptr settings; + + std::shared_ptr to_roles; + + String getID(char) const override; + ASTPtr clone() const override; + void formatImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const override; +}; +} diff --git a/dbms/src/Parsers/ASTCreateUserQuery.cpp b/dbms/src/Parsers/ASTCreateUserQuery.cpp index e848a5e0abb..0631d08ae74 100644 --- a/dbms/src/Parsers/ASTCreateUserQuery.cpp +++ b/dbms/src/Parsers/ASTCreateUserQuery.cpp @@ -1,5 +1,6 @@ #include -#include +#include +#include #include @@ -135,17 +136,17 @@ namespace } - void formatDefaultRoles(const ASTGenericRoleSet & default_roles, const IAST::FormatSettings & settings) + void formatDefaultRoles(const ASTExtendedRoleSet & default_roles, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " DEFAULT ROLE " << (settings.hilite ? IAST::hilite_none : ""); default_roles.format(settings); } - void formatProfile(const String & profile_name, const IAST::FormatSettings & settings) + void formatSettings(const ASTSettingsProfileElements & settings, const IAST::FormatSettings & format) { - settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " PROFILE " << (settings.hilite ? IAST::hilite_none : "") - << quoteString(profile_name); + format.ostr << (format.hilite ? IAST::hilite_keyword : "") << " SETTINGS " << (format.hilite ? IAST::hilite_none : ""); + settings.format(format); } } @@ -162,44 +163,44 @@ ASTPtr ASTCreateUserQuery::clone() const } -void ASTCreateUserQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const +void ASTCreateUserQuery::formatImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const { if (attach) { - settings.ostr << (settings.hilite ? hilite_keyword : "") << "ATTACH USER" << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << "ATTACH USER" << (format.hilite ? hilite_none : ""); } else { - settings.ostr << (settings.hilite ? hilite_keyword : "") << (alter ? "ALTER USER" : "CREATE USER") - << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << (alter ? "ALTER USER" : "CREATE USER") + << (format.hilite ? hilite_none : ""); } if (if_exists) - settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF EXISTS" << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << " IF EXISTS" << (format.hilite ? hilite_none : ""); else if (if_not_exists) - settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (format.hilite ? hilite_none : ""); else if (or_replace) - settings.ostr << (settings.hilite ? hilite_keyword : "") << " OR REPLACE" << (settings.hilite ? hilite_none : ""); + format.ostr << (format.hilite ? hilite_keyword : "") << " OR REPLACE" << (format.hilite ? hilite_none : ""); - settings.ostr << " " << backQuoteIfNeed(name); + format.ostr << " " << backQuoteIfNeed(name); if (!new_name.empty()) - formatRenameTo(new_name, settings); + formatRenameTo(new_name, format); if (authentication) - formatAuthentication(*authentication, settings); + formatAuthentication(*authentication, format); if (hosts) - formatHosts(nullptr, *hosts, settings); + formatHosts(nullptr, *hosts, format); if (add_hosts) - formatHosts("ADD", *add_hosts, settings); + formatHosts("ADD", *add_hosts, format); if (remove_hosts) - formatHosts("REMOVE", *remove_hosts, settings); + formatHosts("REMOVE", *remove_hosts, format); if (default_roles) - formatDefaultRoles(*default_roles, settings); + formatDefaultRoles(*default_roles, format); - if (profile) - formatProfile(*profile, settings); + if (settings && (!settings->empty() || alter)) + formatSettings(*settings, format); } } diff --git a/dbms/src/Parsers/ASTCreateUserQuery.h b/dbms/src/Parsers/ASTCreateUserQuery.h index d6db56a408f..fc2aa0121ed 100644 --- a/dbms/src/Parsers/ASTCreateUserQuery.h +++ b/dbms/src/Parsers/ASTCreateUserQuery.h @@ -7,20 +7,21 @@ namespace DB { -class ASTGenericRoleSet; +class ASTExtendedRoleSet; +class ASTSettingsProfileElements; /** CREATE USER [IF NOT EXISTS | OR REPLACE] name - * [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}] - * [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] - * [DEFAULT ROLE role [,...]] - * [PROFILE 'profile_name'] + * [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}] + * [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] + * [DEFAULT ROLE role [,...]] + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] * * ALTER USER [IF EXISTS] name * [RENAME TO new_name] * [IDENTIFIED [WITH {PLAINTEXT_PASSWORD|SHA256_PASSWORD|DOUBLE_SHA1_PASSWORD}] BY {'password'|'hash'}] * [[ADD|REMOVE] HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] * [DEFAULT ROLE role [,...] | ALL | ALL EXCEPT role [,...] ] - * [PROFILE 'profile_name'] + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] */ class ASTCreateUserQuery : public IAST { @@ -41,12 +42,12 @@ public: std::optional add_hosts; std::optional remove_hosts; - std::shared_ptr default_roles; + std::shared_ptr default_roles; - std::optional profile; + std::shared_ptr settings; String getID(char) const override; ASTPtr clone() const override; - void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; + void formatImpl(const FormatSettings & format, FormatState &, FormatStateStacked) const override; }; } diff --git a/dbms/src/Parsers/ASTDropAccessEntityQuery.cpp b/dbms/src/Parsers/ASTDropAccessEntityQuery.cpp index 0b6bae7575e..3896128ceb5 100644 --- a/dbms/src/Parsers/ASTDropAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ASTDropAccessEntityQuery.cpp @@ -8,14 +8,15 @@ namespace { using Kind = ASTDropAccessEntityQuery::Kind; - const char * kindToKeyword(Kind kind) + const char * getKeyword(Kind kind) { switch (kind) { case Kind::USER: return "USER"; case Kind::ROLE: return "ROLE"; case Kind::QUOTA: return "QUOTA"; - case Kind::ROW_POLICY: return "POLICY"; + case Kind::ROW_POLICY: return "ROW POLICY"; + case Kind::SETTINGS_PROFILE: return "SETTINGS PROFILE"; } __builtin_unreachable(); } @@ -23,14 +24,14 @@ namespace ASTDropAccessEntityQuery::ASTDropAccessEntityQuery(Kind kind_) - : kind(kind_), keyword(kindToKeyword(kind_)) + : kind(kind_) { } String ASTDropAccessEntityQuery::getID(char) const { - return String("DROP ") + keyword + " query"; + return String("DROP ") + getKeyword(kind) + " query"; } @@ -43,7 +44,7 @@ ASTPtr ASTDropAccessEntityQuery::clone() const void ASTDropAccessEntityQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { settings.ostr << (settings.hilite ? hilite_keyword : "") - << "DROP " << keyword + << "DROP " << getKeyword(kind) << (if_exists ? " IF EXISTS" : "") << (settings.hilite ? hilite_none : ""); diff --git a/dbms/src/Parsers/ASTDropAccessEntityQuery.h b/dbms/src/Parsers/ASTDropAccessEntityQuery.h index eea40fd5343..5f0b46bd896 100644 --- a/dbms/src/Parsers/ASTDropAccessEntityQuery.h +++ b/dbms/src/Parsers/ASTDropAccessEntityQuery.h @@ -11,6 +11,7 @@ namespace DB * DROP ROLE [IF EXISTS] name [,...] * DROP QUOTA [IF EXISTS] name [,...] * DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...] + * DROP [SETTINGS] PROFILE [IF EXISTS] name [,...] */ class ASTDropAccessEntityQuery : public IAST { @@ -21,11 +22,10 @@ public: ROLE, QUOTA, ROW_POLICY, + SETTINGS_PROFILE, }; const Kind kind; - const char * const keyword; - bool if_exists = false; Strings names; std::vector row_policies_names; diff --git a/dbms/src/Parsers/ASTGenericRoleSet.cpp b/dbms/src/Parsers/ASTExtendedRoleSet.cpp similarity index 93% rename from dbms/src/Parsers/ASTGenericRoleSet.cpp rename to dbms/src/Parsers/ASTExtendedRoleSet.cpp index 50f2b0adc7e..3ac1052897d 100644 --- a/dbms/src/Parsers/ASTGenericRoleSet.cpp +++ b/dbms/src/Parsers/ASTExtendedRoleSet.cpp @@ -1,4 +1,4 @@ -#include +#include #include @@ -20,7 +20,7 @@ namespace } } -void ASTGenericRoleSet::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const +void ASTExtendedRoleSet::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { if (empty()) { diff --git a/dbms/src/Parsers/ASTGenericRoleSet.h b/dbms/src/Parsers/ASTExtendedRoleSet.h similarity index 74% rename from dbms/src/Parsers/ASTGenericRoleSet.h rename to dbms/src/Parsers/ASTExtendedRoleSet.h index b9a1ab99248..84190211087 100644 --- a/dbms/src/Parsers/ASTGenericRoleSet.h +++ b/dbms/src/Parsers/ASTExtendedRoleSet.h @@ -1,14 +1,13 @@ #pragma once #include -#include namespace DB { /// Represents a set of users/roles like /// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...] -class ASTGenericRoleSet : public IAST +class ASTExtendedRoleSet : public IAST { public: Strings names; @@ -20,8 +19,8 @@ public: bool empty() const { return names.empty() && !current_user && !all; } - String getID(char) const override { return "GenericRoleSet"; } - ASTPtr clone() const override { return std::make_shared(*this); } + String getID(char) const override { return "ExtendedRoleSet"; } + ASTPtr clone() const override { return std::make_shared(*this); } void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; }; } diff --git a/dbms/src/Parsers/ASTGrantQuery.cpp b/dbms/src/Parsers/ASTGrantQuery.cpp index 9365e1b96b7..94521d790f2 100644 --- a/dbms/src/Parsers/ASTGrantQuery.cpp +++ b/dbms/src/Parsers/ASTGrantQuery.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -97,7 +97,7 @@ namespace } - void formatToRoles(const ASTGenericRoleSet & to_roles, ASTGrantQuery::Kind kind, const IAST::FormatSettings & settings) + void formatToRoles(const ASTExtendedRoleSet & to_roles, ASTGrantQuery::Kind kind, const IAST::FormatSettings & settings) { using Kind = ASTGrantQuery::Kind; settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ((kind == Kind::GRANT) ? " TO " : " FROM ") diff --git a/dbms/src/Parsers/ASTGrantQuery.h b/dbms/src/Parsers/ASTGrantQuery.h index 8ce3d9c20dc..95b5f0b8448 100644 --- a/dbms/src/Parsers/ASTGrantQuery.h +++ b/dbms/src/Parsers/ASTGrantQuery.h @@ -6,7 +6,7 @@ namespace DB { -class ASTGenericRoleSet; +class ASTExtendedRoleSet; /** GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO {user_name | CURRENT_USER} [,...] [WITH GRANT OPTION] @@ -26,8 +26,8 @@ public: Kind kind = Kind::GRANT; bool attach = false; AccessRightsElements access_rights_elements; - std::shared_ptr roles; - std::shared_ptr to_roles; + std::shared_ptr roles; + std::shared_ptr to_roles; bool grant_option = false; bool admin_option = false; diff --git a/dbms/src/Parsers/ASTSetRoleQuery.cpp b/dbms/src/Parsers/ASTSetRoleQuery.cpp index de61f5a3113..0c8842fdac6 100644 --- a/dbms/src/Parsers/ASTSetRoleQuery.cpp +++ b/dbms/src/Parsers/ASTSetRoleQuery.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include diff --git a/dbms/src/Parsers/ASTSetRoleQuery.h b/dbms/src/Parsers/ASTSetRoleQuery.h index ad22d30e287..8f1fb357d86 100644 --- a/dbms/src/Parsers/ASTSetRoleQuery.h +++ b/dbms/src/Parsers/ASTSetRoleQuery.h @@ -5,7 +5,7 @@ namespace DB { -class ASTGenericRoleSet; +class ASTExtendedRoleSet; /** SET ROLE {DEFAULT | NONE | role [,...] | ALL | ALL EXCEPT role [,...]} * SET DEFAULT ROLE {NONE | role [,...] | ALL | ALL EXCEPT role [,...]} TO {user|CURRENT_USER} [,...] @@ -21,8 +21,8 @@ public: }; Kind kind = Kind::SET_ROLE; - std::shared_ptr roles; - std::shared_ptr to_users; + std::shared_ptr roles; + std::shared_ptr to_users; String getID(char) const override; ASTPtr clone() const override; diff --git a/dbms/src/Parsers/ASTSettingsProfileElement.cpp b/dbms/src/Parsers/ASTSettingsProfileElement.cpp new file mode 100644 index 00000000000..b3f4032d14c --- /dev/null +++ b/dbms/src/Parsers/ASTSettingsProfileElement.cpp @@ -0,0 +1,88 @@ +#include +#include +#include + + +namespace DB +{ +namespace +{ + void formatProfileNameOrID(const String & str, bool is_id, const IAST::FormatSettings & settings) + { + if (is_id) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "ID" << (settings.hilite ? IAST::hilite_none : "") << "(" + << quoteString(str) << ")"; + } + else + { + settings.ostr << backQuoteIfNeed(str); + } + } +} + +void ASTSettingsProfileElement::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const +{ + if (!parent_profile.empty()) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "PROFILE " << (settings.hilite ? IAST::hilite_none : ""); + formatProfileNameOrID(parent_profile, id_mode, settings); + return; + } + + settings.ostr << name; + + if (!value.isNull()) + { + settings.ostr << " = " << applyVisitor(FieldVisitorToString{}, value); + } + + if (!min_value.isNull()) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " MIN " << (settings.hilite ? IAST::hilite_none : "") + << applyVisitor(FieldVisitorToString{}, min_value); + } + + if (!max_value.isNull()) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " MAX " << (settings.hilite ? IAST::hilite_none : "") + << applyVisitor(FieldVisitorToString{}, max_value); + } + + if (readonly) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << (*readonly ? " READONLY" : " WRITABLE") + << (settings.hilite ? IAST::hilite_none : ""); + } +} + + +bool ASTSettingsProfileElements::empty() const +{ + for (const auto & element : elements) + if (!element->empty()) + return false; + return true; +} + + +void ASTSettingsProfileElements::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const +{ + if (empty()) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "NONE" << (settings.hilite ? IAST::hilite_none : ""); + return; + } + + bool need_comma = false; + for (const auto & element : elements) + { + if (need_comma) + settings.ostr << ", "; + need_comma = true; + + element->format(settings); + } +} + +} diff --git a/dbms/src/Parsers/ASTSettingsProfileElement.h b/dbms/src/Parsers/ASTSettingsProfileElement.h new file mode 100644 index 00000000000..0470b51cf85 --- /dev/null +++ b/dbms/src/Parsers/ASTSettingsProfileElement.h @@ -0,0 +1,45 @@ +#pragma once + +#include +#include + + +namespace DB +{ +/** Represents a settings profile's element like the following + * {variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE]} | PROFILE 'profile_name' + */ +class ASTSettingsProfileElement : public IAST +{ +public: + String parent_profile; + String name; + Field value; + Field min_value; + Field max_value; + std::optional readonly; + bool id_mode = false; /// If true then `parent_profile` keeps UUID, not a name. + + bool empty() const { return parent_profile.empty() && name.empty(); } + + String getID(char) const override { return "SettingsProfileElement"; } + ASTPtr clone() const override { return std::make_shared(*this); } + void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; +}; + + +/** Represents settings profile's elements like the following + * {{variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE]} | PROFILE 'profile_name'} [,...] + */ +class ASTSettingsProfileElements : public IAST +{ +public: + std::vector> elements; + + bool empty() const; + + String getID(char) const override { return "SettingsProfileElements"; } + ASTPtr clone() const override { return std::make_shared(*this); } + void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; +}; +} diff --git a/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.cpp b/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.cpp index 4201a733f43..9e562043f09 100644 --- a/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.cpp @@ -8,13 +8,15 @@ namespace { using Kind = ASTShowCreateAccessEntityQuery::Kind; - const char * kindToKeyword(Kind kind) + const char * getKeyword(Kind kind) { switch (kind) { case Kind::USER: return "USER"; + case Kind::ROLE: return "ROLE"; case Kind::QUOTA: return "QUOTA"; - case Kind::ROW_POLICY: return "POLICY"; + case Kind::ROW_POLICY: return "ROW POLICY"; + case Kind::SETTINGS_PROFILE: return "SETTINGS PROFILE"; } __builtin_unreachable(); } @@ -22,14 +24,14 @@ namespace ASTShowCreateAccessEntityQuery::ASTShowCreateAccessEntityQuery(Kind kind_) - : kind(kind_), keyword(kindToKeyword(kind_)) + : kind(kind_) { } String ASTShowCreateAccessEntityQuery::getID(char) const { - return String("SHOW CREATE ") + keyword + " query"; + return String("SHOW CREATE ") + getKeyword(kind) + " query"; } @@ -42,13 +44,13 @@ ASTPtr ASTShowCreateAccessEntityQuery::clone() const void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { settings.ostr << (settings.hilite ? hilite_keyword : "") - << "SHOW CREATE " << keyword + << "SHOW CREATE " << getKeyword(kind) << (settings.hilite ? hilite_none : ""); - if ((kind == Kind::USER) && current_user) + if (current_user) { } - else if ((kind == Kind::QUOTA) && current_quota) + else if (current_quota) settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : ""); else if (kind == Kind::ROW_POLICY) { diff --git a/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.h b/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.h index 43fa215f64c..e76a9177979 100644 --- a/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.h +++ b/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.h @@ -9,6 +9,8 @@ namespace DB /** SHOW CREATE QUOTA [name | CURRENT] * SHOW CREATE [ROW] POLICY name ON [database.]table * SHOW CREATE USER [name | CURRENT_USER] + * SHOW CREATE ROLE name + * SHOW CREATE [SETTINGS] PROFILE name */ class ASTShowCreateAccessEntityQuery : public ASTQueryWithOutput { @@ -16,12 +18,13 @@ public: enum class Kind { USER, + ROLE, QUOTA, ROW_POLICY, + SETTINGS_PROFILE, }; - const Kind kind; - const char * const keyword; + const Kind kind; String name; bool current_quota = false; bool current_user = false; diff --git a/dbms/src/Parsers/ParserCreateQuotaQuery.cpp b/dbms/src/Parsers/ParserCreateQuotaQuery.cpp index c03fb14874c..9a6afec6941 100644 --- a/dbms/src/Parsers/ParserCreateQuotaQuery.cpp +++ b/dbms/src/Parsers/ParserCreateQuotaQuery.cpp @@ -3,10 +3,10 @@ #include #include #include -#include +#include #include #include -#include +#include #include #include @@ -187,15 +187,15 @@ namespace }); } - bool parseToRoles(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & roles) + bool parseToRoles(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & roles) { return IParserBase::wrapParseImpl(pos, [&] { ASTPtr node; - if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserGenericRoleSet{}.enableIDMode(id_mode).parse(pos, node, expected)) + if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserExtendedRoleSet{}.useIDMode(id_mode).parse(pos, node, expected)) return false; - roles = std::static_pointer_cast(node); + roles = std::static_pointer_cast(node); return true; }); } @@ -205,12 +205,10 @@ namespace bool ParserCreateQuotaQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { bool alter = false; - bool attach = false; if (attach_mode) { if (!ParserKeyword{"ATTACH QUOTA"}.ignore(pos, expected)) return false; - attach = true; } else { @@ -243,7 +241,6 @@ bool ParserCreateQuotaQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expe String new_name; std::optional key_type; std::vector all_limits; - std::shared_ptr roles; while (true) { @@ -256,12 +253,12 @@ bool ParserCreateQuotaQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expe if (parseAllLimits(pos, expected, alter, all_limits)) continue; - if (!roles && parseToRoles(pos, expected, attach, roles)) - continue; - break; } + std::shared_ptr roles; + parseToRoles(pos, expected, attach_mode, roles); + auto query = std::make_shared(); node = query; diff --git a/dbms/src/Parsers/ParserCreateRoleQuery.cpp b/dbms/src/Parsers/ParserCreateRoleQuery.cpp index 5a4ef016f77..e2b42c976b4 100644 --- a/dbms/src/Parsers/ParserCreateRoleQuery.cpp +++ b/dbms/src/Parsers/ParserCreateRoleQuery.cpp @@ -3,6 +3,8 @@ #include #include #include +#include +#include #include @@ -20,18 +22,35 @@ namespace return parseRoleName(pos, expected, new_name); }); } + + bool parseSettings(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & settings) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserKeyword{"SETTINGS"}.ignore(pos, expected)) + return false; + + ASTPtr new_settings_ast; + if (!ParserSettingsProfileElements{}.useIDMode(id_mode).parse(pos, new_settings_ast, expected)) + return false; + + if (!settings) + settings = std::make_shared(); + const auto & new_settings = new_settings_ast->as(); + settings->elements.insert(settings->elements.end(), new_settings.elements.begin(), new_settings.elements.end()); + return true; + }); + } } bool ParserCreateRoleQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { - bool attach = false; bool alter = false; if (attach_mode) { if (!ParserKeyword{"ATTACH ROLE"}.ignore(pos, expected)) return false; - attach = true; } else { @@ -62,19 +81,29 @@ bool ParserCreateRoleQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec return false; String new_name; - if (alter) - parseRenameTo(pos, expected, new_name); + std::shared_ptr settings; + while (true) + { + if (alter && parseRenameTo(pos, expected, new_name)) + continue; + + if (parseSettings(pos, expected, attach_mode, settings)) + continue; + + break; + } auto query = std::make_shared(); node = query; query->alter = alter; - query->attach = attach; + query->attach = attach_mode; query->if_exists = if_exists; query->if_not_exists = if_not_exists; query->or_replace = or_replace; query->name = std::move(name); query->new_name = std::move(new_name); + query->settings = std::move(settings); return true; } diff --git a/dbms/src/Parsers/ParserCreateRoleQuery.h b/dbms/src/Parsers/ParserCreateRoleQuery.h index a1690687282..2afeb7f7ec4 100644 --- a/dbms/src/Parsers/ParserCreateRoleQuery.h +++ b/dbms/src/Parsers/ParserCreateRoleQuery.h @@ -7,9 +7,11 @@ namespace DB { /** Parses queries like * CREATE ROLE [IF NOT EXISTS | OR REPLACE] name + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] * * ALTER ROLE [IF EXISTS] name - * [RENAME TO new_name] + * [RENAME TO new_name] + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] */ class ParserCreateRoleQuery : public IParserBase { diff --git a/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp b/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp index ff865f3644f..ab0fbc87e12 100644 --- a/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp +++ b/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp @@ -1,8 +1,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include @@ -19,7 +19,7 @@ namespace ErrorCodes namespace { - using ConditionIndex = RowPolicy::ConditionIndex; + using ConditionType = RowPolicy::ConditionType; bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_policy_name) { @@ -73,7 +73,7 @@ namespace }); } - bool parseConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector> & conditions) + bool parseConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector> & conditions) { return IParserBase::wrapParseImpl(pos, [&] { @@ -136,14 +136,14 @@ namespace if (filter && !check && !alter) check = filter; - auto set_condition = [&](ConditionIndex index, const ASTPtr & condition) + auto set_condition = [&](ConditionType index, const ASTPtr & condition) { - auto it = std::find_if(conditions.begin(), conditions.end(), [index](const std::pair & element) + auto it = std::find_if(conditions.begin(), conditions.end(), [index](const std::pair & element) { return element.first == index; }); if (it == conditions.end()) - it = conditions.insert(conditions.end(), std::pair{index, nullptr}); + it = conditions.insert(conditions.end(), std::pair{index, nullptr}); it->second = condition; }; @@ -170,11 +170,11 @@ namespace }); } - bool parseMultipleConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector> & conditions) + bool parseMultipleConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector> & conditions) { return IParserBase::wrapParseImpl(pos, [&] { - std::vector> res_conditions; + std::vector> res_conditions; do { if (!parseConditions(pos, expected, alter, res_conditions)) @@ -187,16 +187,16 @@ namespace }); } - bool parseToRoles(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & roles) + bool parseToRoles(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & roles) { return IParserBase::wrapParseImpl(pos, [&] { ASTPtr ast; if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) - || !ParserGenericRoleSet{}.enableIDMode(id_mode).parse(pos, ast, expected)) + || !ParserExtendedRoleSet{}.useIDMode(id_mode).parse(pos, ast, expected)) return false; - roles = std::static_pointer_cast(ast); + roles = std::static_pointer_cast(ast); return true; }); } @@ -206,12 +206,10 @@ namespace bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { bool alter = false; - bool attach = false; if (attach_mode) { if (!ParserKeyword{"ATTACH POLICY"}.ignore(pos, expected) && !ParserKeyword{"ATTACH ROW POLICY"}.ignore(pos, expected)) return false; - attach = true; } else { @@ -247,8 +245,7 @@ bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & String new_policy_name; std::optional is_restrictive; - std::vector> conditions; - std::shared_ptr roles; + std::vector> conditions; while (true) { @@ -261,17 +258,17 @@ bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & if (parseMultipleConditions(pos, expected, alter, conditions)) continue; - if (!roles && parseToRoles(pos, expected, attach, roles)) - continue; - break; } + std::shared_ptr roles; + parseToRoles(pos, expected, attach_mode, roles); + auto query = std::make_shared(); node = query; query->alter = alter; - query->attach = attach; + query->attach = attach_mode; query->if_exists = if_exists; query->if_not_exists = if_not_exists; query->or_replace = or_replace; diff --git a/dbms/src/Parsers/ParserCreateSettingsProfileQuery.cpp b/dbms/src/Parsers/ParserCreateSettingsProfileQuery.cpp new file mode 100644 index 00000000000..c7c9e064f6c --- /dev/null +++ b/dbms/src/Parsers/ParserCreateSettingsProfileQuery.cpp @@ -0,0 +1,130 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace +{ + bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) + return false; + + return parseIdentifierOrStringLiteral(pos, expected, new_name); + }); + } + + bool parseSettings(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & settings) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserKeyword{"SETTINGS"}.ignore(pos, expected)) + return false; + + ASTPtr new_settings_ast; + if (!ParserSettingsProfileElements{}.useIDMode(id_mode).parse(pos, new_settings_ast, expected)) + return false; + + if (!settings) + settings = std::make_shared(); + const auto & new_settings = new_settings_ast->as(); + settings->elements.insert(settings->elements.end(), new_settings.elements.begin(), new_settings.elements.end()); + return true; + }); + } + + bool parseToRoles(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & roles) + { + return IParserBase::wrapParseImpl(pos, [&] + { + ASTPtr ast; + if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) + || !ParserExtendedRoleSet{}.useIDMode(id_mode).parse(pos, ast, expected)) + return false; + + roles = std::static_pointer_cast(ast); + return true; + }); + } +} + + +bool ParserCreateSettingsProfileQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + bool alter = false; + if (attach_mode) + { + if (!ParserKeyword{"ATTACH SETTINGS PROFILE"}.ignore(pos, expected) && !ParserKeyword{"ATTACH PROFILE"}.ignore(pos, expected)) + return false; + } + else + { + if (ParserKeyword{"ALTER SETTINGS PROFILE"}.ignore(pos, expected) || ParserKeyword{"ALTER PROFILE"}.ignore(pos, expected)) + alter = true; + else if (!ParserKeyword{"CREATE SETTINGS PROFILE"}.ignore(pos, expected) && !ParserKeyword{"CREATE PROFILE"}.ignore(pos, expected)) + return false; + } + + bool if_exists = false; + bool if_not_exists = false; + bool or_replace = false; + if (alter) + { + if (ParserKeyword{"IF EXISTS"}.ignore(pos, expected)) + if_exists = true; + } + else + { + if (ParserKeyword{"IF NOT EXISTS"}.ignore(pos, expected)) + if_not_exists = true; + else if (ParserKeyword{"OR REPLACE"}.ignore(pos, expected)) + or_replace = true; + } + + String name; + if (!parseIdentifierOrStringLiteral(pos, expected, name)) + return false; + + String new_name; + std::shared_ptr settings; + while (true) + { + if (alter && parseRenameTo(pos, expected, new_name)) + continue; + + if (parseSettings(pos, expected, attach_mode, settings)) + continue; + + break; + } + + std::shared_ptr to_roles; + parseToRoles(pos, expected, attach_mode, to_roles); + + auto query = std::make_shared(); + node = query; + + query->alter = alter; + query->attach = attach_mode; + query->if_exists = if_exists; + query->if_not_exists = if_not_exists; + query->or_replace = or_replace; + query->name = std::move(name); + query->new_name = std::move(new_name); + query->settings = std::move(settings); + query->to_roles = std::move(to_roles); + + return true; +} +} diff --git a/dbms/src/Parsers/ParserCreateSettingsProfileQuery.h b/dbms/src/Parsers/ParserCreateSettingsProfileQuery.h new file mode 100644 index 00000000000..6797fc884fa --- /dev/null +++ b/dbms/src/Parsers/ParserCreateSettingsProfileQuery.h @@ -0,0 +1,28 @@ +#pragma once + +#include + + +namespace DB +{ +/** Parses queries like + * CREATE SETTINGS PROFILE [IF NOT EXISTS | OR REPLACE] name + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] + * + * ALTER SETTINGS PROFILE [IF EXISTS] name + * [RENAME TO new_name] + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] + */ +class ParserCreateSettingsProfileQuery : public IParserBase +{ +public: + ParserCreateSettingsProfileQuery & enableAttachMode(bool enable) { attach_mode = enable; return *this; } + +protected: + const char * getName() const override { return "CREATE SETTINGS PROFILE or ALTER SETTINGS PROFILE query"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; + +private: + bool attach_mode = false; +}; +} diff --git a/dbms/src/Parsers/ParserCreateUserQuery.cpp b/dbms/src/Parsers/ParserCreateUserQuery.cpp index c0c4196acee..a7cc6550644 100644 --- a/dbms/src/Parsers/ParserCreateUserQuery.cpp +++ b/dbms/src/Parsers/ParserCreateUserQuery.cpp @@ -5,8 +5,10 @@ #include #include #include -#include -#include +#include +#include +#include +#include #include #include @@ -208,7 +210,7 @@ namespace } - bool parseDefaultRoles(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & default_roles) + bool parseDefaultRoles(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & default_roles) { return IParserBase::wrapParseImpl(pos, [&] { @@ -216,27 +218,30 @@ namespace return false; ASTPtr ast; - if (!ParserGenericRoleSet{}.enableCurrentUserKeyword(false).enableIDMode(id_mode).parse(pos, ast, expected)) + if (!ParserExtendedRoleSet{}.enableCurrentUserKeyword(false).useIDMode(id_mode).parse(pos, ast, expected)) return false; - default_roles = typeid_cast>(ast); + default_roles = typeid_cast>(ast); return true; }); } - bool parseProfileName(IParserBase::Pos & pos, Expected & expected, std::optional & profile) + bool parseSettings(IParserBase::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & settings) { return IParserBase::wrapParseImpl(pos, [&] { - if (!ParserKeyword{"PROFILE"}.ignore(pos, expected)) + if (!ParserKeyword{"SETTINGS"}.ignore(pos, expected)) return false; - ASTPtr ast; - if (!ParserStringLiteral{}.parse(pos, ast, expected)) + ASTPtr new_settings_ast; + if (!ParserSettingsProfileElements{}.useIDMode(id_mode).parse(pos, new_settings_ast, expected)) return false; - profile = ast->as().value.safeGet(); + if (!settings) + settings = std::make_shared(); + const auto & new_settings = new_settings_ast->as(); + settings->elements.insert(settings->elements.end(), new_settings.elements.begin(), new_settings.elements.end()); return true; }); } @@ -246,12 +251,9 @@ namespace bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { bool alter = false; - bool attach = false; if (attach_mode) { - if (ParserKeyword{"ATTACH USER"}.ignore(pos, expected)) - attach = true; - else + if (!ParserKeyword{"ATTACH USER"}.ignore(pos, expected)) return false; } else @@ -289,8 +291,8 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec std::optional hosts; std::optional add_hosts; std::optional remove_hosts; - std::shared_ptr default_roles; - std::optional profile; + std::shared_ptr default_roles; + std::shared_ptr settings; while (true) { @@ -300,10 +302,10 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec if (parseHosts(pos, expected, nullptr, hosts)) continue; - if (!profile && parseProfileName(pos, expected, profile)) + if (parseSettings(pos, expected, attach_mode, settings)) continue; - if (!default_roles && parseDefaultRoles(pos, expected, attach, default_roles)) + if (!default_roles && parseDefaultRoles(pos, expected, attach_mode, default_roles)) continue; if (alter) @@ -330,7 +332,7 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec node = query; query->alter = alter; - query->attach = attach; + query->attach = attach_mode; query->if_exists = if_exists; query->if_not_exists = if_not_exists; query->or_replace = or_replace; @@ -341,7 +343,7 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec query->add_hosts = std::move(add_hosts); query->remove_hosts = std::move(remove_hosts); query->default_roles = std::move(default_roles); - query->profile = std::move(profile); + query->settings = std::move(settings); return true; } diff --git a/dbms/src/Parsers/ParserCreateUserQuery.h b/dbms/src/Parsers/ParserCreateUserQuery.h index 85e0ada7cf6..bd6ab74d53f 100644 --- a/dbms/src/Parsers/ParserCreateUserQuery.h +++ b/dbms/src/Parsers/ParserCreateUserQuery.h @@ -7,15 +7,15 @@ namespace DB { /** Parses queries like * CREATE USER [IF NOT EXISTS | OR REPLACE] name - * [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}] - * [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] - * [PROFILE 'profile_name'] + * [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}] + * [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] * * ALTER USER [IF EXISTS] name - * [RENAME TO new_name] - * [IDENTIFIED [WITH {PLAINTEXT_PASSWORD|SHA256_PASSWORD|DOUBLE_SHA1_PASSWORD}] BY {'password'|'hash'}] - * [[ADD|REMOVE] HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] - * [PROFILE 'profile_name'] + * [RENAME TO new_name] + * [IDENTIFIED [WITH {PLAINTEXT_PASSWORD|SHA256_PASSWORD|DOUBLE_SHA1_PASSWORD}] BY {'password'|'hash'}] + * [[ADD|REMOVE] HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] + * [SETTINGS variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE] | PROFILE 'profile_name'] [,...] */ class ParserCreateUserQuery : public IParserBase { diff --git a/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp b/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp index f257dc0fd64..23e18d7d32c 100644 --- a/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp @@ -4,7 +4,6 @@ #include #include #include -#include namespace DB @@ -90,6 +89,8 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & kind = Kind::QUOTA; else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected)) kind = Kind::ROW_POLICY; + else if (ParserKeyword{"SETTINGS PROFILE"}.ignore(pos, expected) || ParserKeyword{"PROFILE"}.ignore(pos, expected)) + kind = Kind::SETTINGS_PROFILE; else return false; @@ -112,7 +113,6 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & } else { - assert(kind == Kind::QUOTA); if (!parseNames(pos, expected, names)) return false; } diff --git a/dbms/src/Parsers/ParserDropAccessEntityQuery.h b/dbms/src/Parsers/ParserDropAccessEntityQuery.h index e4fb323d5f6..fd9149ba03a 100644 --- a/dbms/src/Parsers/ParserDropAccessEntityQuery.h +++ b/dbms/src/Parsers/ParserDropAccessEntityQuery.h @@ -9,12 +9,13 @@ namespace DB * DROP USER [IF EXISTS] name [,...] * DROP ROLE [IF EXISTS] name [,...] * DROP QUOTA [IF EXISTS] name [,...] + * DROP [SETTINGS] PROFILE [IF EXISTS] name [,...] * DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...] */ class ParserDropAccessEntityQuery : public IParserBase { protected: - const char * getName() const override { return "DROP QUOTA query"; } + const char * getName() const override { return "DROP access entity query"; } bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; }; } diff --git a/dbms/src/Parsers/ParserGenericRoleSet.cpp b/dbms/src/Parsers/ParserExtendedRoleSet.cpp similarity index 94% rename from dbms/src/Parsers/ParserGenericRoleSet.cpp rename to dbms/src/Parsers/ParserExtendedRoleSet.cpp index a58c638e36d..80f05c45f5b 100644 --- a/dbms/src/Parsers/ParserGenericRoleSet.cpp +++ b/dbms/src/Parsers/ParserExtendedRoleSet.cpp @@ -1,8 +1,8 @@ -#include +#include #include #include #include -#include +#include #include #include @@ -109,7 +109,7 @@ namespace } -bool ParserGenericRoleSet::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +bool ParserExtendedRoleSet::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { Strings names; bool current_user = false; @@ -125,7 +125,7 @@ bool ParserGenericRoleSet::parseImpl(Pos & pos, ASTPtr & node, Expected & expect if (all) names.clear(); - auto result = std::make_shared(); + auto result = std::make_shared(); result->names = std::move(names); result->current_user = current_user; result->all = all; diff --git a/dbms/src/Parsers/ParserExtendedRoleSet.h b/dbms/src/Parsers/ParserExtendedRoleSet.h new file mode 100644 index 00000000000..df723786bd9 --- /dev/null +++ b/dbms/src/Parsers/ParserExtendedRoleSet.h @@ -0,0 +1,28 @@ +#pragma once + +#include + + +namespace DB +{ +/** Parses a string like this: + * {role|CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {role|CURRENT_USER} [,...] + */ +class ParserExtendedRoleSet : public IParserBase +{ +public: + ParserExtendedRoleSet & enableAllKeyword(bool enable_) { all_keyword = enable_; return *this; } + ParserExtendedRoleSet & enableCurrentUserKeyword(bool enable_) { current_user_keyword = enable_; return *this; } + ParserExtendedRoleSet & useIDMode(bool enable_) { id_mode = enable_; return *this; } + +protected: + const char * getName() const override { return "ExtendedRoleSet"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; + +private: + bool all_keyword = true; + bool current_user_keyword = true; + bool id_mode = false; +}; + +} diff --git a/dbms/src/Parsers/ParserGenericRoleSet.h b/dbms/src/Parsers/ParserGenericRoleSet.h deleted file mode 100644 index b209cb22350..00000000000 --- a/dbms/src/Parsers/ParserGenericRoleSet.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include - - -namespace DB -{ -/** Parses a string like this: - * {role|CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {role|CURRENT_USER} [,...] - */ -class ParserGenericRoleSet : public IParserBase -{ -public: - ParserGenericRoleSet & enableAllKeyword(bool enable_) { all_keyword = enable_; return *this; } - ParserGenericRoleSet & enableCurrentUserKeyword(bool enable_) { current_user_keyword = enable_; return *this; } - ParserGenericRoleSet & enableIDMode(bool enable_) { id_mode = enable_; return *this; } - -protected: - const char * getName() const override { return "GenericRoleSet"; } - bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; - -private: - bool all_keyword = true; - bool current_user_keyword = true; - bool id_mode = false; -}; - -} diff --git a/dbms/src/Parsers/ParserGrantQuery.cpp b/dbms/src/Parsers/ParserGrantQuery.cpp index dc2fbc5f260..f8533c27d88 100644 --- a/dbms/src/Parsers/ParserGrantQuery.cpp +++ b/dbms/src/Parsers/ParserGrantQuery.cpp @@ -1,10 +1,10 @@ #include #include #include -#include +#include #include #include -#include +#include #include @@ -209,21 +209,21 @@ namespace } - bool parseRoles(IParser::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & roles) + bool parseRoles(IParser::Pos & pos, Expected & expected, bool id_mode, std::shared_ptr & roles) { return IParserBase::wrapParseImpl(pos, [&] { ASTPtr ast; - if (!ParserGenericRoleSet{}.enableAllKeyword(false).enableCurrentUserKeyword(false).enableIDMode(id_mode).parse(pos, ast, expected)) + if (!ParserExtendedRoleSet{}.enableAllKeyword(false).enableCurrentUserKeyword(false).useIDMode(id_mode).parse(pos, ast, expected)) return false; - roles = typeid_cast>(ast); + roles = typeid_cast>(ast); return true; }); } - bool parseToRoles(IParser::Pos & pos, Expected & expected, ASTGrantQuery::Kind kind, std::shared_ptr & to_roles) + bool parseToRoles(IParser::Pos & pos, Expected & expected, ASTGrantQuery::Kind kind, std::shared_ptr & to_roles) { return IParserBase::wrapParseImpl(pos, [&] { @@ -240,10 +240,10 @@ namespace } ASTPtr ast; - if (!ParserGenericRoleSet{}.enableAllKeyword(kind == Kind::REVOKE).parse(pos, ast, expected)) + if (!ParserExtendedRoleSet{}.enableAllKeyword(kind == Kind::REVOKE).parse(pos, ast, expected)) return false; - to_roles = typeid_cast>(ast); + to_roles = typeid_cast>(ast); return true; }); } @@ -280,11 +280,11 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) } AccessRightsElements elements; - std::shared_ptr roles; + std::shared_ptr roles; if (!parseAccessRightsElements(pos, expected, elements) && !parseRoles(pos, expected, attach, roles)) return false; - std::shared_ptr to_roles; + std::shared_ptr to_roles; if (!parseToRoles(pos, expected, kind, to_roles)) return false; diff --git a/dbms/src/Parsers/ParserQuery.cpp b/dbms/src/Parsers/ParserQuery.cpp index a157a3ca354..144c309927b 100644 --- a/dbms/src/Parsers/ParserQuery.cpp +++ b/dbms/src/Parsers/ParserQuery.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,7 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) ParserCreateRoleQuery create_role_p; ParserCreateQuotaQuery create_quota_p; ParserCreateRowPolicyQuery create_row_policy_p; + ParserCreateSettingsProfileQuery create_settings_profile_p; ParserDropAccessEntityQuery drop_access_entity_p; ParserGrantQuery grant_p; ParserSetRoleQuery set_role_p; @@ -47,6 +49,7 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) || create_role_p.parse(pos, node, expected) || create_quota_p.parse(pos, node, expected) || create_row_policy_p.parse(pos, node, expected) + || create_settings_profile_p.parse(pos, node, expected) || drop_access_entity_p.parse(pos, node, expected) || grant_p.parse(pos, node, expected); diff --git a/dbms/src/Parsers/ParserSetRoleQuery.cpp b/dbms/src/Parsers/ParserSetRoleQuery.cpp index 3031bf8ad01..e6ff7893891 100644 --- a/dbms/src/Parsers/ParserSetRoleQuery.cpp +++ b/dbms/src/Parsers/ParserSetRoleQuery.cpp @@ -1,28 +1,28 @@ #include #include #include -#include -#include +#include +#include namespace DB { namespace { - bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) + bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) { return IParserBase::wrapParseImpl(pos, [&] { ASTPtr ast; - if (!ParserGenericRoleSet{}.enableCurrentUserKeyword(false).parse(pos, ast, expected)) + if (!ParserExtendedRoleSet{}.enableCurrentUserKeyword(false).parse(pos, ast, expected)) return false; - roles = typeid_cast>(ast); + roles = typeid_cast>(ast); return true; }); } - bool parseToUsers(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & to_users) + bool parseToUsers(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & to_users) { return IParserBase::wrapParseImpl(pos, [&] { @@ -30,10 +30,10 @@ namespace return false; ASTPtr ast; - if (!ParserGenericRoleSet{}.enableAllKeyword(false).parse(pos, ast, expected)) + if (!ParserExtendedRoleSet{}.enableAllKeyword(false).parse(pos, ast, expected)) return false; - to_users = typeid_cast>(ast); + to_users = typeid_cast>(ast); return true; }); } @@ -53,8 +53,8 @@ bool ParserSetRoleQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected else return false; - std::shared_ptr roles; - std::shared_ptr to_users; + std::shared_ptr roles; + std::shared_ptr to_users; if ((kind == Kind::SET_ROLE) || (kind == Kind::SET_DEFAULT_ROLE)) { diff --git a/dbms/src/Parsers/ParserSettingsProfileElement.cpp b/dbms/src/Parsers/ParserSettingsProfileElement.cpp new file mode 100644 index 00000000000..06fa58fde4e --- /dev/null +++ b/dbms/src/Parsers/ParserSettingsProfileElement.cpp @@ -0,0 +1,164 @@ +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace +{ + bool parseProfileNameOrID(IParserBase::Pos & pos, Expected & expected, bool parse_id, String & res) + { + return IParserBase::wrapParseImpl(pos, [&] + { + ASTPtr ast; + if (!parse_id) + return parseIdentifierOrStringLiteral(pos, expected, res); + + if (!ParserKeyword{"ID"}.ignore(pos, expected)) + return false; + if (!ParserToken(TokenType::OpeningRoundBracket).ignore(pos, expected)) + return false; + if (!ParserStringLiteral{}.parse(pos, ast, expected)) + return false; + String id = ast->as().value.safeGet(); + if (!ParserToken(TokenType::ClosingRoundBracket).ignore(pos, expected)) + return false; + + res = std::move(id); + return true; + }); + } + + + bool parseValue(IParserBase::Pos & pos, Expected & expected, Field & res) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserToken{TokenType::Equals}.ignore(pos, expected)) + return false; + + ASTPtr ast; + if (!ParserLiteral{}.parse(pos, ast, expected)) + return false; + + res = ast->as().value; + return true; + }); + } + + + bool parseMinMaxValue(IParserBase::Pos & pos, Expected & expected, Field & min_value, Field & max_value) + { + return IParserBase::wrapParseImpl(pos, [&] + { + bool is_min_value = ParserKeyword{"MIN"}.ignore(pos, expected); + bool is_max_value = !is_min_value && ParserKeyword{"MAX"}.ignore(pos, expected); + if (!is_min_value && !is_max_value) + return false; + + ParserToken{TokenType::Equals}.ignore(pos, expected); + + ASTPtr ast; + if (!ParserLiteral{}.parse(pos, ast, expected)) + return false; + + auto min_or_max_value = ast->as().value; + + if (is_min_value) + min_value = min_or_max_value; + else + max_value = min_or_max_value; + return true; + }); + } + + + bool parseReadonlyOrWritableKeyword(IParserBase::Pos & pos, Expected & expected, std::optional & readonly) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (ParserKeyword{"READONLY"}.ignore(pos, expected)) + { + readonly = true; + return true; + } + else if (ParserKeyword{"READONLY"}.ignore(pos, expected)) + { + readonly = false; + return true; + } + else + return false; + }); + } +} + + +bool ParserSettingsProfileElement::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + String parent_profile; + String name; + Field value; + Field min_value; + Field max_value; + std::optional readonly; + + if (ParserKeyword{"PROFILE"}.ignore(pos, expected)) + { + if (!parseProfileNameOrID(pos, expected, id_mode, parent_profile)) + return false; + } + else + { + ASTPtr name_ast; + if (!ParserIdentifier{}.parse(pos, name_ast, expected)) + return false; + name = getIdentifierName(name_ast); + + while (parseValue(pos, expected, value) || parseMinMaxValue(pos, expected, min_value, max_value) + || parseReadonlyOrWritableKeyword(pos, expected, readonly)) + ; + } + + auto result = std::make_shared(); + result->parent_profile = std::move(parent_profile); + result->name = std::move(name); + result->value = std::move(value); + result->min_value = std::move(min_value); + result->max_value = std::move(max_value); + result->readonly = readonly; + result->id_mode = id_mode; + node = result; + return true; +} + + +bool ParserSettingsProfileElements::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + std::vector> elements; + + if (!ParserKeyword{"NONE"}.ignore(pos, expected)) + { + do + { + ASTPtr ast; + if (!ParserSettingsProfileElement{}.useIDMode(id_mode).parse(pos, ast, expected)) + return false; + auto element = typeid_cast>(ast); + elements.push_back(std::move(element)); + } + while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + } + + auto result = std::make_shared(); + result->elements = std::move(elements); + node = result; + return true; +} + +} diff --git a/dbms/src/Parsers/ParserSettingsProfileElement.h b/dbms/src/Parsers/ParserSettingsProfileElement.h new file mode 100644 index 00000000000..ec8e1abb5b5 --- /dev/null +++ b/dbms/src/Parsers/ParserSettingsProfileElement.h @@ -0,0 +1,36 @@ +#pragma once + +#include + + +namespace DB +{ +/** Parses a string like this: + * {variable [= value] [MIN [=] min_value] [MAX [=] max_value] [READONLY|WRITABLE]} | PROFILE 'profile_name' + */ +class ParserSettingsProfileElement : public IParserBase +{ +public: + ParserSettingsProfileElement & useIDMode(bool enable_) { id_mode = enable_; return *this; } + +protected: + const char * getName() const override { return "SettingsProfileElement"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; + +private: + bool id_mode = false; +}; + + +class ParserSettingsProfileElements : public IParserBase +{ +public: + ParserSettingsProfileElements & useIDMode(bool enable_) { id_mode = enable_; return *this; } + +protected: + const char * getName() const override { return "SettingsProfileElements"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; + +private: + bool id_mode = false; +};} diff --git a/dbms/src/Parsers/ParserShowCreateAccessEntityQuery.cpp b/dbms/src/Parsers/ParserShowCreateAccessEntityQuery.cpp index d1e6bc45478..faf9a0a1554 100644 --- a/dbms/src/Parsers/ParserShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ParserShowCreateAccessEntityQuery.cpp @@ -22,6 +22,10 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe kind = Kind::QUOTA; else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected)) kind = Kind::ROW_POLICY; + else if (ParserKeyword{"ROLE"}.ignore(pos, expected)) + kind = Kind::ROLE; + else if (ParserKeyword{"SETTINGS PROFILE"}.ignore(pos, expected) || ParserKeyword{"PROFILE"}.ignore(pos, expected)) + kind = Kind::SETTINGS_PROFILE; else return false; @@ -35,6 +39,11 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe if (!parseUserNameOrCurrentUserTag(pos, expected, name, current_user)) current_user = true; } + else if (kind == Kind::ROLE) + { + if (!parseRoleName(pos, expected, name)) + return false; + } else if (kind == Kind::ROW_POLICY) { String & database = row_policy_name.database; @@ -44,9 +53,8 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe || !parseDatabaseAndTableName(pos, expected, database, table_name)) return false; } - else + else if (kind == Kind::QUOTA) { - assert(kind == Kind::QUOTA); if (ParserKeyword{"CURRENT"}.ignore(pos, expected)) { /// SHOW CREATE QUOTA CURRENT @@ -62,6 +70,11 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe current_quota = true; } } + else if (kind == Kind::SETTINGS_PROFILE) + { + if (!parseIdentifierOrStringLiteral(pos, expected, name)) + return false; + } auto query = std::make_shared(kind); node = query; diff --git a/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.cpp b/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.cpp index e0242533518..ee5b254ccf9 100644 --- a/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.cpp +++ b/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.cpp @@ -335,7 +335,7 @@ void TreeExecutorBlockInputStream::setLimits(const IBlockInputStream::LocalLimit source->setLimits(limits_); } -void TreeExecutorBlockInputStream::setQuota(const QuotaContextPtr & quota_) +void TreeExecutorBlockInputStream::setQuota(const std::shared_ptr & quota_) { for (auto & source : sources_with_progress) source->setQuota(quota_); diff --git a/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.h b/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.h index 3ab8dde6948..24cab387eb8 100644 --- a/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.h +++ b/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.h @@ -43,7 +43,7 @@ public: void setProgressCallback(const ProgressCallback & callback) final; void setProcessListElement(QueryStatus * elem) final; void setLimits(const LocalLimits & limits_) final; - void setQuota(const QuotaContextPtr & quota_) final; + void setQuota(const std::shared_ptr & quota_) final; void addTotalRowsApprox(size_t value) final; protected: diff --git a/dbms/src/Processors/Pipe.cpp b/dbms/src/Processors/Pipe.cpp index 7b35c351d2f..f3ffb6ee201 100644 --- a/dbms/src/Processors/Pipe.cpp +++ b/dbms/src/Processors/Pipe.cpp @@ -106,7 +106,7 @@ void Pipe::setLimits(const ISourceWithProgress::LocalLimits & limits) } } -void Pipe::setQuota(const QuotaContextPtr & quota) +void Pipe::setQuota(const std::shared_ptr & quota) { for (auto & processor : processors) { diff --git a/dbms/src/Processors/Pipe.h b/dbms/src/Processors/Pipe.h index 20f5eb038a3..f30eaef678f 100644 --- a/dbms/src/Processors/Pipe.h +++ b/dbms/src/Processors/Pipe.h @@ -40,7 +40,7 @@ public: /// Specify quotas and limits for every ISourceWithProgress. void setLimits(const SourceWithProgress::LocalLimits & limits); - void setQuota(const QuotaContextPtr & quota); + void setQuota(const std::shared_ptr & quota); /// Set information about preferred executor number for sources. void pinSources(size_t executor_number); diff --git a/dbms/src/Processors/Sources/SourceFromInputStream.h b/dbms/src/Processors/Sources/SourceFromInputStream.h index 83e7f9929c9..0fc92164059 100644 --- a/dbms/src/Processors/Sources/SourceFromInputStream.h +++ b/dbms/src/Processors/Sources/SourceFromInputStream.h @@ -28,7 +28,7 @@ public: /// Implementation for methods from ISourceWithProgress. void setLimits(const LocalLimits & limits_) final { stream->setLimits(limits_); } - void setQuota(const QuotaContextPtr & quota_) final { stream->setQuota(quota_); } + void setQuota(const std::shared_ptr & quota_) final { stream->setQuota(quota_); } void setProcessListElement(QueryStatus * elem) final { stream->setProcessListElement(elem); } void setProgressCallback(const ProgressCallback & callback) final { stream->setProgressCallback(callback); } void addTotalRowsApprox(size_t value) final { stream->addTotalRowsApprox(value); } diff --git a/dbms/src/Processors/Sources/SourceWithProgress.cpp b/dbms/src/Processors/Sources/SourceWithProgress.cpp index 0cac415aedb..80844da16cd 100644 --- a/dbms/src/Processors/Sources/SourceWithProgress.cpp +++ b/dbms/src/Processors/Sources/SourceWithProgress.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include namespace DB { diff --git a/dbms/src/Processors/Sources/SourceWithProgress.h b/dbms/src/Processors/Sources/SourceWithProgress.h index d22a2bf087a..4778c50e49d 100644 --- a/dbms/src/Processors/Sources/SourceWithProgress.h +++ b/dbms/src/Processors/Sources/SourceWithProgress.h @@ -21,7 +21,7 @@ public: /// Set the quota. If you set a quota on the amount of raw data, /// then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits. - virtual void setQuota(const QuotaContextPtr & quota_) = 0; + virtual void setQuota(const std::shared_ptr & quota_) = 0; /// Set the pointer to the process list item. /// General information about the resources spent on the request will be written into it. @@ -49,7 +49,7 @@ public: using LimitsMode = IBlockInputStream::LimitsMode; void setLimits(const LocalLimits & limits_) final { limits = limits_; } - void setQuota(const QuotaContextPtr & quota_) final { quota = quota_; } + void setQuota(const std::shared_ptr & quota_) final { quota = quota_; } void setProcessListElement(QueryStatus * elem) final { process_list_elem = elem; } void setProgressCallback(const ProgressCallback & callback) final { progress_callback = callback; } void addTotalRowsApprox(size_t value) final { total_rows_approx += value; } @@ -62,7 +62,7 @@ protected: private: LocalLimits limits; - QuotaContextPtr quota; + std::shared_ptr quota; ProgressCallback progress_callback; QueryStatus * process_list_elem = nullptr; diff --git a/dbms/src/Processors/Transforms/LimitsCheckingTransform.cpp b/dbms/src/Processors/Transforms/LimitsCheckingTransform.cpp index 3ead146abc1..c3ac019f2b6 100644 --- a/dbms/src/Processors/Transforms/LimitsCheckingTransform.cpp +++ b/dbms/src/Processors/Transforms/LimitsCheckingTransform.cpp @@ -1,5 +1,5 @@ #include -#include +#include namespace DB { diff --git a/dbms/src/Processors/Transforms/LimitsCheckingTransform.h b/dbms/src/Processors/Transforms/LimitsCheckingTransform.h index bfc5c338da1..3014c259487 100644 --- a/dbms/src/Processors/Transforms/LimitsCheckingTransform.h +++ b/dbms/src/Processors/Transforms/LimitsCheckingTransform.h @@ -33,7 +33,7 @@ public: String getName() const override { return "LimitsCheckingTransform"; } - void setQuota(const QuotaContextPtr & quota_) { quota = quota_; } + void setQuota(const std::shared_ptr & quota_) { quota = quota_; } protected: void transform(Chunk & chunk) override; @@ -41,7 +41,7 @@ protected: private: LocalLimits limits; - QuotaContextPtr quota; + std::shared_ptr quota; UInt64 prev_elapsed = 0; ProcessorProfileInfo info; diff --git a/dbms/src/Storages/MergeTree/MergeTreeDataMergerMutator.cpp b/dbms/src/Storages/MergeTree/MergeTreeDataMergerMutator.cpp index ac7d5e4b541..7857720f862 100644 --- a/dbms/src/Storages/MergeTree/MergeTreeDataMergerMutator.cpp +++ b/dbms/src/Storages/MergeTree/MergeTreeDataMergerMutator.cpp @@ -958,8 +958,8 @@ MergeTreeData::MutableDataPartPtr MergeTreeDataMergerMutator::mutatePartToTempor auto storage_from_source_part = StorageFromMergeTreeDataPart::create(source_part); auto context_for_reading = context; - context_for_reading.getSettingsRef().max_streams_to_max_threads_ratio = 1; - context_for_reading.getSettingsRef().max_threads = 1; + context_for_reading.setSetting("max_streams_to_max_threads_ratio", 1); + context_for_reading.setSetting("max_threads", 1); MutationCommands commands_for_part; for (const auto & command : commands) diff --git a/dbms/src/Storages/StorageMerge.cpp b/dbms/src/Storages/StorageMerge.cpp index 46e76a3fcde..e23aa608ec7 100644 --- a/dbms/src/Storages/StorageMerge.cpp +++ b/dbms/src/Storages/StorageMerge.cpp @@ -186,7 +186,7 @@ Pipes StorageMerge::read( * since there is no certainty that it works when one of table is MergeTree and other is not. */ auto modified_context = std::make_shared(context); - modified_context->getSettingsRef().optimize_move_to_prewhere = false; + modified_context->setSetting("optimize_move_to_prewhere", false); /// What will be result structure depending on query processed stage in source tables? Block header = getQueryHeader(column_names, query_info, context, processed_stage); @@ -300,8 +300,8 @@ Pipes StorageMerge::createSources(const SelectQueryInfo & query_info, const Quer modified_query_info.query->as()->replaceDatabaseAndTable(source_database, table_name); /// Maximum permissible parallelism is streams_num - modified_context->getSettingsRef().max_threads = UInt64(streams_num); - modified_context->getSettingsRef().max_streams_to_max_threads_ratio = 1; + modified_context->setSetting("max_threads", streams_num); + modified_context->setSetting("max_streams_to_max_threads_ratio", 1); InterpreterSelectQuery interpreter{modified_query_info.query, *modified_context, SelectQueryOptions(processed_stage)}; diff --git a/dbms/src/Storages/System/StorageSystemColumns.cpp b/dbms/src/Storages/System/StorageSystemColumns.cpp index 9bf96f95264..cbf6ada9ed3 100644 --- a/dbms/src/Storages/System/StorageSystemColumns.cpp +++ b/dbms/src/Storages/System/StorageSystemColumns.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include @@ -62,12 +62,12 @@ public: ColumnPtr databases_, ColumnPtr tables_, Storages storages_, - const AccessRightsContextPtr & access_rights_, + const std::shared_ptr & access_, String query_id_) : SourceWithProgress(header_) , columns_mask(std::move(columns_mask_)), max_block_size(max_block_size_) , databases(std::move(databases_)), tables(std::move(tables_)), storages(std::move(storages_)) - , query_id(std::move(query_id_)), total_tables(tables->size()), access_rights(access_rights_) + , query_id(std::move(query_id_)), total_tables(tables->size()), access(access_) { } @@ -82,7 +82,7 @@ protected: MutableColumns res_columns = getPort().getHeader().cloneEmptyColumns(); size_t rows_count = 0; - const bool check_access_for_tables = !access_rights->isGranted(AccessType::SHOW); + const bool check_access_for_tables = !access->isGranted(AccessType::SHOW_COLUMNS); while (rows_count < max_block_size && db_table_num < total_tables) { @@ -128,14 +128,14 @@ protected: column_sizes = storage->getColumnSizes(); } - bool check_access_for_columns = check_access_for_tables && !access_rights->isGranted(AccessType::SHOW, database_name, table_name); + bool check_access_for_columns = check_access_for_tables && !access->isGranted(AccessType::SHOW_COLUMNS, database_name, table_name); for (const auto & column : columns) { if (column.is_virtual) continue; - if (check_access_for_columns && !access_rights->isGranted(AccessType::SHOW, database_name, table_name, column.name)) + if (check_access_for_columns && !access->isGranted(AccessType::SHOW_COLUMNS, database_name, table_name, column.name)) continue; size_t src_index = 0; @@ -230,7 +230,7 @@ private: String query_id; size_t db_table_num = 0; size_t total_tables; - AccessRightsContextPtr access_rights; + std::shared_ptr access; }; @@ -332,7 +332,7 @@ Pipes StorageSystemColumns::read( pipes.emplace_back(std::make_shared( std::move(columns_mask), std::move(header), max_block_size, std::move(filtered_database_column), std::move(filtered_table_column), std::move(storages), - context.getAccessRights(), context.getCurrentQueryId())); + context.getAccess(), context.getCurrentQueryId())); return pipes; } diff --git a/dbms/src/Storages/System/StorageSystemDatabases.cpp b/dbms/src/Storages/System/StorageSystemDatabases.cpp index 4588fd28482..5a35e079a5b 100644 --- a/dbms/src/Storages/System/StorageSystemDatabases.cpp +++ b/dbms/src/Storages/System/StorageSystemDatabases.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include @@ -20,13 +20,13 @@ NamesAndTypesList StorageSystemDatabases::getNamesAndTypes() void StorageSystemDatabases::fillData(MutableColumns & res_columns, const Context & context, const SelectQueryInfo &) const { - const auto access_rights = context.getAccessRights(); - const bool check_access_for_databases = !access_rights->isGranted(AccessType::SHOW); + const auto access = context.getAccess(); + const bool check_access_for_databases = !access->isGranted(AccessType::SHOW_DATABASES); auto databases = DatabaseCatalog::instance().getDatabases(); for (const auto & database : databases) { - if (check_access_for_databases && !access_rights->isGranted(AccessType::SHOW, database.first)) + if (check_access_for_databases && !access->isGranted(AccessType::SHOW_DATABASES, database.first)) continue; res_columns[0]->insert(database.first); diff --git a/dbms/src/Storages/System/StorageSystemDictionaries.cpp b/dbms/src/Storages/System/StorageSystemDictionaries.cpp index 87a11387e4d..400b1074250 100644 --- a/dbms/src/Storages/System/StorageSystemDictionaries.cpp +++ b/dbms/src/Storages/System/StorageSystemDictionaries.cpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include #include @@ -49,8 +49,8 @@ NamesAndTypesList StorageSystemDictionaries::getNamesAndTypes() void StorageSystemDictionaries::fillData(MutableColumns & res_columns, const Context & context, const SelectQueryInfo & /*query_info*/) const { - const auto access_rights = context.getAccessRights(); - const bool check_access_for_dictionaries = !access_rights->isGranted(AccessType::SHOW); + const auto access = context.getAccess(); + const bool check_access_for_dictionaries = !access->isGranted(AccessType::SHOW_DICTIONARIES); const auto & external_dictionaries = context.getExternalDictionariesLoader(); for (const auto & load_result : external_dictionaries.getCurrentLoadResults()) @@ -74,7 +74,7 @@ void StorageSystemDictionaries::fillData(MutableColumns & res_columns, const Con } if (check_access_for_dictionaries - && !access_rights->isGranted(AccessType::SHOW, database.empty() ? IDictionary::NO_DATABASE_TAG : database, short_name)) + && !access->isGranted(AccessType::SHOW_DICTIONARIES, database.empty() ? IDictionary::NO_DATABASE_TAG : database, short_name)) continue; size_t i = 0; diff --git a/dbms/src/Storages/System/StorageSystemMerges.cpp b/dbms/src/Storages/System/StorageSystemMerges.cpp index 7e71d90120a..39d22bd00ca 100644 --- a/dbms/src/Storages/System/StorageSystemMerges.cpp +++ b/dbms/src/Storages/System/StorageSystemMerges.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include namespace DB @@ -36,12 +36,12 @@ NamesAndTypesList StorageSystemMerges::getNamesAndTypes() void StorageSystemMerges::fillData(MutableColumns & res_columns, const Context & context, const SelectQueryInfo &) const { - const auto access_rights = context.getAccessRights(); - const bool check_access_for_tables = !access_rights->isGranted(AccessType::SHOW); + const auto access = context.getAccess(); + const bool check_access_for_tables = !access->isGranted(AccessType::SHOW_TABLES); for (const auto & merge : context.getMergeList().get()) { - if (check_access_for_tables && !access_rights->isGranted(AccessType::SHOW, merge.database, merge.table)) + if (check_access_for_tables && !access->isGranted(AccessType::SHOW_TABLES, merge.database, merge.table)) continue; size_t i = 0; diff --git a/dbms/src/Storages/System/StorageSystemMutations.cpp b/dbms/src/Storages/System/StorageSystemMutations.cpp index 51c5bd47c6d..e7d9cc38671 100644 --- a/dbms/src/Storages/System/StorageSystemMutations.cpp +++ b/dbms/src/Storages/System/StorageSystemMutations.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include @@ -37,8 +37,8 @@ NamesAndTypesList StorageSystemMutations::getNamesAndTypes() void StorageSystemMutations::fillData(MutableColumns & res_columns, const Context & context, const SelectQueryInfo & query_info) const { - const auto access_rights = context.getAccessRights(); - const bool check_access_for_databases = !access_rights->isGranted(AccessType::SHOW); + const auto access = context.getAccess(); + const bool check_access_for_databases = !access->isGranted(AccessType::SHOW_TABLES); /// Collect a set of *MergeTree tables. std::map> merge_tree_tables; @@ -48,14 +48,14 @@ void StorageSystemMutations::fillData(MutableColumns & res_columns, const Contex if (db.second->getEngineName() == "Lazy") continue; - const bool check_access_for_tables = check_access_for_databases && !access_rights->isGranted(AccessType::SHOW, db.first); + const bool check_access_for_tables = check_access_for_databases && !access->isGranted(AccessType::SHOW_TABLES, db.first); for (auto iterator = db.second->getTablesIterator(context); iterator->isValid(); iterator->next()) { if (!dynamic_cast(iterator->table().get())) continue; - if (check_access_for_tables && !access_rights->isGranted(AccessType::SHOW, db.first, iterator->name())) + if (check_access_for_tables && !access->isGranted(AccessType::SHOW_TABLES, db.first, iterator->name())) continue; merge_tree_tables[db.first][iterator->name()] = iterator->table(); diff --git a/dbms/src/Storages/System/StorageSystemPartsBase.cpp b/dbms/src/Storages/System/StorageSystemPartsBase.cpp index 4bf3737dcd4..c5edde3e5d9 100644 --- a/dbms/src/Storages/System/StorageSystemPartsBase.cpp +++ b/dbms/src/Storages/System/StorageSystemPartsBase.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -73,8 +73,8 @@ StoragesInfoStream::StoragesInfoStream(const SelectQueryInfo & query_info, const MutableColumnPtr engine_column_mut = ColumnString::create(); MutableColumnPtr active_column_mut = ColumnUInt8::create(); - const auto access_rights = context.getAccessRights(); - const bool check_access_for_tables = !access_rights->isGranted(AccessType::SHOW); + const auto access = context.getAccess(); + const bool check_access_for_tables = !access->isGranted(AccessType::SHOW_TABLES); { Databases databases = DatabaseCatalog::instance().getDatabases(); @@ -119,7 +119,7 @@ StoragesInfoStream::StoragesInfoStream(const SelectQueryInfo & query_info, const if (!dynamic_cast(storage.get())) continue; - if (check_access_for_tables && !access_rights->isGranted(AccessType::SHOW, database_name, table_name)) + if (check_access_for_tables && !access->isGranted(AccessType::SHOW_TABLES, database_name, table_name)) continue; storages[std::make_pair(database_name, iterator->name())] = storage; diff --git a/dbms/src/Storages/System/StorageSystemQuotaUsage.cpp b/dbms/src/Storages/System/StorageSystemQuotaUsage.cpp index 8835e77eeb5..53afb1d563a 100644 --- a/dbms/src/Storages/System/StorageSystemQuotaUsage.cpp +++ b/dbms/src/Storages/System/StorageSystemQuotaUsage.cpp @@ -6,7 +6,8 @@ #include #include #include -#include +#include +#include #include diff --git a/dbms/src/Storages/System/StorageSystemQuotas.cpp b/dbms/src/Storages/System/StorageSystemQuotas.cpp index 81969ab2364..228339ea305 100644 --- a/dbms/src/Storages/System/StorageSystemQuotas.cpp +++ b/dbms/src/Storages/System/StorageSystemQuotas.cpp @@ -87,7 +87,7 @@ void StorageSystemQuotas::fillData(MutableColumns & res_columns, const Context & storage_name_column.insert(storage_name); key_type_column.insert(static_cast(quota->key_type)); - for (const String & role : quota->roles.toStringsWithNames(access_control)) + for (const String & role : quota->to_roles.toStringsWithNames(access_control)) roles_data.insert(role); roles_offsets.push_back(roles_data.size()); diff --git a/dbms/src/Storages/System/StorageSystemReplicas.cpp b/dbms/src/Storages/System/StorageSystemReplicas.cpp index 16a2a8d07de..251b45e44b6 100644 --- a/dbms/src/Storages/System/StorageSystemReplicas.cpp +++ b/dbms/src/Storages/System/StorageSystemReplicas.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include #include #include @@ -65,8 +65,8 @@ Pipes StorageSystemReplicas::read( { check(column_names); - const auto access_rights = context.getAccessRights(); - const bool check_access_for_databases = !access_rights->isGranted(AccessType::SHOW); + const auto access = context.getAccess(); + const bool check_access_for_databases = !access->isGranted(AccessType::SHOW_TABLES); /// We collect a set of replicated tables. std::map> replicated_tables; @@ -75,12 +75,12 @@ Pipes StorageSystemReplicas::read( /// Lazy database can not contain replicated tables if (db.second->getEngineName() == "Lazy") continue; - const bool check_access_for_tables = check_access_for_databases && !access_rights->isGranted(AccessType::SHOW, db.first); + const bool check_access_for_tables = check_access_for_databases && !access->isGranted(AccessType::SHOW_TABLES, db.first); for (auto iterator = db.second->getTablesIterator(context); iterator->isValid(); iterator->next()) { if (!dynamic_cast(iterator->table().get())) continue; - if (check_access_for_tables && !access_rights->isGranted(AccessType::SHOW, db.first, iterator->name())) + if (check_access_for_tables && !access->isGranted(AccessType::SHOW_TABLES, db.first, iterator->name())) continue; replicated_tables[db.first][iterator->name()] = iterator->table(); } diff --git a/dbms/src/Storages/System/StorageSystemReplicationQueue.cpp b/dbms/src/Storages/System/StorageSystemReplicationQueue.cpp index 5148e0a9ec8..2c188cf3734 100644 --- a/dbms/src/Storages/System/StorageSystemReplicationQueue.cpp +++ b/dbms/src/Storages/System/StorageSystemReplicationQueue.cpp @@ -8,7 +8,7 @@ #include #include #include -#include +#include #include #include @@ -48,8 +48,8 @@ NamesAndTypesList StorageSystemReplicationQueue::getNamesAndTypes() void StorageSystemReplicationQueue::fillData(MutableColumns & res_columns, const Context & context, const SelectQueryInfo & query_info) const { - const auto access_rights = context.getAccessRights(); - const bool check_access_for_databases = !access_rights->isGranted(AccessType::SHOW); + const auto access = context.getAccess(); + const bool check_access_for_databases = !access->isGranted(AccessType::SHOW_TABLES); std::map> replicated_tables; for (const auto & db : DatabaseCatalog::instance().getDatabases()) @@ -58,13 +58,13 @@ void StorageSystemReplicationQueue::fillData(MutableColumns & res_columns, const if (db.second->getEngineName() == "Lazy") continue; - const bool check_access_for_tables = check_access_for_databases && !access_rights->isGranted(AccessType::SHOW, db.first); + const bool check_access_for_tables = check_access_for_databases && !access->isGranted(AccessType::SHOW_TABLES, db.first); for (auto iterator = db.second->getTablesIterator(context); iterator->isValid(); iterator->next()) { if (!dynamic_cast(iterator->table().get())) continue; - if (check_access_for_tables && !access_rights->isGranted(AccessType::SHOW, db.first, iterator->name())) + if (check_access_for_tables && !access->isGranted(AccessType::SHOW_TABLES, db.first, iterator->name())) continue; replicated_tables[db.first][iterator->name()] = iterator->table(); } diff --git a/dbms/src/Storages/System/StorageSystemRowPolicies.cpp b/dbms/src/Storages/System/StorageSystemRowPolicies.cpp index 8ac4ac1b755..bd302cba3cf 100644 --- a/dbms/src/Storages/System/StorageSystemRowPolicies.cpp +++ b/dbms/src/Storages/System/StorageSystemRowPolicies.cpp @@ -24,8 +24,8 @@ NamesAndTypesList StorageSystemRowPolicies::getNamesAndTypes() {"restrictive", std::make_shared()}, }; - for (auto index : ext::range_with_static_cast(RowPolicy::MAX_CONDITION_INDEX)) - names_and_types.push_back({RowPolicy::conditionIndexToColumnName(index), std::make_shared()}); + for (auto index : ext::range_with_static_cast(RowPolicy::MAX_CONDITION_TYPE)) + names_and_types.push_back({RowPolicy::conditionTypeToColumnName(index), std::make_shared()}); return names_and_types; } @@ -52,7 +52,7 @@ void StorageSystemRowPolicies::fillData(MutableColumns & res_columns, const Cont res_columns[i++]->insert(storage ? storage->getStorageName() : ""); res_columns[i++]->insert(policy->isRestrictive()); - for (auto index : ext::range(RowPolicy::MAX_CONDITION_INDEX)) + for (auto index : ext::range(RowPolicy::MAX_CONDITION_TYPE)) res_columns[i++]->insert(policy->conditions[index]); } } diff --git a/dbms/src/Storages/System/StorageSystemTables.cpp b/dbms/src/Storages/System/StorageSystemTables.cpp index 7bc0799e795..cb72d3408df 100644 --- a/dbms/src/Storages/System/StorageSystemTables.cpp +++ b/dbms/src/Storages/System/StorageSystemTables.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include #include #include @@ -105,8 +105,8 @@ protected: MutableColumns res_columns = getPort().getHeader().cloneEmptyColumns(); - const auto access_rights = context.getAccessRights(); - const bool check_access_for_databases = !access_rights->isGranted(AccessType::SHOW); + const auto access = context.getAccess(); + const bool check_access_for_databases = !access->isGranted(AccessType::SHOW_TABLES); size_t rows_count = 0; while (rows_count < max_block_size) @@ -196,7 +196,7 @@ protected: return Chunk(std::move(res_columns), num_rows); } - const bool check_access_for_tables = check_access_for_databases && !access_rights->isGranted(AccessType::SHOW, database_name); + const bool check_access_for_tables = check_access_for_databases && !access->isGranted(AccessType::SHOW_TABLES, database_name); if (!tables_it || !tables_it->isValid()) tables_it = database->getTablesWithDictionaryTablesIterator(context); @@ -206,7 +206,7 @@ protected: for (; rows_count < max_block_size && tables_it->isValid(); tables_it->next()) { auto table_name = tables_it->name(); - if (check_access_for_tables && !access_rights->isGranted(AccessType::SHOW, database_name, table_name)) + if (check_access_for_tables && !access->isGranted(AccessType::SHOW_TABLES, database_name, table_name)) continue; StoragePtr table = nullptr; diff --git a/dbms/tests/integration/test_authentication/test.py b/dbms/tests/integration/test_authentication/test.py index 11ca967fbee..483b59813e5 100644 --- a/dbms/tests/integration/test_authentication/test.py +++ b/dbms/tests/integration/test_authentication/test.py @@ -10,8 +10,8 @@ def setup_nodes(): try: cluster.start() - instance.query("CREATE USER sasha PROFILE 'default'") - instance.query("CREATE USER masha IDENTIFIED BY 'qwerty' PROFILE 'default'") + instance.query("CREATE USER sasha") + instance.query("CREATE USER masha IDENTIFIED BY 'qwerty'") yield cluster diff --git a/dbms/tests/integration/test_disk_access_storage/test.py b/dbms/tests/integration/test_disk_access_storage/test.py index 169c0a35414..d5e1f283167 100644 --- a/dbms/tests/integration/test_disk_access_storage/test.py +++ b/dbms/tests/integration/test_disk_access_storage/test.py @@ -16,9 +16,11 @@ def started_cluster(): def create_entities(): - instance.query("CREATE USER u1") - instance.query("CREATE ROLE rx") + instance.query("CREATE SETTINGS PROFILE s1 SETTINGS max_memory_usage = 123456789 MIN 100000000 MAX 200000000") + instance.query("CREATE USER u1 SETTINGS PROFILE s1") + instance.query("CREATE ROLE rx SETTINGS PROFILE s1") instance.query("CREATE USER u2 IDENTIFIED BY 'qwerty' HOST LOCAL DEFAULT ROLE rx") + instance.query("CREATE SETTINGS PROFILE s2 SETTINGS PROFILE s1 TO u2") instance.query("CREATE ROW POLICY p ON mydb.mytable FOR SELECT USING a<1000 TO u1, u2") instance.query("CREATE QUOTA q FOR INTERVAL 1 HOUR SET MAX QUERIES = 100 TO ALL EXCEPT rx") @@ -29,19 +31,23 @@ def drop_entities(): instance.query("DROP ROLE IF EXISTS rx, ry") instance.query("DROP ROW POLICY IF EXISTS p ON mydb.mytable") instance.query("DROP QUOTA IF EXISTS q") + instance.query("DROP SETTINGS PROFILE IF EXISTS s1, s2") def test_create(): create_entities() def check(): - assert instance.query("SHOW CREATE USER u1") == "CREATE USER u1\n" + assert instance.query("SHOW CREATE USER u1") == "CREATE USER u1 SETTINGS PROFILE s1\n" assert instance.query("SHOW CREATE USER u2") == "CREATE USER u2 HOST LOCAL DEFAULT ROLE rx\n" assert instance.query("SHOW CREATE ROW POLICY p ON mydb.mytable") == "CREATE POLICY p ON mydb.mytable FOR SELECT USING a < 1000 TO u1, u2\n" assert instance.query("SHOW CREATE QUOTA q") == "CREATE QUOTA q KEYED BY \\'none\\' FOR INTERVAL 1 HOUR MAX QUERIES = 100 TO ALL EXCEPT rx\n" assert instance.query("SHOW GRANTS FOR u1") == "" assert instance.query("SHOW GRANTS FOR u2") == "GRANT rx TO u2\n" + assert instance.query("SHOW CREATE ROLE rx") == "CREATE ROLE rx SETTINGS PROFILE s1\n" assert instance.query("SHOW GRANTS FOR rx") == "" + assert instance.query("SHOW CREATE SETTINGS PROFILE s1") == "CREATE SETTINGS PROFILE s1 SETTINGS max_memory_usage = 123456789 MIN 100000000 MAX 200000000\n" + assert instance.query("SHOW CREATE SETTINGS PROFILE s2") == "CREATE SETTINGS PROFILE s2 SETTINGS PROFILE s1 TO u2\n" check() instance.restart_clickhouse() # Check persistency @@ -56,16 +62,22 @@ def test_alter(): instance.query("GRANT ry TO u2") instance.query("ALTER USER u2 DEFAULT ROLE ry") instance.query("GRANT rx TO ry WITH ADMIN OPTION") + instance.query("ALTER ROLE rx SETTINGS PROFILE s2") instance.query("GRANT SELECT ON mydb.mytable TO u1") instance.query("GRANT SELECT ON mydb.* TO rx WITH GRANT OPTION") + instance.query("ALTER SETTINGS PROFILE s1 SETTINGS max_memory_usage = 987654321 READONLY") def check(): - assert instance.query("SHOW CREATE USER u1") == "CREATE USER u1\n" + assert instance.query("SHOW CREATE USER u1") == "CREATE USER u1 SETTINGS PROFILE s1\n" assert instance.query("SHOW CREATE USER u2") == "CREATE USER u2 HOST LOCAL DEFAULT ROLE ry\n" assert instance.query("SHOW GRANTS FOR u1") == "GRANT SELECT ON mydb.mytable TO u1\n" assert instance.query("SHOW GRANTS FOR u2") == "GRANT rx, ry TO u2\n" + assert instance.query("SHOW CREATE ROLE rx") == "CREATE ROLE rx SETTINGS PROFILE s2\n" + assert instance.query("SHOW CREATE ROLE ry") == "CREATE ROLE ry\n" assert instance.query("SHOW GRANTS FOR rx") == "GRANT SELECT ON mydb.* TO rx WITH GRANT OPTION\n" assert instance.query("SHOW GRANTS FOR ry") == "GRANT rx TO ry WITH ADMIN OPTION\n" + assert instance.query("SHOW CREATE SETTINGS PROFILE s1") == "CREATE SETTINGS PROFILE s1 SETTINGS max_memory_usage = 987654321 READONLY\n" + assert instance.query("SHOW CREATE SETTINGS PROFILE s2") == "CREATE SETTINGS PROFILE s2 SETTINGS PROFILE s1 TO u2\n" check() instance.restart_clickhouse() # Check persistency @@ -80,9 +92,11 @@ def test_drop(): instance.query("DROP ROLE rx") instance.query("DROP ROW POLICY p ON mydb.mytable") instance.query("DROP QUOTA q") + instance.query("DROP SETTINGS PROFILE s1") def check(): assert instance.query("SHOW CREATE USER u1") == "CREATE USER u1\n" + assert instance.query("SHOW CREATE SETTINGS PROFILE s2") == "CREATE SETTINGS PROFILE s2\n" assert "User `u2` not found" in instance.query_and_get_error("SHOW CREATE USER u2") assert "Row policy `p ON mydb.mytable` not found" in instance.query_and_get_error("SHOW CREATE ROW POLICY p ON mydb.mytable") assert "Quota `q` not found" in instance.query_and_get_error("SHOW CREATE QUOTA q") diff --git a/dbms/tests/integration/test_format_avro_confluent/test.py b/dbms/tests/integration/test_format_avro_confluent/test.py index 42ebf05d161..a93b5585f8d 100644 --- a/dbms/tests/integration/test_format_avro_confluent/test.py +++ b/dbms/tests/integration/test_format_avro_confluent/test.py @@ -29,12 +29,14 @@ def cluster(): cluster.shutdown() -def run_query(instance, query, stdin=None, settings=None): +def run_query(instance, query, data=None, settings=None): # type: (ClickHouseInstance, str, object, dict) -> str logging.info("Running query '{}'...".format(query)) # use http to force parsing on server - result = instance.http_query(query, data=stdin, params=settings) + if not data: + data = " " # make POST request + result = instance.http_query(query, data=data, params=settings) logging.info("Query finished") return result @@ -64,7 +66,7 @@ def test_select(cluster): 'test_subject', schema, {'value': x} ) buf.write(message) - stdin = buf.getvalue() + data = buf.getvalue() instance = cluster.instances["dummy"] # type: ClickHouseInstance schema_registry_url = "http://{}:{}".format( @@ -74,7 +76,7 @@ def test_select(cluster): run_query(instance, "create table avro_data(value Int64) engine = Memory()") settings = {'format_avro_schema_registry_url': schema_registry_url} - run_query(instance, "insert into avro_data format AvroConfluent", stdin, settings) + run_query(instance, "insert into avro_data format AvroConfluent", data, settings) stdout = run_query(instance, "select * from avro_data") assert list(map(str.split, stdout.splitlines())) == [ ["0"], diff --git a/dbms/tests/integration/test_grant_and_revoke/test.py b/dbms/tests/integration/test_grant_and_revoke/test.py index 132e62f3db0..25e0e9882de 100644 --- a/dbms/tests/integration/test_grant_and_revoke/test.py +++ b/dbms/tests/integration/test_grant_and_revoke/test.py @@ -30,14 +30,14 @@ def reset_users_and_roles(): def test_login(): - instance.query("CREATE USER A PROFILE 'default'") - instance.query("CREATE USER B PROFILE 'default'") + instance.query("CREATE USER A") + instance.query("CREATE USER B") assert instance.query("SELECT 1", user='A') == "1\n" assert instance.query("SELECT 1", user='B') == "1\n" def test_grant_and_revoke(): - instance.query("CREATE USER A PROFILE 'default'") + instance.query("CREATE USER A") assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') instance.query('GRANT SELECT ON test_table TO A') @@ -48,8 +48,8 @@ def test_grant_and_revoke(): def test_grant_option(): - instance.query("CREATE USER A PROFILE 'default'") - instance.query("CREATE USER B PROFILE 'default'") + instance.query("CREATE USER A") + instance.query("CREATE USER B") instance.query('GRANT SELECT ON test_table TO A') assert instance.query("SELECT * FROM test_table", user='A') == "1\t5\n2\t10\n" @@ -63,7 +63,7 @@ def test_grant_option(): def test_create_role(): - instance.query("CREATE USER A PROFILE 'default'") + instance.query("CREATE USER A") instance.query('CREATE ROLE R1') assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') @@ -79,7 +79,7 @@ def test_create_role(): def test_grant_role_to_role(): - instance.query("CREATE USER A PROFILE 'default'") + instance.query("CREATE USER A") instance.query('CREATE ROLE R1') instance.query('CREATE ROLE R2') @@ -96,7 +96,7 @@ def test_grant_role_to_role(): def test_combine_privileges(): - instance.query("CREATE USER A PROFILE 'default'") + instance.query("CREATE USER A ") instance.query('CREATE ROLE R1') instance.query('CREATE ROLE R2') @@ -113,8 +113,8 @@ def test_combine_privileges(): def test_admin_option(): - instance.query("CREATE USER A PROFILE 'default'") - instance.query("CREATE USER B PROFILE 'default'") + instance.query("CREATE USER A") + instance.query("CREATE USER B") instance.query('CREATE ROLE R1') instance.query('GRANT SELECT ON test_table TO R1') diff --git a/dbms/tests/integration/test_settings_profile/__init__.py b/dbms/tests/integration/test_settings_profile/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dbms/tests/integration/test_settings_profile/test.py b/dbms/tests/integration/test_settings_profile/test.py new file mode 100644 index 00000000000..592ab5b92d6 --- /dev/null +++ b/dbms/tests/integration/test_settings_profile/test.py @@ -0,0 +1,106 @@ +import pytest +from helpers.cluster import ClickHouseCluster + +cluster = ClickHouseCluster(__file__) +instance = cluster.add_instance('instance') + + +@pytest.fixture(scope="module", autouse=True) +def setup_nodes(): + try: + cluster.start() + + instance.query("CREATE USER robin") + + yield cluster + + finally: + cluster.shutdown() + + +@pytest.fixture(autouse=True) +def reset_after_test(): + try: + yield + finally: + instance.query("CREATE USER OR REPLACE robin") + instance.query("DROP ROLE IF EXISTS worker") + instance.query("DROP SETTINGS PROFILE IF EXISTS xyz, alpha") + + +def test_settings_profile(): + # Set settings and constraints via CREATE SETTINGS PROFILE ... TO user + instance.query("CREATE SETTINGS PROFILE xyz SETTINGS max_memory_usage = 100000001 MIN 90000000 MAX 110000000 TO robin") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "100000001\n" + assert "Setting max_memory_usage shouldn't be less than 90000000" in instance.query_and_get_error("SET max_memory_usage = 80000000", user="robin") + assert "Setting max_memory_usage shouldn't be greater than 110000000" in instance.query_and_get_error("SET max_memory_usage = 120000000", user="robin") + + instance.query("ALTER SETTINGS PROFILE xyz TO NONE") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "10000000000\n" + instance.query("SET max_memory_usage = 80000000", user="robin") + instance.query("SET max_memory_usage = 120000000", user="robin") + + # Set settings and constraints via CREATE USER ... SETTINGS PROFILE + instance.query("ALTER USER robin SETTINGS PROFILE xyz") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "100000001\n" + assert "Setting max_memory_usage shouldn't be less than 90000000" in instance.query_and_get_error("SET max_memory_usage = 80000000", user="robin") + assert "Setting max_memory_usage shouldn't be greater than 110000000" in instance.query_and_get_error("SET max_memory_usage = 120000000", user="robin") + + instance.query("ALTER USER robin SETTINGS NONE") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "10000000000\n" + instance.query("SET max_memory_usage = 80000000", user="robin") + instance.query("SET max_memory_usage = 120000000", user="robin") + + +def test_settings_profile_from_granted_role(): + # Set settings and constraints via granted role + instance.query("CREATE SETTINGS PROFILE xyz SETTINGS max_memory_usage = 100000001 MIN 90000000 MAX 110000000") + instance.query("CREATE ROLE worker SETTINGS PROFILE xyz") + instance.query("GRANT worker TO robin") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "100000001\n" + assert "Setting max_memory_usage shouldn't be less than 90000000" in instance.query_and_get_error("SET max_memory_usage = 80000000", user="robin") + assert "Setting max_memory_usage shouldn't be greater than 110000000" in instance.query_and_get_error("SET max_memory_usage = 120000000", user="robin") + + instance.query("REVOKE worker FROM robin") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "10000000000\n" + instance.query("SET max_memory_usage = 80000000", user="robin") + instance.query("SET max_memory_usage = 120000000", user="robin") + + instance.query("ALTER ROLE worker SETTINGS NONE") + instance.query("GRANT worker TO robin") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "10000000000\n" + instance.query("SET max_memory_usage = 80000000", user="robin") + instance.query("SET max_memory_usage = 120000000", user="robin") + + # Set settings and constraints via CREATE SETTINGS PROFILE ... TO granted role + instance.query("ALTER SETTINGS PROFILE xyz TO worker") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "100000001\n" + assert "Setting max_memory_usage shouldn't be less than 90000000" in instance.query_and_get_error("SET max_memory_usage = 80000000", user="robin") + assert "Setting max_memory_usage shouldn't be greater than 110000000" in instance.query_and_get_error("SET max_memory_usage = 120000000", user="robin") + + instance.query("ALTER SETTINGS PROFILE xyz TO NONE") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "10000000000\n" + instance.query("SET max_memory_usage = 80000000", user="robin") + instance.query("SET max_memory_usage = 120000000", user="robin") + + +def test_inheritance_of_settings_profile(): + instance.query("CREATE SETTINGS PROFILE xyz SETTINGS max_memory_usage = 100000002 READONLY") + instance.query("CREATE SETTINGS PROFILE alpha SETTINGS PROFILE xyz TO robin") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "100000002\n" + assert "Setting max_memory_usage should not be changed" in instance.query_and_get_error("SET max_memory_usage = 80000000", user="robin") + + +def test_alter_and_drop(): + instance.query("CREATE SETTINGS PROFILE xyz SETTINGS max_memory_usage = 100000003 MIN 90000000 MAX 110000000 TO robin") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "100000003\n" + assert "Setting max_memory_usage shouldn't be less than 90000000" in instance.query_and_get_error("SET max_memory_usage = 80000000", user="robin") + assert "Setting max_memory_usage shouldn't be greater than 110000000" in instance.query_and_get_error("SET max_memory_usage = 120000000", user="robin") + + instance.query("ALTER SETTINGS PROFILE xyz SETTINGS readonly=1") + assert "Cannot modify 'max_memory_usage' setting in readonly mode" in instance.query_and_get_error("SET max_memory_usage = 80000000", user="robin") + + instance.query("DROP SETTINGS PROFILE xyz") + assert instance.query("SELECT value FROM system.settings WHERE name = 'max_memory_usage'", user="robin") == "10000000000\n" + instance.query("SET max_memory_usage = 80000000", user="robin") + instance.query("SET max_memory_usage = 120000000", user="robin") diff --git a/dbms/tests/queries/0_stateless/00834_kill_mutation.reference b/dbms/tests/queries/0_stateless/00834_kill_mutation.reference index cbee44069d8..1e4a67b66ea 100644 --- a/dbms/tests/queries/0_stateless/00834_kill_mutation.reference +++ b/dbms/tests/queries/0_stateless/00834_kill_mutation.reference @@ -1,7 +1,7 @@ *** Create and kill a single invalid mutation *** 1 -waiting test kill_mutation mutation_3.txt +waiting test kill_mutation mutation_3.txt DELETE WHERE toUInt32(s) = 1 *** Create and kill invalid mutation that blocks another mutation *** 1 -waiting test kill_mutation mutation_4.txt +waiting test kill_mutation mutation_4.txt DELETE WHERE toUInt32(s) = 1 2001-01-01 2 b diff --git a/dbms/tests/queries/0_stateless/00834_kill_mutation_replicated_zookeeper.reference b/dbms/tests/queries/0_stateless/00834_kill_mutation_replicated_zookeeper.reference index a997ebe1dc9..d6a82e48836 100644 --- a/dbms/tests/queries/0_stateless/00834_kill_mutation_replicated_zookeeper.reference +++ b/dbms/tests/queries/0_stateless/00834_kill_mutation_replicated_zookeeper.reference @@ -1,9 +1,9 @@ *** Create and kill a single invalid mutation *** 1 Mutation 0000000000 was killed -waiting test kill_mutation_r1 0000000000 +waiting test kill_mutation_r1 0000000000 DELETE WHERE toUInt32(s) = 1 0 *** Create and kill invalid mutation that blocks another mutation *** 1 -waiting test kill_mutation_r1 0000000001 +waiting test kill_mutation_r1 0000000001 DELETE WHERE toUInt32(s) = 1 2001-01-01 2 b diff --git a/dbms/tests/queries/0_stateless/01074_partial_revokes.reference b/dbms/tests/queries/0_stateless/01074_partial_revokes.reference index e64d439b5b2..19a70679143 100644 --- a/dbms/tests/queries/0_stateless/01074_partial_revokes.reference +++ b/dbms/tests/queries/0_stateless/01074_partial_revokes.reference @@ -1,5 +1,2 @@ -A -GRANT SELECT ON *.* TO test_user_01074 -B GRANT SELECT ON *.* TO test_user_01074 REVOKE SELECT ON db.* FROM test_user_01074 diff --git a/dbms/tests/queries/0_stateless/01074_partial_revokes.sql b/dbms/tests/queries/0_stateless/01074_partial_revokes.sql index af7048a0815..4406341cc4f 100644 --- a/dbms/tests/queries/0_stateless/01074_partial_revokes.sql +++ b/dbms/tests/queries/0_stateless/01074_partial_revokes.sql @@ -1,15 +1,8 @@ DROP USER IF EXISTS test_user_01074; CREATE USER test_user_01074; -SELECT 'A'; -SET partial_revokes=0; GRANT SELECT ON *.* TO test_user_01074; REVOKE SELECT ON db.* FROM test_user_01074; SHOW GRANTS FOR test_user_01074; -SELECT 'B'; -SET partial_revokes=1; -REVOKE SELECT ON db.* FROM test_user_01074; -SHOW GRANTS FOR test_user_01074; - DROP USER test_user_01074;