Merge branch 'master' into fix_expressions_in_metadata

This commit is contained in:
Alexander Tokmakov 2020-02-22 16:42:45 +03:00
commit ac0e2c2256
119 changed files with 3532 additions and 1213 deletions

View File

@ -218,7 +218,7 @@ void MySQLHandler::authenticate(const String & user_name, const String & auth_pl
try try
{ {
// For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used. // For compatibility with JavaScript MySQL client, Native41 authentication plugin is used when possible (if password is specified using double SHA1). Otherwise SHA256 plugin is used.
auto user = connection_context.getAccessControlManager().getUser(user_name); auto user = connection_context.getAccessControlManager().read<User>(user_name);
const DB::Authentication::Type user_auth_type = user->authentication.getType(); const DB::Authentication::Type user_auth_type = user->authentication.getType();
if (user_auth_type != DB::Authentication::DOUBLE_SHA1_PASSWORD && user_auth_type != DB::Authentication::PLAINTEXT_PASSWORD && user_auth_type != DB::Authentication::NO_PASSWORD) if (user_auth_type != DB::Authentication::DOUBLE_SHA1_PASSWORD && user_auth_type != DB::Authentication::PLAINTEXT_PASSWORD && user_auth_type != DB::Authentication::NO_PASSWORD)
{ {

View File

@ -902,7 +902,7 @@ void TCPHandler::receiveQuery()
} }
else else
{ {
query_context->switchRowPolicy(); query_context->setInitialRowPolicy();
} }
} }

View File

@ -2,10 +2,10 @@
#include <Access/MultipleAccessStorage.h> #include <Access/MultipleAccessStorage.h>
#include <Access/MemoryAccessStorage.h> #include <Access/MemoryAccessStorage.h>
#include <Access/UsersConfigAccessStorage.h> #include <Access/UsersConfigAccessStorage.h>
#include <Access/User.h> #include <Access/AccessRightsContextFactory.h>
#include <Access/QuotaContextFactory.h> #include <Access/RoleContextFactory.h>
#include <Access/RowPolicyContextFactory.h> #include <Access/RowPolicyContextFactory.h>
#include <Access/AccessRightsContext.h> #include <Access/QuotaContextFactory.h>
namespace DB namespace DB
@ -24,8 +24,10 @@ namespace
AccessControlManager::AccessControlManager() AccessControlManager::AccessControlManager()
: MultipleAccessStorage(createStorages()), : MultipleAccessStorage(createStorages()),
quota_context_factory(std::make_unique<QuotaContextFactory>(*this)), access_rights_context_factory(std::make_unique<AccessRightsContextFactory>(*this)),
row_policy_context_factory(std::make_unique<RowPolicyContextFactory>(*this)) role_context_factory(std::make_unique<RoleContextFactory>(*this)),
row_policy_context_factory(std::make_unique<RowPolicyContextFactory>(*this)),
quota_context_factory(std::make_unique<QuotaContextFactory>(*this))
{ {
} }
@ -35,53 +37,6 @@ AccessControlManager::~AccessControlManager()
} }
UserPtr AccessControlManager::getUser(
const String & user_name, std::function<void(const UserPtr &)> on_change, ext::scope_guard * subscription) const
{
return getUser(getID<User>(user_name), std::move(on_change), subscription);
}
UserPtr AccessControlManager::getUser(
const UUID & user_id, std::function<void(const UserPtr &)> on_change, ext::scope_guard * subscription) const
{
if (on_change && subscription)
{
*subscription = subscribeForChanges(user_id, [on_change](const UUID &, const AccessEntityPtr & user)
{
if (user)
on_change(typeid_cast<UserPtr>(user));
});
}
return read<User>(user_id);
}
UserPtr AccessControlManager::authorizeAndGetUser(
const String & user_name,
const String & password,
const Poco::Net::IPAddress & address,
std::function<void(const UserPtr &)> on_change,
ext::scope_guard * subscription) const
{
return authorizeAndGetUser(getID<User>(user_name), password, address, std::move(on_change), subscription);
}
UserPtr AccessControlManager::authorizeAndGetUser(
const UUID & user_id,
const String & password,
const Poco::Net::IPAddress & address,
std::function<void(const UserPtr &)> on_change,
ext::scope_guard * subscription) const
{
auto user = getUser(user_id, on_change, subscription);
user->allowed_client_hosts.checkContains(address, user->getName());
user->authentication.checkPassword(password, user->getName());
return user;
}
void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguration & users_config) void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguration & users_config)
{ {
auto & users_config_access_storage = dynamic_cast<UsersConfigAccessStorage &>(getStorageByIndex(1)); auto & users_config_access_storage = dynamic_cast<UsersConfigAccessStorage &>(getStorageByIndex(1));
@ -89,16 +44,36 @@ void AccessControlManager::loadFromConfig(const Poco::Util::AbstractConfiguratio
} }
std::shared_ptr<const AccessRightsContext> AccessControlManager::getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database) AccessRightsContextPtr AccessControlManager::getAccessRightsContext(
const UUID & user_id,
const std::vector<UUID> & current_roles,
bool use_default_roles,
const Settings & settings,
const String & current_database,
const ClientInfo & client_info) const
{ {
return std::make_shared<AccessRightsContext>(user, client_info, settings, current_database); return access_rights_context_factory->createContext(user_id, current_roles, use_default_roles, settings, current_database, client_info);
} }
std::shared_ptr<QuotaContext> AccessControlManager::createQuotaContext( RoleContextPtr AccessControlManager::getRoleContext(
const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key) const std::vector<UUID> & current_roles,
const std::vector<UUID> & current_roles_with_admin_option) const
{ {
return quota_context_factory->createContext(user_name, address, custom_quota_key); return role_context_factory->createContext(current_roles, current_roles_with_admin_option);
}
RowPolicyContextPtr AccessControlManager::getRowPolicyContext(const UUID & user_id, const std::vector<UUID> & enabled_roles) const
{
return row_policy_context_factory->createContext(user_id, enabled_roles);
}
QuotaContextPtr AccessControlManager::getQuotaContext(
const String & user_name, const UUID & user_id, const std::vector<UUID> & enabled_roles, const Poco::Net::IPAddress & address, const String & custom_quota_key) const
{
return quota_context_factory->createContext(user_name, user_id, enabled_roles, address, custom_quota_key);
} }
@ -107,10 +82,4 @@ std::vector<QuotaUsageInfo> AccessControlManager::getQuotaUsageInfo() const
return quota_context_factory->getUsageInfo(); return quota_context_factory->getUsageInfo();
} }
std::shared_ptr<RowPolicyContext> AccessControlManager::getRowPolicyContext(const String & user_name) const
{
return row_policy_context_factory->createContext(user_name);
}
} }

View File

@ -2,7 +2,6 @@
#include <Access/MultipleAccessStorage.h> #include <Access/MultipleAccessStorage.h>
#include <Poco/AutoPtr.h> #include <Poco/AutoPtr.h>
#include <ext/scope_guard.h>
#include <memory> #include <memory>
@ -20,15 +19,21 @@ namespace Poco
namespace DB namespace DB
{ {
class AccessRightsContext;
using AccessRightsContextPtr = std::shared_ptr<const AccessRightsContext>;
class AccessRightsContextFactory;
struct User; struct User;
using UserPtr = std::shared_ptr<const User>; using UserPtr = std::shared_ptr<const User>;
struct RoleContext;
using RoleContextPtr = std::shared_ptr<const RoleContext>;
class RoleContextFactory;
class RowPolicyContext;
using RowPolicyContextPtr = std::shared_ptr<const RowPolicyContext>;
class RowPolicyContextFactory;
class QuotaContext; class QuotaContext;
using QuotaContextPtr = std::shared_ptr<const QuotaContext>;
class QuotaContextFactory; class QuotaContextFactory;
struct QuotaUsageInfo; struct QuotaUsageInfo;
class RowPolicyContext;
class RowPolicyContextFactory;
class AccessRights;
class AccessRightsContext;
class ClientInfo; class ClientInfo;
struct Settings; struct Settings;
@ -42,23 +47,36 @@ public:
void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config); void loadFromConfig(const Poco::Util::AbstractConfiguration & users_config);
UserPtr getUser(const String & user_name, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const; AccessRightsContextPtr getAccessRightsContext(
UserPtr getUser(const UUID & user_id, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const; const UUID & user_id,
UserPtr authorizeAndGetUser(const String & user_name, const String & password, const Poco::Net::IPAddress & address, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const; const std::vector<UUID> & current_roles,
UserPtr authorizeAndGetUser(const UUID & user_id, const String & password, const Poco::Net::IPAddress & address, std::function<void(const UserPtr &)> on_change = {}, ext::scope_guard * subscription = nullptr) const; bool use_default_roles,
const Settings & settings,
const String & current_database,
const ClientInfo & client_info) const;
std::shared_ptr<const AccessRightsContext> getAccessRightsContext(const UserPtr & user, const ClientInfo & client_info, const Settings & settings, const String & current_database); RoleContextPtr getRoleContext(
const std::vector<UUID> & current_roles,
const std::vector<UUID> & current_roles_with_admin_option) const;
std::shared_ptr<QuotaContext> RowPolicyContextPtr getRowPolicyContext(
createQuotaContext(const String & user_name, const Poco::Net::IPAddress & address, const String & custom_quota_key); const UUID & user_id,
const std::vector<UUID> & enabled_roles) const;
QuotaContextPtr getQuotaContext(
const String & user_name,
const UUID & user_id,
const std::vector<UUID> & enabled_roles,
const Poco::Net::IPAddress & address,
const String & custom_quota_key) const;
std::vector<QuotaUsageInfo> getQuotaUsageInfo() const; std::vector<QuotaUsageInfo> getQuotaUsageInfo() const;
std::shared_ptr<RowPolicyContext> getRowPolicyContext(const String & user_name) const;
private: private:
std::unique_ptr<QuotaContextFactory> quota_context_factory; std::unique_ptr<AccessRightsContextFactory> access_rights_context_factory;
std::unique_ptr<RoleContextFactory> role_context_factory;
std::unique_ptr<RowPolicyContextFactory> row_policy_context_factory; std::unique_ptr<RowPolicyContextFactory> row_policy_context_factory;
std::unique_ptr<QuotaContextFactory> quota_context_factory;
}; };
} }

View File

@ -304,15 +304,10 @@ private:
ext::push_back(all, std::move(alter)); ext::push_back(all, std::move(alter));
auto create_database = std::make_unique<Node>("CREATE DATABASE", next_flag++, DATABASE_LEVEL); auto create_database = std::make_unique<Node>("CREATE DATABASE", next_flag++, DATABASE_LEVEL);
ext::push_back(create_database->aliases, "ATTACH DATABASE");
auto create_table = std::make_unique<Node>("CREATE TABLE", next_flag++, TABLE_LEVEL); auto create_table = std::make_unique<Node>("CREATE TABLE", next_flag++, TABLE_LEVEL);
ext::push_back(create_table->aliases, "ATTACH TABLE");
auto create_view = std::make_unique<Node>("CREATE VIEW", next_flag++, VIEW_LEVEL); auto create_view = std::make_unique<Node>("CREATE VIEW", next_flag++, VIEW_LEVEL);
ext::push_back(create_view->aliases, "ATTACH VIEW");
auto create_dictionary = std::make_unique<Node>("CREATE DICTIONARY", next_flag++, DICTIONARY_LEVEL); auto create_dictionary = std::make_unique<Node>("CREATE DICTIONARY", next_flag++, DICTIONARY_LEVEL);
ext::push_back(create_dictionary->aliases, "ATTACH DICTIONARY");
auto create = std::make_unique<Node>("CREATE", std::move(create_database), std::move(create_table), std::move(create_view), std::move(create_dictionary)); auto create = std::make_unique<Node>("CREATE", std::move(create_database), std::move(create_table), std::move(create_view), std::move(create_dictionary));
ext::push_back(create->aliases, "ATTACH");
ext::push_back(all, std::move(create)); ext::push_back(all, std::move(create));
auto create_temporary_table = std::make_unique<Node>("CREATE TEMPORARY TABLE", next_flag++, GLOBAL_LEVEL); auto create_temporary_table = std::make_unique<Node>("CREATE TEMPORARY TABLE", next_flag++, GLOBAL_LEVEL);
@ -325,13 +320,6 @@ private:
auto drop = std::make_unique<Node>("DROP", std::move(drop_database), std::move(drop_table), std::move(drop_view), std::move(drop_dictionary)); auto drop = std::make_unique<Node>("DROP", std::move(drop_database), std::move(drop_table), std::move(drop_view), std::move(drop_dictionary));
ext::push_back(all, std::move(drop)); ext::push_back(all, std::move(drop));
auto detach_database = std::make_unique<Node>("DETACH DATABASE", next_flag++, DATABASE_LEVEL);
auto detach_table = std::make_unique<Node>("DETACH TABLE", next_flag++, TABLE_LEVEL);
auto detach_view = std::make_unique<Node>("DETACH VIEW", next_flag++, VIEW_LEVEL);
auto detach_dictionary = std::make_unique<Node>("DETACH DICTIONARY", next_flag++, DICTIONARY_LEVEL);
auto detach = std::make_unique<Node>("DETACH", std::move(detach_database), std::move(detach_table), std::move(detach_view), std::move(detach_dictionary));
ext::push_back(all, std::move(detach));
auto truncate_table = std::make_unique<Node>("TRUNCATE TABLE", next_flag++, TABLE_LEVEL); auto truncate_table = std::make_unique<Node>("TRUNCATE TABLE", next_flag++, TABLE_LEVEL);
auto truncate_view = std::make_unique<Node>("TRUNCATE VIEW", next_flag++, VIEW_LEVEL); auto truncate_view = std::make_unique<Node>("TRUNCATE VIEW", next_flag++, VIEW_LEVEL);
auto truncate = std::make_unique<Node>("TRUNCATE", std::move(truncate_table), std::move(truncate_view)); auto truncate = std::make_unique<Node>("TRUNCATE", std::move(truncate_table), std::move(truncate_view));
@ -347,8 +335,18 @@ private:
ext::push_back(all, std::move(kill)); ext::push_back(all, std::move(kill));
auto create_user = std::make_unique<Node>("CREATE USER", next_flag++, GLOBAL_LEVEL); auto create_user = std::make_unique<Node>("CREATE USER", next_flag++, GLOBAL_LEVEL);
ext::push_back(create_user->aliases, "ALTER USER", "DROP USER", "CREATE ROLE", "DROP ROLE", "CREATE POLICY", "ALTER POLICY", "DROP POLICY", "CREATE QUOTA", "ALTER QUOTA", "DROP QUOTA"); auto alter_user = std::make_unique<Node>("ALTER USER", next_flag++, GLOBAL_LEVEL);
ext::push_back(all, std::move(create_user)); auto drop_user = std::make_unique<Node>("DROP USER", next_flag++, GLOBAL_LEVEL);
auto create_role = std::make_unique<Node>("CREATE ROLE", next_flag++, GLOBAL_LEVEL);
auto drop_role = std::make_unique<Node>("DROP ROLE", next_flag++, GLOBAL_LEVEL);
auto create_policy = std::make_unique<Node>("CREATE POLICY", next_flag++, GLOBAL_LEVEL);
auto alter_policy = std::make_unique<Node>("ALTER POLICY", next_flag++, GLOBAL_LEVEL);
auto drop_policy = std::make_unique<Node>("DROP POLICY", next_flag++, GLOBAL_LEVEL);
auto create_quota = std::make_unique<Node>("CREATE QUOTA", next_flag++, GLOBAL_LEVEL);
auto alter_quota = std::make_unique<Node>("ALTER QUOTA", next_flag++, GLOBAL_LEVEL);
auto drop_quota = std::make_unique<Node>("DROP QUOTA", next_flag++, GLOBAL_LEVEL);
auto role_admin = std::make_unique<Node>("ROLE ADMIN", next_flag++, GLOBAL_LEVEL);
ext::push_back(all, std::move(create_user), std::move(alter_user), std::move(drop_user), std::move(create_role), std::move(drop_role), std::move(create_policy), std::move(alter_policy), std::move(drop_policy), std::move(create_quota), std::move(alter_quota), std::move(drop_quota), std::move(role_admin));
auto shutdown = std::make_unique<Node>("SHUTDOWN", next_flag++, GLOBAL_LEVEL); auto shutdown = std::make_unique<Node>("SHUTDOWN", next_flag++, GLOBAL_LEVEL);
ext::push_back(shutdown->aliases, "SYSTEM SHUTDOWN", "SYSTEM KILL"); ext::push_back(shutdown->aliases, "SYSTEM SHUTDOWN", "SYSTEM KILL");

View File

@ -1,11 +1,20 @@
#include <Access/AccessRightsContext.h> #include <Access/AccessRightsContext.h>
#include <Access/AccessControlManager.h>
#include <Access/RoleContext.h>
#include <Access/RowPolicyContext.h>
#include <Access/QuotaContext.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/CurrentRolesInfo.h>
#include <Common/Exception.h> #include <Common/Exception.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <Core/Settings.h> #include <Core/Settings.h>
#include <IO/WriteHelpers.h>
#include <Poco/Logger.h> #include <Poco/Logger.h>
#include <common/logger_useful.h> #include <common/logger_useful.h>
#include <boost/algorithm/string/join.hpp>
#include <boost/smart_ptr/make_shared_object.hpp> #include <boost/smart_ptr/make_shared_object.hpp>
#include <boost/range/algorithm/fill.hpp>
#include <boost/range/algorithm/set_algorithm.hpp>
#include <assert.h> #include <assert.h>
@ -17,6 +26,7 @@ namespace ErrorCodes
extern const int READONLY; extern const int READONLY;
extern const int QUERY_IS_PROHIBITED; extern const int QUERY_IS_PROHIBITED;
extern const int FUNCTION_NOT_ALLOWED; extern const int FUNCTION_NOT_ALLOWED;
extern const int UNKNOWN_USER;
} }
@ -85,25 +95,116 @@ AccessRightsContext::AccessRightsContext()
{ {
auto everything_granted = boost::make_shared<AccessRights>(); auto everything_granted = boost::make_shared<AccessRights>();
everything_granted->grant(AccessType::ALL); everything_granted->grant(AccessType::ALL);
result_access_cache[0] = std::move(everything_granted); boost::range::fill(result_access_cache, everything_granted);
enabled_roles_with_admin_option = boost::make_shared<boost::container::flat_set<UUID>>();
row_policy_context = std::make_shared<RowPolicyContext>();
quota_context = std::make_shared<QuotaContext>();
} }
AccessRightsContext::AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_) AccessRightsContext::AccessRightsContext(const AccessControlManager & manager_, const Params & params_)
: user(user_) : manager(&manager_)
, readonly(settings.readonly) , params(params_)
, allow_ddl(settings.allow_ddl)
, allow_introspection(settings.allow_introspection_functions)
, current_database(current_database_)
, interface(client_info_.interface)
, http_method(client_info_.http_method)
, trace_log(&Poco::Logger::get("AccessRightsContext (" + user_->getName() + ")"))
{ {
subscription_for_user_change = manager->subscribeForChanges(
*params.user_id, [this](const UUID &, const AccessEntityPtr & entity)
{
UserPtr changed_user = entity ? typeid_cast<UserPtr>(entity) : nullptr;
std::lock_guard lock{mutex};
setUser(changed_user);
});
setUser(manager->read<User>(*params.user_id));
}
void AccessRightsContext::setUser(const UserPtr & user_) const
{
user = user_;
if (!user)
{
/// User has been dropped.
auto nothing_granted = boost::make_shared<AccessRights>();
boost::range::fill(result_access_cache, nothing_granted);
subscription_for_user_change = {};
subscription_for_roles_info_change = {};
role_context = nullptr;
enabled_roles_with_admin_option = boost::make_shared<boost::container::flat_set<UUID>>();
row_policy_context = std::make_shared<RowPolicyContext>();
quota_context = std::make_shared<QuotaContext>();
return;
}
user_name = user->getName();
trace_log = &Poco::Logger::get("AccessRightsContext (" + user_name + ")");
std::vector<UUID> current_roles, current_roles_with_admin_option;
if (params.use_default_roles)
{
for (const UUID & id : user->granted_roles)
{
if (user->default_roles.match(id))
current_roles.push_back(id);
}
boost::range::set_intersection(current_roles, user->granted_roles_with_admin_option,
std::back_inserter(current_roles_with_admin_option));
}
else
{
current_roles.reserve(params.current_roles.size());
for (const auto & id : params.current_roles)
{
if (user->granted_roles.contains(id))
current_roles.push_back(id);
if (user->granted_roles_with_admin_option.contains(id))
current_roles_with_admin_option.push_back(id);
}
}
subscription_for_roles_info_change = {};
role_context = manager->getRoleContext(current_roles, current_roles_with_admin_option);
subscription_for_roles_info_change = role_context->subscribeForChanges([this](const CurrentRolesInfoPtr & roles_info_)
{
std::lock_guard lock{mutex};
setRolesInfo(roles_info_);
});
setRolesInfo(role_context->getInfo());
}
void AccessRightsContext::setRolesInfo(const CurrentRolesInfoPtr & roles_info_) const
{
assert(roles_info_);
roles_info = roles_info_;
enabled_roles_with_admin_option.store(nullptr /* need to recalculate */);
boost::range::fill(result_access_cache, nullptr /* need recalculate */);
row_policy_context = manager->getRowPolicyContext(*params.user_id, roles_info->enabled_roles);
quota_context = manager->getQuotaContext(user_name, *params.user_id, roles_info->enabled_roles, params.address, params.quota_key);
}
void AccessRightsContext::checkPassword(const String & password) const
{
std::lock_guard lock{mutex};
if (!user)
throw Exception(user_name + ": User has been dropped", ErrorCodes::UNKNOWN_USER);
user->authentication.checkPassword(password, user_name);
}
void AccessRightsContext::checkHostIsAllowed() const
{
std::lock_guard lock{mutex};
if (!user)
throw Exception(user_name + ": User has been dropped", ErrorCodes::UNKNOWN_USER);
user->allowed_client_hosts.checkContains(params.address, user_name);
} }
template <int mode, bool grant_option, typename... Args> template <int mode, bool grant_option, typename... Args>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const
{ {
auto result_access = calculateResultAccess(grant_option); auto result_access = calculateResultAccess(grant_option);
bool is_granted = result_access->isGranted(access, args...); bool is_granted = result_access->isGranted(access, args...);
@ -126,12 +227,16 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
auto show_error = [&](const String & msg, [[maybe_unused]] int error_code) auto show_error = [&](const String & msg, [[maybe_unused]] int error_code)
{ {
if constexpr (mode == THROW_IF_ACCESS_DENIED) if constexpr (mode == THROW_IF_ACCESS_DENIED)
throw Exception(user->getName() + ": " + msg, error_code); throw Exception(user_name + ": " + msg, error_code);
else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED) else if constexpr (mode == LOG_WARNING_IF_ACCESS_DENIED)
LOG_WARNING(log_, user->getName() + ": " + msg + formatSkippedMessage(args...)); LOG_WARNING(log_, user_name + ": " + msg + formatSkippedMessage(args...));
}; };
if (grant_option && calculateResultAccess(false, readonly, allow_ddl, allow_introspection)->isGranted(access, args...)) if (!user)
{
show_error("User has been dropped", ErrorCodes::UNKNOWN_USER);
}
else if (grant_option && calculateResultAccess(false, params.readonly, params.allow_ddl, params.allow_introspection)->isGranted(access, args...))
{ {
show_error( show_error(
"Not enough privileges. " "Not enough privileges. "
@ -140,9 +245,9 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
+ AccessRightsElement{access, args...}.toString() + " WITH GRANT OPTION", + AccessRightsElement{access, args...}.toString() + " WITH GRANT OPTION",
ErrorCodes::ACCESS_DENIED); ErrorCodes::ACCESS_DENIED);
} }
else if (readonly && calculateResultAccess(false, false, allow_ddl, allow_introspection)->isGranted(access, args...)) else if (params.readonly && calculateResultAccess(false, false, params.allow_ddl, params.allow_introspection)->isGranted(access, args...))
{ {
if (interface == ClientInfo::Interface::HTTP && http_method == ClientInfo::HTTPMethod::GET) if (params.interface == ClientInfo::Interface::HTTP && params.http_method == ClientInfo::HTTPMethod::GET)
show_error( show_error(
"Cannot execute query in readonly mode. " "Cannot execute query in readonly mode. "
"For queries over HTTP, method GET implies readonly. You should use method POST for modifying queries", "For queries over HTTP, method GET implies readonly. You should use method POST for modifying queries",
@ -150,11 +255,11 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
else else
show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY); show_error("Cannot execute query in readonly mode", ErrorCodes::READONLY);
} }
else if (!allow_ddl && calculateResultAccess(false, readonly, true, allow_introspection)->isGranted(access, args...)) else if (!params.allow_ddl && calculateResultAccess(false, params.readonly, true, params.allow_introspection)->isGranted(access, args...))
{ {
show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED); show_error("Cannot execute query. DDL queries are prohibited for the user", ErrorCodes::QUERY_IS_PROHIBITED);
} }
else if (!allow_introspection && calculateResultAccess(false, readonly, allow_ddl, true)->isGranted(access, args...)) else if (!params.allow_introspection && calculateResultAccess(false, params.readonly, params.allow_ddl, true)->isGranted(access, args...))
{ {
show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED); show_error("Introspection functions are disabled, because setting 'allow_introspection_functions' is set to 0", ErrorCodes::FUNCTION_NOT_ALLOWED);
} }
@ -171,94 +276,127 @@ bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessFlags & acc
template <int mode, bool grant_option> template <int mode, bool grant_option>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElement & element) const bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessRightsElement & element) const
{ {
if (element.any_database) if (element.any_database)
{ {
return checkImpl<mode, grant_option>(log_, element.access_flags); return checkAccessImpl<mode, grant_option>(log_, element.access_flags);
} }
else if (element.any_table) else if (element.any_table)
{ {
if (element.database.empty()) if (element.database.empty())
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database); return checkAccessImpl<mode, grant_option>(log_, element.access_flags, params.current_database);
else else
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database); return checkAccessImpl<mode, grant_option>(log_, element.access_flags, element.database);
} }
else if (element.any_column) else if (element.any_column)
{ {
if (element.database.empty()) if (element.database.empty())
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database, element.table); return checkAccessImpl<mode, grant_option>(log_, element.access_flags, params.current_database, element.table);
else else
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table); return checkAccessImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table);
} }
else else
{ {
if (element.database.empty()) if (element.database.empty())
return checkImpl<mode, grant_option>(log_, element.access_flags, current_database, element.table, element.columns); return checkAccessImpl<mode, grant_option>(log_, element.access_flags, params.current_database, element.table, element.columns);
else else
return checkImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table, element.columns); return checkAccessImpl<mode, grant_option>(log_, element.access_flags, element.database, element.table, element.columns);
} }
} }
template <int mode, bool grant_option> template <int mode, bool grant_option>
bool AccessRightsContext::checkImpl(Poco::Logger * log_, const AccessRightsElements & elements) const bool AccessRightsContext::checkAccessImpl(Poco::Logger * log_, const AccessRightsElements & elements) const
{ {
for (const auto & element : elements) for (const auto & element : elements)
if (!checkImpl<mode, grant_option>(log_, element)) if (!checkAccessImpl<mode, grant_option>(log_, element))
return false; return false;
return true; return true;
} }
void AccessRightsContext::check(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); } void AccessRightsContext::checkAccess(const AccessFlags & access) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database); } void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table); } void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); } void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); } void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); } void AccessRightsContext::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
void AccessRightsContext::check(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); } void AccessRightsContext::checkAccess(const AccessRightsElement & access) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
void AccessRightsContext::check(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); } void AccessRightsContext::checkAccess(const AccessRightsElements & access) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); } bool AccessRightsContext::isGranted(const AccessFlags & access) const { return checkAccessImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database); } bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table); } bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); } bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, column); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); } bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkAccessImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); } bool AccessRightsContext::isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access, database, table, columns); }
bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); } bool AccessRightsContext::isGranted(const AccessRightsElement & access) const { return checkAccessImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); } bool AccessRightsContext::isGranted(const AccessRightsElements & access) const { return checkAccessImpl<RETURN_FALSE_IF_ACCESS_DENIED, false>(nullptr, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); } bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access) const { return checkAccessImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database); } bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table); } bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, column); } bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, column); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); } bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkAccessImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); } bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access, database, table, columns); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); } bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElement & access) const { return checkAccessImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); } bool AccessRightsContext::isGranted(Poco::Logger * log_, const AccessRightsElements & access) const { return checkAccessImpl<LOG_WARNING_IF_ACCESS_DENIED, false>(log_, access); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); } void AccessRightsContext::checkGrantOption(const AccessFlags & access) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database); } void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table); } void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, column); } void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, column); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); } void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); }
void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); } void AccessRightsContext::checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access, database, table, columns); }
void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); } void AccessRightsContext::checkGrantOption(const AccessRightsElement & access) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); } void AccessRightsContext::checkGrantOption(const AccessRightsElements & access) const { checkAccessImpl<THROW_IF_ACCESS_DENIED, true>(nullptr, access); }
void AccessRightsContext::checkAdminOption(const UUID & role_id) const
{
if (isGranted(AccessType::ROLE_ADMIN))
return;
boost::shared_ptr<const boost::container::flat_set<UUID>> enabled_roles = enabled_roles_with_admin_option.load();
if (!enabled_roles)
{
std::lock_guard lock{mutex};
enabled_roles = enabled_roles_with_admin_option.load();
if (!enabled_roles)
{
if (roles_info)
enabled_roles = boost::make_shared<boost::container::flat_set<UUID>>(roles_info->enabled_roles_with_admin_option.begin(), roles_info->enabled_roles_with_admin_option.end());
else
enabled_roles = boost::make_shared<boost::container::flat_set<UUID>>();
enabled_roles_with_admin_option.store(enabled_roles);
}
}
if (enabled_roles->contains(role_id))
return;
std::optional<String> role_name = manager->readName(role_id);
if (!role_name)
role_name = "ID {" + toString(role_id) + "}";
throw Exception(
getUserName() + ": Not enough privileges. To execute this query it's necessary to have the grant " + backQuoteIfNeed(*role_name)
+ " WITH ADMIN OPTION ",
ErrorCodes::ACCESS_DENIED);
}
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option) const boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option) const
{ {
return calculateResultAccess(grant_option, readonly, allow_ddl, allow_introspection); return calculateResultAccess(grant_option, params.readonly, params.allow_ddl, params.allow_introspection);
} }
boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const
{ {
size_t cache_index = static_cast<size_t>(readonly_ != readonly) size_t cache_index = static_cast<size_t>(readonly_ != params.readonly)
+ static_cast<size_t>(allow_ddl_ != allow_ddl) * 2 + + static_cast<size_t>(allow_ddl_ != params.allow_ddl) * 2 +
+ static_cast<size_t>(allow_introspection_ != allow_introspection) * 3 + static_cast<size_t>(allow_introspection_ != params.allow_introspection) * 3
+ static_cast<size_t>(grant_option) * 4; + static_cast<size_t>(grant_option) * 4;
assert(cache_index < std::size(result_access_cache)); assert(cache_index < std::size(result_access_cache));
auto cached = result_access_cache[cache_index].load(); auto cached = result_access_cache[cache_index].load();
@ -273,20 +411,35 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
auto result_ptr = boost::make_shared<AccessRights>(); auto result_ptr = boost::make_shared<AccessRights>();
auto & result = *result_ptr; auto & result = *result_ptr;
result = grant_option ? user->access_with_grant_option : user->access; if (grant_option)
{
result = user->access_with_grant_option;
if (roles_info)
result.merge(roles_info->access_with_grant_option);
}
else
{
result = user->access;
if (roles_info)
result.merge(roles_info->access);
}
static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW static const AccessFlags table_ddl = AccessType::CREATE_DATABASE | AccessType::CREATE_TABLE | AccessType::CREATE_VIEW
| AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW | AccessType::ALTER_TABLE | AccessType::ALTER_VIEW | AccessType::DROP_DATABASE | AccessType::DROP_TABLE | AccessType::DROP_VIEW
| AccessType::DETACH_DATABASE | AccessType::DETACH_TABLE | AccessType::DETACH_VIEW | AccessType::TRUNCATE; | AccessType::TRUNCATE;
static const AccessFlags dictionary_ddl = AccessType::CREATE_DICTIONARY | AccessType::DROP_DICTIONARY | AccessType::DETACH_DICTIONARY; static const AccessFlags dictionary_ddl = AccessType::CREATE_DICTIONARY | AccessType::DROP_DICTIONARY;
static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl; static const AccessFlags table_and_dictionary_ddl = table_ddl | dictionary_ddl;
static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE; static const AccessFlags write_table_access = AccessType::INSERT | AccessType::OPTIMIZE;
static const AccessFlags all_dcl = AccessType::CREATE_USER | AccessType::CREATE_ROLE | AccessType::CREATE_POLICY
| AccessType::CREATE_QUOTA | AccessType::ALTER_USER | AccessType::ALTER_POLICY | AccessType::ALTER_QUOTA | AccessType::DROP_USER
| AccessType::DROP_ROLE | AccessType::DROP_POLICY | AccessType::DROP_QUOTA | AccessType::ROLE_ADMIN;
/// Anyone has access to the "system" database. /// Anyone has access to the "system" database.
result.grant(AccessType::SELECT, "system"); if (!result.isGranted(AccessType::SELECT, "system"))
result.grant(AccessType::SELECT, "system");
if (readonly_) if (readonly_)
result.fullRevoke(write_table_access | AccessType::SYSTEM); result.fullRevoke(write_table_access | all_dcl | AccessType::SYSTEM | AccessType::KILL);
if (readonly_ || !allow_ddl_) if (readonly_ || !allow_ddl_)
result.fullRevoke(table_and_dictionary_ddl); result.fullRevoke(table_and_dictionary_ddl);
@ -306,10 +459,118 @@ boost::shared_ptr<const AccessRights> AccessRightsContext::calculateResultAccess
result_access_cache[cache_index].store(result_ptr); result_access_cache[cache_index].store(result_ptr);
if (trace_log && (readonly == readonly_) && (allow_ddl == allow_ddl_) && (allow_introspection == allow_introspection_)) if (trace_log && (params.readonly == readonly_) && (params.allow_ddl == allow_ddl_) && (params.allow_introspection == allow_introspection_))
{
LOG_TRACE(trace_log, "List of all grants: " << result_ptr->toString() << (grant_option ? " WITH GRANT OPTION" : "")); LOG_TRACE(trace_log, "List of all grants: " << result_ptr->toString() << (grant_option ? " WITH GRANT OPTION" : ""));
if (roles_info && !roles_info->getCurrentRolesNames().empty())
{
LOG_TRACE(
trace_log,
"Current_roles: " << boost::algorithm::join(roles_info->getCurrentRolesNames(), ", ")
<< ", enabled_roles: " << boost::algorithm::join(roles_info->getEnabledRolesNames(), ", "));
}
}
return result_ptr; return result_ptr;
} }
UserPtr AccessRightsContext::getUser() const
{
std::lock_guard lock{mutex};
return user;
}
String AccessRightsContext::getUserName() const
{
std::lock_guard lock{mutex};
return user_name;
}
CurrentRolesInfoPtr AccessRightsContext::getRolesInfo() const
{
std::lock_guard lock{mutex};
return roles_info;
}
std::vector<UUID> AccessRightsContext::getCurrentRoles() const
{
std::lock_guard lock{mutex};
return roles_info ? roles_info->current_roles : std::vector<UUID>{};
}
Strings AccessRightsContext::getCurrentRolesNames() const
{
std::lock_guard lock{mutex};
return roles_info ? roles_info->getCurrentRolesNames() : Strings{};
}
std::vector<UUID> AccessRightsContext::getEnabledRoles() const
{
std::lock_guard lock{mutex};
return roles_info ? roles_info->enabled_roles : std::vector<UUID>{};
}
Strings AccessRightsContext::getEnabledRolesNames() const
{
std::lock_guard lock{mutex};
return roles_info ? roles_info->getEnabledRolesNames() : Strings{};
}
RowPolicyContextPtr AccessRightsContext::getRowPolicy() const
{
std::lock_guard lock{mutex};
return row_policy_context;
}
QuotaContextPtr AccessRightsContext::getQuota() const
{
std::lock_guard lock{mutex};
return quota_context;
}
bool operator <(const AccessRightsContext::Params & lhs, const AccessRightsContext::Params & rhs)
{
#define ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(field) \
if (lhs.field < rhs.field) \
return true; \
if (lhs.field > rhs.field) \
return false
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(user_id);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_roles);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(use_default_roles);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(address);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(quota_key);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_database);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(readonly);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_ddl);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_introspection);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(interface);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(http_method);
return false;
#undef ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER
}
bool operator ==(const AccessRightsContext::Params & lhs, const AccessRightsContext::Params & rhs)
{
#define ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(field) \
if (lhs.field != rhs.field) \
return false
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(user_id);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_roles);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(use_default_roles);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(address);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(quota_key);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(current_database);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(readonly);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_ddl);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(allow_introspection);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(interface);
ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER(http_method);
return true;
#undef ACCESS_RIGHTS_CONTEXT_PARAMS_COMPARE_HELPER
}
} }

