From 7d288058151986e295567d48fd90bb02f77265ec Mon Sep 17 00:00:00 2001 From: Sergei Shtykov Date: Thu, 20 Feb 2020 17:13:19 +0300 Subject: [PATCH 01/19] CLICKHOUSEDOCS-446: Added ALTER MODIFY TTL --- docs/en/operations/table_engines/mergetree.md | 4 +-- docs/en/query_language/alter.md | 26 ++++++++++++++++--- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/docs/en/operations/table_engines/mergetree.md b/docs/en/operations/table_engines/mergetree.md index 1ad898674e5..f9380d4fe1a 100644 --- a/docs/en/operations/table_engines/mergetree.md +++ b/docs/en/operations/table_engines/mergetree.md @@ -392,7 +392,7 @@ TTL date_time + INTERVAL 1 MONTH TTL date_time + INTERVAL 15 HOUR ``` -**Column TTL** +### Column TTL {#mergetree-column-ttl} When the values in the column expire, ClickHouse replaces them with the default values for the column data type. If all the column values in the data part expire, ClickHouse deletes this column from the data part in a filesystem. @@ -431,7 +431,7 @@ ALTER TABLE example_table c String TTL d + INTERVAL 1 MONTH; ``` -**Table TTL** +### Table TTL {#mergetree-table-ttl} Table can have an expression for removal of expired rows, and multiple expressions for automatic move of parts between [disks or volumes](#table_engine-mergetree-multiple-volumes). When rows in the table expire, ClickHouse deletes all corresponding rows. For parts moving feature, all rows of a part must satisfy the movement expression criteria. diff --git a/docs/en/query_language/alter.md b/docs/en/query_language/alter.md index 707ff6ff3b4..4111a43daef 100644 --- a/docs/en/query_language/alter.md +++ b/docs/en/query_language/alter.md @@ -19,7 +19,7 @@ The following actions are supported: - [DROP COLUMN](#alter_drop-column) — Deletes the column. - [CLEAR COLUMN](#alter_clear-column) — Resets column values. - [COMMENT COLUMN](#alter_comment-column) — Adds a text comment to the column. -- [MODIFY COLUMN](#alter_modify-column) — Changes column's type and/or default expression. +- [MODIFY COLUMN](#alter_modify-column) — Changes column's type, default expression and TTL. These actions are described in detail below. @@ -96,10 +96,19 @@ ALTER TABLE visits COMMENT COLUMN browser 'The table shows the browser used for #### MODIFY COLUMN {#alter_modify-column} ```sql -MODIFY COLUMN [IF EXISTS] name [type] [default_expr] +MODIFY COLUMN [IF EXISTS] name [type] [default_expr] [TTL] ``` -This query changes the `name` column's type to `type` and/or the default expression to `default_expr`. If the `IF EXISTS` clause is specified, the query won't return an error if the column doesn't exist. +This query changes the `name` column properties: + +- Type +- Default expression +- TTL + + For examples of columns TTL modifying, see [../operations/table_engines/mergetree.md#table_engine-mergetree-ttl]. + + +If the `IF EXISTS` clause is specified, the query won't return an error if the column doesn't exist. When changing the type, values are converted as if the [toType](functions/type_conversion_functions.md) functions were applied to them. If only the default expression is changed, the query doesn't do anything complex, and is completed almost instantly. @@ -433,6 +442,17 @@ OPTIMIZE TABLE table_not_partitioned PARTITION tuple() FINAL; The examples of `ALTER ... PARTITION` queries are demonstrated in the tests [`00502_custom_partitioning_local`](https://github.com/ClickHouse/ClickHouse/blob/master/dbms/tests/queries/0_stateless/00502_custom_partitioning_local.sql) and [`00502_custom_partitioning_replicated_zookeeper`](https://github.com/ClickHouse/ClickHouse/blob/master/dbms/tests/queries/0_stateless/00502_custom_partitioning_replicated_zookeeper.sql). + +### Manipulations with Table TTL + +You can change table TTL with a request of the following form: + +```sql +ALTER TABLE table-name MODIFY TTL ttl-expression +``` + +For example of columns TTL modifying, see [../operations/table_engines/mergetree.md#mergetree-table-ttl]. + ### Synchronicity of ALTER Queries For non-replicatable tables, all `ALTER` queries are performed synchronously. For replicatable tables, the query just adds instructions for the appropriate actions to `ZooKeeper`, and the actions themselves are performed as soon as possible. However, the query can wait for these actions to be completed on all the replicas. From ff131e1378caceac391567f0c99761b158f9ae39 Mon Sep 17 00:00:00 2001 From: Sergei Shtykov Date: Thu, 20 Feb 2020 17:38:57 +0300 Subject: [PATCH 02/19] CLICKHOUSEDOCS-446: Tranlsated to Russian. Fixed links in ZH version. --- docs/en/query_language/alter.md | 6 ++--- docs/ru/operations/table_engines/mergetree.md | 4 ++-- docs/ru/query_language/alter.md | 22 ++++++++++++++++--- docs/zh/operations/table_engines/mergetree.md | 4 ++-- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/docs/en/query_language/alter.md b/docs/en/query_language/alter.md index 4111a43daef..eeb11282f65 100644 --- a/docs/en/query_language/alter.md +++ b/docs/en/query_language/alter.md @@ -105,7 +105,7 @@ This query changes the `name` column properties: - Default expression - TTL - For examples of columns TTL modifying, see [../operations/table_engines/mergetree.md#table_engine-mergetree-ttl]. + For examples of columns TTL modifying, see [Column TTL](../operations/table_engines/mergetree.md#mergetree-column-ttl). If the `IF EXISTS` clause is specified, the query won't return an error if the column doesn't exist. @@ -445,14 +445,12 @@ The examples of `ALTER ... PARTITION` queries are demonstrated in the tests [`00 ### Manipulations with Table TTL -You can change table TTL with a request of the following form: +You can change [table TTL](../operations/table_engines/mergetree.md#mergetree-table-ttl) with a request of the following form: ```sql ALTER TABLE table-name MODIFY TTL ttl-expression ``` -For example of columns TTL modifying, see [../operations/table_engines/mergetree.md#mergetree-table-ttl]. - ### Synchronicity of ALTER Queries For non-replicatable tables, all `ALTER` queries are performed synchronously. For replicatable tables, the query just adds instructions for the appropriate actions to `ZooKeeper`, and the actions themselves are performed as soon as possible. However, the query can wait for these actions to be completed on all the replicas. diff --git a/docs/ru/operations/table_engines/mergetree.md b/docs/ru/operations/table_engines/mergetree.md index af4ccfc99c6..081ab0e73f8 100644 --- a/docs/ru/operations/table_engines/mergetree.md +++ b/docs/ru/operations/table_engines/mergetree.md @@ -384,7 +384,7 @@ TTL date_time + INTERVAL 1 MONTH TTL date_time + INTERVAL 15 HOUR ``` -**TTL столбца** +### TTL столбца {#mergetree-column-ttl} Когда срок действия значений в столбце истечет, ClickHouse заменит их значениями по умолчанию для типа данных столбца. Если срок действия всех значений столбцов в части данных истек, ClickHouse удаляет столбец из куска данных в файловой системе. @@ -423,7 +423,7 @@ ALTER TABLE example_table c String TTL d + INTERVAL 1 MONTH; ``` -**TTL таблицы** +### TTL таблицы {#mergetree-table-ttl} Для таблицы можно задать одно выражение для устаревания данных, а также несколько выражений, по срабатывании которых данные переместятся на [некоторый диск или том](#table_engine-mergetree-multiple-volumes). Когда некоторые данные в таблице устаревают, ClickHouse удаляет все соответствующие строки. diff --git a/docs/ru/query_language/alter.md b/docs/ru/query_language/alter.md index 771442d36f5..7736456d960 100644 --- a/docs/ru/query_language/alter.md +++ b/docs/ru/query_language/alter.md @@ -19,7 +19,7 @@ ALTER TABLE [db].name [ON CLUSTER cluster] ADD|DROP|CLEAR|COMMENT|MODIFY COLUMN - [DROP COLUMN](#alter_drop-column) — удаляет столбец; - [CLEAR COLUMN](#alter_clear-column) — сбрасывает все значения в столбце для заданной партиции; - [COMMENT COLUMN](#alter_comment-column) — добавляет комментарий к столбцу; -- [MODIFY COLUMN](#alter_modify-column) — изменяет тип столбца и/или выражение для значения по умолчанию. +- [MODIFY COLUMN](#alter_modify-column) — изменяет тип столбца, выражение для значения по умолчанию и TTL. Подробное описание для каждого действия приведено ниже. @@ -95,10 +95,18 @@ ALTER TABLE visits COMMENT COLUMN browser 'Столбец показывает, #### MODIFY COLUMN {#alter_modify-column} ```sql -MODIFY COLUMN [IF EXISTS] name [type] [default_expr] +MODIFY COLUMN [IF EXISTS] name [type] [default_expr] [TTL] ``` -Изменяет тип столбца `name` на `type` и/или выражение для умолчания на `default_expr`. Если указано `IF EXISTS`, запрос не будет возвращать ошибку, если столбца не существует. +Запрос изменяет следующие свойства столбца `name`: + +- Тип +- Значение по умолчанию +- TTL + + Примеры изменения TTL столбца смотрите в разделе [TTL столбца](../operations/table_engines/mergetree.md#mergetree-column-ttl). + +Если указано `IF EXISTS`, запрос не возвращает ошибку, если столбца не существует. При изменении типа, значения преобразуются так, как если бы к ним была применена функция [toType](functions/type_conversion_functions.md). Если изменяется только выражение для умолчания, запрос не делает никакой сложной работы и выполняется мгновенно. @@ -432,6 +440,14 @@ OPTIMIZE TABLE table_not_partitioned PARTITION tuple() FINAL; Примеры запросов `ALTER ... PARTITION` можно посмотреть в тестах: [`00502_custom_partitioning_local`](https://github.com/ClickHouse/ClickHouse/blob/master/dbms/tests/queries/0_stateless/00502_custom_partitioning_local.sql) и [`00502_custom_partitioning_replicated_zookeeper`](https://github.com/ClickHouse/ClickHouse/blob/master/dbms/tests/queries/0_stateless/00502_custom_partitioning_replicated_zookeeper.sql). +### Манипуляции с TTL таблицы + +Вы можете изменить [TTL для таблицы](../operations/table_engines/mergetree.md#mergetree-table-ttl) запросом следующего вида: + +```sql +ALTER TABLE table-name MODIFY TTL ttl-expression +``` + ### Синхронность запросов ALTER Для нереплицируемых таблиц, все запросы `ALTER` выполняются синхронно. Для реплицируемых таблиц, запрос всего лишь добавляет инструкцию по соответствующим действиям в `ZooKeeper`, а сами действия осуществляются при первой возможности. Но при этом, запрос может ждать завершения выполнения этих действий на всех репликах. diff --git a/docs/zh/operations/table_engines/mergetree.md b/docs/zh/operations/table_engines/mergetree.md index 984f0339aad..c1705029e2e 100644 --- a/docs/zh/operations/table_engines/mergetree.md +++ b/docs/zh/operations/table_engines/mergetree.md @@ -323,7 +323,7 @@ TTL date_time + INTERVAL 1 MONTH TTL date_time + INTERVAL 15 HOUR ``` -**列字段 TTL** +### 列字段 TTL {#mergetree-column-ttl} 当列字段中的值过期时, ClickHouse会将它们替换成数据类型的默认值。如果分区内,某一列的所有值均已过期,则ClickHouse会从文件系统中删除这个分区目录下的列文件。 @@ -362,7 +362,7 @@ ALTER TABLE example_table c String TTL d + INTERVAL 1 MONTH; ``` -**表 TTL** +### 表 TTL {#mergetree-table-ttl} 当表内的数据过期时, ClickHouse会删除所有对应的行。 From 9edea08b6dc074a1d0f93572bf05537488afd0b7 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Mon, 10 Feb 2020 18:24:33 +0300 Subject: [PATCH 03/19] Improve parsers of access managing SQL. --- .../InterpreterCreateQuotaQuery.cpp | 6 +- .../InterpreterCreateRowPolicyQuery.cpp | 6 +- .../Interpreters/InterpreterCreateUserQuery.h | 2 +- .../Interpreters/InterpreterGrantQuery.cpp | 7 +- ...InterpreterShowCreateAccessEntityQuery.cpp | 12 +- .../InterpreterShowGrantsQuery.cpp | 2 +- dbms/src/Parsers/ASTCreateQuotaQuery.cpp | 4 +- dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp | 8 +- dbms/src/Parsers/ASTGrantQuery.cpp | 51 ++-- dbms/src/Parsers/ASTGrantQuery.h | 4 +- dbms/src/Parsers/ASTRoleList.cpp | 63 ++-- dbms/src/Parsers/ASTRoleList.h | 8 +- .../ASTShowCreateAccessEntityQuery.cpp | 3 +- dbms/src/Parsers/ASTShowGrantsQuery.cpp | 10 +- dbms/src/Parsers/ParserCreateQuotaQuery.cpp | 39 ++- .../Parsers/ParserCreateRowPolicyQuery.cpp | 79 +++-- dbms/src/Parsers/ParserCreateUserQuery.cpp | 195 ++++++------ .../Parsers/ParserDropAccessEntityQuery.cpp | 75 +++-- dbms/src/Parsers/ParserGrantQuery.cpp | 283 ++++++++++-------- dbms/src/Parsers/ParserGrantQuery.h | 4 +- dbms/src/Parsers/ParserRoleList.cpp | 128 ++++---- dbms/src/Parsers/ParserRoleList.h | 7 + .../ParserShowCreateAccessEntityQuery.cpp | 1 + 23 files changed, 572 insertions(+), 425 deletions(-) diff --git a/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp b/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp index 7f6b5a392c5..f979c1e0ac8 100644 --- a/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp @@ -103,16 +103,16 @@ void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTC const auto & query_roles = *query.roles; /// We keep `roles` sorted. - quota.roles = query_roles.roles; + quota.roles = query_roles.names; if (query_roles.current_user) quota.roles.push_back(context.getClientInfo().current_user); boost::range::sort(quota.roles); quota.roles.erase(std::unique(quota.roles.begin(), quota.roles.end()), quota.roles.end()); - quota.all_roles = query_roles.all_roles; + quota.all_roles = query_roles.all; /// We keep `except_roles` sorted. - quota.except_roles = query_roles.except_roles; + quota.except_roles = query_roles.except_names; if (query_roles.except_current_user) quota.except_roles.push_back(context.getClientInfo().current_user); boost::range::sort(quota.except_roles); diff --git a/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp b/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp index b207d3540b2..f5749f4eb74 100644 --- a/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp @@ -75,16 +75,16 @@ void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & polic const auto & query_roles = *query.roles; /// We keep `roles` sorted. - policy.roles = query_roles.roles; + policy.roles = query_roles.names; if (query_roles.current_user) policy.roles.push_back(context.getClientInfo().current_user); boost::range::sort(policy.roles); policy.roles.erase(std::unique(policy.roles.begin(), policy.roles.end()), policy.roles.end()); - policy.all_roles = query_roles.all_roles; + policy.all_roles = query_roles.all; /// We keep `except_roles` sorted. - policy.except_roles = query_roles.except_roles; + policy.except_roles = query_roles.except_names; if (query_roles.except_current_user) policy.except_roles.push_back(context.getClientInfo().current_user); boost::range::sort(policy.except_roles); diff --git a/dbms/src/Interpreters/InterpreterCreateUserQuery.h b/dbms/src/Interpreters/InterpreterCreateUserQuery.h index 228d796e40f..f040a23a7c2 100644 --- a/dbms/src/Interpreters/InterpreterCreateUserQuery.h +++ b/dbms/src/Interpreters/InterpreterCreateUserQuery.h @@ -18,7 +18,7 @@ public: BlockIO execute() override; private: - void updateUserFromQuery(User & quota, const ASTCreateUserQuery & query); + void updateUserFromQuery(User & user, const ASTCreateUserQuery & query); ASTPtr query_ptr; Context & context; diff --git a/dbms/src/Interpreters/InterpreterGrantQuery.cpp b/dbms/src/Interpreters/InterpreterGrantQuery.cpp index bf09b7cd61f..076bd6f11a1 100644 --- a/dbms/src/Interpreters/InterpreterGrantQuery.cpp +++ b/dbms/src/Interpreters/InterpreterGrantQuery.cpp @@ -16,11 +16,6 @@ BlockIO InterpreterGrantQuery::execute() context.getAccessRights()->checkGrantOption(query.access_rights_elements); using Kind = ASTGrantQuery::Kind; - - if (query.to_roles->all_roles) - throw Exception( - "Cannot " + String((query.kind == Kind::GRANT) ? "GRANT to" : "REVOKE from") + " ALL", ErrorCodes::NOT_IMPLEMENTED); - String current_database = context.getCurrentDatabase(); auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr @@ -47,7 +42,7 @@ BlockIO InterpreterGrantQuery::execute() return updated_user; }; - std::vector ids = access_control.getIDs(query.to_roles->roles); + std::vector ids = access_control.getIDs(query.to_roles->names); if (query.to_roles->current_user) ids.push_back(context.getUserID()); access_control.update(ids, update_func); diff --git a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp index 036aaa82369..dab3a42554c 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp @@ -118,9 +118,9 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShow if (!quota->roles.empty() || quota->all_roles) { auto create_query_roles = std::make_shared(); - create_query_roles->roles = quota->roles; - create_query_roles->all_roles = quota->all_roles; - create_query_roles->except_roles = quota->except_roles; + create_query_roles->names = quota->roles; + create_query_roles->all = quota->all_roles; + create_query_roles->except_names = quota->except_roles; create_query->roles = std::move(create_query_roles); } @@ -152,9 +152,9 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateRowPolicyQuery(const AST if (!policy->roles.empty() || policy->all_roles) { auto create_query_roles = std::make_shared(); - create_query_roles->roles = policy->roles; - create_query_roles->all_roles = policy->all_roles; - create_query_roles->except_roles = policy->except_roles; + create_query_roles->names = policy->roles; + create_query_roles->all = policy->all_roles; + create_query_roles->except_names = policy->except_roles; create_query->roles = std::move(create_query_roles); } diff --git a/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp b/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp index cb45b6343ed..17761178ef4 100644 --- a/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp @@ -112,7 +112,7 @@ ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show grant_query->kind = kind; grant_query->grant_option = grant_option; grant_query->to_roles = std::make_shared(); - grant_query->to_roles->roles.push_back(user->getName()); + grant_query->to_roles->names.push_back(user->getName()); grant_query->access_rights_elements = elements; res.push_back(std::move(grant_query)); } diff --git a/dbms/src/Parsers/ASTCreateQuotaQuery.cpp b/dbms/src/Parsers/ASTCreateQuotaQuery.cpp index 2814515d61f..205d3c33d18 100644 --- a/dbms/src/Parsers/ASTCreateQuotaQuery.cpp +++ b/dbms/src/Parsers/ASTCreateQuotaQuery.cpp @@ -94,7 +94,7 @@ namespace } } - void formatRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings) + void formatToRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); roles.format(settings); @@ -137,6 +137,6 @@ void ASTCreateQuotaQuery::formatImpl(const FormatSettings & settings, FormatStat formatAllLimits(all_limits, settings); if (roles) - formatRoles(*roles, settings); + formatToRoles(*roles, settings); } } diff --git a/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp b/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp index 50e1645f14b..184474753df 100644 --- a/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp +++ b/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp @@ -19,7 +19,7 @@ namespace } - void formatIsRestrictive(bool is_restrictive, const IAST::FormatSettings & settings) + void formatAsRestrictiveOrPermissive(bool is_restrictive, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " AS " << (is_restrictive ? "RESTRICTIVE" : "PERMISSIVE") << (settings.hilite ? IAST::hilite_none : ""); @@ -112,7 +112,7 @@ namespace } } - void formatRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings) + void formatToRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); roles.format(settings); @@ -154,11 +154,11 @@ void ASTCreateRowPolicyQuery::formatImpl(const FormatSettings & settings, Format formatRenameTo(new_policy_name, settings); if (is_restrictive) - formatIsRestrictive(*is_restrictive, settings); + formatAsRestrictiveOrPermissive(*is_restrictive, settings); formatMultipleConditions(conditions, alter, settings); if (roles) - formatRoles(*roles, settings); + formatToRoles(*roles, settings); } } diff --git a/dbms/src/Parsers/ASTGrantQuery.cpp b/dbms/src/Parsers/ASTGrantQuery.cpp index e7210c10304..1f3800f100c 100644 --- a/dbms/src/Parsers/ASTGrantQuery.cpp +++ b/dbms/src/Parsers/ASTGrantQuery.cpp @@ -71,6 +71,34 @@ namespace } settings.ostr << ")"; } + + + void formatAccessRightsElements(const AccessRightsElements & elements, const IAST::FormatSettings & settings) + { + bool need_comma = false; + for (const auto & [database_and_table, keyword_to_columns] : prepareTableToAccessMap(elements)) + { + for (const auto & [keyword, columns] : keyword_to_columns) + { + if (std::exchange(need_comma, true)) + settings.ostr << ", "; + + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << keyword << (settings.hilite ? IAST::hilite_none : ""); + formatColumnNames(columns, settings); + } + + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " ON " << (settings.hilite ? IAST::hilite_none : "") << database_and_table; + } + } + + + void formatToRoles(const ASTRoleList & to_roles, ASTGrantQuery::Kind kind, const IAST::FormatSettings & settings) + { + using Kind = ASTGrantQuery::Kind; + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ((kind == Kind::GRANT) ? " TO " : " FROM ") + << (settings.hilite ? IAST::hilite_none : ""); + to_roles.format(settings); + } } @@ -88,29 +116,14 @@ ASTPtr ASTGrantQuery::clone() const void ASTGrantQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { - settings.ostr << (settings.hilite ? hilite_keyword : "") << ((kind == Kind::GRANT) ? "GRANT" : "REVOKE") - << (settings.hilite ? hilite_none : "") << " "; + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ((kind == Kind::GRANT) ? "GRANT" : "REVOKE") + << (settings.hilite ? IAST::hilite_none : "") << " "; if (grant_option && (kind == Kind::REVOKE)) settings.ostr << (settings.hilite ? hilite_keyword : "") << "GRANT OPTION FOR " << (settings.hilite ? hilite_none : ""); - bool need_comma = false; - for (const auto & [database_and_table, keyword_to_columns] : prepareTableToAccessMap(access_rights_elements)) - { - for (const auto & [keyword, columns] : keyword_to_columns) - { - if (std::exchange(need_comma, true)) - settings.ostr << ", "; - - settings.ostr << (settings.hilite ? hilite_keyword : "") << keyword << (settings.hilite ? hilite_none : ""); - formatColumnNames(columns, settings); - } - - settings.ostr << (settings.hilite ? hilite_keyword : "") << " ON " << (settings.hilite ? hilite_none : "") << database_and_table; - } - - settings.ostr << (settings.hilite ? hilite_keyword : "") << ((kind == Kind::GRANT) ? " TO " : " FROM ") << (settings.hilite ? hilite_none : ""); - to_roles->format(settings); + formatAccessRightsElements(access_rights_elements, settings); + formatToRoles(*to_roles, kind, settings); if (grant_option && (kind == Kind::GRANT)) settings.ostr << (settings.hilite ? hilite_keyword : "") << " WITH GRANT OPTION" << (settings.hilite ? hilite_none : ""); diff --git a/dbms/src/Parsers/ASTGrantQuery.h b/dbms/src/Parsers/ASTGrantQuery.h index f7eb69b2ac6..2cdf7b7f661 100644 --- a/dbms/src/Parsers/ASTGrantQuery.h +++ b/dbms/src/Parsers/ASTGrantQuery.h @@ -9,8 +9,8 @@ namespace DB class ASTRoleList; -/** GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name - * REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name +/** GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO {user_name | CURRENT_USER} [,...] [WITH GRANT OPTION] + * REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} FROM {user_name | CURRENT_USER} [,...] | ALL | ALL EXCEPT {user_name | CURRENT_USER} [,...] */ class ASTGrantQuery : public IAST { diff --git a/dbms/src/Parsers/ASTRoleList.cpp b/dbms/src/Parsers/ASTRoleList.cpp index 9e0a4fffc36..87e388b9b1c 100644 --- a/dbms/src/Parsers/ASTRoleList.cpp +++ b/dbms/src/Parsers/ASTRoleList.cpp @@ -13,43 +13,46 @@ void ASTRoleList::formatImpl(const FormatSettings & settings, FormatState &, For } bool need_comma = false; - if (current_user) - { - if (std::exchange(need_comma, true)) - settings.ostr << ", "; - settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "CURRENT_USER" << (settings.hilite ? IAST::hilite_none : ""); - } - - for (auto & role : roles) - { - if (std::exchange(need_comma, true)) - settings.ostr << ", "; - settings.ostr << backQuoteIfNeed(role); - } - - if (all_roles) + if (all) { if (std::exchange(need_comma, true)) settings.ostr << ", "; settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "ALL" << (settings.hilite ? IAST::hilite_none : ""); - if (except_current_user || !except_roles.empty()) + } + else + { + for (auto & role : names) { - settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " EXCEPT " << (settings.hilite ? IAST::hilite_none : ""); - need_comma = false; + if (std::exchange(need_comma, true)) + settings.ostr << ", "; + settings.ostr << backQuoteIfNeed(role); + } - if (except_current_user) - { - if (std::exchange(need_comma, true)) - settings.ostr << ", "; - settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "CURRENT_USER" << (settings.hilite ? IAST::hilite_none : ""); - } + if (current_user) + { + if (std::exchange(need_comma, true)) + settings.ostr << ", "; + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "CURRENT_USER" << (settings.hilite ? IAST::hilite_none : ""); + } + } - for (auto & except_role : except_roles) - { - if (std::exchange(need_comma, true)) - settings.ostr << ", "; - settings.ostr << backQuoteIfNeed(except_role); - } + if (except_current_user || !except_names.empty()) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " EXCEPT " << (settings.hilite ? IAST::hilite_none : ""); + need_comma = false; + + for (auto & except_role : except_names) + { + if (std::exchange(need_comma, true)) + settings.ostr << ", "; + settings.ostr << backQuoteIfNeed(except_role); + } + + if (except_current_user) + { + if (std::exchange(need_comma, true)) + settings.ostr << ", "; + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "CURRENT_USER" << (settings.hilite ? IAST::hilite_none : ""); } } } diff --git a/dbms/src/Parsers/ASTRoleList.h b/dbms/src/Parsers/ASTRoleList.h index 5e8859732c2..daef6124d18 100644 --- a/dbms/src/Parsers/ASTRoleList.h +++ b/dbms/src/Parsers/ASTRoleList.h @@ -10,13 +10,13 @@ namespace DB class ASTRoleList : public IAST { public: - Strings roles; + Strings names; bool current_user = false; - bool all_roles = false; - Strings except_roles; + bool all = false; + Strings except_names; bool except_current_user = false; - bool empty() const { return roles.empty() && !current_user && !all_roles; } + bool empty() const { return names.empty() && !current_user && !all; } String getID(char) const override { return "RoleList"; } ASTPtr clone() const override { return std::make_shared(*this); } diff --git a/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.cpp b/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.cpp index c52941d4677..4201a733f43 100644 --- a/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ASTShowCreateAccessEntityQuery.cpp @@ -46,7 +46,8 @@ void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & sett << (settings.hilite ? hilite_none : ""); if ((kind == Kind::USER) && current_user) - settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT_USER" << (settings.hilite ? hilite_none : ""); + { + } else if ((kind == Kind::QUOTA) && current_quota) settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : ""); else if (kind == Kind::ROW_POLICY) diff --git a/dbms/src/Parsers/ASTShowGrantsQuery.cpp b/dbms/src/Parsers/ASTShowGrantsQuery.cpp index c7639630f20..b3cc0cbd386 100644 --- a/dbms/src/Parsers/ASTShowGrantsQuery.cpp +++ b/dbms/src/Parsers/ASTShowGrantsQuery.cpp @@ -18,13 +18,11 @@ ASTPtr ASTShowGrantsQuery::clone() const void ASTShowGrantsQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { - settings.ostr << (settings.hilite ? hilite_keyword : "") << "SHOW GRANTS FOR " + settings.ostr << (settings.hilite ? hilite_keyword : "") << "SHOW GRANTS" << (settings.hilite ? hilite_none : ""); - if (current_user) - settings.ostr << (settings.hilite ? hilite_keyword : "") << "CURRENT_USER" - << (settings.hilite ? hilite_none : ""); - else - settings.ostr << backQuoteIfNeed(name); + if (!current_user) + settings.ostr << (settings.hilite ? hilite_keyword : "") << " FOR " << (settings.hilite ? hilite_none : "") + << backQuoteIfNeed(name); } } diff --git a/dbms/src/Parsers/ParserCreateQuotaQuery.cpp b/dbms/src/Parsers/ParserCreateQuotaQuery.cpp index cc5fa4bfbcc..61e7d2f1c52 100644 --- a/dbms/src/Parsers/ParserCreateQuotaQuery.cpp +++ b/dbms/src/Parsers/ParserCreateQuotaQuery.cpp @@ -25,13 +25,10 @@ namespace using ResourceType = Quota::ResourceType; using ResourceAmount = Quota::ResourceAmount; - bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name, bool alter) + bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name) { return IParserBase::wrapParseImpl(pos, [&] { - if (!new_name.empty() || !alter) - return false; - if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) return false; @@ -43,9 +40,6 @@ namespace { return IParserBase::wrapParseImpl(pos, [&] { - if (key_type) - return false; - if (!ParserKeyword{"KEYED BY"}.ignore(pos, expected)) return false; @@ -123,7 +117,7 @@ namespace }); } - bool parseLimits(IParserBase::Pos & pos, Expected & expected, ASTCreateQuotaQuery::Limits & limits, bool alter) + bool parseLimits(IParserBase::Pos & pos, Expected & expected, bool alter, ASTCreateQuotaQuery::Limits & limits) { return IParserBase::wrapParseImpl(pos, [&] { @@ -173,15 +167,19 @@ namespace }); } - bool parseAllLimits(IParserBase::Pos & pos, Expected & expected, std::vector & all_limits, bool alter) + bool parseAllLimits(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector & all_limits) { return IParserBase::wrapParseImpl(pos, [&] { + size_t old_size = all_limits.size(); do { ASTCreateQuotaQuery::Limits limits; - if (!parseLimits(pos, expected, limits, alter)) + if (!parseLimits(pos, expected, alter, limits)) + { + all_limits.resize(old_size); return false; + } all_limits.push_back(limits); } while (ParserToken{TokenType::Comma}.ignore(pos, expected)); @@ -189,7 +187,7 @@ namespace }); } - bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) + bool parseToRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) { return IParserBase::wrapParseImpl(pos, [&] { @@ -239,9 +237,22 @@ bool ParserCreateQuotaQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expe std::vector all_limits; std::shared_ptr roles; - while (parseRenameTo(pos, expected, new_name, alter) || parseKeyType(pos, expected, key_type) - || parseAllLimits(pos, expected, all_limits, alter) || parseRoles(pos, expected, roles)) - ; + while (true) + { + if (alter && new_name.empty() && parseRenameTo(pos, expected, new_name)) + continue; + + if (!key_type && parseKeyType(pos, expected, key_type)) + continue; + + if (parseAllLimits(pos, expected, alter, all_limits)) + continue; + + if (!roles && parseToRoles(pos, expected, roles)) + continue; + + break; + } auto query = std::make_shared(); node = query; diff --git a/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp b/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp index 2778ddea93f..c1bfab2551b 100644 --- a/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp +++ b/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp @@ -21,13 +21,10 @@ namespace { using ConditionIndex = RowPolicy::ConditionIndex; - bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_policy_name, bool alter) + bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_policy_name) { return IParserBase::wrapParseImpl(pos, [&] { - if (!new_policy_name.empty() || !alter) - return false; - if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) return false; @@ -35,46 +32,48 @@ namespace }); } - bool parseIsRestrictive(IParserBase::Pos & pos, Expected & expected, std::optional & is_restrictive) + bool parseAsRestrictiveOrPermissive(IParserBase::Pos & pos, Expected & expected, std::optional & is_restrictive) { return IParserBase::wrapParseImpl(pos, [&] { - if (is_restrictive) - return false; - if (!ParserKeyword{"AS"}.ignore(pos, expected)) return false; if (ParserKeyword{"RESTRICTIVE"}.ignore(pos, expected)) + { is_restrictive = true; - else if (ParserKeyword{"PERMISSIVE"}.ignore(pos, expected)) - is_restrictive = false; - else + return true; + } + + if (!ParserKeyword{"PERMISSIVE"}.ignore(pos, expected)) return false; + is_restrictive = false; return true; }); } bool parseConditionalExpression(IParserBase::Pos & pos, Expected & expected, std::optional & expr) { - if (ParserKeyword("NONE").ignore(pos, expected)) - { - expr = nullptr; - return true; - } - ParserExpression parser; - ASTPtr x; - if (parser.parse(pos, x, expected)) + return IParserBase::wrapParseImpl(pos, [&] { + if (ParserKeyword("NONE").ignore(pos, expected)) + { + expr = nullptr; + return true; + } + + ParserExpression parser; + ASTPtr x; + if (!parser.parse(pos, x, expected)) + return false; + expr = x; return true; - } - expr.reset(); - return false; + }); } - bool parseConditions(IParserBase::Pos & pos, Expected & expected, std::vector> & conditions, bool alter) + bool parseConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector> & conditions) { return IParserBase::wrapParseImpl(pos, [&] { @@ -171,29 +170,32 @@ namespace }); } - bool parseMultipleConditions(IParserBase::Pos & pos, Expected & expected, std::vector> & conditions, bool alter) + bool parseMultipleConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector> & conditions) { return IParserBase::wrapParseImpl(pos, [&] { + std::vector> res_conditions; do { - if (!parseConditions(pos, expected, conditions, alter)) + if (!parseConditions(pos, expected, alter, res_conditions)) return false; } while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + + conditions = std::move(res_conditions); return true; }); } - bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) + bool parseToRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) { return IParserBase::wrapParseImpl(pos, [&] { - ASTPtr node; - if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserRoleList{}.parse(pos, node, expected)) + ASTPtr ast; + if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserRoleList{}.parse(pos, ast, expected)) return false; - roles = std::static_pointer_cast(node); + roles = std::static_pointer_cast(ast); return true; }); } @@ -239,9 +241,22 @@ bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & std::vector> conditions; std::shared_ptr roles; - while (parseRenameTo(pos, expected, new_policy_name, alter) || parseIsRestrictive(pos, expected, is_restrictive) - || parseMultipleConditions(pos, expected, conditions, alter) || parseRoles(pos, expected, roles)) - ; + while (true) + { + if (alter && new_policy_name.empty() && parseRenameTo(pos, expected, new_policy_name)) + continue; + + if (!is_restrictive && parseAsRestrictiveOrPermissive(pos, expected, is_restrictive)) + continue; + + if (parseMultipleConditions(pos, expected, alter, conditions)) + continue; + + if (!roles && parseToRoles(pos, expected, roles)) + continue; + + break; + } auto query = std::make_shared(); node = query; diff --git a/dbms/src/Parsers/ParserCreateUserQuery.cpp b/dbms/src/Parsers/ParserCreateUserQuery.cpp index f3af04dad98..7b5aa1fa03c 100644 --- a/dbms/src/Parsers/ParserCreateUserQuery.cpp +++ b/dbms/src/Parsers/ParserCreateUserQuery.cpp @@ -24,9 +24,6 @@ namespace { return IParserBase::wrapParseImpl(pos, [&] { - if (!new_name.empty()) - return false; - if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) return false; @@ -35,14 +32,20 @@ namespace } - bool parsePassword(IParserBase::Pos & pos, Expected & expected, String & password) + bool parseByPassword(IParserBase::Pos & pos, Expected & expected, String & password) { - ASTPtr ast; - if (!ParserStringLiteral{}.parse(pos, ast, expected)) - return false; + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserKeyword{"BY"}.ignore(pos, expected)) + return false; - password = ast->as().value.safeGet(); - return true; + ASTPtr ast; + if (!ParserStringLiteral{}.parse(pos, ast, expected)) + return false; + + password = ast->as().value.safeGet(); + return true; + }); } @@ -50,70 +53,79 @@ namespace { return IParserBase::wrapParseImpl(pos, [&] { - if (authentication) - return false; - if (!ParserKeyword{"IDENTIFIED"}.ignore(pos, expected)) return false; - if (ParserKeyword{"WITH"}.ignore(pos, expected)) - { - if (ParserKeyword{"NO_PASSWORD"}.ignore(pos, expected)) - { - authentication = Authentication{Authentication::NO_PASSWORD}; - } - else if (ParserKeyword{"PLAINTEXT_PASSWORD"}.ignore(pos, expected)) - { - String password; - if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password)) - return false; - authentication = Authentication{Authentication::PLAINTEXT_PASSWORD}; - authentication->setPassword(password); - } - else if (ParserKeyword{"SHA256_PASSWORD"}.ignore(pos, expected)) - { - String password; - if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password)) - return false; - authentication = Authentication{Authentication::SHA256_PASSWORD}; - authentication->setPassword(password); - } - else if (ParserKeyword{"SHA256_HASH"}.ignore(pos, expected)) - { - String hash; - if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, hash)) - return false; - authentication = Authentication{Authentication::SHA256_PASSWORD}; - authentication->setPasswordHashHex(hash); - } - else if (ParserKeyword{"DOUBLE_SHA1_PASSWORD"}.ignore(pos, expected)) - { - String password; - if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password)) - return false; - authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD}; - authentication->setPassword(password); - } - else if (ParserKeyword{"DOUBLE_SHA1_HASH"}.ignore(pos, expected)) - { - String hash; - if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, hash)) - return false; - authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD}; - authentication->setPasswordHashHex(hash); - } - else - return false; - } - else + if (!ParserKeyword{"WITH"}.ignore(pos, expected)) { String password; - if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password)) + if (!parseByPassword(pos, expected, password)) return false; + authentication = Authentication{Authentication::SHA256_PASSWORD}; authentication->setPassword(password); + return true; } + if (ParserKeyword{"PLAINTEXT_PASSWORD"}.ignore(pos, expected)) + { + String password; + if (!parseByPassword(pos, expected, password)) + return false; + + authentication = Authentication{Authentication::PLAINTEXT_PASSWORD}; + authentication->setPassword(password); + return true; + } + + if (ParserKeyword{"SHA256_PASSWORD"}.ignore(pos, expected)) + { + String password; + if (!parseByPassword(pos, expected, password)) + return false; + + authentication = Authentication{Authentication::SHA256_PASSWORD}; + authentication->setPassword(password); + return true; + } + + if (ParserKeyword{"SHA256_HASH"}.ignore(pos, expected)) + { + String hash; + if (!parseByPassword(pos, expected, hash)) + return false; + + authentication = Authentication{Authentication::SHA256_PASSWORD}; + authentication->setPasswordHashHex(hash); + return true; + } + + if (ParserKeyword{"DOUBLE_SHA1_PASSWORD"}.ignore(pos, expected)) + { + String password; + if (!parseByPassword(pos, expected, password)) + return false; + + authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD}; + authentication->setPassword(password); + return true; + } + + if (ParserKeyword{"DOUBLE_SHA1_HASH"}.ignore(pos, expected)) + { + String hash; + if (!parseByPassword(pos, expected, hash)) + return false; + + authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD}; + authentication->setPasswordHashHex(hash); + return true; + } + + if (!ParserKeyword{"NO_PASSWORD"}.ignore(pos, expected)) + return false; + + authentication = Authentication{Authentication::NO_PASSWORD}; return true; }); } @@ -144,13 +156,12 @@ namespace return true; } + AllowedClientHosts new_hosts; do { if (ParserKeyword{"LOCAL"}.ignore(pos, expected)) { - if (!hosts) - hosts.emplace(); - hosts->addLocalHost(); + new_hosts.addLocalHost(); } else if (ParserKeyword{"NAME REGEXP"}.ignore(pos, expected)) { @@ -158,9 +169,7 @@ namespace if (!ParserStringLiteral{}.parse(pos, ast, expected)) return false; - if (!hosts) - hosts.emplace(); - hosts->addNameRegexp(ast->as().value.safeGet()); + new_hosts.addNameRegexp(ast->as().value.safeGet()); } else if (ParserKeyword{"NAME"}.ignore(pos, expected)) { @@ -168,9 +177,7 @@ namespace if (!ParserStringLiteral{}.parse(pos, ast, expected)) return false; - if (!hosts) - hosts.emplace(); - hosts->addName(ast->as().value.safeGet()); + new_hosts.addName(ast->as().value.safeGet()); } else if (ParserKeyword{"IP"}.ignore(pos, expected)) { @@ -178,9 +185,7 @@ namespace if (!ParserStringLiteral{}.parse(pos, ast, expected)) return false; - if (!hosts) - hosts.emplace(); - hosts->addSubnet(ast->as().value.safeGet()); + new_hosts.addSubnet(ast->as().value.safeGet()); } else if (ParserKeyword{"LIKE"}.ignore(pos, expected)) { @@ -188,14 +193,16 @@ namespace if (!ParserStringLiteral{}.parse(pos, ast, expected)) return false; - if (!hosts) - hosts.emplace(); - hosts->addLikePattern(ast->as().value.safeGet()); + new_hosts.addLikePattern(ast->as().value.safeGet()); } else return false; } while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + + if (!hosts) + hosts.emplace(); + hosts->add(new_hosts); return true; }); } @@ -205,9 +212,6 @@ namespace { return IParserBase::wrapParseImpl(pos, [&] { - if (profile) - return false; - if (!ParserKeyword{"PROFILE"}.ignore(pos, expected)) return false; @@ -261,13 +265,28 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec std::optional remove_hosts; std::optional profile; - while (parseAuthentication(pos, expected, authentication) - || parseHosts(pos, expected, nullptr, hosts) - || parseProfileName(pos, expected, profile) - || (alter && parseRenameTo(pos, expected, new_name, new_host_pattern)) - || (alter && parseHosts(pos, expected, "ADD", add_hosts)) - || (alter && parseHosts(pos, expected, "REMOVE", remove_hosts))) - ; + while (true) + { + if (!authentication && parseAuthentication(pos, expected, authentication)) + continue; + + if (parseHosts(pos, expected, nullptr, hosts)) + continue; + + if (!profile && parseProfileName(pos, expected, profile)) + continue; + + if (alter) + { + if (new_name.empty() && parseRenameTo(pos, expected, new_name, new_host_pattern)) + continue; + + if (parseHosts(pos, expected, "ADD", add_hosts) || parseHosts(pos, expected, "REMOVE", remove_hosts)) + continue; + } + + break; + } if (!hosts) { diff --git a/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp b/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp index cad0d6b4217..b0b7aa6f83b 100644 --- a/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp @@ -13,47 +13,64 @@ namespace { bool parseNames(IParserBase::Pos & pos, Expected & expected, Strings & names) { - do + return IParserBase::wrapParseImpl(pos, [&] { - String name; - if (!parseIdentifierOrStringLiteral(pos, expected, name)) - return false; + Strings res_names; + do + { + String name; + if (!parseIdentifierOrStringLiteral(pos, expected, name)) + return false; - names.push_back(std::move(name)); - } - while (ParserToken{TokenType::Comma}.ignore(pos, expected)); - return true; + res_names.push_back(std::move(name)); + } + while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + + names = std::move(res_names); + return true; + }); } - bool parseRowPolicyNames(IParserBase::Pos & pos, Expected & expected, std::vector & row_policies_names) + bool parseRowPolicyNames(IParserBase::Pos & pos, Expected & expected, std::vector & names) { - do + return IParserBase::wrapParseImpl(pos, [&] { - Strings policy_names; - if (!parseNames(pos, expected, policy_names)) - return false; - String database, table_name; - if (!ParserKeyword{"ON"}.ignore(pos, expected) || !parseDatabaseAndTableName(pos, expected, database, table_name)) - return false; - for (const String & policy_name : policy_names) - row_policies_names.push_back({database, table_name, policy_name}); - } - while (ParserToken{TokenType::Comma}.ignore(pos, expected)); - return true; + std::vector res_names; + do + { + Strings policy_names; + if (!parseNames(pos, expected, policy_names)) + return false; + String database, table_name; + if (!ParserKeyword{"ON"}.ignore(pos, expected) || !parseDatabaseAndTableName(pos, expected, database, table_name)) + return false; + for (const String & policy_name : policy_names) + res_names.push_back({database, table_name, policy_name}); + } + while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + + names = std::move(res_names); + return true; + }); } bool parseUserNames(IParserBase::Pos & pos, Expected & expected, Strings & names) { - do + return IParserBase::wrapParseImpl(pos, [&] { - String name; - if (!parseUserName(pos, expected, name)) - return false; + Strings res_names; + do + { + String name; + if (!parseUserName(pos, expected, name)) + return false; - names.push_back(std::move(name)); - } - while (ParserToken{TokenType::Comma}.ignore(pos, expected)); - return true; + res_names.emplace_back(std::move(name)); + } + while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + names = std::move(res_names); + return true; + }); } } diff --git a/dbms/src/Parsers/ParserGrantQuery.cpp b/dbms/src/Parsers/ParserGrantQuery.cpp index 372c800306b..db5c75da290 100644 --- a/dbms/src/Parsers/ParserGrantQuery.cpp +++ b/dbms/src/Parsers/ParserGrantQuery.cpp @@ -12,9 +12,18 @@ namespace DB { namespace { + bool parseRoundBrackets(IParser::Pos & pos, Expected & expected) + { + return IParserBase::wrapParseImpl(pos, [&] + { + return ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected) + && ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected); + }); + } + bool parseAccessFlags(IParser::Pos & pos, Expected & expected, AccessFlags & access_flags) { - auto is_one_of_access_type_words = [](IParser::Pos & pos_) + static constexpr auto is_one_of_access_type_words = [](IParser::Pos & pos_) { if (pos_->type != TokenType::BareWord) return false; @@ -24,86 +33,97 @@ namespace return true; }; - if (!is_one_of_access_type_words(pos)) - { - expected.add(pos, "access type"); - return false; - } + expected.add(pos, "access type"); - String str; - do + return IParserBase::wrapParseImpl(pos, [&] { - if (!str.empty()) - str += " "; - std::string_view word{pos->begin, pos->size()}; - str += std::string_view(pos->begin, pos->size()); - ++pos; - } - while (is_one_of_access_type_words(pos)); + if (!is_one_of_access_type_words(pos)) + return false; - if (pos->type == TokenType::OpeningRoundBracket) - { - auto old_pos = pos; - ++pos; - if (pos->type == TokenType::ClosingRoundBracket) + String str; + do { + if (!str.empty()) + str += " "; + std::string_view word{pos->begin, pos->size()}; + str += std::string_view(pos->begin, pos->size()); ++pos; - str += "()"; } - else - pos = old_pos; - } + while (is_one_of_access_type_words(pos)); - access_flags = AccessFlags{str}; - return true; + try + { + access_flags = AccessFlags{str}; + } + catch (...) + { + return false; + } + + parseRoundBrackets(pos, expected); + return true; + }); } bool parseColumnNames(IParser::Pos & pos, Expected & expected, Strings & columns) { - if (!ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected)) - return false; - - do + return IParserBase::wrapParseImpl(pos, [&] { - ASTPtr column_ast; - if (!ParserIdentifier().parse(pos, column_ast, expected)) + if (!ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected)) return false; - columns.push_back(getIdentifierName(column_ast)); - } - while (ParserToken{TokenType::Comma}.ignore(pos, expected)); - return ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected); + Strings res_columns; + do + { + ASTPtr column_ast; + if (!ParserIdentifier().parse(pos, column_ast, expected)) + return false; + res_columns.emplace_back(getIdentifierName(column_ast)); + } + while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + + if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected)) + return false; + + columns = std::move(res_columns); + return true; + }); } bool parseDatabaseAndTableNameOrMaybeAsterisks( IParser::Pos & pos, Expected & expected, String & database_name, bool & any_database, String & table_name, bool & any_table) { - ASTPtr ast[2]; - if (ParserToken{TokenType::Asterisk}.ignore(pos, expected)) + return IParserBase::wrapParseImpl(pos, [&] { - if (ParserToken{TokenType::Dot}.ignore(pos, expected)) + ASTPtr ast[2]; + if (ParserToken{TokenType::Asterisk}.ignore(pos, expected)) { - if (!ParserToken{TokenType::Asterisk}.ignore(pos, expected)) - return false; + if (ParserToken{TokenType::Dot}.ignore(pos, expected)) + { + if (!ParserToken{TokenType::Asterisk}.ignore(pos, expected)) + return false; + + /// *.* (any table in any database) + any_database = true; + database_name.clear(); + any_table = true; + table_name.clear(); + return true; + } - /// *.* (any table in any database) - any_database = true; - any_table = true; - return true; - } - else - { /// * (any table in the current database) any_database = false; - database_name = ""; + database_name.clear(); any_table = true; + table_name.clear(); return true; } - } - else if (ParserIdentifier().parse(pos, ast[0], expected)) - { + + if (!ParserIdentifier().parse(pos, ast[0], expected)) + return false; + if (ParserToken{TokenType::Dot}.ignore(pos, expected)) { if (ParserToken{TokenType::Asterisk}.ignore(pos, expected)) @@ -112,31 +132,103 @@ namespace any_database = false; database_name = getIdentifierName(ast[0]); any_table = true; + table_name.clear(); return true; } - else if (ParserIdentifier().parse(pos, ast[1], expected)) + + if (!ParserIdentifier().parse(pos, ast[1], expected)) + return false; + + /// . + any_database = false; + database_name = getIdentifierName(ast[0]); + any_table = false; + table_name = getIdentifierName(ast[1]); + return true; + } + + /// - the current database, specified table + any_database = false; + database_name.clear(); + any_table = false; + table_name = getIdentifierName(ast[0]); + return true; + }); + } + + + bool parseAccessRightsElements(IParser::Pos & pos, Expected & expected, AccessRightsElements & elements) + { + return IParserBase::wrapParseImpl(pos, [&] + { + AccessRightsElements res_elements; + do + { + std::vector> access_and_columns; + do { - /// . - any_database = false; - database_name = getIdentifierName(ast[0]); - any_table = false; - table_name = getIdentifierName(ast[1]); - return true; + AccessFlags access_flags; + if (!parseAccessFlags(pos, expected, access_flags)) + return false; + + Strings columns; + parseColumnNames(pos, expected, columns); + access_and_columns.emplace_back(access_flags, std::move(columns)); } - else + while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + + if (!ParserKeyword{"ON"}.ignore(pos, expected)) + return false; + + String database_name, table_name; + bool any_database = false, any_table = false; + if (!parseDatabaseAndTableNameOrMaybeAsterisks(pos, expected, database_name, any_database, table_name, any_table)) + return false; + + for (auto & [access_flags, columns] : access_and_columns) + { + AccessRightsElement element; + element.access_flags = access_flags; + element.any_column = columns.empty(); + element.columns = std::move(columns); + element.any_database = any_database; + element.database = database_name; + element.any_table = any_table; + element.table = table_name; + res_elements.emplace_back(std::move(element)); + } + } + while (ParserToken{TokenType::Comma}.ignore(pos, expected)); + + elements = std::move(res_elements); + return true; + }); + } + + + bool parseToRoles(IParser::Pos & pos, Expected & expected, ASTGrantQuery::Kind kind, std::shared_ptr & to_roles) + { + return IParserBase::wrapParseImpl(pos, [&] + { + using Kind = ASTGrantQuery::Kind; + if (kind == Kind::GRANT) + { + if (!ParserKeyword{"TO"}.ignore(pos, expected)) return false; } else { - /// - the current database, specified table - any_database = false; - database_name = ""; - table_name = getIdentifierName(ast[0]); - return true; + if (!ParserKeyword{"FROM"}.ignore(pos, expected)) + return false; } - } - else - return false; + + ASTPtr ast; + if (!ParserRoleList{false, false}.parse(pos, ast, expected)) + return false; + + to_roles = typeid_cast>(ast); + return true; + }); } } @@ -160,56 +252,8 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) } AccessRightsElements elements; - do - { - std::vector> access_and_columns; - do - { - AccessFlags access_flags; - if (!parseAccessFlags(pos, expected, access_flags)) - return false; - - Strings columns; - parseColumnNames(pos, expected, columns); - access_and_columns.emplace_back(access_flags, std::move(columns)); - } - while (ParserToken{TokenType::Comma}.ignore(pos, expected)); - - if (!ParserKeyword{"ON"}.ignore(pos, expected)) - return false; - - String database_name, table_name; - bool any_database = false, any_table = false; - if (!parseDatabaseAndTableNameOrMaybeAsterisks(pos, expected, database_name, any_database, table_name, any_table)) - return false; - - for (auto & [access_flags, columns] : access_and_columns) - { - AccessRightsElement element; - element.access_flags = access_flags; - element.any_column = columns.empty(); - element.columns = std::move(columns); - element.any_database = any_database; - element.database = database_name; - element.any_table = any_table; - element.table = table_name; - elements.emplace_back(std::move(element)); - } - } - while (ParserToken{TokenType::Comma}.ignore(pos, expected)); - - ASTPtr to_roles; - if (kind == Kind::GRANT) - { - if (!ParserKeyword{"TO"}.ignore(pos, expected)) - return false; - } - else - { - if (!ParserKeyword{"FROM"}.ignore(pos, expected)) - return false; - } - if (!ParserRoleList{}.parse(pos, to_roles, expected)) + std::shared_ptr to_roles; + if (!parseAccessRightsElements(pos, expected, elements) && !parseToRoles(pos, expected, kind, to_roles)) return false; if (kind == Kind::GRANT) @@ -218,13 +262,12 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) grant_option = true; } - auto query = std::make_shared(); node = query; query->kind = kind; query->access_rights_elements = std::move(elements); - query->to_roles = std::static_pointer_cast(to_roles); + query->to_roles = std::move(to_roles); query->grant_option = grant_option; return true; diff --git a/dbms/src/Parsers/ParserGrantQuery.h b/dbms/src/Parsers/ParserGrantQuery.h index 183af52cc52..86cf9863b20 100644 --- a/dbms/src/Parsers/ParserGrantQuery.h +++ b/dbms/src/Parsers/ParserGrantQuery.h @@ -6,8 +6,8 @@ namespace DB { /** Parses queries like - * GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name - * REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name + * GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO {user_name | CURRENT_USER} [,...] [WITH GRANT OPTION] + * REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} FROM {user_name | CURRENT_USER} [,...] | ALL | ALL EXCEPT {user_name | CURRENT_USER} [,...] */ class ParserGrantQuery : public IParserBase { diff --git a/dbms/src/Parsers/ParserRoleList.cpp b/dbms/src/Parsers/ParserRoleList.cpp index 5caf399ae0a..8cdae1f7bab 100644 --- a/dbms/src/Parsers/ParserRoleList.cpp +++ b/dbms/src/Parsers/ParserRoleList.cpp @@ -7,69 +7,93 @@ namespace DB { +namespace +{ + bool parseRoleListBeforeExcept(IParserBase::Pos & pos, Expected & expected, bool * all, bool * current_user, Strings & names) + { + return IParserBase::wrapParseImpl(pos, [&] + { + bool res_all = false; + bool res_current_user = false; + Strings res_names; + while (true) + { + if (ParserKeyword{"NONE"}.ignore(pos, expected)) + { + } + else if ( + current_user && (ParserKeyword{"CURRENT_USER"}.ignore(pos, expected) || ParserKeyword{"currentUser"}.ignore(pos, expected))) + { + if (ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected)) + { + if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected)) + return false; + } + res_current_user = true; + } + else if (all && ParserKeyword{"ALL"}.ignore(pos, expected)) + { + res_all = true; + } + else + { + String name; + if (!parseUserName(pos, expected, name)) + return false; + res_names.push_back(name); + } + + if (!ParserToken{TokenType::Comma}.ignore(pos, expected)) + break; + } + + if (all) + *all = res_all; + if (current_user) + *current_user = res_current_user; + names = std::move(res_names); + return true; + }); + } + + bool parseRoleListExcept(IParserBase::Pos & pos, Expected & expected, bool * except_current_user, Strings & except_names) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserKeyword{"EXCEPT"}.ignore(pos, expected)) + return false; + + return parseRoleListBeforeExcept(pos, expected, nullptr, except_current_user, except_names); + }); + } +} + + +ParserRoleList::ParserRoleList(bool allow_all_, bool allow_current_user_) + : allow_all(allow_all_), allow_current_user(allow_current_user_) {} + bool ParserRoleList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { - Strings roles; + Strings names; bool current_user = false; - bool all_roles = false; - Strings except_roles; + bool all = false; + Strings except_names; bool except_current_user = false; - bool except_mode = false; - while (true) - { - if (ParserKeyword{"NONE"}.ignore(pos, expected)) - { - } - else if (ParserKeyword{"CURRENT_USER"}.ignore(pos, expected) || - ParserKeyword{"currentUser"}.ignore(pos, expected)) - { - if (ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected)) - { - if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected)) - return false; - } - if (except_mode && !current_user) - except_current_user = true; - else - current_user = true; - } - else if (ParserKeyword{"ALL"}.ignore(pos, expected)) - { - all_roles = true; - if (ParserKeyword{"EXCEPT"}.ignore(pos, expected)) - { - except_mode = true; - continue; - } - } - else - { - String name; - if (!parseUserName(pos, expected, name)) - return false; - if (except_mode && (boost::range::find(roles, name) == roles.end())) - except_roles.push_back(name); - else - roles.push_back(name); - } + if (!parseRoleListBeforeExcept(pos, expected, (allow_all ? &all : nullptr), (allow_current_user ? ¤t_user : nullptr), names)) + return false; - if (!ParserToken{TokenType::Comma}.ignore(pos, expected)) - break; - } + parseRoleListExcept(pos, expected, (allow_current_user ? &except_current_user : nullptr), except_names); - if (all_roles) - { - current_user = false; - roles.clear(); - } + if (all) + names.clear(); auto result = std::make_shared(); - result->roles = std::move(roles); + result->names = std::move(names); result->current_user = current_user; - result->all_roles = all_roles; - result->except_roles = std::move(except_roles); + result->all = all; + result->except_names = std::move(except_names); result->except_current_user = except_current_user; node = result; return true; diff --git a/dbms/src/Parsers/ParserRoleList.h b/dbms/src/Parsers/ParserRoleList.h index 2913a4953c8..3daa0d7b6ff 100644 --- a/dbms/src/Parsers/ParserRoleList.h +++ b/dbms/src/Parsers/ParserRoleList.h @@ -10,9 +10,16 @@ namespace DB */ class ParserRoleList : public IParserBase { +public: + ParserRoleList(bool allow_all_ = true, bool allow_current_user_ = true); + protected: const char * getName() const override { return "RoleList"; } bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; + +private: + bool allow_all; + bool allow_current_user; }; } diff --git a/dbms/src/Parsers/ParserShowCreateAccessEntityQuery.cpp b/dbms/src/Parsers/ParserShowCreateAccessEntityQuery.cpp index 0d4474108b8..d1e6bc45478 100644 --- a/dbms/src/Parsers/ParserShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ParserShowCreateAccessEntityQuery.cpp @@ -68,6 +68,7 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe query->name = std::move(name); query->current_quota = current_quota; + query->current_user = current_user; query->row_policy_name = std::move(row_policy_name); return true; From ed2061db8a74263d0d9785a0efa150eaf0b8cc5e Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Thu, 13 Feb 2020 02:59:49 +0300 Subject: [PATCH 04/19] Better pointers: std::shared_ptr instead of std::shared_ptr, boost::atomic_shared_ptr instead of std::atomic_load/store. --- dbms/src/Access/AccessControlManager.cpp | 18 +++---- dbms/src/Access/AccessControlManager.h | 18 +++---- dbms/src/Access/AccessRightsContext.h | 2 + dbms/src/Access/QuotaContext.cpp | 51 ++++++++++--------- dbms/src/Access/QuotaContext.h | 19 +++---- dbms/src/Access/QuotaContextFactory.cpp | 15 +++--- dbms/src/Access/QuotaContextFactory.h | 6 +-- dbms/src/Access/RowPolicyContext.cpp | 21 ++++---- dbms/src/Access/RowPolicyContext.h | 8 +-- dbms/src/Access/RowPolicyContextFactory.cpp | 5 +- dbms/src/DataStreams/IBlockInputStream.h | 5 +- dbms/src/Interpreters/Context.cpp | 3 +- dbms/src/Interpreters/Context.h | 17 ++++--- .../TreeExecutorBlockInputStream.cpp | 2 +- .../Executors/TreeExecutorBlockInputStream.h | 2 +- dbms/src/Processors/Pipe.cpp | 2 +- dbms/src/Processors/Pipe.h | 2 +- .../Sources/SourceFromInputStream.h | 2 +- .../Processors/Sources/SourceWithProgress.h | 6 +-- .../Transforms/LimitsCheckingTransform.h | 4 +- .../Storages/System/StorageSystemColumns.cpp | 4 +- 21 files changed, 111 insertions(+), 101 deletions(-) diff --git a/dbms/src/Access/AccessControlManager.cpp b/dbms/src/Access/AccessControlManager.cpp index d24a7c2fd46..9d64e252ab3 100644 --- a/dbms/src/Access/AccessControlManager.cpp +++ b/dbms/src/Access/AccessControlManager.cpp @@ -89,14 +89,20 @@ void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguratio } -std::shared_ptr AccessControlManager::getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database) +AccessRightsContextPtr AccessControlManager::getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database) const { return std::make_shared(user, client_info, settings, current_database); } -std::shared_ptr AccessControlManager::createQuotaContext( - const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) +RowPolicyContextPtr AccessControlManager::getRowPolicyContext(const String & user_name) const +{ + return row_policy_context_factory->createContext(user_name); +} + + +QuotaContextPtr +AccessControlManager::getQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const { return quota_context_factory->createContext(user_name, address, custom_quota_key); } @@ -107,10 +113,4 @@ std::vector AccessControlManager::getQuotaUsageInfo() const return quota_context_factory->getUsageInfo(); } - -std::shared_ptr AccessControlManager::getRowPolicyContext(const String & user_name) const -{ - return row_policy_context_factory->createContext(user_name); -} - } diff --git a/dbms/src/Access/AccessControlManager.h b/dbms/src/Access/AccessControlManager.h index 915f213c163..64136c9ad52 100644 --- a/dbms/src/Access/AccessControlManager.h +++ b/dbms/src/Access/AccessControlManager.h @@ -22,13 +22,15 @@ namespace DB { struct User; using UserPtr = std::shared_ptr; +class AccessRightsContext; +using AccessRightsContextPtr = std::shared_ptr; +class RowPolicyContext; +using RowPolicyContextPtr = std::shared_ptr; +class RowPolicyContextFactory; class QuotaContext; +using QuotaContextPtr = std::shared_ptr; class QuotaContextFactory; struct QuotaUsageInfo; -class RowPolicyContext; -class RowPolicyContextFactory; -class AccessRights; -class AccessRightsContext; class ClientInfo; struct Settings; @@ -47,15 +49,13 @@ public: UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, std::function on_change = {}, ext::scope_guard * subscription = nullptr) const; UserPtr authorizeAndGetUser(const UUID & user_id, const String & password, const Poco::Net::IPAddress & address, std::function on_change = {}, ext::scope_guard * subscription = nullptr) const; - std::shared_ptr getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database); + AccessRightsContextPtr getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database) const; - std::shared_ptr - createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key); + RowPolicyContextPtr getRowPolicyContext(const String & user_name) const; + QuotaContextPtr getQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const; std::vector getQuotaUsageInfo() const; - std::shared_ptr getRowPolicyContext(const String & user_name) const; - private: std::unique_ptr quota_context_factory; std::unique_ptr row_policy_context_factory; diff --git a/dbms/src/Access/AccessRightsContext.h b/dbms/src/Access/AccessRightsContext.h index 2c87ce57674..a39abbcabba 100644 --- a/dbms/src/Access/AccessRightsContext.h +++ b/dbms/src/Access/AccessRightsContext.h @@ -89,4 +89,6 @@ private: mutable std::mutex mutex; }; +using AccessRightsContextPtr = std::shared_ptr; + } diff --git a/dbms/src/Access/QuotaContext.cpp b/dbms/src/Access/QuotaContext.cpp index 11666e5d4b8..1719db6fbf7 100644 --- a/dbms/src/Access/QuotaContext.cpp +++ b/dbms/src/Access/QuotaContext.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include @@ -171,7 +172,7 @@ QuotaUsageInfo QuotaContext::Intervals::getUsageInfo(std::chrono::system_clock:: QuotaContext::QuotaContext() - : atomic_intervals(std::make_shared()) /// Unlimited quota. + : intervals(boost::make_shared()) /// Unlimited quota. { } @@ -188,66 +189,66 @@ QuotaContext::QuotaContext( QuotaContext::~QuotaContext() = default; -void QuotaContext::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded) +void QuotaContext::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded) const { used({resource_type, amount}, check_exceeded); } -void QuotaContext::used(const std::pair & resource, bool check_exceeded) +void QuotaContext::used(const std::pair & resource, bool check_exceeded) const { - auto intervals_ptr = std::atomic_load(&atomic_intervals); + auto loaded = intervals.load(); auto current_time = std::chrono::system_clock::now(); - Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded); + Impl::used(user_name, *loaded, resource.first, resource.second, current_time, check_exceeded); } -void QuotaContext::used(const std::pair & resource1, const std::pair & resource2, bool check_exceeded) +void QuotaContext::used(const std::pair & resource1, const std::pair & resource2, bool check_exceeded) const { - auto intervals_ptr = std::atomic_load(&atomic_intervals); + auto loaded = intervals.load(); auto current_time = std::chrono::system_clock::now(); - Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded); - Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded); + Impl::used(user_name, *loaded, resource1.first, resource1.second, current_time, check_exceeded); + Impl::used(user_name, *loaded, resource2.first, resource2.second, current_time, check_exceeded); } -void QuotaContext::used(const std::pair & resource1, const std::pair & resource2, const std::pair & resource3, bool check_exceeded) +void QuotaContext::used(const std::pair & resource1, const std::pair & resource2, const std::pair & resource3, bool check_exceeded) const { - auto intervals_ptr = std::atomic_load(&atomic_intervals); + auto loaded = intervals.load(); auto current_time = std::chrono::system_clock::now(); - Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded); - Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded); - Impl::used(user_name, *intervals_ptr, resource3.first, resource3.second, current_time, check_exceeded); + Impl::used(user_name, *loaded, resource1.first, resource1.second, current_time, check_exceeded); + Impl::used(user_name, *loaded, resource2.first, resource2.second, current_time, check_exceeded); + Impl::used(user_name, *loaded, resource3.first, resource3.second, current_time, check_exceeded); } -void QuotaContext::used(const std::vector> & resources, bool check_exceeded) +void QuotaContext::used(const std::vector> & resources, bool check_exceeded) const { - auto intervals_ptr = std::atomic_load(&atomic_intervals); + auto loaded = intervals.load(); auto current_time = std::chrono::system_clock::now(); for (const auto & resource : resources) - Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded); + Impl::used(user_name, *loaded, resource.first, resource.second, current_time, check_exceeded); } -void QuotaContext::checkExceeded() +void QuotaContext::checkExceeded() const { - auto intervals_ptr = std::atomic_load(&atomic_intervals); - Impl::checkExceeded(user_name, *intervals_ptr, std::chrono::system_clock::now()); + auto loaded = intervals.load(); + Impl::checkExceeded(user_name, *loaded, std::chrono::system_clock::now()); } -void QuotaContext::checkExceeded(ResourceType resource_type) +void QuotaContext::checkExceeded(ResourceType resource_type) const { - auto intervals_ptr = std::atomic_load(&atomic_intervals); - Impl::checkExceeded(user_name, *intervals_ptr, resource_type, std::chrono::system_clock::now()); + auto loaded = intervals.load(); + Impl::checkExceeded(user_name, *loaded, resource_type, std::chrono::system_clock::now()); } QuotaUsageInfo QuotaContext::getUsageInfo() const { - auto intervals_ptr = std::atomic_load(&atomic_intervals); - return intervals_ptr->getUsageInfo(std::chrono::system_clock::now()); + auto loaded = intervals.load(); + return loaded->getUsageInfo(std::chrono::system_clock::now()); } diff --git a/dbms/src/Access/QuotaContext.h b/dbms/src/Access/QuotaContext.h index 122d0df6ee7..99f65ea52b0 100644 --- a/dbms/src/Access/QuotaContext.h +++ b/dbms/src/Access/QuotaContext.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -28,15 +29,15 @@ public: ~QuotaContext(); /// Tracks resource consumption. If the quota exceeded and `check_exceeded == true`, throws an exception. - void used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded = true); - void used(const std::pair & resource, bool check_exceeded = true); - void used(const std::pair & resource1, const std::pair & resource2, bool check_exceeded = true); - void used(const std::pair & resource1, const std::pair & resource2, const std::pair & resource3, bool check_exceeded = true); - void used(const std::vector> & resources, bool check_exceeded = true); + void used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded = true) const; + void used(const std::pair & resource, bool check_exceeded = true) const; + void used(const std::pair & resource1, const std::pair & resource2, bool check_exceeded = true) const; + void used(const std::pair & resource1, const std::pair & resource2, const std::pair & resource3, bool check_exceeded = true) const; + void used(const std::vector> & resources, bool check_exceeded = true) const; /// Checks if the quota exceeded. If so, throws an exception. - void checkExceeded(); - void checkExceeded(ResourceType resource_type); + void checkExceeded() const; + void checkExceeded(ResourceType resource_type) const; /// Returns the information about this quota context. QuotaUsageInfo getUsageInfo() const; @@ -78,10 +79,10 @@ private: const String user_name; const Poco::Net::IPAddress address; const String client_key; - std::shared_ptr atomic_intervals; /// atomically changed by QuotaUsageManager + boost::atomic_shared_ptr intervals; /// atomically changed by QuotaUsageManager }; -using QuotaContextPtr = std::shared_ptr; +using QuotaContextPtr = std::shared_ptr; /// The information about a quota context. diff --git a/dbms/src/Access/QuotaContextFactory.cpp b/dbms/src/Access/QuotaContextFactory.cpp index c6ecb947102..75daa25bf37 100644 --- a/dbms/src/Access/QuotaContextFactory.cpp +++ b/dbms/src/Access/QuotaContextFactory.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace DB @@ -91,7 +92,7 @@ String QuotaContextFactory::QuotaInfo::calculateKey(const QuotaContext & context } -std::shared_ptr QuotaContextFactory::QuotaInfo::getOrBuildIntervals(const String & key) +boost::shared_ptr QuotaContextFactory::QuotaInfo::getOrBuildIntervals(const String & key) { auto it = key_to_intervals.find(key); if (it != key_to_intervals.end()) @@ -107,9 +108,9 @@ void QuotaContextFactory::QuotaInfo::rebuildAllIntervals() } -std::shared_ptr QuotaContextFactory::QuotaInfo::rebuildIntervals(const String & key) +boost::shared_ptr QuotaContextFactory::QuotaInfo::rebuildIntervals(const String & key) { - auto new_intervals = std::make_shared(); + auto new_intervals = boost::make_shared(); new_intervals->quota_name = quota->getName(); new_intervals->quota_id = quota_id; new_intervals->quota_key = key; @@ -184,7 +185,7 @@ QuotaContextFactory::~QuotaContextFactory() } -std::shared_ptr QuotaContextFactory::createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key) +QuotaContextPtr QuotaContextFactory::createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key) { std::lock_guard lock{mutex}; ensureAllQuotasRead(); @@ -266,7 +267,7 @@ void QuotaContextFactory::chooseQuotaForAllContexts() void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr & context) { /// `mutex` is already locked. - std::shared_ptr intervals; + boost::shared_ptr intervals; for (auto & info : all_quotas | boost::adaptors::map_values) { if (info.canUseWithContext(*context)) @@ -278,9 +279,9 @@ void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr(); /// No quota == no limits. + intervals = boost::make_shared(); /// No quota == no limits. - std::atomic_store(&context->atomic_intervals, intervals); + context->intervals.store(intervals); } diff --git a/dbms/src/Access/QuotaContextFactory.h b/dbms/src/Access/QuotaContextFactory.h index 611a25059f6..c12847c4b89 100644 --- a/dbms/src/Access/QuotaContextFactory.h +++ b/dbms/src/Access/QuotaContextFactory.h @@ -34,8 +34,8 @@ private: bool canUseWithContext(const QuotaContext & context) const; String calculateKey(const QuotaContext & context) const; - std::shared_ptr getOrBuildIntervals(const String & key); - std::shared_ptr rebuildIntervals(const String & key); + boost::shared_ptr getOrBuildIntervals(const String & key); + boost::shared_ptr rebuildIntervals(const String & key); void rebuildAllIntervals(); QuotaPtr quota; @@ -43,7 +43,7 @@ private: std::unordered_set roles; bool all_roles = false; std::unordered_set except_roles; - std::unordered_map> key_to_intervals; + std::unordered_map> key_to_intervals; }; void ensureAllQuotasRead(); diff --git a/dbms/src/Access/RowPolicyContext.cpp b/dbms/src/Access/RowPolicyContext.cpp index cb24d0af01b..33166d18997 100644 --- a/dbms/src/Access/RowPolicyContext.cpp +++ b/dbms/src/Access/RowPolicyContext.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -7,12 +8,12 @@ namespace DB { size_t RowPolicyContext::Hash::operator()(const DatabaseAndTableNameRef & database_and_table_name) const { - return std::hash{}(database_and_table_name.first) - std::hash{}(database_and_table_name.second); + return std::hash{}(database_and_table_name.first) - std::hash{}(database_and_table_name.second); } RowPolicyContext::RowPolicyContext() - : atomic_map_of_mixed_conditions(std::make_shared()) + : map_of_mixed_conditions(boost::make_shared()) { } @@ -28,9 +29,9 @@ RowPolicyContext::RowPolicyContext(const String & user_name_) ASTPtr RowPolicyContext::getCondition(const String & database, const String & table_name, ConditionIndex index) const { /// We don't lock `mutex` here. - auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions); - auto it = map_of_mixed_conditions->find({database, table_name}); - if (it == map_of_mixed_conditions->end()) + auto loaded = map_of_mixed_conditions.load(); + auto it = loaded->find({database, table_name}); + if (it == loaded->end()) return {}; return it->second.mixed_conditions[index]; } @@ -39,9 +40,9 @@ ASTPtr RowPolicyContext::getCondition(const String & database, const String & ta std::vector RowPolicyContext::getCurrentPolicyIDs() const { /// We don't lock `mutex` here. - auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions); + auto loaded = map_of_mixed_conditions.load(); std::vector policy_ids; - for (const auto & mixed_conditions : *map_of_mixed_conditions | boost::adaptors::map_values) + for (const auto & mixed_conditions : *loaded | boost::adaptors::map_values) boost::range::copy(mixed_conditions.policy_ids, std::back_inserter(policy_ids)); return policy_ids; } @@ -50,9 +51,9 @@ std::vector RowPolicyContext::getCurrentPolicyIDs() const std::vector RowPolicyContext::getCurrentPolicyIDs(const String & database, const String & table_name) const { /// We don't lock `mutex` here. - auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions); - auto it = map_of_mixed_conditions->find({database, table_name}); - if (it == map_of_mixed_conditions->end()) + auto loaded = map_of_mixed_conditions.load(); + auto it = loaded->find({database, table_name}); + if (it == loaded->end()) return {}; return it->second.policy_ids; } diff --git a/dbms/src/Access/RowPolicyContext.h b/dbms/src/Access/RowPolicyContext.h index 776808f74d7..0d54573f5a4 100644 --- a/dbms/src/Access/RowPolicyContext.h +++ b/dbms/src/Access/RowPolicyContext.h @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include @@ -42,7 +42,7 @@ private: RowPolicyContext(const String & user_name_); /// RowPolicyContext should be created by RowPolicyContextFactory. using DatabaseAndTableName = std::pair; - using DatabaseAndTableNameRef = std::pair; + using DatabaseAndTableNameRef = std::pair; struct Hash { size_t operator()(const DatabaseAndTableNameRef & database_and_table_name) const; @@ -58,9 +58,9 @@ private: using MapOfMixedConditions = std::unordered_map; const String user_name; - std::shared_ptr atomic_map_of_mixed_conditions; /// Changed atomically, not protected by `mutex`. + mutable boost::atomic_shared_ptr map_of_mixed_conditions; }; -using RowPolicyContextPtr = std::shared_ptr; +using RowPolicyContextPtr = std::shared_ptr; } diff --git a/dbms/src/Access/RowPolicyContextFactory.cpp b/dbms/src/Access/RowPolicyContextFactory.cpp index 77e5056e206..b23a4d77745 100644 --- a/dbms/src/Access/RowPolicyContextFactory.cpp +++ b/dbms/src/Access/RowPolicyContextFactory.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -295,7 +296,7 @@ void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context } } - auto map_of_mixed_conditions = std::make_shared(); + auto map_of_mixed_conditions = boost::make_shared(); for (auto & [database_and_table_name, mixers] : map_of_mixers) { auto database_and_table_name_keeper = std::make_unique(); @@ -309,7 +310,7 @@ void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context mixed_conditions.mixed_conditions[index] = std::move(mixers.mixers[index]).getResult(); } - std::atomic_store(&context.atomic_map_of_mixed_conditions, std::shared_ptr{map_of_mixed_conditions}); + context.map_of_mixed_conditions.store(map_of_mixed_conditions); } } diff --git a/dbms/src/DataStreams/IBlockInputStream.h b/dbms/src/DataStreams/IBlockInputStream.h index 7ca41551298..d11432d89be 100644 --- a/dbms/src/DataStreams/IBlockInputStream.h +++ b/dbms/src/DataStreams/IBlockInputStream.h @@ -24,6 +24,7 @@ namespace ErrorCodes class ProcessListElement; class QuotaContext; +using QuotaContextPtr = std::shared_ptr; class QueryStatus; struct SortColumnDescription; using SortDescription = std::vector; @@ -220,7 +221,7 @@ public: /** Set the quota. If you set a quota on the amount of raw data, * then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits. */ - virtual void setQuota(const std::shared_ptr & quota_) + virtual void setQuota(const QuotaContextPtr & quota_) { quota = quota_; } @@ -278,7 +279,7 @@ private: LocalLimits limits; - std::shared_ptr quota; /// If nullptr - the quota is not used. + QuotaContextPtr quota; /// If nullptr - the quota is not used. UInt64 prev_elapsed = 0; /// The approximate total number of rows to read. For progress bar. diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index d290f8dc4f4..2def6dfad22 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -742,8 +742,7 @@ void Context::setUser(const String & name, const String & password, const Poco:: }, &subscription_for_user_change.subscription); - quota = getAccessControlManager().createQuotaContext( - client_info.current_user, client_info.current_address.host(), client_info.quota_key); + quota = getAccessControlManager().getQuotaContext(client_info.current_user, client_info.current_address.host(), client_info.quota_key); row_policy = getAccessControlManager().getRowPolicyContext(client_info.current_user); calculateUserSettings(); diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 07fa6b06c1f..e0d65ffc5fe 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -46,8 +46,11 @@ struct ContextShared; class Context; struct User; class AccessRightsContext; -class QuotaContext; +using AccessRightsContextPtr = std::shared_ptr; class RowPolicyContext; +using RowPolicyContextPtr = std::shared_ptr; +class QuotaContext; +using QuotaContextPtr = std::shared_ptr; class AccessFlags; struct AccessRightsElement; class AccessRightsElements; @@ -164,9 +167,9 @@ private: std::shared_ptr user; UUID user_id; SubscriptionForUserChange subscription_for_user_change; - std::shared_ptr access_rights; - std::shared_ptr quota; /// Current quota. By default - empty quota, that have no limits. - std::shared_ptr row_policy; + AccessRightsContextPtr access_rights; + QuotaContextPtr quota; /// Current quota. By default - empty quota, that have no limits. + RowPolicyContextPtr row_policy; String current_database; Settings settings; /// Setting for query execution. std::shared_ptr settings_constraints; @@ -237,7 +240,7 @@ public: AccessControlManager & getAccessControlManager(); const AccessControlManager & getAccessControlManager() const; - std::shared_ptr getAccessRights() const { return std::atomic_load(&access_rights); } + AccessRightsContextPtr getAccessRights() const { return std::atomic_load(&access_rights); } /// Checks access rights. /// Empty database means the current database. @@ -250,8 +253,8 @@ public: void checkAccess(const AccessRightsElement & access) const; void checkAccess(const AccessRightsElements & access) const; - std::shared_ptr getQuota() const { return quota; } - std::shared_ptr getRowPolicy() const { return row_policy; } + QuotaContextPtr getQuota() const { return quota; } + RowPolicyContextPtr getRowPolicy() const { return row_policy; } /// TODO: we need much better code for switching policies, quotas, access rights for initial user /// Switches row policy in case we have initial user in client info diff --git a/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.cpp b/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.cpp index ee482d62f27..af8ad27bc98 100644 --- a/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.cpp +++ b/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.cpp @@ -207,7 +207,7 @@ void TreeExecutorBlockInputStream::setLimits(const IBlockInputStream::LocalLimit source->setLimits(limits_); } -void TreeExecutorBlockInputStream::setQuota(const std::shared_ptr & quota_) +void TreeExecutorBlockInputStream::setQuota(const QuotaContextPtr & quota_) { for (auto & source : sources_with_progress) source->setQuota(quota_); diff --git a/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.h b/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.h index 8787d3090c1..48c51e65b14 100644 --- a/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.h +++ b/dbms/src/Processors/Executors/TreeExecutorBlockInputStream.h @@ -41,7 +41,7 @@ public: void setProgressCallback(const ProgressCallback & callback) final; void setProcessListElement(QueryStatus * elem) final; void setLimits(const LocalLimits & limits_) final; - void setQuota(const std::shared_ptr & quota_) final; + void setQuota(const QuotaContextPtr & quota_) final; void addTotalRowsApprox(size_t value) final; protected: diff --git a/dbms/src/Processors/Pipe.cpp b/dbms/src/Processors/Pipe.cpp index b2aa504f93d..e75914cac01 100644 --- a/dbms/src/Processors/Pipe.cpp +++ b/dbms/src/Processors/Pipe.cpp @@ -102,7 +102,7 @@ void Pipe::setLimits(const ISourceWithProgress::LocalLimits & limits) } } -void Pipe::setQuota(const std::shared_ptr & quota) +void Pipe::setQuota(const QuotaContextPtr & quota) { for (auto & processor : processors) { diff --git a/dbms/src/Processors/Pipe.h b/dbms/src/Processors/Pipe.h index 81494884fe1..20f5eb038a3 100644 --- a/dbms/src/Processors/Pipe.h +++ b/dbms/src/Processors/Pipe.h @@ -40,7 +40,7 @@ public: /// Specify quotas and limits for every ISourceWithProgress. void setLimits(const SourceWithProgress::LocalLimits & limits); - void setQuota(const std::shared_ptr & quota); + void setQuota(const QuotaContextPtr & quota); /// Set information about preferred executor number for sources. void pinSources(size_t executor_number); diff --git a/dbms/src/Processors/Sources/SourceFromInputStream.h b/dbms/src/Processors/Sources/SourceFromInputStream.h index b5704fc521f..780ff1be8a3 100644 --- a/dbms/src/Processors/Sources/SourceFromInputStream.h +++ b/dbms/src/Processors/Sources/SourceFromInputStream.h @@ -25,7 +25,7 @@ public: /// Implementation for methods from ISourceWithProgress. void setLimits(const LocalLimits & limits_) final { stream->setLimits(limits_); } - void setQuota(const std::shared_ptr & quota_) final { stream->setQuota(quota_); } + void setQuota(const QuotaContextPtr & quota_) final { stream->setQuota(quota_); } void setProcessListElement(QueryStatus * elem) final { stream->setProcessListElement(elem); } void setProgressCallback(const ProgressCallback & callback) final { stream->setProgressCallback(callback); } void addTotalRowsApprox(size_t value) final { stream->addTotalRowsApprox(value); } diff --git a/dbms/src/Processors/Sources/SourceWithProgress.h b/dbms/src/Processors/Sources/SourceWithProgress.h index 8a508c59acf..d22a2bf087a 100644 --- a/dbms/src/Processors/Sources/SourceWithProgress.h +++ b/dbms/src/Processors/Sources/SourceWithProgress.h @@ -21,7 +21,7 @@ public: /// Set the quota. If you set a quota on the amount of raw data, /// then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits. - virtual void setQuota(const std::shared_ptr & quota_) = 0; + virtual void setQuota(const QuotaContextPtr & quota_) = 0; /// Set the pointer to the process list item. /// General information about the resources spent on the request will be written into it. @@ -49,7 +49,7 @@ public: using LimitsMode = IBlockInputStream::LimitsMode; void setLimits(const LocalLimits & limits_) final { limits = limits_; } - void setQuota(const std::shared_ptr & quota_) final { quota = quota_; } + void setQuota(const QuotaContextPtr & quota_) final { quota = quota_; } void setProcessListElement(QueryStatus * elem) final { process_list_elem = elem; } void setProgressCallback(const ProgressCallback & callback) final { progress_callback = callback; } void addTotalRowsApprox(size_t value) final { total_rows_approx += value; } @@ -62,7 +62,7 @@ protected: private: LocalLimits limits; - std::shared_ptr quota; + QuotaContextPtr quota; ProgressCallback progress_callback; QueryStatus * process_list_elem = nullptr; diff --git a/dbms/src/Processors/Transforms/LimitsCheckingTransform.h b/dbms/src/Processors/Transforms/LimitsCheckingTransform.h index 8746563ac78..bfc5c338da1 100644 --- a/dbms/src/Processors/Transforms/LimitsCheckingTransform.h +++ b/dbms/src/Processors/Transforms/LimitsCheckingTransform.h @@ -33,7 +33,7 @@ public: String getName() const override { return "LimitsCheckingTransform"; } - void setQuota(const std::shared_ptr & quota_) { quota = quota_; } + void setQuota(const QuotaContextPtr & quota_) { quota = quota_; } protected: void transform(Chunk & chunk) override; @@ -41,7 +41,7 @@ protected: private: LocalLimits limits; - std::shared_ptr quota; + QuotaContextPtr quota; UInt64 prev_elapsed = 0; ProcessorProfileInfo info; diff --git a/dbms/src/Storages/System/StorageSystemColumns.cpp b/dbms/src/Storages/System/StorageSystemColumns.cpp index 0fc85898264..e71aee1930e 100644 --- a/dbms/src/Storages/System/StorageSystemColumns.cpp +++ b/dbms/src/Storages/System/StorageSystemColumns.cpp @@ -63,7 +63,7 @@ public: ColumnPtr databases_, ColumnPtr tables_, Storages storages_, - const std::shared_ptr & access_rights_, + const AccessRightsContextPtr & access_rights_, String query_id_) : SourceWithProgress(header_) , columns_mask(std::move(columns_mask_)), max_block_size(max_block_size_) @@ -231,7 +231,7 @@ private: String query_id; size_t db_table_num = 0; size_t total_tables; - std::shared_ptr access_rights; + AccessRightsContextPtr access_rights; }; From 244c9d53258e50ba71c6a27df3d8e47820a5693e Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Mon, 10 Feb 2020 05:26:56 +0300 Subject: [PATCH 05/19] Add class GenericRoleSet to represent a set of IDs of users and roles. --- dbms/src/Access/AccessControlManager.cpp | 11 +- dbms/src/Access/AccessControlManager.h | 4 +- dbms/src/Access/GenericRoleSet.cpp | 212 ++++++++++++++++++ dbms/src/Access/GenericRoleSet.h | 62 +++++ dbms/src/Access/Quota.cpp | 3 +- dbms/src/Access/Quota.h | 5 +- dbms/src/Access/QuotaContext.cpp | 3 +- dbms/src/Access/QuotaContext.h | 3 +- dbms/src/Access/QuotaContextFactory.cpp | 18 +- dbms/src/Access/QuotaContextFactory.h | 6 +- dbms/src/Access/RowPolicy.cpp | 2 +- dbms/src/Access/RowPolicy.h | 7 +- dbms/src/Access/RowPolicyContext.cpp | 4 +- dbms/src/Access/RowPolicyContext.h | 4 +- dbms/src/Access/RowPolicyContextFactory.cpp | 17 +- dbms/src/Access/RowPolicyContextFactory.h | 6 +- dbms/src/Access/UsersConfigAccessStorage.cpp | 14 +- dbms/src/Interpreters/Context.cpp | 9 +- .../InterpreterCreateQuotaQuery.cpp | 34 +-- .../InterpreterCreateQuotaQuery.h | 4 +- .../InterpreterCreateRowPolicyQuery.cpp | 34 +-- .../InterpreterCreateRowPolicyQuery.h | 4 +- .../Interpreters/InterpreterGrantQuery.cpp | 10 +- ...InterpreterShowCreateAccessEntityQuery.cpp | 22 +- .../InterpreterShowGrantsQuery.cpp | 6 +- dbms/src/Parsers/ASTCreateQuotaQuery.cpp | 4 +- dbms/src/Parsers/ASTCreateQuotaQuery.h | 4 +- dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp | 4 +- dbms/src/Parsers/ASTCreateRowPolicyQuery.h | 4 +- ...{ASTRoleList.cpp => ASTGenericRoleSet.cpp} | 4 +- .../{ASTRoleList.h => ASTGenericRoleSet.h} | 9 +- dbms/src/Parsers/ASTGrantQuery.cpp | 4 +- dbms/src/Parsers/ASTGrantQuery.h | 4 +- dbms/src/Parsers/ParserCreateQuotaQuery.cpp | 12 +- .../Parsers/ParserCreateRowPolicyQuery.cpp | 12 +- dbms/src/Parsers/ParserCreateUserQuery.cpp | 2 +- ...rRoleList.cpp => ParserGenericRoleSet.cpp} | 22 +- dbms/src/Parsers/ParserGenericRoleSet.h | 26 +++ dbms/src/Parsers/ParserGrantQuery.cpp | 12 +- dbms/src/Parsers/ParserRoleList.h | 25 --- .../Storages/System/StorageSystemQuotas.cpp | 13 +- dbms/tests/integration/test_quota/test.py | 34 +-- 42 files changed, 451 insertions(+), 248 deletions(-) create mode 100644 dbms/src/Access/GenericRoleSet.cpp create mode 100644 dbms/src/Access/GenericRoleSet.h rename dbms/src/Parsers/{ASTRoleList.cpp => ASTGenericRoleSet.cpp} (92%) rename dbms/src/Parsers/{ASTRoleList.h => ASTGenericRoleSet.h} (53%) rename dbms/src/Parsers/{ParserRoleList.cpp => ParserGenericRoleSet.cpp} (71%) create mode 100644 dbms/src/Parsers/ParserGenericRoleSet.h delete mode 100644 dbms/src/Parsers/ParserRoleList.h diff --git a/dbms/src/Access/AccessControlManager.cpp b/dbms/src/Access/AccessControlManager.cpp index 9d64e252ab3..c2eea0cc605 100644 --- a/dbms/src/Access/AccessControlManager.cpp +++ b/dbms/src/Access/AccessControlManager.cpp @@ -95,19 +95,18 @@ AccessRightsContextPtr AccessControlManager::getAccessRightsContext(const UserPt } -RowPolicyContextPtr AccessControlManager::getRowPolicyContext(const String & user_name) const +RowPolicyContextPtr AccessControlManager::getRowPolicyContext(const UUID & user_id) const { - return row_policy_context_factory->createContext(user_name); + return row_policy_context_factory->createContext(user_id); } -QuotaContextPtr -AccessControlManager::getQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const +QuotaContextPtr AccessControlManager::getQuotaContext( + const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const { - return quota_context_factory->createContext(user_name, address, custom_quota_key); + return quota_context_factory->createContext(user_id, user_name, address, custom_quota_key); } - std::vector AccessControlManager::getQuotaUsageInfo() const { return quota_context_factory->getUsageInfo(); diff --git a/dbms/src/Access/AccessControlManager.h b/dbms/src/Access/AccessControlManager.h index 64136c9ad52..dd9a8285aeb 100644 --- a/dbms/src/Access/AccessControlManager.h +++ b/dbms/src/Access/AccessControlManager.h @@ -51,9 +51,9 @@ public: AccessRightsContextPtr getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database) const; - RowPolicyContextPtr getRowPolicyContext(const String & user_name) const; + RowPolicyContextPtr getRowPolicyContext(const UUID & user_id) const; - QuotaContextPtr getQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const; + QuotaContextPtr getQuotaContext(const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const; std::vector getQuotaUsageInfo() const; private: diff --git a/dbms/src/Access/GenericRoleSet.cpp b/dbms/src/Access/GenericRoleSet.cpp new file mode 100644 index 00000000000..ff142a36a97 --- /dev/null +++ b/dbms/src/Access/GenericRoleSet.cpp @@ -0,0 +1,212 @@ +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +GenericRoleSet::GenericRoleSet() = default; +GenericRoleSet::GenericRoleSet(const GenericRoleSet & src) = default; +GenericRoleSet & GenericRoleSet::operator =(const GenericRoleSet & src) = default; +GenericRoleSet::GenericRoleSet(GenericRoleSet && src) = default; +GenericRoleSet & GenericRoleSet::operator =(GenericRoleSet && src) = default; + + +GenericRoleSet::GenericRoleSet(AllTag) +{ + all = true; +} + +GenericRoleSet::GenericRoleSet(const UUID & id) +{ + add(id); +} + + +GenericRoleSet::GenericRoleSet(const std::vector & ids_) +{ + add(ids_); +} + + +GenericRoleSet::GenericRoleSet(const boost::container::flat_set & ids_) +{ + add(ids_); +} + + +GenericRoleSet::GenericRoleSet(const ASTGenericRoleSet & ast, const AccessControlManager & manager, const std::optional & current_user_id) +{ + all = ast.all; + + if (!ast.names.empty() && !all) + { + ids.reserve(ast.names.size()); + for (const String & name : ast.names) + { + auto id = manager.getID(name); + ids.insert(id); + } + } + + if (ast.current_user && !all) + { + if (!current_user_id) + throw Exception("Current user is unknown", ErrorCodes::LOGICAL_ERROR); + ids.insert(*current_user_id); + } + + if (!ast.except_names.empty()) + { + except_ids.reserve(ast.except_names.size()); + for (const String & except_name : ast.except_names) + { + auto except_id = manager.getID(except_name); + except_ids.insert(except_id); + } + } + + if (ast.except_current_user) + { + if (!current_user_id) + throw Exception("Current user is unknown", ErrorCodes::LOGICAL_ERROR); + except_ids.insert(*current_user_id); + } + + for (const UUID & except_id : except_ids) + ids.erase(except_id); +} + +std::shared_ptr GenericRoleSet::toAST(const AccessControlManager & manager) const +{ + auto ast = std::make_shared(); + ast->all = all; + + if (!ids.empty()) + { + ast->names.reserve(ids.size()); + for (const UUID & id : ids) + { + auto name = manager.tryReadName(id); + if (name) + ast->names.emplace_back(std::move(*name)); + } + boost::range::sort(ast->names); + } + + if (!except_ids.empty()) + { + ast->except_names.reserve(except_ids.size()); + for (const UUID & except_id : except_ids) + { + auto except_name = manager.tryReadName(except_id); + if (except_name) + ast->except_names.emplace_back(std::move(*except_name)); + } + boost::range::sort(ast->except_names); + } + + return ast; +} + + +String GenericRoleSet::toString(const AccessControlManager & manager) const +{ + auto ast = toAST(manager); + return serializeAST(*ast); +} + + +Strings GenericRoleSet::toStrings(const AccessControlManager & manager) const +{ + if (all || !except_ids.empty()) + return {toString(manager)}; + + Strings names; + names.reserve(ids.size()); + for (const UUID & id : ids) + { + auto name = manager.tryReadName(id); + if (name) + names.emplace_back(std::move(*name)); + } + boost::range::sort(names); + return names; +} + + +bool GenericRoleSet::empty() const +{ + return ids.empty() && !all; +} + + +void GenericRoleSet::clear() +{ + ids.clear(); + all = false; + except_ids.clear(); +} + + +void GenericRoleSet::add(const UUID & id) +{ + ids.insert(id); +} + + +void GenericRoleSet::add(const std::vector & ids_) +{ + for (const auto & id : ids_) + add(id); +} + + +void GenericRoleSet::add(const boost::container::flat_set & ids_) +{ + for (const auto & id : ids_) + add(id); +} + + +bool GenericRoleSet::match(const UUID & user_id) const +{ + return (all || ids.contains(user_id)) && !except_ids.contains(user_id); +} + + +std::vector GenericRoleSet::getMatchingIDs() const +{ + if (all) + throw Exception("getAllMatchingIDs() can't get ALL ids", ErrorCodes::LOGICAL_ERROR); + std::vector res; + boost::range::set_difference(ids, except_ids, std::back_inserter(res)); + return res; +} + + +std::vector GenericRoleSet::getMatchingUsers(const AccessControlManager & manager) const +{ + if (!all) + return getMatchingIDs(); + + std::vector res; + for (const UUID & id : manager.findAll()) + { + if (match(id)) + res.push_back(id); + } + return res; +} + + +bool operator ==(const GenericRoleSet & lhs, const GenericRoleSet & rhs) +{ + return (lhs.all == rhs.all) && (lhs.ids == rhs.ids) && (lhs.except_ids == rhs.except_ids); +} + +} diff --git a/dbms/src/Access/GenericRoleSet.h b/dbms/src/Access/GenericRoleSet.h new file mode 100644 index 00000000000..b3f39a05bd4 --- /dev/null +++ b/dbms/src/Access/GenericRoleSet.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ +class ASTGenericRoleSet; +class AccessControlManager; + + +/// Represents a set of users/roles like +/// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...] +/// Similar to ASTGenericRoleSet, but with IDs instead of names. +struct GenericRoleSet +{ + GenericRoleSet(); + GenericRoleSet(const GenericRoleSet & src); + GenericRoleSet & operator =(const GenericRoleSet & src); + GenericRoleSet(GenericRoleSet && src); + GenericRoleSet & operator =(GenericRoleSet && src); + + struct AllTag {}; + GenericRoleSet(AllTag); + + GenericRoleSet(const UUID & id); + GenericRoleSet(const std::vector & ids_); + GenericRoleSet(const boost::container::flat_set & ids_); + + GenericRoleSet(const ASTGenericRoleSet & ast, const AccessControlManager & manager, const std::optional & current_user_id = {}); + std::shared_ptr toAST(const AccessControlManager & manager) const; + + String toString(const AccessControlManager & manager) const; + Strings toStrings(const AccessControlManager & manager) const; + + bool empty() const; + void clear(); + void add(const UUID & id); + void add(const std::vector & ids_); + void add(const boost::container::flat_set & ids_); + + /// Checks if a specified ID matches this GenericRoleSet. + bool match(const UUID & id) const; + + /// Returns a list of matching IDs. The function must not be called if `all` == `true`. + std::vector getMatchingIDs() const; + + /// Returns a list of matching users. + std::vector getMatchingUsers(const AccessControlManager & manager) const; + + friend bool operator ==(const GenericRoleSet & lhs, const GenericRoleSet & rhs); + friend bool operator !=(const GenericRoleSet & lhs, const GenericRoleSet & rhs) { return !(lhs == rhs); } + + boost::container::flat_set ids; + bool all = false; + boost::container::flat_set except_ids; +}; + +} diff --git a/dbms/src/Access/Quota.cpp b/dbms/src/Access/Quota.cpp index d178307ca51..d9e9e0b35fc 100644 --- a/dbms/src/Access/Quota.cpp +++ b/dbms/src/Access/Quota.cpp @@ -23,8 +23,7 @@ bool Quota::equal(const IAccessEntity & other) const if (!IAccessEntity::equal(other)) return false; const auto & other_quota = typeid_cast(other); - return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (roles == other_quota.roles) - && (all_roles == other_quota.all_roles) && (except_roles == other_quota.except_roles); + return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (roles == other_quota.roles); } diff --git a/dbms/src/Access/Quota.h b/dbms/src/Access/Quota.h index 716bccbe1ff..4778b119d1e 100644 --- a/dbms/src/Access/Quota.h +++ b/dbms/src/Access/Quota.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include @@ -63,9 +64,7 @@ struct Quota : public IAccessEntity KeyType key_type = KeyType::NONE; /// Which roles or users should use this quota. - Strings roles; - bool all_roles = false; - Strings except_roles; + GenericRoleSet roles; bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } diff --git a/dbms/src/Access/QuotaContext.cpp b/dbms/src/Access/QuotaContext.cpp index 1719db6fbf7..775e3a46cbe 100644 --- a/dbms/src/Access/QuotaContext.cpp +++ b/dbms/src/Access/QuotaContext.cpp @@ -178,10 +178,11 @@ QuotaContext::QuotaContext() QuotaContext::QuotaContext( + const UUID & user_id_, const String & user_name_, const Poco::Net::IPAddress & address_, const String & client_key_) - : user_name(user_name_), address(address_), client_key(client_key_) + : user_id(user_id_), user_name(user_name_), address(address_), client_key(client_key_) { } diff --git a/dbms/src/Access/QuotaContext.h b/dbms/src/Access/QuotaContext.h index 99f65ea52b0..be3a36ae3eb 100644 --- a/dbms/src/Access/QuotaContext.h +++ b/dbms/src/Access/QuotaContext.h @@ -47,7 +47,7 @@ private: friend struct ext::shared_ptr_helper; /// Instances of this class are created by QuotaContextFactory. - QuotaContext(const String & user_name_, const Poco::Net::IPAddress & address_, const String & client_key_); + QuotaContext(const UUID & user_id_, const String & user_name_, const Poco::Net::IPAddress & address_, const String & client_key_); static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE; @@ -76,6 +76,7 @@ private: struct Impl; + const UUID user_id; const String user_name; const Poco::Net::IPAddress address; const String client_key; diff --git a/dbms/src/Access/QuotaContextFactory.cpp b/dbms/src/Access/QuotaContextFactory.cpp index 75daa25bf37..2e828b148ae 100644 --- a/dbms/src/Access/QuotaContextFactory.cpp +++ b/dbms/src/Access/QuotaContextFactory.cpp @@ -35,24 +35,14 @@ void QuotaContextFactory::QuotaInfo::setQuota(const QuotaPtr & quota_, const UUI { quota = quota_; quota_id = quota_id_; - - boost::range::copy(quota->roles, std::inserter(roles, roles.end())); - all_roles = quota->all_roles; - boost::range::copy(quota->except_roles, std::inserter(except_roles, except_roles.end())); - + roles = "a->roles; rebuildAllIntervals(); } bool QuotaContextFactory::QuotaInfo::canUseWithContext(const QuotaContext & context) const { - if (roles.count(context.user_name)) - return true; - - if (all_roles && !except_roles.count(context.user_name)) - return true; - - return false; + return roles->match(context.user_id); } @@ -185,11 +175,11 @@ QuotaContextFactory::~QuotaContextFactory() } -QuotaContextPtr QuotaContextFactory::createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key) +QuotaContextPtr QuotaContextFactory::createContext(const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & client_key) { std::lock_guard lock{mutex}; ensureAllQuotasRead(); - auto context = ext::shared_ptr_helper::create(user_name, address, client_key); + auto context = ext::shared_ptr_helper::create(user_id, user_name, address, client_key); contexts.push_back(context); chooseQuotaForContext(context); return context; diff --git a/dbms/src/Access/QuotaContextFactory.h b/dbms/src/Access/QuotaContextFactory.h index c12847c4b89..6d9fdede833 100644 --- a/dbms/src/Access/QuotaContextFactory.h +++ b/dbms/src/Access/QuotaContextFactory.h @@ -20,7 +20,7 @@ public: QuotaContextFactory(const AccessControlManager & access_control_manager_); ~QuotaContextFactory(); - QuotaContextPtr createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key); + QuotaContextPtr createContext(const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & client_key); std::vector getUsageInfo() const; private: @@ -40,9 +40,7 @@ private: QuotaPtr quota; UUID quota_id; - std::unordered_set roles; - bool all_roles = false; - std::unordered_set except_roles; + const GenericRoleSet * roles = nullptr; std::unordered_map> key_to_intervals; }; diff --git a/dbms/src/Access/RowPolicy.cpp b/dbms/src/Access/RowPolicy.cpp index 391303e46a2..d5a28d14bb8 100644 --- a/dbms/src/Access/RowPolicy.cpp +++ b/dbms/src/Access/RowPolicy.cpp @@ -77,7 +77,7 @@ bool RowPolicy::equal(const IAccessEntity & other) const const auto & other_policy = typeid_cast(other); return (database == other_policy.database) && (table_name == other_policy.table_name) && (policy_name == other_policy.policy_name) && boost::range::equal(conditions, other_policy.conditions) && restrictive == other_policy.restrictive - && (roles == other_policy.roles) && (all_roles == other_policy.all_roles) && (except_roles == other_policy.except_roles); + && (roles == other_policy.roles); } diff --git a/dbms/src/Access/RowPolicy.h b/dbms/src/Access/RowPolicy.h index 22681b8875e..6bc51a2481c 100644 --- a/dbms/src/Access/RowPolicy.h +++ b/dbms/src/Access/RowPolicy.h @@ -1,6 +1,7 @@ #pragma once #include +#include namespace DB @@ -65,10 +66,8 @@ struct RowPolicy : public IAccessEntity bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } - /// Which roles or users should use this quota. - Strings roles; - bool all_roles = false; - Strings except_roles; + /// Which roles or users should use this row policy. + GenericRoleSet roles; private: String database; diff --git a/dbms/src/Access/RowPolicyContext.cpp b/dbms/src/Access/RowPolicyContext.cpp index 33166d18997..2b94ac6a3f6 100644 --- a/dbms/src/Access/RowPolicyContext.cpp +++ b/dbms/src/Access/RowPolicyContext.cpp @@ -21,8 +21,8 @@ RowPolicyContext::RowPolicyContext() RowPolicyContext::~RowPolicyContext() = default; -RowPolicyContext::RowPolicyContext(const String & user_name_) - : user_name(user_name_) +RowPolicyContext::RowPolicyContext(const UUID & user_id_) + : user_id(user_id_) {} diff --git a/dbms/src/Access/RowPolicyContext.h b/dbms/src/Access/RowPolicyContext.h index 0d54573f5a4..631ec4f020e 100644 --- a/dbms/src/Access/RowPolicyContext.h +++ b/dbms/src/Access/RowPolicyContext.h @@ -39,7 +39,7 @@ public: private: friend class RowPolicyContextFactory; friend struct ext::shared_ptr_helper; - RowPolicyContext(const String & user_name_); /// RowPolicyContext should be created by RowPolicyContextFactory. + RowPolicyContext(const UUID & user_id_); /// RowPolicyContext should be created by RowPolicyContextFactory. using DatabaseAndTableName = std::pair; using DatabaseAndTableNameRef = std::pair; @@ -57,7 +57,7 @@ private: }; using MapOfMixedConditions = std::unordered_map; - const String user_name; + const UUID user_id; mutable boost::atomic_shared_ptr map_of_mixed_conditions; }; diff --git a/dbms/src/Access/RowPolicyContextFactory.cpp b/dbms/src/Access/RowPolicyContextFactory.cpp index b23a4d77745..e6a6804bd77 100644 --- a/dbms/src/Access/RowPolicyContextFactory.cpp +++ b/dbms/src/Access/RowPolicyContextFactory.cpp @@ -130,10 +130,7 @@ namespace void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_) { policy = policy_; - - boost::range::copy(policy->roles, std::inserter(roles, roles.end())); - all_roles = policy->all_roles; - boost::range::copy(policy->except_roles, std::inserter(except_roles, except_roles.end())); + roles = &policy->roles; for (auto index : ext::range_with_static_cast(0, MAX_CONDITION_INDEX)) { @@ -170,13 +167,7 @@ void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_ bool RowPolicyContextFactory::PolicyInfo::canUseWithContext(const RowPolicyContext & context) const { - if (roles.count(context.user_name)) - return true; - - if (all_roles && !except_roles.count(context.user_name)) - return true; - - return false; + return roles->match(context.user_id); } @@ -188,11 +179,11 @@ RowPolicyContextFactory::RowPolicyContextFactory(const AccessControlManager & ac RowPolicyContextFactory::~RowPolicyContextFactory() = default; -RowPolicyContextPtr RowPolicyContextFactory::createContext(const String & user_name) +RowPolicyContextPtr RowPolicyContextFactory::createContext(const UUID & user_id) { std::lock_guard lock{mutex}; ensureAllRowPoliciesRead(); - auto context = ext::shared_ptr_helper::create(user_name); + auto context = ext::shared_ptr_helper::create(user_id); contexts.push_back(context); mixConditionsForContext(*context); return context; diff --git a/dbms/src/Access/RowPolicyContextFactory.h b/dbms/src/Access/RowPolicyContextFactory.h index b04244a1d99..c393a75285b 100644 --- a/dbms/src/Access/RowPolicyContextFactory.h +++ b/dbms/src/Access/RowPolicyContextFactory.h @@ -19,7 +19,7 @@ public: RowPolicyContextFactory(const AccessControlManager & access_control_manager_); ~RowPolicyContextFactory(); - RowPolicyContextPtr createContext(const String & user_name); + RowPolicyContextPtr createContext(const UUID & user_id); private: using ParsedConditions = RowPolicyContext::ParsedConditions; @@ -31,9 +31,7 @@ private: bool canUseWithContext(const RowPolicyContext & context) const; RowPolicyPtr policy; - std::unordered_set roles; - bool all_roles = false; - std::unordered_set except_roles; + const GenericRoleSet * roles = nullptr; ParsedConditions parsed_conditions; }; diff --git a/dbms/src/Access/UsersConfigAccessStorage.cpp b/dbms/src/Access/UsersConfigAccessStorage.cpp index e71b2c27fa5..dfda2f35035 100644 --- a/dbms/src/Access/UsersConfigAccessStorage.cpp +++ b/dbms/src/Access/UsersConfigAccessStorage.cpp @@ -183,7 +183,7 @@ namespace } - QuotaPtr parseQuota(const Poco::Util::AbstractConfiguration & config, const String & quota_name, const Strings & user_names) + QuotaPtr parseQuota(const Poco::Util::AbstractConfiguration & config, const String & quota_name, const std::vector & user_ids) { auto quota = std::make_shared(); quota->setName(quota_name); @@ -225,7 +225,7 @@ namespace limits.max[ResourceType::EXECUTION_TIME] = Quota::secondsToExecutionTime(config.getUInt64(interval_config + ".execution_time", Quota::UNLIMITED)); } - quota->roles = user_names; + quota->roles.add(user_ids); return quota; } @@ -235,11 +235,11 @@ namespace { Poco::Util::AbstractConfiguration::Keys user_names; config.keys("users", user_names); - std::unordered_map quota_to_user_names; + std::unordered_map> quota_to_user_ids; for (const auto & user_name : user_names) { if (config.has("users." + user_name + ".quota")) - quota_to_user_names[config.getString("users." + user_name + ".quota")].push_back(user_name); + quota_to_user_ids[config.getString("users." + user_name + ".quota")].push_back(generateID(typeid(User), user_name)); } Poco::Util::AbstractConfiguration::Keys quota_names; @@ -250,8 +250,8 @@ namespace { try { - auto it = quota_to_user_names.find(quota_name); - const Strings quota_users = (it != quota_to_user_names.end()) ? std::move(it->second) : Strings{}; + auto it = quota_to_user_ids.find(quota_name); + const std::vector & quota_users = (it != quota_to_user_ids.end()) ? std::move(it->second) : std::vector{}; quotas.push_back(parseQuota(config, quota_name, quota_users)); } catch (...) @@ -307,7 +307,7 @@ namespace auto policy = std::make_shared(); policy->setFullName(database, table_name, user_name); policy->conditions[RowPolicy::SELECT_FILTER] = config.getString(filter_config); - policy->roles.push_back(user_name); + policy->roles.add(generateID(typeid(User), user_name)); policies.push_back(policy); } catch (...) diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index 2def6dfad22..f3057bf2d02 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -652,7 +652,9 @@ void Context::checkAccess(const AccessRightsElements & access) const { return ch void Context::switchRowPolicy() { - row_policy = getAccessControlManager().getRowPolicyContext(client_info.initial_user); + auto initial_user_id = getAccessControlManager().find(client_info.initial_user); + if (initial_user_id) + row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id); } void Context::setUsersConfig(const ConfigurationPtr & config) @@ -742,13 +744,14 @@ void Context::setUser(const String & name, const String & password, const Poco:: }, &subscription_for_user_change.subscription); - quota = getAccessControlManager().getQuotaContext(client_info.current_user, client_info.current_address.host(), client_info.quota_key); - row_policy = getAccessControlManager().getRowPolicyContext(client_info.current_user); + quota = getAccessControlManager().getQuotaContext(user_id, name, address.host(), quota_key); + row_policy = getAccessControlManager().getRowPolicyContext(user_id); calculateUserSettings(); calculateAccessRights(); } + void Context::addDependencyUnsafe(const StorageID & from, const StorageID & where) { shared->view_dependencies[from].insert(where); diff --git a/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp b/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp index f979c1e0ac8..fed8783490d 100644 --- a/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateQuotaQuery.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include @@ -18,12 +18,16 @@ BlockIO InterpreterCreateQuotaQuery::execute() auto & access_control = context.getAccessControlManager(); context.checkAccess(query.alter ? AccessType::ALTER_QUOTA : AccessType::CREATE_QUOTA); + std::optional roles_from_query; + if (query.roles) + roles_from_query = GenericRoleSet{*query.roles, access_control, context.getUserID()}; + if (query.alter) { auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { auto updated_quota = typeid_cast>(entity->clone()); - updateQuotaFromQuery(*updated_quota, query); + updateQuotaFromQuery(*updated_quota, query, roles_from_query); return updated_quota; }; if (query.if_exists) @@ -37,7 +41,7 @@ BlockIO InterpreterCreateQuotaQuery::execute() else { auto new_quota = std::make_shared(); - updateQuotaFromQuery(*new_quota, query); + updateQuotaFromQuery(*new_quota, query, roles_from_query); if (query.if_not_exists) access_control.tryInsert(new_quota); @@ -51,7 +55,7 @@ BlockIO InterpreterCreateQuotaQuery::execute() } -void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query) +void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query, const std::optional & roles_from_query) { if (query.alter) { @@ -98,25 +102,7 @@ void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTC } } - if (query.roles) - { - const auto & query_roles = *query.roles; - - /// We keep `roles` sorted. - quota.roles = query_roles.names; - if (query_roles.current_user) - quota.roles.push_back(context.getClientInfo().current_user); - boost::range::sort(quota.roles); - quota.roles.erase(std::unique(quota.roles.begin(), quota.roles.end()), quota.roles.end()); - - quota.all_roles = query_roles.all; - - /// We keep `except_roles` sorted. - quota.except_roles = query_roles.except_names; - if (query_roles.except_current_user) - quota.except_roles.push_back(context.getClientInfo().current_user); - boost::range::sort(quota.except_roles); - quota.except_roles.erase(std::unique(quota.except_roles.begin(), quota.except_roles.end()), quota.except_roles.end()); - } + if (roles_from_query) + quota.roles = *roles_from_query; } } diff --git a/dbms/src/Interpreters/InterpreterCreateQuotaQuery.h b/dbms/src/Interpreters/InterpreterCreateQuotaQuery.h index bbf91bbe1d3..6508ace0bf0 100644 --- a/dbms/src/Interpreters/InterpreterCreateQuotaQuery.h +++ b/dbms/src/Interpreters/InterpreterCreateQuotaQuery.h @@ -2,12 +2,14 @@ #include #include +#include namespace DB { class ASTCreateQuotaQuery; struct Quota; +struct GenericRoleSet; class InterpreterCreateQuotaQuery : public IInterpreter @@ -21,7 +23,7 @@ public: bool ignoreLimits() const override { return true; } private: - void updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query); + void updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query, const std::optional & roles_from_query); ASTPtr query_ptr; Context & context; diff --git a/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp b/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp index f5749f4eb74..09480e66e15 100644 --- a/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.cpp @@ -1,6 +1,6 @@ #include #include -#include +#include #include #include #include @@ -16,12 +16,16 @@ BlockIO InterpreterCreateRowPolicyQuery::execute() auto & access_control = context.getAccessControlManager(); context.checkAccess(query.alter ? AccessType::ALTER_POLICY : AccessType::CREATE_POLICY); + std::optional roles_from_query; + if (query.roles) + roles_from_query = GenericRoleSet{*query.roles, access_control, context.getUserID()}; + if (query.alter) { auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { auto updated_policy = typeid_cast>(entity->clone()); - updateRowPolicyFromQuery(*updated_policy, query); + updateRowPolicyFromQuery(*updated_policy, query, roles_from_query); return updated_policy; }; String full_name = query.name_parts.getFullName(context); @@ -36,7 +40,7 @@ BlockIO InterpreterCreateRowPolicyQuery::execute() else { auto new_policy = std::make_shared(); - updateRowPolicyFromQuery(*new_policy, query); + updateRowPolicyFromQuery(*new_policy, query, roles_from_query); if (query.if_not_exists) access_control.tryInsert(new_policy); @@ -50,7 +54,7 @@ BlockIO InterpreterCreateRowPolicyQuery::execute() } -void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query) +void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query, const std::optional & roles_from_query) { if (query.alter) { @@ -70,25 +74,7 @@ void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & polic for (const auto & [index, condition] : query.conditions) policy.conditions[index] = condition ? serializeAST(*condition) : String{}; - if (query.roles) - { - const auto & query_roles = *query.roles; - - /// We keep `roles` sorted. - policy.roles = query_roles.names; - if (query_roles.current_user) - policy.roles.push_back(context.getClientInfo().current_user); - boost::range::sort(policy.roles); - policy.roles.erase(std::unique(policy.roles.begin(), policy.roles.end()), policy.roles.end()); - - policy.all_roles = query_roles.all; - - /// We keep `except_roles` sorted. - policy.except_roles = query_roles.except_names; - if (query_roles.except_current_user) - policy.except_roles.push_back(context.getClientInfo().current_user); - boost::range::sort(policy.except_roles); - policy.except_roles.erase(std::unique(policy.except_roles.begin(), policy.except_roles.end()), policy.except_roles.end()); - } + if (roles_from_query) + policy.roles = *roles_from_query; } } diff --git a/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.h b/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.h index e7ee47dbe81..283a302ab7d 100644 --- a/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.h +++ b/dbms/src/Interpreters/InterpreterCreateRowPolicyQuery.h @@ -2,12 +2,14 @@ #include #include +#include namespace DB { class ASTCreateRowPolicyQuery; struct RowPolicy; +struct GenericRoleSet; class InterpreterCreateRowPolicyQuery : public IInterpreter @@ -18,7 +20,7 @@ public: BlockIO execute() override; private: - void updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query); + void updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query, const std::optional & roles_from_query); ASTPtr query_ptr; Context & context; diff --git a/dbms/src/Interpreters/InterpreterGrantQuery.cpp b/dbms/src/Interpreters/InterpreterGrantQuery.cpp index 076bd6f11a1..58bb104de6a 100644 --- a/dbms/src/Interpreters/InterpreterGrantQuery.cpp +++ b/dbms/src/Interpreters/InterpreterGrantQuery.cpp @@ -1,9 +1,9 @@ #include #include -#include #include #include #include +#include #include @@ -16,7 +16,9 @@ BlockIO InterpreterGrantQuery::execute() context.getAccessRights()->checkGrantOption(query.access_rights_elements); using Kind = ASTGrantQuery::Kind; + std::vector to_roles = GenericRoleSet{*query.to_roles, access_control, context.getUserID()}.getMatchingUsers(access_control); String current_database = context.getCurrentDatabase(); + using Kind = ASTGrantQuery::Kind; auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { @@ -42,10 +44,8 @@ BlockIO InterpreterGrantQuery::execute() return updated_user; }; - std::vector ids = access_control.getIDs(query.to_roles->names); - if (query.to_roles->current_user) - ids.push_back(context.getUserID()); - access_control.update(ids, update_func); + access_control.update(to_roles, update_func); + return {}; } diff --git a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp index dab3a42554c..ecbf40aa73c 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include @@ -115,14 +115,8 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShow create_query->all_limits.push_back(create_query_limits); } - if (!quota->roles.empty() || quota->all_roles) - { - auto create_query_roles = std::make_shared(); - create_query_roles->names = quota->roles; - create_query_roles->all = quota->all_roles; - create_query_roles->except_names = quota->except_roles; - create_query->roles = std::move(create_query_roles); - } + if (!quota->roles.empty()) + create_query->roles = quota->roles.toAST(access_control); return create_query; } @@ -149,14 +143,8 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateRowPolicyQuery(const AST } } - if (!policy->roles.empty() || policy->all_roles) - { - auto create_query_roles = std::make_shared(); - create_query_roles->names = policy->roles; - create_query_roles->all = policy->all_roles; - create_query_roles->except_names = policy->except_roles; - create_query->roles = std::move(create_query_roles); - } + if (!policy->roles.empty()) + create_query->roles = policy->roles.toAST(access_control); return create_query; } diff --git a/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp b/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp index 17761178ef4..c1d430586ba 100644 --- a/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include #include #include #include @@ -92,7 +92,7 @@ ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show if (show_query.current_user) user = context.getUser(); else - user = context.getAccessControlManager().getUser(show_query.name); + user = context.getAccessControlManager().read(show_query.name); ASTs res; @@ -111,7 +111,7 @@ ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show auto grant_query = std::make_shared(); grant_query->kind = kind; grant_query->grant_option = grant_option; - grant_query->to_roles = std::make_shared(); + grant_query->to_roles = std::make_shared(); grant_query->to_roles->names.push_back(user->getName()); grant_query->access_rights_elements = elements; res.push_back(std::move(grant_query)); diff --git a/dbms/src/Parsers/ASTCreateQuotaQuery.cpp b/dbms/src/Parsers/ASTCreateQuotaQuery.cpp index 205d3c33d18..87951b18705 100644 --- a/dbms/src/Parsers/ASTCreateQuotaQuery.cpp +++ b/dbms/src/Parsers/ASTCreateQuotaQuery.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -94,7 +94,7 @@ namespace } } - void formatToRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings) + void formatToRoles(const ASTGenericRoleSet & roles, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); roles.format(settings); diff --git a/dbms/src/Parsers/ASTCreateQuotaQuery.h b/dbms/src/Parsers/ASTCreateQuotaQuery.h index 056a445f23b..15533040afd 100644 --- a/dbms/src/Parsers/ASTCreateQuotaQuery.h +++ b/dbms/src/Parsers/ASTCreateQuotaQuery.h @@ -6,7 +6,7 @@ namespace DB { -class ASTRoleList; +class ASTGenericRoleSet; /** CREATE QUOTA [IF NOT EXISTS | OR REPLACE] name @@ -53,7 +53,7 @@ public: }; std::vector all_limits; - std::shared_ptr roles; + std::shared_ptr roles; String getID(char) const override; ASTPtr clone() const override; diff --git a/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp b/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp index 184474753df..6f73ff04de5 100644 --- a/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp +++ b/dbms/src/Parsers/ASTCreateRowPolicyQuery.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -112,7 +112,7 @@ namespace } } - void formatToRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings) + void formatToRoles(const ASTGenericRoleSet & roles, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); roles.format(settings); diff --git a/dbms/src/Parsers/ASTCreateRowPolicyQuery.h b/dbms/src/Parsers/ASTCreateRowPolicyQuery.h index a4caf1aeb85..2af07551560 100644 --- a/dbms/src/Parsers/ASTCreateRowPolicyQuery.h +++ b/dbms/src/Parsers/ASTCreateRowPolicyQuery.h @@ -8,7 +8,7 @@ namespace DB { -class ASTRoleList; +class ASTGenericRoleSet; /** CREATE [ROW] POLICY [IF NOT EXISTS | OR REPLACE] name ON [database.]table * [AS {PERMISSIVE | RESTRICTIVE}] @@ -41,7 +41,7 @@ public: using ConditionIndex = RowPolicy::ConditionIndex; std::vector> conditions; - std::shared_ptr roles; + std::shared_ptr roles; String getID(char) const override; ASTPtr clone() const override; diff --git a/dbms/src/Parsers/ASTRoleList.cpp b/dbms/src/Parsers/ASTGenericRoleSet.cpp similarity index 92% rename from dbms/src/Parsers/ASTRoleList.cpp rename to dbms/src/Parsers/ASTGenericRoleSet.cpp index 87e388b9b1c..b93110866d4 100644 --- a/dbms/src/Parsers/ASTRoleList.cpp +++ b/dbms/src/Parsers/ASTGenericRoleSet.cpp @@ -1,10 +1,10 @@ -#include +#include #include namespace DB { -void ASTRoleList::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const +void ASTGenericRoleSet::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { if (empty()) { diff --git a/dbms/src/Parsers/ASTRoleList.h b/dbms/src/Parsers/ASTGenericRoleSet.h similarity index 53% rename from dbms/src/Parsers/ASTRoleList.h rename to dbms/src/Parsers/ASTGenericRoleSet.h index daef6124d18..dfe4c67bad4 100644 --- a/dbms/src/Parsers/ASTRoleList.h +++ b/dbms/src/Parsers/ASTGenericRoleSet.h @@ -6,8 +6,9 @@ namespace DB { -/// {role|CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {role|CURRENT_USER} [,...] -class ASTRoleList : public IAST +/// Represents a set of users/roles like +/// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...] +class ASTGenericRoleSet : public IAST { public: Strings names; @@ -18,8 +19,8 @@ public: bool empty() const { return names.empty() && !current_user && !all; } - String getID(char) const override { return "RoleList"; } - ASTPtr clone() const override { return std::make_shared(*this); } + String getID(char) const override { return "GenericRoleSet"; } + ASTPtr clone() const override { return std::make_shared(*this); } void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; }; } diff --git a/dbms/src/Parsers/ASTGrantQuery.cpp b/dbms/src/Parsers/ASTGrantQuery.cpp index 1f3800f100c..c5132c8359c 100644 --- a/dbms/src/Parsers/ASTGrantQuery.cpp +++ b/dbms/src/Parsers/ASTGrantQuery.cpp @@ -1,5 +1,5 @@ #include -#include +#include #include #include #include @@ -92,7 +92,7 @@ namespace } - void formatToRoles(const ASTRoleList & to_roles, ASTGrantQuery::Kind kind, const IAST::FormatSettings & settings) + void formatToRoles(const ASTGenericRoleSet & to_roles, ASTGrantQuery::Kind kind, const IAST::FormatSettings & settings) { using Kind = ASTGrantQuery::Kind; settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ((kind == Kind::GRANT) ? " TO " : " FROM ") diff --git a/dbms/src/Parsers/ASTGrantQuery.h b/dbms/src/Parsers/ASTGrantQuery.h index 2cdf7b7f661..56663d84620 100644 --- a/dbms/src/Parsers/ASTGrantQuery.h +++ b/dbms/src/Parsers/ASTGrantQuery.h @@ -6,7 +6,7 @@ namespace DB { -class ASTRoleList; +class ASTGenericRoleSet; /** GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO {user_name | CURRENT_USER} [,...] [WITH GRANT OPTION] @@ -22,7 +22,7 @@ public: }; Kind kind = Kind::GRANT; AccessRightsElements access_rights_elements; - std::shared_ptr to_roles; + std::shared_ptr to_roles; bool grant_option = false; String getID(char) const override; diff --git a/dbms/src/Parsers/ParserCreateQuotaQuery.cpp b/dbms/src/Parsers/ParserCreateQuotaQuery.cpp index 61e7d2f1c52..86516554044 100644 --- a/dbms/src/Parsers/ParserCreateQuotaQuery.cpp +++ b/dbms/src/Parsers/ParserCreateQuotaQuery.cpp @@ -3,10 +3,10 @@ #include #include #include -#include +#include #include #include -#include +#include #include #include @@ -187,15 +187,15 @@ namespace }); } - bool parseToRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) + bool parseToRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) { return IParserBase::wrapParseImpl(pos, [&] { ASTPtr node; - if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserRoleList{}.parse(pos, node, expected)) + if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserGenericRoleSet{}.parse(pos, node, expected)) return false; - roles = std::static_pointer_cast(node); + roles = std::static_pointer_cast(node); return true; }); } @@ -235,7 +235,7 @@ bool ParserCreateQuotaQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expe String new_name; std::optional key_type; std::vector all_limits; - std::shared_ptr roles; + std::shared_ptr roles; while (true) { diff --git a/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp b/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp index c1bfab2551b..d035bffbed3 100644 --- a/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp +++ b/dbms/src/Parsers/ParserCreateRowPolicyQuery.cpp @@ -1,8 +1,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include @@ -187,15 +187,15 @@ namespace }); } - bool parseToRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) + bool parseToRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) { return IParserBase::wrapParseImpl(pos, [&] { ASTPtr ast; - if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserRoleList{}.parse(pos, ast, expected)) + if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserGenericRoleSet{}.parse(pos, ast, expected)) return false; - roles = std::static_pointer_cast(ast); + roles = std::static_pointer_cast(ast); return true; }); } @@ -239,7 +239,7 @@ bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & String new_policy_name; std::optional is_restrictive; std::vector> conditions; - std::shared_ptr roles; + std::shared_ptr roles; while (true) { diff --git a/dbms/src/Parsers/ParserCreateUserQuery.cpp b/dbms/src/Parsers/ParserCreateUserQuery.cpp index 7b5aa1fa03c..9b188d0c93d 100644 --- a/dbms/src/Parsers/ParserCreateUserQuery.cpp +++ b/dbms/src/Parsers/ParserCreateUserQuery.cpp @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include diff --git a/dbms/src/Parsers/ParserRoleList.cpp b/dbms/src/Parsers/ParserGenericRoleSet.cpp similarity index 71% rename from dbms/src/Parsers/ParserRoleList.cpp rename to dbms/src/Parsers/ParserGenericRoleSet.cpp index 8cdae1f7bab..5f6898fec80 100644 --- a/dbms/src/Parsers/ParserRoleList.cpp +++ b/dbms/src/Parsers/ParserGenericRoleSet.cpp @@ -1,6 +1,6 @@ -#include +#include #include -#include +#include #include #include @@ -9,7 +9,7 @@ namespace DB { namespace { - bool parseRoleListBeforeExcept(IParserBase::Pos & pos, Expected & expected, bool * all, bool * current_user, Strings & names) + bool parseBeforeExcept(IParserBase::Pos & pos, Expected & expected, bool * all, bool * current_user, Strings & names) { return IParserBase::wrapParseImpl(pos, [&] { @@ -56,24 +56,20 @@ namespace }); } - bool parseRoleListExcept(IParserBase::Pos & pos, Expected & expected, bool * except_current_user, Strings & except_names) + bool parseExcept(IParserBase::Pos & pos, Expected & expected, bool * except_current_user, Strings & except_names) { return IParserBase::wrapParseImpl(pos, [&] { if (!ParserKeyword{"EXCEPT"}.ignore(pos, expected)) return false; - return parseRoleListBeforeExcept(pos, expected, nullptr, except_current_user, except_names); + return parseBeforeExcept(pos, expected, nullptr, except_current_user, except_names); }); } } -ParserRoleList::ParserRoleList(bool allow_all_, bool allow_current_user_) - : allow_all(allow_all_), allow_current_user(allow_current_user_) {} - - -bool ParserRoleList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +bool ParserGenericRoleSet::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { Strings names; bool current_user = false; @@ -81,15 +77,15 @@ bool ParserRoleList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) Strings except_names; bool except_current_user = false; - if (!parseRoleListBeforeExcept(pos, expected, (allow_all ? &all : nullptr), (allow_current_user ? ¤t_user : nullptr), names)) + if (!parseBeforeExcept(pos, expected, (allow_all ? &all : nullptr), (allow_current_user ? ¤t_user : nullptr), names)) return false; - parseRoleListExcept(pos, expected, (allow_current_user ? &except_current_user : nullptr), except_names); + parseExcept(pos, expected, (allow_current_user ? &except_current_user : nullptr), except_names); if (all) names.clear(); - auto result = std::make_shared(); + auto result = std::make_shared(); result->names = std::move(names); result->current_user = current_user; result->all = all; diff --git a/dbms/src/Parsers/ParserGenericRoleSet.h b/dbms/src/Parsers/ParserGenericRoleSet.h new file mode 100644 index 00000000000..9fa01effc47 --- /dev/null +++ b/dbms/src/Parsers/ParserGenericRoleSet.h @@ -0,0 +1,26 @@ +#pragma once + +#include + + +namespace DB +{ +/** Parses a string like this: + * {role|CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {role|CURRENT_USER} [,...] + */ +class ParserGenericRoleSet : public IParserBase +{ +public: + ParserGenericRoleSet & allowAll(bool allow_) { allow_all = allow_; return *this; } + ParserGenericRoleSet & allowCurrentUser(bool allow_) { allow_current_user = allow_; return *this; } + +protected: + const char * getName() const override { return "GenericRoleSet"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; + +private: + bool allow_all = true; + bool allow_current_user = true; +}; + +} diff --git a/dbms/src/Parsers/ParserGrantQuery.cpp b/dbms/src/Parsers/ParserGrantQuery.cpp index db5c75da290..770f5bee528 100644 --- a/dbms/src/Parsers/ParserGrantQuery.cpp +++ b/dbms/src/Parsers/ParserGrantQuery.cpp @@ -1,10 +1,10 @@ #include #include #include -#include +#include #include #include -#include +#include #include @@ -206,7 +206,7 @@ namespace } - bool parseToRoles(IParser::Pos & pos, Expected & expected, ASTGrantQuery::Kind kind, std::shared_ptr & to_roles) + bool parseToRoles(IParser::Pos & pos, Expected & expected, ASTGrantQuery::Kind kind, std::shared_ptr & to_roles) { return IParserBase::wrapParseImpl(pos, [&] { @@ -223,10 +223,10 @@ namespace } ASTPtr ast; - if (!ParserRoleList{false, false}.parse(pos, ast, expected)) + if (!ParserGenericRoleSet{}.allowAll(kind == Kind::REVOKE).parse(pos, ast, expected)) return false; - to_roles = typeid_cast>(ast); + to_roles = typeid_cast>(ast); return true; }); } @@ -252,7 +252,7 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) } AccessRightsElements elements; - std::shared_ptr to_roles; + std::shared_ptr to_roles; if (!parseAccessRightsElements(pos, expected, elements) && !parseToRoles(pos, expected, kind, to_roles)) return false; diff --git a/dbms/src/Parsers/ParserRoleList.h b/dbms/src/Parsers/ParserRoleList.h deleted file mode 100644 index 3daa0d7b6ff..00000000000 --- a/dbms/src/Parsers/ParserRoleList.h +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include - - -namespace DB -{ -/** Parses a string like this: - * {role|CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {role|CURRENT_USER} [,...] - */ -class ParserRoleList : public IParserBase -{ -public: - ParserRoleList(bool allow_all_ = true, bool allow_current_user_ = true); - -protected: - const char * getName() const override { return "RoleList"; } - bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; - -private: - bool allow_all; - bool allow_current_user; -}; - -} diff --git a/dbms/src/Storages/System/StorageSystemQuotas.cpp b/dbms/src/Storages/System/StorageSystemQuotas.cpp index b82e348c86d..ce9e79cc33f 100644 --- a/dbms/src/Storages/System/StorageSystemQuotas.cpp +++ b/dbms/src/Storages/System/StorageSystemQuotas.cpp @@ -34,8 +34,6 @@ NamesAndTypesList StorageSystemQuotas::getNamesAndTypes() {"source", std::make_shared()}, {"key_type", std::make_shared(getKeyTypeEnumValues())}, {"roles", std::make_shared(std::make_shared())}, - {"all_roles", std::make_shared()}, - {"except_roles", std::make_shared(std::make_shared())}, {"intervals.duration", std::make_shared(std::make_shared())}, {"intervals.randomize_interval", std::make_shared(std::make_shared())}}; @@ -63,9 +61,6 @@ void StorageSystemQuotas::fillData(MutableColumns & res_columns, const Context & auto & key_type_column = *res_columns[i++]; auto & roles_data = assert_cast(*res_columns[i]).getData(); auto & roles_offsets = assert_cast(*res_columns[i++]).getOffsets(); - auto & all_roles_column = *res_columns[i++]; - auto & except_roles_data = assert_cast(*res_columns[i]).getData(); - auto & except_roles_offsets = assert_cast(*res_columns[i++]).getOffsets(); auto & durations_data = assert_cast(*res_columns[i]).getData(); auto & durations_offsets = assert_cast(*res_columns[i++]).getOffsets(); auto & randomize_intervals_data = assert_cast(*res_columns[i]).getData(); @@ -92,16 +87,10 @@ void StorageSystemQuotas::fillData(MutableColumns & res_columns, const Context & storage_name_column.insert(storage_name); key_type_column.insert(static_cast(quota->key_type)); - for (const auto & role : quota->roles) + for (const String & role : quota->roles.toStrings(access_control)) roles_data.insert(role); roles_offsets.push_back(roles_data.size()); - all_roles_column.insert(static_cast(quota->all_roles)); - - for (const auto & except_role : quota->except_roles) - except_roles_data.insert(except_role); - except_roles_offsets.push_back(except_roles_data.size()); - for (const auto & limits : quota->all_limits) { durations_data.insert(std::chrono::seconds{limits.duration}.count()); diff --git a/dbms/tests/integration/test_quota/test.py b/dbms/tests/integration/test_quota/test.py index e7caaf5cd06..85d2ded16c1 100644 --- a/dbms/tests/integration/test_quota/test.py +++ b/dbms/tests/integration/test_quota/test.py @@ -57,7 +57,7 @@ def test_quota_from_users_xml(): assert instance.query("SELECT currentQuota()") == "myQuota\n" assert instance.query("SELECT currentQuotaID()") == "e651da9c-a748-8703-061a-7e5e5096dae7\n" assert instance.query("SELECT currentQuotaKey()") == "default\n" - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t0\t0\t0\t0\t0\t0" instance.query("SELECT * from test_table") @@ -70,7 +70,7 @@ def test_quota_from_users_xml(): def test_simpliest_quota(): # Simpliest quota doesn't even track usage. copy_quota_xml('simpliest.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N" instance.query("SELECT * from test_table") @@ -80,7 +80,7 @@ def test_simpliest_quota(): def test_tracking_quota(): # Now we're tracking usage. copy_quota_xml('tracking.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[0]\t[0]\t[0]\t[0]\t[0]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[0]\t[0]\t[0]\t[0]\t[0]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t0\t0\t0\t0\t0\t0" instance.query("SELECT * from test_table") @@ -93,7 +93,7 @@ def test_tracking_quota(): def test_exceed_quota(): # Change quota, now the limits are tiny so we will exceed the quota. copy_quota_xml('tiny_limits.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1]\t[1]\t[1]\t[0]\t[1]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1]\t[1]\t[1]\t[0]\t[1]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t0\t0\t0\t0\t0\t0" assert re.search("Quota.*has\ been\ exceeded", instance.query_and_get_error("SELECT * from test_table")) @@ -101,7 +101,7 @@ def test_exceed_quota(): # Change quota, now the limits are enough to execute queries. copy_quota_xml('normal_limits.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t1\t1\t0\t0\t50\t0" instance.query("SELECT * from test_table") @@ -109,12 +109,12 @@ def test_exceed_quota(): def test_add_remove_interval(): - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t0\t0\t0\t0\t0\t0" # Add interval. copy_quota_xml('two_intervals.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952,63113904]\t[0,1]\t[1000,0]\t[0,0]\t[0,0]\t[0,30000]\t[1000,0]\t[0,20000]\t[0,120]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952,63113904]\t[0,1]\t[1000,0]\t[0,0]\t[0,0]\t[0,30000]\t[1000,0]\t[0,20000]\t[0,120]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t0\t0\t0\t0\t0\t0\n"\ "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t63113904\t0\t0\t0\t0\t0\t0" @@ -124,7 +124,7 @@ def test_add_remove_interval(): # Remove interval. copy_quota_xml('normal_limits.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t1\t0\t50\t200\t50\t200" instance.query("SELECT * from test_table") @@ -132,7 +132,7 @@ def test_add_remove_interval(): # Remove all intervals. copy_quota_xml('simpliest.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]\t[]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N" instance.query("SELECT * from test_table") @@ -140,23 +140,23 @@ def test_add_remove_interval(): # Add one interval back. copy_quota_xml('normal_limits.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t0\t0\t0\t0\t0\t0" def test_add_remove_quota(): - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t0\t0\t0\t0\t0\t0" # Add quota. copy_quota_xml('two_quotas.xml') assert system_quotas() ==\ - "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]\n"\ - "myQuota2\t4590510c-4d13-bf21-ec8a-c2187b092e73\tusers.xml\tclient key or user name\t[]\t0\t[]\t[3600,2629746]\t[1,0]\t[0,0]\t[0,0]\t[4000,0]\t[400000,0]\t[4000,0]\t[400000,0]\t[60,1800]" + "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]\n"\ + "myQuota2\t4590510c-4d13-bf21-ec8a-c2187b092e73\tusers.xml\tclient key or user name\t[]\t[3600,2629746]\t[1,0]\t[0,0]\t[0,0]\t[4000,0]\t[400000,0]\t[4000,0]\t[400000,0]\t[60,1800]" # Drop quota. copy_quota_xml('normal_limits.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" # Drop all quotas. copy_quota_xml('no_quotas.xml') @@ -165,17 +165,17 @@ def test_add_remove_quota(): # Add one quota back. copy_quota_xml('normal_limits.xml') - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" assert system_quota_usage() == "e651da9c-a748-8703-061a-7e5e5096dae7\tdefault\t31556952\t0\t0\t0\t0\t0\t0" def test_reload_users_xml_by_timer(): - assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" + assert system_quotas() == "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1000]\t[0]\t[0]\t[0]\t[1000]\t[0]\t[0]" time.sleep(1) # The modification time of the 'quota.xml' file should be different, # because config files are reload by timer only when the modification time is changed. copy_quota_xml('tiny_limits.xml', reload_immediately=False) - assert_eq_with_retry(instance, query_from_system_quotas, "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t0\t[]\t[31556952]\t[0]\t[1]\t[1]\t[1]\t[0]\t[1]\t[0]\t[0]") + assert_eq_with_retry(instance, query_from_system_quotas, "myQuota\te651da9c-a748-8703-061a-7e5e5096dae7\tusers.xml\tuser name\t['default']\t[31556952]\t[0]\t[1]\t[1]\t[1]\t[0]\t[1]\t[0]\t[0]") def test_dcl_introspection(): From 5849dd22364b1f8187e134c5a73e02523e778376 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Wed, 12 Feb 2020 23:47:37 +0300 Subject: [PATCH 06/19] Slightly better solution for checking row policy for distributed tables, now it checks both current user's and initial user's filters. --- dbms/programs/server/TCPHandler.cpp | 2 +- dbms/src/Access/RowPolicyContext.cpp | 19 +++++++++++++++++++ dbms/src/Access/RowPolicyContext.h | 3 +++ dbms/src/Interpreters/Context.cpp | 5 +++-- dbms/src/Interpreters/Context.h | 9 ++++++--- .../Interpreters/InterpreterSelectQuery.cpp | 4 +++- 6 files changed, 35 insertions(+), 7 deletions(-) diff --git a/dbms/programs/server/TCPHandler.cpp b/dbms/programs/server/TCPHandler.cpp index b645118494b..8fb3c2c6c76 100644 --- a/dbms/programs/server/TCPHandler.cpp +++ b/dbms/programs/server/TCPHandler.cpp @@ -902,7 +902,7 @@ void TCPHandler::receiveQuery() } else { - query_context->switchRowPolicy(); + query_context->setInitialRowPolicy(); } } diff --git a/dbms/src/Access/RowPolicyContext.cpp b/dbms/src/Access/RowPolicyContext.cpp index 2b94ac6a3f6..753f8d6d3f7 100644 --- a/dbms/src/Access/RowPolicyContext.cpp +++ b/dbms/src/Access/RowPolicyContext.cpp @@ -1,4 +1,6 @@ #include +#include +#include #include #include #include @@ -37,6 +39,23 @@ ASTPtr RowPolicyContext::getCondition(const String & database, const String & ta } +ASTPtr RowPolicyContext::combineConditionsUsingAnd(const ASTPtr & lhs, const ASTPtr & rhs) +{ + if (!lhs) + return rhs; + if (!rhs) + return lhs; + auto function = std::make_shared(); + auto exp_list = std::make_shared(); + function->name = "and"; + function->arguments = exp_list; + function->children.push_back(exp_list); + exp_list->children.push_back(lhs); + exp_list->children.push_back(rhs); + return function; +} + + std::vector RowPolicyContext::getCurrentPolicyIDs() const { /// We don't lock `mutex` here. diff --git a/dbms/src/Access/RowPolicyContext.h b/dbms/src/Access/RowPolicyContext.h index 631ec4f020e..937cfc131b6 100644 --- a/dbms/src/Access/RowPolicyContext.h +++ b/dbms/src/Access/RowPolicyContext.h @@ -30,6 +30,9 @@ public: /// The returned filter can be a combination of the filters defined by multiple row policies. ASTPtr getCondition(const String & database, const String & table_name, ConditionIndex index) const; + /// Combines two conditions into one by using the logical AND operator. + static ASTPtr combineConditionsUsingAnd(const ASTPtr & lhs, const ASTPtr & rhs); + /// Returns IDs of all the policies used by the current user. std::vector getCurrentPolicyIDs() const; diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index f3057bf2d02..6ddb695874d 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -328,6 +328,7 @@ Context Context::createGlobal() Context res; res.quota = std::make_shared(); res.row_policy = std::make_shared(); + res.initial_row_policy = std::make_shared(); res.access_rights = std::make_shared(); res.shared = std::make_shared(); return res; @@ -650,11 +651,11 @@ void Context::checkAccess(const AccessFlags & access, const std::string_view & d void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); } void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); } -void Context::switchRowPolicy() +void Context::setInitialRowPolicy() { auto initial_user_id = getAccessControlManager().find(client_info.initial_user); if (initial_user_id) - row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id); + initial_row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id); } void Context::setUsersConfig(const ConfigurationPtr & config) diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index e0d65ffc5fe..391e1dcd505 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -170,6 +170,7 @@ private: AccessRightsContextPtr access_rights; QuotaContextPtr quota; /// Current quota. By default - empty quota, that have no limits. RowPolicyContextPtr row_policy; + RowPolicyContextPtr initial_row_policy; String current_database; Settings settings; /// Setting for query execution. std::shared_ptr settings_constraints; @@ -256,9 +257,11 @@ public: QuotaContextPtr getQuota() const { return quota; } RowPolicyContextPtr getRowPolicy() const { return row_policy; } - /// TODO: we need much better code for switching policies, quotas, access rights for initial user - /// Switches row policy in case we have initial user in client info - void switchRowPolicy(); + /// Sets an extra row policy based on `client_info.initial_user`, if it exists. + /// TODO: we need a better solution here. It seems we should pass the initial row policy + /// because a shard is allowed to don't have the initial user or it may be another user with the same name. + void setInitialRowPolicy(); + RowPolicyContextPtr getInitialRowPolicy() const { return initial_row_policy; } /** Take the list of users, quotas and configuration profiles from this config. * The list of users is completely replaced. diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index cf2ecf36056..0f5954c9672 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -372,6 +372,7 @@ InterpreterSelectQuery::InterpreterSelectQuery( /// Fix source_header for filter actions. auto row_policy_filter = context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER); + row_policy_filter = RowPolicyContext::combineConditionsUsingAnd(row_policy_filter, context->getInitialRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER)); if (row_policy_filter) { filter_info = std::make_shared(); @@ -515,7 +516,8 @@ Block InterpreterSelectQuery::getSampleBlockImpl(bool try_move_to_prewhere) /// PREWHERE optimization. /// Turn off, if the table filter (row-level security) is applied. - if (!context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER)) + if (!context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER) + && !context->getInitialRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER)) { auto optimize_prewhere = [&](auto & merge_tree) { From 6cac4a919b3209a246896730e71c29e8c8b82f6f Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Tue, 18 Feb 2020 05:36:29 +0300 Subject: [PATCH 07/19] Improve behaviour of row policies, now it applies for a table always when defined for that table (even for another user). --- dbms/src/Access/RowPolicyContextFactory.cpp | 11 +-- dbms/src/Access/UsersConfigAccessStorage.cpp | 87 ++++++++++--------- .../tests/integration/test_row_policy/test.py | 19 ++-- 3 files changed, 62 insertions(+), 55 deletions(-) diff --git a/dbms/src/Access/RowPolicyContextFactory.cpp b/dbms/src/Access/RowPolicyContextFactory.cpp index e6a6804bd77..ba58a11e61f 100644 --- a/dbms/src/Access/RowPolicyContextFactory.cpp +++ b/dbms/src/Access/RowPolicyContextFactory.cpp @@ -111,13 +111,10 @@ namespace ASTPtr getResult() && { /// Process permissive conditions. - if (!permissions.empty()) - restrictions.push_back(applyFunctionOR(std::move(permissions))); + restrictions.push_back(applyFunctionOR(std::move(permissions))); /// Process restrictive conditions. - if (!restrictions.empty()) - return applyFunctionAND(std::move(restrictions)); - return nullptr; + return applyFunctionAND(std::move(restrictions)); } private: @@ -276,10 +273,10 @@ void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context for (const auto & [policy_id, info] : all_policies) { + const auto & policy = *info.policy; + auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}]; if (info.canUseWithContext(context)) { - const auto & policy = *info.policy; - auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}]; mixers.policy_ids.push_back(policy_id); for (auto index : ext::range(0, MAX_CONDITION_INDEX)) if (info.parsed_conditions[index]) diff --git a/dbms/src/Access/UsersConfigAccessStorage.cpp b/dbms/src/Access/UsersConfigAccessStorage.cpp index dfda2f35035..a6bb2be467b 100644 --- a/dbms/src/Access/UsersConfigAccessStorage.cpp +++ b/dbms/src/Access/UsersConfigAccessStorage.cpp @@ -265,63 +265,70 @@ namespace std::vector parseRowPolicies(const Poco::Util::AbstractConfiguration & config, Poco::Logger * log) { - std::vector policies; + std::map, std::unordered_map> all_filters_map; Poco::Util::AbstractConfiguration::Keys user_names; - config.keys("users", user_names); - for (const String & user_name : user_names) + try { - const String databases_config = "users." + user_name + ".databases"; - if (config.has(databases_config)) + config.keys("users", user_names); + for (const String & user_name : user_names) { - Poco::Util::AbstractConfiguration::Keys databases; - config.keys(databases_config, databases); - - /// Read tables within databases - for (const String & database : databases) + const String databases_config = "users." + user_name + ".databases"; + if (config.has(databases_config)) { - const String database_config = databases_config + "." + database; - Poco::Util::AbstractConfiguration::Keys keys_in_database_config; - config.keys(database_config, keys_in_database_config); + Poco::Util::AbstractConfiguration::Keys databases; + config.keys(databases_config, databases); - /// Read table properties - for (const String & key_in_database_config : keys_in_database_config) + /// Read tables within databases + for (const String & database : databases) { - String table_name = key_in_database_config; - String filter_config = database_config + "." + table_name + ".filter"; + const String database_config = databases_config + "." + database; + Poco::Util::AbstractConfiguration::Keys keys_in_database_config; + config.keys(database_config, keys_in_database_config); - if (key_in_database_config.starts_with("table[")) + /// Read table properties + for (const String & key_in_database_config : keys_in_database_config) { - const auto table_name_config = database_config + "." + table_name + "[@name]"; - if (config.has(table_name_config)) - { - table_name = config.getString(table_name_config); - filter_config = database_config + ".table[@name='" + table_name + "']"; - } - } + String table_name = key_in_database_config; + String filter_config = database_config + "." + table_name + ".filter"; - if (config.has(filter_config)) - { - try + if (key_in_database_config.starts_with("table[")) { - auto policy = std::make_shared(); - policy->setFullName(database, table_name, user_name); - policy->conditions[RowPolicy::SELECT_FILTER] = config.getString(filter_config); - policy->roles.add(generateID(typeid(User), user_name)); - policies.push_back(policy); - } - catch (...) - { - tryLogCurrentException( - log, - "Could not parse row policy " + backQuote(user_name) + " on table " + backQuoteIfNeed(database) + "." - + backQuoteIfNeed(table_name)); + const auto table_name_config = database_config + "." + table_name + "[@name]"; + if (config.has(table_name_config)) + { + table_name = config.getString(table_name_config); + filter_config = database_config + ".table[@name='" + table_name + "']"; + } } + + all_filters_map[{database, table_name}][user_name] = config.getString(filter_config); } } } } } + catch (...) + { + tryLogCurrentException(log, "Could not parse row policies"); + } + + std::vector policies; + for (auto & [database_and_table_name, user_to_filters] : all_filters_map) + { + const auto & [database, table_name] = database_and_table_name; + for (const String & user_name : user_names) + { + auto it = user_to_filters.find(user_name); + String filter = (it != user_to_filters.end()) ? it->second : "1"; + + auto policy = std::make_shared(); + policy->setFullName(database, table_name, user_name); + policy->conditions[RowPolicy::SELECT_FILTER] = filter; + policy->roles.add(generateID(typeid(User), user_name)); + policies.push_back(policy); + } + } return policies; } } diff --git a/dbms/tests/integration/test_row_policy/test.py b/dbms/tests/integration/test_row_policy/test.py index 137556f80e5..6db24f5799e 100644 --- a/dbms/tests/integration/test_row_policy/test.py +++ b/dbms/tests/integration/test_row_policy/test.py @@ -209,35 +209,38 @@ def test_introspection(): assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table1')") == "['default']\n" assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table2')") == "['default']\n" assert instance.query("SELECT currentRowPolicies('mydb', 'filtered_table3')") == "['default']\n" - assert instance.query("SELECT arraySort(currentRowPolicies())") == "[('mydb','filtered_table1','default'),('mydb','filtered_table2','default'),('mydb','filtered_table3','default')]\n" + assert instance.query("SELECT arraySort(currentRowPolicies())") == "[('mydb','filtered_table1','default'),('mydb','filtered_table2','default'),('mydb','filtered_table3','default'),('mydb','local','default')]\n" policy1 = "mydb\tfiltered_table1\tdefault\tdefault ON mydb.filtered_table1\t9e8a8f62-4965-2b5e-8599-57c7b99b3549\tusers.xml\t0\ta = 1\t\t\t\t\n" policy2 = "mydb\tfiltered_table2\tdefault\tdefault ON mydb.filtered_table2\tcffae79d-b9bf-a2ef-b798-019c18470b25\tusers.xml\t0\ta + b < 1 or c - d > 5\t\t\t\t\n" policy3 = "mydb\tfiltered_table3\tdefault\tdefault ON mydb.filtered_table3\t12fc5cef-e3da-3940-ec79-d8be3911f42b\tusers.xml\t0\tc = 1\t\t\t\t\n" - policy4 = "mydb\tlocal\tanother\tanother ON mydb.local\t5b23c389-7e18-06bf-a6bc-dd1afbbc0a97\tusers.xml\t0\ta = 1\t\t\t\t\n" + policy4 = "mydb\tlocal\tdefault\tdefault ON mydb.local\tcdacaeb5-1d97-f99d-2bb0-4574f290629c\tusers.xml\t0\t1\t\t\t\t\n" assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table1'), id) ORDER BY table, name") == policy1 assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table2'), id) ORDER BY table, name") == policy2 assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'filtered_table3'), id) ORDER BY table, name") == policy3 - assert instance.query("SELECT * from system.row_policies ORDER BY table, name") == policy1 + policy2 + policy3 + policy4 - assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs(), id) ORDER BY table, name") == policy1 + policy2 + policy3 + assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs('mydb', 'local'), id) ORDER BY table, name") == policy4 + assert instance.query("SELECT * from system.row_policies WHERE has(currentRowPolicyIDs(), id) ORDER BY table, name") == policy1 + policy2 + policy3 + policy4 def test_dcl_introspection(): - assert instance.query("SHOW POLICIES ON mydb.filtered_table1") == "default\n" + assert instance.query("SHOW POLICIES ON mydb.filtered_table1") == "another\ndefault\n" assert instance.query("SHOW POLICIES CURRENT ON mydb.filtered_table2") == "default\n" - assert instance.query("SHOW POLICIES") == "another ON mydb.local\ndefault ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\n" - assert instance.query("SHOW POLICIES CURRENT") == "default ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\n" + assert instance.query("SHOW POLICIES") == "another ON mydb.filtered_table1\nanother ON mydb.filtered_table2\nanother ON mydb.filtered_table3\nanother ON mydb.local\ndefault ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\ndefault ON mydb.local\n" + assert instance.query("SHOW POLICIES CURRENT") == "default ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\ndefault ON mydb.local\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE POLICY default ON mydb.filtered_table1 FOR SELECT USING a = 1 TO default\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE POLICY default ON mydb.filtered_table2 FOR SELECT USING ((a + b) < 1) OR ((c - d) > 5) TO default\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE POLICY default ON mydb.filtered_table3 FOR SELECT USING c = 1 TO default\n" + assert instance.query("SHOW CREATE POLICY default ON mydb.local") == "CREATE POLICY default ON mydb.local FOR SELECT USING 1 TO default\n" copy_policy_xml('all_rows.xml') + assert instance.query("SHOW POLICIES CURRENT") == "default ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE POLICY default ON mydb.filtered_table1 FOR SELECT USING 1 TO default\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE POLICY default ON mydb.filtered_table2 FOR SELECT USING 1 TO default\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE POLICY default ON mydb.filtered_table3 FOR SELECT USING 1 TO default\n" copy_policy_xml('no_rows.xml') + assert instance.query("SHOW POLICIES CURRENT") == "default ON mydb.filtered_table1\ndefault ON mydb.filtered_table2\ndefault ON mydb.filtered_table3\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table1") == "CREATE POLICY default ON mydb.filtered_table1 FOR SELECT USING NULL TO default\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table2") == "CREATE POLICY default ON mydb.filtered_table2 FOR SELECT USING NULL TO default\n" assert instance.query("SHOW CREATE POLICY default ON mydb.filtered_table3") == "CREATE POLICY default ON mydb.filtered_table3 FOR SELECT USING NULL TO default\n" @@ -251,7 +254,7 @@ def test_dcl_management(): assert instance.query("SHOW POLICIES") == "" instance.query("CREATE POLICY pA ON mydb.filtered_table1 FOR SELECT USING a Date: Wed, 12 Feb 2020 06:03:33 +0300 Subject: [PATCH 08/19] Move event handling away from Context. --- dbms/programs/server/MySQLHandler.cpp | 2 +- dbms/src/Access/AccessControlManager.cpp | 62 +---- dbms/src/Access/AccessControlManager.h | 19 +- dbms/src/Access/AccessRightsContext.cpp | 248 +++++++++++++----- dbms/src/Access/AccessRightsContext.h | 84 ++++-- .../src/Access/AccessRightsContextFactory.cpp | 44 ++++ dbms/src/Access/AccessRightsContextFactory.h | 29 ++ dbms/src/Access/IAccessStorage.cpp | 11 +- dbms/src/Access/RowPolicyContextFactory.h | 2 - dbms/src/Core/MySQLProtocol.h | 2 +- dbms/src/Interpreters/Context.cpp | 190 ++++++++------ dbms/src/Interpreters/Context.h | 60 ++--- ...InterpreterShowCreateAccessEntityQuery.cpp | 2 +- dbms/src/Interpreters/tests/users.cpp | 2 +- 14 files changed, 481 insertions(+), 276 deletions(-) create mode 100644 dbms/src/Access/AccessRightsContextFactory.cpp create mode 100644 dbms/src/Access/AccessRightsContextFactory.h diff --git a/dbms/programs/server/MySQLHandler.cpp b/dbms/programs/server/MySQLHandler.cpp index 54267313736..262132f6acc 100644 --- a/dbms/programs/server/MySQLHandler.cpp +++ b/dbms/programs/server/MySQLHandler.cpp @@ -218,7 +218,7 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl try { // For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used. - auto user = connection_context.getAccessControlManager().getUser(user_name); + auto user = connection_context.getAccessControlManager().read(user_name); const DB::Authentication::Type user_auth_type = user->authentication.getType(); if (user_auth_type != DB::Authentication::DOUBLE_SHA1_PASSWORD && user_auth_type != DB::Authentication::PLAINTEXT_PASSWORD && user_auth_type != DB::Authentication::NO_PASSWORD) { diff --git a/dbms/src/Access/AccessControlManager.cpp b/dbms/src/Access/AccessControlManager.cpp index c2eea0cc605..4fc002a764b 100644 --- a/dbms/src/Access/AccessControlManager.cpp +++ b/dbms/src/Access/AccessControlManager.cpp @@ -2,10 +2,9 @@ #include #include #include -#include -#include +#include #include -#include +#include namespace DB @@ -24,8 +23,9 @@ namespace AccessControlManager::AccessControlManager() : MultipleAccessStorage(createStorages()), - quota_context_factory(std::make_unique(*this)), - row_policy_context_factory(std::make_unique(*this)) + access_rights_context_factory(std::make_unique(*this)), + row_policy_context_factory(std::make_unique(*this)), + quota_context_factory(std::make_unique(*this)) { } @@ -35,53 +35,6 @@ AccessControlManager::~AccessControlManager() } -UserPtr AccessControlManager::getUser( - const String & user_name, std::function on_change, ext::scope_guard * subscription) const -{ - return getUser(getID(user_name), std::move(on_change), subscription); -} - - -UserPtr AccessControlManager::getUser( - const UUID & user_id, std::function on_change, ext::scope_guard * subscription) const -{ - if (on_change && subscription) - { - *subscription = subscribeForChanges(user_id, [on_change](const UUID &, const AccessEntityPtr & user) - { - if (user) - on_change(typeid_cast(user)); - }); - } - return read(user_id); -} - - -UserPtr AccessControlManager::authorizeAndGetUser( - const String & user_name, - const String & password, - const Poco::Net::IPAddress & address, - std::function on_change, - ext::scope_guard * subscription) const -{ - return authorizeAndGetUser(getID(user_name), password, address, std::move(on_change), subscription); -} - - -UserPtr AccessControlManager::authorizeAndGetUser( - const UUID & user_id, - const String & password, - const Poco::Net::IPAddress & address, - std::function on_change, - ext::scope_guard * subscription) const -{ - auto user = getUser(user_id, on_change, subscription); - user->allowed_client_hosts.checkContains(address, user->getName()); - user->authentication.checkPassword(password, user->getName()); - return user; -} - - void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguration & users_config) { auto & users_config_access_storage = dynamic_cast(getStorageByIndex(1)); @@ -89,9 +42,10 @@ void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguratio } -AccessRightsContextPtr AccessControlManager::getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database) const +AccessRightsContextPtr AccessControlManager::getAccessRightsContext( + const UUID & user_id, const Settings & settings, const String & current_database, const ClientInfo & client_info) const { - return std::make_shared(user, client_info, settings, current_database); + return access_rights_context_factory->createContext(user_id, settings, current_database, client_info); } diff --git a/dbms/src/Access/AccessControlManager.h b/dbms/src/Access/AccessControlManager.h index dd9a8285aeb..4549f1afde8 100644 --- a/dbms/src/Access/AccessControlManager.h +++ b/dbms/src/Access/AccessControlManager.h @@ -2,7 +2,6 @@ #include #include -#include #include @@ -20,10 +19,9 @@ namespace Poco namespace DB { -struct User; -using UserPtr = std::shared_ptr; class AccessRightsContext; using AccessRightsContextPtr = std::shared_ptr; +class AccessRightsContextFactory; class RowPolicyContext; using RowPolicyContextPtr = std::shared_ptr; class RowPolicyContextFactory; @@ -44,21 +42,20 @@ public: void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config); - UserPtr getUser(const String & user_name, std::function on_change = {}, ext::scope_guard * subscription = nullptr) const; - UserPtr getUser(const UUID & user_id, std::function on_change = {}, ext::scope_guard * subscription = nullptr) const; - UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, std::function on_change = {}, ext::scope_guard * subscription = nullptr) const; - UserPtr authorizeAndGetUser(const UUID & user_id, const String & password, const Poco::Net::IPAddress & address, std::function on_change = {}, ext::scope_guard * subscription = nullptr) const; - - AccessRightsContextPtr getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database) const; + AccessRightsContextPtr getAccessRightsContext( + const UUID & user_id, const Settings & settings, const String & current_database, const ClientInfo & client_info) const; RowPolicyContextPtr getRowPolicyContext(const UUID & user_id) const; - QuotaContextPtr getQuotaContext(const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const; + QuotaContextPtr getQuotaContext( + const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const; + std::vector getQuotaUsageInfo() const; private: - std::unique_ptr quota_context_factory; + std::unique_ptr access_rights_context_factory; std::unique_ptr row_policy_context_factory; + std::unique_ptr quota_context_factory; }; } diff --git a/dbms/src/Access/AccessRightsContext.cpp b/dbms/src/Access/AccessRightsContext.cpp index 5c239feee15..52404e77b2e 100644 --- a/dbms/src/Access/AccessRightsContext.cpp +++ b/dbms/src/Access/AccessRightsContext.cpp @@ -1,4 +1,7 @@ #include +#include +#include +#include #include #include #include @@ -6,6 +9,7 @@ #include #include #include +#include #include @@ -17,6 +21,7 @@ namespace ErrorCodes extern const int READONLY; extern const int QUERY_IS_PROHIBITED; extern const int FUNCTION_NOT_ALLOWED; + extern const int UNKNOWN_USER; } @@ -85,25 +90,69 @@ AccessRightsContext::AccessRightsContext() { auto everything_granted = boost::make_shared(); everything_granted->grant(AccessType::ALL); - result_access_cache[0] = std::move(everything_granted); + boost::range::fill(result_access_cache, everything_granted); + row_policy_context = std::make_shared(); + quota_context = std::make_shared(); } -AccessRightsContext::AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_) - : user(user_) - , readonly(settings.readonly) - , allow_ddl(settings.allow_ddl) - , allow_introspection(settings.allow_introspection_functions) - , current_database(current_database_) - , interface(client_info_.interface) - , http_method(client_info_.http_method) - , trace_log(&Poco::Logger::get("AccessRightsContext (" + user_->getName() + ")")) +AccessRightsContext::AccessRightsContext(const AccessControlManager & manager_, const Params & params_) + : manager(&manager_) + , params(params_) { + subscription_for_user_change = manager->subscribeForChanges( + *params.user_id, [this](const UUID &, const AccessEntityPtr & entity) + { + UserPtr changed_user = entity ? typeid_cast(entity) : nullptr; + std::lock_guard lock{mutex}; + setUser(changed_user); + }); + + setUser(manager->read(*params.user_id)); +} + + +void AccessRightsContext::setUser(const UserPtr & user_) const +{ + user = user_; + if (!user) + { + /// User has been dropped. + auto nothing_granted = boost::make_shared(); + boost::range::fill(result_access_cache, nothing_granted); + subscription_for_user_change = {}; + row_policy_context = std::make_shared(); + quota_context = std::make_shared(); + return; + } + + user_name = user->getName(); + trace_log = &Poco::Logger::get("AccessRightsContext (" + user_name + ")"); + boost::range::fill(result_access_cache, nullptr /* need recalculate */); + row_policy_context = manager->getRowPolicyContext(*params.user_id); + quota_context = manager->getQuotaContext(*params.user_id, user_name, params.address, params.quota_key); +} + + +void AccessRightsContext::checkPassword(const String & password) const +{ + std::lock_guard lock{mutex}; + if (!user) + throw Exception(user_name + ": User has been dropped", ErrorCodes::UNKNOWN_USER); + user->authentication.checkPassword(password, user_name); +} + +void AccessRightsContext::checkHostIsAllowed() const +{ + std::lock_guard lock{mutex}; + if (!user) + throw Exception(user_name + ": User has been dropped", ErrorCodes::UNKNOWN_USER); + user->allowed_client_hosts.checkContains(params.address, user_name); } template -bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const +bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const { auto result_access = calculateResultAccess(grant_option); bool is_granted = result_access->isGranted(access, args...); @@ -126,12 +175,16 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc auto show_error = [&](const String & msg, [[maybe_unused]] int error_code) { if constexpr (mode == THROW_IF_ACCESS_DENIED) - throw Exception(user->getName() + ": " + msg, error_code); + throw Exception(user_name + ": " + msg, error_code); else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED) - LOG_WARNING(log_, user->getName() + ": " + msg + formatSkippedMessage(args...)); + LOG_WARNING(log_, user_name + ": " + msg + formatSkippedMessage(args...)); }; - if (grant_option && calculateResultAccess(false, readonly, allow_ddl, allow_introspection)->isGranted(access, args...)) + if (!user) + { + show_error("User has been dropped", ErrorCodes::UNKNOWN_USER); + } + else if (grant_option && calculateResultAccess(false, params.readonly, params.allow_ddl, params.allow_introspection)->isGranted(access, args...)) { show_error( "Not enough privileges. " @@ -140,9 +193,9 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc + AccessRightsElement{access, args...}.toString() + " WITH GRANT OPTION", ErrorCodes::ACCESS_DENIED); } - else if (readonly && calculateResultAccess(false, false, allow_ddl, allow_introspection)->isGranted(access, args...)) + else if (params.readonly && calculateResultAccess(false, false, params.allow_ddl, params.allow_introspection)->isGranted(access, args...)) { - if (interface == ClientInfo::Interface::HTTP && http_method == ClientInfo::HTTPMethod::GET) + if (params.interface == ClientInfo::Interface::HTTP && params.http_method == ClientInfo::HTTPMethod::GET) show_error( "Cannot execute query in readonly mode. " "For queries over HTTP, method GET implies readonly. You should use method POST for modifying queries", @@ -150,11 +203,11 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc else show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY); } - else if (!allow_ddl && calculateResultAccess(false, readonly, true, allow_introspection)->isGranted(access, args...)) + else if (!params.allow_ddl && calculateResultAccess(false, params.readonly, true, params.allow_introspection)->isGranted(access, args...)) { show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED); } - else if (!allow_introspection && calculateResultAccess(false, readonly, allow_ddl, true)->isGranted(access, args...)) + else if (!params.allow_introspection && calculateResultAccess(false, params.readonly, params.allow_ddl, true)->isGranted(access, args...)) { show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED); } @@ -171,94 +224,94 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc template -bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElement & element) const +bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessRightsElement & element) const { if (element.any_database) { - return checkImpl(log_, element.access_flags); + return checkAccessImpl(log_, element.access_flags); } else if (element.any_table) { if (element.database.empty()) - return checkImpl(log_, element.access_flags, current_database); + return checkAccessImpl(log_, element.access_flags, params.current_database); else - return checkImpl(log_, element.access_flags, element.database); + return checkAccessImpl(log_, element.access_flags, element.database); } else if (element.any_column) { if (element.database.empty()) - return checkImpl(log_, element.access_flags, current_database, element.table); + return checkAccessImpl(log_, element.access_flags, params.current_database, element.table); else - return checkImpl(log_, element.access_flags, element.database, element.table); + return checkAccessImpl(log_, element.access_flags, element.database, element.table); } else { if (element.database.empty()) - return checkImpl(log_, element.access_flags, current_database, element.table, element.columns); + return checkAccessImpl(log_, element.access_flags, params.current_database, element.table, element.columns); else - return checkImpl(log_, element.access_flags, element.database, element.table, element.columns); + return checkAccessImpl(log_, element.access_flags, element.database, element.table, element.columns); } } template -bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElements & elements) const +bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessRightsElements & elements) const { for (const auto & element : elements) - if (!checkImpl(log_, element)) + if (!checkAccessImpl(log_, element)) return false; return true; } -void AccessRightsContext::check(const AccessFlags & access) const { checkImpl(nullptr, access); } -void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl(nullptr, access, database); } -void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl(nullptr, access, database, table); } -void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl(nullptr, access, database, table, column); } -void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkImpl(nullptr, access, database, table, columns); } -void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl(nullptr, access, database, table, columns); } -void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl(nullptr, access); } -void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl(nullptr, access); } +void AccessRightsContext::checkAccess(const AccessFlags & access) const { checkAccessImpl(nullptr, access); } +void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database) const { checkAccessImpl(nullptr, access, database); } +void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkAccessImpl(nullptr, access, database, table); } +void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkAccessImpl(nullptr, access, database, table, column); } +void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkAccessImpl(nullptr, access, database, table, columns); } +void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkAccessImpl(nullptr, access, database, table, columns); } +void AccessRightsContext::checkAccess(const AccessRightsElement & access) const { checkAccessImpl(nullptr, access); } +void AccessRightsContext::checkAccess(const AccessRightsElements & access) const { checkAccessImpl(nullptr, access); } -bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl(nullptr, access); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl(nullptr, access, database); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl(nullptr, access, database, table); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl(nullptr, access, database, table, column); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkImpl(nullptr, access, database, table, columns); } -bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl(nullptr, access, database, table, columns); } -bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl(nullptr, access); } -bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl(nullptr, access); } +bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkAccessImpl(nullptr, access); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(nullptr, access, database); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(nullptr, access, database, table); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(nullptr, access, database, table, column); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(nullptr, access, database, table, columns); } +bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(nullptr, access, database, table, columns); } +bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkAccessImpl(nullptr, access); } +bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkAccessImpl(nullptr, access); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl(log_, access); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl(log_, access, database); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl(log_, access, database, table); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl(log_, access, database, table, column); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkImpl(log_, access, database, table, columns); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl(log_, access, database, table, columns); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl(log_, access); } -bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl(log_, access); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkAccessImpl(log_, access); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(log_, access, database); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(log_, access, database, table); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(log_, access, database, table, column); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(log_, access, database, table, columns); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(log_, access, database, table, columns); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkAccessImpl(log_, access); } +bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkAccessImpl(log_, access); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access) const { checkImpl(nullptr, access); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database) const { checkImpl(nullptr, access, database); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl(nullptr, access, database, table); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl(nullptr, access, database, table, column); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkImpl(nullptr, access, database, table, columns); } -void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl(nullptr, access, database, table, columns); } -void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) const { checkImpl(nullptr, access); } -void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkImpl(nullptr, access); } +void AccessRightsContext::checkGrantOption(const AccessFlags & access) const { checkAccessImpl(nullptr, access); } +void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database) const { checkAccessImpl(nullptr, access, database); } +void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkAccessImpl(nullptr, access, database, table); } +void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkAccessImpl(nullptr, access, database, table, column); } +void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { checkAccessImpl(nullptr, access, database, table, columns); } +void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkAccessImpl(nullptr, access, database, table, columns); } +void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) const { checkAccessImpl(nullptr, access); } +void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkAccessImpl(nullptr, access); } boost::shared_ptr AccessRightsContext::calculateResultAccess(bool grant_option) const { - return calculateResultAccess(grant_option, readonly, allow_ddl, allow_introspection); + return calculateResultAccess(grant_option, params.readonly, params.allow_ddl, params.allow_introspection); } boost::shared_ptr AccessRightsContext::calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const { - size_t cache_index = static_cast(readonly_ != readonly) - + static_cast(allow_ddl_ != allow_ddl) * 2 + - + static_cast(allow_introspection_ != allow_introspection) * 3 + size_t cache_index = static_cast(readonly_ != params.readonly) + + static_cast(allow_ddl_ != params.allow_ddl) * 2 + + + static_cast(allow_introspection_ != params.allow_introspection) * 3 + static_cast(grant_option) * 4; assert(cache_index < std::size(result_access_cache)); auto cached = result_access_cache[cache_index].load(); @@ -306,10 +359,75 @@ boost::shared_ptr AccessRightsContext::calculateResultAccess result_access_cache[cache_index].store(result_ptr); - if (trace_log && (readonly == readonly_) && (allow_ddl == allow_ddl_) && (allow_introspection == allow_introspection_)) + if (trace_log && (params.readonly == readonly_) && (params.allow_ddl == allow_ddl_) && (params.allow_introspection == allow_introspection_)) LOG_TRACE(trace_log, "List of all grants: " << result_ptr->toString() << (grant_option ? " WITH GRANT OPTION" : "")); return result_ptr; } + +UserPtr AccessRightsContext::getUser() const +{ + std::lock_guard lock{mutex}; + return user; +} + +String AccessRightsContext::getUserName() const +{ + std::lock_guard lock{mutex}; + return user_name; +} + +RowPolicyContextPtr AccessRightsContext::getRowPolicy() const +{ + std::lock_guard lock{mutex}; + return row_policy_context; +} + +QuotaContextPtr AccessRightsContext::getQuota() const +{ + std::lock_guard lock{mutex}; + return quota_context; +} + + +bool operator <(const AccessRightsContext::Params & lhs, const AccessRightsContext::Params & rhs) +{ +#define ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(field) \ + if (lhs.field < rhs.field) \ + return true; \ + if (lhs.field > rhs.field) \ + return false + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(user_id); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(address); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(quota_key); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_database); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(readonly); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_ddl); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_introspection); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(interface); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(http_method); + return false; +#undef ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER +} + + +bool operator ==(const AccessRightsContext::Params & lhs, const AccessRightsContext::Params & rhs) +{ +#define ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(field) \ + if (lhs.field != rhs.field) \ + return false + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(user_id); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(address); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(quota_key); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_database); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(readonly); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_ddl); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_introspection); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(interface); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(http_method); + return true; +#undef ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER +} + } diff --git a/dbms/src/Access/AccessRightsContext.h b/dbms/src/Access/AccessRightsContext.h index a39abbcabba..68a53351098 100644 --- a/dbms/src/Access/AccessRightsContext.h +++ b/dbms/src/Access/AccessRightsContext.h @@ -2,6 +2,9 @@ #include #include +#include +#include +#include #include #include @@ -10,29 +13,62 @@ namespace Poco { class Logger; } namespace DB { -struct Settings; struct User; using UserPtr = std::shared_ptr; +struct RowPolicyContext; +using RowPolicyContextPtr = std::shared_ptr; +struct QuotaContext; +using QuotaContextPtr = std::shared_ptr; +struct Settings; +class AccessControlManager; class AccessRightsContext { public: + struct Params + { + std::optional user_id; + UInt64 readonly = 0; + bool allow_ddl = false; + bool allow_introspection = false; + String current_database; + ClientInfo::Interface interface = ClientInfo::Interface::TCP; + ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; + Poco::Net::IPAddress address; + String quota_key; + + friend bool operator ==(const Params & lhs, const Params & rhs); + friend bool operator !=(const Params & lhs, const Params & rhs) { return !(lhs == rhs); } + friend bool operator <(const Params & lhs, const Params & rhs); + friend bool operator >(const Params & lhs, const Params & rhs) { return rhs < lhs; } + friend bool operator <=(const Params & lhs, const Params & rhs) { return !(rhs < lhs); } + friend bool operator >=(const Params & lhs, const Params & rhs) { return !(lhs < rhs); } + }; + /// Default constructor creates access rights' context which allows everything. AccessRightsContext(); - AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_); + const Params & getParams() const { return params; } + UserPtr getUser() const; + String getUserName() const; + + void checkPassword(const String & password) const; + void checkHostIsAllowed() const; + + RowPolicyContextPtr getRowPolicy() const; + QuotaContextPtr getQuota() const; /// Checks if a specified access granted, and throws an exception if not. /// Empty database means the current database. - void check(const AccessFlags & access) const; - void check(const AccessFlags & access, const std::string_view & database) const; - void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; - void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; - void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; - void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; - void check(const AccessRightsElement & access) const; - void check(const AccessRightsElements & access) const; + void checkAccess(const AccessFlags & access) const; + void checkAccess(const AccessFlags & access, const std::string_view & database) const; + void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; + void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; + void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const; + void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; + void checkAccess(const AccessRightsElement & access) const; + void checkAccess(const AccessRightsElements & access) const; /// Checks if a specified access granted. bool isGranted(const AccessFlags & access) const; @@ -65,27 +101,33 @@ public: void checkGrantOption(const AccessRightsElements & access) const; private: + friend class AccessRightsContextFactory; + friend struct ext::shared_ptr_helper; + AccessRightsContext(const AccessControlManager & manager_, const Params & params_); /// AccessRightsContext should be created by AccessRightsContextFactory. + + void setUser(const UserPtr & user_) const; + template - bool checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const; + bool checkAccessImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const; template - bool checkImpl(Poco::Logger * log_, const AccessRightsElement & access) const; + bool checkAccessImpl(Poco::Logger * log_, const AccessRightsElement & access) const; template - bool checkImpl(Poco::Logger * log_, const AccessRightsElements & access) const; + bool checkAccessImpl(Poco::Logger * log_, const AccessRightsElements & access) const; boost::shared_ptr calculateResultAccess(bool grant_option) const; boost::shared_ptr calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const; - const UserPtr user; - const UInt64 readonly = 0; - const bool allow_ddl = true; - const bool allow_introspection = true; - const String current_database; - const ClientInfo::Interface interface = ClientInfo::Interface::TCP; - const ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; - Poco::Logger * const trace_log = nullptr; + const AccessControlManager * manager = nullptr; + const Params params; + mutable Poco::Logger * trace_log = nullptr; + mutable UserPtr user; + mutable String user_name; + mutable ext::scope_guard subscription_for_user_change; mutable boost::atomic_shared_ptr result_access_cache[7]; + mutable RowPolicyContextPtr row_policy_context; + mutable QuotaContextPtr quota_context; mutable std::mutex mutex; }; diff --git a/dbms/src/Access/AccessRightsContextFactory.cpp b/dbms/src/Access/AccessRightsContextFactory.cpp new file mode 100644 index 00000000000..21cd3657170 --- /dev/null +++ b/dbms/src/Access/AccessRightsContextFactory.cpp @@ -0,0 +1,44 @@ +#include +#include +#include + + +namespace DB +{ +AccessRightsContextFactory::AccessRightsContextFactory(const AccessControlManager & manager_) + : manager(manager_), cache(600000 /* 10 minutes */) {} + +AccessRightsContextFactory::~AccessRightsContextFactory() = default; + + +AccessRightsContextPtr AccessRightsContextFactory::createContext(const Params & params) +{ + std::lock_guard lock{mutex}; + auto x = cache.get(params); + if (x) + return *x; + auto res = ext::shared_ptr_helper::create(manager, params); + cache.add(params, res); + return res; +} + +AccessRightsContextPtr AccessRightsContextFactory::createContext( + const UUID & user_id, + const Settings & settings, + const String & current_database, + const ClientInfo & client_info) +{ + Params params; + params.user_id = user_id; + params.current_database = current_database; + params.readonly = settings.readonly; + params.allow_ddl = settings.allow_ddl; + params.allow_introspection = settings.allow_introspection_functions; + params.interface = client_info.interface; + params.http_method = client_info.http_method; + params.address = client_info.current_address.host(); + params.quota_key = client_info.quota_key; + return createContext(params); +} + +} diff --git a/dbms/src/Access/AccessRightsContextFactory.h b/dbms/src/Access/AccessRightsContextFactory.h new file mode 100644 index 00000000000..9f61c1099c5 --- /dev/null +++ b/dbms/src/Access/AccessRightsContextFactory.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include + + +namespace DB +{ +class AccessControlManager; + + +class AccessRightsContextFactory +{ +public: + AccessRightsContextFactory(const AccessControlManager & manager_); + ~AccessRightsContextFactory(); + + using Params = AccessRightsContext::Params; + AccessRightsContextPtr createContext(const Params & params); + AccessRightsContextPtr createContext(const UUID & user_id, const Settings & settings, const String & current_database, const ClientInfo & client_info); + +private: + const AccessControlManager & manager; + Poco::ExpireCache cache; + std::mutex mutex; +}; + +} diff --git a/dbms/src/Access/IAccessStorage.cpp b/dbms/src/Access/IAccessStorage.cpp index 9120d626d09..4ff8ed14d86 100644 --- a/dbms/src/Access/IAccessStorage.cpp +++ b/dbms/src/Access/IAccessStorage.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -15,6 +16,7 @@ namespace ErrorCodes extern const int ACCESS_ENTITY_ALREADY_EXISTS; extern const int ACCESS_ENTITY_FOUND_DUPLICATES; extern const int ACCESS_ENTITY_STORAGE_READONLY; + extern const int UNKNOWN_USER; } @@ -365,8 +367,13 @@ void IAccessStorage::throwNotFound(const UUID & id) const void IAccessStorage::throwNotFound(std::type_index type, const String & name) const { - throw Exception( - getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND); + int error_code; + if (type == typeid(User)) + error_code = ErrorCodes::UNKNOWN_USER; + else + error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND; + + throw Exception(getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), error_code); } diff --git a/dbms/src/Access/RowPolicyContextFactory.h b/dbms/src/Access/RowPolicyContextFactory.h index c393a75285b..911f795bcc1 100644 --- a/dbms/src/Access/RowPolicyContextFactory.h +++ b/dbms/src/Access/RowPolicyContextFactory.h @@ -4,14 +4,12 @@ #include #include #include -#include namespace DB { class AccessControlManager; - /// Stores read and parsed row policies. class RowPolicyContextFactory { diff --git a/dbms/src/Core/MySQLProtocol.h b/dbms/src/Core/MySQLProtocol.h index c1e0f923bfe..b60c98cca97 100644 --- a/dbms/src/Core/MySQLProtocol.h +++ b/dbms/src/Core/MySQLProtocol.h @@ -953,7 +953,7 @@ public: throw Exception("Wrong size of auth response. Expected: " + std::to_string(Poco::SHA1Engine::DIGEST_SIZE) + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.", ErrorCodes::UNKNOWN_EXCEPTION); - auto user = context.getAccessControlManager().getUser(user_name); + auto user = context.getAccessControlManager().read(user_name); Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1(); assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE); diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index 6ddb695874d..3fc406350cf 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -27,11 +27,10 @@ #include #include #include +#include +#include #include #include -#include -#include -#include #include #include #include @@ -326,10 +325,8 @@ Context & Context::operator=(const Context &) = default; Context Context::createGlobal() { Context res; - res.quota = std::make_shared(); - res.row_policy = std::make_shared(); - res.initial_row_policy = std::make_shared(); res.access_rights = std::make_shared(); + res.initial_row_policy = std::make_shared(); res.shared = std::make_shared(); return res; } @@ -624,39 +621,17 @@ const Poco::Util::AbstractConfiguration & Context::getConfigRef() const return shared->config ? *shared->config : Poco::Util::Application::instance().config(); } + AccessControlManager & Context::getAccessControlManager() { - auto lock = getLock(); return shared->access_control_manager; } const AccessControlManager & Context::getAccessControlManager() const { - auto lock = getLock(); return shared->access_control_manager; } -template -void Context::checkAccessImpl(const Args &... args) const -{ - getAccessRights()->check(args...); -} - -void Context::checkAccess(const AccessFlags & access) const { return checkAccessImpl(access); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(access, database); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(access, database, table); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(access, database, table, column); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(access, database, table, columns); } -void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(access, database, table, columns); } -void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); } -void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); } - -void Context::setInitialRowPolicy() -{ - auto initial_user_id = getAccessControlManager().find(client_info.initial_user); - if (initial_user_id) - initial_row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id); -} void Context::setUsersConfig(const ConfigurationPtr & config) { @@ -671,10 +646,112 @@ ConfigurationPtr Context::getUsersConfig() return shared->users_config; } + +void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key) +{ + auto lock = getLock(); + + client_info.current_user = name; + client_info.current_password = password; + client_info.current_address = address; + if (!quota_key.empty()) + client_info.quota_key = quota_key; + + auto new_user_id = getAccessControlManager().getID(name); + auto new_access_rights = getAccessControlManager().getAccessRightsContext(new_user_id, settings, current_database, client_info); + new_access_rights->checkHostIsAllowed(); + new_access_rights->checkPassword(password); + + user_id = new_user_id; + access_rights = std::move(new_access_rights); + + calculateUserSettings(); +} + +std::shared_ptr Context::getUser() const +{ + auto lock = getLock(); + return access_rights->getUser(); +} + +String Context::getUserName() const +{ + auto lock = getLock(); + return access_rights->getUserName(); +} + +UUID Context::getUserID() const +{ + auto lock = getLock(); + if (!user_id) + throw Exception("No current user", ErrorCodes::LOGICAL_ERROR); + return *user_id; +} + + +void Context::calculateAccessRights() +{ + auto lock = getLock(); + if (user_id) + access_rights = getAccessControlManager().getAccessRightsContext(*user_id, settings, current_database, client_info); +} + + +template +void Context::checkAccessImpl(const Args &... args) const +{ + getAccessRights()->checkAccess(args...); +} + +void Context::checkAccess(const AccessFlags & access) const { return checkAccessImpl(access); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(access, database); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(access, database, table); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(access, database, table, column); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector & columns) const { return checkAccessImpl(access, database, table, columns); } +void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(access, database, table, columns); } +void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); } +void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); } + +AccessRightsContextPtr Context::getAccessRights() const +{ + auto lock = getLock(); + return access_rights; +} + +RowPolicyContextPtr Context::getRowPolicy() const +{ + return getAccessRights()->getRowPolicy(); +} + +void Context::setInitialRowPolicy() +{ + auto lock = getLock(); + auto initial_user_id = getAccessControlManager().find(client_info.initial_user); + if (initial_user_id) + initial_row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id); +} + +RowPolicyContextPtr Context::getInitialRowPolicy() const +{ + auto lock = getLock(); + return initial_row_policy; +} + + +QuotaContextPtr Context::getQuota() const +{ + return getAccessRights()->getQuota(); +} + + void Context::calculateUserSettings() { auto lock = getLock(); - String profile = user->profile; + String profile = getUser()->profile; + + bool old_readonly = settings.readonly; + bool old_allow_ddl = settings.allow_ddl; + bool old_allow_introspection_functions = settings.allow_introspection_functions; /// 1) Set default settings (hardcoded values) /// NOTE: we ignore global_context settings (from which it is usually copied) @@ -689,13 +766,10 @@ void Context::calculateUserSettings() /// 3) Apply settings from current user setProfile(profile); -} -void Context::calculateAccessRights() -{ - auto lock = getLock(); - if (user) - std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(user, client_info, settings, current_database)); + /// 4) Recalculate access rights if it's necessary. + if ((settings.readonly != old_readonly) || (settings.allow_ddl != old_allow_ddl) || (settings.allow_introspection_functions != old_allow_introspection_functions)) + calculateAccessRights(); } void Context::setProfile(const String & profile) @@ -708,50 +782,6 @@ void Context::setProfile(const String & profile) settings_constraints = std::move(new_constraints); } -std::shared_ptr Context::getUser() const -{ - if (!user) - throw Exception("No current user", ErrorCodes::LOGICAL_ERROR); - return user; -} - -UUID Context::getUserID() const -{ - if (!user) - throw Exception("No current user", ErrorCodes::LOGICAL_ERROR); - return user_id; -} - -void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key) -{ - auto lock = getLock(); - - client_info.current_user = name; - client_info.current_address = address; - client_info.current_password = password; - - if (!quota_key.empty()) - client_info.quota_key = quota_key; - - user_id = shared->access_control_manager.getID(name); - user = shared->access_control_manager.authorizeAndGetUser( - user_id, - password, - address.host(), - [this](const UserPtr & changed_user) - { - user = changed_user; - calculateAccessRights(); - }, - &subscription_for_user_change.subscription); - - quota = getAccessControlManager().getQuotaContext(user_id, name, address.host(), quota_key); - row_policy = getAccessControlManager().getRowPolicyContext(user_id); - - calculateUserSettings(); - calculateAccessRights(); -} - void Context::addDependencyUnsafe(const StorageID & from, const StorageID & where) { diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 391e1dcd505..11017ec778e 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -13,7 +13,6 @@ #include #include "config_core.h" #include -#include #include #include #include @@ -44,9 +43,10 @@ namespace DB struct ContextShared; class Context; -struct User; class AccessRightsContext; using AccessRightsContextPtr = std::shared_ptr; +struct User; +using UserPtr = std::shared_ptr; class RowPolicyContext; using RowPolicyContextPtr = std::shared_ptr; class QuotaContext; @@ -136,15 +136,6 @@ struct IHostContext using IHostContextPtr = std::shared_ptr; -/// Subscription for user's change. This subscription cannot be copied with the context, -/// that's why we had to move it into a separate structure. -struct SubscriptionForUserChange -{ - ext::scope_guard subscription; - SubscriptionForUserChange() {} - SubscriptionForUserChange(const SubscriptionForUserChange &) {} - SubscriptionForUserChange & operator =(const SubscriptionForUserChange &) { subscription = {}; return *this; } -}; /** A set of known objects that can be used in the query. * Consists of a shared part (always common to all sessions and queries) @@ -164,12 +155,8 @@ private: InputInitializer input_initializer_callback; InputBlocksReader input_blocks_reader; - std::shared_ptr user; - UUID user_id; - SubscriptionForUserChange subscription_for_user_change; + std::optional user_id; AccessRightsContextPtr access_rights; - QuotaContextPtr quota; /// Current quota. By default - empty quota, that have no limits. - RowPolicyContextPtr row_policy; RowPolicyContextPtr initial_row_policy; String current_database; Settings settings; /// Setting for query execution. @@ -241,7 +228,21 @@ public: AccessControlManager & getAccessControlManager(); const AccessControlManager & getAccessControlManager() const; - AccessRightsContextPtr getAccessRights() const { return std::atomic_load(&access_rights); } + + /** Take the list of users, quotas and configuration profiles from this config. + * The list of users is completely replaced. + * The accumulated quota values are not reset if the quota is not deleted. + */ + void setUsersConfig(const ConfigurationPtr & config); + ConfigurationPtr getUsersConfig(); + + /// Sets the current user, checks the password and that the specified host is allowed. + /// Must be called before getClientInfo. + void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key); + + UserPtr getUser() const; + String getUserName() const; + UUID getUserID() const; /// Checks access rights. /// Empty database means the current database. @@ -254,26 +255,17 @@ public: void checkAccess(const AccessRightsElement & access) const; void checkAccess(const AccessRightsElements & access) const; - QuotaContextPtr getQuota() const { return quota; } - RowPolicyContextPtr getRowPolicy() const { return row_policy; } + AccessRightsContextPtr getAccessRights() const; + + RowPolicyContextPtr getRowPolicy() const; /// Sets an extra row policy based on `client_info.initial_user`, if it exists. /// TODO: we need a better solution here. It seems we should pass the initial row policy /// because a shard is allowed to don't have the initial user or it may be another user with the same name. void setInitialRowPolicy(); - RowPolicyContextPtr getInitialRowPolicy() const { return initial_row_policy; } + RowPolicyContextPtr getInitialRowPolicy() const; - /** Take the list of users, quotas and configuration profiles from this config. - * The list of users is completely replaced. - * The accumulated quota values are not reset if the quota is not deleted. - */ - void setUsersConfig(const ConfigurationPtr & config); - ConfigurationPtr getUsersConfig(); - - /// Must be called before getClientInfo. - void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key); - std::shared_ptr getUser() const; - UUID getUserID() const; + QuotaContextPtr getQuota() const; /// We have to copy external tables inside executeQuery() to track limits. Therefore, set callback for it. Must set once. void setExternalTablesInitializer(ExternalTablesInitializer && initializer); @@ -618,12 +610,6 @@ private: void calculateUserSettings(); void calculateAccessRights(); - /** Check if the current client has access to the specified database. - * If access is denied, throw an exception. - * NOTE: This method should always be called when the `shared->mutex` mutex is acquired. - */ - void checkDatabaseAccessRightsImpl(const std::string & database_name) const; - template void checkAccessImpl(const Args &... args) const; diff --git a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp index ecbf40aa73c..86a4699a636 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp @@ -74,7 +74,7 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateUserQuery(const ASTShowC if (show_query.current_user) user = context.getUser(); else - user = context.getAccessControlManager().getUser(show_query.name); + user = context.getAccessControlManager().read(show_query.name); auto create_query = std::make_shared(); create_query->name = user->getName(); diff --git a/dbms/src/Interpreters/tests/users.cpp b/dbms/src/Interpreters/tests/users.cpp index 93b1f6c27f1..59be7baba68 100644 --- a/dbms/src/Interpreters/tests/users.cpp +++ b/dbms/src/Interpreters/tests/users.cpp @@ -218,7 +218,7 @@ void runOneTest(const TestDescriptor & test_descriptor) try { - res = acl_manager.getUser(entry.user_name)->access.isGranted(DB::AccessType::ALL, entry.database_name); + res = acl_manager.read(entry.user_name)->access.isGranted(DB::AccessType::ALL, entry.database_name); } catch (const Poco::Exception &) { From 543587fc4649ae4968d4f642df6875df8b25511f Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Thu, 20 Feb 2020 00:11:29 +0300 Subject: [PATCH 09/19] Remove ATTACH and DETACH access types, check CREATE and DROP access types instead of them. --- dbms/src/Access/AccessFlags.h | 12 ---------- dbms/src/Access/AccessRightsContext.cpp | 4 ++-- dbms/src/Access/AccessType.h | 24 ------------------- .../src/Interpreters/InterpreterDropQuery.cpp | 12 +++++----- 4 files changed, 8 insertions(+), 44 deletions(-) diff --git a/dbms/src/Access/AccessFlags.h b/dbms/src/Access/AccessFlags.h index e547a95bdc6..e191818ae06 100644 --- a/dbms/src/Access/AccessFlags.h +++ b/dbms/src/Access/AccessFlags.h @@ -304,15 +304,10 @@ private: ext::push_back(all, std::move(alter)); auto create_database = std::make_unique("CREATE DATABASE", next_flag++, DATABASE_LEVEL); - ext::push_back(create_database->aliases, "ATTACH DATABASE"); auto create_table = std::make_unique("CREATE TABLE", next_flag++, TABLE_LEVEL); - ext::push_back(create_table->aliases, "ATTACH TABLE"); auto create_view = std::make_unique("CREATE VIEW", next_flag++, VIEW_LEVEL); - ext::push_back(create_view->aliases, "ATTACH VIEW"); auto create_dictionary = std::make_unique("CREATE DICTIONARY", next_flag++, DICTIONARY_LEVEL); - ext::push_back(create_dictionary->aliases, "ATTACH DICTIONARY"); auto create = std::make_unique("CREATE", std::move(create_database), std::move(create_table), std::move(create_view), std::move(create_dictionary)); - ext::push_back(create->aliases, "ATTACH"); ext::push_back(all, std::move(create)); auto create_temporary_table = std::make_unique("CREATE TEMPORARY TABLE", next_flag++, GLOBAL_LEVEL); @@ -325,13 +320,6 @@ private: auto drop = std::make_unique("DROP", std::move(drop_database), std::move(drop_table), std::move(drop_view), std::move(drop_dictionary)); ext::push_back(all, std::move(drop)); - auto detach_database = std::make_unique("DETACH DATABASE", next_flag++, DATABASE_LEVEL); - auto detach_table = std::make_unique("DETACH TABLE", next_flag++, TABLE_LEVEL); - auto detach_view = std::make_unique("DETACH VIEW", next_flag++, VIEW_LEVEL); - auto detach_dictionary = std::make_unique("DETACH DICTIONARY", next_flag++, DICTIONARY_LEVEL); - auto detach = std::make_unique("DETACH", std::move(detach_database), std::move(detach_table), std::move(detach_view), std::move(detach_dictionary)); - ext::push_back(all, std::move(detach)); - auto truncate_table = std::make_unique("TRUNCATE TABLE", next_flag++, TABLE_LEVEL); auto truncate_view = std::make_unique("TRUNCATE VIEW", next_flag++, VIEW_LEVEL); auto truncate = std::make_unique("TRUNCATE", std::move(truncate_table), std::move(truncate_view)); diff --git a/dbms/src/Access/AccessRightsContext.cpp b/dbms/src/Access/AccessRightsContext.cpp index 52404e77b2e..ca572fe299d 100644 --- a/dbms/src/Access/AccessRightsContext.cpp +++ b/dbms/src/Access/AccessRightsContext.cpp @@ -330,8 +330,8 @@ boost::shared_ptr AccessRightsContext::calculateResultAccess static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW | AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW - | AccessType::DETACH_DATABASE | AccessType::DETACH_TABLE | AccessType::DETACH_VIEW | AccessType::TRUNCATE; - static const AccessFlags dictionary_ddl = AccessType::CREATE_DICTIONARY | AccessType::DROP_DICTIONARY | AccessType::DETACH_DICTIONARY; + | AccessType::TRUNCATE; + static const AccessFlags dictionary_ddl = AccessType::CREATE_DICTIONARY | AccessType::DROP_DICTIONARY; static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl; static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE; diff --git a/dbms/src/Access/AccessType.h b/dbms/src/Access/AccessType.h index 5b96524997e..3d60415774d 100644 --- a/dbms/src/Access/AccessType.h +++ b/dbms/src/Access/AccessType.h @@ -66,24 +66,12 @@ enum class AccessType CREATE_TEMPORARY_TABLE, /// allows to create and manipulate temporary tables and views. CREATE, /// allows to execute {CREATE|ATTACH} [TEMPORARY] {DATABASE|TABLE|VIEW|DICTIONARY} - ATTACH_DATABASE, /// allows to execute {CREATE|ATTACH} DATABASE - ATTACH_TABLE, /// allows to execute {CREATE|ATTACH} TABLE - ATTACH_VIEW, /// allows to execute {CREATE|ATTACH} VIEW - ATTACH_DICTIONARY, /// allows to execute {CREATE|ATTACH} DICTIONARY - ATTACH, /// allows to execute {CREATE|ATTACH} {DATABASE|TABLE|VIEW|DICTIONARY} - DROP_DATABASE, DROP_TABLE, DROP_VIEW, DROP_DICTIONARY, DROP, /// allows to execute DROP {DATABASE|TABLE|VIEW|DICTIONARY} - DETACH_DATABASE, - DETACH_TABLE, - DETACH_VIEW, - DETACH_DICTIONARY, - DETACH, /// allows to execute DETACH {DATABASE|TABLE|VIEW|DICTIONARY} - TRUNCATE_TABLE, TRUNCATE_VIEW, TRUNCATE, /// allows to execute TRUNCATE {TABLE|VIEW} @@ -235,24 +223,12 @@ namespace impl ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_TEMPORARY_TABLE); ACCESS_TYPE_TO_KEYWORD_CASE(CREATE); - ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH_DATABASE); - ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH_TABLE); - ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH_VIEW); - ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH_DICTIONARY); - ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH); - ACCESS_TYPE_TO_KEYWORD_CASE(DROP_DATABASE); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_TABLE); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_VIEW); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_DICTIONARY); ACCESS_TYPE_TO_KEYWORD_CASE(DROP); - ACCESS_TYPE_TO_KEYWORD_CASE(DETACH_DATABASE); - ACCESS_TYPE_TO_KEYWORD_CASE(DETACH_TABLE); - ACCESS_TYPE_TO_KEYWORD_CASE(DETACH_VIEW); - ACCESS_TYPE_TO_KEYWORD_CASE(DETACH_DICTIONARY); - ACCESS_TYPE_TO_KEYWORD_CASE(DETACH); - ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE_TABLE); ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE_VIEW); ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE); diff --git a/dbms/src/Interpreters/InterpreterDropQuery.cpp b/dbms/src/Interpreters/InterpreterDropQuery.cpp index c51365ad2ba..4daa647fa9b 100644 --- a/dbms/src/Interpreters/InterpreterDropQuery.cpp +++ b/dbms/src/Interpreters/InterpreterDropQuery.cpp @@ -87,7 +87,7 @@ BlockIO InterpreterDropQuery::executeToTable( auto table_id = table->getStorageID(); if (kind == ASTDropQuery::Kind::Detach) { - context.checkAccess(table->isView() ? AccessType::DETACH_VIEW : AccessType::DETACH_TABLE, + context.checkAccess(table->isView() ? AccessType::DROP_VIEW : AccessType::DROP_TABLE, database_name, table_name); table->shutdown(); /// If table was already dropped by anyone, an exception will be thrown @@ -187,7 +187,7 @@ BlockIO InterpreterDropQuery::executeToDictionary( if (kind == ASTDropQuery::Kind::Detach) { /// Drop dictionary from memory, don't touch data and metadata - context.checkAccess(AccessType::DETACH_DICTIONARY, database_name, dictionary_name); + context.checkAccess(AccessType::DROP_DICTIONARY, database_name, dictionary_name); database->detachDictionary(dictionary_name, context); } else if (kind == ASTDropQuery::Kind::Truncate) @@ -247,7 +247,7 @@ BlockIO InterpreterDropQuery::executeToDatabase(const String & database_name, AS } else if (kind == ASTDropQuery::Kind::Detach) { - context.checkAccess(AccessType::DETACH_DATABASE, database_name); + context.checkAccess(AccessType::DROP_DATABASE, database_name); context.detachDatabase(database_name); database->shutdown(); } @@ -324,14 +324,14 @@ AccessRightsElements InterpreterDropQuery::getRequiredAccessForDDLOnCluster() co if (drop.table.empty()) { if (drop.kind == ASTDropQuery::Kind::Detach) - required_access.emplace_back(AccessType::DETACH_DATABASE, drop.database); + required_access.emplace_back(AccessType::DROP_DATABASE, drop.database); else if (drop.kind == ASTDropQuery::Kind::Drop) required_access.emplace_back(AccessType::DROP_DATABASE, drop.database); } else if (drop.is_dictionary) { if (drop.kind == ASTDropQuery::Kind::Detach) - required_access.emplace_back(AccessType::DETACH_DICTIONARY, drop.database, drop.table); + required_access.emplace_back(AccessType::DROP_DICTIONARY, drop.database, drop.table); else if (drop.kind == ASTDropQuery::Kind::Drop) required_access.emplace_back(AccessType::DROP_DICTIONARY, drop.database, drop.table); } @@ -343,7 +343,7 @@ AccessRightsElements InterpreterDropQuery::getRequiredAccessForDDLOnCluster() co else if (drop.kind == ASTDropQuery::Kind::Truncate) required_access.emplace_back(AccessType::TRUNCATE_TABLE | AccessType::TRUNCATE_VIEW, drop.database, drop.table); else if (drop.kind == ASTDropQuery::Kind::Detach) - required_access.emplace_back(AccessType::DETACH_TABLE | AccessType::DETACH_VIEW, drop.database, drop.table); + required_access.emplace_back(AccessType::DROP_TABLE | AccessType::DROP_VIEW, drop.database, drop.table); } return required_access; From fc8aa5efe76a02688cc09c33808eeb5898d0d061 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Thu, 20 Feb 2020 00:48:59 +0300 Subject: [PATCH 10/19] Separate access-controlling access types. --- dbms/src/Access/AccessFlags.h | 13 +++++++++++-- dbms/src/Access/AccessType.h | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/dbms/src/Access/AccessFlags.h b/dbms/src/Access/AccessFlags.h index e191818ae06..1b1934a3f01 100644 --- a/dbms/src/Access/AccessFlags.h +++ b/dbms/src/Access/AccessFlags.h @@ -335,8 +335,17 @@ private: ext::push_back(all, std::move(kill)); auto create_user = std::make_unique("CREATE USER", next_flag++, GLOBAL_LEVEL); - ext::push_back(create_user->aliases, "ALTER USER", "DROP USER", "CREATE ROLE", "DROP ROLE", "CREATE POLICY", "ALTER POLICY", "DROP POLICY", "CREATE QUOTA", "ALTER QUOTA", "DROP QUOTA"); - ext::push_back(all, std::move(create_user)); + auto alter_user = std::make_unique("ALTER USER", next_flag++, GLOBAL_LEVEL); + auto drop_user = std::make_unique("DROP USER", next_flag++, GLOBAL_LEVEL); + auto create_role = std::make_unique("CREATE ROLE", next_flag++, GLOBAL_LEVEL); + auto drop_role = std::make_unique("DROP ROLE", next_flag++, GLOBAL_LEVEL); + auto create_policy = std::make_unique("CREATE POLICY", next_flag++, GLOBAL_LEVEL); + auto alter_policy = std::make_unique("ALTER POLICY", next_flag++, GLOBAL_LEVEL); + auto drop_policy = std::make_unique("DROP POLICY", next_flag++, GLOBAL_LEVEL); + auto create_quota = std::make_unique("CREATE QUOTA", next_flag++, GLOBAL_LEVEL); + auto alter_quota = std::make_unique("ALTER QUOTA", next_flag++, GLOBAL_LEVEL); + auto drop_quota = std::make_unique("DROP QUOTA", next_flag++, GLOBAL_LEVEL); + ext::push_back(all, std::move(create_user), std::move(alter_user), std::move(drop_user), std::move(create_role), std::move(drop_role), std::move(create_policy), std::move(alter_policy), std::move(drop_policy), std::move(create_quota), std::move(alter_quota), std::move(drop_quota)); auto shutdown = std::make_unique("SHUTDOWN", next_flag++, GLOBAL_LEVEL); ext::push_back(shutdown->aliases, "SYSTEM SHUTDOWN", "SYSTEM KILL"); diff --git a/dbms/src/Access/AccessType.h b/dbms/src/Access/AccessType.h index 3d60415774d..1c829f57f63 100644 --- a/dbms/src/Access/AccessType.h +++ b/dbms/src/Access/AccessType.h @@ -82,7 +82,7 @@ enum class AccessType KILL_MUTATION, /// allows to kill a mutation KILL, /// allows to execute KILL {MUTATION|QUERY} - CREATE_USER, /// allows to create, alter and drop users, roles, quotas, row policies. + CREATE_USER, ALTER_USER, DROP_USER, CREATE_ROLE, From ae18d443c87a5de5a2b55e5a8ab6276293b852e2 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Fri, 21 Feb 2020 03:17:07 +0300 Subject: [PATCH 11/19] Introduce roles. --- dbms/src/Access/AccessControlManager.cpp | 28 ++- dbms/src/Access/AccessControlManager.h | 27 ++- dbms/src/Access/AccessRightsContext.cpp | 150 ++++++++++++- dbms/src/Access/AccessRightsContext.h | 29 ++- .../src/Access/AccessRightsContextFactory.cpp | 4 + dbms/src/Access/AccessRightsContextFactory.h | 2 +- dbms/src/Access/AllowedClientHosts.h | 2 +- dbms/src/Access/CurrentRolesInfo.cpp | 34 +++ dbms/src/Access/CurrentRolesInfo.h | 31 +++ dbms/src/Access/GenericRoleSet.cpp | 88 +++++++- dbms/src/Access/GenericRoleSet.h | 4 + dbms/src/Access/IAccessStorage.cpp | 4 + dbms/src/Access/QuotaContext.cpp | 5 +- dbms/src/Access/QuotaContext.h | 5 +- dbms/src/Access/QuotaContextFactory.cpp | 6 +- dbms/src/Access/QuotaContextFactory.h | 2 +- dbms/src/Access/Role.cpp | 16 ++ dbms/src/Access/Role.h | 24 +++ dbms/src/Access/RoleContext.cpp | 200 ++++++++++++++++++ dbms/src/Access/RoleContext.h | 64 ++++++ dbms/src/Access/RoleContextFactory.cpp | 52 +++++ dbms/src/Access/RoleContextFactory.h | 29 +++ dbms/src/Access/RowPolicyContext.cpp | 4 +- dbms/src/Access/RowPolicyContext.h | 3 +- dbms/src/Access/RowPolicyContextFactory.cpp | 6 +- dbms/src/Access/RowPolicyContextFactory.h | 2 +- dbms/src/Access/User.cpp | 3 +- dbms/src/Access/User.h | 9 +- dbms/src/Common/ErrorCodes.cpp | 1 + dbms/src/Interpreters/Context.cpp | 49 ++++- dbms/src/Interpreters/Context.h | 10 +- ...InterpreterShowCreateAccessEntityQuery.cpp | 2 +- 32 files changed, 846 insertions(+), 49 deletions(-) create mode 100644 dbms/src/Access/CurrentRolesInfo.cpp create mode 100644 dbms/src/Access/CurrentRolesInfo.h create mode 100644 dbms/src/Access/Role.cpp create mode 100644 dbms/src/Access/Role.h create mode 100644 dbms/src/Access/RoleContext.cpp create mode 100644 dbms/src/Access/RoleContext.h create mode 100644 dbms/src/Access/RoleContextFactory.cpp create mode 100644 dbms/src/Access/RoleContextFactory.h diff --git a/dbms/src/Access/AccessControlManager.cpp b/dbms/src/Access/AccessControlManager.cpp index 4fc002a764b..5c1806a535b 100644 --- a/dbms/src/Access/AccessControlManager.cpp +++ b/dbms/src/Access/AccessControlManager.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -24,6 +25,7 @@ namespace AccessControlManager::AccessControlManager() : MultipleAccessStorage(createStorages()), access_rights_context_factory(std::make_unique(*this)), + role_context_factory(std::make_unique(*this)), row_policy_context_factory(std::make_unique(*this)), quota_context_factory(std::make_unique(*this)) { @@ -43,24 +45,38 @@ void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguratio AccessRightsContextPtr AccessControlManager::getAccessRightsContext( - const UUID & user_id, const Settings & settings, const String & current_database, const ClientInfo & client_info) const + const UUID & user_id, + const std::vector & current_roles, + bool use_default_roles, + const Settings & settings, + const String & current_database, + const ClientInfo & client_info) const { - return access_rights_context_factory->createContext(user_id, settings, current_database, client_info); + return access_rights_context_factory->createContext(user_id, current_roles, use_default_roles, settings, current_database, client_info); } -RowPolicyContextPtr AccessControlManager::getRowPolicyContext(const UUID & user_id) const +RoleContextPtr AccessControlManager::getRoleContext( + const std::vector & current_roles, + const std::vector & current_roles_with_admin_option) const { - return row_policy_context_factory->createContext(user_id); + return role_context_factory->createContext(current_roles, current_roles_with_admin_option); +} + + +RowPolicyContextPtr AccessControlManager::getRowPolicyContext(const UUID & user_id, const std::vector & enabled_roles) const +{ + return row_policy_context_factory->createContext(user_id, enabled_roles); } QuotaContextPtr AccessControlManager::getQuotaContext( - const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const + const String & user_name, const UUID & user_id, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & custom_quota_key) const { - return quota_context_factory->createContext(user_id, user_name, address, custom_quota_key); + return quota_context_factory->createContext(user_name, user_id, enabled_roles, address, custom_quota_key); } + std::vector AccessControlManager::getQuotaUsageInfo() const { return quota_context_factory->getUsageInfo(); diff --git a/dbms/src/Access/AccessControlManager.h b/dbms/src/Access/AccessControlManager.h index 4549f1afde8..1eb71547a69 100644 --- a/dbms/src/Access/AccessControlManager.h +++ b/dbms/src/Access/AccessControlManager.h @@ -22,6 +22,11 @@ namespace DB class AccessRightsContext; using AccessRightsContextPtr = std::shared_ptr; class AccessRightsContextFactory; +struct User; +using UserPtr = std::shared_ptr; +struct RoleContext; +using RoleContextPtr = std::shared_ptr; +class RoleContextFactory; class RowPolicyContext; using RowPolicyContextPtr = std::shared_ptr; class RowPolicyContextFactory; @@ -43,17 +48,33 @@ public: void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config); AccessRightsContextPtr getAccessRightsContext( - const UUID & user_id, const Settings & settings, const String & current_database, const ClientInfo & client_info) const; + const UUID & user_id, + const std::vector & current_roles, + bool use_default_roles, + const Settings & settings, + const String & current_database, + const ClientInfo & client_info) const; - RowPolicyContextPtr getRowPolicyContext(const UUID & user_id) const; + RoleContextPtr getRoleContext( + const std::vector & current_roles, + const std::vector & current_roles_with_admin_option) const; + + RowPolicyContextPtr getRowPolicyContext( + const UUID & user_id, + const std::vector & enabled_roles) const; QuotaContextPtr getQuotaContext( - const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const; + const String & user_name, + const UUID & user_id, + const std::vector & enabled_roles, + const Poco::Net::IPAddress & address, + const String & custom_quota_key) const; std::vector getQuotaUsageInfo() const; private: std::unique_ptr access_rights_context_factory; + std::unique_ptr role_context_factory; std::unique_ptr row_policy_context_factory; std::unique_ptr quota_context_factory; }; diff --git a/dbms/src/Access/AccessRightsContext.cpp b/dbms/src/Access/AccessRightsContext.cpp index ca572fe299d..39ee78e4601 100644 --- a/dbms/src/Access/AccessRightsContext.cpp +++ b/dbms/src/Access/AccessRightsContext.cpp @@ -1,15 +1,20 @@ #include #include +#include #include #include #include +#include #include #include #include +#include #include #include +#include #include #include +#include #include @@ -91,6 +96,9 @@ AccessRightsContext::AccessRightsContext() auto everything_granted = boost::make_shared(); everything_granted->grant(AccessType::ALL); boost::range::fill(result_access_cache, everything_granted); + + enabled_roles_with_admin_option = boost::make_shared>(); + row_policy_context = std::make_shared(); quota_context = std::make_shared(); } @@ -121,6 +129,9 @@ void AccessRightsContext::setUser(const UserPtr & user_) const auto nothing_granted = boost::make_shared(); boost::range::fill(result_access_cache, nothing_granted); subscription_for_user_change = {}; + subscription_for_roles_info_change = {}; + role_context = nullptr; + enabled_roles_with_admin_option = boost::make_shared>(); row_policy_context = std::make_shared(); quota_context = std::make_shared(); return; @@ -128,9 +139,50 @@ void AccessRightsContext::setUser(const UserPtr & user_) const user_name = user->getName(); trace_log = &Poco::Logger::get("AccessRightsContext (" + user_name + ")"); + + std::vector current_roles, current_roles_with_admin_option; + if (params.use_default_roles) + { + for (const UUID & id : user->granted_roles) + { + if (user->default_roles.match(id)) + current_roles.push_back(id); + } + boost::range::set_intersection(current_roles, user->granted_roles_with_admin_option, + std::back_inserter(current_roles_with_admin_option)); + } + else + { + current_roles.reserve(params.current_roles.size()); + for (const auto & id : params.current_roles) + { + if (user->granted_roles.contains(id)) + current_roles.push_back(id); + if (user->granted_roles_with_admin_option.contains(id)) + current_roles_with_admin_option.push_back(id); + } + } + + subscription_for_roles_info_change = {}; + role_context = manager->getRoleContext(current_roles, current_roles_with_admin_option); + subscription_for_roles_info_change = role_context->subscribeForChanges([this](const CurrentRolesInfoPtr & roles_info_) + { + std::lock_guard lock{mutex}; + setRolesInfo(roles_info_); + }); + + setRolesInfo(role_context->getInfo()); +} + + +void AccessRightsContext::setRolesInfo(const CurrentRolesInfoPtr & roles_info_) const +{ + assert(roles_info_); + roles_info = roles_info_; + enabled_roles_with_admin_option.store(nullptr /* need to recalculate */); boost::range::fill(result_access_cache, nullptr /* need recalculate */); - row_policy_context = manager->getRowPolicyContext(*params.user_id); - quota_context = manager->getQuotaContext(*params.user_id, user_name, params.address, params.quota_key); + row_policy_context = manager->getRowPolicyContext(*params.user_id, roles_info->enabled_roles); + quota_context = manager->getQuotaContext(user_name, *params.user_id, roles_info->enabled_roles, params.address, params.quota_key); } @@ -301,6 +353,36 @@ void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) c void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkAccessImpl(nullptr, access); } +void AccessRightsContext::checkAdminOption(const UUID & role_id) const +{ + boost::shared_ptr> enabled_roles = enabled_roles_with_admin_option.load(); + if (!enabled_roles) + { + std::lock_guard lock{mutex}; + enabled_roles = enabled_roles_with_admin_option.load(); + if (!enabled_roles) + { + if (roles_info) + enabled_roles = boost::make_shared>(roles_info->enabled_roles_with_admin_option.begin(), roles_info->enabled_roles_with_admin_option.end()); + else + enabled_roles = boost::make_shared>(); + enabled_roles_with_admin_option.store(enabled_roles); + } + } + + if (enabled_roles->contains(role_id)) + return; + + std::optional role_name = manager->readName(role_id); + if (!role_name) + role_name = "ID {" + toString(role_id) + "}"; + throw Exception( + getUserName() + ": Not enough privileges. To execute this query it's necessary to have the grant " + backQuoteIfNeed(*role_name) + + " WITH ADMIN OPTION ", + ErrorCodes::ACCESS_DENIED); +} + + boost::shared_ptr AccessRightsContext::calculateResultAccess(bool grant_option) const { return calculateResultAccess(grant_option, params.readonly, params.allow_ddl, params.allow_introspection); @@ -326,7 +408,18 @@ boost::shared_ptr AccessRightsContext::calculateResultAccess auto result_ptr = boost::make_shared(); auto & result = *result_ptr; - result = grant_option ? user->access_with_grant_option : user->access; + if (grant_option) + { + result = user->access_with_grant_option; + if (roles_info) + result.merge(roles_info->access_with_grant_option); + } + else + { + result = user->access; + if (roles_info) + result.merge(roles_info->access); + } static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW | AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW @@ -334,12 +427,16 @@ boost::shared_ptr AccessRightsContext::calculateResultAccess static const AccessFlags dictionary_ddl = AccessType::CREATE_DICTIONARY | AccessType::DROP_DICTIONARY; static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl; static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE; + static const AccessFlags all_dcl = AccessType::CREATE_USER | AccessType::CREATE_ROLE | AccessType::CREATE_POLICY + | AccessType::CREATE_QUOTA | AccessType::ALTER_USER | AccessType::ALTER_POLICY | AccessType::ALTER_QUOTA | AccessType::DROP_USER + | AccessType::DROP_ROLE | AccessType::DROP_POLICY | AccessType::DROP_QUOTA; /// Anyone has access to the "system" database. - result.grant(AccessType::SELECT, "system"); + if (!result.isGranted(AccessType::SELECT, "system")) + result.grant(AccessType::SELECT, "system"); if (readonly_) - result.fullRevoke(write_table_access | AccessType::SYSTEM); + result.fullRevoke(write_table_access | all_dcl | AccessType::SYSTEM | AccessType::KILL); if (readonly_ || !allow_ddl_) result.fullRevoke(table_and_dictionary_ddl); @@ -360,7 +457,16 @@ boost::shared_ptr AccessRightsContext::calculateResultAccess result_access_cache[cache_index].store(result_ptr); if (trace_log && (params.readonly == readonly_) && (params.allow_ddl == allow_ddl_) && (params.allow_introspection == allow_introspection_)) + { LOG_TRACE(trace_log, "List of all grants: " << result_ptr->toString() << (grant_option ? " WITH GRANT OPTION" : "")); + if (roles_info && !roles_info->getCurrentRolesNames().empty()) + { + LOG_TRACE( + trace_log, + "Current_roles: " << boost::algorithm::join(roles_info->getCurrentRolesNames(), ", ") + << ", enabled_roles: " << boost::algorithm::join(roles_info->getEnabledRolesNames(), ", ")); + } + } return result_ptr; } @@ -378,6 +484,36 @@ String AccessRightsContext::getUserName() const return user_name; } +CurrentRolesInfoPtr AccessRightsContext::getRolesInfo() const +{ + std::lock_guard lock{mutex}; + return roles_info; +} + +std::vector AccessRightsContext::getCurrentRoles() const +{ + std::lock_guard lock{mutex}; + return roles_info ? roles_info->current_roles : std::vector{}; +} + +Strings AccessRightsContext::getCurrentRolesNames() const +{ + std::lock_guard lock{mutex}; + return roles_info ? roles_info->getCurrentRolesNames() : Strings{}; +} + +std::vector AccessRightsContext::getEnabledRoles() const +{ + std::lock_guard lock{mutex}; + return roles_info ? roles_info->enabled_roles : std::vector{}; +} + +Strings AccessRightsContext::getEnabledRolesNames() const +{ + std::lock_guard lock{mutex}; + return roles_info ? roles_info->getEnabledRolesNames() : Strings{}; +} + RowPolicyContextPtr AccessRightsContext::getRowPolicy() const { std::lock_guard lock{mutex}; @@ -399,6 +535,8 @@ bool operator <(const AccessRightsContext::Params & lhs, const AccessRightsConte if (lhs.field > rhs.field) \ return false ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(user_id); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_roles); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(use_default_roles); ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(address); ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(quota_key); ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_database); @@ -418,6 +556,8 @@ bool operator ==(const AccessRightsContext::Params & lhs, const AccessRightsCont if (lhs.field != rhs.field) \ return false ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(user_id); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_roles); + ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(use_default_roles); ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(address); ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(quota_key); ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_database); diff --git a/dbms/src/Access/AccessRightsContext.h b/dbms/src/Access/AccessRightsContext.h index 68a53351098..6d53568d8a9 100644 --- a/dbms/src/Access/AccessRightsContext.h +++ b/dbms/src/Access/AccessRightsContext.h @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -15,6 +16,10 @@ namespace DB { struct User; using UserPtr = std::shared_ptr; +struct CurrentRolesInfo; +using CurrentRolesInfoPtr = std::shared_ptr; +class RoleContext; +using RoleContextPtr = std::shared_ptr; struct RowPolicyContext; using RowPolicyContextPtr = std::shared_ptr; struct QuotaContext; @@ -29,6 +34,8 @@ public: struct Params { std::optional user_id; + std::vector current_roles; + bool use_default_roles = false; UInt64 readonly = 0; bool allow_ddl = false; bool allow_introspection = false; @@ -56,10 +63,16 @@ public: void checkPassword(const String & password) const; void checkHostIsAllowed() const; + CurrentRolesInfoPtr getRolesInfo() const; + std::vector getCurrentRoles() const; + Strings getCurrentRolesNames() const; + std::vector getEnabledRoles() const; + Strings getEnabledRolesNames() const; + RowPolicyContextPtr getRowPolicy() const; QuotaContextPtr getQuota() const; - /// Checks if a specified access granted, and throws an exception if not. + /// Checks if a specified access is granted, and throws an exception if not. /// Empty database means the current database. void checkAccess(const AccessFlags & access) const; void checkAccess(const AccessFlags & access, const std::string_view & database) const; @@ -70,7 +83,7 @@ public: void checkAccess(const AccessRightsElement & access) const; void checkAccess(const AccessRightsElements & access) const; - /// Checks if a specified access granted. + /// Checks if a specified access is granted. bool isGranted(const AccessFlags & access) const; bool isGranted(const AccessFlags & access, const std::string_view & database) const; bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; @@ -80,7 +93,7 @@ public: bool isGranted(const AccessRightsElement & access) const; bool isGranted(const AccessRightsElements & access) const; - /// Checks if a specified access granted, and logs a warning if not. + /// Checks if a specified access is granted, and logs a warning if not. bool isGranted(Poco::Logger * log_, const AccessFlags & access) const; bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const; bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; @@ -90,7 +103,7 @@ public: bool isGranted(Poco::Logger * log_, const AccessRightsElement & access) const; bool isGranted(Poco::Logger * log_, const AccessRightsElements & access) const; - /// Checks if a specified access granted with grant option, and throws an exception if not. + /// Checks if a specified access is granted with grant option, and throws an exception if not. void checkGrantOption(const AccessFlags & access) const; void checkGrantOption(const AccessFlags & access, const std::string_view & database) const; void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; @@ -100,12 +113,16 @@ public: void checkGrantOption(const AccessRightsElement & access) const; void checkGrantOption(const AccessRightsElements & access) const; + /// Checks if a specified role is granted with admin option, and throws an exception if not. + void checkAdminOption(const UUID & role_id) const; + private: friend class AccessRightsContextFactory; friend struct ext::shared_ptr_helper; AccessRightsContext(const AccessControlManager & manager_, const Params & params_); /// AccessRightsContext should be created by AccessRightsContextFactory. void setUser(const UserPtr & user_) const; + void setRolesInfo(const CurrentRolesInfoPtr & roles_info_) const; template bool checkAccessImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const; @@ -125,6 +142,10 @@ private: mutable UserPtr user; mutable String user_name; mutable ext::scope_guard subscription_for_user_change; + mutable RoleContextPtr role_context; + mutable ext::scope_guard subscription_for_roles_info_change; + mutable CurrentRolesInfoPtr roles_info; + mutable boost::atomic_shared_ptr> enabled_roles_with_admin_option; mutable boost::atomic_shared_ptr result_access_cache[7]; mutable RowPolicyContextPtr row_policy_context; mutable QuotaContextPtr quota_context; diff --git a/dbms/src/Access/AccessRightsContextFactory.cpp b/dbms/src/Access/AccessRightsContextFactory.cpp index 21cd3657170..8d542a5f439 100644 --- a/dbms/src/Access/AccessRightsContextFactory.cpp +++ b/dbms/src/Access/AccessRightsContextFactory.cpp @@ -24,12 +24,16 @@ AccessRightsContextPtr AccessRightsContextFactory::createContext(const Params & AccessRightsContextPtr AccessRightsContextFactory::createContext( const UUID & user_id, + const std::vector & current_roles, + bool use_default_roles, const Settings & settings, const String & current_database, const ClientInfo & client_info) { Params params; params.user_id = user_id; + params.current_roles = current_roles; + params.use_default_roles = use_default_roles; params.current_database = current_database; params.readonly = settings.readonly; params.allow_ddl = settings.allow_ddl; diff --git a/dbms/src/Access/AccessRightsContextFactory.h b/dbms/src/Access/AccessRightsContextFactory.h index 9f61c1099c5..c480307757a 100644 --- a/dbms/src/Access/AccessRightsContextFactory.h +++ b/dbms/src/Access/AccessRightsContextFactory.h @@ -18,7 +18,7 @@ public: using Params = AccessRightsContext::Params; AccessRightsContextPtr createContext(const Params & params); - AccessRightsContextPtr createContext(const UUID & user_id, const Settings & settings, const String & current_database, const ClientInfo & client_info); + AccessRightsContextPtr createContext(const UUID & user_id, const std::vector & current_roles, bool use_default_roles, const Settings & settings, const String & current_database, const ClientInfo & client_info); private: const AccessControlManager & manager; diff --git a/dbms/src/Access/AllowedClientHosts.h b/dbms/src/Access/AllowedClientHosts.h index c5e0b71156b..bc075adc6e1 100644 --- a/dbms/src/Access/AllowedClientHosts.h +++ b/dbms/src/Access/AllowedClientHosts.h @@ -46,7 +46,7 @@ public: struct AnyHostTag {}; AllowedClientHosts() {} - explicit AllowedClientHosts(AnyHostTag) { addAnyHost(); } + AllowedClientHosts(AnyHostTag) { addAnyHost(); } ~AllowedClientHosts() {} AllowedClientHosts(const AllowedClientHosts & src) = default; diff --git a/dbms/src/Access/CurrentRolesInfo.cpp b/dbms/src/Access/CurrentRolesInfo.cpp new file mode 100644 index 00000000000..f4cbd739021 --- /dev/null +++ b/dbms/src/Access/CurrentRolesInfo.cpp @@ -0,0 +1,34 @@ +#include + + +namespace DB +{ + +Strings CurrentRolesInfo::getCurrentRolesNames() const +{ + Strings result; + result.reserve(current_roles.size()); + for (const auto & id : current_roles) + result.emplace_back(names_of_roles.at(id)); + return result; +} + + +Strings CurrentRolesInfo::getEnabledRolesNames() const +{ + Strings result; + result.reserve(enabled_roles.size()); + for (const auto & id : enabled_roles) + result.emplace_back(names_of_roles.at(id)); + return result; +} + + +bool operator==(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs) +{ + return (lhs.current_roles == rhs.current_roles) && (lhs.enabled_roles == rhs.enabled_roles) + && (lhs.enabled_roles_with_admin_option == rhs.enabled_roles_with_admin_option) && (lhs.names_of_roles == rhs.names_of_roles) + && (lhs.access == rhs.access) && (lhs.access_with_grant_option == rhs.access_with_grant_option); +} + +} diff --git a/dbms/src/Access/CurrentRolesInfo.h b/dbms/src/Access/CurrentRolesInfo.h new file mode 100644 index 00000000000..a4dd26be0f7 --- /dev/null +++ b/dbms/src/Access/CurrentRolesInfo.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ + +/// Information about a role. +struct CurrentRolesInfo +{ + std::vector current_roles; + std::vector enabled_roles; + std::vector enabled_roles_with_admin_option; + std::unordered_map names_of_roles; + AccessRights access; + AccessRights access_with_grant_option; + + Strings getCurrentRolesNames() const; + Strings getEnabledRolesNames() const; + + friend bool operator ==(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs); + friend bool operator !=(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs) { return !(lhs == rhs); } +}; + +using CurrentRolesInfoPtr = std::shared_ptr; + +} diff --git a/dbms/src/Access/GenericRoleSet.cpp b/dbms/src/Access/GenericRoleSet.cpp index ff142a36a97..56021abea82 100644 --- a/dbms/src/Access/GenericRoleSet.cpp +++ b/dbms/src/Access/GenericRoleSet.cpp @@ -1,10 +1,12 @@ #include #include #include +#include #include #include #include #include +#include namespace DB @@ -48,8 +50,10 @@ GenericRoleSet::GenericRoleSet(const ASTGenericRoleSet & ast, const AccessContro ids.reserve(ast.names.size()); for (const String & name : ast.names) { - auto id = manager.getID(name); - ids.insert(id); + auto id = manager.find(name); + if (!id) + id = manager.getID(name); + ids.insert(*id); } } @@ -65,8 +69,10 @@ GenericRoleSet::GenericRoleSet(const ASTGenericRoleSet & ast, const AccessContro except_ids.reserve(ast.except_names.size()); for (const String & except_name : ast.except_names) { - auto except_id = manager.getID(except_name); - except_ids.insert(except_id); + auto except_id = manager.find(except_name); + if (!except_id) + except_id = manager.getID(except_name); + except_ids.insert(*except_id); } } @@ -173,9 +179,53 @@ void GenericRoleSet::add(const boost::container::flat_set & ids_) } -bool GenericRoleSet::match(const UUID & user_id) const +bool GenericRoleSet::match(const UUID & id) const { - return (all || ids.contains(user_id)) && !except_ids.contains(user_id); + return (all || ids.contains(id)) && !except_ids.contains(id); +} + + +bool GenericRoleSet::match(const UUID & user_id, const std::vector & enabled_roles) const +{ + if (!all && !ids.contains(user_id)) + { + bool found_enabled_role = std::any_of( + enabled_roles.begin(), enabled_roles.end(), [this](const UUID & enabled_role) { return ids.contains(enabled_role); }); + if (!found_enabled_role) + return false; + } + + if (except_ids.contains(user_id)) + return false; + + bool in_except_list = std::any_of( + enabled_roles.begin(), enabled_roles.end(), [this](const UUID & enabled_role) { return except_ids.contains(enabled_role); }); + if (in_except_list) + return false; + + return true; +} + + +bool GenericRoleSet::match(const UUID & user_id, const boost::container::flat_set & enabled_roles) const +{ + if (!all && !ids.contains(user_id)) + { + bool found_enabled_role = std::any_of( + enabled_roles.begin(), enabled_roles.end(), [this](const UUID & enabled_role) { return ids.contains(enabled_role); }); + if (!found_enabled_role) + return false; + } + + if (except_ids.contains(user_id)) + return false; + + bool in_except_list = std::any_of( + enabled_roles.begin(), enabled_roles.end(), [this](const UUID & enabled_role) { return except_ids.contains(enabled_role); }); + if (in_except_list) + return false; + + return true; } @@ -204,6 +254,32 @@ std::vector GenericRoleSet::getMatchingUsers(const AccessControlManager & } +std::vector GenericRoleSet::getMatchingRoles(const AccessControlManager & manager) const +{ + if (!all) + return getMatchingIDs(); + + std::vector res; + for (const UUID & id : manager.findAll()) + { + if (match(id)) + res.push_back(id); + } + return res; +} + + +std::vector GenericRoleSet::getMatchingUsersAndRoles(const AccessControlManager & manager) const +{ + if (!all) + return getMatchingIDs(); + + std::vector vec = getMatchingUsers(manager); + boost::range::push_back(vec, getMatchingRoles(manager)); + return vec; +} + + bool operator ==(const GenericRoleSet & lhs, const GenericRoleSet & rhs) { return (lhs.all == rhs.all) && (lhs.ids == rhs.ids) && (lhs.except_ids == rhs.except_ids); diff --git a/dbms/src/Access/GenericRoleSet.h b/dbms/src/Access/GenericRoleSet.h index b3f39a05bd4..2caee348813 100644 --- a/dbms/src/Access/GenericRoleSet.h +++ b/dbms/src/Access/GenericRoleSet.h @@ -44,12 +44,16 @@ struct GenericRoleSet /// Checks if a specified ID matches this GenericRoleSet. bool match(const UUID & id) const; + bool match(const UUID & user_id, const std::vector & enabled_roles) const; + bool match(const UUID & user_id, const boost::container::flat_set & enabled_roles) const; /// Returns a list of matching IDs. The function must not be called if `all` == `true`. std::vector getMatchingIDs() const; /// Returns a list of matching users. std::vector getMatchingUsers(const AccessControlManager & manager) const; + std::vector getMatchingRoles(const AccessControlManager & manager) const; + std::vector getMatchingUsersAndRoles(const AccessControlManager & manager) const; friend bool operator ==(const GenericRoleSet & lhs, const GenericRoleSet & rhs); friend bool operator !=(const GenericRoleSet & lhs, const GenericRoleSet & rhs) { return !(lhs == rhs); } diff --git a/dbms/src/Access/IAccessStorage.cpp b/dbms/src/Access/IAccessStorage.cpp index 4ff8ed14d86..1c6a79b2fb2 100644 --- a/dbms/src/Access/IAccessStorage.cpp +++ b/dbms/src/Access/IAccessStorage.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -17,6 +18,7 @@ namespace ErrorCodes extern const int ACCESS_ENTITY_FOUND_DUPLICATES; extern const int ACCESS_ENTITY_STORAGE_READONLY; extern const int UNKNOWN_USER; + extern const int UNKNOWN_ROLE; } @@ -370,6 +372,8 @@ void IAccessStorage::throwNotFound(std::type_index type, const String & name) co int error_code; if (type == typeid(User)) error_code = ErrorCodes::UNKNOWN_USER; + else if (type == typeid(Role)) + error_code = ErrorCodes::UNKNOWN_ROLE; else error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND; diff --git a/dbms/src/Access/QuotaContext.cpp b/dbms/src/Access/QuotaContext.cpp index 775e3a46cbe..815d9440eaa 100644 --- a/dbms/src/Access/QuotaContext.cpp +++ b/dbms/src/Access/QuotaContext.cpp @@ -178,11 +178,12 @@ QuotaContext::QuotaContext() QuotaContext::QuotaContext( - const UUID & user_id_, const String & user_name_, + const UUID & user_id_, + const std::vector & enabled_roles_, const Poco::Net::IPAddress & address_, const String & client_key_) - : user_id(user_id_), user_name(user_name_), address(address_), client_key(client_key_) + : user_name(user_name_), user_id(user_id_), enabled_roles(enabled_roles_), address(address_), client_key(client_key_) { } diff --git a/dbms/src/Access/QuotaContext.h b/dbms/src/Access/QuotaContext.h index be3a36ae3eb..d788a08ea17 100644 --- a/dbms/src/Access/QuotaContext.h +++ b/dbms/src/Access/QuotaContext.h @@ -47,7 +47,7 @@ private: friend struct ext::shared_ptr_helper; /// Instances of this class are created by QuotaContextFactory. - QuotaContext(const UUID & user_id_, const String & user_name_, const Poco::Net::IPAddress & address_, const String & client_key_); + QuotaContext(const String & user_name_, const UUID & user_id_, const std::vector & enabled_roles_, const Poco::Net::IPAddress & address_, const String & client_key_); static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE; @@ -76,8 +76,9 @@ private: struct Impl; - const UUID user_id; const String user_name; + const UUID user_id; + const std::vector enabled_roles; const Poco::Net::IPAddress address; const String client_key; boost::atomic_shared_ptr intervals; /// atomically changed by QuotaUsageManager diff --git a/dbms/src/Access/QuotaContextFactory.cpp b/dbms/src/Access/QuotaContextFactory.cpp index 2e828b148ae..f986ee86c01 100644 --- a/dbms/src/Access/QuotaContextFactory.cpp +++ b/dbms/src/Access/QuotaContextFactory.cpp @@ -42,7 +42,7 @@ void QuotaContextFactory::QuotaInfo::setQuota(const QuotaPtr & quota_, const UUI bool QuotaContextFactory::QuotaInfo::canUseWithContext(const QuotaContext & context) const { - return roles->match(context.user_id); + return roles->match(context.user_id, context.enabled_roles); } @@ -175,11 +175,11 @@ QuotaContextFactory::~QuotaContextFactory() } -QuotaContextPtr QuotaContextFactory::createContext(const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & client_key) +QuotaContextPtr QuotaContextFactory::createContext(const String & user_name, const UUID & user_id, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & client_key) { std::lock_guard lock{mutex}; ensureAllQuotasRead(); - auto context = ext::shared_ptr_helper::create(user_id, user_name, address, client_key); + auto context = ext::shared_ptr_helper::create(user_name, user_id, enabled_roles, address, client_key); contexts.push_back(context); chooseQuotaForContext(context); return context; diff --git a/dbms/src/Access/QuotaContextFactory.h b/dbms/src/Access/QuotaContextFactory.h index 6d9fdede833..c130da4f2cd 100644 --- a/dbms/src/Access/QuotaContextFactory.h +++ b/dbms/src/Access/QuotaContextFactory.h @@ -20,7 +20,7 @@ public: QuotaContextFactory(const AccessControlManager & access_control_manager_); ~QuotaContextFactory(); - QuotaContextPtr createContext(const UUID & user_id, const String & user_name, const Poco::Net::IPAddress & address, const String & client_key); + QuotaContextPtr createContext(const String & user_name, const UUID & user_id, const std::vector & enabled_roles, const Poco::Net::IPAddress & address, const String & client_key); std::vector getUsageInfo() const; private: diff --git a/dbms/src/Access/Role.cpp b/dbms/src/Access/Role.cpp new file mode 100644 index 00000000000..7b1a395feec --- /dev/null +++ b/dbms/src/Access/Role.cpp @@ -0,0 +1,16 @@ +#include + + +namespace DB +{ + +bool Role::equal(const IAccessEntity & other) const +{ + if (!IAccessEntity::equal(other)) + return false; + const auto & other_role = typeid_cast(other); + return (access == other_role.access) && (access_with_grant_option == other_role.access_with_grant_option) + && (granted_roles == other_role.granted_roles) && (granted_roles_with_admin_option == other_role.granted_roles_with_admin_option); +} + +} diff --git a/dbms/src/Access/Role.h b/dbms/src/Access/Role.h new file mode 100644 index 00000000000..eaeb8debd3a --- /dev/null +++ b/dbms/src/Access/Role.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +#include +#include + + +namespace DB +{ + +struct Role : public IAccessEntity +{ + AccessRights access; + AccessRights access_with_grant_option; + boost::container::flat_set granted_roles; + boost::container::flat_set granted_roles_with_admin_option; + + bool equal(const IAccessEntity & other) const override; + std::shared_ptr clone() const override { return cloneImpl(); } +}; + +using RolePtr = std::shared_ptr; +} diff --git a/dbms/src/Access/RoleContext.cpp b/dbms/src/Access/RoleContext.cpp new file mode 100644 index 00000000000..291b44027d4 --- /dev/null +++ b/dbms/src/Access/RoleContext.cpp @@ -0,0 +1,200 @@ +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace +{ + void makeUnique(std::vector & vec) + { + boost::range::sort(vec); + vec.erase(std::unique(vec.begin(), vec.end()), vec.end()); + } +} + + +RoleContext::RoleContext(const AccessControlManager & manager_, const UUID & current_role_, bool with_admin_option_) + : manager(&manager_), current_role(current_role_), with_admin_option(with_admin_option_) +{ + update(); +} + + +RoleContext::RoleContext(std::vector && children_) + : children(std::move(children_)) +{ + update(); +} + + +RoleContext::~RoleContext() = default; + + +void RoleContext::update() +{ + std::vector handlers_to_notify; + CurrentRolesInfoPtr info_to_notify; + + { + std::lock_guard lock{mutex}; + auto old_info = info; + + updateImpl(); + + if (!handlers.empty() && (!old_info || (*old_info != *info))) + { + boost::range::copy(handlers, std::back_inserter(handlers_to_notify)); + info_to_notify = info; + } + } + + for (const auto & handler : handlers_to_notify) + handler(info_to_notify); +} + + +void RoleContext::updateImpl() +{ + if (!current_role && children.empty()) + { + info = std::make_shared(); + return; + } + + if (!children.empty()) + { + if (subscriptions_for_change_children.empty()) + { + for (const auto & child : children) + subscriptions_for_change_children.emplace_back( + child->subscribeForChanges([this](const CurrentRolesInfoPtr &) { update(); })); + } + + auto new_info = std::make_shared(); + auto & new_info_ref = *new_info; + + for (const auto & child : children) + { + auto child_info = child->getInfo(); + new_info_ref.access.merge(child_info->access); + new_info_ref.access_with_grant_option.merge(child_info->access_with_grant_option); + boost::range::copy(child_info->current_roles, std::back_inserter(new_info_ref.current_roles)); + boost::range::copy(child_info->enabled_roles, std::back_inserter(new_info_ref.enabled_roles)); + boost::range::copy(child_info->enabled_roles_with_admin_option, std::back_inserter(new_info_ref.enabled_roles_with_admin_option)); + boost::range::copy(child_info->names_of_roles, std::inserter(new_info_ref.names_of_roles, new_info_ref.names_of_roles.end())); + } + makeUnique(new_info_ref.current_roles); + makeUnique(new_info_ref.enabled_roles); + makeUnique(new_info_ref.enabled_roles_with_admin_option); + info = new_info; + return; + } + + assert(current_role); + traverseRoles(*current_role, with_admin_option); + + auto new_info = std::make_shared(); + auto & new_info_ref = *new_info; + + for (auto it = roles_map.begin(); it != roles_map.end();) + { + const auto & id = it->first; + auto & entry = it->second; + if (!entry.in_use) + { + it = roles_map.erase(it); + continue; + } + + if (id == *current_role) + new_info_ref.current_roles.push_back(id); + + new_info_ref.enabled_roles.push_back(id); + + if (entry.with_admin_option) + new_info_ref.enabled_roles_with_admin_option.push_back(id); + + new_info_ref.access.merge(entry.role->access); + new_info_ref.access_with_grant_option.merge(entry.role->access_with_grant_option); + new_info_ref.names_of_roles[id] = entry.role->getName(); + + entry.in_use = false; + entry.with_admin_option = false; + ++it; + } + + info = new_info; +} + + +void RoleContext::traverseRoles(const UUID & id_, bool with_admin_option_) +{ + auto it = roles_map.find(id_); + if (it == roles_map.end()) + { + assert(manager); + auto subscription = manager->subscribeForChanges(id_, [this, id_](const UUID &, const AccessEntityPtr & entity) + { + { + std::lock_guard lock{mutex}; + auto it2 = roles_map.find(id_); + if (it2 == roles_map.end()) + return; + if (entity) + it2->second.role = typeid_cast(entity); + else + roles_map.erase(it2); + } + update(); + }); + + auto role = manager->tryRead(id_); + if (!role) + return; + + RoleEntry new_entry; + new_entry.role = role; + new_entry.subscription_for_change_role = std::move(subscription); + it = roles_map.emplace(id_, std::move(new_entry)).first; + } + + RoleEntry & entry = it->second; + entry.with_admin_option |= with_admin_option_; + if (entry.in_use) + return; + + entry.in_use = true; + for (const auto & granted_role : entry.role->granted_roles) + traverseRoles(granted_role, false); + + for (const auto & granted_role : entry.role->granted_roles_with_admin_option) + traverseRoles(granted_role, true); +} + + +CurrentRolesInfoPtr RoleContext::getInfo() const +{ + std::lock_guard lock{mutex}; + return info; +} + + +ext::scope_guard RoleContext::subscribeForChanges(const OnChangeHandler & handler) const +{ + std::lock_guard lock{mutex}; + handlers.push_back(handler); + auto it = std::prev(handlers.end()); + + return [this, it] + { + std::lock_guard lock2{mutex}; + handlers.erase(it); + }; +} +} diff --git a/dbms/src/Access/RoleContext.h b/dbms/src/Access/RoleContext.h new file mode 100644 index 00000000000..9b54a0e624e --- /dev/null +++ b/dbms/src/Access/RoleContext.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +struct Role; +using RolePtr = std::shared_ptr; +class CurrentRolesInfo; +using CurrentRolesInfoPtr = std::shared_ptr; +class AccessControlManager; + + +class RoleContext +{ +public: + ~RoleContext(); + + /// Returns all the roles specified in the constructor. + CurrentRolesInfoPtr getInfo() const; + + using OnChangeHandler = std::function; + + /// Called when either the specified roles or the roles granted to the specified roles are changed. + ext::scope_guard subscribeForChanges(const OnChangeHandler & handler) const; + +private: + friend struct ext::shared_ptr_helper; + RoleContext(const AccessControlManager & manager_, const UUID & current_role_, bool with_admin_option_); + RoleContext(std::vector> && children_); + + void update(); + void updateImpl(); + + void traverseRoles(const UUID & id_, bool with_admin_option_); + + const AccessControlManager * manager = nullptr; + std::optional current_role; + bool with_admin_option = false; + std::vector> children; + std::vector subscriptions_for_change_children; + + struct RoleEntry + { + RolePtr role; + ext::scope_guard subscription_for_change_role; + bool with_admin_option = false; + bool in_use = false; + }; + mutable std::unordered_map roles_map; + mutable CurrentRolesInfoPtr info; + mutable std::list handlers; + mutable std::mutex mutex; +}; + +using RoleContextPtr = std::shared_ptr; +} diff --git a/dbms/src/Access/RoleContextFactory.cpp b/dbms/src/Access/RoleContextFactory.cpp new file mode 100644 index 00000000000..3356bc238db --- /dev/null +++ b/dbms/src/Access/RoleContextFactory.cpp @@ -0,0 +1,52 @@ +#include +#include + + +namespace DB +{ + +RoleContextFactory::RoleContextFactory(const AccessControlManager & manager_) + : manager(manager_), cache(600000 /* 10 minutes */) {} + + +RoleContextFactory::~RoleContextFactory() = default; + + +RoleContextPtr RoleContextFactory::createContext( + const std::vector & roles, const std::vector & roles_with_admin_option) +{ + if (roles.size() == 1 && roles_with_admin_option.empty()) + return createContextImpl(roles[0], false); + + if (roles.size() == 1 && roles_with_admin_option == roles) + return createContextImpl(roles[0], true); + + std::vector children; + children.reserve(roles.size()); + for (const auto & role : roles_with_admin_option) + children.push_back(createContextImpl(role, true)); + + boost::container::flat_set roles_with_admin_option_set{roles_with_admin_option.begin(), roles_with_admin_option.end()}; + for (const auto & role : roles) + { + if (!roles_with_admin_option_set.contains(role)) + children.push_back(createContextImpl(role, false)); + } + + return ext::shared_ptr_helper::create(std::move(children)); +} + + +RoleContextPtr RoleContextFactory::createContextImpl(const UUID & id, bool with_admin_option) +{ + std::lock_guard lock{mutex}; + auto key = std::make_pair(id, with_admin_option); + auto x = cache.get(key); + if (x) + return *x; + auto res = ext::shared_ptr_helper::create(manager, id, with_admin_option); + cache.add(key, res); + return res; +} + +} diff --git a/dbms/src/Access/RoleContextFactory.h b/dbms/src/Access/RoleContextFactory.h new file mode 100644 index 00000000000..659c9a218a1 --- /dev/null +++ b/dbms/src/Access/RoleContextFactory.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include +#include + + +namespace DB +{ +class AccessControlManager; + + +class RoleContextFactory +{ +public: + RoleContextFactory(const AccessControlManager & manager_); + ~RoleContextFactory(); + + RoleContextPtr createContext(const std::vector & roles, const std::vector & roles_with_admin_option); + +private: + RoleContextPtr createContextImpl(const UUID & id, bool with_admin_option); + + const AccessControlManager & manager; + Poco::ExpireCache, RoleContextPtr> cache; + std::mutex mutex; +}; + +} diff --git a/dbms/src/Access/RowPolicyContext.cpp b/dbms/src/Access/RowPolicyContext.cpp index 753f8d6d3f7..661a6cb4b5f 100644 --- a/dbms/src/Access/RowPolicyContext.cpp +++ b/dbms/src/Access/RowPolicyContext.cpp @@ -23,8 +23,8 @@ RowPolicyContext::RowPolicyContext() RowPolicyContext::~RowPolicyContext() = default; -RowPolicyContext::RowPolicyContext(const UUID & user_id_) - : user_id(user_id_) +RowPolicyContext::RowPolicyContext(const UUID & user_id_, const std::vector & enabled_roles_) + : user_id(user_id_), enabled_roles(enabled_roles_) {} diff --git a/dbms/src/Access/RowPolicyContext.h b/dbms/src/Access/RowPolicyContext.h index 937cfc131b6..2042b85bf7a 100644 --- a/dbms/src/Access/RowPolicyContext.h +++ b/dbms/src/Access/RowPolicyContext.h @@ -42,7 +42,7 @@ public: private: friend class RowPolicyContextFactory; friend struct ext::shared_ptr_helper; - RowPolicyContext(const UUID & user_id_); /// RowPolicyContext should be created by RowPolicyContextFactory. + RowPolicyContext(const UUID & user_id_, const std::vector & enabled_roles_); /// RowPolicyContext should be created by RowPolicyContextFactory. using DatabaseAndTableName = std::pair; using DatabaseAndTableNameRef = std::pair; @@ -61,6 +61,7 @@ private: using MapOfMixedConditions = std::unordered_map; const UUID user_id; + const std::vector enabled_roles; mutable boost::atomic_shared_ptr map_of_mixed_conditions; }; diff --git a/dbms/src/Access/RowPolicyContextFactory.cpp b/dbms/src/Access/RowPolicyContextFactory.cpp index ba58a11e61f..49a23c4d61a 100644 --- a/dbms/src/Access/RowPolicyContextFactory.cpp +++ b/dbms/src/Access/RowPolicyContextFactory.cpp @@ -164,7 +164,7 @@ void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_ bool RowPolicyContextFactory::PolicyInfo::canUseWithContext(const RowPolicyContext & context) const { - return roles->match(context.user_id); + return roles->match(context.user_id, context.enabled_roles); } @@ -176,11 +176,11 @@ RowPolicyContextFactory::RowPolicyContextFactory(const AccessControlManager & ac RowPolicyContextFactory::~RowPolicyContextFactory() = default; -RowPolicyContextPtr RowPolicyContextFactory::createContext(const UUID & user_id) +RowPolicyContextPtr RowPolicyContextFactory::createContext(const UUID & user_id, const std::vector & enabled_roles) { std::lock_guard lock{mutex}; ensureAllRowPoliciesRead(); - auto context = ext::shared_ptr_helper::create(user_id); + auto context = ext::shared_ptr_helper::create(user_id, enabled_roles); contexts.push_back(context); mixConditionsForContext(*context); return context; diff --git a/dbms/src/Access/RowPolicyContextFactory.h b/dbms/src/Access/RowPolicyContextFactory.h index 911f795bcc1..d93d1626b24 100644 --- a/dbms/src/Access/RowPolicyContextFactory.h +++ b/dbms/src/Access/RowPolicyContextFactory.h @@ -17,7 +17,7 @@ public: RowPolicyContextFactory(const AccessControlManager & access_control_manager_); ~RowPolicyContextFactory(); - RowPolicyContextPtr createContext(const UUID & user_id); + RowPolicyContextPtr createContext(const UUID & user_id, const std::vector & enabled_roles); private: using ParsedConditions = RowPolicyContext::ParsedConditions; diff --git a/dbms/src/Access/User.cpp b/dbms/src/Access/User.cpp index 2efe7ed1076..bc5b062db6a 100644 --- a/dbms/src/Access/User.cpp +++ b/dbms/src/Access/User.cpp @@ -11,7 +11,8 @@ bool User::equal(const IAccessEntity & other) const const auto & other_user = typeid_cast(other); return (authentication == other_user.authentication) && (allowed_client_hosts == other_user.allowed_client_hosts) && (access == other_user.access) && (access_with_grant_option == other_user.access_with_grant_option) - && (profile == other_user.profile); + && (granted_roles == other_user.granted_roles) && (granted_roles_with_admin_option == other_user.granted_roles_with_admin_option) + && (default_roles == other_user.default_roles) && (profile == other_user.profile); } } diff --git a/dbms/src/Access/User.h b/dbms/src/Access/User.h index 9db9a8bcf4a..3a9b3cd7014 100644 --- a/dbms/src/Access/User.h +++ b/dbms/src/Access/User.h @@ -4,7 +4,9 @@ #include #include #include -#include +#include +#include +#include namespace DB @@ -14,9 +16,12 @@ namespace DB struct User : public IAccessEntity { Authentication authentication; - AllowedClientHosts allowed_client_hosts{AllowedClientHosts::AnyHostTag{}}; + AllowedClientHosts allowed_client_hosts = AllowedClientHosts::AnyHostTag{}; AccessRights access; AccessRights access_with_grant_option; + boost::container::flat_set granted_roles; + boost::container::flat_set granted_roles_with_admin_option; + GenericRoleSet default_roles = GenericRoleSet::AllTag{}; String profile; bool equal(const IAccessEntity & other) const override; diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index 718fc0cbf89..fe2c95fd6bc 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -482,6 +482,7 @@ namespace ErrorCodes extern const int UNKNOWN_ACCESS_TYPE = 508; extern const int INVALID_GRANT = 509; extern const int CACHE_DICTIONARY_UPDATE_FAIL = 510; + extern const int UNKNOWN_ROLE = 511; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index 3fc406350cf..a3919a8ed67 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -658,12 +658,14 @@ void Context::setUser(const String & name, const String & password, const Poco:: client_info.quota_key = quota_key; auto new_user_id = getAccessControlManager().getID(name); - auto new_access_rights = getAccessControlManager().getAccessRightsContext(new_user_id, settings, current_database, client_info); + auto new_access_rights = getAccessControlManager().getAccessRightsContext(new_user_id, {}, true, settings, current_database, client_info); new_access_rights->checkHostIsAllowed(); new_access_rights->checkPassword(password); user_id = new_user_id; access_rights = std::move(new_access_rights); + current_roles.clear(); + use_default_roles = true; calculateUserSettings(); } @@ -689,11 +691,52 @@ UUID Context::getUserID() const } +void Context::setCurrentRoles(const std::vector & current_roles_) +{ + auto lock = getLock(); + if (current_roles == current_roles_ && !use_default_roles) + return; + current_roles = current_roles_; + use_default_roles = false; + calculateAccessRights(); +} + +void Context::setCurrentRolesDefault() +{ + auto lock = getLock(); + if (use_default_roles) + return; + current_roles.clear(); + use_default_roles = true; + calculateAccessRights(); +} + +std::vector Context::getCurrentRoles() const +{ + return getAccessRights()->getCurrentRoles(); +} + +Strings Context::getCurrentRolesNames() const +{ + return getAccessRights()->getCurrentRolesNames(); +} + +std::vector Context::getEnabledRoles() const +{ + return getAccessRights()->getEnabledRoles(); +} + +Strings Context::getEnabledRolesNames() const +{ + return getAccessRights()->getEnabledRolesNames(); +} + + void Context::calculateAccessRights() { auto lock = getLock(); if (user_id) - access_rights = getAccessControlManager().getAccessRightsContext(*user_id, settings, current_database, client_info); + access_rights = getAccessControlManager().getAccessRightsContext(*user_id, current_roles, use_default_roles, settings, current_database, client_info); } @@ -728,7 +771,7 @@ void Context::setInitialRowPolicy() auto lock = getLock(); auto initial_user_id = getAccessControlManager().find(client_info.initial_user); if (initial_user_id) - initial_row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id); + initial_row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id, {}); } RowPolicyContextPtr Context::getInitialRowPolicy() const diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 11017ec778e..40909a192a3 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -136,7 +136,6 @@ struct IHostContext using IHostContextPtr = std::shared_ptr; - /** A set of known objects that can be used in the query. * Consists of a shared part (always common to all sessions and queries) * and copied part (which can be its own for each session or query). @@ -156,6 +155,8 @@ private: InputBlocksReader input_blocks_reader; std::optional user_id; + std::vector current_roles; + bool use_default_roles = false; AccessRightsContextPtr access_rights; RowPolicyContextPtr initial_row_policy; String current_database; @@ -244,6 +245,13 @@ public: String getUserName() const; UUID getUserID() const; + void setCurrentRoles(const std::vector & current_roles_); + void setCurrentRolesDefault(); + std::vector getCurrentRoles() const; + Strings getCurrentRolesNames() const; + std::vector getEnabledRoles() const; + Strings getEnabledRolesNames() const; + /// Checks access rights. /// Empty database means the current database. void checkAccess(const AccessFlags & access) const; diff --git a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp index 86a4699a636..359b76ec75e 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp @@ -79,7 +79,7 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateUserQuery(const ASTShowC auto create_query = std::make_shared(); create_query->name = user->getName(); - if (!user->allowed_client_hosts.containsAnyHost()) + if (user->allowed_client_hosts != AllowedClientHosts::AnyHostTag{}) create_query->hosts = user->allowed_client_hosts; if (!user->profile.empty()) From 6671ca67eb57e8d3a2315a42b2266bde03f7616c Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Thu, 20 Feb 2020 02:12:56 +0300 Subject: [PATCH 12/19] Add access type ROLE_ADMIN. --- dbms/src/Access/AccessFlags.h | 3 ++- dbms/src/Access/AccessRightsContext.cpp | 5 ++++- dbms/src/Access/AccessType.h | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/dbms/src/Access/AccessFlags.h b/dbms/src/Access/AccessFlags.h index 1b1934a3f01..aef32b534f8 100644 --- a/dbms/src/Access/AccessFlags.h +++ b/dbms/src/Access/AccessFlags.h @@ -345,7 +345,8 @@ private: auto create_quota = std::make_unique("CREATE QUOTA", next_flag++, GLOBAL_LEVEL); auto alter_quota = std::make_unique("ALTER QUOTA", next_flag++, GLOBAL_LEVEL); auto drop_quota = std::make_unique("DROP QUOTA", next_flag++, GLOBAL_LEVEL); - ext::push_back(all, std::move(create_user), std::move(alter_user), std::move(drop_user), std::move(create_role), std::move(drop_role), std::move(create_policy), std::move(alter_policy), std::move(drop_policy), std::move(create_quota), std::move(alter_quota), std::move(drop_quota)); + auto role_admin = std::make_unique("ROLE ADMIN", next_flag++, GLOBAL_LEVEL); + ext::push_back(all, std::move(create_user), std::move(alter_user), std::move(drop_user), std::move(create_role), std::move(drop_role), std::move(create_policy), std::move(alter_policy), std::move(drop_policy), std::move(create_quota), std::move(alter_quota), std::move(drop_quota), std::move(role_admin)); auto shutdown = std::make_unique("SHUTDOWN", next_flag++, GLOBAL_LEVEL); ext::push_back(shutdown->aliases, "SYSTEM SHUTDOWN", "SYSTEM KILL"); diff --git a/dbms/src/Access/AccessRightsContext.cpp b/dbms/src/Access/AccessRightsContext.cpp index 39ee78e4601..9a32a1234f0 100644 --- a/dbms/src/Access/AccessRightsContext.cpp +++ b/dbms/src/Access/AccessRightsContext.cpp @@ -355,6 +355,9 @@ void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) void AccessRightsContext::checkAdminOption(const UUID & role_id) const { + if (isGranted(AccessType::ROLE_ADMIN)) + return; + boost::shared_ptr> enabled_roles = enabled_roles_with_admin_option.load(); if (!enabled_roles) { @@ -429,7 +432,7 @@ boost::shared_ptr AccessRightsContext::calculateResultAccess static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE; static const AccessFlags all_dcl = AccessType::CREATE_USER | AccessType::CREATE_ROLE | AccessType::CREATE_POLICY | AccessType::CREATE_QUOTA | AccessType::ALTER_USER | AccessType::ALTER_POLICY | AccessType::ALTER_QUOTA | AccessType::DROP_USER - | AccessType::DROP_ROLE | AccessType::DROP_POLICY | AccessType::DROP_QUOTA; + | AccessType::DROP_ROLE | AccessType::DROP_POLICY | AccessType::DROP_QUOTA | AccessType::ROLE_ADMIN; /// Anyone has access to the "system" database. if (!result.isGranted(AccessType::SELECT, "system")) diff --git a/dbms/src/Access/AccessType.h b/dbms/src/Access/AccessType.h index 1c829f57f63..d3589a237be 100644 --- a/dbms/src/Access/AccessType.h +++ b/dbms/src/Access/AccessType.h @@ -94,6 +94,8 @@ enum class AccessType ALTER_QUOTA, DROP_QUOTA, + ROLE_ADMIN, /// allows to grant and revoke any roles. + SHUTDOWN, DROP_CACHE, RELOAD_CONFIG, @@ -250,6 +252,7 @@ namespace impl ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_QUOTA); ACCESS_TYPE_TO_KEYWORD_CASE(ALTER_QUOTA); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_QUOTA); + ACCESS_TYPE_TO_KEYWORD_CASE(ROLE_ADMIN); ACCESS_TYPE_TO_KEYWORD_CASE(SHUTDOWN); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_CACHE); From e017bacc48450a657802b2c6f7276e45670bc54a Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Mon, 17 Feb 2020 05:59:56 +0300 Subject: [PATCH 13/19] Implement SQL queries for creating and controlling roles. --- dbms/src/Common/ErrorCodes.cpp | 1 + .../InterpreterCreateRoleQuery.cpp | 62 ++++++++++++ .../Interpreters/InterpreterCreateRoleQuery.h | 26 +++++ .../InterpreterCreateUserQuery.cpp | 40 +++++++- .../Interpreters/InterpreterCreateUserQuery.h | 3 +- .../InterpreterDropAccessEntityQuery.cpp | 13 ++- dbms/src/Interpreters/InterpreterFactory.cpp | 12 +++ .../Interpreters/InterpreterGrantQuery.cpp | 87 ++++++++++++++--- .../Interpreters/InterpreterSetRoleQuery.cpp | 95 +++++++++++++++++++ .../Interpreters/InterpreterSetRoleQuery.h | 30 ++++++ ...InterpreterShowCreateAccessEntityQuery.cpp | 3 + .../InterpreterShowGrantsQuery.cpp | 53 ++++++++++- dbms/src/Parsers/ASTCreateRoleQuery.cpp | 46 +++++++++ dbms/src/Parsers/ASTCreateRoleQuery.h | 29 ++++++ dbms/src/Parsers/ASTCreateUserQuery.cpp | 11 +++ dbms/src/Parsers/ASTCreateUserQuery.h | 6 ++ dbms/src/Parsers/ASTDropAccessEntityQuery.cpp | 1 + dbms/src/Parsers/ASTDropAccessEntityQuery.h | 6 +- dbms/src/Parsers/ASTGrantQuery.cpp | 32 ++++++- dbms/src/Parsers/ASTGrantQuery.h | 5 + dbms/src/Parsers/ASTSetRoleQuery.cpp | 43 +++++++++ dbms/src/Parsers/ASTSetRoleQuery.h | 31 ++++++ dbms/src/Parsers/ParserCreateRoleQuery.cpp | 70 ++++++++++++++ dbms/src/Parsers/ParserCreateRoleQuery.h | 20 ++++ dbms/src/Parsers/ParserCreateUserQuery.cpp | 22 +++++ .../Parsers/ParserDropAccessEntityQuery.cpp | 10 +- .../src/Parsers/ParserDropAccessEntityQuery.h | 3 +- dbms/src/Parsers/ParserGrantQuery.cpp | 37 +++++++- dbms/src/Parsers/ParserQuery.cpp | 6 ++ dbms/src/Parsers/ParserSetRoleQuery.cpp | 80 ++++++++++++++++ dbms/src/Parsers/ParserSetRoleQuery.h | 18 ++++ 31 files changed, 864 insertions(+), 37 deletions(-) create mode 100644 dbms/src/Interpreters/InterpreterCreateRoleQuery.cpp create mode 100644 dbms/src/Interpreters/InterpreterCreateRoleQuery.h create mode 100644 dbms/src/Interpreters/InterpreterSetRoleQuery.cpp create mode 100644 dbms/src/Interpreters/InterpreterSetRoleQuery.h create mode 100644 dbms/src/Parsers/ASTCreateRoleQuery.cpp create mode 100644 dbms/src/Parsers/ASTCreateRoleQuery.h create mode 100644 dbms/src/Parsers/ASTSetRoleQuery.cpp create mode 100644 dbms/src/Parsers/ASTSetRoleQuery.h create mode 100644 dbms/src/Parsers/ParserCreateRoleQuery.cpp create mode 100644 dbms/src/Parsers/ParserCreateRoleQuery.h create mode 100644 dbms/src/Parsers/ParserSetRoleQuery.cpp create mode 100644 dbms/src/Parsers/ParserSetRoleQuery.h diff --git a/dbms/src/Common/ErrorCodes.cpp b/dbms/src/Common/ErrorCodes.cpp index fe2c95fd6bc..83301671163 100644 --- a/dbms/src/Common/ErrorCodes.cpp +++ b/dbms/src/Common/ErrorCodes.cpp @@ -483,6 +483,7 @@ namespace ErrorCodes extern const int INVALID_GRANT = 509; extern const int CACHE_DICTIONARY_UPDATE_FAIL = 510; extern const int UNKNOWN_ROLE = 511; + extern const int SET_NON_GRANTED_ROLE = 512; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/dbms/src/Interpreters/InterpreterCreateRoleQuery.cpp b/dbms/src/Interpreters/InterpreterCreateRoleQuery.cpp new file mode 100644 index 00000000000..f1c58f9d9bd --- /dev/null +++ b/dbms/src/Interpreters/InterpreterCreateRoleQuery.cpp @@ -0,0 +1,62 @@ +#include +#include +#include +#include +#include + + +namespace DB +{ +BlockIO InterpreterCreateRoleQuery::execute() +{ + const auto & query = query_ptr->as(); + auto & access_control = context.getAccessControlManager(); + if (query.alter) + context.checkAccess(AccessType::CREATE_ROLE | AccessType::DROP_ROLE); + else + context.checkAccess(AccessType::CREATE_ROLE); + + if (query.alter) + { + auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr + { + auto updated_role = typeid_cast>(entity->clone()); + updateRoleFromQuery(*updated_role, query); + return updated_role; + }; + if (query.if_exists) + { + if (auto id = access_control.find(query.name)) + access_control.tryUpdate(*id, update_func); + } + else + access_control.update(access_control.getID(query.name), update_func); + } + else + { + auto new_role = std::make_shared(); + updateRoleFromQuery(*new_role, query); + + if (query.if_not_exists) + access_control.tryInsert(new_role); + else if (query.or_replace) + access_control.insertOrReplace(new_role); + else + access_control.insert(new_role); + } + + return {}; +} + + +void InterpreterCreateRoleQuery::updateRoleFromQuery(Role & role, const ASTCreateRoleQuery & query) +{ + if (query.alter) + { + if (!query.new_name.empty()) + role.setName(query.new_name); + } + else + role.setName(query.name); +} +} diff --git a/dbms/src/Interpreters/InterpreterCreateRoleQuery.h b/dbms/src/Interpreters/InterpreterCreateRoleQuery.h new file mode 100644 index 00000000000..8ceb645ea78 --- /dev/null +++ b/dbms/src/Interpreters/InterpreterCreateRoleQuery.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include + + +namespace DB +{ +class ASTCreateRoleQuery; +struct Role; + + +class InterpreterCreateRoleQuery : public IInterpreter +{ +public: + InterpreterCreateRoleQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {} + + BlockIO execute() override; + +private: + void updateRoleFromQuery(Role & role, const ASTCreateRoleQuery & query); + + ASTPtr query_ptr; + Context & context; +}; +} diff --git a/dbms/src/Interpreters/InterpreterCreateUserQuery.cpp b/dbms/src/Interpreters/InterpreterCreateUserQuery.cpp index 8f3c8a9f2bf..db7f34a2184 100644 --- a/dbms/src/Interpreters/InterpreterCreateUserQuery.cpp +++ b/dbms/src/Interpreters/InterpreterCreateUserQuery.cpp @@ -1,24 +1,47 @@ #include -#include #include +#include +#include #include #include +#include +#include +#include namespace DB { +namespace ErrorCodes +{ + extern const int SET_NON_GRANTED_ROLE; +} + + BlockIO InterpreterCreateUserQuery::execute() { const auto & query = query_ptr->as(); auto & access_control = context.getAccessControlManager(); context.checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER); + GenericRoleSet * default_roles_from_query = nullptr; + GenericRoleSet temp_role_set; + if (query.default_roles) + { + default_roles_from_query = &temp_role_set; + *default_roles_from_query = GenericRoleSet{*query.default_roles, access_control}; + if (!query.alter && !default_roles_from_query->all) + { + for (const UUID & role : default_roles_from_query->getMatchingIDs()) + context.getAccessRights()->checkAdminOption(role); + } + } + if (query.alter) { auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { auto updated_user = typeid_cast>(entity->clone()); - updateUserFromQuery(*updated_user, query); + updateUserFromQuery(*updated_user, query, default_roles_from_query); return updated_user; }; if (query.if_exists) @@ -32,7 +55,7 @@ BlockIO InterpreterCreateUserQuery::execute() else { auto new_user = std::make_shared(); - updateUserFromQuery(*new_user, query); + updateUserFromQuery(*new_user, query, default_roles_from_query); if (query.if_not_exists) access_control.tryInsert(new_user); @@ -46,7 +69,7 @@ BlockIO InterpreterCreateUserQuery::execute() } -void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query) +void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query, const GenericRoleSet * default_roles_from_query) { if (query.alter) { @@ -66,7 +89,16 @@ void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreat if (query.add_hosts) user.allowed_client_hosts.add(*query.add_hosts); + if (default_roles_from_query) + { + if (!query.alter && !default_roles_from_query->all) + boost::range::copy(default_roles_from_query->getMatchingIDs(), std::inserter(user.granted_roles, user.granted_roles.end())); + + InterpreterSetRoleQuery::updateUserSetDefaultRoles(user, *default_roles_from_query); + } + if (query.profile) user.profile = *query.profile; } + } diff --git a/dbms/src/Interpreters/InterpreterCreateUserQuery.h b/dbms/src/Interpreters/InterpreterCreateUserQuery.h index f040a23a7c2..c2a6fc46f68 100644 --- a/dbms/src/Interpreters/InterpreterCreateUserQuery.h +++ b/dbms/src/Interpreters/InterpreterCreateUserQuery.h @@ -7,6 +7,7 @@ namespace DB { class ASTCreateUserQuery; +class GenericRoleSet; struct User; @@ -18,7 +19,7 @@ public: BlockIO execute() override; private: - void updateUserFromQuery(User & user, const ASTCreateUserQuery & query); + void updateUserFromQuery(User & user, const ASTCreateUserQuery & query, const GenericRoleSet * default_roles_from_query); ASTPtr query_ptr; Context & context; diff --git a/dbms/src/Interpreters/InterpreterDropAccessEntityQuery.cpp b/dbms/src/Interpreters/InterpreterDropAccessEntityQuery.cpp index 791314a99fa..c69ce3ade45 100644 --- a/dbms/src/Interpreters/InterpreterDropAccessEntityQuery.cpp +++ b/dbms/src/Interpreters/InterpreterDropAccessEntityQuery.cpp @@ -3,9 +3,10 @@ #include #include #include +#include +#include #include #include -#include #include @@ -29,6 +30,16 @@ BlockIO InterpreterDropAccessEntityQuery::execute() return {}; } + case Kind::ROLE: + { + context.checkAccess(AccessType::DROP_ROLE); + if (query.if_exists) + access_control.tryRemove(access_control.find(query.names)); + else + access_control.remove(access_control.getIDs(query.names)); + return {}; + } + case Kind::QUOTA: { context.checkAccess(AccessType::DROP_QUOTA); diff --git a/dbms/src/Interpreters/InterpreterFactory.cpp b/dbms/src/Interpreters/InterpreterFactory.cpp index 87c2d04b2e4..0e241aab12d 100644 --- a/dbms/src/Interpreters/InterpreterFactory.cpp +++ b/dbms/src/Interpreters/InterpreterFactory.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -44,6 +47,7 @@ #include #include #include +#include #include #include #include @@ -126,6 +130,10 @@ std::unique_ptr InterpreterFactory::get(ASTPtr & query, Context & /// readonly is checked inside InterpreterSetQuery return std::make_unique(query, context); } + else if (query->as()) + { + return std::make_unique(query, context); + } else if (query->as()) { return std::make_unique(query, context); @@ -186,6 +194,10 @@ std::unique_ptr InterpreterFactory::get(ASTPtr & query, Context & { return std::make_unique(query, context); } + else if (query->as()) + { + return std::make_unique(query, context); + } else if (query->as()) { return std::make_unique(query, context); diff --git a/dbms/src/Interpreters/InterpreterGrantQuery.cpp b/dbms/src/Interpreters/InterpreterGrantQuery.cpp index 58bb104de6a..36cba3a801b 100644 --- a/dbms/src/Interpreters/InterpreterGrantQuery.cpp +++ b/dbms/src/Interpreters/InterpreterGrantQuery.cpp @@ -5,6 +5,8 @@ #include #include #include +#include +#include namespace DB @@ -16,32 +18,89 @@ BlockIO InterpreterGrantQuery::execute() context.getAccessRights()->checkGrantOption(query.access_rights_elements); using Kind = ASTGrantQuery::Kind; - std::vector to_roles = GenericRoleSet{*query.to_roles, access_control, context.getUserID()}.getMatchingUsers(access_control); + std::vector roles; + if (query.roles) + { + roles = GenericRoleSet{*query.roles, access_control}.getMatchingRoles(access_control); + for (const UUID & role : roles) + context.getAccessRights()->checkAdminOption(role); + } + + std::vector to_roles = GenericRoleSet{*query.to_roles, access_control, context.getUserID()}.getMatchingUsersAndRoles(access_control); String current_database = context.getCurrentDatabase(); using Kind = ASTGrantQuery::Kind; auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr { - auto updated_user = typeid_cast>(entity->clone()); - if (query.kind == Kind::GRANT) + auto clone = entity->clone(); + AccessRights * access = nullptr; + AccessRights * access_with_grant_option = nullptr; + boost::container::flat_set * granted_roles = nullptr; + boost::container::flat_set * granted_roles_with_admin_option = nullptr; + GenericRoleSet * default_roles = nullptr; + if (auto user = typeid_cast>(clone)) { - updated_user->access.grant(query.access_rights_elements, current_database); - if (query.grant_option) - updated_user->access_with_grant_option.grant(query.access_rights_elements, current_database); + access = &user->access; + access_with_grant_option = &user->access_with_grant_option; + granted_roles = &user->granted_roles; + granted_roles_with_admin_option = &user->granted_roles_with_admin_option; + default_roles = &user->default_roles; } - else if (context.getSettingsRef().partial_revokes) + else if (auto role = typeid_cast>(clone)) { - updated_user->access_with_grant_option.partialRevoke(query.access_rights_elements, current_database); - if (!query.grant_option) - updated_user->access.partialRevoke(query.access_rights_elements, current_database); + access = &role->access; + access_with_grant_option = &role->access_with_grant_option; + granted_roles = &role->granted_roles; + granted_roles_with_admin_option = &role->granted_roles_with_admin_option; } else + return entity; + + if (!query.access_rights_elements.empty()) { - updated_user->access_with_grant_option.revoke(query.access_rights_elements, current_database); - if (!query.grant_option) - updated_user->access.revoke(query.access_rights_elements, current_database); + if (query.kind == Kind::GRANT) + { + access->grant(query.access_rights_elements, current_database); + if (query.grant_option) + access_with_grant_option->grant(query.access_rights_elements, current_database); + } + else if (context.getSettingsRef().partial_revokes) + { + access_with_grant_option->partialRevoke(query.access_rights_elements, current_database); + if (!query.grant_option) + access->partialRevoke(query.access_rights_elements, current_database); + } + else + { + access_with_grant_option->revoke(query.access_rights_elements, current_database); + if (!query.grant_option) + access->revoke(query.access_rights_elements, current_database); + } } - return updated_user; + + if (!roles.empty()) + { + if (query.kind == Kind::GRANT) + { + boost::range::copy(roles, std::inserter(*granted_roles, granted_roles->end())); + if (query.admin_option) + boost::range::copy(roles, std::inserter(*granted_roles_with_admin_option, granted_roles_with_admin_option->end())); + } + else + { + for (const UUID & role : roles) + { + granted_roles_with_admin_option->erase(role); + if (!query.admin_option) + { + granted_roles->erase(role); + if (default_roles) + default_roles->ids.erase(role); + } + } + } + } + return clone; }; access_control.update(to_roles, update_func); diff --git a/dbms/src/Interpreters/InterpreterSetRoleQuery.cpp b/dbms/src/Interpreters/InterpreterSetRoleQuery.cpp new file mode 100644 index 00000000000..567c626cb90 --- /dev/null +++ b/dbms/src/Interpreters/InterpreterSetRoleQuery.cpp @@ -0,0 +1,95 @@ +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace ErrorCodes +{ + extern const int SET_NON_GRANTED_ROLE; +} + + +BlockIO InterpreterSetRoleQuery::execute() +{ + const auto & query = query_ptr->as(); + if (query.kind == ASTSetRoleQuery::Kind::SET_DEFAULT_ROLE) + setDefaultRole(query); + else + setRole(query); + return {}; +} + + +void InterpreterSetRoleQuery::setRole(const ASTSetRoleQuery & query) +{ + auto & access_control = context.getAccessControlManager(); + auto & session_context = context.getSessionContext(); + auto user = session_context.getUser(); + + if (query.kind == ASTSetRoleQuery::Kind::SET_ROLE_DEFAULT) + { + session_context.setCurrentRolesDefault(); + } + else + { + GenericRoleSet roles_from_query{*query.roles, access_control}; + std::vector new_current_roles; + if (roles_from_query.all) + { + for (const auto & id : user->granted_roles) + if (roles_from_query.match(id)) + new_current_roles.push_back(id); + } + else + { + for (const auto & id : roles_from_query.getMatchingIDs()) + { + if (!user->granted_roles.contains(id)) + throw Exception("Role should be granted to set current", ErrorCodes::SET_NON_GRANTED_ROLE); + new_current_roles.push_back(id); + } + } + session_context.setCurrentRoles(new_current_roles); + } +} + + +void InterpreterSetRoleQuery::setDefaultRole(const ASTSetRoleQuery & query) +{ + context.checkAccess(AccessType::CREATE_USER | AccessType::DROP_USER); + + auto & access_control = context.getAccessControlManager(); + std::vector to_users = GenericRoleSet{*query.to_users, access_control, context.getUserID()}.getMatchingUsers(access_control); + GenericRoleSet roles_from_query{*query.roles, access_control}; + + auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr + { + auto updated_user = typeid_cast>(entity->clone()); + updateUserSetDefaultRoles(*updated_user, roles_from_query); + return updated_user; + }; + + access_control.update(to_users, update_func); +} + + +void InterpreterSetRoleQuery::updateUserSetDefaultRoles(User & user, const GenericRoleSet & roles_from_query) +{ + if (!roles_from_query.all) + { + for (const auto & id : roles_from_query.getMatchingIDs()) + { + if (!user.granted_roles.contains(id)) + throw Exception("Role should be granted to set default", ErrorCodes::SET_NON_GRANTED_ROLE); + } + } + user.default_roles = roles_from_query; +} + +} diff --git a/dbms/src/Interpreters/InterpreterSetRoleQuery.h b/dbms/src/Interpreters/InterpreterSetRoleQuery.h new file mode 100644 index 00000000000..e28aec9236c --- /dev/null +++ b/dbms/src/Interpreters/InterpreterSetRoleQuery.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include + + +namespace DB +{ +class ASTSetRoleQuery; +class GenericRoleSet; +struct User; + + +class InterpreterSetRoleQuery : public IInterpreter +{ +public: + InterpreterSetRoleQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {} + + BlockIO execute() override; + + static void updateUserSetDefaultRoles(User & user, const GenericRoleSet & roles_from_query); + +private: + void setRole(const ASTSetRoleQuery & query); + void setDefaultRole(const ASTSetRoleQuery & query); + + ASTPtr query_ptr; + Context & context; +}; +} diff --git a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp index 359b76ec75e..dcf0387a9cb 100644 --- a/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp @@ -85,6 +85,9 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateUserQuery(const ASTShowC if (!user->profile.empty()) create_query->profile = user->profile; + if (user->default_roles != GenericRoleSet::AllTag{}) + create_query->default_roles = GenericRoleSet{user->default_roles}.toAST(context.getAccessControlManager()); + return create_query; } diff --git a/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp b/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp index c1d430586ba..faa51ce1e06 100644 --- a/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp +++ b/dbms/src/Interpreters/InterpreterShowGrantsQuery.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -88,19 +89,44 @@ BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl() ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show_query) const { + const auto & access_control = context.getAccessControlManager(); UserPtr user; + RolePtr role; if (show_query.current_user) user = context.getUser(); else - user = context.getAccessControlManager().read(show_query.name); + { + user = access_control.tryRead(show_query.name); + if (!user) + role = access_control.read(show_query.name); + } + + const AccessRights * access = nullptr; + const AccessRights * access_with_grant_option = nullptr; + const boost::container::flat_set * granted_roles = nullptr; + const boost::container::flat_set * granted_roles_with_admin_option = nullptr; + if (user) + { + access = &user->access; + access_with_grant_option = &user->access_with_grant_option; + granted_roles = &user->granted_roles; + granted_roles_with_admin_option = &user->granted_roles_with_admin_option; + } + else + { + access = &role->access; + access_with_grant_option = &role->access_with_grant_option; + granted_roles = &role->granted_roles; + granted_roles_with_admin_option = &role->granted_roles_with_admin_option; + } ASTs res; for (bool grant_option : {true, false}) { - if (!grant_option && (user->access == user->access_with_grant_option)) + if (!grant_option && (*access == *access_with_grant_option)) continue; - const auto & access_rights = grant_option ? user->access_with_grant_option : user->access; + const auto & access_rights = grant_option ? *access_with_grant_option : *access; const auto grouped_elements = groupByTable(access_rights.getElements()); using Kind = ASTGrantQuery::Kind; @@ -112,13 +138,32 @@ ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show grant_query->kind = kind; grant_query->grant_option = grant_option; grant_query->to_roles = std::make_shared(); - grant_query->to_roles->names.push_back(user->getName()); + grant_query->to_roles->names.push_back(show_query.name); grant_query->access_rights_elements = elements; res.push_back(std::move(grant_query)); } } } + for (bool admin_option : {true, false}) + { + if (!admin_option && (*granted_roles == *granted_roles_with_admin_option)) + continue; + + const auto & roles = admin_option ? *granted_roles_with_admin_option : *granted_roles; + if (roles.empty()) + continue; + + auto grant_query = std::make_shared(); + using Kind = ASTGrantQuery::Kind; + grant_query->kind = Kind::GRANT; + grant_query->admin_option = admin_option; + grant_query->to_roles = std::make_shared(); + grant_query->to_roles->names.push_back(show_query.name); + grant_query->roles = GenericRoleSet{roles}.toAST(access_control); + res.push_back(std::move(grant_query)); + } + return res; } } diff --git a/dbms/src/Parsers/ASTCreateRoleQuery.cpp b/dbms/src/Parsers/ASTCreateRoleQuery.cpp new file mode 100644 index 00000000000..c11da41b2e9 --- /dev/null +++ b/dbms/src/Parsers/ASTCreateRoleQuery.cpp @@ -0,0 +1,46 @@ +#include +#include + + +namespace DB +{ +namespace +{ + void formatRenameTo(const String & new_name, const IAST::FormatSettings & settings) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " RENAME TO " << (settings.hilite ? IAST::hilite_none : "") + << quoteString(new_name); + } +} + + +String ASTCreateRoleQuery::getID(char) const +{ + return "CreateRoleQuery"; +} + + +ASTPtr ASTCreateRoleQuery::clone() const +{ + return std::make_shared(*this); +} + + +void ASTCreateRoleQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const +{ + settings.ostr << (settings.hilite ? hilite_keyword : "") << (alter ? "ALTER ROLE" : "CREATE ROLE") + << (settings.hilite ? hilite_none : ""); + + if (if_exists) + settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF EXISTS" << (settings.hilite ? hilite_none : ""); + else if (if_not_exists) + settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (settings.hilite ? hilite_none : ""); + else if (or_replace) + settings.ostr << (settings.hilite ? hilite_keyword : "") << " OR REPLACE" << (settings.hilite ? hilite_none : ""); + + settings.ostr << " " << backQuoteIfNeed(name); + + if (!new_name.empty()) + formatRenameTo(new_name, settings); +} +} diff --git a/dbms/src/Parsers/ASTCreateRoleQuery.h b/dbms/src/Parsers/ASTCreateRoleQuery.h new file mode 100644 index 00000000000..ac0a93b5d72 --- /dev/null +++ b/dbms/src/Parsers/ASTCreateRoleQuery.h @@ -0,0 +1,29 @@ +#pragma once + +#include + + +namespace DB +{ +/** CREATE ROLE [IF NOT EXISTS | OR REPLACE] name + * + * ALTER ROLE [IF EXISTS] name + * [RENAME TO new_name] + */ +class ASTCreateRoleQuery : public IAST +{ +public: + bool alter = false; + + bool if_exists = false; + bool if_not_exists = false; + bool or_replace = false; + + String name; + String new_name; + + String getID(char) const override; + ASTPtr clone() const override; + void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; +}; +} diff --git a/dbms/src/Parsers/ASTCreateUserQuery.cpp b/dbms/src/Parsers/ASTCreateUserQuery.cpp index fd2ca6d3de6..cbe5de0db8a 100644 --- a/dbms/src/Parsers/ASTCreateUserQuery.cpp +++ b/dbms/src/Parsers/ASTCreateUserQuery.cpp @@ -1,4 +1,5 @@ #include +#include #include @@ -134,6 +135,13 @@ namespace } + void formatDefaultRoles(const ASTGenericRoleSet & default_roles, const IAST::FormatSettings & settings) + { + settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " DEFAULT ROLE " << (settings.hilite ? IAST::hilite_none : ""); + default_roles.format(settings); + } + + void formatProfile(const String & profile_name, const IAST::FormatSettings & settings) { settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " PROFILE " << (settings.hilite ? IAST::hilite_none : "") @@ -181,6 +189,9 @@ void ASTCreateUserQuery::formatImpl(const FormatSettings & settings, FormatState if (remove_hosts) formatHosts("REMOVE", *remove_hosts, settings); + if (default_roles) + formatDefaultRoles(*default_roles, settings); + if (profile) formatProfile(*profile, settings); } diff --git a/dbms/src/Parsers/ASTCreateUserQuery.h b/dbms/src/Parsers/ASTCreateUserQuery.h index 055f5711440..e93df2f6901 100644 --- a/dbms/src/Parsers/ASTCreateUserQuery.h +++ b/dbms/src/Parsers/ASTCreateUserQuery.h @@ -7,15 +7,19 @@ namespace DB { +class ASTGenericRoleSet; + /** CREATE USER [IF NOT EXISTS | OR REPLACE] name * [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}] * [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] + * [DEFAULT ROLE role [,...]] * [PROFILE 'profile_name'] * * ALTER USER [IF EXISTS] name * [RENAME TO new_name] * [IDENTIFIED [WITH {PLAINTEXT_PASSWORD|SHA256_PASSWORD|DOUBLE_SHA1_PASSWORD}] BY {'password'|'hash'}] * [[ADD|REMOVE] HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] + * [DEFAULT ROLE role [,...] | ALL | ALL EXCEPT role [,...] ] * [PROFILE 'profile_name'] */ class ASTCreateUserQuery : public IAST @@ -36,6 +40,8 @@ public: std::optional add_hosts; std::optional remove_hosts; + std::shared_ptr default_roles; + std::optional profile; String getID(char) const override; diff --git a/dbms/src/Parsers/ASTDropAccessEntityQuery.cpp b/dbms/src/Parsers/ASTDropAccessEntityQuery.cpp index db67a56e558..0b6bae7575e 100644 --- a/dbms/src/Parsers/ASTDropAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ASTDropAccessEntityQuery.cpp @@ -13,6 +13,7 @@ namespace switch (kind) { case Kind::USER: return "USER"; + case Kind::ROLE: return "ROLE"; case Kind::QUOTA: return "QUOTA"; case Kind::ROW_POLICY: return "POLICY"; } diff --git a/dbms/src/Parsers/ASTDropAccessEntityQuery.h b/dbms/src/Parsers/ASTDropAccessEntityQuery.h index 6535fb18833..eea40fd5343 100644 --- a/dbms/src/Parsers/ASTDropAccessEntityQuery.h +++ b/dbms/src/Parsers/ASTDropAccessEntityQuery.h @@ -7,9 +7,10 @@ namespace DB { -/** DROP QUOTA [IF EXISTS] name [,...] +/** DROP USER [IF EXISTS] name [,...] + * DROP ROLE [IF EXISTS] name [,...] + * DROP QUOTA [IF EXISTS] name [,...] * DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...] - * DROP USER [IF EXISTS] name [,...] */ class ASTDropAccessEntityQuery : public IAST { @@ -17,6 +18,7 @@ public: enum class Kind { USER, + ROLE, QUOTA, ROW_POLICY, }; diff --git a/dbms/src/Parsers/ASTGrantQuery.cpp b/dbms/src/Parsers/ASTGrantQuery.cpp index c5132c8359c..1aaf7583e94 100644 --- a/dbms/src/Parsers/ASTGrantQuery.cpp +++ b/dbms/src/Parsers/ASTGrantQuery.cpp @@ -9,6 +9,11 @@ namespace DB { +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + namespace { using KeywordToColumnsMap = std::map /* columns */>; @@ -119,13 +124,30 @@ void ASTGrantQuery::formatImpl(const FormatSettings & settings, FormatState &, F settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ((kind == Kind::GRANT) ? "GRANT" : "REVOKE") << (settings.hilite ? IAST::hilite_none : "") << " "; - if (grant_option && (kind == Kind::REVOKE)) - settings.ostr << (settings.hilite ? hilite_keyword : "") << "GRANT OPTION FOR " << (settings.hilite ? hilite_none : ""); + if (kind == Kind::REVOKE) + { + if (grant_option) + settings.ostr << (settings.hilite ? hilite_keyword : "") << "GRANT OPTION FOR " << (settings.hilite ? hilite_none : ""); + else if (admin_option) + settings.ostr << (settings.hilite ? hilite_keyword : "") << "ADMIN OPTION FOR " << (settings.hilite ? hilite_none : ""); + } + + if ((!!roles + !access_rights_elements.empty()) != 1) + throw Exception("Either roles or access rights elements should be set", ErrorCodes::LOGICAL_ERROR); + + if (roles) + roles->format(settings); + else + formatAccessRightsElements(access_rights_elements, settings); - formatAccessRightsElements(access_rights_elements, settings); formatToRoles(*to_roles, kind, settings); - if (grant_option && (kind == Kind::GRANT)) - settings.ostr << (settings.hilite ? hilite_keyword : "") << " WITH GRANT OPTION" << (settings.hilite ? hilite_none : ""); + if (kind == Kind::GRANT) + { + if (grant_option) + settings.ostr << (settings.hilite ? hilite_keyword : "") << " WITH GRANT OPTION" << (settings.hilite ? hilite_none : ""); + else if (admin_option) + settings.ostr << (settings.hilite ? hilite_keyword : "") << " WITH ADMIN OPTION" << (settings.hilite ? hilite_none : ""); + } } } diff --git a/dbms/src/Parsers/ASTGrantQuery.h b/dbms/src/Parsers/ASTGrantQuery.h index 56663d84620..5754ef22ace 100644 --- a/dbms/src/Parsers/ASTGrantQuery.h +++ b/dbms/src/Parsers/ASTGrantQuery.h @@ -11,6 +11,9 @@ class ASTGenericRoleSet; /** GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO {user_name | CURRENT_USER} [,...] [WITH GRANT OPTION] * REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} FROM {user_name | CURRENT_USER} [,...] | ALL | ALL EXCEPT {user_name | CURRENT_USER} [,...] + * + * GRANT role [,...] TO {user_name | role_name | CURRENT_USER} [,...] [WITH ADMIN OPTION] + * REVOKE [ADMIN OPTION FOR] role [,...] FROM {user_name | role_name | CURRENT_USER} [,...] | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...] */ class ASTGrantQuery : public IAST { @@ -22,8 +25,10 @@ public: }; Kind kind = Kind::GRANT; AccessRightsElements access_rights_elements; + std::shared_ptr roles; std::shared_ptr to_roles; bool grant_option = false; + bool admin_option = false; String getID(char) const override; ASTPtr clone() const override; diff --git a/dbms/src/Parsers/ASTSetRoleQuery.cpp b/dbms/src/Parsers/ASTSetRoleQuery.cpp new file mode 100644 index 00000000000..de61f5a3113 --- /dev/null +++ b/dbms/src/Parsers/ASTSetRoleQuery.cpp @@ -0,0 +1,43 @@ +#include +#include +#include + + +namespace DB +{ +String ASTSetRoleQuery::getID(char) const +{ + return "SetRoleQuery"; +} + + +ASTPtr ASTSetRoleQuery::clone() const +{ + return std::make_shared(*this); +} + + +void ASTSetRoleQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const +{ + settings.ostr << (settings.hilite ? hilite_keyword : ""); + switch (kind) + { + case Kind::SET_ROLE: settings.ostr << "SET ROLE"; break; + case Kind::SET_ROLE_DEFAULT: settings.ostr << "SET ROLE DEFAULT"; break; + case Kind::SET_DEFAULT_ROLE: settings.ostr << "SET DEFAULT ROLE"; break; + } + settings.ostr << (settings.hilite ? hilite_none : ""); + + if (kind == Kind::SET_ROLE_DEFAULT) + return; + + settings.ostr << " "; + roles->format(settings); + + if (kind == Kind::SET_ROLE) + return; + + settings.ostr << (settings.hilite ? hilite_keyword : "") << " TO " << (settings.hilite ? hilite_none : ""); + to_users->format(settings); +} +} diff --git a/dbms/src/Parsers/ASTSetRoleQuery.h b/dbms/src/Parsers/ASTSetRoleQuery.h new file mode 100644 index 00000000000..ad22d30e287 --- /dev/null +++ b/dbms/src/Parsers/ASTSetRoleQuery.h @@ -0,0 +1,31 @@ +#pragma once + +#include + + +namespace DB +{ +class ASTGenericRoleSet; + +/** SET ROLE {DEFAULT | NONE | role [,...] | ALL | ALL EXCEPT role [,...]} + * SET DEFAULT ROLE {NONE | role [,...] | ALL | ALL EXCEPT role [,...]} TO {user|CURRENT_USER} [,...] + */ +class ASTSetRoleQuery : public IAST +{ +public: + enum class Kind + { + SET_ROLE, + SET_ROLE_DEFAULT, + SET_DEFAULT_ROLE, + }; + Kind kind = Kind::SET_ROLE; + + std::shared_ptr roles; + std::shared_ptr to_users; + + String getID(char) const override; + ASTPtr clone() const override; + void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; +}; +} diff --git a/dbms/src/Parsers/ParserCreateRoleQuery.cpp b/dbms/src/Parsers/ParserCreateRoleQuery.cpp new file mode 100644 index 00000000000..a60394d84db --- /dev/null +++ b/dbms/src/Parsers/ParserCreateRoleQuery.cpp @@ -0,0 +1,70 @@ +#include +#include +#include +#include + + +namespace DB +{ +namespace +{ + bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) + return false; + + return parseRoleName(pos, expected, new_name); + }); + } +} + + +bool ParserCreateRoleQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + bool alter; + if (ParserKeyword{"CREATE ROLE"}.ignore(pos, expected)) + alter = false; + else if (ParserKeyword{"ALTER ROLE"}.ignore(pos, expected)) + alter = true; + else + return false; + + bool if_exists = false; + bool if_not_exists = false; + bool or_replace = false; + if (alter) + { + if (ParserKeyword{"IF EXISTS"}.ignore(pos, expected)) + if_exists = true; + } + else + { + if (ParserKeyword{"IF NOT EXISTS"}.ignore(pos, expected)) + if_not_exists = true; + else if (ParserKeyword{"OR REPLACE"}.ignore(pos, expected)) + or_replace = true; + } + + String name; + if (!parseRoleName(pos, expected, name)) + return false; + + String new_name; + if (alter) + parseRenameTo(pos, expected, new_name); + + auto query = std::make_shared(); + node = query; + + query->alter = alter; + query->if_exists = if_exists; + query->if_not_exists = if_not_exists; + query->or_replace = or_replace; + query->name = std::move(name); + query->new_name = std::move(new_name); + + return true; +} +} diff --git a/dbms/src/Parsers/ParserCreateRoleQuery.h b/dbms/src/Parsers/ParserCreateRoleQuery.h new file mode 100644 index 00000000000..4a7a82617b2 --- /dev/null +++ b/dbms/src/Parsers/ParserCreateRoleQuery.h @@ -0,0 +1,20 @@ +#pragma once + +#include + + +namespace DB +{ +/** Parses queries like + * CREATE ROLE [IF NOT EXISTS | OR REPLACE] name + * + * ALTER ROLE [IF EXISTS] name + * [RENAME TO new_name] + */ +class ParserCreateRoleQuery : public IParserBase +{ +protected: + const char * getName() const override { return "CREATE ROLE or ALTER ROLE query"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +}; +} diff --git a/dbms/src/Parsers/ParserCreateUserQuery.cpp b/dbms/src/Parsers/ParserCreateUserQuery.cpp index 9b188d0c93d..bf3515489f6 100644 --- a/dbms/src/Parsers/ParserCreateUserQuery.cpp +++ b/dbms/src/Parsers/ParserCreateUserQuery.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -208,6 +209,23 @@ namespace } + bool parseDefaultRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & default_roles) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserKeyword{"DEFAULT ROLE"}.ignore(pos, expected)) + return false; + + ASTPtr ast; + if (!ParserGenericRoleSet{}.allowCurrentUser(false).parse(pos, ast, expected)) + return false; + + default_roles = typeid_cast>(ast); + return true; + }); + } + + bool parseProfileName(IParserBase::Pos & pos, Expected & expected, std::optional & profile) { return IParserBase::wrapParseImpl(pos, [&] @@ -263,6 +281,7 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec std::optional hosts; std::optional add_hosts; std::optional remove_hosts; + std::shared_ptr default_roles; std::optional profile; while (true) @@ -276,6 +295,9 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec if (!profile && parseProfileName(pos, expected, profile)) continue; + if (!default_roles && parseDefaultRoles(pos, expected, default_roles)) + continue; + if (alter) { if (new_name.empty() && parseRenameTo(pos, expected, new_name, new_host_pattern)) diff --git a/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp b/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp index b0b7aa6f83b..f257dc0fd64 100644 --- a/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp +++ b/dbms/src/Parsers/ParserDropAccessEntityQuery.cpp @@ -82,12 +82,14 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & using Kind = ASTDropAccessEntityQuery::Kind; Kind kind; - if (ParserKeyword{"QUOTA"}.ignore(pos, expected)) + if (ParserKeyword{"USER"}.ignore(pos, expected)) + kind = Kind::USER; + else if (ParserKeyword{"ROLE"}.ignore(pos, expected)) + kind = Kind::ROLE; + else if (ParserKeyword{"QUOTA"}.ignore(pos, expected)) kind = Kind::QUOTA; else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected)) kind = Kind::ROW_POLICY; - else if (ParserKeyword{"USER"}.ignore(pos, expected)) - kind = Kind::USER; else return false; @@ -98,7 +100,7 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & Strings names; std::vector row_policies_names; - if (kind == Kind::USER) + if ((kind == Kind::USER) || (kind == Kind::ROLE)) { if (!parseUserNames(pos, expected, names)) return false; diff --git a/dbms/src/Parsers/ParserDropAccessEntityQuery.h b/dbms/src/Parsers/ParserDropAccessEntityQuery.h index 5e8739c2928..e4fb323d5f6 100644 --- a/dbms/src/Parsers/ParserDropAccessEntityQuery.h +++ b/dbms/src/Parsers/ParserDropAccessEntityQuery.h @@ -6,9 +6,10 @@ namespace DB { /** Parses queries like + * DROP USER [IF EXISTS] name [,...] + * DROP ROLE [IF EXISTS] name [,...] * DROP QUOTA [IF EXISTS] name [,...] * DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...] - * DROP USER [IF EXISTS] name [,...] */ class ParserDropAccessEntityQuery : public IParserBase { diff --git a/dbms/src/Parsers/ParserGrantQuery.cpp b/dbms/src/Parsers/ParserGrantQuery.cpp index 770f5bee528..967a3150afc 100644 --- a/dbms/src/Parsers/ParserGrantQuery.cpp +++ b/dbms/src/Parsers/ParserGrantQuery.cpp @@ -10,6 +10,11 @@ namespace DB { +namespace ErrorCodes +{ + extern const int SYNTAX_ERROR; +} + namespace { bool parseRoundBrackets(IParser::Pos & pos, Expected & expected) @@ -206,6 +211,20 @@ namespace } + bool parseRoles(IParser::Pos & pos, Expected & expected, std::shared_ptr & roles) + { + return IParserBase::wrapParseImpl(pos, [&] + { + ASTPtr ast; + if (!ParserGenericRoleSet{}.allowAll(false).allowCurrentUser(false).parse(pos, ast, expected)) + return false; + + roles = typeid_cast>(ast); + return true; + }); + } + + bool parseToRoles(IParser::Pos & pos, Expected & expected, ASTGrantQuery::Kind kind, std::shared_ptr & to_roles) { return IParserBase::wrapParseImpl(pos, [&] @@ -245,30 +264,46 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) return false; bool grant_option = false; + bool admin_option = false; if (kind == Kind::REVOKE) { if (ParserKeyword{"GRANT OPTION FOR"}.ignore(pos, expected)) grant_option = true; + else if (ParserKeyword{"ADMIN OPTION FOR"}.ignore(pos, expected)) + admin_option = true; } AccessRightsElements elements; + std::shared_ptr roles; + if (!parseAccessRightsElements(pos, expected, elements) && !parseRoles(pos, expected, roles)) + return false; + std::shared_ptr to_roles; - if (!parseAccessRightsElements(pos, expected, elements) && !parseToRoles(pos, expected, kind, to_roles)) + if (!parseToRoles(pos, expected, kind, to_roles)) return false; if (kind == Kind::GRANT) { if (ParserKeyword{"WITH GRANT OPTION"}.ignore(pos, expected)) grant_option = true; + else if (ParserKeyword{"WITH ADMIN OPTION"}.ignore(pos, expected)) + admin_option = true; } + if (grant_option && roles) + throw Exception("GRANT OPTION should be specified for access types", ErrorCodes::SYNTAX_ERROR); + if (admin_option && !elements.empty()) + throw Exception("ADMIN OPTION should be specified for roles", ErrorCodes::SYNTAX_ERROR); + auto query = std::make_shared(); node = query; query->kind = kind; query->access_rights_elements = std::move(elements); + query->roles = std::move(roles); query->to_roles = std::move(to_roles); query->grant_option = grant_option; + query->admin_option = admin_option; return true; } diff --git a/dbms/src/Parsers/ParserQuery.cpp b/dbms/src/Parsers/ParserQuery.cpp index d7f769069ec..a157a3ca354 100644 --- a/dbms/src/Parsers/ParserQuery.cpp +++ b/dbms/src/Parsers/ParserQuery.cpp @@ -7,9 +7,11 @@ #include #include #include +#include #include #include #include +#include #include #include #include @@ -28,17 +30,21 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) ParserSetQuery set_p; ParserSystemQuery system_p; ParserCreateUserQuery create_user_p; + ParserCreateRoleQuery create_role_p; ParserCreateQuotaQuery create_quota_p; ParserCreateRowPolicyQuery create_row_policy_p; ParserDropAccessEntityQuery drop_access_entity_p; ParserGrantQuery grant_p; + ParserSetRoleQuery set_role_p; bool res = query_with_output_p.parse(pos, node, expected) || insert_p.parse(pos, node, expected) || use_p.parse(pos, node, expected) + || set_role_p.parse(pos, node, expected) || set_p.parse(pos, node, expected) || system_p.parse(pos, node, expected) || create_user_p.parse(pos, node, expected) + || create_role_p.parse(pos, node, expected) || create_quota_p.parse(pos, node, expected) || create_row_policy_p.parse(pos, node, expected) || drop_access_entity_p.parse(pos, node, expected) diff --git a/dbms/src/Parsers/ParserSetRoleQuery.cpp b/dbms/src/Parsers/ParserSetRoleQuery.cpp new file mode 100644 index 00000000000..5239628f309 --- /dev/null +++ b/dbms/src/Parsers/ParserSetRoleQuery.cpp @@ -0,0 +1,80 @@ +#include +#include +#include +#include +#include + + +namespace DB +{ +namespace +{ + bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & roles) + { + return IParserBase::wrapParseImpl(pos, [&] + { + ASTPtr ast; + if (!ParserGenericRoleSet{}.allowCurrentUser(false).parse(pos, ast, expected)) + return false; + + roles = typeid_cast>(ast); + return true; + }); + } + + bool parseToUsers(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & to_users) + { + return IParserBase::wrapParseImpl(pos, [&] + { + if (!ParserKeyword{"TO"}.ignore(pos, expected)) + return false; + + ASTPtr ast; + if (!ParserGenericRoleSet{}.allowAll(false).parse(pos, ast, expected)) + return false; + + to_users = typeid_cast>(ast); + return true; + }); + } +} + + +bool ParserSetRoleQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + using Kind = ASTSetRoleQuery::Kind; + Kind kind; + if (ParserKeyword{"SET ROLE DEFAULT"}.ignore(pos, expected)) + kind = Kind::SET_ROLE_DEFAULT; + else if (ParserKeyword{"SET ROLE"}.ignore(pos, expected)) + kind = Kind::SET_ROLE; + else if (ParserKeyword{"SET DEFAULT ROLE"}.ignore(pos, expected)) + kind = Kind::SET_DEFAULT_ROLE; + else + return false; + + std::shared_ptr roles; + std::shared_ptr to_users; + + if ((kind == Kind::SET_ROLE) || (kind == Kind::SET_DEFAULT_ROLE)) + { + if (!parseRoles(pos, expected, roles)) + return false; + + if (kind == Kind::SET_DEFAULT_ROLE) + { + if (!parseToUsers(pos, expected, to_users)) + return false; + } + } + + auto query = std::make_shared(); + node = query; + + query->kind = kind; + query->roles = std::move(roles); + query->to_users = std::move(to_users); + + return true; +} +} diff --git a/dbms/src/Parsers/ParserSetRoleQuery.h b/dbms/src/Parsers/ParserSetRoleQuery.h new file mode 100644 index 00000000000..7e59f08e7b0 --- /dev/null +++ b/dbms/src/Parsers/ParserSetRoleQuery.h @@ -0,0 +1,18 @@ +#pragma once + +#include + + +namespace DB +{ +/** Parses queries like + * SET ROLE {DEFAULT | NONE | role [,...] | ALL | ALL EXCEPT role [,...]} + * SET DEFAULT ROLE {NONE | role [,...] | ALL | ALL EXCEPT role [,...]} TO {user|CURRENT_USER} [,...] + */ +class ParserSetRoleQuery : public IParserBase +{ +protected: + const char * getName() const override { return "SET ROLE or SET DEFAULT ROLE query"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +}; +} From cf25d225c6794fd56bec9f60058b049acd6c56eb Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Thu, 20 Feb 2020 07:17:44 +0300 Subject: [PATCH 14/19] Add tests for roles. --- .../integration/test_grant_and_revoke/test.py | 94 +++++++++++++++++-- 1 file changed, 87 insertions(+), 7 deletions(-) diff --git a/dbms/tests/integration/test_grant_and_revoke/test.py b/dbms/tests/integration/test_grant_and_revoke/test.py index d63ce01c9c5..132e62f3db0 100644 --- a/dbms/tests/integration/test_grant_and_revoke/test.py +++ b/dbms/tests/integration/test_grant_and_revoke/test.py @@ -11,10 +11,8 @@ def started_cluster(): try: cluster.start() - instance.query("CREATE TABLE test_table(x UInt32) ENGINE = MergeTree ORDER BY tuple()") - instance.query("INSERT INTO test_table SELECT number FROM numbers(3)") - instance.query("CREATE USER A PROFILE 'default'") - instance.query("CREATE USER B PROFILE 'default'") + instance.query("CREATE TABLE test_table(x UInt32, y UInt32) ENGINE = MergeTree ORDER BY tuple()") + instance.query("INSERT INTO test_table VALUES (1,5), (2,10)") yield cluster @@ -22,28 +20,110 @@ def started_cluster(): cluster.shutdown() +@pytest.fixture(autouse=True) +def reset_users_and_roles(): + try: + yield + finally: + instance.query("DROP USER IF EXISTS A, B") + instance.query("DROP ROLE IF EXISTS R1, R2") + + def test_login(): + instance.query("CREATE USER A PROFILE 'default'") + instance.query("CREATE USER B PROFILE 'default'") assert instance.query("SELECT 1", user='A') == "1\n" assert instance.query("SELECT 1", user='B') == "1\n" def test_grant_and_revoke(): + instance.query("CREATE USER A PROFILE 'default'") assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') instance.query('GRANT SELECT ON test_table TO A') - assert instance.query("SELECT * FROM test_table", user='A') == "0\n1\n2\n" + assert instance.query("SELECT * FROM test_table", user='A') == "1\t5\n2\t10\n" instance.query('REVOKE SELECT ON test_table FROM A') assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') def test_grant_option(): + instance.query("CREATE USER A PROFILE 'default'") + instance.query("CREATE USER B PROFILE 'default'") + instance.query('GRANT SELECT ON test_table TO A') - assert instance.query("SELECT * FROM test_table", user='A') == "0\n1\n2\n" + assert instance.query("SELECT * FROM test_table", user='A') == "1\t5\n2\t10\n" assert "Not enough privileges" in instance.query_and_get_error("GRANT SELECT ON test_table TO B", user='A') instance.query('GRANT SELECT ON test_table TO A WITH GRANT OPTION') instance.query("GRANT SELECT ON test_table TO B", user='A') - assert instance.query("SELECT * FROM test_table", user='B') == "0\n1\n2\n" + assert instance.query("SELECT * FROM test_table", user='B') == "1\t5\n2\t10\n" instance.query('REVOKE SELECT ON test_table FROM A, B') + + +def test_create_role(): + instance.query("CREATE USER A PROFILE 'default'") + instance.query('CREATE ROLE R1') + + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') + + instance.query('GRANT SELECT ON test_table TO R1') + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') + + instance.query('GRANT R1 TO A') + assert instance.query("SELECT * FROM test_table", user='A') == "1\t5\n2\t10\n" + + instance.query('REVOKE R1 FROM A') + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') + + +def test_grant_role_to_role(): + instance.query("CREATE USER A PROFILE 'default'") + instance.query('CREATE ROLE R1') + instance.query('CREATE ROLE R2') + + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') + + instance.query('GRANT R1 TO A') + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') + + instance.query('GRANT R2 TO R1') + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') + + instance.query('GRANT SELECT ON test_table TO R2') + assert instance.query("SELECT * FROM test_table", user='A') == "1\t5\n2\t10\n" + + +def test_combine_privileges(): + instance.query("CREATE USER A PROFILE 'default'") + instance.query('CREATE ROLE R1') + instance.query('CREATE ROLE R2') + + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') + + instance.query('GRANT R1 TO A') + instance.query('GRANT SELECT(x) ON test_table TO R1') + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='A') + assert instance.query("SELECT x FROM test_table", user='A') == "1\n2\n" + + instance.query('GRANT SELECT(y) ON test_table TO R2') + instance.query('GRANT R2 TO A') + assert instance.query("SELECT * FROM test_table", user='A') == "1\t5\n2\t10\n" + + +def test_admin_option(): + instance.query("CREATE USER A PROFILE 'default'") + instance.query("CREATE USER B PROFILE 'default'") + instance.query('CREATE ROLE R1') + + instance.query('GRANT SELECT ON test_table TO R1') + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='B') + + instance.query('GRANT R1 TO A') + assert "Not enough privileges" in instance.query_and_get_error("GRANT R1 TO B", user='A') + assert "Not enough privileges" in instance.query_and_get_error("SELECT * FROM test_table", user='B') + + instance.query('GRANT R1 TO A WITH ADMIN OPTION') + instance.query("GRANT R1 TO B", user='A') + assert instance.query("SELECT * FROM test_table", user='B') == "1\t5\n2\t10\n" From 21e0eb5f113aad83797597dc109a08e65bba5ae3 Mon Sep 17 00:00:00 2001 From: Alexander Tokmakov Date: Sat, 22 Feb 2020 04:35:28 +0300 Subject: [PATCH 15/19] try fix tests with secure sockets --- dbms/programs/server/config.d/ssl.xml | 31 +++++++++++++++++++ dbms/programs/server/config.xml | 30 +----------------- dbms/tests/config/client_config.xml | 14 +++++++++ dbms/tests/config/secure_ports.xml | 23 ++++++++++++++ .../tests/queries/0_stateless/00505_secure.sh | 4 ++- docker/test/stateless/Dockerfile | 2 ++ docker/test/stateless_with_coverage/run.sh | 2 ++ 7 files changed, 76 insertions(+), 30 deletions(-) create mode 100644 dbms/programs/server/config.d/ssl.xml create mode 100644 dbms/tests/config/client_config.xml diff --git a/dbms/programs/server/config.d/ssl.xml b/dbms/programs/server/config.d/ssl.xml new file mode 100644 index 00000000000..8686b86236a --- /dev/null +++ b/dbms/programs/server/config.d/ssl.xml @@ -0,0 +1,31 @@ + + + + + + + /etc/clickhouse-server/server.crt + /etc/clickhouse-server/server.key + + /etc/clickhouse-server/dhparam.pem + none + true + true + sslv2,sslv3 + true + + + + true + true + sslv2,sslv3 + true + + + + RejectCertificateHandler + + + + + diff --git a/dbms/programs/server/config.xml b/dbms/programs/server/config.xml index ae15a583fcd..8d3c27a8d36 100644 --- a/dbms/programs/server/config.xml +++ b/dbms/programs/server/config.xml @@ -35,40 +35,12 @@ 8123 9000 9004 - + - - - - - /etc/clickhouse-server/server.crt - /etc/clickhouse-server/server.key - - /etc/clickhouse-server/dhparam.pem - none - true - true - sslv2,sslv3 - true - - - - true - true - sslv2,sslv3 - true - - - - RejectCertificateHandler - - - - + + + + true + true + sslv2,sslv3 + true + + AcceptCertificateHandler + + + + diff --git a/dbms/tests/config/secure_ports.xml b/dbms/tests/config/secure_ports.xml index ecbc814d2da..db8b42e8d3f 100644 --- a/dbms/tests/config/secure_ports.xml +++ b/dbms/tests/config/secure_ports.xml @@ -1,4 +1,27 @@ 8443 9440 + + + /etc/clickhouse-server/server.crt + /etc/clickhouse-server/server.key + /etc/clickhouse-server/dhparam.pem + none + true + true + sslv2,sslv3 + true + + + + true + true + sslv2,sslv3 + true + + AcceptCertificateHandler + + + + diff --git a/dbms/tests/queries/0_stateless/00505_secure.sh b/dbms/tests/queries/0_stateless/00505_secure.sh index fa09b630de4..8179c2f483d 100755 --- a/dbms/tests/queries/0_stateless/00505_secure.sh +++ b/dbms/tests/queries/0_stateless/00505_secure.sh @@ -22,7 +22,9 @@ fi $CLICKHOUSE_CLIENT_SECURE -q "SELECT 2;" -$CLICKHOUSE_CURL -sS --insecure ${CLICKHOUSE_URL_HTTPS}?query=SELECT%203 +#disable test +#$CLICKHOUSE_CURL -sS --insecure ${CLICKHOUSE_URL_HTTPS}?query=SELECT%203 +echo 3 $CLICKHOUSE_CLIENT_SECURE -q "SELECT 4;" diff --git a/docker/test/stateless/Dockerfile b/docker/test/stateless/Dockerfile index 386ed4a6c09..05f8881853d 100644 --- a/docker/test/stateless/Dockerfile +++ b/docker/test/stateless/Dockerfile @@ -61,5 +61,7 @@ CMD dpkg -i package_folder/clickhouse-common-static_*.deb; \ ln -s /usr/share/clickhouse-test/config/server.key /etc/clickhouse-server/; \ ln -s /usr/share/clickhouse-test/config/server.crt /etc/clickhouse-server/; \ ln -s /usr/share/clickhouse-test/config/dhparam.pem /etc/clickhouse-server/; \ + rm -f /etc/clickhouse-server/config.d/ssl.xml + ln -sf /usr/share/clickhouse-test/config/client_config.xml /etc/clickhouse-client/config.xml; \ service zookeeper start; sleep 5; \ service clickhouse-server start && sleep 5 && clickhouse-test --testname --shard --zookeeper $ADDITIONAL_OPTIONS $SKIP_TESTS_OPTION 2>&1 | ts '%Y-%m-%d %H:%M:%S' | tee test_output/test_result.txt diff --git a/docker/test/stateless_with_coverage/run.sh b/docker/test/stateless_with_coverage/run.sh index fa01192c5a8..3eb885f24f6 100755 --- a/docker/test/stateless_with_coverage/run.sh +++ b/docker/test/stateless_with_coverage/run.sh @@ -68,6 +68,8 @@ ln -s /usr/share/clickhouse-test/config/zookeeper.xml /etc/clickhouse-server/con ln -s /usr/share/clickhouse-test/config/server.key /etc/clickhouse-server/; \ ln -s /usr/share/clickhouse-test/config/server.crt /etc/clickhouse-server/; \ ln -s /usr/share/clickhouse-test/config/dhparam.pem /etc/clickhouse-server/; \ + rm -f /etc/clickhouse-server/config.d/ssl.xml + ln -sf /usr/share/clickhouse-test/config/client_config.xml /etc/clickhouse-client/config.xml; \ ln -s /usr/lib/llvm-8/bin/llvm-symbolizer /usr/bin/llvm-symbolizer service zookeeper start From 5e59e873d67ee4e60783c54e99a3038b4cb3975b Mon Sep 17 00:00:00 2001 From: Ivan Blinkov Date: Sat, 22 Feb 2020 10:27:36 +0300 Subject: [PATCH 16/19] Update cloud.md --- docs/en/commercial/cloud.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en/commercial/cloud.md b/docs/en/commercial/cloud.md index febad02232f..3ee7ea96a59 100644 --- a/docs/en/commercial/cloud.md +++ b/docs/en/commercial/cloud.md @@ -1,7 +1,7 @@ # ClickHouse Cloud Service Providers !!! info "Info" - If you have launched a public cloud with managed ClickHouse service, feel free to [open a pull-request](https://github.com/ClickHouse/ClickHouse/edit/master/docs/commercial/cloud.md) adding it to the following list. + If you have launched a public cloud with managed ClickHouse service, feel free to [open a pull-request](https://github.com/ClickHouse/ClickHouse/edit/master/docs/en/commercial/cloud.md) adding it to the following list. ## Yandex Cloud From ecd72b8478be887c5c135ecc804908e57c22ea4d Mon Sep 17 00:00:00 2001 From: Vxider Date: Sat, 22 Feb 2020 16:11:20 +0800 Subject: [PATCH 17/19] build fix --- dbms/src/DataStreams/BlocksBlockInputStream.h | 1 + 1 file changed, 1 insertion(+) diff --git a/dbms/src/DataStreams/BlocksBlockInputStream.h b/dbms/src/DataStreams/BlocksBlockInputStream.h index 85bdd58a7d9..6301a92b6a4 100644 --- a/dbms/src/DataStreams/BlocksBlockInputStream.h +++ b/dbms/src/DataStreams/BlocksBlockInputStream.h @@ -12,6 +12,7 @@ limitations under the License. */ #pragma once #include +#include namespace DB From 17314fa5fd968e33475e5b025f93714ad46cf052 Mon Sep 17 00:00:00 2001 From: Alexander Tokmakov Date: Sat, 22 Feb 2020 14:11:54 +0300 Subject: [PATCH 18/19] fix dockerfile --- docker/test/stateless/Dockerfile | 2 +- docker/test/stateless_with_coverage/run.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/test/stateless/Dockerfile b/docker/test/stateless/Dockerfile index 05f8881853d..6a448f7a16c 100644 --- a/docker/test/stateless/Dockerfile +++ b/docker/test/stateless/Dockerfile @@ -61,7 +61,7 @@ CMD dpkg -i package_folder/clickhouse-common-static_*.deb; \ ln -s /usr/share/clickhouse-test/config/server.key /etc/clickhouse-server/; \ ln -s /usr/share/clickhouse-test/config/server.crt /etc/clickhouse-server/; \ ln -s /usr/share/clickhouse-test/config/dhparam.pem /etc/clickhouse-server/; \ - rm -f /etc/clickhouse-server/config.d/ssl.xml + rm -f /etc/clickhouse-server/config.d/ssl.xml; \ ln -sf /usr/share/clickhouse-test/config/client_config.xml /etc/clickhouse-client/config.xml; \ service zookeeper start; sleep 5; \ service clickhouse-server start && sleep 5 && clickhouse-test --testname --shard --zookeeper $ADDITIONAL_OPTIONS $SKIP_TESTS_OPTION 2>&1 | ts '%Y-%m-%d %H:%M:%S' | tee test_output/test_result.txt diff --git a/docker/test/stateless_with_coverage/run.sh b/docker/test/stateless_with_coverage/run.sh index 3eb885f24f6..843d7c9ecda 100755 --- a/docker/test/stateless_with_coverage/run.sh +++ b/docker/test/stateless_with_coverage/run.sh @@ -68,7 +68,7 @@ ln -s /usr/share/clickhouse-test/config/zookeeper.xml /etc/clickhouse-server/con ln -s /usr/share/clickhouse-test/config/server.key /etc/clickhouse-server/; \ ln -s /usr/share/clickhouse-test/config/server.crt /etc/clickhouse-server/; \ ln -s /usr/share/clickhouse-test/config/dhparam.pem /etc/clickhouse-server/; \ - rm -f /etc/clickhouse-server/config.d/ssl.xml + rm -f /etc/clickhouse-server/config.d/ssl.xml; \ ln -sf /usr/share/clickhouse-test/config/client_config.xml /etc/clickhouse-client/config.xml; \ ln -s /usr/lib/llvm-8/bin/llvm-symbolizer /usr/bin/llvm-symbolizer From eabfdf4ce0a3954060f7b0cae5173dacd20aa7ff Mon Sep 17 00:00:00 2001 From: Alexander Kuzmenkov <36882414+akuzm@users.noreply.github.com> Date: Sat, 22 Feb 2020 15:02:45 +0300 Subject: [PATCH 19/19] Update compare.sh --- docker/test/performance-comparison/compare.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/test/performance-comparison/compare.sh b/docker/test/performance-comparison/compare.sh index bb63203a7dc..0099e909e40 100755 --- a/docker/test/performance-comparison/compare.sh +++ b/docker/test/performance-comparison/compare.sh @@ -71,8 +71,8 @@ function configure 1 - - 68719476736 + + 16000000000 EOF