View File

@ -2,7 +2,11 @@
#include <Access/AccessRights.h> #include <Access/AccessRights.h>
#include <Interpreters/ClientInfo.h> #include <Interpreters/ClientInfo.h>
#include <Core/UUID.h>
#include <ext/scope_guard.h>
#include <ext/shared_ptr_helper.h>
#include <boost/smart_ptr/atomic_shared_ptr.hpp> #include <boost/smart_ptr/atomic_shared_ptr.hpp>
#include <boost/container/flat_set.hpp>
#include <mutex> #include <mutex>
@ -10,31 +14,76 @@ namespace Poco { class Logger; }
namespace DB namespace DB
{ {
struct Settings;
struct User; struct User;
using UserPtr = std::shared_ptr<const User>; using UserPtr = std::shared_ptr<const User>;
struct CurrentRolesInfo;
using CurrentRolesInfoPtr = std::shared_ptr<const CurrentRolesInfo>;
class RoleContext;
using RoleContextPtr = std::shared_ptr<const RoleContext>;
struct RowPolicyContext;
using RowPolicyContextPtr = std::shared_ptr<const RowPolicyContext>;
struct QuotaContext;
using QuotaContextPtr = std::shared_ptr<const QuotaContext>;
struct Settings;
class AccessControlManager;
class AccessRightsContext class AccessRightsContext
{ {
public: public:
struct Params
{
std::optional<UUID> user_id;
std::vector<UUID> current_roles;
bool use_default_roles = false;
UInt64 readonly = 0;
bool allow_ddl = false;
bool allow_introspection = false;
String current_database;
ClientInfo::Interface interface = ClientInfo::Interface::TCP;
ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN;
Poco::Net::IPAddress address;
String quota_key;
friend bool operator ==(const Params & lhs, const Params & rhs);
friend bool operator !=(const Params & lhs, const Params & rhs) { return !(lhs == rhs); }
friend bool operator <(const Params & lhs, const Params & rhs);
friend bool operator >(const Params & lhs, const Params & rhs) { return rhs < lhs; }
friend bool operator <=(const Params & lhs, const Params & rhs) { return !(rhs < lhs); }
friend bool operator >=(const Params & lhs, const Params & rhs) { return !(lhs < rhs); }
};
/// Default constructor creates access rights' context which allows everything. /// Default constructor creates access rights' context which allows everything.
AccessRightsContext(); AccessRightsContext();
AccessRightsContext(const UserPtr & user_, const ClientInfo & client_info_, const Settings & settings, const String & current_database_); const Params & getParams() const { return params; }
UserPtr getUser() const;
String getUserName() const;
/// Checks if a specified access granted, and throws an exception if not. void checkPassword(const String & password) const;
void checkHostIsAllowed() const;
CurrentRolesInfoPtr getRolesInfo() const;
std::vector<UUID> getCurrentRoles() const;
Strings getCurrentRolesNames() const;
std::vector<UUID> getEnabledRoles() const;
Strings getEnabledRolesNames() const;
RowPolicyContextPtr getRowPolicy() const;
QuotaContextPtr getQuota() const;
/// Checks if a specified access is granted, and throws an exception if not.
/// Empty database means the current database. /// Empty database means the current database.
void check(const AccessFlags & access) const; void checkAccess(const AccessFlags & access) const;
void check(const AccessFlags & access, const std::string_view & database) const; void checkAccess(const AccessFlags & access, const std::string_view & database) const;
void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const;
void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const; void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const;
void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const; void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const;
void check(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const; void checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const;
void check(const AccessRightsElement & access) const; void checkAccess(const AccessRightsElement & access) const;
void check(const AccessRightsElements & access) const; void checkAccess(const AccessRightsElements & access) const;
/// Checks if a specified access granted. /// Checks if a specified access is granted.
bool isGranted(const AccessFlags & access) const; bool isGranted(const AccessFlags & access) const;
bool isGranted(const AccessFlags & access, const std::string_view & database) const; bool isGranted(const AccessFlags & access, const std::string_view & database) const;
bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; bool isGranted(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const;
@ -44,7 +93,7 @@ public:
bool isGranted(const AccessRightsElement & access) const; bool isGranted(const AccessRightsElement & access) const;
bool isGranted(const AccessRightsElements & access) const; bool isGranted(const AccessRightsElements & access) const;
/// Checks if a specified access granted, and logs a warning if not. /// Checks if a specified access is granted, and logs a warning if not.
bool isGranted(Poco::Logger * log_, const AccessFlags & access) const; bool isGranted(Poco::Logger * log_, const AccessFlags & access) const;
bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const; bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database) const;
bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; bool isGranted(Poco::Logger * log_, const AccessFlags & access, const std::string_view & database, const std::string_view & table) const;
@ -54,7 +103,7 @@ public:
bool isGranted(Poco::Logger * log_, const AccessRightsElement & access) const; bool isGranted(Poco::Logger * log_, const AccessRightsElement & access) const;
bool isGranted(Poco::Logger * log_, const AccessRightsElements & access) const; bool isGranted(Poco::Logger * log_, const AccessRightsElements & access) const;
/// Checks if a specified access granted with grant option, and throws an exception if not. /// Checks if a specified access is granted with grant option, and throws an exception if not.
void checkGrantOption(const AccessFlags & access) const; void checkGrantOption(const AccessFlags & access) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database) const; void checkGrantOption(const AccessFlags & access, const std::string_view & database) const;
void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const; void checkGrantOption(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const;
@ -64,29 +113,45 @@ public:
void checkGrantOption(const AccessRightsElement & access) const; void checkGrantOption(const AccessRightsElement & access) const;
void checkGrantOption(const AccessRightsElements & access) const; void checkGrantOption(const AccessRightsElements & access) const;
/// Checks if a specified role is granted with admin option, and throws an exception if not.
void checkAdminOption(const UUID & role_id) const;
private: private:
friend class AccessRightsContextFactory;
friend struct ext::shared_ptr_helper<AccessRightsContext>;
AccessRightsContext(const AccessControlManager & manager_, const Params & params_); /// AccessRightsContext should be created by AccessRightsContextFactory.
void setUser(const UserPtr & user_) const;
void setRolesInfo(const CurrentRolesInfoPtr & roles_info_) const;
template <int mode, bool grant_option, typename... Args> template <int mode, bool grant_option, typename... Args>
bool checkImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const; bool checkAccessImpl(Poco::Logger * log_, const AccessFlags & access, const Args &... args) const;
template <int mode, bool grant_option> template <int mode, bool grant_option>
bool checkImpl(Poco::Logger * log_, const AccessRightsElement & access) const; bool checkAccessImpl(Poco::Logger * log_, const AccessRightsElement & access) const;
template <int mode, bool grant_option> template <int mode, bool grant_option>
bool checkImpl(Poco::Logger * log_, const AccessRightsElements & access) const; bool checkAccessImpl(Poco::Logger * log_, const AccessRightsElements & access) const;
boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option) const; boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option) const;
boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const; boost::shared_ptr<const AccessRights> calculateResultAccess(bool grant_option, UInt64 readonly_, bool allow_ddl_, bool allow_introspection_) const;
const UserPtr user; const AccessControlManager * manager = nullptr;
const UInt64 readonly = 0; const Params params;
const bool allow_ddl = true; mutable Poco::Logger * trace_log = nullptr;
const bool allow_introspection = true; mutable UserPtr user;
const String current_database; mutable String user_name;
const ClientInfo::Interface interface = ClientInfo::Interface::TCP; mutable ext::scope_guard subscription_for_user_change;
const ClientInfo::HTTPMethod http_method = ClientInfo::HTTPMethod::UNKNOWN; mutable RoleContextPtr role_context;
Poco::Logger * const trace_log = nullptr; mutable ext::scope_guard subscription_for_roles_info_change;
mutable CurrentRolesInfoPtr roles_info;
mutable boost::atomic_shared_ptr<const boost::container::flat_set<UUID>> enabled_roles_with_admin_option;
mutable boost::atomic_shared_ptr<const AccessRights> result_access_cache[7]; mutable boost::atomic_shared_ptr<const AccessRights> result_access_cache[7];
mutable RowPolicyContextPtr row_policy_context;
mutable QuotaContextPtr quota_context;
mutable std::mutex mutex; mutable std::mutex mutex;
}; };
using AccessRightsContextPtr = std::shared_ptr<const AccessRightsContext>;
} }

View File

@ -0,0 +1,48 @@
#include <Access/AccessRightsContextFactory.h>
#include <Access/AccessControlManager.h>
#include <Core/Settings.h>
namespace DB
{
AccessRightsContextFactory::AccessRightsContextFactory(const AccessControlManager & manager_)
: manager(manager_), cache(600000 /* 10 minutes */) {}
AccessRightsContextFactory::~AccessRightsContextFactory() = default;
AccessRightsContextPtr AccessRightsContextFactory::createContext(const Params & params)
{
std::lock_guard lock{mutex};
auto x = cache.get(params);
if (x)
return *x;
auto res = ext::shared_ptr_helper<AccessRightsContext>::create(manager, params);
cache.add(params, res);
return res;
}
AccessRightsContextPtr AccessRightsContextFactory::createContext(
const UUID & user_id,
const std::vector<UUID> & current_roles,
bool use_default_roles,
const Settings & settings,
const String & current_database,
const ClientInfo & client_info)
{
Params params;
params.user_id = user_id;
params.current_roles = current_roles;
params.use_default_roles = use_default_roles;
params.current_database = current_database;
params.readonly = settings.readonly;
params.allow_ddl = settings.allow_ddl;
params.allow_introspection = settings.allow_introspection_functions;
params.interface = client_info.interface;
params.http_method = client_info.http_method;
params.address = client_info.current_address.host();
params.quota_key = client_info.quota_key;
return createContext(params);
}
}

View File

@ -0,0 +1,29 @@
#pragma once
#include <Access/AccessRightsContext.h>
#include <Poco/ExpireCache.h>
#include <mutex>
namespace DB
{
class AccessControlManager;
class AccessRightsContextFactory
{
public:
AccessRightsContextFactory(const AccessControlManager & manager_);
~AccessRightsContextFactory();
using Params = AccessRightsContext::Params;
AccessRightsContextPtr createContext(const Params & params);
AccessRightsContextPtr createContext(const UUID & user_id, const std::vector<UUID> & current_roles, bool use_default_roles, const Settings & settings, const String & current_database, const ClientInfo & client_info);
private:
const AccessControlManager & manager;
Poco::ExpireCache<Params, AccessRightsContextPtr> cache;
std::mutex mutex;
};
}

View File

@ -66,24 +66,12 @@ enum class AccessType
CREATE_TEMPORARY_TABLE, /// allows to create and manipulate temporary tables and views. CREATE_TEMPORARY_TABLE, /// allows to create and manipulate temporary tables and views.
CREATE, /// allows to execute {CREATE|ATTACH} [TEMPORARY] {DATABASE|TABLE|VIEW|DICTIONARY} CREATE, /// allows to execute {CREATE|ATTACH} [TEMPORARY] {DATABASE|TABLE|VIEW|DICTIONARY}
ATTACH_DATABASE, /// allows to execute {CREATE|ATTACH} DATABASE
ATTACH_TABLE, /// allows to execute {CREATE|ATTACH} TABLE
ATTACH_VIEW, /// allows to execute {CREATE|ATTACH} VIEW
ATTACH_DICTIONARY, /// allows to execute {CREATE|ATTACH} DICTIONARY
ATTACH, /// allows to execute {CREATE|ATTACH} {DATABASE|TABLE|VIEW|DICTIONARY}
DROP_DATABASE, DROP_DATABASE,
DROP_TABLE, DROP_TABLE,
DROP_VIEW, DROP_VIEW,
DROP_DICTIONARY, DROP_DICTIONARY,
DROP, /// allows to execute DROP {DATABASE|TABLE|VIEW|DICTIONARY} DROP, /// allows to execute DROP {DATABASE|TABLE|VIEW|DICTIONARY}
DETACH_DATABASE,
DETACH_TABLE,
DETACH_VIEW,
DETACH_DICTIONARY,
DETACH, /// allows to execute DETACH {DATABASE|TABLE|VIEW|DICTIONARY}
TRUNCATE_TABLE, TRUNCATE_TABLE,
TRUNCATE_VIEW, TRUNCATE_VIEW,
TRUNCATE, /// allows to execute TRUNCATE {TABLE|VIEW} TRUNCATE, /// allows to execute TRUNCATE {TABLE|VIEW}
@ -94,7 +82,7 @@ enum class AccessType
KILL_MUTATION, /// allows to kill a mutation KILL_MUTATION, /// allows to kill a mutation
KILL, /// allows to execute KILL {MUTATION|QUERY} KILL, /// allows to execute KILL {MUTATION|QUERY}
CREATE_USER, /// allows to create, alter and drop users, roles, quotas, row policies. CREATE_USER,
ALTER_USER, ALTER_USER,
DROP_USER, DROP_USER,
CREATE_ROLE, CREATE_ROLE,
@ -106,6 +94,8 @@ enum class AccessType
ALTER_QUOTA, ALTER_QUOTA,
DROP_QUOTA, DROP_QUOTA,
ROLE_ADMIN, /// allows to grant and revoke any roles.
SHUTDOWN, SHUTDOWN,
DROP_CACHE, DROP_CACHE,
RELOAD_CONFIG, RELOAD_CONFIG,
@ -235,24 +225,12 @@ namespace impl
ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_TEMPORARY_TABLE); ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_TEMPORARY_TABLE);
ACCESS_TYPE_TO_KEYWORD_CASE(CREATE); ACCESS_TYPE_TO_KEYWORD_CASE(CREATE);
ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH_DATABASE);
ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH_TABLE);
ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH_VIEW);
ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH_DICTIONARY);
ACCESS_TYPE_TO_KEYWORD_CASE(ATTACH);
ACCESS_TYPE_TO_KEYWORD_CASE(DROP_DATABASE); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_DATABASE);
ACCESS_TYPE_TO_KEYWORD_CASE(DROP_TABLE); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_TABLE);
ACCESS_TYPE_TO_KEYWORD_CASE(DROP_VIEW); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_VIEW);
ACCESS_TYPE_TO_KEYWORD_CASE(DROP_DICTIONARY); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_DICTIONARY);
ACCESS_TYPE_TO_KEYWORD_CASE(DROP); ACCESS_TYPE_TO_KEYWORD_CASE(DROP);
ACCESS_TYPE_TO_KEYWORD_CASE(DETACH_DATABASE);
ACCESS_TYPE_TO_KEYWORD_CASE(DETACH_TABLE);
ACCESS_TYPE_TO_KEYWORD_CASE(DETACH_VIEW);
ACCESS_TYPE_TO_KEYWORD_CASE(DETACH_DICTIONARY);
ACCESS_TYPE_TO_KEYWORD_CASE(DETACH);
ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE_TABLE); ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE_TABLE);
ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE_VIEW); ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE_VIEW);
ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE); ACCESS_TYPE_TO_KEYWORD_CASE(TRUNCATE);
@ -274,6 +252,7 @@ namespace impl
ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_QUOTA); ACCESS_TYPE_TO_KEYWORD_CASE(CREATE_QUOTA);
ACCESS_TYPE_TO_KEYWORD_CASE(ALTER_QUOTA); ACCESS_TYPE_TO_KEYWORD_CASE(ALTER_QUOTA);
ACCESS_TYPE_TO_KEYWORD_CASE(DROP_QUOTA); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_QUOTA);
ACCESS_TYPE_TO_KEYWORD_CASE(ROLE_ADMIN);
ACCESS_TYPE_TO_KEYWORD_CASE(SHUTDOWN); ACCESS_TYPE_TO_KEYWORD_CASE(SHUTDOWN);
ACCESS_TYPE_TO_KEYWORD_CASE(DROP_CACHE); ACCESS_TYPE_TO_KEYWORD_CASE(DROP_CACHE);

View File

@ -46,7 +46,7 @@ public:
struct AnyHostTag {}; struct AnyHostTag {};
AllowedClientHosts() {} AllowedClientHosts() {}
explicit AllowedClientHosts(AnyHostTag) { addAnyHost(); } AllowedClientHosts(AnyHostTag) { addAnyHost(); }
~AllowedClientHosts() {} ~AllowedClientHosts() {}
AllowedClientHosts(const AllowedClientHosts & src) = default; AllowedClientHosts(const AllowedClientHosts & src) = default;

View File

@ -0,0 +1,34 @@
#include <Access/CurrentRolesInfo.h>
namespace DB
{
Strings CurrentRolesInfo::getCurrentRolesNames() const
{
Strings result;
result.reserve(current_roles.size());
for (const auto & id : current_roles)
result.emplace_back(names_of_roles.at(id));
return result;
}
Strings CurrentRolesInfo::getEnabledRolesNames() const
{
Strings result;
result.reserve(enabled_roles.size());
for (const auto & id : enabled_roles)
result.emplace_back(names_of_roles.at(id));
return result;
}
bool operator==(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs)
{
return (lhs.current_roles == rhs.current_roles) && (lhs.enabled_roles == rhs.enabled_roles)
&& (lhs.enabled_roles_with_admin_option == rhs.enabled_roles_with_admin_option) && (lhs.names_of_roles == rhs.names_of_roles)
&& (lhs.access == rhs.access) && (lhs.access_with_grant_option == rhs.access_with_grant_option);
}
}

View File

@ -0,0 +1,31 @@
#pragma once
#include <Access/AccessRights.h>
#include <Core/UUID.h>
#include <unordered_map>
#include <vector>
namespace DB
{
/// Information about a role.
struct CurrentRolesInfo
{
std::vector<UUID> current_roles;
std::vector<UUID> enabled_roles;
std::vector<UUID> enabled_roles_with_admin_option;
std::unordered_map<UUID, String> names_of_roles;
AccessRights access;
AccessRights access_with_grant_option;
Strings getCurrentRolesNames() const;
Strings getEnabledRolesNames() const;
friend bool operator ==(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs);
friend bool operator !=(const CurrentRolesInfo & lhs, const CurrentRolesInfo & rhs) { return !(lhs == rhs); }
};
using CurrentRolesInfoPtr = std::shared_ptr<const CurrentRolesInfo>;
}

View File

@ -0,0 +1,288 @@
#include <Access/GenericRoleSet.h>
#include <Access/AccessControlManager.h>
#include <Access/User.h>
#include <Access/Role.h>
#include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/formatAST.h>
#include <boost/range/algorithm/set_algorithm.hpp>
#include <boost/range/algorithm/sort.hpp>
#include <boost/range/algorithm_ext/push_back.hpp>
namespace DB
{
GenericRoleSet::GenericRoleSet() = default;
GenericRoleSet::GenericRoleSet(const GenericRoleSet & src) = default;
GenericRoleSet & GenericRoleSet::operator =(const GenericRoleSet & src) = default;
GenericRoleSet::GenericRoleSet(GenericRoleSet && src) = default;
GenericRoleSet & GenericRoleSet::operator =(GenericRoleSet && src) = default;
GenericRoleSet::GenericRoleSet(AllTag)
{
all = true;
}
GenericRoleSet::GenericRoleSet(const UUID & id)
{
add(id);
}
GenericRoleSet::GenericRoleSet(const std::vector<UUID> & ids_)
{
add(ids_);
}
GenericRoleSet::GenericRoleSet(const boost::container::flat_set<UUID> & ids_)
{
add(ids_);
}
GenericRoleSet::GenericRoleSet(const ASTGenericRoleSet & ast, const AccessControlManager & manager, const std::optional<UUID> & current_user_id)
{
all = ast.all;
if (!ast.names.empty() && !all)
{
ids.reserve(ast.names.size());
for (const String & name : ast.names)
{
auto id = manager.find<User>(name);
if (!id)
id = manager.getID<Role>(name);
ids.insert(*id);
}
}
if (ast.current_user && !all)
{
if (!current_user_id)
throw Exception("Current user is unknown", ErrorCodes::LOGICAL_ERROR);
ids.insert(*current_user_id);
}
if (!ast.except_names.empty())
{
except_ids.reserve(ast.except_names.size());
for (const String & except_name : ast.except_names)
{
auto except_id = manager.find<User>(except_name);
if (!except_id)
except_id = manager.getID<Role>(except_name);
except_ids.insert(*except_id);
}
}
if (ast.except_current_user)
{
if (!current_user_id)
throw Exception("Current user is unknown", ErrorCodes::LOGICAL_ERROR);
except_ids.insert(*current_user_id);
}
for (const UUID & except_id : except_ids)
ids.erase(except_id);
}
std::shared_ptr<ASTGenericRoleSet> GenericRoleSet::toAST(const AccessControlManager & manager) const
{
auto ast = std::make_shared<ASTGenericRoleSet>();
ast->all = all;
if (!ids.empty())
{
ast->names.reserve(ids.size());
for (const UUID & id : ids)
{
auto name = manager.tryReadName(id);
if (name)
ast->names.emplace_back(std::move(*name));
}
boost::range::sort(ast->names);
}
if (!except_ids.empty())
{
ast->except_names.reserve(except_ids.size());
for (const UUID & except_id : except_ids)
{
auto except_name = manager.tryReadName(except_id);
if (except_name)
ast->except_names.emplace_back(std::move(*except_name));
}
boost::range::sort(ast->except_names);
}
return ast;
}
String GenericRoleSet::toString(const AccessControlManager & manager) const
{
auto ast = toAST(manager);
return serializeAST(*ast);
}
Strings GenericRoleSet::toStrings(const AccessControlManager & manager) const
{
if (all || !except_ids.empty())
return {toString(manager)};
Strings names;
names.reserve(ids.size());
for (const UUID & id : ids)
{
auto name = manager.tryReadName(id);
if (name)
names.emplace_back(std::move(*name));
}
boost::range::sort(names);
return names;
}
bool GenericRoleSet::empty() const
{
return ids.empty() && !all;
}
void GenericRoleSet::clear()
{
ids.clear();
all = false;
except_ids.clear();
}
void GenericRoleSet::add(const UUID & id)
{
ids.insert(id);
}
void GenericRoleSet::add(const std::vector<UUID> & ids_)
{
for (const auto & id : ids_)
add(id);
}
void GenericRoleSet::add(const boost::container::flat_set<UUID> & ids_)
{
for (const auto & id : ids_)
add(id);
}
bool GenericRoleSet::match(const UUID & id) const
{
return (all || ids.contains(id)) && !except_ids.contains(id);
}
bool GenericRoleSet::match(const UUID & user_id, const std::vector<UUID> & enabled_roles) const
{
if (!all && !ids.contains(user_id))
{
bool found_enabled_role = std::any_of(
enabled_roles.begin(), enabled_roles.end(), [this](const UUID & enabled_role) { return ids.contains(enabled_role); });
if (!found_enabled_role)
return false;
}
if (except_ids.contains(user_id))
return false;
bool in_except_list = std::any_of(
enabled_roles.begin(), enabled_roles.end(), [this](const UUID & enabled_role) { return except_ids.contains(enabled_role); });
if (in_except_list)
return false;
return true;
}
bool GenericRoleSet::match(const UUID & user_id, const boost::container::flat_set<UUID> & enabled_roles) const
{
if (!all && !ids.contains(user_id))
{
bool found_enabled_role = std::any_of(
enabled_roles.begin(), enabled_roles.end(), [this](const UUID & enabled_role) { return ids.contains(enabled_role); });
if (!found_enabled_role)
return false;
}
if (except_ids.contains(user_id))
return false;
bool in_except_list = std::any_of(
enabled_roles.begin(), enabled_roles.end(), [this](const UUID & enabled_role) { return except_ids.contains(enabled_role); });
if (in_except_list)
return false;
return true;
}
std::vector<UUID> GenericRoleSet::getMatchingIDs() const
{
if (all)
throw Exception("getAllMatchingIDs() can't get ALL ids", ErrorCodes::LOGICAL_ERROR);
std::vector<UUID> res;
boost::range::set_difference(ids, except_ids, std::back_inserter(res));
return res;
}
std::vector<UUID> GenericRoleSet::getMatchingUsers(const AccessControlManager & manager) const
{
if (!all)
return getMatchingIDs();
std::vector<UUID> res;
for (const UUID & id : manager.findAll<User>())
{
if (match(id))
res.push_back(id);
}
return res;
}
std::vector<UUID> GenericRoleSet::getMatchingRoles(const AccessControlManager & manager) const
{
if (!all)
return getMatchingIDs();
std::vector<UUID> res;
for (const UUID & id : manager.findAll<Role>())
{
if (match(id))
res.push_back(id);
}
return res;
}
std::vector<UUID> GenericRoleSet::getMatchingUsersAndRoles(const AccessControlManager & manager) const
{
if (!all)
return getMatchingIDs();
std::vector<UUID> vec = getMatchingUsers(manager);
boost::range::push_back(vec, getMatchingRoles(manager));
return vec;
}
bool operator ==(const GenericRoleSet & lhs, const GenericRoleSet & rhs)
{
return (lhs.all == rhs.all) && (lhs.ids == rhs.ids) && (lhs.except_ids == rhs.except_ids);
}
}

View File

@ -0,0 +1,66 @@
#pragma once
#include <Core/UUID.h>
#include <boost/container/flat_set.hpp>
#include <memory>
#include <optional>
namespace DB
{
class ASTGenericRoleSet;
class AccessControlManager;
/// Represents a set of users/roles like
/// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...]
/// Similar to ASTGenericRoleSet, but with IDs instead of names.
struct GenericRoleSet
{
GenericRoleSet();
GenericRoleSet(const GenericRoleSet & src);
GenericRoleSet & operator =(const GenericRoleSet & src);
GenericRoleSet(GenericRoleSet && src);
GenericRoleSet & operator =(GenericRoleSet && src);
struct AllTag {};
GenericRoleSet(AllTag);
GenericRoleSet(const UUID & id);
GenericRoleSet(const std::vector<UUID> & ids_);
GenericRoleSet(const boost::container::flat_set<UUID> & ids_);
GenericRoleSet(const ASTGenericRoleSet & ast, const AccessControlManager & manager, const std::optional<UUID> & current_user_id = {});
std::shared_ptr<ASTGenericRoleSet> toAST(const AccessControlManager & manager) const;
String toString(const AccessControlManager & manager) const;
Strings toStrings(const AccessControlManager & manager) const;
bool empty() const;
void clear();
void add(const UUID & id);
void add(const std::vector<UUID> & ids_);
void add(const boost::container::flat_set<UUID> & ids_);
/// Checks if a specified ID matches this GenericRoleSet.
bool match(const UUID & id) const;
bool match(const UUID & user_id, const std::vector<UUID> & enabled_roles) const;
bool match(const UUID & user_id, const boost::container::flat_set<UUID> & enabled_roles) const;
/// Returns a list of matching IDs. The function must not be called if `all` == `true`.
std::vector<UUID> getMatchingIDs() const;
/// Returns a list of matching users.
std::vector<UUID> getMatchingUsers(const AccessControlManager & manager) const;
std::vector<UUID> getMatchingRoles(const AccessControlManager & manager) const;
std::vector<UUID> getMatchingUsersAndRoles(const AccessControlManager & manager) const;
friend bool operator ==(const GenericRoleSet & lhs, const GenericRoleSet & rhs);
friend bool operator !=(const GenericRoleSet & lhs, const GenericRoleSet & rhs) { return !(lhs == rhs); }
boost::container::flat_set<UUID> ids;
bool all = false;
boost::container::flat_set<UUID> except_ids;
};
}

View File

@ -1,4 +1,6 @@
#include <Access/IAccessStorage.h> #include <Access/IAccessStorage.h>
#include <Access/User.h>
#include <Access/Role.h>
#include <Common/Exception.h> #include <Common/Exception.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
@ -15,6 +17,8 @@ namespace ErrorCodes
extern const int ACCESS_ENTITY_ALREADY_EXISTS; extern const int ACCESS_ENTITY_ALREADY_EXISTS;
extern const int ACCESS_ENTITY_FOUND_DUPLICATES; extern const int ACCESS_ENTITY_FOUND_DUPLICATES;
extern const int ACCESS_ENTITY_STORAGE_READONLY; extern const int ACCESS_ENTITY_STORAGE_READONLY;
extern const int UNKNOWN_USER;
extern const int UNKNOWN_ROLE;
} }
@ -365,8 +369,15 @@ void IAccessStorage::throwNotFound(const UUID & id) const
void IAccessStorage::throwNotFound(std::type_index type, const String & name) const void IAccessStorage::throwNotFound(std::type_index type, const String & name) const
{ {
throw Exception( int error_code;
getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND); if (type == typeid(User))
error_code = ErrorCodes::UNKNOWN_USER;
else if (type == typeid(Role))
error_code = ErrorCodes::UNKNOWN_ROLE;
else
error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND;
throw Exception(getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), error_code);
} }

View File

@ -23,8 +23,7 @@ bool Quota::equal(const IAccessEntity & other) const
if (!IAccessEntity::equal(other)) if (!IAccessEntity::equal(other))
return false; return false;
const auto & other_quota = typeid_cast<const Quota &>(other); const auto & other_quota = typeid_cast<const Quota &>(other);
return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (roles == other_quota.roles) return (all_limits == other_quota.all_limits) && (key_type == other_quota.key_type) && (roles == other_quota.roles);
&& (all_roles == other_quota.all_roles) && (except_roles == other_quota.except_roles);
} }

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <Access/IAccessEntity.h> #include <Access/IAccessEntity.h>
#include <Access/GenericRoleSet.h>
#include <chrono> #include <chrono>
@ -63,9 +64,7 @@ struct Quota : public IAccessEntity
KeyType key_type = KeyType::NONE; KeyType key_type = KeyType::NONE;
/// Which roles or users should use this quota. /// Which roles or users should use this quota.
Strings roles; GenericRoleSet roles;
bool all_roles = false;
Strings except_roles;
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); }

View File

@ -3,6 +3,7 @@
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <ext/chrono_io.h> #include <ext/chrono_io.h>
#include <ext/range.h> #include <ext/range.h>
#include <boost/smart_ptr/make_shared.hpp>
#include <boost/range/algorithm/fill.hpp> #include <boost/range/algorithm/fill.hpp>
@ -171,16 +172,18 @@ QuotaUsageInfo QuotaContext::Intervals::getUsageInfo(std::chrono::system_clock::
QuotaContext::QuotaContext() QuotaContext::QuotaContext()
: atomic_intervals(std::make_shared<Intervals>()) /// Unlimited quota. : intervals(boost::make_shared<Intervals>()) /// Unlimited quota.
{ {
} }
QuotaContext::QuotaContext( QuotaContext::QuotaContext(
const String & user_name_, const String & user_name_,
const UUID & user_id_,
const std::vector<UUID> & enabled_roles_,
const Poco::Net::IPAddress & address_, const Poco::Net::IPAddress & address_,
const String & client_key_) const String & client_key_)
: user_name(user_name_), address(address_), client_key(client_key_) : user_name(user_name_), user_id(user_id_), enabled_roles(enabled_roles_), address(address_), client_key(client_key_)
{ {
} }
@ -188,66 +191,66 @@ QuotaContext::QuotaContext(
QuotaContext::~QuotaContext() = default; QuotaContext::~QuotaContext() = default;
void QuotaContext::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded) void QuotaContext::used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded) const
{ {
used({resource_type, amount}, check_exceeded); used({resource_type, amount}, check_exceeded);
} }
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded) void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded) const
{ {
auto intervals_ptr = std::atomic_load(&atomic_intervals); auto loaded = intervals.load();
auto current_time = std::chrono::system_clock::now(); auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded); Impl::used(user_name, *loaded, resource.first, resource.second, current_time, check_exceeded);
} }
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded) void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded) const
{ {
auto intervals_ptr = std::atomic_load(&atomic_intervals); auto loaded = intervals.load();
auto current_time = std::chrono::system_clock::now(); auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded); Impl::used(user_name, *loaded, resource1.first, resource1.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded); Impl::used(user_name, *loaded, resource2.first, resource2.second, current_time, check_exceeded);
} }
void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded) void QuotaContext::used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded) const
{ {
auto intervals_ptr = std::atomic_load(&atomic_intervals); auto loaded = intervals.load();
auto current_time = std::chrono::system_clock::now(); auto current_time = std::chrono::system_clock::now();
Impl::used(user_name, *intervals_ptr, resource1.first, resource1.second, current_time, check_exceeded); Impl::used(user_name, *loaded, resource1.first, resource1.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource2.first, resource2.second, current_time, check_exceeded); Impl::used(user_name, *loaded, resource2.first, resource2.second, current_time, check_exceeded);
Impl::used(user_name, *intervals_ptr, resource3.first, resource3.second, current_time, check_exceeded); Impl::used(user_name, *loaded, resource3.first, resource3.second, current_time, check_exceeded);
} }
void QuotaContext::used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded) void QuotaContext::used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded) const
{ {
auto intervals_ptr = std::atomic_load(&atomic_intervals); auto loaded = intervals.load();
auto current_time = std::chrono::system_clock::now(); auto current_time = std::chrono::system_clock::now();
for (const auto & resource : resources) for (const auto & resource : resources)
Impl::used(user_name, *intervals_ptr, resource.first, resource.second, current_time, check_exceeded); Impl::used(user_name, *loaded, resource.first, resource.second, current_time, check_exceeded);
} }
void QuotaContext::checkExceeded() void QuotaContext::checkExceeded() const
{ {
auto intervals_ptr = std::atomic_load(&atomic_intervals); auto loaded = intervals.load();
Impl::checkExceeded(user_name, *intervals_ptr, std::chrono::system_clock::now()); Impl::checkExceeded(user_name, *loaded, std::chrono::system_clock::now());
} }
void QuotaContext::checkExceeded(ResourceType resource_type) void QuotaContext::checkExceeded(ResourceType resource_type) const
{ {
auto intervals_ptr = std::atomic_load(&atomic_intervals); auto loaded = intervals.load();
Impl::checkExceeded(user_name, *intervals_ptr, resource_type, std::chrono::system_clock::now()); Impl::checkExceeded(user_name, *loaded, resource_type, std::chrono::system_clock::now());
} }
QuotaUsageInfo QuotaContext::getUsageInfo() const QuotaUsageInfo QuotaContext::getUsageInfo() const
{ {
auto intervals_ptr = std::atomic_load(&atomic_intervals); auto loaded = intervals.load();
return intervals_ptr->getUsageInfo(std::chrono::system_clock::now()); return loaded->getUsageInfo(std::chrono::system_clock::now());
} }

View File

@ -5,6 +5,7 @@
#include <Poco/Net/IPAddress.h> #include <Poco/Net/IPAddress.h>
#include <ext/shared_ptr_helper.h> #include <ext/shared_ptr_helper.h>
#include <boost/noncopyable.hpp> #include <boost/noncopyable.hpp>
#include <boost/smart_ptr/atomic_shared_ptr.hpp>
#include <atomic> #include <atomic>
#include <chrono> #include <chrono>
#include <memory> #include <memory>
@ -28,15 +29,15 @@ public:
~QuotaContext(); ~QuotaContext();
/// Tracks resource consumption. If the quota exceeded and `check_exceeded == true`, throws an exception. /// Tracks resource consumption. If the quota exceeded and `check_exceeded == true`, throws an exception.
void used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded = true); void used(ResourceType resource_type, ResourceAmount amount, bool check_exceeded = true) const;
void used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded = true); void used(const std::pair<ResourceType, ResourceAmount> & resource, bool check_exceeded = true) const;
void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded = true); void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, bool check_exceeded = true) const;
void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded = true); void used(const std::pair<ResourceType, ResourceAmount> & resource1, const std::pair<ResourceType, ResourceAmount> & resource2, const std::pair<ResourceType, ResourceAmount> & resource3, bool check_exceeded = true) const;
void used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded = true); void used(const std::vector<std::pair<ResourceType, ResourceAmount>> & resources, bool check_exceeded = true) const;
/// Checks if the quota exceeded. If so, throws an exception. /// Checks if the quota exceeded. If so, throws an exception.
void checkExceeded(); void checkExceeded() const;
void checkExceeded(ResourceType resource_type); void checkExceeded(ResourceType resource_type) const;
/// Returns the information about this quota context. /// Returns the information about this quota context.
QuotaUsageInfo getUsageInfo() const; QuotaUsageInfo getUsageInfo() const;
@ -46,7 +47,7 @@ private:
friend struct ext::shared_ptr_helper<QuotaContext>; friend struct ext::shared_ptr_helper<QuotaContext>;
/// Instances of this class are created by QuotaContextFactory. /// Instances of this class are created by QuotaContextFactory.
QuotaContext(const String & user_name_, const Poco::Net::IPAddress & address_, const String & client_key_); QuotaContext(const String & user_name_, const UUID & user_id_, const std::vector<UUID> & enabled_roles_, const Poco::Net::IPAddress & address_, const String & client_key_);
static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE; static constexpr size_t MAX_RESOURCE_TYPE = Quota::MAX_RESOURCE_TYPE;
@ -76,12 +77,14 @@ private:
struct Impl; struct Impl;
const String user_name; const String user_name;
const UUID user_id;
const std::vector<UUID> enabled_roles;
const Poco::Net::IPAddress address; const Poco::Net::IPAddress address;
const String client_key; const String client_key;
std::shared_ptr<const Intervals> atomic_intervals; /// atomically changed by QuotaUsageManager boost::atomic_shared_ptr<const Intervals> intervals; /// atomically changed by QuotaUsageManager
}; };
using QuotaContextPtr = std::shared_ptr<QuotaContext>; using QuotaContextPtr = std::shared_ptr<const QuotaContext>;
/// The information about a quota context. /// The information about a quota context.

View File

@ -9,6 +9,7 @@
#include <boost/range/algorithm/lower_bound.hpp> #include <boost/range/algorithm/lower_bound.hpp>
#include <boost/range/algorithm/stable_sort.hpp> #include <boost/range/algorithm/stable_sort.hpp>
#include <boost/range/algorithm_ext/erase.hpp> #include <boost/range/algorithm_ext/erase.hpp>
#include <boost/smart_ptr/make_shared.hpp>
namespace DB namespace DB
@ -34,24 +35,14 @@ void QuotaContextFactory::QuotaInfo::setQuota(const QuotaPtr & quota_, const UUI
{ {
quota = quota_; quota = quota_;
quota_id = quota_id_; quota_id = quota_id_;
roles = &quota->roles;
boost::range::copy(quota->roles, std::inserter(roles, roles.end()));
all_roles = quota->all_roles;
boost::range::copy(quota->except_roles, std::inserter(except_roles, except_roles.end()));
rebuildAllIntervals(); rebuildAllIntervals();
} }
bool QuotaContextFactory::QuotaInfo::canUseWithContext(const QuotaContext & context) const bool QuotaContextFactory::QuotaInfo::canUseWithContext(const QuotaContext & context) const
{ {
if (roles.count(context.user_name)) return roles->match(context.user_id, context.enabled_roles);
return true;
if (all_roles && !except_roles.count(context.user_name))
return true;
return false;
} }
@ -91,7 +82,7 @@ String QuotaContextFactory::QuotaInfo::calculateKey(const QuotaContext & context
} }
std::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::getOrBuildIntervals(const String & key) boost::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::getOrBuildIntervals(const String & key)
{ {
auto it = key_to_intervals.find(key); auto it = key_to_intervals.find(key);
if (it != key_to_intervals.end()) if (it != key_to_intervals.end())
@ -107,9 +98,9 @@ void QuotaContextFactory::QuotaInfo::rebuildAllIntervals()
} }
std::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::rebuildIntervals(const String & key) boost::shared_ptr<const QuotaContext::Intervals> QuotaContextFactory::QuotaInfo::rebuildIntervals(const String & key)
{ {
auto new_intervals = std::make_shared<Intervals>(); auto new_intervals = boost::make_shared<Intervals>();
new_intervals->quota_name = quota->getName(); new_intervals->quota_name = quota->getName();
new_intervals->quota_id = quota_id; new_intervals->quota_id = quota_id;
new_intervals->quota_key = key; new_intervals->quota_key = key;
@ -184,11 +175,11 @@ QuotaContextFactory::~QuotaContextFactory()
} }
std::shared_ptr<QuotaContext> QuotaContextFactory::createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key) QuotaContextPtr QuotaContextFactory::createContext(const String & user_name, const UUID & user_id, const std::vector<UUID> & enabled_roles, const Poco::Net::IPAddress & address, const String & client_key)
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
ensureAllQuotasRead(); ensureAllQuotasRead();
auto context = ext::shared_ptr_helper<QuotaContext>::create(user_name, address, client_key); auto context = ext::shared_ptr_helper<QuotaContext>::create(user_name, user_id, enabled_roles, address, client_key);
contexts.push_back(context); contexts.push_back(context);
chooseQuotaForContext(context); chooseQuotaForContext(context);
return context; return context;
@ -266,7 +257,7 @@ void QuotaContextFactory::chooseQuotaForAllContexts()
void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr<QuotaContext> & context) void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr<QuotaContext> & context)
{ {
/// `mutex` is already locked. /// `mutex` is already locked.
std::shared_ptr<const Intervals> intervals; boost::shared_ptr<const Intervals> intervals;
for (auto & info : all_quotas | boost::adaptors::map_values) for (auto & info : all_quotas | boost::adaptors::map_values)
{ {
if (info.canUseWithContext(*context)) if (info.canUseWithContext(*context))
@ -278,9 +269,9 @@ void QuotaContextFactory::chooseQuotaForContext(const std::shared_ptr<QuotaConte
} }
if (!intervals) if (!intervals)
intervals = std::make_shared<Intervals>(); /// No quota == no limits. intervals = boost::make_shared<Intervals>(); /// No quota == no limits.
std::atomic_store(&context->atomic_intervals, intervals); context->intervals.store(intervals);
} }

View File

@ -20,7 +20,7 @@ public:
QuotaContextFactory(const AccessControlManager & access_control_manager_); QuotaContextFactory(const AccessControlManager & access_control_manager_);
~QuotaContextFactory(); ~QuotaContextFactory();
QuotaContextPtr createContext(const String & user_name, const Poco::Net::IPAddress & address, const String & client_key); QuotaContextPtr createContext(const String & user_name, const UUID & user_id, const std::vector<UUID> & enabled_roles, const Poco::Net::IPAddress & address, const String & client_key);
std::vector<QuotaUsageInfo> getUsageInfo() const; std::vector<QuotaUsageInfo> getUsageInfo() const;
private: private:
@ -34,16 +34,14 @@ private:
bool canUseWithContext(const QuotaContext & context) const; bool canUseWithContext(const QuotaContext & context) const;
String calculateKey(const QuotaContext & context) const; String calculateKey(const QuotaContext & context) const;
std::shared_ptr<const Intervals> getOrBuildIntervals(const String & key); boost::shared_ptr<const Intervals> getOrBuildIntervals(const String & key);
std::shared_ptr<const Intervals> rebuildIntervals(const String & key); boost::shared_ptr<const Intervals> rebuildIntervals(const String & key);
void rebuildAllIntervals(); void rebuildAllIntervals();
QuotaPtr quota; QuotaPtr quota;
UUID quota_id; UUID quota_id;
std::unordered_set<String> roles; const GenericRoleSet * roles = nullptr;
bool all_roles = false; std::unordered_map<String /* quota key */, boost::shared_ptr<const Intervals>> key_to_intervals;
std::unordered_set<String> except_roles;
std::unordered_map<String /* quota key */, std::shared_ptr<const Intervals>> key_to_intervals;
}; };
void ensureAllQuotasRead(); void ensureAllQuotasRead();

16
dbms/src/Access/Role.cpp Normal file
View File

@ -0,0 +1,16 @@
#include <Access/Role.h>
namespace DB
{
bool Role::equal(const IAccessEntity & other) const
{
if (!IAccessEntity::equal(other))
return false;
const auto & other_role = typeid_cast<const Role &>(other);
return (access == other_role.access) && (access_with_grant_option == other_role.access_with_grant_option)
&& (granted_roles == other_role.granted_roles) && (granted_roles_with_admin_option == other_role.granted_roles_with_admin_option);
}
}

24
dbms/src/Access/Role.h Normal file
View File

@ -0,0 +1,24 @@
#pragma once
#include <Access/IAccessEntity.h>
#include <Access/AccessRights.h>
#include <Core/UUID.h>
#include <boost/container/flat_set.hpp>
namespace DB
{
struct Role : public IAccessEntity
{
AccessRights access;
AccessRights access_with_grant_option;
boost::container::flat_set<UUID> granted_roles;
boost::container::flat_set<UUID> granted_roles_with_admin_option;
bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Role>(); }
};
using RolePtr = std::shared_ptr<const Role>;
}

View File

@ -0,0 +1,200 @@
#include <Access/RoleContext.h>
#include <Access/Role.h>
#include <Access/CurrentRolesInfo.h>
#include <Access/AccessControlManager.h>
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/algorithm/find.hpp>
#include <boost/range/algorithm/sort.hpp>
namespace DB
{
namespace
{
void makeUnique(std::vector<UUID> & vec)
{
boost::range::sort(vec);
vec.erase(std::unique(vec.begin(), vec.end()), vec.end());
}
}
RoleContext::RoleContext(const AccessControlManager & manager_, const UUID & current_role_, bool with_admin_option_)
: manager(&manager_), current_role(current_role_), with_admin_option(with_admin_option_)
{
update();
}
RoleContext::RoleContext(std::vector<RoleContextPtr> && children_)
: children(std::move(children_))
{
update();
}
RoleContext::~RoleContext() = default;
void RoleContext::update()
{
std::vector<OnChangeHandler> handlers_to_notify;
CurrentRolesInfoPtr info_to_notify;
{
std::lock_guard lock{mutex};
auto old_info = info;
updateImpl();
if (!handlers.empty() && (!old_info || (*old_info != *info)))
{
boost::range::copy(handlers, std::back_inserter(handlers_to_notify));
info_to_notify = info;
}
}
for (const auto & handler : handlers_to_notify)
handler(info_to_notify);
}
void RoleContext::updateImpl()
{
if (!current_role && children.empty())
{
info = std::make_shared<CurrentRolesInfo>();
return;
}
if (!children.empty())
{
if (subscriptions_for_change_children.empty())
{
for (const auto & child : children)
subscriptions_for_change_children.emplace_back(
child->subscribeForChanges([this](const CurrentRolesInfoPtr &) { update(); }));
}
auto new_info = std::make_shared<CurrentRolesInfo>();
auto & new_info_ref = *new_info;
for (const auto & child : children)
{
auto child_info = child->getInfo();
new_info_ref.access.merge(child_info->access);
new_info_ref.access_with_grant_option.merge(child_info->access_with_grant_option);
boost::range::copy(child_info->current_roles, std::back_inserter(new_info_ref.current_roles));
boost::range::copy(child_info->enabled_roles, std::back_inserter(new_info_ref.enabled_roles));
boost::range::copy(child_info->enabled_roles_with_admin_option, std::back_inserter(new_info_ref.enabled_roles_with_admin_option));
boost::range::copy(child_info->names_of_roles, std::inserter(new_info_ref.names_of_roles, new_info_ref.names_of_roles.end()));
}
makeUnique(new_info_ref.current_roles);
makeUnique(new_info_ref.enabled_roles);
makeUnique(new_info_ref.enabled_roles_with_admin_option);
info = new_info;
return;
}
assert(current_role);
traverseRoles(*current_role, with_admin_option);
auto new_info = std::make_shared<CurrentRolesInfo>();
auto & new_info_ref = *new_info;
for (auto it = roles_map.begin(); it != roles_map.end();)
{
const auto & id = it->first;
auto & entry = it->second;
if (!entry.in_use)
{
it = roles_map.erase(it);
continue;
}
if (id == *current_role)
new_info_ref.current_roles.push_back(id);
new_info_ref.enabled_roles.push_back(id);
if (entry.with_admin_option)
new_info_ref.enabled_roles_with_admin_option.push_back(id);
new_info_ref.access.merge(entry.role->access);
new_info_ref.access_with_grant_option.merge(entry.role->access_with_grant_option);
new_info_ref.names_of_roles[id] = entry.role->getName();
entry.in_use = false;
entry.with_admin_option = false;
++it;
}
info = new_info;
}
void RoleContext::traverseRoles(const UUID & id_, bool with_admin_option_)
{
auto it = roles_map.find(id_);
if (it == roles_map.end())
{
assert(manager);
auto subscription = manager->subscribeForChanges(id_, [this, id_](const UUID &, const AccessEntityPtr & entity)
{
{
std::lock_guard lock{mutex};
auto it2 = roles_map.find(id_);
if (it2 == roles_map.end())
return;
if (entity)
it2->second.role = typeid_cast<RolePtr>(entity);
else
roles_map.erase(it2);
}
update();
});
auto role = manager->tryRead<Role>(id_);
if (!role)
return;
RoleEntry new_entry;
new_entry.role = role;
new_entry.subscription_for_change_role = std::move(subscription);
it = roles_map.emplace(id_, std::move(new_entry)).first;
}
RoleEntry & entry = it->second;
entry.with_admin_option |= with_admin_option_;
if (entry.in_use)
return;
entry.in_use = true;
for (const auto & granted_role : entry.role->granted_roles)
traverseRoles(granted_role, false);
for (const auto & granted_role : entry.role->granted_roles_with_admin_option)
traverseRoles(granted_role, true);
}
CurrentRolesInfoPtr RoleContext::getInfo() const
{
std::lock_guard lock{mutex};
return info;
}
ext::scope_guard RoleContext::subscribeForChanges(const OnChangeHandler & handler) const
{
std::lock_guard lock{mutex};
handlers.push_back(handler);
auto it = std::prev(handlers.end());
return [this, it]
{
std::lock_guard lock2{mutex};
handlers.erase(it);
};
}
}

View File

@ -0,0 +1,64 @@
#pragma once
#include <Core/UUID.h>
#include <ext/scope_guard.h>
#include <ext/shared_ptr_helper.h>
#include <list>
#include <mutex>
#include <unordered_map>
#include <vector>
namespace DB
{
struct Role;
using RolePtr = std::shared_ptr<const Role>;
class CurrentRolesInfo;
using CurrentRolesInfoPtr = std::shared_ptr<const CurrentRolesInfo>;
class AccessControlManager;
class RoleContext
{
public:
~RoleContext();
/// Returns all the roles specified in the constructor.
CurrentRolesInfoPtr getInfo() const;
using OnChangeHandler = std::function<void(const CurrentRolesInfoPtr & info)>;
/// Called when either the specified roles or the roles granted to the specified roles are changed.
ext::scope_guard subscribeForChanges(const OnChangeHandler & handler) const;
private:
friend struct ext::shared_ptr_helper<RoleContext>;
RoleContext(const AccessControlManager & manager_, const UUID & current_role_, bool with_admin_option_);
RoleContext(std::vector<std::shared_ptr<const RoleContext>> && children_);
void update();
void updateImpl();
void traverseRoles(const UUID & id_, bool with_admin_option_);
const AccessControlManager * manager = nullptr;
std::optional<UUID> current_role;
bool with_admin_option = false;
std::vector<std::shared_ptr<const RoleContext>> children;
std::vector<ext::scope_guard> subscriptions_for_change_children;
struct RoleEntry
{
RolePtr role;
ext::scope_guard subscription_for_change_role;
bool with_admin_option = false;
bool in_use = false;
};
mutable std::unordered_map<UUID, RoleEntry> roles_map;
mutable CurrentRolesInfoPtr info;
mutable std::list<OnChangeHandler> handlers;
mutable std::mutex mutex;
};
using RoleContextPtr = std::shared_ptr<const RoleContext>;
}

View File

@ -0,0 +1,52 @@
#include <Access/RoleContextFactory.h>
#include <boost/container/flat_set.hpp>
namespace DB
{
RoleContextFactory::RoleContextFactory(const AccessControlManager & manager_)
: manager(manager_), cache(600000 /* 10 minutes */) {}
RoleContextFactory::~RoleContextFactory() = default;
RoleContextPtr RoleContextFactory::createContext(
const std::vector<UUID> & roles, const std::vector<UUID> & roles_with_admin_option)
{
if (roles.size() == 1 && roles_with_admin_option.empty())
return createContextImpl(roles[0], false);
if (roles.size() == 1 && roles_with_admin_option == roles)
return createContextImpl(roles[0], true);
std::vector<RoleContextPtr> children;
children.reserve(roles.size());
for (const auto & role : roles_with_admin_option)
children.push_back(createContextImpl(role, true));
boost::container::flat_set<UUID> roles_with_admin_option_set{roles_with_admin_option.begin(), roles_with_admin_option.end()};
for (const auto & role : roles)
{
if (!roles_with_admin_option_set.contains(role))
children.push_back(createContextImpl(role, false));
}
return ext::shared_ptr_helper<RoleContext>::create(std::move(children));
}
RoleContextPtr RoleContextFactory::createContextImpl(const UUID & id, bool with_admin_option)
{
std::lock_guard lock{mutex};
auto key = std::make_pair(id, with_admin_option);
auto x = cache.get(key);
if (x)
return *x;
auto res = ext::shared_ptr_helper<RoleContext>::create(manager, id, with_admin_option);
cache.add(key, res);
return res;
}
}

View File

@ -0,0 +1,29 @@
#pragma once
#include <Access/RoleContext.h>
#include <Poco/ExpireCache.h>
#include <mutex>
namespace DB
{
class AccessControlManager;
class RoleContextFactory
{
public:
RoleContextFactory(const AccessControlManager & manager_);
~RoleContextFactory();
RoleContextPtr createContext(const std::vector<UUID> & roles, const std::vector<UUID> & roles_with_admin_option);
private:
RoleContextPtr createContextImpl(const UUID & id, bool with_admin_option);
const AccessControlManager & manager;
Poco::ExpireCache<std::pair<UUID, bool>, RoleContextPtr> cache;
std::mutex mutex;
};
}

View File

@ -77,7 +77,7 @@ bool RowPolicy::equal(const IAccessEntity & other) const
const auto & other_policy = typeid_cast<const RowPolicy &>(other); const auto & other_policy = typeid_cast<const RowPolicy &>(other);
return (database == other_policy.database) && (table_name == other_policy.table_name) && (policy_name == other_policy.policy_name) return (database == other_policy.database) && (table_name == other_policy.table_name) && (policy_name == other_policy.policy_name)
&& boost::range::equal(conditions, other_policy.conditions) && restrictive == other_policy.restrictive && boost::range::equal(conditions, other_policy.conditions) && restrictive == other_policy.restrictive
&& (roles == other_policy.roles) && (all_roles == other_policy.all_roles) && (except_roles == other_policy.except_roles); && (roles == other_policy.roles);
} }

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <Access/IAccessEntity.h> #include <Access/IAccessEntity.h>
#include <Access/GenericRoleSet.h>
namespace DB namespace DB
@ -65,10 +66,8 @@ struct RowPolicy : public IAccessEntity
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<RowPolicy>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<RowPolicy>(); }
/// Which roles or users should use this quota. /// Which roles or users should use this row policy.
Strings roles; GenericRoleSet roles;
bool all_roles = false;
Strings except_roles;
private: private:
String database; String database;

View File

@ -1,4 +1,7 @@
#include <Access/RowPolicyContext.h> #include <Access/RowPolicyContext.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTExpressionList.h>
#include <boost/smart_ptr/make_shared.hpp>
#include <boost/range/adaptor/map.hpp> #include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp> #include <boost/range/algorithm/copy.hpp>
@ -7,12 +10,12 @@ namespace DB
{ {
size_t RowPolicyContext::Hash::operator()(const DatabaseAndTableNameRef & database_and_table_name) const size_t RowPolicyContext::Hash::operator()(const DatabaseAndTableNameRef & database_and_table_name) const
{ {
return std::hash<StringRef>{}(database_and_table_name.first) - std::hash<StringRef>{}(database_and_table_name.second); return std::hash<std::string_view>{}(database_and_table_name.first) - std::hash<std::string_view>{}(database_and_table_name.second);
} }
RowPolicyContext::RowPolicyContext() RowPolicyContext::RowPolicyContext()
: atomic_map_of_mixed_conditions(std::make_shared<MapOfMixedConditions>()) : map_of_mixed_conditions(boost::make_shared<MapOfMixedConditions>())
{ {
} }
@ -20,28 +23,45 @@ RowPolicyContext::RowPolicyContext()
RowPolicyContext::~RowPolicyContext() = default; RowPolicyContext::~RowPolicyContext() = default;
RowPolicyContext::RowPolicyContext(const String & user_name_) RowPolicyContext::RowPolicyContext(const UUID & user_id_, const std::vector<UUID> & enabled_roles_)
: user_name(user_name_) : user_id(user_id_), enabled_roles(enabled_roles_)
{} {}
ASTPtr RowPolicyContext::getCondition(const String & database, const String & table_name, ConditionIndex index) const ASTPtr RowPolicyContext::getCondition(const String & database, const String & table_name, ConditionIndex index) const
{ {
/// We don't lock `mutex` here. /// We don't lock `mutex` here.
auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions); auto loaded = map_of_mixed_conditions.load();
auto it = map_of_mixed_conditions->find({database, table_name}); auto it = loaded->find({database, table_name});
if (it == map_of_mixed_conditions->end()) if (it == loaded->end())
return {}; return {};
return it->second.mixed_conditions[index]; return it->second.mixed_conditions[index];
} }
ASTPtr RowPolicyContext::combineConditionsUsingAnd(const ASTPtr & lhs, const ASTPtr & rhs)
{
if (!lhs)
return rhs;
if (!rhs)
return lhs;
auto function = std::make_shared<ASTFunction>();
auto exp_list = std::make_shared<ASTExpressionList>();
function->name = "and";
function->arguments = exp_list;
function->children.push_back(exp_list);
exp_list->children.push_back(lhs);
exp_list->children.push_back(rhs);
return function;
}
std::vector<UUID> RowPolicyContext::getCurrentPolicyIDs() const std::vector<UUID> RowPolicyContext::getCurrentPolicyIDs() const
{ {
/// We don't lock `mutex` here. /// We don't lock `mutex` here.
auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions); auto loaded = map_of_mixed_conditions.load();
std::vector<UUID> policy_ids; std::vector<UUID> policy_ids;
for (const auto & mixed_conditions : *map_of_mixed_conditions | boost::adaptors::map_values) for (const auto & mixed_conditions : *loaded | boost::adaptors::map_values)
boost::range::copy(mixed_conditions.policy_ids, std::back_inserter(policy_ids)); boost::range::copy(mixed_conditions.policy_ids, std::back_inserter(policy_ids));
return policy_ids; return policy_ids;
} }
@ -50,9 +70,9 @@ std::vector<UUID> RowPolicyContext::getCurrentPolicyIDs() const
std::vector<UUID> RowPolicyContext::getCurrentPolicyIDs(const String & database, const String & table_name) const std::vector<UUID> RowPolicyContext::getCurrentPolicyIDs(const String & database, const String & table_name) const
{ {
/// We don't lock `mutex` here. /// We don't lock `mutex` here.
auto map_of_mixed_conditions = std::atomic_load(&atomic_map_of_mixed_conditions); auto loaded = map_of_mixed_conditions.load();
auto it = map_of_mixed_conditions->find({database, table_name}); auto it = loaded->find({database, table_name});
if (it == map_of_mixed_conditions->end()) if (it == loaded->end())
return {}; return {};
return it->second.policy_ids; return it->second.policy_ids;
} }

View File

@ -3,7 +3,7 @@
#include <Access/RowPolicy.h> #include <Access/RowPolicy.h>
#include <Core/Types.h> #include <Core/Types.h>
#include <Core/UUID.h> #include <Core/UUID.h>
#include <common/StringRef.h> #include <boost/smart_ptr/atomic_shared_ptr.hpp>
#include <memory> #include <memory>
#include <unordered_map> #include <unordered_map>
@ -30,6 +30,9 @@ public:
/// The returned filter can be a combination of the filters defined by multiple row policies. /// The returned filter can be a combination of the filters defined by multiple row policies.
ASTPtr getCondition(const String & database, const String & table_name, ConditionIndex index) const; ASTPtr getCondition(const String & database, const String & table_name, ConditionIndex index) const;
/// Combines two conditions into one by using the logical AND operator.
static ASTPtr combineConditionsUsingAnd(const ASTPtr & lhs, const ASTPtr & rhs);
/// Returns IDs of all the policies used by the current user. /// Returns IDs of all the policies used by the current user.
std::vector<UUID> getCurrentPolicyIDs() const; std::vector<UUID> getCurrentPolicyIDs() const;
@ -39,10 +42,10 @@ public:
private: private:
friend class RowPolicyContextFactory; friend class RowPolicyContextFactory;
friend struct ext::shared_ptr_helper<RowPolicyContext>; friend struct ext::shared_ptr_helper<RowPolicyContext>;
RowPolicyContext(const String & user_name_); /// RowPolicyContext should be created by RowPolicyContextFactory. RowPolicyContext(const UUID & user_id_, const std::vector<UUID> & enabled_roles_); /// RowPolicyContext should be created by RowPolicyContextFactory.
using DatabaseAndTableName = std::pair<String, String>; using DatabaseAndTableName = std::pair<String, String>;
using DatabaseAndTableNameRef = std::pair<StringRef, StringRef>; using DatabaseAndTableNameRef = std::pair<std::string_view, std::string_view>;
struct Hash struct Hash
{ {
size_t operator()(const DatabaseAndTableNameRef & database_and_table_name) const; size_t operator()(const DatabaseAndTableNameRef & database_and_table_name) const;
@ -57,10 +60,11 @@ private:
}; };
using MapOfMixedConditions = std::unordered_map<DatabaseAndTableNameRef, MixedConditions, Hash>; using MapOfMixedConditions = std::unordered_map<DatabaseAndTableNameRef, MixedConditions, Hash>;
const String user_name; const UUID user_id;
std::shared_ptr<const MapOfMixedConditions> atomic_map_of_mixed_conditions; /// Changed atomically, not protected by `mutex`. const std::vector<UUID> enabled_roles;
mutable boost::atomic_shared_ptr<const MapOfMixedConditions> map_of_mixed_conditions;
}; };
using RowPolicyContextPtr = std::shared_ptr<RowPolicyContext>; using RowPolicyContextPtr = std::shared_ptr<const RowPolicyContext>;
} }

View File

@ -8,6 +8,7 @@
#include <Common/Exception.h> #include <Common/Exception.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <ext/range.h> #include <ext/range.h>
#include <boost/smart_ptr/make_shared.hpp>
#include <boost/range/algorithm/copy.hpp> #include <boost/range/algorithm/copy.hpp>
#include <boost/range/algorithm_ext/erase.hpp> #include <boost/range/algorithm_ext/erase.hpp>
@ -110,13 +111,10 @@ namespace
ASTPtr getResult() && ASTPtr getResult() &&
{ {
/// Process permissive conditions. /// Process permissive conditions.
if (!permissions.empty()) restrictions.push_back(applyFunctionOR(std::move(permissions)));
restrictions.push_back(applyFunctionOR(std::move(permissions)));
/// Process restrictive conditions. /// Process restrictive conditions.
if (!restrictions.empty()) return applyFunctionAND(std::move(restrictions));
return applyFunctionAND(std::move(restrictions));
return nullptr;
} }
private: private:
@ -129,10 +127,7 @@ namespace
void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_) void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_)
{ {
policy = policy_; policy = policy_;
roles = &policy->roles;
boost::range::copy(policy->roles, std::inserter(roles, roles.end()));
all_roles = policy->all_roles;
boost::range::copy(policy->except_roles, std::inserter(except_roles, except_roles.end()));
for (auto index : ext::range_with_static_cast<ConditionIndex>(0, MAX_CONDITION_INDEX)) for (auto index : ext::range_with_static_cast<ConditionIndex>(0, MAX_CONDITION_INDEX))
{ {
@ -169,13 +164,7 @@ void RowPolicyContextFactory::PolicyInfo::setPolicy(const RowPolicyPtr & policy_
bool RowPolicyContextFactory::PolicyInfo::canUseWithContext(const RowPolicyContext & context) const bool RowPolicyContextFactory::PolicyInfo::canUseWithContext(const RowPolicyContext & context) const
{ {
if (roles.count(context.user_name)) return roles->match(context.user_id, context.enabled_roles);
return true;
if (all_roles && !except_roles.count(context.user_name))
return true;
return false;
} }
@ -187,11 +176,11 @@ RowPolicyContextFactory::RowPolicyContextFactory(const AccessControlManager & ac
RowPolicyContextFactory::~RowPolicyContextFactory() = default; RowPolicyContextFactory::~RowPolicyContextFactory() = default;
RowPolicyContextPtr RowPolicyContextFactory::createContext(const String & user_name) RowPolicyContextPtr RowPolicyContextFactory::createContext(const UUID & user_id, const std::vector<UUID> & enabled_roles)
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
ensureAllRowPoliciesRead(); ensureAllRowPoliciesRead();
auto context = ext::shared_ptr_helper<RowPolicyContext>::create(user_name); auto context = ext::shared_ptr_helper<RowPolicyContext>::create(user_id, enabled_roles);
contexts.push_back(context); contexts.push_back(context);
mixConditionsForContext(*context); mixConditionsForContext(*context);
return context; return context;
@ -284,10 +273,10 @@ void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context
for (const auto & [policy_id, info] : all_policies) for (const auto & [policy_id, info] : all_policies)
{ {
const auto & policy = *info.policy;
auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}];
if (info.canUseWithContext(context)) if (info.canUseWithContext(context))
{ {
const auto & policy = *info.policy;
auto & mixers = map_of_mixers[std::pair{policy.getDatabase(), policy.getTableName()}];
mixers.policy_ids.push_back(policy_id); mixers.policy_ids.push_back(policy_id);
for (auto index : ext::range(0, MAX_CONDITION_INDEX)) for (auto index : ext::range(0, MAX_CONDITION_INDEX))
if (info.parsed_conditions[index]) if (info.parsed_conditions[index])
@ -295,7 +284,7 @@ void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context
} }
} }
auto map_of_mixed_conditions = std::make_shared<MapOfMixedConditions>(); auto map_of_mixed_conditions = boost::make_shared<MapOfMixedConditions>();
for (auto & [database_and_table_name, mixers] : map_of_mixers) for (auto & [database_and_table_name, mixers] : map_of_mixers)
{ {
auto database_and_table_name_keeper = std::make_unique<DatabaseAndTableName>(); auto database_and_table_name_keeper = std::make_unique<DatabaseAndTableName>();
@ -309,7 +298,7 @@ void RowPolicyContextFactory::mixConditionsForContext(RowPolicyContext & context
mixed_conditions.mixed_conditions[index] = std::move(mixers.mixers[index]).getResult(); mixed_conditions.mixed_conditions[index] = std::move(mixers.mixers[index]).getResult();
} }
std::atomic_store(&context.atomic_map_of_mixed_conditions, std::shared_ptr<const MapOfMixedConditions>{map_of_mixed_conditions}); context.map_of_mixed_conditions.store(map_of_mixed_conditions);
} }
} }

View File

@ -4,14 +4,12 @@
#include <ext/scope_guard.h> #include <ext/scope_guard.h>
#include <mutex> #include <mutex>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
namespace DB namespace DB
{ {
class AccessControlManager; class AccessControlManager;
/// Stores read and parsed row policies. /// Stores read and parsed row policies.
class RowPolicyContextFactory class RowPolicyContextFactory
{ {
@ -19,7 +17,7 @@ public:
RowPolicyContextFactory(const AccessControlManager & access_control_manager_); RowPolicyContextFactory(const AccessControlManager & access_control_manager_);
~RowPolicyContextFactory(); ~RowPolicyContextFactory();
RowPolicyContextPtr createContext(const String & user_name); RowPolicyContextPtr createContext(const UUID & user_id, const std::vector<UUID> & enabled_roles);
private: private:
using ParsedConditions = RowPolicyContext::ParsedConditions; using ParsedConditions = RowPolicyContext::ParsedConditions;
@ -31,9 +29,7 @@ private:
bool canUseWithContext(const RowPolicyContext & context) const; bool canUseWithContext(const RowPolicyContext & context) const;
RowPolicyPtr policy; RowPolicyPtr policy;
std::unordered_set<String> roles; const GenericRoleSet * roles = nullptr;
bool all_roles = false;
std::unordered_set<String> except_roles;
ParsedConditions parsed_conditions; ParsedConditions parsed_conditions;
}; };

View File

@ -11,7 +11,8 @@ bool User::equal(const IAccessEntity & other) const
const auto & other_user = typeid_cast<const User &>(other); const auto & other_user = typeid_cast<const User &>(other);
return (authentication == other_user.authentication) && (allowed_client_hosts == other_user.allowed_client_hosts) return (authentication == other_user.authentication) && (allowed_client_hosts == other_user.allowed_client_hosts)
&& (access == other_user.access) && (access_with_grant_option == other_user.access_with_grant_option) && (access == other_user.access) && (access_with_grant_option == other_user.access_with_grant_option)
&& (profile == other_user.profile); && (granted_roles == other_user.granted_roles) && (granted_roles_with_admin_option == other_user.granted_roles_with_admin_option)
&& (default_roles == other_user.default_roles) && (profile == other_user.profile);
} }
} }

View File

@ -4,7 +4,9 @@
#include <Access/Authentication.h> #include <Access/Authentication.h>
#include <Access/AllowedClientHosts.h> #include <Access/AllowedClientHosts.h>
#include <Access/AccessRights.h> #include <Access/AccessRights.h>
#include <Core/Types.h> #include <Access/GenericRoleSet.h>
#include <Core/UUID.h>
#include <boost/container/flat_set.hpp>
namespace DB namespace DB
@ -14,9 +16,12 @@ namespace DB
struct User : public IAccessEntity struct User : public IAccessEntity
{ {
Authentication authentication; Authentication authentication;
AllowedClientHosts allowed_client_hosts{AllowedClientHosts::AnyHostTag{}}; AllowedClientHosts allowed_client_hosts = AllowedClientHosts::AnyHostTag{};
AccessRights access; AccessRights access;
AccessRights access_with_grant_option; AccessRights access_with_grant_option;
boost::container::flat_set<UUID> granted_roles;
boost::container::flat_set<UUID> granted_roles_with_admin_option;
GenericRoleSet default_roles = GenericRoleSet::AllTag{};
String profile; String profile;
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;

View File

@ -183,7 +183,7 @@ namespace
} }
QuotaPtr parseQuota(const Poco::Util::AbstractConfiguration & config, const String & quota_name, const Strings & user_names) QuotaPtr parseQuota(const Poco::Util::AbstractConfiguration & config, const String & quota_name, const std::vector<UUID> & user_ids)
{ {
auto quota = std::make_shared<Quota>(); auto quota = std::make_shared<Quota>();
quota->setName(quota_name); quota->setName(quota_name);
@ -225,7 +225,7 @@ namespace
limits.max[ResourceType::EXECUTION_TIME] = Quota::secondsToExecutionTime(config.getUInt64(interval_config + ".execution_time", Quota::UNLIMITED)); limits.max[ResourceType::EXECUTION_TIME] = Quota::secondsToExecutionTime(config.getUInt64(interval_config + ".execution_time", Quota::UNLIMITED));
} }
quota->roles = user_names; quota->roles.add(user_ids);
return quota; return quota;
} }
@ -235,11 +235,11 @@ namespace
{ {
Poco::Util::AbstractConfiguration::Keys user_names; Poco::Util::AbstractConfiguration::Keys user_names;
config.keys("users", user_names); config.keys("users", user_names);
std::unordered_map<String, Strings> quota_to_user_names; std::unordered_map<String, std::vector<UUID>> quota_to_user_ids;
for (const auto & user_name : user_names) for (const auto & user_name : user_names)
{ {
if (config.has("users." + user_name + ".quota")) if (config.has("users." + user_name + ".quota"))
quota_to_user_names[config.getString("users." + user_name + ".quota")].push_back(user_name); quota_to_user_ids[config.getString("users." + user_name + ".quota")].push_back(generateID(typeid(User), user_name));
} }
Poco::Util::AbstractConfiguration::Keys quota_names; Poco::Util::AbstractConfiguration::Keys quota_names;
@ -250,8 +250,8 @@ namespace
{ {
try try
{ {
auto it = quota_to_user_names.find(quota_name); auto it = quota_to_user_ids.find(quota_name);
const Strings quota_users = (it != quota_to_user_names.end()) ? std::move(it->second) : Strings{}; const std::vector<UUID> & quota_users = (it != quota_to_user_ids.end()) ? std::move(it->second) : std::vector<UUID>{};
quotas.push_back(parseQuota(config, quota_name, quota_users)); quotas.push_back(parseQuota(config, quota_name, quota_users));
} }
catch (...) catch (...)
@ -265,63 +265,70 @@ namespace
std::vector<AccessEntityPtr> parseRowPolicies(const Poco::Util::AbstractConfiguration & config, Poco::Logger * log) std::vector<AccessEntityPtr> parseRowPolicies(const Poco::Util::AbstractConfiguration & config, Poco::Logger * log)
{ {
std::vector<AccessEntityPtr> policies; std::map<std::pair<String /* database */, String /* table */>, std::unordered_map<String /* user */, String /* filter */>> all_filters_map;
Poco::Util::AbstractConfiguration::Keys user_names; Poco::Util::AbstractConfiguration::Keys user_names;
config.keys("users", user_names);
for (const String & user_name : user_names) try
{ {
const String databases_config = "users." + user_name + ".databases"; config.keys("users", user_names);
if (config.has(databases_config)) for (const String & user_name : user_names)
{ {
Poco::Util::AbstractConfiguration::Keys databases; const String databases_config = "users." + user_name + ".databases";
config.keys(databases_config, databases); if (config.has(databases_config))
/// Read tables within databases
for (const String & database : databases)
{ {
const String database_config = databases_config + "." + database; Poco::Util::AbstractConfiguration::Keys databases;
Poco::Util::AbstractConfiguration::Keys keys_in_database_config; config.keys(databases_config, databases);
config.keys(database_config, keys_in_database_config);
/// Read table properties /// Read tables within databases
for (const String & key_in_database_config : keys_in_database_config) for (const String & database : databases)
{ {
String table_name = key_in_database_config; const String database_config = databases_config + "." + database;
String filter_config = database_config + "." + table_name + ".filter"; Poco::Util::AbstractConfiguration::Keys keys_in_database_config;
config.keys(database_config, keys_in_database_config);
if (key_in_database_config.starts_with("table[")) /// Read table properties
for (const String & key_in_database_config : keys_in_database_config)
{ {
const auto table_name_config = database_config + "." + table_name + "[@name]"; String table_name = key_in_database_config;
if (config.has(table_name_config)) String filter_config = database_config + "." + table_name + ".filter";
{
table_name = config.getString(table_name_config);
filter_config = database_config + ".table[@name='" + table_name + "']";
}
}
if (config.has(filter_config)) if (key_in_database_config.starts_with("table["))
{
try
{ {
auto policy = std::make_shared<RowPolicy>(); const auto table_name_config = database_config + "." + table_name + "[@name]";
policy->setFullName(database, table_name, user_name); if (config.has(table_name_config))
policy->conditions[RowPolicy::SELECT_FILTER] = config.getString(filter_config); {
policy->roles.push_back(user_name); table_name = config.getString(table_name_config);
policies.push_back(policy); filter_config = database_config + ".table[@name='" + table_name + "']";
} }
catch (...)
{
tryLogCurrentException(
log,
"Could not parse row policy " + backQuote(user_name) + " on table " + backQuoteIfNeed(database) + "."
+ backQuoteIfNeed(table_name));
} }
all_filters_map[{database, table_name}][user_name] = config.getString(filter_config);
} }
} }
} }
} }
} }
catch (...)
{
tryLogCurrentException(log, "Could not parse row policies");
}
std::vector<AccessEntityPtr> policies;
for (auto & [database_and_table_name, user_to_filters] : all_filters_map)
{
const auto & [database, table_name] = database_and_table_name;
for (const String & user_name : user_names)
{
auto it = user_to_filters.find(user_name);
String filter = (it != user_to_filters.end()) ? it->second : "1";
auto policy = std::make_shared<RowPolicy>();
policy->setFullName(database, table_name, user_name);
policy->conditions[RowPolicy::SELECT_FILTER] = filter;
policy->roles.add(generateID(typeid(User), user_name));
policies.push_back(policy);
}
}
return policies; return policies;
} }
} }

View File

@ -482,6 +482,8 @@ namespace ErrorCodes
extern const int UNKNOWN_ACCESS_TYPE = 508; extern const int UNKNOWN_ACCESS_TYPE = 508;
extern const int INVALID_GRANT = 509; extern const int INVALID_GRANT = 509;
extern const int CACHE_DICTIONARY_UPDATE_FAIL = 510; extern const int CACHE_DICTIONARY_UPDATE_FAIL = 510;
extern const int UNKNOWN_ROLE = 511;
extern const int SET_NON_GRANTED_ROLE = 512;
extern const int KEEPER_EXCEPTION = 999; extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000; extern const int POCO_EXCEPTION = 1000;

View File

@ -953,7 +953,7 @@ public:
throw Exception("Wrong size of auth response. Expected: " + std::to_string(Poco::SHA1Engine::DIGEST_SIZE) + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.", throw Exception("Wrong size of auth response. Expected: " + std::to_string(Poco::SHA1Engine::DIGEST_SIZE) + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.",
ErrorCodes::UNKNOWN_EXCEPTION); ErrorCodes::UNKNOWN_EXCEPTION);
auto user = context.getAccessControlManager().getUser(user_name); auto user = context.getAccessControlManager().read<User>(user_name);
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1(); Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE); assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);

View File

@ -12,6 +12,7 @@ limitations under the License. */
#pragma once #pragma once
#include <DataStreams/IBlockInputStream.h> #include <DataStreams/IBlockInputStream.h>
#include <Processors/Sources/SourceWithProgress.h>
namespace DB namespace DB

View File

@ -24,6 +24,7 @@ namespace ErrorCodes
class ProcessListElement; class ProcessListElement;
class QuotaContext; class QuotaContext;
using QuotaContextPtr = std::shared_ptr<const QuotaContext>;
class QueryStatus; class QueryStatus;
struct SortColumnDescription; struct SortColumnDescription;
using SortDescription = std::vector<SortColumnDescription>; using SortDescription = std::vector<SortColumnDescription>;
@ -220,7 +221,7 @@ public:
/** Set the quota. If you set a quota on the amount of raw data, /** Set the quota. If you set a quota on the amount of raw data,
* then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits. * then you should also set mode = LIMITS_TOTAL to LocalLimits with setLimits.
*/ */
virtual void setQuota(const std::shared_ptr<QuotaContext> & quota_) virtual void setQuota(const QuotaContextPtr & quota_)
{ {
quota = quota_; quota = quota_;
} }
@ -278,7 +279,7 @@ private:
LocalLimits limits; LocalLimits limits;
std::shared_ptr<QuotaContext> quota; /// If nullptr - the quota is not used. QuotaContextPtr quota; /// If nullptr - the quota is not used.
UInt64 prev_elapsed = 0; UInt64 prev_elapsed = 0;
/// The approximate total number of rows to read. For progress bar. /// The approximate total number of rows to read. For progress bar.

View File

@ -27,11 +27,10 @@
#include <Interpreters/ActionLocksManager.h> #include <Interpreters/ActionLocksManager.h>
#include <Core/Settings.h> #include <Core/Settings.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/AccessRightsContext.h>
#include <Access/RowPolicyContext.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/SettingsConstraints.h> #include <Access/SettingsConstraints.h>
#include <Access/QuotaContext.h>
#include <Access/RowPolicyContext.h>
#include <Access/AccessRightsContext.h>
#include <Interpreters/ExpressionJIT.h> #include <Interpreters/ExpressionJIT.h>
#include <Dictionaries/Embedded/GeoDictionariesLoader.h> #include <Dictionaries/Embedded/GeoDictionariesLoader.h>
#include <Interpreters/EmbeddedDictionaries.h> #include <Interpreters/EmbeddedDictionaries.h>
@ -320,9 +319,8 @@ Context & Context::operator=(const Context &) = default;
Context Context::createGlobal() Context Context::createGlobal()
{ {
Context res; Context res;
res.quota = std::make_shared<QuotaContext>();
res.row_policy = std::make_shared<RowPolicyContext>();
res.access_rights = std::make_shared<AccessRightsContext>(); res.access_rights = std::make_shared<AccessRightsContext>();
res.initial_row_policy = std::make_shared<RowPolicyContext>();
res.shared = std::make_shared<ContextShared>(); res.shared = std::make_shared<ContextShared>();
return res; return res;
} }
@ -617,37 +615,17 @@ const Poco::Util::AbstractConfiguration & Context::getConfigRef() const
return shared->config ? *shared->config : Poco::Util::Application::instance().config(); return shared->config ? *shared->config : Poco::Util::Application::instance().config();
} }
AccessControlManager & Context::getAccessControlManager() AccessControlManager & Context::getAccessControlManager()
{ {
auto lock = getLock();
return shared->access_control_manager; return shared->access_control_manager;
} }
const AccessControlManager & Context::getAccessControlManager() const const AccessControlManager & Context::getAccessControlManager() const
{ {
auto lock = getLock();
return shared->access_control_manager; return shared->access_control_manager;
} }
template <typename... Args>
void Context::checkAccessImpl(const Args &... args) const
{
getAccessRights()->check(args...);
}
void Context::checkAccess(const AccessFlags & access) const { return checkAccessImpl(access); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(access, database); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(access, database, table); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(access, database, table, column); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkAccessImpl(access, database, table, columns); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(access, database, table, columns); }
void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); }
void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); }
void Context::switchRowPolicy()
{
row_policy = getAccessControlManager().getRowPolicyContext(client_info.initial_user);
}
void Context::setUsersConfig(const ConfigurationPtr & config) void Context::setUsersConfig(const ConfigurationPtr & config)
{ {
@ -662,10 +640,155 @@ ConfigurationPtr Context::getUsersConfig()
return shared->users_config; return shared->users_config;
} }
void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key)
{
auto lock = getLock();
client_info.current_user = name;
client_info.current_password = password;
client_info.current_address = address;
if (!quota_key.empty())
client_info.quota_key = quota_key;
auto new_user_id = getAccessControlManager().getID<User>(name);
auto new_access_rights = getAccessControlManager().getAccessRightsContext(new_user_id, {}, true, settings, current_database, client_info);
new_access_rights->checkHostIsAllowed();
new_access_rights->checkPassword(password);
user_id = new_user_id;
access_rights = std::move(new_access_rights);
current_roles.clear();
use_default_roles = true;
calculateUserSettings();
}
std::shared_ptr<const User> Context::getUser() const
{
auto lock = getLock();
return access_rights->getUser();
}
String Context::getUserName() const
{
auto lock = getLock();
return access_rights->getUserName();
}
UUID Context::getUserID() const
{
auto lock = getLock();
if (!user_id)
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
return *user_id;
}
void Context::setCurrentRoles(const std::vector<UUID> & current_roles_)
{
auto lock = getLock();
if (current_roles == current_roles_ && !use_default_roles)
return;
current_roles = current_roles_;
use_default_roles = false;
calculateAccessRights();
}
void Context::setCurrentRolesDefault()
{
auto lock = getLock();
if (use_default_roles)
return;
current_roles.clear();
use_default_roles = true;
calculateAccessRights();
}
std::vector<UUID> Context::getCurrentRoles() const
{
return getAccessRights()->getCurrentRoles();
}
Strings Context::getCurrentRolesNames() const
{
return getAccessRights()->getCurrentRolesNames();
}
std::vector<UUID> Context::getEnabledRoles() const
{
return getAccessRights()->getEnabledRoles();
}
Strings Context::getEnabledRolesNames() const
{
return getAccessRights()->getEnabledRolesNames();
}
void Context::calculateAccessRights()
{
auto lock = getLock();
if (user_id)
access_rights = getAccessControlManager().getAccessRightsContext(*user_id, current_roles, use_default_roles, settings, current_database, client_info);
}
template <typename... Args>
void Context::checkAccessImpl(const Args &... args) const
{
getAccessRights()->checkAccess(args...);
}
void Context::checkAccess(const AccessFlags & access) const { return checkAccessImpl(access); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database) const { return checkAccessImpl(access, database); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table) const { return checkAccessImpl(access, database, table); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::string_view & column) const { return checkAccessImpl(access, database, table, column); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const std::vector<std::string_view> & columns) const { return checkAccessImpl(access, database, table, columns); }
void Context::checkAccess(const AccessFlags & access, const std::string_view & database, const std::string_view & table, const Strings & columns) const { return checkAccessImpl(access, database, table, columns); }
void Context::checkAccess(const AccessRightsElement & access) const { return checkAccessImpl(access); }
void Context::checkAccess(const AccessRightsElements & access) const { return checkAccessImpl(access); }
AccessRightsContextPtr Context::getAccessRights() const
{
auto lock = getLock();
return access_rights;
}
RowPolicyContextPtr Context::getRowPolicy() const
{
return getAccessRights()->getRowPolicy();
}
void Context::setInitialRowPolicy()
{
auto lock = getLock();
auto initial_user_id = getAccessControlManager().find<User>(client_info.initial_user);
if (initial_user_id)
initial_row_policy = getAccessControlManager().getRowPolicyContext(*initial_user_id, {});
}
RowPolicyContextPtr Context::getInitialRowPolicy() const
{
auto lock = getLock();
return initial_row_policy;
}
QuotaContextPtr Context::getQuota() const
{
return getAccessRights()->getQuota();
}
void Context::calculateUserSettings() void Context::calculateUserSettings()
{ {
auto lock = getLock(); auto lock = getLock();
String profile = user->profile; String profile = getUser()->profile;
bool old_readonly = settings.readonly;
bool old_allow_ddl = settings.allow_ddl;
bool old_allow_introspection_functions = settings.allow_introspection_functions;
/// 1) Set default settings (hardcoded values) /// 1) Set default settings (hardcoded values)
/// NOTE: we ignore global_context settings (from which it is usually copied) /// NOTE: we ignore global_context settings (from which it is usually copied)
@ -680,13 +803,10 @@ void Context::calculateUserSettings()
/// 3) Apply settings from current user /// 3) Apply settings from current user
setProfile(profile); setProfile(profile);
}
void Context::calculateAccessRights() /// 4) Recalculate access rights if it's necessary.
{ if ((settings.readonly != old_readonly) || (settings.allow_ddl != old_allow_ddl) || (settings.allow_introspection_functions != old_allow_introspection_functions))
auto lock = getLock(); calculateAccessRights();
if (user)
std::atomic_store(&access_rights, getAccessControlManager().getAccessRightsContext(user, client_info, settings, current_database));
} }
void Context::setProfile(const String & profile) void Context::setProfile(const String & profile)
@ -699,50 +819,6 @@ void Context::setProfile(const String & profile)
settings_constraints = std::move(new_constraints); settings_constraints = std::move(new_constraints);
} }
std::shared_ptr<const User> Context::getUser() const
{
if (!user)
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
return user;
}
UUID Context::getUserID() const
{
if (!user)
throw Exception("No current user", ErrorCodes::LOGICAL_ERROR);
return user_id;
}
void Context::setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key)
{
auto lock = getLock();
client_info.current_user = name;
client_info.current_address = address;
client_info.current_password = password;
if (!quota_key.empty())
client_info.quota_key = quota_key;
user_id = shared->access_control_manager.getID<User>(name);
user = shared->access_control_manager.authorizeAndGetUser(
user_id,
password,
address.host(),
[this](const UserPtr & changed_user)
{
user = changed_user;
calculateAccessRights();
},
&subscription_for_user_change.subscription);
quota = getAccessControlManager().createQuotaContext(
client_info.current_user, client_info.current_address.host(), client_info.quota_key);
row_policy = getAccessControlManager().getRowPolicyContext(client_info.current_user);
calculateUserSettings();
calculateAccessRights();
}
void Context::addDependencyUnsafe(const StorageID & from, const StorageID & where) void Context::addDependencyUnsafe(const StorageID & from, const StorageID & where)
{ {

View File

@ -13,7 +13,6 @@
#include <Common/ThreadPool.h> #include <Common/ThreadPool.h>
#include "config_core.h" #include "config_core.h"
#include <Storages/IStorage_fwd.h> #include <Storages/IStorage_fwd.h>
#include <ext/scope_guard.h>
#include <atomic> #include <atomic>
#include <chrono> #include <chrono>
#include <condition_variable> #include <condition_variable>
@ -44,10 +43,14 @@ namespace DB
struct ContextShared; struct ContextShared;
class Context; class Context;
struct User;
class AccessRightsContext; class AccessRightsContext;
class QuotaContext; using AccessRightsContextPtr = std::shared_ptr<const AccessRightsContext>;
struct User;
using UserPtr = std::shared_ptr<const User>;
class RowPolicyContext; class RowPolicyContext;
using RowPolicyContextPtr = std::shared_ptr<const RowPolicyContext>;
class QuotaContext;
using QuotaContextPtr = std::shared_ptr<const QuotaContext>;
class AccessFlags; class AccessFlags;
struct AccessRightsElement; struct AccessRightsElement;
class AccessRightsElements; class AccessRightsElements;
@ -133,16 +136,6 @@ struct IHostContext
using IHostContextPtr = std::shared_ptr<IHostContext>; using IHostContextPtr = std::shared_ptr<IHostContext>;
/// Subscription for user's change. This subscription cannot be copied with the context,
/// that's why we had to move it into a separate structure.
struct SubscriptionForUserChange
{
ext::scope_guard subscription;
SubscriptionForUserChange() {}
SubscriptionForUserChange(const SubscriptionForUserChange &) {}
SubscriptionForUserChange & operator =(const SubscriptionForUserChange &) { subscription = {}; return *this; }
};
/** A set of known objects that can be used in the query. /** A set of known objects that can be used in the query.
* Consists of a shared part (always common to all sessions and queries) * Consists of a shared part (always common to all sessions and queries)
* and copied part (which can be its own for each session or query). * and copied part (which can be its own for each session or query).
@ -161,12 +154,11 @@ private:
InputInitializer input_initializer_callback; InputInitializer input_initializer_callback;
InputBlocksReader input_blocks_reader; InputBlocksReader input_blocks_reader;
std::shared_ptr<const User> user; std::optional<UUID> user_id;
UUID user_id; std::vector<UUID> current_roles;
SubscriptionForUserChange subscription_for_user_change; bool use_default_roles = false;
std::shared_ptr<const AccessRightsContext> access_rights; AccessRightsContextPtr access_rights;
std::shared_ptr<QuotaContext> quota; /// Current quota. By default - empty quota, that have no limits. RowPolicyContextPtr initial_row_policy;
std::shared_ptr<RowPolicyContext> row_policy;
String current_database; String current_database;
Settings settings; /// Setting for query execution. Settings settings; /// Setting for query execution.
std::shared_ptr<const SettingsConstraints> settings_constraints; std::shared_ptr<const SettingsConstraints> settings_constraints;
@ -237,7 +229,28 @@ public:
AccessControlManager & getAccessControlManager(); AccessControlManager & getAccessControlManager();
const AccessControlManager & getAccessControlManager() const; const AccessControlManager & getAccessControlManager() const;
std::shared_ptr<const AccessRightsContext> getAccessRights() const { return std::atomic_load(&access_rights); }
/** Take the list of users, quotas and configuration profiles from this config.
* The list of users is completely replaced.
* The accumulated quota values are not reset if the quota is not deleted.
*/
void setUsersConfig(const ConfigurationPtr & config);
ConfigurationPtr getUsersConfig();
/// Sets the current user, checks the password and that the specified host is allowed.
/// Must be called before getClientInfo.
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key);
UserPtr getUser() const;
String getUserName() const;
UUID getUserID() const;
void setCurrentRoles(const std::vector<UUID> & current_roles_);
void setCurrentRolesDefault();
std::vector<UUID> getCurrentRoles() const;
Strings getCurrentRolesNames() const;
std::vector<UUID> getEnabledRoles() const;
Strings getEnabledRolesNames() const;
/// Checks access rights. /// Checks access rights.
/// Empty database means the current database. /// Empty database means the current database.
@ -250,24 +263,17 @@ public:
void checkAccess(const AccessRightsElement & access) const; void checkAccess(const AccessRightsElement & access) const;
void checkAccess(const AccessRightsElements & access) const; void checkAccess(const AccessRightsElements & access) const;
std::shared_ptr<QuotaContext> getQuota() const { return quota; } AccessRightsContextPtr getAccessRights() const;
std::shared_ptr<RowPolicyContext> getRowPolicy() const { return row_policy; }
/// TODO: we need much better code for switching policies, quotas, access rights for initial user RowPolicyContextPtr getRowPolicy() const;
/// Switches row policy in case we have initial user in client info
void switchRowPolicy();
/** Take the list of users, quotas and configuration profiles from this config. /// Sets an extra row policy based on `client_info.initial_user`, if it exists.
* The list of users is completely replaced. /// TODO: we need a better solution here. It seems we should pass the initial row policy
* The accumulated quota values are not reset if the quota is not deleted. /// because a shard is allowed to don't have the initial user or it may be another user with the same name.
*/ void setInitialRowPolicy();
void setUsersConfig(const ConfigurationPtr & config); RowPolicyContextPtr getInitialRowPolicy() const;
ConfigurationPtr getUsersConfig();
/// Must be called before getClientInfo. QuotaContextPtr getQuota() const;
void setUser(const String & name, const String & password, const Poco::Net::SocketAddress & address, const String & quota_key);
std::shared_ptr<const User> getUser() const;
UUID getUserID() const;
/// We have to copy external tables inside executeQuery() to track limits. Therefore, set callback for it. Must set once. /// We have to copy external tables inside executeQuery() to track limits. Therefore, set callback for it. Must set once.
void setExternalTablesInitializer(ExternalTablesInitializer && initializer); void setExternalTablesInitializer(ExternalTablesInitializer && initializer);
@ -612,12 +618,6 @@ private:
void calculateUserSettings(); void calculateUserSettings();
void calculateAccessRights(); void calculateAccessRights();
/** Check if the current client has access to the specified database.
* If access is denied, throw an exception.
* NOTE: This method should always be called when the `shared->mutex` mutex is acquired.
*/
void checkDatabaseAccessRightsImpl(const std::string & database_name) const;
template <typename... Args> template <typename... Args>
void checkAccessImpl(const Args &... args) const; void checkAccessImpl(const Args &... args) const;

View File

@ -1,6 +1,6 @@
#include <Interpreters/InterpreterCreateQuotaQuery.h> #include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h> #include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/AccessFlags.h> #include <Access/AccessFlags.h>
@ -18,12 +18,16 @@ BlockIO InterpreterCreateQuotaQuery::execute()
auto & access_control = context.getAccessControlManager(); auto & access_control = context.getAccessControlManager();
context.checkAccess(query.alter ? AccessType::ALTER_QUOTA : AccessType::CREATE_QUOTA); context.checkAccess(query.alter ? AccessType::ALTER_QUOTA : AccessType::CREATE_QUOTA);
std::optional<GenericRoleSet> roles_from_query;
if (query.roles)
roles_from_query = GenericRoleSet{*query.roles, access_control, context.getUserID()};
if (query.alter) if (query.alter)
{ {
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_quota = typeid_cast<std::shared_ptr<Quota>>(entity->clone()); auto updated_quota = typeid_cast<std::shared_ptr<Quota>>(entity->clone());
updateQuotaFromQuery(*updated_quota, query); updateQuotaFromQuery(*updated_quota, query, roles_from_query);
return updated_quota; return updated_quota;
}; };
if (query.if_exists) if (query.if_exists)
@ -37,7 +41,7 @@ BlockIO InterpreterCreateQuotaQuery::execute()
else else
{ {
auto new_quota = std::make_shared<Quota>(); auto new_quota = std::make_shared<Quota>();
updateQuotaFromQuery(*new_quota, query); updateQuotaFromQuery(*new_quota, query, roles_from_query);
if (query.if_not_exists) if (query.if_not_exists)
access_control.tryInsert(new_quota); access_control.tryInsert(new_quota);
@ -51,7 +55,7 @@ BlockIO InterpreterCreateQuotaQuery::execute()
} }
void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query) void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query, const std::optional<GenericRoleSet> & roles_from_query)
{ {
if (query.alter) if (query.alter)
{ {
@ -98,25 +102,7 @@ void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTC
} }
} }
if (query.roles) if (roles_from_query)
{ quota.roles = *roles_from_query;
const auto & query_roles = *query.roles;
/// We keep `roles` sorted.
quota.roles = query_roles.roles;
if (query_roles.current_user)
quota.roles.push_back(context.getClientInfo().current_user);
boost::range::sort(quota.roles);
quota.roles.erase(std::unique(quota.roles.begin(), quota.roles.end()), quota.roles.end());
quota.all_roles = query_roles.all_roles;
/// We keep `except_roles` sorted.
quota.except_roles = query_roles.except_roles;
if (query_roles.except_current_user)
quota.except_roles.push_back(context.getClientInfo().current_user);
boost::range::sort(quota.except_roles);
quota.except_roles.erase(std::unique(quota.except_roles.begin(), quota.except_roles.end()), quota.except_roles.end());
}
} }
} }

View File

@ -2,12 +2,14 @@
#include <Interpreters/IInterpreter.h> #include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h> #include <Parsers/IAST_fwd.h>
#include <optional>
namespace DB namespace DB
{ {
class ASTCreateQuotaQuery; class ASTCreateQuotaQuery;
struct Quota; struct Quota;
struct GenericRoleSet;
class InterpreterCreateQuotaQuery : public IInterpreter class InterpreterCreateQuotaQuery : public IInterpreter
@ -21,7 +23,7 @@ public:
bool ignoreLimits() const override { return true; } bool ignoreLimits() const override { return true; }
private: private:
void updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query); void updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query, const std::optional<GenericRoleSet> & roles_from_query);
ASTPtr query_ptr; ASTPtr query_ptr;
Context & context; Context & context;

View File

@ -0,0 +1,62 @@
#include <Interpreters/InterpreterCreateRoleQuery.h>
#include <Parsers/ASTCreateRoleQuery.h>
#include <Interpreters/Context.h>
#include <Access/AccessControlManager.h>
#include <Access/Role.h>
namespace DB
{
BlockIO InterpreterCreateRoleQuery::execute()
{
const auto & query = query_ptr->as<const ASTCreateRoleQuery &>();
auto & access_control = context.getAccessControlManager();
if (query.alter)
context.checkAccess(AccessType::CREATE_ROLE | AccessType::DROP_ROLE);
else
context.checkAccess(AccessType::CREATE_ROLE);
if (query.alter)
{
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_role = typeid_cast<std::shared_ptr<Role>>(entity->clone());
updateRoleFromQuery(*updated_role, query);
return updated_role;
};
if (query.if_exists)
{
if (auto id = access_control.find<Role>(query.name))
access_control.tryUpdate(*id, update_func);
}
else
access_control.update(access_control.getID<Role>(query.name), update_func);
}
else
{
auto new_role = std::make_shared<Role>();
updateRoleFromQuery(*new_role, query);
if (query.if_not_exists)
access_control.tryInsert(new_role);
else if (query.or_replace)
access_control.insertOrReplace(new_role);
else
access_control.insert(new_role);
}
return {};
}
void InterpreterCreateRoleQuery::updateRoleFromQuery(Role & role, const ASTCreateRoleQuery & query)
{
if (query.alter)
{
if (!query.new_name.empty())
role.setName(query.new_name);
}
else
role.setName(query.name);
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTCreateRoleQuery;
struct Role;
class InterpreterCreateRoleQuery : public IInterpreter
{
public:
InterpreterCreateRoleQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
private:
void updateRoleFromQuery(Role & role, const ASTCreateRoleQuery & query);
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -1,6 +1,6 @@
#include <Interpreters/InterpreterCreateRowPolicyQuery.h> #include <Interpreters/InterpreterCreateRowPolicyQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h> #include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
@ -16,12 +16,16 @@ BlockIO InterpreterCreateRowPolicyQuery::execute()
auto & access_control = context.getAccessControlManager(); auto & access_control = context.getAccessControlManager();
context.checkAccess(query.alter ? AccessType::ALTER_POLICY : AccessType::CREATE_POLICY); context.checkAccess(query.alter ? AccessType::ALTER_POLICY : AccessType::CREATE_POLICY);
std::optional<GenericRoleSet> roles_from_query;
if (query.roles)
roles_from_query = GenericRoleSet{*query.roles, access_control, context.getUserID()};
if (query.alter) if (query.alter)
{ {
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_policy = typeid_cast<std::shared_ptr<RowPolicy>>(entity->clone()); auto updated_policy = typeid_cast<std::shared_ptr<RowPolicy>>(entity->clone());
updateRowPolicyFromQuery(*updated_policy, query); updateRowPolicyFromQuery(*updated_policy, query, roles_from_query);
return updated_policy; return updated_policy;
}; };
String full_name = query.name_parts.getFullName(context); String full_name = query.name_parts.getFullName(context);
@ -36,7 +40,7 @@ BlockIO InterpreterCreateRowPolicyQuery::execute()
else else
{ {
auto new_policy = std::make_shared<RowPolicy>(); auto new_policy = std::make_shared<RowPolicy>();
updateRowPolicyFromQuery(*new_policy, query); updateRowPolicyFromQuery(*new_policy, query, roles_from_query);
if (query.if_not_exists) if (query.if_not_exists)
access_control.tryInsert(new_policy); access_control.tryInsert(new_policy);
@ -50,7 +54,7 @@ BlockIO InterpreterCreateRowPolicyQuery::execute()
} }
void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query) void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query, const std::optional<GenericRoleSet> & roles_from_query)
{ {
if (query.alter) if (query.alter)
{ {
@ -70,25 +74,7 @@ void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & polic
for (const auto & [index, condition] : query.conditions) for (const auto & [index, condition] : query.conditions)
policy.conditions[index] = condition ? serializeAST(*condition) : String{}; policy.conditions[index] = condition ? serializeAST(*condition) : String{};
if (query.roles) if (roles_from_query)
{ policy.roles = *roles_from_query;
const auto & query_roles = *query.roles;
/// We keep `roles` sorted.
policy.roles = query_roles.roles;
if (query_roles.current_user)
policy.roles.push_back(context.getClientInfo().current_user);
boost::range::sort(policy.roles);
policy.roles.erase(std::unique(policy.roles.begin(), policy.roles.end()), policy.roles.end());
policy.all_roles = query_roles.all_roles;
/// We keep `except_roles` sorted.
policy.except_roles = query_roles.except_roles;
if (query_roles.except_current_user)
policy.except_roles.push_back(context.getClientInfo().current_user);
boost::range::sort(policy.except_roles);
policy.except_roles.erase(std::unique(policy.except_roles.begin(), policy.except_roles.end()), policy.except_roles.end());
}
} }
} }

View File

@ -2,12 +2,14 @@
#include <Interpreters/IInterpreter.h> #include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h> #include <Parsers/IAST_fwd.h>
#include <optional>
namespace DB namespace DB
{ {
class ASTCreateRowPolicyQuery; class ASTCreateRowPolicyQuery;
struct RowPolicy; struct RowPolicy;
struct GenericRoleSet;
class InterpreterCreateRowPolicyQuery : public IInterpreter class InterpreterCreateRowPolicyQuery : public IInterpreter
@ -18,7 +20,7 @@ public:
BlockIO execute() override; BlockIO execute() override;
private: private:
void updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query); void updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query, const std::optional<GenericRoleSet> & roles_from_query);
ASTPtr query_ptr; ASTPtr query_ptr;
Context & context; Context & context;

View File

@ -1,24 +1,47 @@
#include <Interpreters/InterpreterCreateUserQuery.h> #include <Interpreters/InterpreterCreateUserQuery.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/InterpreterSetRoleQuery.h>
#include <Parsers/ASTCreateUserQuery.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/GenericRoleSet.h>
#include <Access/AccessRightsContext.h>
#include <boost/range/algorithm/copy.hpp>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int SET_NON_GRANTED_ROLE;
}
BlockIO InterpreterCreateUserQuery::execute() BlockIO InterpreterCreateUserQuery::execute()
{ {
const auto & query = query_ptr->as<const ASTCreateUserQuery &>(); const auto & query = query_ptr->as<const ASTCreateUserQuery &>();
auto & access_control = context.getAccessControlManager(); auto & access_control = context.getAccessControlManager();
context.checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER); context.checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER);
GenericRoleSet * default_roles_from_query = nullptr;
GenericRoleSet temp_role_set;
if (query.default_roles)
{
default_roles_from_query = &temp_role_set;
*default_roles_from_query = GenericRoleSet{*query.default_roles, access_control};
if (!query.alter && !default_roles_from_query->all)
{
for (const UUID & role : default_roles_from_query->getMatchingIDs())
context.getAccessRights()->checkAdminOption(role);
}
}
if (query.alter) if (query.alter)
{ {
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone()); auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
updateUserFromQuery(*updated_user, query); updateUserFromQuery(*updated_user, query, default_roles_from_query);
return updated_user; return updated_user;
}; };
if (query.if_exists) if (query.if_exists)
@ -32,7 +55,7 @@ BlockIO InterpreterCreateUserQuery::execute()
else else
{ {
auto new_user = std::make_shared<User>(); auto new_user = std::make_shared<User>();
updateUserFromQuery(*new_user, query); updateUserFromQuery(*new_user, query, default_roles_from_query);
if (query.if_not_exists) if (query.if_not_exists)
access_control.tryInsert(new_user); access_control.tryInsert(new_user);
@ -46,7 +69,7 @@ BlockIO InterpreterCreateUserQuery::execute()
} }
void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query) void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query, const GenericRoleSet * default_roles_from_query)
{ {
if (query.alter) if (query.alter)
{ {
@ -66,7 +89,16 @@ void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreat
if (query.add_hosts) if (query.add_hosts)
user.allowed_client_hosts.add(*query.add_hosts); user.allowed_client_hosts.add(*query.add_hosts);
if (default_roles_from_query)
{
if (!query.alter && !default_roles_from_query->all)
boost::range::copy(default_roles_from_query->getMatchingIDs(), std::inserter(user.granted_roles, user.granted_roles.end()));
InterpreterSetRoleQuery::updateUserSetDefaultRoles(user, *default_roles_from_query);
}
if (query.profile) if (query.profile)
user.profile = *query.profile; user.profile = *query.profile;
} }
} }

View File

@ -7,6 +7,7 @@
namespace DB namespace DB
{ {
class ASTCreateUserQuery; class ASTCreateUserQuery;
class GenericRoleSet;
struct User; struct User;
@ -18,7 +19,7 @@ public:
BlockIO execute() override; BlockIO execute() override;
private: private:
void updateUserFromQuery(User & quota, const ASTCreateUserQuery & query); void updateUserFromQuery(User & user, const ASTCreateUserQuery & query, const GenericRoleSet * default_roles_from_query);
ASTPtr query_ptr; ASTPtr query_ptr;
Context & context; Context & context;

View File

@ -3,9 +3,10 @@
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/AccessFlags.h> #include <Access/AccessFlags.h>
#include <Access/User.h>
#include <Access/Role.h>
#include <Access/Quota.h> #include <Access/Quota.h>
#include <Access/RowPolicy.h> #include <Access/RowPolicy.h>
#include <Access/User.h>
#include <boost/range/algorithm/transform.hpp> #include <boost/range/algorithm/transform.hpp>
@ -29,6 +30,16 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
return {}; return {};
} }
case Kind::ROLE:
{
context.checkAccess(AccessType::DROP_ROLE);
if (query.if_exists)
access_control.tryRemove(access_control.find<Role>(query.names));
else
access_control.remove(access_control.getIDs<Role>(query.names));
return {};
}
case Kind::QUOTA: case Kind::QUOTA:
{ {
context.checkAccess(AccessType::DROP_QUOTA); context.checkAccess(AccessType::DROP_QUOTA);

View File

@ -87,7 +87,7 @@ BlockIO InterpreterDropQuery::executeToTable(
auto table_id = table->getStorageID(); auto table_id = table->getStorageID();
if (kind == ASTDropQuery::Kind::Detach) if (kind == ASTDropQuery::Kind::Detach)
{ {
context.checkAccess(table->isView() ? AccessType::DETACH_VIEW : AccessType::DETACH_TABLE, context.checkAccess(table->isView() ? AccessType::DROP_VIEW : AccessType::DROP_TABLE,
database_name, table_name); database_name, table_name);
table->shutdown(); table->shutdown();
/// If table was already dropped by anyone, an exception will be thrown /// If table was already dropped by anyone, an exception will be thrown
@ -187,7 +187,7 @@ BlockIO InterpreterDropQuery::executeToDictionary(
if (kind == ASTDropQuery::Kind::Detach) if (kind == ASTDropQuery::Kind::Detach)
{ {
/// Drop dictionary from memory, don't touch data and metadata /// Drop dictionary from memory, don't touch data and metadata
context.checkAccess(AccessType::DETACH_DICTIONARY, database_name, dictionary_name); context.checkAccess(AccessType::DROP_DICTIONARY, database_name, dictionary_name);
database->detachDictionary(dictionary_name, context); database->detachDictionary(dictionary_name, context);
} }
else if (kind == ASTDropQuery::Kind::Truncate) else if (kind == ASTDropQuery::Kind::Truncate)
@ -247,7 +247,7 @@ BlockIO InterpreterDropQuery::executeToDatabase(const String & database_name, AS
} }
else if (kind == ASTDropQuery::Kind::Detach) else if (kind == ASTDropQuery::Kind::Detach)
{ {
context.checkAccess(AccessType::DETACH_DATABASE, database_name); context.checkAccess(AccessType::DROP_DATABASE, database_name);
context.detachDatabase(database_name); context.detachDatabase(database_name);
database->shutdown(); database->shutdown();
} }
@ -324,14 +324,14 @@ AccessRightsElements InterpreterDropQuery::getRequiredAccessForDDLOnCluster() co
if (drop.table.empty()) if (drop.table.empty())
{ {
if (drop.kind == ASTDropQuery::Kind::Detach) if (drop.kind == ASTDropQuery::Kind::Detach)
required_access.emplace_back(AccessType::DETACH_DATABASE, drop.database); required_access.emplace_back(AccessType::DROP_DATABASE, drop.database);
else if (drop.kind == ASTDropQuery::Kind::Drop) else if (drop.kind == ASTDropQuery::Kind::Drop)
required_access.emplace_back(AccessType::DROP_DATABASE, drop.database); required_access.emplace_back(AccessType::DROP_DATABASE, drop.database);
} }
else if (drop.is_dictionary) else if (drop.is_dictionary)
{ {
if (drop.kind == ASTDropQuery::Kind::Detach) if (drop.kind == ASTDropQuery::Kind::Detach)
required_access.emplace_back(AccessType::DETACH_DICTIONARY, drop.database, drop.table); required_access.emplace_back(AccessType::DROP_DICTIONARY, drop.database, drop.table);
else if (drop.kind == ASTDropQuery::Kind::Drop) else if (drop.kind == ASTDropQuery::Kind::Drop)
required_access.emplace_back(AccessType::DROP_DICTIONARY, drop.database, drop.table); required_access.emplace_back(AccessType::DROP_DICTIONARY, drop.database, drop.table);
} }
@ -343,7 +343,7 @@ AccessRightsElements InterpreterDropQuery::getRequiredAccessForDDLOnCluster() co
else if (drop.kind == ASTDropQuery::Kind::Truncate) else if (drop.kind == ASTDropQuery::Kind::Truncate)
required_access.emplace_back(AccessType::TRUNCATE_TABLE | AccessType::TRUNCATE_VIEW, drop.database, drop.table); required_access.emplace_back(AccessType::TRUNCATE_TABLE | AccessType::TRUNCATE_VIEW, drop.database, drop.table);
else if (drop.kind == ASTDropQuery::Kind::Detach) else if (drop.kind == ASTDropQuery::Kind::Detach)
required_access.emplace_back(AccessType::DETACH_TABLE | AccessType::DETACH_VIEW, drop.database, drop.table); required_access.emplace_back(AccessType::DROP_TABLE | AccessType::DROP_VIEW, drop.database, drop.table);
} }
return required_access; return required_access;

View File

@ -2,6 +2,7 @@
#include <Parsers/ASTCheckQuery.h> #include <Parsers/ASTCheckQuery.h>
#include <Parsers/ASTCreateQuery.h> #include <Parsers/ASTCreateQuery.h>
#include <Parsers/ASTCreateUserQuery.h> #include <Parsers/ASTCreateUserQuery.h>
#include <Parsers/ASTCreateRoleQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h> #include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h> #include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTDropAccessEntityQuery.h> #include <Parsers/ASTDropAccessEntityQuery.h>
@ -13,6 +14,7 @@
#include <Parsers/ASTSelectQuery.h> #include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSelectWithUnionQuery.h> #include <Parsers/ASTSelectWithUnionQuery.h>
#include <Parsers/ASTSetQuery.h> #include <Parsers/ASTSetQuery.h>
#include <Parsers/ASTSetRoleQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h> #include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTShowProcesslistQuery.h> #include <Parsers/ASTShowProcesslistQuery.h>
#include <Parsers/ASTShowGrantsQuery.h> #include <Parsers/ASTShowGrantsQuery.h>
@ -29,6 +31,7 @@
#include <Interpreters/InterpreterCheckQuery.h> #include <Interpreters/InterpreterCheckQuery.h>
#include <Interpreters/InterpreterCreateQuery.h> #include <Interpreters/InterpreterCreateQuery.h>
#include <Interpreters/InterpreterCreateUserQuery.h> #include <Interpreters/InterpreterCreateUserQuery.h>
#include <Interpreters/InterpreterCreateRoleQuery.h>
#include <Interpreters/InterpreterCreateQuotaQuery.h> #include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Interpreters/InterpreterCreateRowPolicyQuery.h> #include <Interpreters/InterpreterCreateRowPolicyQuery.h>
#include <Interpreters/InterpreterDescribeQuery.h> #include <Interpreters/InterpreterDescribeQuery.h>
@ -44,6 +47,7 @@
#include <Interpreters/InterpreterSelectQuery.h> #include <Interpreters/InterpreterSelectQuery.h>
#include <Interpreters/InterpreterSelectWithUnionQuery.h> #include <Interpreters/InterpreterSelectWithUnionQuery.h>
#include <Interpreters/InterpreterSetQuery.h> #include <Interpreters/InterpreterSetQuery.h>
#include <Interpreters/InterpreterSetRoleQuery.h>
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h> #include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/InterpreterShowCreateQuery.h> #include <Interpreters/InterpreterShowCreateQuery.h>
#include <Interpreters/InterpreterShowProcesslistQuery.h> #include <Interpreters/InterpreterShowProcesslistQuery.h>
@ -126,6 +130,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
/// readonly is checked inside InterpreterSetQuery /// readonly is checked inside InterpreterSetQuery
return std::make_unique<InterpreterSetQuery>(query, context); return std::make_unique<InterpreterSetQuery>(query, context);
} }
else if (query->as<ASTSetRoleQuery>())
{
return std::make_unique<InterpreterSetRoleQuery>(query, context);
}
else if (query->as<ASTOptimizeQuery>()) else if (query->as<ASTOptimizeQuery>())
{ {
return std::make_unique<InterpreterOptimizeQuery>(query, context); return std::make_unique<InterpreterOptimizeQuery>(query, context);
@ -186,6 +194,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{ {
return std::make_unique<InterpreterCreateUserQuery>(query, context); return std::make_unique<InterpreterCreateUserQuery>(query, context);
} }
else if (query->as<ASTCreateRoleQuery>())
{
return std::make_unique<InterpreterCreateRoleQuery>(query, context);
}
else if (query->as<ASTCreateQuotaQuery>()) else if (query->as<ASTCreateQuotaQuery>())
{ {
return std::make_unique<InterpreterCreateQuotaQuery>(query, context); return std::make_unique<InterpreterCreateQuotaQuery>(query, context);

View File

@ -1,10 +1,12 @@
#include <Interpreters/InterpreterGrantQuery.h> #include <Interpreters/InterpreterGrantQuery.h>
#include <Parsers/ASTGrantQuery.h> #include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTRoleList.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/AccessRightsContext.h> #include <Access/AccessRightsContext.h>
#include <Access/GenericRoleSet.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/Role.h>
#include <boost/range/algorithm/copy.hpp>
namespace DB namespace DB
@ -16,41 +18,93 @@ BlockIO InterpreterGrantQuery::execute()
context.getAccessRights()->checkGrantOption(query.access_rights_elements); context.getAccessRights()->checkGrantOption(query.access_rights_elements);
using Kind = ASTGrantQuery::Kind; using Kind = ASTGrantQuery::Kind;
std::vector<UUID> roles;
if (query.roles)
{
roles = GenericRoleSet{*query.roles, access_control}.getMatchingRoles(access_control);
for (const UUID & role : roles)
context.getAccessRights()->checkAdminOption(role);
}
if (query.to_roles->all_roles) std::vector<UUID> to_roles = GenericRoleSet{*query.to_roles, access_control, context.getUserID()}.getMatchingUsersAndRoles(access_control);
throw Exception(
"Cannot " + String((query.kind == Kind::GRANT) ? "GRANT to" : "REVOKE from") + " ALL", ErrorCodes::NOT_IMPLEMENTED);
String current_database = context.getCurrentDatabase(); String current_database = context.getCurrentDatabase();
using Kind = ASTGrantQuery::Kind;
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone()); auto clone = entity->clone();
if (query.kind == Kind::GRANT) AccessRights * access = nullptr;
AccessRights * access_with_grant_option = nullptr;
boost::container::flat_set<UUID> * granted_roles = nullptr;
boost::container::flat_set<UUID> * granted_roles_with_admin_option = nullptr;
GenericRoleSet * default_roles = nullptr;
if (auto user = typeid_cast<std::shared_ptr<User>>(clone))
{ {
updated_user->access.grant(query.access_rights_elements, current_database); access = &user->access;
if (query.grant_option) access_with_grant_option = &user->access_with_grant_option;
updated_user->access_with_grant_option.grant(query.access_rights_elements, current_database); granted_roles = &user->granted_roles;
granted_roles_with_admin_option = &user->granted_roles_with_admin_option;
default_roles = &user->default_roles;
} }
else if (context.getSettingsRef().partial_revokes) else if (auto role = typeid_cast<std::shared_ptr<Role>>(clone))
{ {
updated_user->access_with_grant_option.partialRevoke(query.access_rights_elements, current_database); access = &role->access;
if (!query.grant_option) access_with_grant_option = &role->access_with_grant_option;
updated_user->access.partialRevoke(query.access_rights_elements, current_database); granted_roles = &role->granted_roles;
granted_roles_with_admin_option = &role->granted_roles_with_admin_option;
} }
else else
return entity;
if (!query.access_rights_elements.empty())
{ {
updated_user->access_with_grant_option.revoke(query.access_rights_elements, current_database); if (query.kind == Kind::GRANT)
if (!query.grant_option) {
updated_user->access.revoke(query.access_rights_elements, current_database); access->grant(query.access_rights_elements, current_database);
if (query.grant_option)
access_with_grant_option->grant(query.access_rights_elements, current_database);
}
else if (context.getSettingsRef().partial_revokes)
{
access_with_grant_option->partialRevoke(query.access_rights_elements, current_database);
if (!query.grant_option)
access->partialRevoke(query.access_rights_elements, current_database);
}
else
{
access_with_grant_option->revoke(query.access_rights_elements, current_database);
if (!query.grant_option)
access->revoke(query.access_rights_elements, current_database);
}
} }
return updated_user;
if (!roles.empty())
{
if (query.kind == Kind::GRANT)
{
boost::range::copy(roles, std::inserter(*granted_roles, granted_roles->end()));
if (query.admin_option)
boost::range::copy(roles, std::inserter(*granted_roles_with_admin_option, granted_roles_with_admin_option->end()));
}
else
{
for (const UUID & role : roles)
{
granted_roles_with_admin_option->erase(role);
if (!query.admin_option)
{
granted_roles->erase(role);
if (default_roles)
default_roles->ids.erase(role);
}
}
}
}
return clone;
}; };
std::vector<UUID> ids = access_control.getIDs<User>(query.to_roles->roles); access_control.update(to_roles, update_func);
if (query.to_roles->current_user)
ids.push_back(context.getUserID());
access_control.update(ids, update_func);
return {}; return {};
} }

View File

@ -373,6 +373,7 @@ InterpreterSelectQuery::InterpreterSelectQuery(
/// Fix source_header for filter actions. /// Fix source_header for filter actions.
auto row_policy_filter = context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER); auto row_policy_filter = context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER);
row_policy_filter = RowPolicyContext::combineConditionsUsingAnd(row_policy_filter, context->getInitialRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER));
if (row_policy_filter) if (row_policy_filter)
{ {
filter_info = std::make_shared<FilterInfo>(); filter_info = std::make_shared<FilterInfo>();
@ -516,7 +517,8 @@ Block InterpreterSelectQuery::getSampleBlockImpl(bool try_move_to_prewhere)
/// PREWHERE optimization. /// PREWHERE optimization.
/// Turn off, if the table filter (row-level security) is applied. /// Turn off, if the table filter (row-level security) is applied.
if (!context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER)) if (!context->getRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER)
&& !context->getInitialRowPolicy()->getCondition(table_id.getDatabaseName(), table_id.getTableName(), RowPolicy::SELECT_FILTER))
{ {
auto optimize_prewhere = [&](auto & merge_tree) auto optimize_prewhere = [&](auto & merge_tree)
{ {

View File

@ -0,0 +1,95 @@
#include <Interpreters/InterpreterSetRoleQuery.h>
#include <Parsers/ASTSetRoleQuery.h>
#include <Parsers/ASTGenericRoleSet.h>
#include <Interpreters/Context.h>
#include <Access/GenericRoleSet.h>
#include <Access/AccessControlManager.h>
#include <Access/User.h>
namespace DB
{
namespace ErrorCodes
{
extern const int SET_NON_GRANTED_ROLE;
}
BlockIO InterpreterSetRoleQuery::execute()
{
const auto & query = query_ptr->as<const ASTSetRoleQuery &>();
if (query.kind == ASTSetRoleQuery::Kind::SET_DEFAULT_ROLE)
setDefaultRole(query);
else
setRole(query);
return {};
}
void InterpreterSetRoleQuery::setRole(const ASTSetRoleQuery & query)
{
auto & access_control = context.getAccessControlManager();
auto & session_context = context.getSessionContext();
auto user = session_context.getUser();
if (query.kind == ASTSetRoleQuery::Kind::SET_ROLE_DEFAULT)
{
session_context.setCurrentRolesDefault();
}
else
{
GenericRoleSet roles_from_query{*query.roles, access_control};
std::vector<UUID> new_current_roles;
if (roles_from_query.all)
{
for (const auto & id : user->granted_roles)
if (roles_from_query.match(id))
new_current_roles.push_back(id);
}
else
{
for (const auto & id : roles_from_query.getMatchingIDs())
{
if (!user->granted_roles.contains(id))
throw Exception("Role should be granted to set current", ErrorCodes::SET_NON_GRANTED_ROLE);
new_current_roles.push_back(id);
}
}
session_context.setCurrentRoles(new_current_roles);
}
}
void InterpreterSetRoleQuery::setDefaultRole(const ASTSetRoleQuery & query)
{
context.checkAccess(AccessType::CREATE_USER | AccessType::DROP_USER);
auto & access_control = context.getAccessControlManager();
std::vector<UUID> to_users = GenericRoleSet{*query.to_users, access_control, context.getUserID()}.getMatchingUsers(access_control);
GenericRoleSet roles_from_query{*query.roles, access_control};
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
updateUserSetDefaultRoles(*updated_user, roles_from_query);
return updated_user;
};
access_control.update(to_users, update_func);
}
void InterpreterSetRoleQuery::updateUserSetDefaultRoles(User & user, const GenericRoleSet & roles_from_query)
{
if (!roles_from_query.all)
{
for (const auto & id : roles_from_query.getMatchingIDs())
{
if (!user.granted_roles.contains(id))
throw Exception("Role should be granted to set default", ErrorCodes::SET_NON_GRANTED_ROLE);
}
}
user.default_roles = roles_from_query;
}
}

View File

@ -0,0 +1,30 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class ASTSetRoleQuery;
class GenericRoleSet;
struct User;
class InterpreterSetRoleQuery : public IInterpreter
{
public:
InterpreterSetRoleQuery(const ASTPtr & query_ptr_, Context & context_) : query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
static void updateUserSetDefaultRoles(User & user, const GenericRoleSet & roles_from_query);
private:
void setRole(const ASTSetRoleQuery & query);
void setDefaultRole(const ASTSetRoleQuery & query);
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -4,7 +4,7 @@
#include <Parsers/ASTCreateQuotaQuery.h> #include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h> #include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h> #include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/ExpressionListParsers.h> #include <Parsers/ExpressionListParsers.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Parsers/parseQuery.h> #include <Parsers/parseQuery.h>
@ -74,17 +74,20 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateUserQuery(const ASTShowC
if (show_query.current_user) if (show_query.current_user)
user = context.getUser(); user = context.getUser();
else else
user = context.getAccessControlManager().getUser(show_query.name); user = context.getAccessControlManager().read<User>(show_query.name);
auto create_query = std::make_shared<ASTCreateUserQuery>(); auto create_query = std::make_shared<ASTCreateUserQuery>();
create_query->name = user->getName(); create_query->name = user->getName();
if (!user->allowed_client_hosts.containsAnyHost()) if (user->allowed_client_hosts != AllowedClientHosts::AnyHostTag{})
create_query->hosts = user->allowed_client_hosts; create_query->hosts = user->allowed_client_hosts;
if (!user->profile.empty()) if (!user->profile.empty())
create_query->profile = user->profile; create_query->profile = user->profile;
if (user->default_roles != GenericRoleSet::AllTag{})
create_query->default_roles = GenericRoleSet{user->default_roles}.toAST(context.getAccessControlManager());
return create_query; return create_query;
} }
@ -115,14 +118,8 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuotaQuery(const ASTShow
create_query->all_limits.push_back(create_query_limits); create_query->all_limits.push_back(create_query_limits);
} }
if (!quota->roles.empty() || quota->all_roles) if (!quota->roles.empty())
{ create_query->roles = quota->roles.toAST(access_control);
auto create_query_roles = std::make_shared<ASTRoleList>();
create_query_roles->roles = quota->roles;
create_query_roles->all_roles = quota->all_roles;
create_query_roles->except_roles = quota->except_roles;
create_query->roles = std::move(create_query_roles);
}
return create_query; return create_query;
} }
@ -149,14 +146,8 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateRowPolicyQuery(const AST
} }
} }
if (!policy->roles.empty() || policy->all_roles) if (!policy->roles.empty())
{ create_query->roles = policy->roles.toAST(access_control);
auto create_query_roles = std::make_shared<ASTRoleList>();
create_query_roles->roles = policy->roles;
create_query_roles->all_roles = policy->all_roles;
create_query_roles->except_roles = policy->except_roles;
create_query->roles = std::move(create_query_roles);
}
return create_query; return create_query;
} }

View File

@ -1,7 +1,7 @@
#include <Interpreters/InterpreterShowGrantsQuery.h> #include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Parsers/ASTShowGrantsQuery.h> #include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/ASTGrantQuery.h> #include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
@ -9,6 +9,7 @@
#include <DataTypes/DataTypeString.h> #include <DataTypes/DataTypeString.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/Role.h>
#include <boost/range/adaptor/map.hpp> #include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp> #include <boost/range/algorithm/copy.hpp>
@ -88,19 +89,44 @@ BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl()
ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show_query) const ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show_query) const
{ {
const auto & access_control = context.getAccessControlManager();
UserPtr user; UserPtr user;
RolePtr role;
if (show_query.current_user) if (show_query.current_user)
user = context.getUser(); user = context.getUser();
else else
user = context.getAccessControlManager().getUser(show_query.name); {
user = access_control.tryRead<User>(show_query.name);
if (!user)
role = access_control.read<Role>(show_query.name);
}
const AccessRights * access = nullptr;
const AccessRights * access_with_grant_option = nullptr;
const boost::container::flat_set<UUID> * granted_roles = nullptr;
const boost::container::flat_set<UUID> * granted_roles_with_admin_option = nullptr;
if (user)
{
access = &user->access;
access_with_grant_option = &user->access_with_grant_option;
granted_roles = &user->granted_roles;
granted_roles_with_admin_option = &user->granted_roles_with_admin_option;
}
else
{
access = &role->access;
access_with_grant_option = &role->access_with_grant_option;
granted_roles = &role->granted_roles;
granted_roles_with_admin_option = &role->granted_roles_with_admin_option;
}
ASTs res; ASTs res;
for (bool grant_option : {true, false}) for (bool grant_option : {true, false})
{ {
if (!grant_option && (user->access == user->access_with_grant_option)) if (!grant_option && (*access == *access_with_grant_option))
continue; continue;
const auto & access_rights = grant_option ? user->access_with_grant_option : user->access; const auto & access_rights = grant_option ? *access_with_grant_option : *access;
const auto grouped_elements = groupByTable(access_rights.getElements()); const auto grouped_elements = groupByTable(access_rights.getElements());
using Kind = ASTGrantQuery::Kind; using Kind = ASTGrantQuery::Kind;
@ -111,14 +137,33 @@ ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show
auto grant_query = std::make_shared<ASTGrantQuery>(); auto grant_query = std::make_shared<ASTGrantQuery>();
grant_query->kind = kind; grant_query->kind = kind;
grant_query->grant_option = grant_option; grant_query->grant_option = grant_option;
grant_query->to_roles = std::make_shared<ASTRoleList>(); grant_query->to_roles = std::make_shared<ASTGenericRoleSet>();
grant_query->to_roles->roles.push_back(user->getName()); grant_query->to_roles->names.push_back(show_query.name);
grant_query->access_rights_elements = elements; grant_query->access_rights_elements = elements;
res.push_back(std::move(grant_query)); res.push_back(std::move(grant_query));
} }
} }
} }
for (bool admin_option : {true, false})
{
if (!admin_option && (*granted_roles == *granted_roles_with_admin_option))
continue;
const auto & roles = admin_option ? *granted_roles_with_admin_option : *granted_roles;
if (roles.empty())
continue;
auto grant_query = std::make_shared<ASTGrantQuery>();
using Kind = ASTGrantQuery::Kind;
grant_query->kind = Kind::GRANT;
grant_query->admin_option = admin_option;
grant_query->to_roles = std::make_shared<ASTGenericRoleSet>();
grant_query->to_roles->names.push_back(show_query.name);
grant_query->roles = GenericRoleSet{roles}.toAST(access_control);
res.push_back(std::move(grant_query));
}
return res; return res;
} }
} }

View File

@ -218,7 +218,7 @@ void runOneTest(const TestDescriptor & test_descriptor)
try try
{ {
res = acl_manager.getUser(entry.user_name)->access.isGranted(DB::AccessType::ALL, entry.database_name); res = acl_manager.read<DB::User>(entry.user_name)->access.isGranted(DB::AccessType::ALL, entry.database_name);
} }
catch (const Poco::Exception &) catch (const Poco::Exception &)
{ {

View File

@ -1,5 +1,5 @@
#include <Parsers/ASTCreateQuotaQuery.h> #include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <Common/IntervalKind.h> #include <Common/IntervalKind.h>
#include <ext/range.h> #include <ext/range.h>
@ -94,7 +94,7 @@ namespace
} }
} }
void formatRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings) void formatToRoles(const ASTGenericRoleSet & roles, const IAST::FormatSettings & settings)
{ {
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : "");
roles.format(settings); roles.format(settings);
@ -137,6 +137,6 @@ void ASTCreateQuotaQuery::formatImpl(const FormatSettings & settings, FormatStat
formatAllLimits(all_limits, settings); formatAllLimits(all_limits, settings);
if (roles) if (roles)
formatRoles(*roles, settings); formatToRoles(*roles, settings);
} }
} }

View File

@ -6,7 +6,7 @@
namespace DB namespace DB
{ {
class ASTRoleList; class ASTGenericRoleSet;
/** CREATE QUOTA [IF NOT EXISTS | OR REPLACE] name /** CREATE QUOTA [IF NOT EXISTS | OR REPLACE] name
@ -53,7 +53,7 @@ public:
}; };
std::vector<Limits> all_limits; std::vector<Limits> all_limits;
std::shared_ptr<ASTRoleList> roles; std::shared_ptr<ASTGenericRoleSet> roles;
String getID(char) const override; String getID(char) const override;
ASTPtr clone() const override; ASTPtr clone() const override;

View File

@ -0,0 +1,46 @@
#include <Parsers/ASTCreateRoleQuery.h>
#include <Common/quoteString.h>
namespace DB
{
namespace
{
void formatRenameTo(const String & new_name, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " RENAME TO " << (settings.hilite ? IAST::hilite_none : "")
<< quoteString(new_name);
}
}
String ASTCreateRoleQuery::getID(char) const
{
return "CreateRoleQuery";
}
ASTPtr ASTCreateRoleQuery::clone() const
{
return std::make_shared<ASTCreateRoleQuery>(*this);
}
void ASTCreateRoleQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "") << (alter ? "ALTER ROLE" : "CREATE ROLE")
<< (settings.hilite ? hilite_none : "");
if (if_exists)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF EXISTS" << (settings.hilite ? hilite_none : "");
else if (if_not_exists)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " IF NOT EXISTS" << (settings.hilite ? hilite_none : "");
else if (or_replace)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " OR REPLACE" << (settings.hilite ? hilite_none : "");
settings.ostr << " " << backQuoteIfNeed(name);
if (!new_name.empty())
formatRenameTo(new_name, settings);
}
}

View File

@ -0,0 +1,29 @@
#pragma once
#include <Parsers/IAST.h>
namespace DB
{
/** CREATE ROLE [IF NOT EXISTS | OR REPLACE] name
*
* ALTER ROLE [IF EXISTS] name
* [RENAME TO new_name]
*/
class ASTCreateRoleQuery : public IAST
{
public:
bool alter = false;
bool if_exists = false;
bool if_not_exists = false;
bool or_replace = false;
String name;
String new_name;
String getID(char) const override;
ASTPtr clone() const override;
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -1,5 +1,5 @@
#include <Parsers/ASTCreateRowPolicyQuery.h> #include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <boost/range/algorithm/transform.hpp> #include <boost/range/algorithm/transform.hpp>
@ -19,7 +19,7 @@ namespace
} }
void formatIsRestrictive(bool is_restrictive, const IAST::FormatSettings & settings) void formatAsRestrictiveOrPermissive(bool is_restrictive, const IAST::FormatSettings & settings)
{ {
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " AS " << (is_restrictive ? "RESTRICTIVE" : "PERMISSIVE") settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " AS " << (is_restrictive ? "RESTRICTIVE" : "PERMISSIVE")
<< (settings.hilite ? IAST::hilite_none : ""); << (settings.hilite ? IAST::hilite_none : "");
@ -112,7 +112,7 @@ namespace
} }
} }
void formatRoles(const ASTRoleList & roles, const IAST::FormatSettings & settings) void formatToRoles(const ASTGenericRoleSet & roles, const IAST::FormatSettings & settings)
{ {
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : ""); settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " TO " << (settings.hilite ? IAST::hilite_none : "");
roles.format(settings); roles.format(settings);
@ -154,11 +154,11 @@ void ASTCreateRowPolicyQuery::formatImpl(const FormatSettings & settings, Format
formatRenameTo(new_policy_name, settings); formatRenameTo(new_policy_name, settings);
if (is_restrictive) if (is_restrictive)
formatIsRestrictive(*is_restrictive, settings); formatAsRestrictiveOrPermissive(*is_restrictive, settings);
formatMultipleConditions(conditions, alter, settings); formatMultipleConditions(conditions, alter, settings);
if (roles) if (roles)
formatRoles(*roles, settings); formatToRoles(*roles, settings);
} }
} }

View File

@ -8,7 +8,7 @@
namespace DB namespace DB
{ {
class ASTRoleList; class ASTGenericRoleSet;
/** CREATE [ROW] POLICY [IF NOT EXISTS | OR REPLACE] name ON [database.]table /** CREATE [ROW] POLICY [IF NOT EXISTS | OR REPLACE] name ON [database.]table
* [AS {PERMISSIVE | RESTRICTIVE}] * [AS {PERMISSIVE | RESTRICTIVE}]
@ -41,7 +41,7 @@ public:
using ConditionIndex = RowPolicy::ConditionIndex; using ConditionIndex = RowPolicy::ConditionIndex;
std::vector<std::pair<ConditionIndex, ASTPtr>> conditions; std::vector<std::pair<ConditionIndex, ASTPtr>> conditions;
std::shared_ptr<ASTRoleList> roles; std::shared_ptr<ASTGenericRoleSet> roles;
String getID(char) const override; String getID(char) const override;
ASTPtr clone() const override; ASTPtr clone() const override;

View File

@ -1,4 +1,5 @@
#include <Parsers/ASTCreateUserQuery.h> #include <Parsers/ASTCreateUserQuery.h>
#include <Parsers/ASTGenericRoleSet.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
@ -134,6 +135,13 @@ namespace
} }
void formatDefaultRoles(const ASTGenericRoleSet & default_roles, const IAST::FormatSettings & settings)
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " DEFAULT ROLE " << (settings.hilite ? IAST::hilite_none : "");
default_roles.format(settings);
}
void formatProfile(const String & profile_name, const IAST::FormatSettings & settings) void formatProfile(const String & profile_name, const IAST::FormatSettings & settings)
{ {
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " PROFILE " << (settings.hilite ? IAST::hilite_none : "") settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " PROFILE " << (settings.hilite ? IAST::hilite_none : "")
@ -181,6 +189,9 @@ void ASTCreateUserQuery::formatImpl(const FormatSettings & settings, FormatState
if (remove_hosts) if (remove_hosts)
formatHosts("REMOVE", *remove_hosts, settings); formatHosts("REMOVE", *remove_hosts, settings);
if (default_roles)
formatDefaultRoles(*default_roles, settings);
if (profile) if (profile)
formatProfile(*profile, settings); formatProfile(*profile, settings);
} }

View File

@ -7,15 +7,19 @@
namespace DB namespace DB
{ {
class ASTGenericRoleSet;
/** CREATE USER [IF NOT EXISTS | OR REPLACE] name /** CREATE USER [IF NOT EXISTS | OR REPLACE] name
* [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}] * [IDENTIFIED [WITH {NO_PASSWORD|PLAINTEXT_PASSWORD|SHA256_PASSWORD|SHA256_HASH|DOUBLE_SHA1_PASSWORD|DOUBLE_SHA1_HASH}] BY {'password'|'hash'}]
* [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] * [HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE]
* [DEFAULT ROLE role [,...]]
* [PROFILE 'profile_name'] * [PROFILE 'profile_name']
* *
* ALTER USER [IF EXISTS] name * ALTER USER [IF EXISTS] name
* [RENAME TO new_name] * [RENAME TO new_name]
* [IDENTIFIED [WITH {PLAINTEXT_PASSWORD|SHA256_PASSWORD|DOUBLE_SHA1_PASSWORD}] BY {'password'|'hash'}] * [IDENTIFIED [WITH {PLAINTEXT_PASSWORD|SHA256_PASSWORD|DOUBLE_SHA1_PASSWORD}] BY {'password'|'hash'}]
* [[ADD|REMOVE] HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE] * [[ADD|REMOVE] HOST {LOCAL | NAME 'name' | NAME REGEXP 'name_regexp' | IP 'address' | LIKE 'pattern'} [,...] | ANY | NONE]
* [DEFAULT ROLE role [,...] | ALL | ALL EXCEPT role [,...] ]
* [PROFILE 'profile_name'] * [PROFILE 'profile_name']
*/ */
class ASTCreateUserQuery : public IAST class ASTCreateUserQuery : public IAST
@ -36,6 +40,8 @@ public:
std::optional<AllowedClientHosts> add_hosts; std::optional<AllowedClientHosts> add_hosts;
std::optional<AllowedClientHosts> remove_hosts; std::optional<AllowedClientHosts> remove_hosts;
std::shared_ptr<ASTGenericRoleSet> default_roles;
std::optional<String> profile; std::optional<String> profile;
String getID(char) const override; String getID(char) const override;

View File

@ -13,6 +13,7 @@ namespace
switch (kind) switch (kind)
{ {
case Kind::USER: return "USER"; case Kind::USER: return "USER";
case Kind::ROLE: return "ROLE";
case Kind::QUOTA: return "QUOTA"; case Kind::QUOTA: return "QUOTA";
case Kind::ROW_POLICY: return "POLICY"; case Kind::ROW_POLICY: return "POLICY";
} }

View File

@ -7,9 +7,10 @@
namespace DB namespace DB
{ {
/** DROP QUOTA [IF EXISTS] name [,...] /** DROP USER [IF EXISTS] name [,...]
* DROP ROLE [IF EXISTS] name [,...]
* DROP QUOTA [IF EXISTS] name [,...]
* DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...] * DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...]
* DROP USER [IF EXISTS] name [,...]
*/ */
class ASTDropAccessEntityQuery : public IAST class ASTDropAccessEntityQuery : public IAST
{ {
@ -17,6 +18,7 @@ public:
enum class Kind enum class Kind
{ {
USER, USER,
ROLE,
QUOTA, QUOTA,
ROW_POLICY, ROW_POLICY,
}; };

View File

@ -0,0 +1,59 @@
#include <Parsers/ASTGenericRoleSet.h>
#include <Common/quoteString.h>
namespace DB
{
void ASTGenericRoleSet::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
if (empty())
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "NONE" << (settings.hilite ? IAST::hilite_none : "");
return;
}
bool need_comma = false;
if (all)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "ALL" << (settings.hilite ? IAST::hilite_none : "");
}
else
{
for (auto & role : names)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << backQuoteIfNeed(role);
}
if (current_user)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "CURRENT_USER" << (settings.hilite ? IAST::hilite_none : "");
}
}
if (except_current_user || !except_names.empty())
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " EXCEPT " << (settings.hilite ? IAST::hilite_none : "");
need_comma = false;
for (auto & except_role : except_names)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << backQuoteIfNeed(except_role);
}
if (except_current_user)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "CURRENT_USER" << (settings.hilite ? IAST::hilite_none : "");
}
}
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Parsers/IAST.h>
#include <Access/Quota.h>
namespace DB
{
/// Represents a set of users/roles like
/// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...]
class ASTGenericRoleSet : public IAST
{
public:
Strings names;
bool current_user = false;
bool all = false;
Strings except_names;
bool except_current_user = false;
bool empty() const { return names.empty() && !current_user && !all; }
String getID(char) const override { return "GenericRoleSet"; }
ASTPtr clone() const override { return std::make_shared<ASTGenericRoleSet>(*this); }
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -1,5 +1,5 @@
#include <Parsers/ASTGrantQuery.h> #include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Common/quoteString.h> #include <Common/quoteString.h>
#include <boost/range/adaptor/map.hpp> #include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/sort.hpp> #include <boost/range/algorithm/sort.hpp>
@ -9,6 +9,11 @@
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace namespace
{ {
using KeywordToColumnsMap = std::map<std::string_view /* keyword */, std::vector<std::string_view> /* columns */>; using KeywordToColumnsMap = std::map<std::string_view /* keyword */, std::vector<std::string_view> /* columns */>;
@ -71,6 +76,34 @@ namespace
} }
settings.ostr << ")"; settings.ostr << ")";
} }
void formatAccessRightsElements(const AccessRightsElements & elements, const IAST::FormatSettings & settings)
{
bool need_comma = false;
for (const auto & [database_and_table, keyword_to_columns] : prepareTableToAccessMap(elements))
{
for (const auto & [keyword, columns] : keyword_to_columns)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << keyword << (settings.hilite ? IAST::hilite_none : "");
formatColumnNames(columns, settings);
}
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " ON " << (settings.hilite ? IAST::hilite_none : "") << database_and_table;
}
}
void formatToRoles(const ASTGenericRoleSet & to_roles, ASTGrantQuery::Kind kind, const IAST::FormatSettings & settings)
{
using Kind = ASTGrantQuery::Kind;
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ((kind == Kind::GRANT) ? " TO " : " FROM ")
<< (settings.hilite ? IAST::hilite_none : "");
to_roles.format(settings);
}
} }
@ -88,31 +121,33 @@ ASTPtr ASTGrantQuery::clone() const
void ASTGrantQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const void ASTGrantQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{ {
settings.ostr << (settings.hilite ? hilite_keyword : "") << ((kind == Kind::GRANT) ? "GRANT" : "REVOKE") settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << ((kind == Kind::GRANT) ? "GRANT" : "REVOKE")
<< (settings.hilite ? hilite_none : "") << " "; << (settings.hilite ? IAST::hilite_none : "") << " ";
if (grant_option && (kind == Kind::REVOKE)) if (kind == Kind::REVOKE)
settings.ostr << (settings.hilite ? hilite_keyword : "") << "GRANT OPTION FOR " << (settings.hilite ? hilite_none : "");
bool need_comma = false;
for (const auto & [database_and_table, keyword_to_columns] : prepareTableToAccessMap(access_rights_elements))
{ {
for (const auto & [keyword, columns] : keyword_to_columns) if (grant_option)
{ settings.ostr << (settings.hilite ? hilite_keyword : "") << "GRANT OPTION FOR " << (settings.hilite ? hilite_none : "");
if (std::exchange(need_comma, true)) else if (admin_option)
settings.ostr << ", "; settings.ostr << (settings.hilite ? hilite_keyword : "") << "ADMIN OPTION FOR " << (settings.hilite ? hilite_none : "");
settings.ostr << (settings.hilite ? hilite_keyword : "") << keyword << (settings.hilite ? hilite_none : "");
formatColumnNames(columns, settings);
}
settings.ostr << (settings.hilite ? hilite_keyword : "") << " ON " << (settings.hilite ? hilite_none : "") << database_and_table;
} }
settings.ostr << (settings.hilite ? hilite_keyword : "") << ((kind == Kind::GRANT) ? " TO " : " FROM ") << (settings.hilite ? hilite_none : ""); if ((!!roles + !access_rights_elements.empty()) != 1)
to_roles->format(settings); throw Exception("Either roles or access rights elements should be set", ErrorCodes::LOGICAL_ERROR);
if (grant_option && (kind == Kind::GRANT)) if (roles)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " WITH GRANT OPTION" << (settings.hilite ? hilite_none : ""); roles->format(settings);
else
formatAccessRightsElements(access_rights_elements, settings);
formatToRoles(*to_roles, kind, settings);
if (kind == Kind::GRANT)
{
if (grant_option)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " WITH GRANT OPTION" << (settings.hilite ? hilite_none : "");
else if (admin_option)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " WITH ADMIN OPTION" << (settings.hilite ? hilite_none : "");
}
} }
} }

View File

@ -6,11 +6,14 @@
namespace DB namespace DB
{ {
class ASTRoleList; class ASTGenericRoleSet;
/** GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name /** GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO {user_name | CURRENT_USER} [,...] [WITH GRANT OPTION]
* REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name * REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} FROM {user_name | CURRENT_USER} [,...] | ALL | ALL EXCEPT {user_name | CURRENT_USER} [,...]
*
* GRANT role [,...] TO {user_name | role_name | CURRENT_USER} [,...] [WITH ADMIN OPTION]
* REVOKE [ADMIN OPTION FOR] role [,...] FROM {user_name | role_name | CURRENT_USER} [,...] | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...]
*/ */
class ASTGrantQuery : public IAST class ASTGrantQuery : public IAST
{ {
@ -22,8 +25,10 @@ public:
}; };
Kind kind = Kind::GRANT; Kind kind = Kind::GRANT;
AccessRightsElements access_rights_elements; AccessRightsElements access_rights_elements;
std::shared_ptr<ASTRoleList> to_roles; std::shared_ptr<ASTGenericRoleSet> roles;
std::shared_ptr<ASTGenericRoleSet> to_roles;
bool grant_option = false; bool grant_option = false;
bool admin_option = false;
String getID(char) const override; String getID(char) const override;
ASTPtr clone() const override; ASTPtr clone() const override;

View File

@ -1,56 +0,0 @@
#include <Parsers/ASTRoleList.h>
#include <Common/quoteString.h>
namespace DB
{
void ASTRoleList::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
if (empty())
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "NONE" << (settings.hilite ? IAST::hilite_none : "");
return;
}
bool need_comma = false;
if (current_user)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "CURRENT_USER" << (settings.hilite ? IAST::hilite_none : "");
}
for (auto & role : roles)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << backQuoteIfNeed(role);
}
if (all_roles)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "ALL" << (settings.hilite ? IAST::hilite_none : "");
if (except_current_user || !except_roles.empty())
{
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << " EXCEPT " << (settings.hilite ? IAST::hilite_none : "");
need_comma = false;
if (except_current_user)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << (settings.hilite ? IAST::hilite_keyword : "") << "CURRENT_USER" << (settings.hilite ? IAST::hilite_none : "");
}
for (auto & except_role : except_roles)
{
if (std::exchange(need_comma, true))
settings.ostr << ", ";
settings.ostr << backQuoteIfNeed(except_role);
}
}
}
}
}

View File

@ -1,25 +0,0 @@
#pragma once
#include <Parsers/IAST.h>
#include <Access/Quota.h>
namespace DB
{
/// {role|CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {role|CURRENT_USER} [,...]
class ASTRoleList : public IAST
{
public:
Strings roles;
bool current_user = false;
bool all_roles = false;
Strings except_roles;
bool except_current_user = false;
bool empty() const { return roles.empty() && !current_user && !all_roles; }
String getID(char) const override { return "RoleList"; }
ASTPtr clone() const override { return std::make_shared<ASTRoleList>(*this); }
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -0,0 +1,43 @@
#include <Parsers/ASTSetRoleQuery.h>
#include <Parsers/ASTGenericRoleSet.h>
#include <Common/quoteString.h>
namespace DB
{
String ASTSetRoleQuery::getID(char) const
{
return "SetRoleQuery";
}
ASTPtr ASTSetRoleQuery::clone() const
{
return std::make_shared<ASTSetRoleQuery>(*this);
}
void ASTSetRoleQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{
settings.ostr << (settings.hilite ? hilite_keyword : "");
switch (kind)
{
case Kind::SET_ROLE: settings.ostr << "SET ROLE"; break;
case Kind::SET_ROLE_DEFAULT: settings.ostr << "SET ROLE DEFAULT"; break;
case Kind::SET_DEFAULT_ROLE: settings.ostr << "SET DEFAULT ROLE"; break;
}
settings.ostr << (settings.hilite ? hilite_none : "");
if (kind == Kind::SET_ROLE_DEFAULT)
return;
settings.ostr << " ";
roles->format(settings);
if (kind == Kind::SET_ROLE)
return;
settings.ostr << (settings.hilite ? hilite_keyword : "") << " TO " << (settings.hilite ? hilite_none : "");
to_users->format(settings);
}
}

View File

@ -0,0 +1,31 @@
#pragma once
#include <Parsers/IAST.h>
namespace DB
{
class ASTGenericRoleSet;
/** SET ROLE {DEFAULT | NONE | role [,...] | ALL | ALL EXCEPT role [,...]}
* SET DEFAULT ROLE {NONE | role [,...] | ALL | ALL EXCEPT role [,...]} TO {user|CURRENT_USER} [,...]
*/
class ASTSetRoleQuery : public IAST
{
public:
enum class Kind
{
SET_ROLE,
SET_ROLE_DEFAULT,
SET_DEFAULT_ROLE,
};
Kind kind = Kind::SET_ROLE;
std::shared_ptr<ASTGenericRoleSet> roles;
std::shared_ptr<ASTGenericRoleSet> to_users;
String getID(char) const override;
ASTPtr clone() const override;
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};
}

View File

@ -46,7 +46,8 @@ void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & sett
<< (settings.hilite ? hilite_none : ""); << (settings.hilite ? hilite_none : "");
if ((kind == Kind::USER) && current_user) if ((kind == Kind::USER) && current_user)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT_USER" << (settings.hilite ? hilite_none : ""); {
}
else if ((kind == Kind::QUOTA) && current_quota) else if ((kind == Kind::QUOTA) && current_quota)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : ""); settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : "");
else if (kind == Kind::ROW_POLICY) else if (kind == Kind::ROW_POLICY)

View File

@ -18,13 +18,11 @@ ASTPtr ASTShowGrantsQuery::clone() const
void ASTShowGrantsQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const void ASTShowGrantsQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{ {
settings.ostr << (settings.hilite ? hilite_keyword : "") << "SHOW GRANTS FOR " settings.ostr << (settings.hilite ? hilite_keyword : "") << "SHOW GRANTS"
<< (settings.hilite ? hilite_none : ""); << (settings.hilite ? hilite_none : "");
if (current_user) if (!current_user)
settings.ostr << (settings.hilite ? hilite_keyword : "") << "CURRENT_USER" settings.ostr << (settings.hilite ? hilite_keyword : "") << " FOR " << (settings.hilite ? hilite_none : "")
<< (settings.hilite ? hilite_none : ""); << backQuoteIfNeed(name);
else
settings.ostr << backQuoteIfNeed(name);
} }
} }

View File

@ -3,10 +3,10 @@
#include <Parsers/CommonParsers.h> #include <Parsers/CommonParsers.h>
#include <Parsers/parseIntervalKind.h> #include <Parsers/parseIntervalKind.h>
#include <Parsers/parseIdentifierOrStringLiteral.h> #include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/ParserRoleList.h> #include <Parsers/ParserGenericRoleSet.h>
#include <Parsers/ExpressionElementParsers.h> #include <Parsers/ExpressionElementParsers.h>
#include <Parsers/ASTLiteral.h> #include <Parsers/ASTLiteral.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <ext/range.h> #include <ext/range.h>
#include <boost/algorithm/string/predicate.hpp> #include <boost/algorithm/string/predicate.hpp>
@ -25,13 +25,10 @@ namespace
using ResourceType = Quota::ResourceType; using ResourceType = Quota::ResourceType;
using ResourceAmount = Quota::ResourceAmount; using ResourceAmount = Quota::ResourceAmount;
bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name, bool alter) bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
if (!new_name.empty() || !alter)
return false;
if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected))
return false; return false;
@ -43,9 +40,6 @@ namespace
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
if (key_type)
return false;
if (!ParserKeyword{"KEYED BY"}.ignore(pos, expected)) if (!ParserKeyword{"KEYED BY"}.ignore(pos, expected))
return false; return false;
@ -123,7 +117,7 @@ namespace
}); });
} }
bool parseLimits(IParserBase::Pos & pos, Expected & expected, ASTCreateQuotaQuery::Limits & limits, bool alter) bool parseLimits(IParserBase::Pos & pos, Expected & expected, bool alter, ASTCreateQuotaQuery::Limits & limits)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
@ -173,15 +167,19 @@ namespace
}); });
} }
bool parseAllLimits(IParserBase::Pos & pos, Expected & expected, std::vector<ASTCreateQuotaQuery::Limits> & all_limits, bool alter) bool parseAllLimits(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector<ASTCreateQuotaQuery::Limits> & all_limits)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
size_t old_size = all_limits.size();
do do
{ {
ASTCreateQuotaQuery::Limits limits; ASTCreateQuotaQuery::Limits limits;
if (!parseLimits(pos, expected, limits, alter)) if (!parseLimits(pos, expected, alter, limits))
{
all_limits.resize(old_size);
return false; return false;
}
all_limits.push_back(limits); all_limits.push_back(limits);
} }
while (ParserToken{TokenType::Comma}.ignore(pos, expected)); while (ParserToken{TokenType::Comma}.ignore(pos, expected));
@ -189,15 +187,15 @@ namespace
}); });
} }
bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr<ASTRoleList> & roles) bool parseToRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr<ASTGenericRoleSet> & roles)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
ASTPtr node; ASTPtr node;
if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserRoleList{}.parse(pos, node, expected)) if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserGenericRoleSet{}.parse(pos, node, expected))
return false; return false;
roles = std::static_pointer_cast<ASTRoleList>(node); roles = std::static_pointer_cast<ASTGenericRoleSet>(node);
return true; return true;
}); });
} }
@ -237,11 +235,24 @@ bool ParserCreateQuotaQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expe
String new_name; String new_name;
std::optional<KeyType> key_type; std::optional<KeyType> key_type;
std::vector<ASTCreateQuotaQuery::Limits> all_limits; std::vector<ASTCreateQuotaQuery::Limits> all_limits;
std::shared_ptr<ASTRoleList> roles; std::shared_ptr<ASTGenericRoleSet> roles;
while (parseRenameTo(pos, expected, new_name, alter) || parseKeyType(pos, expected, key_type) while (true)
|| parseAllLimits(pos, expected, all_limits, alter) || parseRoles(pos, expected, roles)) {
; if (alter && new_name.empty() && parseRenameTo(pos, expected, new_name))
continue;
if (!key_type && parseKeyType(pos, expected, key_type))
continue;
if (parseAllLimits(pos, expected, alter, all_limits))
continue;
if (!roles && parseToRoles(pos, expected, roles))
continue;
break;
}
auto query = std::make_shared<ASTCreateQuotaQuery>(); auto query = std::make_shared<ASTCreateQuotaQuery>();
node = query; node = query;

View File

@ -0,0 +1,70 @@
#include <Parsers/ParserCreateRoleQuery.h>
#include <Parsers/ASTCreateRoleQuery.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/parseUserName.h>
namespace DB
{
namespace
{
bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected))
return false;
return parseRoleName(pos, expected, new_name);
});
}
}
bool ParserCreateRoleQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
bool alter;
if (ParserKeyword{"CREATE ROLE"}.ignore(pos, expected))
alter = false;
else if (ParserKeyword{"ALTER ROLE"}.ignore(pos, expected))
alter = true;
else
return false;
bool if_exists = false;
bool if_not_exists = false;
bool or_replace = false;
if (alter)
{
if (ParserKeyword{"IF EXISTS"}.ignore(pos, expected))
if_exists = true;
}
else
{
if (ParserKeyword{"IF NOT EXISTS"}.ignore(pos, expected))
if_not_exists = true;
else if (ParserKeyword{"OR REPLACE"}.ignore(pos, expected))
or_replace = true;
}
String name;
if (!parseRoleName(pos, expected, name))
return false;
String new_name;
if (alter)
parseRenameTo(pos, expected, new_name);
auto query = std::make_shared<ASTCreateRoleQuery>();
node = query;
query->alter = alter;
query->if_exists = if_exists;
query->if_not_exists = if_not_exists;
query->or_replace = or_replace;
query->name = std::move(name);
query->new_name = std::move(new_name);
return true;
}
}

View File

@ -0,0 +1,20 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses queries like
* CREATE ROLE [IF NOT EXISTS | OR REPLACE] name
*
* ALTER ROLE [IF EXISTS] name
* [RENAME TO new_name]
*/
class ParserCreateRoleQuery : public IParserBase
{
protected:
const char * getName() const override { return "CREATE ROLE or ALTER ROLE query"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -1,8 +1,8 @@
#include <Parsers/ParserCreateRowPolicyQuery.h> #include <Parsers/ParserCreateRowPolicyQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h> #include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Access/RowPolicy.h> #include <Access/RowPolicy.h>
#include <Parsers/ParserRoleList.h> #include <Parsers/ParserGenericRoleSet.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/parseIdentifierOrStringLiteral.h> #include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseDatabaseAndTableName.h> #include <Parsers/parseDatabaseAndTableName.h>
#include <Parsers/ExpressionListParsers.h> #include <Parsers/ExpressionListParsers.h>
@ -21,13 +21,10 @@ namespace
{ {
using ConditionIndex = RowPolicy::ConditionIndex; using ConditionIndex = RowPolicy::ConditionIndex;
bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_policy_name, bool alter) bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_policy_name)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
if (!new_policy_name.empty() || !alter)
return false;
if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected))
return false; return false;
@ -35,46 +32,48 @@ namespace
}); });
} }
bool parseIsRestrictive(IParserBase::Pos & pos, Expected & expected, std::optional<bool> & is_restrictive) bool parseAsRestrictiveOrPermissive(IParserBase::Pos & pos, Expected & expected, std::optional<bool> & is_restrictive)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
if (is_restrictive)
return false;
if (!ParserKeyword{"AS"}.ignore(pos, expected)) if (!ParserKeyword{"AS"}.ignore(pos, expected))
return false; return false;
if (ParserKeyword{"RESTRICTIVE"}.ignore(pos, expected)) if (ParserKeyword{"RESTRICTIVE"}.ignore(pos, expected))
{
is_restrictive = true; is_restrictive = true;
else if (ParserKeyword{"PERMISSIVE"}.ignore(pos, expected)) return true;
is_restrictive = false; }
else
if (!ParserKeyword{"PERMISSIVE"}.ignore(pos, expected))
return false; return false;
is_restrictive = false;
return true; return true;
}); });
} }
bool parseConditionalExpression(IParserBase::Pos & pos, Expected & expected, std::optional<ASTPtr> & expr) bool parseConditionalExpression(IParserBase::Pos & pos, Expected & expected, std::optional<ASTPtr> & expr)
{ {
if (ParserKeyword("NONE").ignore(pos, expected)) return IParserBase::wrapParseImpl(pos, [&]
{
expr = nullptr;
return true;
}
ParserExpression parser;
ASTPtr x;
if (parser.parse(pos, x, expected))
{ {
if (ParserKeyword("NONE").ignore(pos, expected))
{
expr = nullptr;
return true;
}
ParserExpression parser;
ASTPtr x;
if (!parser.parse(pos, x, expected))
return false;
expr = x; expr = x;
return true; return true;
} });
expr.reset();
return false;
} }
bool parseConditions(IParserBase::Pos & pos, Expected & expected, std::vector<std::pair<ConditionIndex, ASTPtr>> & conditions, bool alter) bool parseConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector<std::pair<ConditionIndex, ASTPtr>> & conditions)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
@ -171,29 +170,32 @@ namespace
}); });
} }
bool parseMultipleConditions(IParserBase::Pos & pos, Expected & expected, std::vector<std::pair<ConditionIndex, ASTPtr>> & conditions, bool alter) bool parseMultipleConditions(IParserBase::Pos & pos, Expected & expected, bool alter, std::vector<std::pair<ConditionIndex, ASTPtr>> & conditions)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
std::vector<std::pair<ConditionIndex, ASTPtr>> res_conditions;
do do
{ {
if (!parseConditions(pos, expected, conditions, alter)) if (!parseConditions(pos, expected, alter, res_conditions))
return false; return false;
} }
while (ParserToken{TokenType::Comma}.ignore(pos, expected)); while (ParserToken{TokenType::Comma}.ignore(pos, expected));
conditions = std::move(res_conditions);
return true; return true;
}); });
} }
bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr<ASTRoleList> & roles) bool parseToRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr<ASTGenericRoleSet> & roles)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
ASTPtr node; ASTPtr ast;
if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserRoleList{}.parse(pos, node, expected)) if (roles || !ParserKeyword{"TO"}.ignore(pos, expected) || !ParserGenericRoleSet{}.parse(pos, ast, expected))
return false; return false;
roles = std::static_pointer_cast<ASTRoleList>(node); roles = std::static_pointer_cast<ASTGenericRoleSet>(ast);
return true; return true;
}); });
} }
@ -237,11 +239,24 @@ bool ParserCreateRowPolicyQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
String new_policy_name; String new_policy_name;
std::optional<bool> is_restrictive; std::optional<bool> is_restrictive;
std::vector<std::pair<ConditionIndex, ASTPtr>> conditions; std::vector<std::pair<ConditionIndex, ASTPtr>> conditions;
std::shared_ptr<ASTRoleList> roles; std::shared_ptr<ASTGenericRoleSet> roles;
while (parseRenameTo(pos, expected, new_policy_name, alter) || parseIsRestrictive(pos, expected, is_restrictive) while (true)
|| parseMultipleConditions(pos, expected, conditions, alter) || parseRoles(pos, expected, roles)) {
; if (alter && new_policy_name.empty() && parseRenameTo(pos, expected, new_policy_name))
continue;
if (!is_restrictive && parseAsRestrictiveOrPermissive(pos, expected, is_restrictive))
continue;
if (parseMultipleConditions(pos, expected, alter, conditions))
continue;
if (!roles && parseToRoles(pos, expected, roles))
continue;
break;
}
auto query = std::make_shared<ASTCreateRowPolicyQuery>(); auto query = std::make_shared<ASTCreateRowPolicyQuery>();
node = query; node = query;

View File

@ -5,7 +5,8 @@
#include <Parsers/parseIdentifierOrStringLiteral.h> #include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/ExpressionElementParsers.h> #include <Parsers/ExpressionElementParsers.h>
#include <Parsers/ASTLiteral.h> #include <Parsers/ASTLiteral.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/ParserGenericRoleSet.h>
#include <ext/range.h> #include <ext/range.h>
#include <boost/algorithm/string/predicate.hpp> #include <boost/algorithm/string/predicate.hpp>
@ -24,9 +25,6 @@ namespace
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
if (!new_name.empty())
return false;
if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected))
return false; return false;
@ -35,14 +33,20 @@ namespace
} }
bool parsePassword(IParserBase::Pos & pos, Expected & expected, String & password) bool parseByPassword(IParserBase::Pos & pos, Expected & expected, String & password)
{ {
ASTPtr ast; return IParserBase::wrapParseImpl(pos, [&]
if (!ParserStringLiteral{}.parse(pos, ast, expected)) {
return false; if (!ParserKeyword{"BY"}.ignore(pos, expected))
return false;
password = ast->as<const ASTLiteral &>().value.safeGet<String>(); ASTPtr ast;
return true; if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false;
password = ast->as<const ASTLiteral &>().value.safeGet<String>();
return true;
});
} }
@ -50,70 +54,79 @@ namespace
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
if (authentication)
return false;
if (!ParserKeyword{"IDENTIFIED"}.ignore(pos, expected)) if (!ParserKeyword{"IDENTIFIED"}.ignore(pos, expected))
return false; return false;
if (ParserKeyword{"WITH"}.ignore(pos, expected)) if (!ParserKeyword{"WITH"}.ignore(pos, expected))
{
if (ParserKeyword{"NO_PASSWORD"}.ignore(pos, expected))
{
authentication = Authentication{Authentication::NO_PASSWORD};
}
else if (ParserKeyword{"PLAINTEXT_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::PLAINTEXT_PASSWORD};
authentication->setPassword(password);
}
else if (ParserKeyword{"SHA256_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::SHA256_PASSWORD};
authentication->setPassword(password);
}
else if (ParserKeyword{"SHA256_HASH"}.ignore(pos, expected))
{
String hash;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, hash))
return false;
authentication = Authentication{Authentication::SHA256_PASSWORD};
authentication->setPasswordHashHex(hash);
}
else if (ParserKeyword{"DOUBLE_SHA1_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD};
authentication->setPassword(password);
}
else if (ParserKeyword{"DOUBLE_SHA1_HASH"}.ignore(pos, expected))
{
String hash;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, hash))
return false;
authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD};
authentication->setPasswordHashHex(hash);
}
else
return false;
}
else
{ {
String password; String password;
if (!ParserKeyword{"BY"}.ignore(pos, expected) || !parsePassword(pos, expected, password)) if (!parseByPassword(pos, expected, password))
return false; return false;
authentication = Authentication{Authentication::SHA256_PASSWORD}; authentication = Authentication{Authentication::SHA256_PASSWORD};
authentication->setPassword(password); authentication->setPassword(password);
return true;
} }
if (ParserKeyword{"PLAINTEXT_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!parseByPassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::PLAINTEXT_PASSWORD};
authentication->setPassword(password);
return true;
}
if (ParserKeyword{"SHA256_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!parseByPassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::SHA256_PASSWORD};
authentication->setPassword(password);
return true;
}
if (ParserKeyword{"SHA256_HASH"}.ignore(pos, expected))
{
String hash;
if (!parseByPassword(pos, expected, hash))
return false;
authentication = Authentication{Authentication::SHA256_PASSWORD};
authentication->setPasswordHashHex(hash);
return true;
}
if (ParserKeyword{"DOUBLE_SHA1_PASSWORD"}.ignore(pos, expected))
{
String password;
if (!parseByPassword(pos, expected, password))
return false;
authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD};
authentication->setPassword(password);
return true;
}
if (ParserKeyword{"DOUBLE_SHA1_HASH"}.ignore(pos, expected))
{
String hash;
if (!parseByPassword(pos, expected, hash))
return false;
authentication = Authentication{Authentication::DOUBLE_SHA1_PASSWORD};
authentication->setPasswordHashHex(hash);
return true;
}
if (!ParserKeyword{"NO_PASSWORD"}.ignore(pos, expected))
return false;
authentication = Authentication{Authentication::NO_PASSWORD};
return true; return true;
}); });
} }
@ -144,13 +157,12 @@ namespace
return true; return true;
} }
AllowedClientHosts new_hosts;
do do
{ {
if (ParserKeyword{"LOCAL"}.ignore(pos, expected)) if (ParserKeyword{"LOCAL"}.ignore(pos, expected))
{ {
if (!hosts) new_hosts.addLocalHost();
hosts.emplace();
hosts->addLocalHost();
} }
else if (ParserKeyword{"NAME REGEXP"}.ignore(pos, expected)) else if (ParserKeyword{"NAME REGEXP"}.ignore(pos, expected))
{ {
@ -158,9 +170,7 @@ namespace
if (!ParserStringLiteral{}.parse(pos, ast, expected)) if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false; return false;
if (!hosts) new_hosts.addNameRegexp(ast->as<const ASTLiteral &>().value.safeGet<String>());
hosts.emplace();
hosts->addNameRegexp(ast->as<const ASTLiteral &>().value.safeGet<String>());
} }
else if (ParserKeyword{"NAME"}.ignore(pos, expected)) else if (ParserKeyword{"NAME"}.ignore(pos, expected))
{ {
@ -168,9 +178,7 @@ namespace
if (!ParserStringLiteral{}.parse(pos, ast, expected)) if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false; return false;
if (!hosts) new_hosts.addName(ast->as<const ASTLiteral &>().value.safeGet<String>());
hosts.emplace();
hosts->addName(ast->as<const ASTLiteral &>().value.safeGet<String>());
} }
else if (ParserKeyword{"IP"}.ignore(pos, expected)) else if (ParserKeyword{"IP"}.ignore(pos, expected))
{ {
@ -178,9 +186,7 @@ namespace
if (!ParserStringLiteral{}.parse(pos, ast, expected)) if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false; return false;
if (!hosts) new_hosts.addSubnet(ast->as<const ASTLiteral &>().value.safeGet<String>());
hosts.emplace();
hosts->addSubnet(ast->as<const ASTLiteral &>().value.safeGet<String>());
} }
else if (ParserKeyword{"LIKE"}.ignore(pos, expected)) else if (ParserKeyword{"LIKE"}.ignore(pos, expected))
{ {
@ -188,14 +194,33 @@ namespace
if (!ParserStringLiteral{}.parse(pos, ast, expected)) if (!ParserStringLiteral{}.parse(pos, ast, expected))
return false; return false;
if (!hosts) new_hosts.addLikePattern(ast->as<const ASTLiteral &>().value.safeGet<String>());
hosts.emplace();
hosts->addLikePattern(ast->as<const ASTLiteral &>().value.safeGet<String>());
} }
else else
return false; return false;
} }
while (ParserToken{TokenType::Comma}.ignore(pos, expected)); while (ParserToken{TokenType::Comma}.ignore(pos, expected));
if (!hosts)
hosts.emplace();
hosts->add(new_hosts);
return true;
});
}
bool parseDefaultRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr<ASTGenericRoleSet> & default_roles)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!ParserKeyword{"DEFAULT ROLE"}.ignore(pos, expected))
return false;
ASTPtr ast;
if (!ParserGenericRoleSet{}.allowCurrentUser(false).parse(pos, ast, expected))
return false;
default_roles = typeid_cast<std::shared_ptr<ASTGenericRoleSet>>(ast);
return true; return true;
}); });
} }
@ -205,9 +230,6 @@ namespace
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
{ {
if (profile)
return false;
if (!ParserKeyword{"PROFILE"}.ignore(pos, expected)) if (!ParserKeyword{"PROFILE"}.ignore(pos, expected))
return false; return false;
@ -259,15 +281,34 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec
std::optional<AllowedClientHosts> hosts; std::optional<AllowedClientHosts> hosts;
std::optional<AllowedClientHosts> add_hosts; std::optional<AllowedClientHosts> add_hosts;
std::optional<AllowedClientHosts> remove_hosts; std::optional<AllowedClientHosts> remove_hosts;
std::shared_ptr<ASTGenericRoleSet> default_roles;
std::optional<String> profile; std::optional<String> profile;
while (parseAuthentication(pos, expected, authentication) while (true)
|| parseHosts(pos, expected, nullptr, hosts) {
|| parseProfileName(pos, expected, profile) if (!authentication && parseAuthentication(pos, expected, authentication))
|| (alter && parseRenameTo(pos, expected, new_name, new_host_pattern)) continue;
|| (alter && parseHosts(pos, expected, "ADD", add_hosts))
|| (alter && parseHosts(pos, expected, "REMOVE", remove_hosts))) if (parseHosts(pos, expected, nullptr, hosts))
; continue;
if (!profile && parseProfileName(pos, expected, profile))
continue;
if (!default_roles && parseDefaultRoles(pos, expected, default_roles))
continue;
if (alter)
{
if (new_name.empty() && parseRenameTo(pos, expected, new_name, new_host_pattern))
continue;
if (parseHosts(pos, expected, "ADD", add_hosts) || parseHosts(pos, expected, "REMOVE", remove_hosts))
continue;
}
break;
}
if (!hosts) if (!hosts)
{ {

View File

@ -13,47 +13,64 @@ namespace
{ {
bool parseNames(IParserBase::Pos & pos, Expected & expected, Strings & names) bool parseNames(IParserBase::Pos & pos, Expected & expected, Strings & names)
{ {
do return IParserBase::wrapParseImpl(pos, [&]
{ {
String name; Strings res_names;
if (!parseIdentifierOrStringLiteral(pos, expected, name)) do
return false; {
String name;
if (!parseIdentifierOrStringLiteral(pos, expected, name))
return false;
names.push_back(std::move(name)); res_names.push_back(std::move(name));
} }
while (ParserToken{TokenType::Comma}.ignore(pos, expected)); while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return true;
names = std::move(res_names);
return true;
});
} }
bool parseRowPolicyNames(IParserBase::Pos & pos, Expected & expected, std::vector<RowPolicy::FullNameParts> & row_policies_names) bool parseRowPolicyNames(IParserBase::Pos & pos, Expected & expected, std::vector<RowPolicy::FullNameParts> & names)
{ {
do return IParserBase::wrapParseImpl(pos, [&]
{ {
Strings policy_names; std::vector<RowPolicy::FullNameParts> res_names;
if (!parseNames(pos, expected, policy_names)) do
return false; {
String database, table_name; Strings policy_names;
if (!ParserKeyword{"ON"}.ignore(pos, expected) || !parseDatabaseAndTableName(pos, expected, database, table_name)) if (!parseNames(pos, expected, policy_names))
return false; return false;
for (const String & policy_name : policy_names) String database, table_name;
row_policies_names.push_back({database, table_name, policy_name}); if (!ParserKeyword{"ON"}.ignore(pos, expected) || !parseDatabaseAndTableName(pos, expected, database, table_name))
} return false;
while (ParserToken{TokenType::Comma}.ignore(pos, expected)); for (const String & policy_name : policy_names)
return true; res_names.push_back({database, table_name, policy_name});
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
names = std::move(res_names);
return true;
});
} }
bool parseUserNames(IParserBase::Pos & pos, Expected & expected, Strings & names) bool parseUserNames(IParserBase::Pos & pos, Expected & expected, Strings & names)
{ {
do return IParserBase::wrapParseImpl(pos, [&]
{ {
String name; Strings res_names;
if (!parseUserName(pos, expected, name)) do
return false; {
String name;
if (!parseUserName(pos, expected, name))
return false;
names.push_back(std::move(name)); res_names.emplace_back(std::move(name));
} }
while (ParserToken{TokenType::Comma}.ignore(pos, expected)); while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return true; names = std::move(res_names);
return true;
});
} }
} }
@ -65,12 +82,14 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
using Kind = ASTDropAccessEntityQuery::Kind; using Kind = ASTDropAccessEntityQuery::Kind;
Kind kind; Kind kind;
if (ParserKeyword{"QUOTA"}.ignore(pos, expected)) if (ParserKeyword{"USER"}.ignore(pos, expected))
kind = Kind::USER;
else if (ParserKeyword{"ROLE"}.ignore(pos, expected))
kind = Kind::ROLE;
else if (ParserKeyword{"QUOTA"}.ignore(pos, expected))
kind = Kind::QUOTA; kind = Kind::QUOTA;
else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected)) else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected))
kind = Kind::ROW_POLICY; kind = Kind::ROW_POLICY;
else if (ParserKeyword{"USER"}.ignore(pos, expected))
kind = Kind::USER;
else else
return false; return false;
@ -81,7 +100,7 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
Strings names; Strings names;
std::vector<RowPolicy::FullNameParts> row_policies_names; std::vector<RowPolicy::FullNameParts> row_policies_names;
if (kind == Kind::USER) if ((kind == Kind::USER) || (kind == Kind::ROLE))
{ {
if (!parseUserNames(pos, expected, names)) if (!parseUserNames(pos, expected, names))
return false; return false;

View File

@ -6,9 +6,10 @@
namespace DB namespace DB
{ {
/** Parses queries like /** Parses queries like
* DROP USER [IF EXISTS] name [,...]
* DROP ROLE [IF EXISTS] name [,...]
* DROP QUOTA [IF EXISTS] name [,...] * DROP QUOTA [IF EXISTS] name [,...]
* DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...] * DROP [ROW] POLICY [IF EXISTS] name [,...] ON [database.]table [,...]
* DROP USER [IF EXISTS] name [,...]
*/ */
class ParserDropAccessEntityQuery : public IParserBase class ParserDropAccessEntityQuery : public IParserBase
{ {

View File

@ -0,0 +1,98 @@
#include <Parsers/ParserGenericRoleSet.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/parseUserName.h>
#include <boost/range/algorithm/find.hpp>
namespace DB
{
namespace
{
bool parseBeforeExcept(IParserBase::Pos & pos, Expected & expected, bool * all, bool * current_user, Strings & names)
{
return IParserBase::wrapParseImpl(pos, [&]
{
bool res_all = false;
bool res_current_user = false;
Strings res_names;
while (true)
{
if (ParserKeyword{"NONE"}.ignore(pos, expected))
{
}
else if (
current_user && (ParserKeyword{"CURRENT_USER"}.ignore(pos, expected) || ParserKeyword{"currentUser"}.ignore(pos, expected)))
{
if (ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected))
{
if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected))
return false;
}
res_current_user = true;
}
else if (all && ParserKeyword{"ALL"}.ignore(pos, expected))
{
res_all = true;
}
else
{
String name;
if (!parseUserName(pos, expected, name))
return false;
res_names.push_back(name);
}
if (!ParserToken{TokenType::Comma}.ignore(pos, expected))
break;
}
if (all)
*all = res_all;
if (current_user)
*current_user = res_current_user;
names = std::move(res_names);
return true;
});
}
bool parseExcept(IParserBase::Pos & pos, Expected & expected, bool * except_current_user, Strings & except_names)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!ParserKeyword{"EXCEPT"}.ignore(pos, expected))
return false;
return parseBeforeExcept(pos, expected, nullptr, except_current_user, except_names);
});
}
}
bool ParserGenericRoleSet::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
Strings names;
bool current_user = false;
bool all = false;
Strings except_names;
bool except_current_user = false;
if (!parseBeforeExcept(pos, expected, (allow_all ? &all : nullptr), (allow_current_user ? &current_user : nullptr), names))
return false;
parseExcept(pos, expected, (allow_current_user ? &except_current_user : nullptr), except_names);
if (all)
names.clear();
auto result = std::make_shared<ASTGenericRoleSet>();
result->names = std::move(names);
result->current_user = current_user;
result->all = all;
result->except_names = std::move(except_names);
result->except_current_user = except_current_user;
node = result;
return true;
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses a string like this:
* {role|CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {role|CURRENT_USER} [,...]
*/
class ParserGenericRoleSet : public IParserBase
{
public:
ParserGenericRoleSet & allowAll(bool allow_) { allow_all = allow_; return *this; }
ParserGenericRoleSet & allowCurrentUser(bool allow_) { allow_current_user = allow_; return *this; }
protected:
const char * getName() const override { return "GenericRoleSet"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
private:
bool allow_all = true;
bool allow_current_user = true;
};
}

View File

@ -1,20 +1,34 @@
#include <Parsers/ParserGrantQuery.h> #include <Parsers/ParserGrantQuery.h>
#include <Parsers/ASTGrantQuery.h> #include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTIdentifier.h> #include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTRoleList.h> #include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/CommonParsers.h> #include <Parsers/CommonParsers.h>
#include <Parsers/ExpressionElementParsers.h> #include <Parsers/ExpressionElementParsers.h>
#include <Parsers/ParserRoleList.h> #include <Parsers/ParserGenericRoleSet.h>
#include <boost/algorithm/string/predicate.hpp> #include <boost/algorithm/string/predicate.hpp>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int SYNTAX_ERROR;
}
namespace namespace
{ {
bool parseRoundBrackets(IParser::Pos & pos, Expected & expected)
{
return IParserBase::wrapParseImpl(pos, [&]
{
return ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected)
&& ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected);
});
}
bool parseAccessFlags(IParser::Pos & pos, Expected & expected, AccessFlags & access_flags) bool parseAccessFlags(IParser::Pos & pos, Expected & expected, AccessFlags & access_flags)
{ {
auto is_one_of_access_type_words = [](IParser::Pos & pos_) static constexpr auto is_one_of_access_type_words = [](IParser::Pos & pos_)
{ {
if (pos_->type != TokenType::BareWord) if (pos_->type != TokenType::BareWord)
return false; return false;
@ -24,86 +38,97 @@ namespace
return true; return true;
}; };
if (!is_one_of_access_type_words(pos)) expected.add(pos, "access type");
{
expected.add(pos, "access type");
return false;
}
String str; return IParserBase::wrapParseImpl(pos, [&]
do
{ {
if (!str.empty()) if (!is_one_of_access_type_words(pos))
str += " "; return false;
std::string_view word{pos->begin, pos->size()};
str += std::string_view(pos->begin, pos->size());
++pos;
}
while (is_one_of_access_type_words(pos));
if (pos->type == TokenType::OpeningRoundBracket) String str;
{ do
auto old_pos = pos;
++pos;
if (pos->type == TokenType::ClosingRoundBracket)
{ {
if (!str.empty())
str += " ";
std::string_view word{pos->begin, pos->size()};
str += std::string_view(pos->begin, pos->size());
++pos; ++pos;
str += "()";
} }
else while (is_one_of_access_type_words(pos));
pos = old_pos;
}
access_flags = AccessFlags{str}; try
return true; {
access_flags = AccessFlags{str};
}
catch (...)
{
return false;
}
parseRoundBrackets(pos, expected);
return true;
});
} }
bool parseColumnNames(IParser::Pos & pos, Expected & expected, Strings & columns) bool parseColumnNames(IParser::Pos & pos, Expected & expected, Strings & columns)
{ {
if (!ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected)) return IParserBase::wrapParseImpl(pos, [&]
return false;
do
{ {
ASTPtr column_ast; if (!ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected))
if (!ParserIdentifier().parse(pos, column_ast, expected))
return false; return false;
columns.push_back(getIdentifierName(column_ast));
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
return ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected); Strings res_columns;
do
{
ASTPtr column_ast;
if (!ParserIdentifier().parse(pos, column_ast, expected))
return false;
res_columns.emplace_back(getIdentifierName(column_ast));
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected))
return false;
columns = std::move(res_columns);
return true;
});
} }
bool parseDatabaseAndTableNameOrMaybeAsterisks( bool parseDatabaseAndTableNameOrMaybeAsterisks(
IParser::Pos & pos, Expected & expected, String & database_name, bool & any_database, String & table_name, bool & any_table) IParser::Pos & pos, Expected & expected, String & database_name, bool & any_database, String & table_name, bool & any_table)
{ {
ASTPtr ast[2]; return IParserBase::wrapParseImpl(pos, [&]
if (ParserToken{TokenType::Asterisk}.ignore(pos, expected))
{ {
if (ParserToken{TokenType::Dot}.ignore(pos, expected)) ASTPtr ast[2];
if (ParserToken{TokenType::Asterisk}.ignore(pos, expected))
{ {
if (!ParserToken{TokenType::Asterisk}.ignore(pos, expected)) if (ParserToken{TokenType::Dot}.ignore(pos, expected))
return false; {
if (!ParserToken{TokenType::Asterisk}.ignore(pos, expected))
return false;
/// *.* (any table in any database)
any_database = true;
database_name.clear();
any_table = true;
table_name.clear();
return true;
}
/// *.* (any table in any database)
any_database = true;
any_table = true;
return true;
}
else
{
/// * (any table in the current database) /// * (any table in the current database)
any_database = false; any_database = false;
database_name = ""; database_name.clear();
any_table = true; any_table = true;
table_name.clear();
return true; return true;
} }
}
else if (ParserIdentifier().parse(pos, ast[0], expected)) if (!ParserIdentifier().parse(pos, ast[0], expected))
{ return false;
if (ParserToken{TokenType::Dot}.ignore(pos, expected)) if (ParserToken{TokenType::Dot}.ignore(pos, expected))
{ {
if (ParserToken{TokenType::Asterisk}.ignore(pos, expected)) if (ParserToken{TokenType::Asterisk}.ignore(pos, expected))
@ -112,31 +137,117 @@ namespace
any_database = false; any_database = false;
database_name = getIdentifierName(ast[0]); database_name = getIdentifierName(ast[0]);
any_table = true; any_table = true;
table_name.clear();
return true; return true;
} }
else if (ParserIdentifier().parse(pos, ast[1], expected))
if (!ParserIdentifier().parse(pos, ast[1], expected))
return false;
/// <database_name>.<table_name>
any_database = false;
database_name = getIdentifierName(ast[0]);
any_table = false;
table_name = getIdentifierName(ast[1]);
return true;
}
/// <table_name> - the current database, specified table
any_database = false;
database_name.clear();
any_table = false;
table_name = getIdentifierName(ast[0]);
return true;
});
}
bool parseAccessRightsElements(IParser::Pos & pos, Expected & expected, AccessRightsElements & elements)
{
return IParserBase::wrapParseImpl(pos, [&]
{
AccessRightsElements res_elements;
do
{
std::vector<std::pair<AccessFlags, Strings>> access_and_columns;
do
{ {
/// <database_name>.<table_name> AccessFlags access_flags;
any_database = false; if (!parseAccessFlags(pos, expected, access_flags))
database_name = getIdentifierName(ast[0]); return false;
any_table = false;
table_name = getIdentifierName(ast[1]); Strings columns;
return true; parseColumnNames(pos, expected, columns);
access_and_columns.emplace_back(access_flags, std::move(columns));
} }
else while (ParserToken{TokenType::Comma}.ignore(pos, expected));
if (!ParserKeyword{"ON"}.ignore(pos, expected))
return false;
String database_name, table_name;
bool any_database = false, any_table = false;
if (!parseDatabaseAndTableNameOrMaybeAsterisks(pos, expected, database_name, any_database, table_name, any_table))
return false;
for (auto & [access_flags, columns] : access_and_columns)
{
AccessRightsElement element;
element.access_flags = access_flags;
element.any_column = columns.empty();
element.columns = std::move(columns);
element.any_database = any_database;
element.database = database_name;
element.any_table = any_table;
element.table = table_name;
res_elements.emplace_back(std::move(element));
}
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
elements = std::move(res_elements);
return true;
});
}
bool parseRoles(IParser::Pos & pos, Expected & expected, std::shared_ptr<ASTGenericRoleSet> & roles)
{
return IParserBase::wrapParseImpl(pos, [&]
{
ASTPtr ast;
if (!ParserGenericRoleSet{}.allowAll(false).allowCurrentUser(false).parse(pos, ast, expected))
return false;
roles = typeid_cast<std::shared_ptr<ASTGenericRoleSet>>(ast);
return true;
});
}
bool parseToRoles(IParser::Pos & pos, Expected & expected, ASTGrantQuery::Kind kind, std::shared_ptr<ASTGenericRoleSet> & to_roles)
{
return IParserBase::wrapParseImpl(pos, [&]
{
using Kind = ASTGrantQuery::Kind;
if (kind == Kind::GRANT)
{
if (!ParserKeyword{"TO"}.ignore(pos, expected))
return false; return false;
} }
else else
{ {
/// <table_name> - the current database, specified table if (!ParserKeyword{"FROM"}.ignore(pos, expected))
any_database = false; return false;
database_name = "";
table_name = getIdentifierName(ast[0]);
return true;
} }
}
else ASTPtr ast;
return false; if (!ParserGenericRoleSet{}.allowAll(kind == Kind::REVOKE).parse(pos, ast, expected))
return false;
to_roles = typeid_cast<std::shared_ptr<ASTGenericRoleSet>>(ast);
return true;
});
} }
} }
@ -153,79 +264,46 @@ bool ParserGrantQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
return false; return false;
bool grant_option = false; bool grant_option = false;
bool admin_option = false;
if (kind == Kind::REVOKE) if (kind == Kind::REVOKE)
{ {
if (ParserKeyword{"GRANT OPTION FOR"}.ignore(pos, expected)) if (ParserKeyword{"GRANT OPTION FOR"}.ignore(pos, expected))
grant_option = true; grant_option = true;
else if (ParserKeyword{"ADMIN OPTION FOR"}.ignore(pos, expected))
admin_option = true;
} }
AccessRightsElements elements; AccessRightsElements elements;
do std::shared_ptr<ASTGenericRoleSet> roles;
{ if (!parseAccessRightsElements(pos, expected, elements) && !parseRoles(pos, expected, roles))
std::vector<std::pair<AccessFlags, Strings>> access_and_columns; return false;
do
{
AccessFlags access_flags;
if (!parseAccessFlags(pos, expected, access_flags))
return false;
Strings columns; std::shared_ptr<ASTGenericRoleSet> to_roles;
parseColumnNames(pos, expected, columns); if (!parseToRoles(pos, expected, kind, to_roles))
access_and_columns.emplace_back(access_flags, std::move(columns));
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
if (!ParserKeyword{"ON"}.ignore(pos, expected))
return false;
String database_name, table_name;
bool any_database = false, any_table = false;
if (!parseDatabaseAndTableNameOrMaybeAsterisks(pos, expected, database_name, any_database, table_name, any_table))
return false;
for (auto & [access_flags, columns] : access_and_columns)
{
AccessRightsElement element;
element.access_flags = access_flags;
element.any_column = columns.empty();
element.columns = std::move(columns);
element.any_database = any_database;
element.database = database_name;
element.any_table = any_table;
element.table = table_name;
elements.emplace_back(std::move(element));
}
}
while (ParserToken{TokenType::Comma}.ignore(pos, expected));
ASTPtr to_roles;
if (kind == Kind::GRANT)
{
if (!ParserKeyword{"TO"}.ignore(pos, expected))
return false;
}
else
{
if (!ParserKeyword{"FROM"}.ignore(pos, expected))
return false;
}
if (!ParserRoleList{}.parse(pos, to_roles, expected))
return false; return false;
if (kind == Kind::GRANT) if (kind == Kind::GRANT)
{ {
if (ParserKeyword{"WITH GRANT OPTION"}.ignore(pos, expected)) if (ParserKeyword{"WITH GRANT OPTION"}.ignore(pos, expected))
grant_option = true; grant_option = true;
else if (ParserKeyword{"WITH ADMIN OPTION"}.ignore(pos, expected))
admin_option = true;
} }
if (grant_option && roles)
throw Exception("GRANT OPTION should be specified for access types", ErrorCodes::SYNTAX_ERROR);
if (admin_option && !elements.empty())
throw Exception("ADMIN OPTION should be specified for roles", ErrorCodes::SYNTAX_ERROR);
auto query = std::make_shared<ASTGrantQuery>(); auto query = std::make_shared<ASTGrantQuery>();
node = query; node = query;
query->kind = kind; query->kind = kind;
query->access_rights_elements = std::move(elements); query->access_rights_elements = std::move(elements);
query->to_roles = std::static_pointer_cast<ASTRoleList>(to_roles); query->roles = std::move(roles);
query->to_roles = std::move(to_roles);
query->grant_option = grant_option; query->grant_option = grant_option;
query->admin_option = admin_option;
return true; return true;
} }

View File

@ -6,8 +6,8 @@
namespace DB namespace DB
{ {
/** Parses queries like /** Parses queries like
* GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name * GRANT access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO {user_name | CURRENT_USER} [,...] [WITH GRANT OPTION]
* REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} TO user_name * REVOKE access_type[(column_name [,...])] [,...] ON {db.table|db.*|*.*|table|*} FROM {user_name | CURRENT_USER} [,...] | ALL | ALL EXCEPT {user_name | CURRENT_USER} [,...]
*/ */
class ParserGrantQuery : public IParserBase class ParserGrantQuery : public IParserBase
{ {

View File

@ -7,9 +7,11 @@
#include <Parsers/ParserOptimizeQuery.h> #include <Parsers/ParserOptimizeQuery.h>
#include <Parsers/ParserUseQuery.h> #include <Parsers/ParserUseQuery.h>
#include <Parsers/ParserSetQuery.h> #include <Parsers/ParserSetQuery.h>
#include <Parsers/ParserSetRoleQuery.h>
#include <Parsers/ParserAlterQuery.h> #include <Parsers/ParserAlterQuery.h>
#include <Parsers/ParserSystemQuery.h> #include <Parsers/ParserSystemQuery.h>
#include <Parsers/ParserCreateUserQuery.h> #include <Parsers/ParserCreateUserQuery.h>
#include <Parsers/ParserCreateRoleQuery.h>
#include <Parsers/ParserCreateQuotaQuery.h> #include <Parsers/ParserCreateQuotaQuery.h>
#include <Parsers/ParserCreateRowPolicyQuery.h> #include <Parsers/ParserCreateRowPolicyQuery.h>
#include <Parsers/ParserDropAccessEntityQuery.h> #include <Parsers/ParserDropAccessEntityQuery.h>
@ -28,17 +30,21 @@ bool ParserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
ParserSetQuery set_p; ParserSetQuery set_p;
ParserSystemQuery system_p; ParserSystemQuery system_p;
ParserCreateUserQuery create_user_p; ParserCreateUserQuery create_user_p;
ParserCreateRoleQuery create_role_p;
ParserCreateQuotaQuery create_quota_p; ParserCreateQuotaQuery create_quota_p;
ParserCreateRowPolicyQuery create_row_policy_p; ParserCreateRowPolicyQuery create_row_policy_p;
ParserDropAccessEntityQuery drop_access_entity_p; ParserDropAccessEntityQuery drop_access_entity_p;
ParserGrantQuery grant_p; ParserGrantQuery grant_p;
ParserSetRoleQuery set_role_p;
bool res = query_with_output_p.parse(pos, node, expected) bool res = query_with_output_p.parse(pos, node, expected)
|| insert_p.parse(pos, node, expected) || insert_p.parse(pos, node, expected)
|| use_p.parse(pos, node, expected) || use_p.parse(pos, node, expected)
|| set_role_p.parse(pos, node, expected)
|| set_p.parse(pos, node, expected) || set_p.parse(pos, node, expected)
|| system_p.parse(pos, node, expected) || system_p.parse(pos, node, expected)
|| create_user_p.parse(pos, node, expected) || create_user_p.parse(pos, node, expected)
|| create_role_p.parse(pos, node, expected)
|| create_quota_p.parse(pos, node, expected) || create_quota_p.parse(pos, node, expected)
|| create_row_policy_p.parse(pos, node, expected) || create_row_policy_p.parse(pos, node, expected)
|| drop_access_entity_p.parse(pos, node, expected) || drop_access_entity_p.parse(pos, node, expected)

View File

@ -1,78 +0,0 @@
#include <Parsers/ParserRoleList.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/ASTRoleList.h>
#include <Parsers/parseUserName.h>
#include <boost/range/algorithm/find.hpp>
namespace DB
{
bool ParserRoleList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
Strings roles;
bool current_user = false;
bool all_roles = false;
Strings except_roles;
bool except_current_user = false;
bool except_mode = false;
while (true)
{
if (ParserKeyword{"NONE"}.ignore(pos, expected))
{
}
else if (ParserKeyword{"CURRENT_USER"}.ignore(pos, expected) ||
ParserKeyword{"currentUser"}.ignore(pos, expected))
{
if (ParserToken{TokenType::OpeningRoundBracket}.ignore(pos, expected))
{
if (!ParserToken{TokenType::ClosingRoundBracket}.ignore(pos, expected))
return false;
}
if (except_mode && !current_user)
except_current_user = true;
else
current_user = true;
}
else if (ParserKeyword{"ALL"}.ignore(pos, expected))
{
all_roles = true;
if (ParserKeyword{"EXCEPT"}.ignore(pos, expected))
{
except_mode = true;
continue;
}
}
else
{
String name;
if (!parseUserName(pos, expected, name))
return false;
if (except_mode && (boost::range::find(roles, name) == roles.end()))
except_roles.push_back(name);
else
roles.push_back(name);
}
if (!ParserToken{TokenType::Comma}.ignore(pos, expected))
break;
}
if (all_roles)
{
current_user = false;
roles.clear();
}
auto result = std::make_shared<ASTRoleList>();
result->roles = std::move(roles);
result->current_user = current_user;
result->all_roles = all_roles;
result->except_roles = std::move(except_roles);
result->except_current_user = except_current_user;
node = result;
return true;
}
}

View File

@ -1,18 +0,0 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses a string like this:
* {role|CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {role|CURRENT_USER} [,...]
*/
class ParserRoleList : public IParserBase
{
protected:
const char * getName() const override { return "RoleList"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -0,0 +1,80 @@
#include <Parsers/ParserSetRoleQuery.h>
#include <Parsers/ASTSetRoleQuery.h>
#include <Parsers/CommonParsers.h>
#include <Parsers/ASTGenericRoleSet.h>
#include <Parsers/ParserGenericRoleSet.h>
namespace DB
{
namespace
{
bool parseRoles(IParserBase::Pos & pos, Expected & expected, std::shared_ptr<ASTGenericRoleSet> & roles)
{
return IParserBase::wrapParseImpl(pos, [&]
{
ASTPtr ast;
if (!ParserGenericRoleSet{}.allowCurrentUser(false).parse(pos, ast, expected))
return false;
roles = typeid_cast<std::shared_ptr<ASTGenericRoleSet>>(ast);
return true;
});
}
bool parseToUsers(IParserBase::Pos & pos, Expected & expected, std::shared_ptr<ASTGenericRoleSet> & to_users)
{
return IParserBase::wrapParseImpl(pos, [&]
{
if (!ParserKeyword{"TO"}.ignore(pos, expected))
return false;
ASTPtr ast;
if (!ParserGenericRoleSet{}.allowAll(false).parse(pos, ast, expected))
return false;
to_users = typeid_cast<std::shared_ptr<ASTGenericRoleSet>>(ast);
return true;
});
}
}
bool ParserSetRoleQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{
using Kind = ASTSetRoleQuery::Kind;
Kind kind;
if (ParserKeyword{"SET ROLE DEFAULT"}.ignore(pos, expected))
kind = Kind::SET_ROLE_DEFAULT;
else if (ParserKeyword{"SET ROLE"}.ignore(pos, expected))
kind = Kind::SET_ROLE;
else if (ParserKeyword{"SET DEFAULT ROLE"}.ignore(pos, expected))
kind = Kind::SET_DEFAULT_ROLE;
else
return false;
std::shared_ptr<ASTGenericRoleSet> roles;
std::shared_ptr<ASTGenericRoleSet> to_users;
if ((kind == Kind::SET_ROLE) || (kind == Kind::SET_DEFAULT_ROLE))
{
if (!parseRoles(pos, expected, roles))
return false;
if (kind == Kind::SET_DEFAULT_ROLE)
{
if (!parseToUsers(pos, expected, to_users))
return false;
}
}
auto query = std::make_shared<ASTSetRoleQuery>();
node = query;
query->kind = kind;
query->roles = std::move(roles);
query->to_users = std::move(to_users);
return true;
}
}

View File

@ -0,0 +1,18 @@
#pragma once
#include <Parsers/IParserBase.h>
namespace DB
{
/** Parses queries like
* SET ROLE {DEFAULT | NONE | role [,...] | ALL | ALL EXCEPT role [,...]}
* SET DEFAULT ROLE {NONE | role [,...] | ALL | ALL EXCEPT role [,...]} TO {user|CURRENT_USER} [,...]
*/
class ParserSetRoleQuery : public IParserBase
{
protected:
const char * getName() const override { return "SET ROLE or SET DEFAULT ROLE query"; }
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override;
};
}

View File

@ -68,6 +68,7 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe
query->name = std::move(name); query->name = std::move(name);
query->current_quota = current_quota; query->current_quota = current_quota;
query->current_user = current_user;
query->row_policy_name = std::move(row_policy_name); query->row_policy_name = std::move(row_policy_name);
return true; return true;

View File

@ -312,7 +312,7 @@ void TreeExecutorBlockInputStream::setLimits(const IBlockInputStream::LocalLimit
source->setLimits(limits_); source->setLimits(limits_);
} }
void TreeExecutorBlockInputStream::setQuota(const std::shared_ptr<QuotaContext> & quota_) void TreeExecutorBlockInputStream::setQuota(const QuotaContextPtr & quota_)
{ {
for (auto & source : sources_with_progress) for (auto & source : sources_with_progress)
source->setQuota(quota_); source->setQuota(quota_);

View File

@ -42,7 +42,7 @@ public:
void setProgressCallback(const ProgressCallback & callback) final; void setProgressCallback(const ProgressCallback & callback) final;
void setProcessListElement(QueryStatus * elem) final; void setProcessListElement(QueryStatus * elem) final;
void setLimits(const LocalLimits & limits_) final; void setLimits(const LocalLimits & limits_) final;
void setQuota(const std::shared_ptr<QuotaContext> & quota_) final; void setQuota(const QuotaContextPtr & quota_) final;
void addTotalRowsApprox(size_t value) final; void addTotalRowsApprox(size_t value) final;
protected: protected:

Some files were not shown because too many files have changed in this diff Show More