From dd8b29b4fbd72f35b5bd62c42972583032432cc1 Mon Sep 17 00:00:00 2001 From: Vitaly Baranov Date: Sun, 3 May 2020 06:12:03 +0300 Subject: [PATCH] Use enum Type instead of std::type_index to represent the type of IAccessEntity. This change simplifies handling of access entities in access storages. --- src/Access/DiskAccessStorage.cpp | 259 +++++++++--------- src/Access/DiskAccessStorage.h | 25 +- src/Access/IAccessEntity.cpp | 38 +-- src/Access/IAccessEntity.h | 120 +++++++- src/Access/IAccessStorage.cpp | 104 ++++--- src/Access/IAccessStorage.h | 122 +++++---- src/Access/MemoryAccessStorage.cpp | 141 +++++----- src/Access/MemoryAccessStorage.h | 23 +- src/Access/MultipleAccessStorage.cpp | 20 +- src/Access/MultipleAccessStorage.h | 10 +- src/Access/Quota.h | 2 + src/Access/Role.cpp | 1 + src/Access/Role.h | 2 + src/Access/RowPolicy.h | 2 + src/Access/SettingsProfile.cpp | 2 + src/Access/SettingsProfile.h | 2 + src/Access/User.cpp | 1 + src/Access/User.h | 2 + src/Access/UsersConfigAccessStorage.cpp | 37 +-- src/Access/UsersConfigAccessStorage.h | 8 +- src/Common/ErrorCodes.cpp | 1 + .../InterpreterDropAccessEntityQuery.cpp | 62 ++--- .../InterpreterDropAccessEntityQuery.h | 4 + ...InterpreterShowCreateAccessEntityQuery.cpp | 57 ++-- .../InterpreterShowGrantsQuery.cpp | 2 +- src/Parsers/ASTDropAccessEntityQuery.cpp | 30 +- src/Parsers/ASTDropAccessEntityQuery.h | 12 +- .../ASTShowCreateAccessEntityQuery.cpp | 30 +- src/Parsers/ASTShowCreateAccessEntityQuery.h | 12 +- src/Parsers/ParserDropAccessEntityQuery.cpp | 35 +-- .../ParserShowCreateAccessEntityQuery.cpp | 42 +-- .../test_access_control_on_cluster/test.py | 6 +- .../test_disk_access_storage/test.py | 6 +- 33 files changed, 626 insertions(+), 594 deletions(-) diff --git a/src/Access/DiskAccessStorage.cpp b/src/Access/DiskAccessStorage.cpp index df9cbab891f..bc3e35a4fc7 100644 --- a/src/Access/DiskAccessStorage.cpp +++ b/src/Access/DiskAccessStorage.cpp @@ -53,6 +53,9 @@ namespace ErrorCodes namespace { + using EntityType = IAccessStorage::EntityType; + using EntityTypeInfo = IAccessStorage::EntityTypeInfo; + /// Special parser for the 'ATTACH access entity' queries. class ParserAttachAccessEntity : public IParserBase { @@ -79,7 +82,7 @@ namespace /// Reads a file containing ATTACH queries and then parses it to build an access entity. - AccessEntityPtr readAccessEntityFile(const std::filesystem::path & file_path) + AccessEntityPtr readEntityFile(const std::filesystem::path & file_path) { /// Read the file. ReadBufferFromFile in{file_path}; @@ -164,11 +167,11 @@ namespace } - AccessEntityPtr tryReadAccessEntityFile(const std::filesystem::path & file_path, Poco::Logger & log) + AccessEntityPtr tryReadEntityFile(const std::filesystem::path & file_path, Poco::Logger & log) { try { - return readAccessEntityFile(file_path); + return readEntityFile(file_path); } catch (...) { @@ -179,12 +182,12 @@ namespace /// Writes ATTACH queries for building a specified access entity to a file. - void writeAccessEntityFile(const std::filesystem::path & file_path, const IAccessEntity & entity) + void writeEntityFile(const std::filesystem::path & file_path, const IAccessEntity & entity) { /// Build list of ATTACH queries. ASTs queries; queries.push_back(InterpreterShowCreateAccessEntityQuery::getAttachQuery(entity)); - if (entity.getType() == typeid(User) || entity.getType() == typeid(Role)) + if ((entity.getType() == EntityType::USER) || (entity.getType() == EntityType::ROLE)) boost::range::push_back(queries, InterpreterShowGrantsQuery::getAttachGrantQueries(entity)); /// Serialize the list of ATTACH queries to a string. @@ -213,21 +216,21 @@ namespace /// Calculates the path to a file named .sql for saving an access entity. - std::filesystem::path getAccessEntityFilePath(const String & directory_path, const UUID & id) + std::filesystem::path getEntityFilePath(const String & directory_path, const UUID & id) { return std::filesystem::path(directory_path).append(toString(id)).replace_extension(".sql"); } /// Reads a map of name of access entity to UUID for access entities of some type from a file. - std::unordered_map readListFile(const std::filesystem::path & file_path) + std::vector> readListFile(const std::filesystem::path & file_path) { ReadBufferFromFile in(file_path); size_t num; readVarUInt(num, in); - std::unordered_map res; - res.reserve(num); + std::vector> id_name_pairs; + id_name_pairs.reserve(num); for (size_t i = 0; i != num; ++i) { @@ -235,19 +238,19 @@ namespace readStringBinary(name, in); UUID id; readUUIDText(id, in); - res[name] = id; + id_name_pairs.emplace_back(id, std::move(name)); } - return res; + return id_name_pairs; } /// Writes a map of name of access entity to UUID for access entities of some type to a file. - void writeListFile(const std::filesystem::path & file_path, const std::unordered_map & map) + void writeListFile(const std::filesystem::path & file_path, const std::vector> & id_name_pairs) { WriteBufferFromFile out(file_path); - writeVarUInt(map.size(), out); - for (const auto & [name, id] : map) + writeVarUInt(id_name_pairs.size(), out); + for (const auto & [id, name] : id_name_pairs) { writeStringBinary(name, out); writeUUIDText(id, out); @@ -256,24 +259,10 @@ namespace /// Calculates the path for storing a map of name of access entity to UUID for access entities of some type. - std::filesystem::path getListFilePath(const String & directory_path, std::type_index type) + std::filesystem::path getListFilePath(const String & directory_path, EntityType type) { - std::string_view file_name; - if (type == typeid(User)) - file_name = "users"; - else if (type == typeid(Role)) - file_name = "roles"; - else if (type == typeid(Quota)) - file_name = "quotas"; - else if (type == typeid(RowPolicy)) - file_name = "row_policies"; - else if (type == typeid(SettingsProfile)) - file_name = "settings_profiles"; - else - throw Exception("Unexpected type of access entity: " + IAccessEntity::getTypeName(type), - ErrorCodes::LOGICAL_ERROR); - - return std::filesystem::path(directory_path).append(file_name).replace_extension(".list"); + std::string_view file_name = EntityTypeInfo::get(type).list_filename; + return std::filesystem::path(directory_path).append(file_name); } @@ -297,21 +286,12 @@ namespace return false; } } - - - const std::vector & getAllAccessEntityTypes() - { - static const std::vector res = {typeid(User), typeid(Role), typeid(RowPolicy), typeid(Quota), typeid(SettingsProfile)}; - return res; - } } DiskAccessStorage::DiskAccessStorage() : IAccessStorage("disk") { - for (auto type : getAllAccessEntityTypes()) - name_to_id_maps[type]; } @@ -363,18 +343,27 @@ void DiskAccessStorage::initialize(const String & directory_path_, Notifications writeLists(); } - for (const auto & [id, entry] : id_to_entry_map) + for (const auto & [id, entry] : entries_by_id) prepareNotifications(id, entry, false, notifications); } +void DiskAccessStorage::clear() +{ + entries_by_id.clear(); + for (auto type : ext::range(EntityType::MAX)) + entries_by_name_and_type[static_cast(type)].clear(); +} + + bool DiskAccessStorage::readLists() { - assert(id_to_entry_map.empty()); + clear(); + bool ok = true; - for (auto type : getAllAccessEntityTypes()) + for (auto type : ext::range(EntityType::MAX)) { - auto & name_to_id_map = name_to_id_maps.at(type); + auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; auto file_path = getListFilePath(directory_path, type); if (!std::filesystem::exists(file_path)) { @@ -385,7 +374,14 @@ bool DiskAccessStorage::readLists() try { - name_to_id_map = readListFile(file_path); + for (const auto & [id, name] : readListFile(file_path)) + { + auto & entry = entries_by_id[id]; + entry.id = id; + entry.type = type; + entry.name = name; + entries_by_name[entry.name] = &entry; + } } catch (...) { @@ -393,17 +389,10 @@ bool DiskAccessStorage::readLists() ok = false; break; } - - for (const auto & [name, id] : name_to_id_map) - id_to_entry_map.emplace(id, Entry{name, type}); } if (!ok) - { - id_to_entry_map.clear(); - for (auto & name_to_id_map : name_to_id_maps | boost::adaptors::map_values) - name_to_id_map.clear(); - } + clear(); return ok; } @@ -419,11 +408,15 @@ bool DiskAccessStorage::writeLists() for (const auto & type : types_of_lists_to_write) { - const auto & name_to_id_map = name_to_id_maps.at(type); + auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; auto file_path = getListFilePath(directory_path, type); try { - writeListFile(file_path, name_to_id_map); + std::vector> id_name_pairs; + id_name_pairs.reserve(entries_by_name.size()); + for (const auto * entry : entries_by_name | boost::adaptors::map_values) + id_name_pairs.emplace_back(entry->id, entry->name); + writeListFile(file_path, id_name_pairs); } catch (...) { @@ -441,7 +434,7 @@ bool DiskAccessStorage::writeLists() } -void DiskAccessStorage::scheduleWriteLists(std::type_index type) +void DiskAccessStorage::scheduleWriteLists(EntityType type) { if (failed_to_write_lists) return; @@ -504,7 +497,7 @@ void DiskAccessStorage::listsWritingThreadFunc() bool DiskAccessStorage::rebuildLists() { LOG_WARNING(getLogger(), "Recovering lists in directory " + directory_path); - assert(id_to_entry_map.empty()); + clear(); for (const auto & directory_entry : std::filesystem::directory_iterator(directory_path)) { @@ -518,58 +511,64 @@ bool DiskAccessStorage::rebuildLists() if (!tryParseUUID(path.stem(), id)) continue; - const auto access_entity_file_path = getAccessEntityFilePath(directory_path, id); - auto entity = tryReadAccessEntityFile(access_entity_file_path, *getLogger()); + const auto access_entity_file_path = getEntityFilePath(directory_path, id); + auto entity = tryReadEntityFile(access_entity_file_path, *getLogger()); if (!entity) continue; + const String & name = entity->getName(); auto type = entity->getType(); - auto & name_to_id_map = name_to_id_maps.at(type); - auto it_by_name = name_to_id_map.emplace(entity->getName(), id).first; - id_to_entry_map.emplace(id, Entry{it_by_name->first, type}); + auto & entry = entries_by_id[id]; + entry.id = id; + entry.type = type; + entry.name = name; + entry.entity = entity; + auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; + entries_by_name[entry.name] = &entry; } - for (auto type : getAllAccessEntityTypes()) + for (auto type : ext::range(EntityType::MAX)) types_of_lists_to_write.insert(type); return true; } -std::optional DiskAccessStorage::findImpl(std::type_index type, const String & name) const +std::optional DiskAccessStorage::findImpl(EntityType type, const String & name) const { std::lock_guard lock{mutex}; - const auto & name_to_id_map = name_to_id_maps.at(type); - auto it = name_to_id_map.find(name); - if (it == name_to_id_map.end()) + const auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; + auto it = entries_by_name.find(name); + if (it == entries_by_name.end()) return {}; - return it->second; + return it->second->id; } -std::vector DiskAccessStorage::findAllImpl(std::type_index type) const +std::vector DiskAccessStorage::findAllImpl(EntityType type) const { std::lock_guard lock{mutex}; - const auto & name_to_id_map = name_to_id_maps.at(type); + const auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; std::vector res; - res.reserve(name_to_id_map.size()); - boost::range::copy(name_to_id_map | boost::adaptors::map_values, std::back_inserter(res)); + res.reserve(entries_by_name.size()); + for (const auto * entry : entries_by_name | boost::adaptors::map_values) + res.emplace_back(entry->id); return res; } bool DiskAccessStorage::existsImpl(const UUID & id) const { std::lock_guard lock{mutex}; - return id_to_entry_map.count(id); + return entries_by_id.count(id); } AccessEntityPtr DiskAccessStorage::readImpl(const UUID & id) const { std::lock_guard lock{mutex}; - auto it = id_to_entry_map.find(id); - if (it == id_to_entry_map.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) throwNotFound(id); const auto & entry = it->second; @@ -582,8 +581,8 @@ AccessEntityPtr DiskAccessStorage::readImpl(const UUID & id) const String DiskAccessStorage::readNameImpl(const UUID & id) const { std::lock_guard lock{mutex}; - auto it = id_to_entry_map.find(id); - if (it == id_to_entry_map.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) throwNotFound(id); return String{it->second.name}; } @@ -610,24 +609,24 @@ UUID DiskAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool repl void DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications) { const String & name = new_entity->getName(); - std::type_index type = new_entity->getType(); + EntityType type = new_entity->getType(); if (!initialized) throw Exception( - "Cannot insert " + new_entity->getTypeName() + " " + backQuote(name) + " to " + getStorageName() - + " because the output directory is not set", + "Cannot insert " + new_entity->outputTypeAndName() + " to storage [" + getStorageName() + + "] because the output directory is not set", ErrorCodes::LOGICAL_ERROR); /// Check that we can insert. - auto it_by_id = id_to_entry_map.find(id); - if (it_by_id != id_to_entry_map.end()) + auto it_by_id = entries_by_id.find(id); + if (it_by_id != entries_by_id.end()) { const auto & existing_entry = it_by_id->second; throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getName()); } - auto & name_to_id_map = name_to_id_maps.at(type); - auto it_by_name = name_to_id_map.find(name); - bool name_collision = (it_by_name != name_to_id_map.end()); + auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; + auto it_by_name = entries_by_name.find(name); + bool name_collision = (it_by_name != entries_by_name.end()); if (name_collision && !replace_if_exists) throwNameCollisionCannotInsert(type, name); @@ -636,13 +635,15 @@ void DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & ne writeAccessEntityToDisk(id, *new_entity); if (name_collision && replace_if_exists) - removeNoLock(it_by_name->second, notifications); + removeNoLock(it_by_name->second->id, notifications); /// Do insertion. - it_by_name = name_to_id_map.emplace(name, id).first; - it_by_id = id_to_entry_map.emplace(id, Entry{it_by_name->first, type}).first; - auto & entry = it_by_id->second; + auto & entry = entries_by_id[id]; + entry.id = id; + entry.type = type; + entry.name = name; entry.entity = new_entity; + entries_by_name[entry.name] = &entry; prepareNotifications(id, entry, false, notifications); } @@ -659,22 +660,21 @@ void DiskAccessStorage::removeImpl(const UUID & id) void DiskAccessStorage::removeNoLock(const UUID & id, Notifications & notifications) { - auto it = id_to_entry_map.find(id); - if (it == id_to_entry_map.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) throwNotFound(id); Entry & entry = it->second; - String name{it->second.name}; - std::type_index type = it->second.type; + EntityType type = entry.type; scheduleWriteLists(type); deleteAccessEntityOnDisk(id); /// Do removing. prepareNotifications(id, entry, true, notifications); - id_to_entry_map.erase(it); - auto & name_to_id_map = name_to_id_maps.at(type); - name_to_id_map.erase(name); + auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; + entries_by_name.erase(entry.name); + entries_by_id.erase(it); } @@ -690,8 +690,8 @@ void DiskAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_fu void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications) { - auto it = id_to_entry_map.find(id); - if (it == id_to_entry_map.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) throwNotFound(id); Entry & entry = it->second; @@ -700,18 +700,22 @@ void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_ auto old_entity = entry.entity; auto new_entity = update_func(old_entity); + if (!new_entity->isTypeOf(old_entity->getType())) + throwBadCast(id, new_entity->getType(), new_entity->getName(), old_entity->getType()); + if (*new_entity == *old_entity) return; - String new_name = new_entity->getName(); - auto old_name = entry.name; - const std::type_index type = entry.type; + const String & new_name = new_entity->getName(); + const String & old_name = old_entity->getName(); + const EntityType type = entry.type; + auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; + bool name_changed = (new_name != old_name); if (name_changed) { - const auto & name_to_id_map = name_to_id_maps.at(type); - if (name_to_id_map.count(new_name)) - throwNameCollisionCannotRename(type, String{old_name}, new_name); + if (entries_by_name.count(new_name)) + throwNameCollisionCannotRename(type, old_name, new_name); scheduleWriteLists(type); } @@ -720,10 +724,9 @@ void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_ if (name_changed) { - auto & name_to_id_map = name_to_id_maps.at(type); - name_to_id_map.erase(String{old_name}); - auto it_by_name = name_to_id_map.emplace(new_name, id).first; - entry.name = it_by_name->first; + entries_by_name.erase(entry.name); + entry.name = new_name; + entries_by_name[entry.name] = &entry; } prepareNotifications(id, entry, false, notifications); @@ -732,19 +735,19 @@ void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_ AccessEntityPtr DiskAccessStorage::readAccessEntityFromDisk(const UUID & id) const { - return readAccessEntityFile(getAccessEntityFilePath(directory_path, id)); + return readEntityFile(getEntityFilePath(directory_path, id)); } void DiskAccessStorage::writeAccessEntityToDisk(const UUID & id, const IAccessEntity & entity) const { - writeAccessEntityFile(getAccessEntityFilePath(directory_path, id), entity); + writeEntityFile(getEntityFilePath(directory_path, id), entity); } void DiskAccessStorage::deleteAccessEntityOnDisk(const UUID & id) const { - auto file_path = getAccessEntityFilePath(directory_path, id); + auto file_path = getEntityFilePath(directory_path, id); if (!std::filesystem::remove(file_path)) throw Exception("Couldn't delete " + file_path.string(), ErrorCodes::FILE_DOESNT_EXIST); } @@ -759,17 +762,16 @@ void DiskAccessStorage::prepareNotifications(const UUID & id, const Entry & entr for (const auto & handler : entry.handlers_by_id) notifications.push_back({handler, id, entity}); - auto range = handlers_by_type.equal_range(entry.type); - for (auto it = range.first; it != range.second; ++it) - notifications.push_back({it->second, id, entity}); + for (const auto & handler : handlers_by_type[static_cast(entry.type)]) + notifications.push_back({handler, id, entity}); } ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const { std::lock_guard lock{mutex}; - auto it = id_to_entry_map.find(id); - if (it == id_to_entry_map.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) return {}; const Entry & entry = it->second; auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler); @@ -777,8 +779,8 @@ ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, con return [this, id, handler_it] { std::lock_guard lock2{mutex}; - auto it2 = id_to_entry_map.find(id); - if (it2 != id_to_entry_map.end()) + auto it2 = entries_by_id.find(id); + if (it2 != entries_by_id.end()) { const Entry & entry2 = it2->second; entry2.handlers_by_id.erase(handler_it); @@ -786,23 +788,26 @@ ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, con }; } -ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const +ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const { std::lock_guard lock{mutex}; - auto handler_it = handlers_by_type.emplace(type, handler); + auto & handlers = handlers_by_type[static_cast(type)]; + handlers.push_back(handler); + auto handler_it = std::prev(handlers.end()); - return [this, handler_it] + return [this, type, handler_it] { std::lock_guard lock2{mutex}; - handlers_by_type.erase(handler_it); + auto & handlers2 = handlers_by_type[static_cast(type)]; + handlers2.erase(handler_it); }; } bool DiskAccessStorage::hasSubscriptionImpl(const UUID & id) const { std::lock_guard lock{mutex}; - auto it = id_to_entry_map.find(id); - if (it != id_to_entry_map.end()) + auto it = entries_by_id.find(id); + if (it != entries_by_id.end()) { const Entry & entry = it->second; return !entry.handlers_by_id.empty(); @@ -810,11 +815,11 @@ bool DiskAccessStorage::hasSubscriptionImpl(const UUID & id) const return false; } -bool DiskAccessStorage::hasSubscriptionImpl(std::type_index type) const +bool DiskAccessStorage::hasSubscriptionImpl(EntityType type) const { std::lock_guard lock{mutex}; - auto range = handlers_by_type.equal_range(type); - return range.first != range.second; + const auto & handlers = handlers_by_type[static_cast(type)]; + return !handlers.empty(); } } diff --git a/src/Access/DiskAccessStorage.h b/src/Access/DiskAccessStorage.h index 104c0f1fa38..79a11195318 100644 --- a/src/Access/DiskAccessStorage.h +++ b/src/Access/DiskAccessStorage.h @@ -17,8 +17,8 @@ public: void setDirectory(const String & directory_path_); private: - std::optional findImpl(std::type_index type, const String & name) const override; - std::vector findAllImpl(std::type_index type) const override; + std::optional findImpl(EntityType type, const String & name) const override; + std::vector findAllImpl(EntityType type) const override; bool existsImpl(const UUID & id) const override; AccessEntityPtr readImpl(const UUID & id) const override; String readNameImpl(const UUID & id) const override; @@ -27,14 +27,15 @@ private: void removeImpl(const UUID & id) override; void updateImpl(const UUID & id, const UpdateFunc & update_func) override; ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override; - ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override; + ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const override; bool hasSubscriptionImpl(const UUID & id) const override; - bool hasSubscriptionImpl(std::type_index type) const override; + bool hasSubscriptionImpl(EntityType type) const override; void initialize(const String & directory_path_, Notifications & notifications); + void clear(); bool readLists(); bool writeLists(); - void scheduleWriteLists(std::type_index type); + void scheduleWriteLists(EntityType type); bool rebuildLists(); void startListsWritingThread(); @@ -52,9 +53,9 @@ private: using NameToIDMap = std::unordered_map; struct Entry { - Entry(const std::string_view & name_, std::type_index type_) : name(name_), type(type_) {} - std::string_view name; /// view points to a string in `name_to_id_maps`. - std::type_index type; + UUID id; + String name; + EntityType type; mutable AccessEntityPtr entity; /// may be nullptr, if the entity hasn't been loaded yet. mutable std::list handlers_by_id; }; @@ -63,14 +64,14 @@ private: String directory_path; bool initialized = false; - std::unordered_map name_to_id_maps; - std::unordered_map id_to_entry_map; - boost::container::flat_set types_of_lists_to_write; + std::unordered_map entries_by_id; + std::unordered_map entries_by_name_and_type[static_cast(EntityType::MAX)]; + boost::container::flat_set types_of_lists_to_write; bool failed_to_write_lists = false; /// Whether writing of the list files has been failed since the recent restart of the server. ThreadFromGlobalPool lists_writing_thread; /// List files are written in a separate thread. std::condition_variable lists_writing_thread_should_exit; /// Signals `lists_writing_thread` to exit. std::atomic lists_writing_thread_exited = false; - mutable std::unordered_multimap handlers_by_type; + mutable std::list handlers_by_type[static_cast(EntityType::MAX)]; mutable std::mutex mutex; }; } diff --git a/src/Access/IAccessEntity.cpp b/src/Access/IAccessEntity.cpp index e5e484c4c33..5dc566fe456 100644 --- a/src/Access/IAccessEntity.cpp +++ b/src/Access/IAccessEntity.cpp @@ -1,48 +1,12 @@ #include -#include -#include -#include -#include -#include -#include namespace DB { -String IAccessEntity::getTypeName(std::type_index type) -{ - if (type == typeid(User)) - return "User"; - if (type == typeid(Quota)) - return "Quota"; - if (type == typeid(RowPolicy)) - return "Row policy"; - if (type == typeid(Role)) - return "Role"; - if (type == typeid(SettingsProfile)) - return "Settings profile"; - return demangle(type.name()); -} - - -const char * IAccessEntity::getKeyword(std::type_index type) -{ - if (type == typeid(User)) - return "USER"; - if (type == typeid(Quota)) - return "QUOTA"; - if (type == typeid(RowPolicy)) - return "ROW POLICY"; - if (type == typeid(Role)) - return "ROLE"; - if (type == typeid(SettingsProfile)) - return "SETTINGS PROFILE"; - __builtin_unreachable(); -} - bool IAccessEntity::equal(const IAccessEntity & other) const { return (name == other.name) && (getType() == other.getType()); } + } diff --git a/src/Access/IAccessEntity.h b/src/Access/IAccessEntity.h index 9a4e80fbde9..39a5cefa7d7 100644 --- a/src/Access/IAccessEntity.h +++ b/src/Access/IAccessEntity.h @@ -2,12 +2,24 @@ #include #include +#include +#include #include -#include namespace DB { +namespace ErrorCodes +{ + extern const int UNKNOWN_USER; + extern const int UNKNOWN_ROLE; + extern const int UNKNOWN_ROW_POLICY; + extern const int UNKNOWN_QUOTA; + extern const int THERE_IS_NO_PROFILE; + extern const int LOGICAL_ERROR; +} + + /// Access entity is a set of data which have a name and a type. Access entity control something related to the access control. /// Entities can be stored to a file or another storage, see IAccessStorage. struct IAccessEntity @@ -17,15 +29,39 @@ struct IAccessEntity virtual ~IAccessEntity() = default; virtual std::shared_ptr clone() const = 0; - std::type_index getType() const { return typeid(*this); } - static String getTypeName(std::type_index type); - const String getTypeName() const { return getTypeName(getType()); } - static const char * getKeyword(std::type_index type); - const char * getKeyword() const { return getKeyword(getType()); } + enum class Type + { + USER, + ROLE, + SETTINGS_PROFILE, + ROW_POLICY, + QUOTA, - template - bool isTypeOf() const { return isTypeOf(typeid(EntityType)); } - bool isTypeOf(std::type_index type) const { return type == getType(); } + MAX, + }; + + virtual Type getType() const = 0; + + struct TypeInfo + { + const char * const raw_name; + const String name; /// Uppercased with spaces instead of underscores, e.g. "SETTINGS PROFILE". + const String alias; /// Alias of the keyword or empty string, e.g. "PROFILE". + const String name_for_output_with_entity_name; /// Lowercased with spaces instead of underscores, e.g. "settings profile". + const char unique_char; /// Unique character for this type. E.g. 'P' for SETTINGS_PROFILE. + const String list_filename; /// Name of the file containing list of objects of this type, including the file extension ".list". + const int not_found_error_code; + + static const TypeInfo & get(Type type_); + String outputWithEntityName(const String & entity_name) const; + }; + + const TypeInfo & getTypeInfo() const { return TypeInfo::get(getType()); } + String outputTypeAndName() const { return getTypeInfo().outputWithEntityName(getName()); } + + template + bool isTypeOf() const { return isTypeOf(EntityClassT::TYPE); } + bool isTypeOf(Type type) const { return type == getType(); } virtual void setName(const String & name_) { name = name_; } const String & getName() const { return name; } @@ -39,12 +75,74 @@ protected: virtual bool equal(const IAccessEntity & other) const; /// Helper function to define clone() in the derived classes. - template + template std::shared_ptr cloneImpl() const { - return std::make_shared(typeid_cast(*this)); + return std::make_shared(typeid_cast(*this)); } }; using AccessEntityPtr = std::shared_ptr; + + +inline const IAccessEntity::TypeInfo & IAccessEntity::TypeInfo::get(Type type_) +{ + static constexpr auto make_info = [](const char * raw_name_, char unique_char_, const char * list_filename_, int not_found_error_code_) + { + String init_name = raw_name_; + boost::to_upper(init_name); + boost::replace_all(init_name, "_", " "); + String init_alias; + if (auto underscore_pos = init_name.find_first_of(" "); underscore_pos != String::npos) + init_alias = init_name.substr(underscore_pos + 1); + String init_name_for_output_with_entity_name = init_name; + boost::to_lower(init_name_for_output_with_entity_name); + return TypeInfo{raw_name_, std::move(init_name), std::move(init_alias), std::move(init_name_for_output_with_entity_name), unique_char_, list_filename_, not_found_error_code_}; + }; + + switch (type_) + { + case Type::USER: + { + static const auto info = make_info("USER", 'U', "users.list", ErrorCodes::UNKNOWN_USER); + return info; + } + case Type::ROLE: + { + static const auto info = make_info("ROLE", 'R', "roles.list", ErrorCodes::UNKNOWN_ROLE); + return info; + } + case Type::SETTINGS_PROFILE: + { + static const auto info = make_info("SETTINGS_PROFILE", 'S', "settings_profiles.list", ErrorCodes::THERE_IS_NO_PROFILE); + return info; + } + case Type::ROW_POLICY: + { + static const auto info = make_info("ROW_POLICY", 'P', "row_policies.list", ErrorCodes::UNKNOWN_ROW_POLICY); + return info; + } + case Type::QUOTA: + { + static const auto info = make_info("QUOTA", 'Q', "quotas.list", ErrorCodes::UNKNOWN_QUOTA); + return info; + } + case Type::MAX: break; + } + throw Exception("Unknown type: " + std::to_string(static_cast(type_)), ErrorCodes::LOGICAL_ERROR); +} + +inline String IAccessEntity::TypeInfo::outputWithEntityName(const String & entity_name) const +{ + String msg = name_for_output_with_entity_name; + msg += " "; + msg += backQuote(entity_name); + return msg; +} + +inline String toString(IAccessEntity::Type type) +{ + return IAccessEntity::TypeInfo::get(type).name; +} + } diff --git a/src/Access/IAccessStorage.cpp b/src/Access/IAccessStorage.cpp index 3dfc3e232ba..8e4314ec7c5 100644 --- a/src/Access/IAccessStorage.cpp +++ b/src/Access/IAccessStorage.cpp @@ -13,27 +13,44 @@ namespace DB namespace ErrorCodes { extern const int BAD_CAST; - extern const int ACCESS_ENTITY_NOT_FOUND; extern const int ACCESS_ENTITY_ALREADY_EXISTS; + extern const int ACCESS_ENTITY_NOT_FOUND; extern const int ACCESS_STORAGE_READONLY; - extern const int UNKNOWN_USER; - extern const int UNKNOWN_ROLE; } -std::vector IAccessStorage::findAll(std::type_index type) const +namespace +{ + using EntityType = IAccessStorage::EntityType; + using EntityTypeInfo = IAccessStorage::EntityTypeInfo; + + bool isNotFoundErrorCode(int error_code) + { + if (error_code == ErrorCodes::ACCESS_ENTITY_NOT_FOUND) + return true; + + for (auto type : ext::range(EntityType::MAX)) + if (error_code == EntityTypeInfo::get(type).not_found_error_code) + return true; + + return false; + } +} + + +std::vector IAccessStorage::findAll(EntityType type) const { return findAllImpl(type); } -std::optional IAccessStorage::find(std::type_index type, const String & name) const +std::optional IAccessStorage::find(EntityType type, const String & name) const { return findImpl(type, name); } -std::vector IAccessStorage::find(std::type_index type, const Strings & names) const +std::vector IAccessStorage::find(EntityType type, const Strings & names) const { std::vector ids; ids.reserve(names.size()); @@ -47,7 +64,7 @@ std::vector IAccessStorage::find(std::type_index type, const Strings & nam } -UUID IAccessStorage::getID(std::type_index type, const String & name) const +UUID IAccessStorage::getID(EntityType type, const String & name) const { auto id = findImpl(type, name); if (id) @@ -56,7 +73,7 @@ UUID IAccessStorage::getID(std::type_index type, const String & name) const } -std::vector IAccessStorage::getIDs(std::type_index type, const Strings & names) const +std::vector IAccessStorage::getIDs(EntityType type, const Strings & names) const { std::vector ids; ids.reserve(names.size()); @@ -190,6 +207,7 @@ void IAccessStorage::remove(const UUID & id) void IAccessStorage::remove(const std::vector & ids) { String error_message; + std::optional error_code; for (const auto & id : ids) { try @@ -198,13 +216,17 @@ void IAccessStorage::remove(const std::vector & ids) } catch (Exception & e) { - if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND) + if (!isNotFoundErrorCode(e.code())) throw; error_message += (error_message.empty() ? "" : ". ") + e.message(); + if (error_code && (*error_code != e.code())) + error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND; + else + error_code = e.code(); } } if (!error_message.empty()) - throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND); + throw Exception(error_message, *error_code); } @@ -250,6 +272,7 @@ void IAccessStorage::update(const UUID & id, const UpdateFunc & update_func) void IAccessStorage::update(const std::vector & ids, const UpdateFunc & update_func) { String error_message; + std::optional error_code; for (const auto & id : ids) { try @@ -258,13 +281,17 @@ void IAccessStorage::update(const std::vector & ids, const UpdateFunc & up } catch (Exception & e) { - if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND) + if (!isNotFoundErrorCode(e.code())) throw; error_message += (error_message.empty() ? "" : ". ") + e.message(); + if (error_code && (*error_code != e.code())) + error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND; + else + error_code = e.code(); } } if (!error_message.empty()) - throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND); + throw Exception(error_message, *error_code); } @@ -301,7 +328,7 @@ std::vector IAccessStorage::tryUpdate(const std::vector & ids, const } -ext::scope_guard IAccessStorage::subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const +ext::scope_guard IAccessStorage::subscribeForChanges(EntityType type, const OnChangedHandler & handler) const { return subscribeForChangesImpl(type, handler); } @@ -322,7 +349,7 @@ ext::scope_guard IAccessStorage::subscribeForChanges(const std::vector & i } -bool IAccessStorage::hasSubscription(std::type_index type) const +bool IAccessStorage::hasSubscription(EntityType type) const { return hasSubscriptionImpl(type); } @@ -361,79 +388,72 @@ Poco::Logger * IAccessStorage::getLogger() const void IAccessStorage::throwNotFound(const UUID & id) const { - throw Exception("ID {" + toString(id) + "} not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND); + throw Exception("ID {" + toString(id) + "} not found in [" + getStorageName() + "]", ErrorCodes::ACCESS_ENTITY_NOT_FOUND); } -void IAccessStorage::throwNotFound(std::type_index type, const String & name) const +void IAccessStorage::throwNotFound(EntityType type, const String & name) const { - int error_code; - if (type == typeid(User)) - error_code = ErrorCodes::UNKNOWN_USER; - else if (type == typeid(Role)) - error_code = ErrorCodes::UNKNOWN_ROLE; - else - error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND; - - throw Exception(getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), error_code); + int error_code = EntityTypeInfo::get(type).not_found_error_code; + throw Exception("There is no " + outputEntityTypeAndName(type, name) + " in [" + getStorageName() + "]", error_code); } -void IAccessStorage::throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type) +void IAccessStorage::throwBadCast(const UUID & id, EntityType type, const String & name, EntityType required_type) { throw Exception( - "ID {" + toString(id) + "}: " + getTypeName(type) + backQuote(name) + " expected to be of type " + getTypeName(required_type), + "ID {" + toString(id) + "}: " + outputEntityTypeAndName(type, name) + " expected to be of type " + toString(required_type), ErrorCodes::BAD_CAST); } -void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const +void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, EntityType type, const String & name, EntityType existing_type, const String & existing_name) const { throw Exception( - getTypeName(type) + " " + backQuote(name) + ": cannot insert because the ID {" + toString(id) + "} is already used by " - + getTypeName(existing_type) + " " + backQuote(existing_name) + " in " + getStorageName(), + outputEntityTypeAndName(type, name) + ": cannot insert because the ID {" + toString(id) + "} is already used by " + + outputEntityTypeAndName(existing_type, existing_name) + " in [" + getStorageName() + "]", ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS); } -void IAccessStorage::throwNameCollisionCannotInsert(std::type_index type, const String & name) const +void IAccessStorage::throwNameCollisionCannotInsert(EntityType type, const String & name) const { throw Exception( - getTypeName(type) + " " + backQuote(name) + ": cannot insert because " + getTypeName(type) + " " + backQuote(name) - + " already exists in " + getStorageName(), + outputEntityTypeAndName(type, name) + ": cannot insert because " + outputEntityTypeAndName(type, name) + " already exists in [" + + getStorageName() + "]", ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS); } -void IAccessStorage::throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const +void IAccessStorage::throwNameCollisionCannotRename(EntityType type, const String & old_name, const String & new_name) const { throw Exception( - getTypeName(type) + " " + backQuote(old_name) + ": cannot rename to " + backQuote(new_name) + " because " + getTypeName(type) + " " - + backQuote(new_name) + " already exists in " + getStorageName(), + outputEntityTypeAndName(type, old_name) + ": cannot rename to " + backQuote(new_name) + " because " + + outputEntityTypeAndName(type, new_name) + " already exists in [" + getStorageName() + "]", ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS); } -void IAccessStorage::throwReadonlyCannotInsert(std::type_index type, const String & name) const +void IAccessStorage::throwReadonlyCannotInsert(EntityType type, const String & name) const { throw Exception( - "Cannot insert " + getTypeName(type) + " " + backQuote(name) + " to " + getStorageName() + " because this storage is readonly", + "Cannot insert " + outputEntityTypeAndName(type, name) + " to [" + getStorageName() + "] because this storage is readonly", ErrorCodes::ACCESS_STORAGE_READONLY); } -void IAccessStorage::throwReadonlyCannotUpdate(std::type_index type, const String & name) const +void IAccessStorage::throwReadonlyCannotUpdate(EntityType type, const String & name) const { throw Exception( - "Cannot update " + getTypeName(type) + " " + backQuote(name) + " in " + getStorageName() + " because this storage is readonly", + "Cannot update " + outputEntityTypeAndName(type, name) + " in [" + getStorageName() + "] because this storage is readonly", ErrorCodes::ACCESS_STORAGE_READONLY); } -void IAccessStorage::throwReadonlyCannotRemove(std::type_index type, const String & name) const +void IAccessStorage::throwReadonlyCannotRemove(EntityType type, const String & name) const { throw Exception( - "Cannot remove " + getTypeName(type) + " " + backQuote(name) + " from " + getStorageName() + " because this storage is readonly", + "Cannot remove " + outputEntityTypeAndName(type, name) + " from [" + getStorageName() + "] because this storage is readonly", ErrorCodes::ACCESS_STORAGE_READONLY); } } diff --git a/src/Access/IAccessStorage.h b/src/Access/IAccessStorage.h index d2413abc4a5..081fed87bd2 100644 --- a/src/Access/IAccessStorage.h +++ b/src/Access/IAccessStorage.h @@ -25,50 +25,53 @@ public: /// Returns the name of this storage. const String & getStorageName() const { return storage_name; } - /// Returns the identifiers of all the entities of a specified type contained in the storage. - std::vector findAll(std::type_index type) const; + using EntityType = IAccessEntity::Type; + using EntityTypeInfo = IAccessEntity::TypeInfo; - template - std::vector findAll() const { return findAll(typeid(EntityType)); } + /// Returns the identifiers of all the entities of a specified type contained in the storage. + std::vector findAll(EntityType type) const; + + template + std::vector findAll() const { return findAll(EntityClassT::TYPE); } /// Searchs for an entity with specified type and name. Returns std::nullopt if not found. - std::optional find(std::type_index type, const String & name) const; + std::optional find(EntityType type, const String & name) const; - template - std::optional find(const String & name) const { return find(typeid(EntityType), name); } + template + std::optional find(const String & name) const { return find(EntityClassT::TYPE, name); } - std::vector find(std::type_index type, const Strings & names) const; + std::vector find(EntityType type, const Strings & names) const; - template - std::vector find(const Strings & names) const { return find(typeid(EntityType), names); } + template + std::vector find(const Strings & names) const { return find(EntityClassT::TYPE, names); } /// Searchs for an entity with specified name and type. Throws an exception if not found. - UUID getID(std::type_index type, const String & name) const; + UUID getID(EntityType type, const String & name) const; - template - UUID getID(const String & name) const { return getID(typeid(EntityType), name); } + template + UUID getID(const String & name) const { return getID(EntityClassT::TYPE, name); } - std::vector getIDs(std::type_index type, const Strings & names) const; + std::vector getIDs(EntityType type, const Strings & names) const; - template - std::vector getIDs(const Strings & names) const { return getIDs(typeid(EntityType), names); } + template + std::vector getIDs(const Strings & names) const { return getIDs(EntityClassT::TYPE, names); } /// Returns whether there is an entity with such identifier in the storage. bool exists(const UUID & id) const; /// Reads an entity. Throws an exception if not found. - template - std::shared_ptr read(const UUID & id) const; + template + std::shared_ptr read(const UUID & id) const; - template - std::shared_ptr read(const String & name) const; + template + std::shared_ptr read(const String & name) const; /// Reads an entity. Returns nullptr if not found. - template - std::shared_ptr tryRead(const UUID & id) const; + template + std::shared_ptr tryRead(const UUID & id) const; - template - std::shared_ptr tryRead(const String & name) const; + template + std::shared_ptr tryRead(const String & name) const; /// Reads only name of an entity. String readName(const UUID & id) const; @@ -118,22 +121,22 @@ public: /// Subscribes for all changes. /// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only). - ext::scope_guard subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const; + ext::scope_guard subscribeForChanges(EntityType type, const OnChangedHandler & handler) const; - template - ext::scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(typeid(EntityType), handler); } + template + ext::scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(EntityClassT::TYPE, handler); } /// Subscribes for changes of a specific entry. /// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only). ext::scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const; ext::scope_guard subscribeForChanges(const std::vector & ids, const OnChangedHandler & handler) const; - bool hasSubscription(std::type_index type) const; + bool hasSubscription(EntityType type) const; bool hasSubscription(const UUID & id) const; protected: - virtual std::optional findImpl(std::type_index type, const String & name) const = 0; - virtual std::vector findAllImpl(std::type_index type) const = 0; + virtual std::optional findImpl(EntityType type, const String & name) const = 0; + virtual std::vector findAllImpl(EntityType type) const = 0; virtual bool existsImpl(const UUID & id) const = 0; virtual AccessEntityPtr readImpl(const UUID & id) const = 0; virtual String readNameImpl(const UUID & id) const = 0; @@ -142,23 +145,23 @@ protected: virtual void removeImpl(const UUID & id) = 0; virtual void updateImpl(const UUID & id, const UpdateFunc & update_func) = 0; virtual ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0; - virtual ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const = 0; + virtual ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const = 0; virtual bool hasSubscriptionImpl(const UUID & id) const = 0; - virtual bool hasSubscriptionImpl(std::type_index type) const = 0; + virtual bool hasSubscriptionImpl(EntityType type) const = 0; static UUID generateRandomID(); Poco::Logger * getLogger() const; - static String getTypeName(std::type_index type) { return IAccessEntity::getTypeName(type); } + static String outputEntityTypeAndName(EntityType type, const String & name) { return EntityTypeInfo::get(type).outputWithEntityName(name); } [[noreturn]] void throwNotFound(const UUID & id) const; - [[noreturn]] void throwNotFound(std::type_index type, const String & name) const; - [[noreturn]] static void throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type); + [[noreturn]] void throwNotFound(EntityType type, const String & name) const; + [[noreturn]] static void throwBadCast(const UUID & id, EntityType type, const String & name, EntityType required_type); [[noreturn]] void throwIDCollisionCannotInsert( - const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const; - [[noreturn]] void throwNameCollisionCannotInsert(std::type_index type, const String & name) const; - [[noreturn]] void throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const; - [[noreturn]] void throwReadonlyCannotInsert(std::type_index type, const String & name) const; - [[noreturn]] void throwReadonlyCannotUpdate(std::type_index type, const String & name) const; - [[noreturn]] void throwReadonlyCannotRemove(std::type_index type, const String & name) const; + const UUID & id, EntityType type, const String & name, EntityType existing_type, const String & existing_name) const; + [[noreturn]] void throwNameCollisionCannotInsert(EntityType type, const String & name) const; + [[noreturn]] void throwNameCollisionCannotRename(EntityType type, const String & old_name, const String & new_name) const; + [[noreturn]] void throwReadonlyCannotInsert(EntityType type, const String & name) const; + [[noreturn]] void throwReadonlyCannotUpdate(EntityType type, const String & name) const; + [[noreturn]] void throwReadonlyCannotRemove(EntityType type, const String & name) const; using Notification = std::tuple; using Notifications = std::vector; @@ -172,38 +175,43 @@ private: }; -template -std::shared_ptr IAccessStorage::read(const UUID & id) const +template +std::shared_ptr IAccessStorage::read(const UUID & id) const { auto entity = readImpl(id); - auto ptr = typeid_cast>(entity); - if (ptr) - return ptr; - throwBadCast(id, entity->getType(), entity->getName(), typeid(EntityType)); + if constexpr (std::is_same_v) + return entity; + else + { + auto ptr = typeid_cast>(entity); + if (ptr) + return ptr; + throwBadCast(id, entity->getType(), entity->getName(), EntityClassT::TYPE); + } } -template -std::shared_ptr IAccessStorage::read(const String & name) const +template +std::shared_ptr IAccessStorage::read(const String & name) const { - return read(getID(name)); + return read(getID(name)); } -template -std::shared_ptr IAccessStorage::tryRead(const UUID & id) const +template +std::shared_ptr IAccessStorage::tryRead(const UUID & id) const { auto entity = tryReadBase(id); if (!entity) return nullptr; - return typeid_cast>(entity); + return typeid_cast>(entity); } -template -std::shared_ptr IAccessStorage::tryRead(const String & name) const +template +std::shared_ptr IAccessStorage::tryRead(const String & name) const { - auto id = find(name); - return id ? tryRead(*id) : nullptr; + auto id = find(name); + return id ? tryRead(*id) : nullptr; } } diff --git a/src/Access/MemoryAccessStorage.cpp b/src/Access/MemoryAccessStorage.cpp index 46efd25191b..720b82796b7 100644 --- a/src/Access/MemoryAccessStorage.cpp +++ b/src/Access/MemoryAccessStorage.cpp @@ -1,6 +1,8 @@ #include #include -#include +#include +#include +#include namespace DB @@ -11,11 +13,12 @@ MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_) } -std::optional MemoryAccessStorage::findImpl(std::type_index type, const String & name) const +std::optional MemoryAccessStorage::findImpl(EntityType type, const String & name) const { std::lock_guard lock{mutex}; - auto it = names.find({name, type}); - if (it == names.end()) + const auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; + auto it = entries_by_name.find(name); + if (it == entries_by_name.end()) return {}; Entry & entry = *(it->second); @@ -23,12 +26,12 @@ std::optional MemoryAccessStorage::findImpl(std::type_index type, const St } -std::vector MemoryAccessStorage::findAllImpl(std::type_index type) const +std::vector MemoryAccessStorage::findAllImpl(EntityType type) const { std::lock_guard lock{mutex}; std::vector result; - result.reserve(entries.size()); - for (const auto & [id, entry] : entries) + result.reserve(entries_by_id.size()); + for (const auto & [id, entry] : entries_by_id) if (entry.entity->isTypeOf(type)) result.emplace_back(id); return result; @@ -38,15 +41,15 @@ std::vector MemoryAccessStorage::findAllImpl(std::type_index type) const bool MemoryAccessStorage::existsImpl(const UUID & id) const { std::lock_guard lock{mutex}; - return entries.count(id); + return entries_by_id.count(id); } AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id) const { std::lock_guard lock{mutex}; - auto it = entries.find(id); - if (it == entries.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) throwNotFound(id); const Entry & entry = it->second; return entry.entity; @@ -74,18 +77,19 @@ UUID MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool re void MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications) { const String & name = new_entity->getName(); - std::type_index type = new_entity->getType(); + EntityType type = new_entity->getType(); /// Check that we can insert. - auto it = entries.find(id); - if (it != entries.end()) + auto it = entries_by_id.find(id); + if (it != entries_by_id.end()) { const auto & existing_entry = it->second; throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getName()); } - auto it2 = names.find({name, type}); - if (it2 != names.end()) + auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; + auto it2 = entries_by_name.find(name); + if (it2 != entries_by_name.end()) { const auto & existing_entry = *(it2->second); if (replace_if_exists) @@ -95,10 +99,10 @@ void MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & } /// Do insertion. - auto & entry = entries[id]; + auto & entry = entries_by_id[id]; entry.id = id; entry.entity = new_entity; - names[std::pair{name, type}] = &entry; + entries_by_name[name] = &entry; prepareNotifications(entry, false, notifications); } @@ -115,19 +119,20 @@ void MemoryAccessStorage::removeImpl(const UUID & id) void MemoryAccessStorage::removeNoLock(const UUID & id, Notifications & notifications) { - auto it = entries.find(id); - if (it == entries.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) throwNotFound(id); Entry & entry = it->second; const String & name = entry.entity->getName(); - std::type_index type = entry.entity->getType(); + EntityType type = entry.entity->getType(); prepareNotifications(entry, true, notifications); /// Do removing. - names.erase({name, type}); - entries.erase(it); + auto & entries_by_name = entries_by_name_and_type[static_cast(type)]; + entries_by_name.erase(name); + entries_by_id.erase(it); } @@ -143,14 +148,17 @@ void MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_ void MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications) { - auto it = entries.find(id); - if (it == entries.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) throwNotFound(id); Entry & entry = it->second; auto old_entity = entry.entity; auto new_entity = update_func(old_entity); + if (!new_entity->isTypeOf(old_entity->getType())) + throwBadCast(id, new_entity->getType(), new_entity->getName(), old_entity->getType()); + if (*new_entity == *old_entity) return; @@ -158,12 +166,12 @@ void MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & updat if (new_entity->getName() != old_entity->getName()) { - auto it2 = names.find({new_entity->getName(), new_entity->getType()}); - if (it2 != names.end()) + auto & entries_by_name = entries_by_name_and_type[static_cast(old_entity->getType())]; + auto it2 = entries_by_name.find(new_entity->getName()); + if (it2 != entries_by_name.end()) throwNameCollisionCannotRename(old_entity->getType(), old_entity->getName(), new_entity->getName()); - names.erase({old_entity->getName(), old_entity->getType()}); - names[std::pair{new_entity->getName(), new_entity->getType()}] = &entry; + entries_by_name[new_entity->getName()] = &entry; } prepareNotifications(entry, false, notifications); @@ -192,43 +200,47 @@ void MemoryAccessStorage::setAll(const std::vector> & all_entities, Notifications & notifications) { - /// Get list of the currently used IDs. Later we will remove those of them which are not used anymore. - std::unordered_set not_used_ids; - for (const auto & id_and_entry : entries) - not_used_ids.emplace(id_and_entry.first); + boost::container::flat_set not_used_ids; + std::vector conflicting_ids; - /// Remove conflicting entities. + /// Get the list of currently used IDs. Later we will remove those of them which are not used anymore. + for (const auto & id : entries_by_id | boost::adaptors::map_keys) + not_used_ids.emplace(id); + + /// Get the list of conflicting IDs and update the list of currently used ones. for (const auto & [id, entity] : all_entities) { - auto it = entries.find(id); - if (it != entries.end()) + auto it = entries_by_id.find(id); + if (it != entries_by_id.end()) { not_used_ids.erase(id); /// ID is used. + Entry & entry = it->second; if (entry.entity->getType() != entity->getType()) - { - removeNoLock(id, notifications); - continue; - } + conflicting_ids.emplace_back(id); /// Conflict: same ID, different type. } - auto it2 = names.find({entity->getName(), entity->getType()}); - if (it2 != names.end()) + + const auto & entries_by_name = entries_by_name_and_type[static_cast(entity->getType())]; + auto it2 = entries_by_name.find(entity->getName()); + if (it2 != entries_by_name.end()) { Entry & entry = *(it2->second); if (entry.id != id) - removeNoLock(id, notifications); + conflicting_ids.emplace_back(entry.id); /// Conflict: same name and type, different ID. } } - /// Remove entities which are not used anymore. - for (const auto & id : not_used_ids) + /// Remove entities which are not used anymore and which are in conflict with new entities. + boost::container::flat_set ids_to_remove = std::move(not_used_ids); + boost::range::copy(conflicting_ids, std::inserter(ids_to_remove, ids_to_remove.end())); + for (const auto & id : ids_to_remove) removeNoLock(id, notifications); /// Insert or update entities. for (const auto & [id, entity] : all_entities) { - auto it = entries.find(id); - if (it != entries.end()) + auto it = entries_by_id.find(id); + if (it != entries_by_id.end()) { if (*(it->second.entity) != *entity) { @@ -244,24 +256,27 @@ void MemoryAccessStorage::setAllNoLock(const std::vectorgetType()); - for (auto it = range.first; it != range.second; ++it) - notifications.push_back({it->second, entry.id, remove ? nullptr : entry.entity}); + for (const auto & handler : handlers_by_type[static_cast(entry.entity->getType())]) + notifications.push_back({handler, entry.id, entity}); } -ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const +ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const { std::lock_guard lock{mutex}; - auto handler_it = handlers_by_type.emplace(type, handler); + auto & handlers = handlers_by_type[static_cast(type)]; + handlers.push_back(handler); + auto handler_it = std::prev(handlers.end()); - return [this, handler_it] + return [this, type, handler_it] { std::lock_guard lock2{mutex}; - handlers_by_type.erase(handler_it); + auto & handlers2 = handlers_by_type[static_cast(type)]; + handlers2.erase(handler_it); }; } @@ -269,8 +284,8 @@ ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(std::type_index ty ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const { std::lock_guard lock{mutex}; - auto it = entries.find(id); - if (it == entries.end()) + auto it = entries_by_id.find(id); + if (it == entries_by_id.end()) return {}; const Entry & entry = it->second; auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler); @@ -278,8 +293,8 @@ ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, c return [this, id, handler_it] { std::lock_guard lock2{mutex}; - auto it2 = entries.find(id); - if (it2 != entries.end()) + auto it2 = entries_by_id.find(id); + if (it2 != entries_by_id.end()) { const Entry & entry2 = it2->second; entry2.handlers_by_id.erase(handler_it); @@ -291,8 +306,8 @@ ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, c bool MemoryAccessStorage::hasSubscriptionImpl(const UUID & id) const { std::lock_guard lock{mutex}; - auto it = entries.find(id); - if (it != entries.end()) + auto it = entries_by_id.find(id); + if (it != entries_by_id.end()) { const Entry & entry = it->second; return !entry.handlers_by_id.empty(); @@ -301,10 +316,10 @@ bool MemoryAccessStorage::hasSubscriptionImpl(const UUID & id) const } -bool MemoryAccessStorage::hasSubscriptionImpl(std::type_index type) const +bool MemoryAccessStorage::hasSubscriptionImpl(EntityType type) const { std::lock_guard lock{mutex}; - auto range = handlers_by_type.equal_range(type); - return range.first != range.second; + const auto & handlers = handlers_by_type[static_cast(type)]; + return !handlers.empty(); } } diff --git a/src/Access/MemoryAccessStorage.h b/src/Access/MemoryAccessStorage.h index b93c2868d34..a2fdd0d0044 100644 --- a/src/Access/MemoryAccessStorage.h +++ b/src/Access/MemoryAccessStorage.h @@ -20,8 +20,8 @@ public: void setAll(const std::vector> & all_entities); private: - std::optional findImpl(std::type_index type, const String & name) const override; - std::vector findAllImpl(std::type_index type) const override; + std::optional findImpl(EntityType type, const String & name) const override; + std::vector findAllImpl(EntityType type) const override; bool existsImpl(const UUID & id) const override; AccessEntityPtr readImpl(const UUID & id) const override; String readNameImpl(const UUID & id) const override; @@ -30,9 +30,9 @@ private: void removeImpl(const UUID & id) override; void updateImpl(const UUID & id, const UpdateFunc & update_func) override; ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override; - ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override; + ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const override; bool hasSubscriptionImpl(const UUID & id) const override; - bool hasSubscriptionImpl(std::type_index type) const override; + bool hasSubscriptionImpl(EntityType type) const override; struct Entry { @@ -47,18 +47,9 @@ private: void setAllNoLock(const std::vector> & all_entities, Notifications & notifications); void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const; - using NameTypePair = std::pair; - struct Hash - { - size_t operator()(const NameTypePair & key) const - { - return std::hash{}(key.first) - std::hash{}(key.second); - } - }; - mutable std::mutex mutex; - std::unordered_map entries; /// We want to search entries both by ID and by the pair of name and type. - std::unordered_map names; /// and by the pair of name and type. - mutable std::unordered_multimap handlers_by_type; + std::unordered_map entries_by_id; /// We want to search entries both by ID and by the pair of name and type. + std::unordered_map entries_by_name_and_type[static_cast(EntityType::MAX)]; + mutable std::list handlers_by_type[static_cast(EntityType::MAX)]; }; } diff --git a/src/Access/MultipleAccessStorage.cpp b/src/Access/MultipleAccessStorage.cpp index 740fe1dac04..0dd1f142f31 100644 --- a/src/Access/MultipleAccessStorage.cpp +++ b/src/Access/MultipleAccessStorage.cpp @@ -38,7 +38,7 @@ MultipleAccessStorage::MultipleAccessStorage( } -std::vector MultipleAccessStorage::findMultiple(std::type_index type, const String & name) const +std::vector MultipleAccessStorage::findMultiple(EntityType type, const String & name) const { std::vector ids; for (const auto & nested_storage : nested_storages) @@ -55,7 +55,7 @@ std::vector MultipleAccessStorage::findMultiple(std::type_index type, cons } -std::optional MultipleAccessStorage::findImpl(std::type_index type, const String & name) const +std::optional MultipleAccessStorage::findImpl(EntityType type, const String & name) const { auto ids = findMultiple(type, name); if (ids.empty()) @@ -72,13 +72,13 @@ std::optional MultipleAccessStorage::findImpl(std::type_index type, const } throw Exception( - "Found " + getTypeName(type) + " " + backQuote(name) + " in " + std::to_string(ids.size()) - + " storages: " + joinStorageNames(storages_with_duplicates), + "Found " + outputEntityTypeAndName(type, name) + " in " + std::to_string(ids.size()) + + " storages [" + joinStorageNames(storages_with_duplicates) + "]", ErrorCodes::ACCESS_ENTITY_FOUND_DUPLICATES); } -std::vector MultipleAccessStorage::findAllImpl(std::type_index type) const +std::vector MultipleAccessStorage::findAllImpl(EntityType type) const { std::vector all_ids; for (const auto & nested_storage : nested_storages) @@ -180,11 +180,7 @@ UUID MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool repl } if (!nested_storage_for_insertion) - { - throw Exception( - "Not found a storage to insert " + entity->getTypeName() + backQuote(entity->getName()), - ErrorCodes::ACCESS_STORAGE_FOR_INSERTION_NOT_FOUND); - } + throw Exception("Not found a storage to insert " + entity->outputTypeAndName(), ErrorCodes::ACCESS_STORAGE_FOR_INSERTION_NOT_FOUND); auto id = replace_if_exists ? nested_storage_for_insertion->insertOrReplace(entity) : nested_storage_for_insertion->insert(entity); std::lock_guard lock{ids_cache_mutex}; @@ -214,7 +210,7 @@ ext::scope_guard MultipleAccessStorage::subscribeForChangesImpl(const UUID & id, } -ext::scope_guard MultipleAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const +ext::scope_guard MultipleAccessStorage::subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const { ext::scope_guard subscriptions; for (const auto & nested_storage : nested_storages) @@ -234,7 +230,7 @@ bool MultipleAccessStorage::hasSubscriptionImpl(const UUID & id) const } -bool MultipleAccessStorage::hasSubscriptionImpl(std::type_index type) const +bool MultipleAccessStorage::hasSubscriptionImpl(EntityType type) const { for (const auto & nested_storage : nested_storages) { diff --git a/src/Access/MultipleAccessStorage.h b/src/Access/MultipleAccessStorage.h index 898d55d30de..ec8c8f2a101 100644 --- a/src/Access/MultipleAccessStorage.h +++ b/src/Access/MultipleAccessStorage.h @@ -15,7 +15,7 @@ public: MultipleAccessStorage(std::vector> nested_storages_); - std::vector findMultiple(std::type_index type, const String & name) const; + std::vector findMultiple(EntityType type, const String & name) const; template std::vector findMultiple(const String & name) const { return findMultiple(EntityType::TYPE, name); } @@ -29,8 +29,8 @@ public: const Storage & getStorageByIndex(size_t i) const { return *(nested_storages[i]); } protected: - std::optional findImpl(std::type_index type, const String & name) const override; - std::vector findAllImpl(std::type_index type) const override; + std::optional findImpl(EntityType type, const String & name) const override; + std::vector findAllImpl(EntityType type) const override; bool existsImpl(const UUID & id) const override; AccessEntityPtr readImpl(const UUID & id) const override; String readNameImpl(const UUID &id) const override; @@ -39,9 +39,9 @@ protected: void removeImpl(const UUID & id) override; void updateImpl(const UUID & id, const UpdateFunc & update_func) override; ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override; - ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override; + ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const override; bool hasSubscriptionImpl(const UUID & id) const override; - bool hasSubscriptionImpl(std::type_index type) const override; + bool hasSubscriptionImpl(EntityType type) const override; private: std::vector> nested_storages; diff --git a/src/Access/Quota.h b/src/Access/Quota.h index a3666aa9b52..317ed2dbc47 100644 --- a/src/Access/Quota.h +++ b/src/Access/Quota.h @@ -75,6 +75,8 @@ struct Quota : public IAccessEntity bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } + static constexpr const Type TYPE = Type::QUOTA; + Type getType() const override { return TYPE; } static const char * getNameOfResourceType(ResourceType resource_type); static const char * resourceTypeToKeyword(ResourceType resource_type); diff --git a/src/Access/Role.cpp b/src/Access/Role.cpp index d7bec28c576..3df562ad1f0 100644 --- a/src/Access/Role.cpp +++ b/src/Access/Role.cpp @@ -11,4 +11,5 @@ bool Role::equal(const IAccessEntity & other) const const auto & other_role = typeid_cast(other); return (access == other_role.access) && (granted_roles == other_role.granted_roles) && (settings == other_role.settings); } + } diff --git a/src/Access/Role.h b/src/Access/Role.h index 01a5c6ea2ce..9acb97bdfbd 100644 --- a/src/Access/Role.h +++ b/src/Access/Role.h @@ -17,6 +17,8 @@ struct Role : public IAccessEntity bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } + static constexpr const Type TYPE = Type::ROLE; + Type getType() const override { return TYPE; } }; using RolePtr = std::shared_ptr; diff --git a/src/Access/RowPolicy.h b/src/Access/RowPolicy.h index b3d490e2dbe..9d582d27045 100644 --- a/src/Access/RowPolicy.h +++ b/src/Access/RowPolicy.h @@ -69,6 +69,8 @@ struct RowPolicy : public IAccessEntity bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } + static constexpr const Type TYPE = Type::ROW_POLICY; + Type getType() const override { return TYPE; } /// Which roles or users should use this row policy. ExtendedRoleSet to_roles; diff --git a/src/Access/SettingsProfile.cpp b/src/Access/SettingsProfile.cpp index c2f868502c0..64fb91eb66b 100644 --- a/src/Access/SettingsProfile.cpp +++ b/src/Access/SettingsProfile.cpp @@ -3,6 +3,7 @@ namespace DB { + bool SettingsProfile::equal(const IAccessEntity & other) const { if (!IAccessEntity::equal(other)) @@ -10,4 +11,5 @@ bool SettingsProfile::equal(const IAccessEntity & other) const const auto & other_profile = typeid_cast(other); return (elements == other_profile.elements) && (to_roles == other_profile.to_roles); } + } diff --git a/src/Access/SettingsProfile.h b/src/Access/SettingsProfile.h index b73b45d57cf..9589b5b3eb5 100644 --- a/src/Access/SettingsProfile.h +++ b/src/Access/SettingsProfile.h @@ -18,6 +18,8 @@ struct SettingsProfile : public IAccessEntity bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } + static constexpr const Type TYPE = Type::SETTINGS_PROFILE; + Type getType() const override { return TYPE; } }; using SettingsProfilePtr = std::shared_ptr; diff --git a/src/Access/User.cpp b/src/Access/User.cpp index 459357731ed..f57ec7c1359 100644 --- a/src/Access/User.cpp +++ b/src/Access/User.cpp @@ -13,4 +13,5 @@ bool User::equal(const IAccessEntity & other) const && (access == other_user.access) && (granted_roles == other_user.granted_roles) && (default_roles == other_user.default_roles) && (settings == other_user.settings); } + } diff --git a/src/Access/User.h b/src/Access/User.h index b20f6538e4d..da2fb14e131 100644 --- a/src/Access/User.h +++ b/src/Access/User.h @@ -24,6 +24,8 @@ struct User : public IAccessEntity bool equal(const IAccessEntity & other) const override; std::shared_ptr clone() const override { return cloneImpl(); } + static constexpr const Type TYPE = Type::USER; + Type getType() const override { return TYPE; } }; using UserPtr = std::shared_ptr; diff --git a/src/Access/UsersConfigAccessStorage.cpp b/src/Access/UsersConfigAccessStorage.cpp index 7f96d838e32..43a9e355911 100644 --- a/src/Access/UsersConfigAccessStorage.cpp +++ b/src/Access/UsersConfigAccessStorage.cpp @@ -27,35 +27,24 @@ namespace ErrorCodes namespace { - char getTypeChar(std::type_index type) - { - if (type == typeid(User)) - return 'U'; - if (type == typeid(Quota)) - return 'Q'; - if (type == typeid(RowPolicy)) - return 'P'; - if (type == typeid(SettingsProfile)) - return 'S'; - return 0; - } + using EntityType = IAccessStorage::EntityType; + using EntityTypeInfo = IAccessStorage::EntityTypeInfo; - - UUID generateID(std::type_index type, const String & name) + UUID generateID(EntityType type, const String & name) { Poco::MD5Engine md5; md5.update(name); char type_storage_chars[] = " USRSXML"; - type_storage_chars[0] = getTypeChar(type); + type_storage_chars[0] = EntityTypeInfo::get(type).unique_char; md5.update(type_storage_chars, strlen(type_storage_chars)); UUID result; memcpy(&result, md5.digest().data(), md5.digestLength()); return result; } - UUID generateID(const IAccessEntity & entity) { return generateID(entity.getType(), entity.getName()); } + UserPtr parseUser(const Poco::Util::AbstractConfiguration & config, const String & user_name) { auto user = std::make_shared(); @@ -95,7 +84,7 @@ namespace { auto profile_name = config.getString(profile_name_config); SettingsProfileElement profile_element; - profile_element.parent_profile = generateID(typeid(SettingsProfile), profile_name); + profile_element.parent_profile = generateID(EntityType::SETTINGS_PROFILE, profile_name); user->settings.push_back(std::move(profile_element)); } @@ -260,7 +249,7 @@ namespace for (const auto & user_name : user_names) { if (config.has("users." + user_name + ".quota")) - quota_to_user_ids[config.getString("users." + user_name + ".quota")].push_back(generateID(typeid(User), user_name)); + quota_to_user_ids[config.getString("users." + user_name + ".quota")].push_back(generateID(EntityType::USER, user_name)); } Poco::Util::AbstractConfiguration::Keys quota_names; @@ -346,7 +335,7 @@ namespace auto policy = std::make_shared(); policy->setNameParts(user_name, database, table_name); policy->conditions[RowPolicy::SELECT_FILTER] = filter; - policy->to_roles.add(generateID(typeid(User), user_name)); + policy->to_roles.add(generateID(EntityType::USER, user_name)); policies.push_back(policy); } } @@ -400,7 +389,7 @@ namespace { String parent_profile_name = config.getString(profile_config + "." + key); SettingsProfileElement profile_element; - profile_element.parent_profile = generateID(typeid(SettingsProfile), parent_profile_name); + profile_element.parent_profile = generateID(EntityType::SETTINGS_PROFILE, parent_profile_name); profile->elements.emplace_back(std::move(profile_element)); continue; } @@ -462,13 +451,13 @@ void UsersConfigAccessStorage::setConfiguration(const Poco::Util::AbstractConfig } -std::optional UsersConfigAccessStorage::findImpl(std::type_index type, const String & name) const +std::optional UsersConfigAccessStorage::findImpl(EntityType type, const String & name) const { return memory_storage.find(type, name); } -std::vector UsersConfigAccessStorage::findAllImpl(std::type_index type) const +std::vector UsersConfigAccessStorage::findAllImpl(EntityType type) const { return memory_storage.findAll(type); } @@ -518,7 +507,7 @@ ext::scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(const UUID & } -ext::scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const +ext::scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const { return memory_storage.subscribeForChanges(type, handler); } @@ -530,7 +519,7 @@ bool UsersConfigAccessStorage::hasSubscriptionImpl(const UUID & id) const } -bool UsersConfigAccessStorage::hasSubscriptionImpl(std::type_index type) const +bool UsersConfigAccessStorage::hasSubscriptionImpl(EntityType type) const { return memory_storage.hasSubscription(type); } diff --git a/src/Access/UsersConfigAccessStorage.h b/src/Access/UsersConfigAccessStorage.h index 773d8caa570..d7012cda4ff 100644 --- a/src/Access/UsersConfigAccessStorage.h +++ b/src/Access/UsersConfigAccessStorage.h @@ -23,8 +23,8 @@ public: void setConfiguration(const Poco::Util::AbstractConfiguration & config); private: - std::optional findImpl(std::type_index type, const String & name) const override; - std::vector findAllImpl(std::type_index type) const override; + std::optional findImpl(EntityType type, const String & name) const override; + std::vector findAllImpl(EntityType type) const override; bool existsImpl(const UUID & id) const override; AccessEntityPtr readImpl(const UUID & id) const override; String readNameImpl(const UUID & id) const override; @@ -33,9 +33,9 @@ private: void removeImpl(const UUID & id) override; void updateImpl(const UUID & id, const UpdateFunc & update_func) override; ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override; - ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override; + ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const override; bool hasSubscriptionImpl(const UUID & id) const override; - bool hasSubscriptionImpl(std::type_index type) const override; + bool hasSubscriptionImpl(EntityType type) const override; MemoryAccessStorage memory_storage; }; diff --git a/src/Common/ErrorCodes.cpp b/src/Common/ErrorCodes.cpp index 120c9b93a78..cc8250158a3 100644 --- a/src/Common/ErrorCodes.cpp +++ b/src/Common/ErrorCodes.cpp @@ -493,6 +493,7 @@ namespace ErrorCodes extern const int NO_REMOTE_SHARD_AVAILABLE = 519; extern const int CANNOT_DETACH_DICTIONARY_AS_TABLE = 520; extern const int ATOMIC_RENAME_FAIL = 521; + extern const int UNKNOWN_ROW_POLICY = 522; extern const int KEEPER_EXCEPTION = 999; extern const int POCO_EXCEPTION = 1000; diff --git a/src/Interpreters/InterpreterDropAccessEntityQuery.cpp b/src/Interpreters/InterpreterDropAccessEntityQuery.cpp index 2a47639e15f..be82147a322 100644 --- a/src/Interpreters/InterpreterDropAccessEntityQuery.cpp +++ b/src/Interpreters/InterpreterDropAccessEntityQuery.cpp @@ -9,54 +9,28 @@ #include #include #include -#include namespace DB { -namespace +namespace ErrorCodes { - using Kind = ASTDropAccessEntityQuery::Kind; - - std::type_index getType(Kind kind) - { - switch (kind) - { - case Kind::USER: return typeid(User); - case Kind::ROLE: return typeid(Role); - case Kind::QUOTA: return typeid(Quota); - case Kind::ROW_POLICY: return typeid(RowPolicy); - case Kind::SETTINGS_PROFILE: return typeid(SettingsProfile); - } - __builtin_unreachable(); - } - - AccessType getRequiredAccessType(Kind kind) - { - switch (kind) - { - case Kind::USER: return AccessType::DROP_USER; - case Kind::ROLE: return AccessType::DROP_ROLE; - case Kind::QUOTA: return AccessType::DROP_QUOTA; - case Kind::ROW_POLICY: return AccessType::DROP_ROW_POLICY; - case Kind::SETTINGS_PROFILE: return AccessType::DROP_SETTINGS_PROFILE; - } - __builtin_unreachable(); - } + extern const int NOT_IMPLEMENTED; } +using EntityType = IAccessEntity::Type; + + BlockIO InterpreterDropAccessEntityQuery::execute() { auto & query = query_ptr->as(); auto & access_control = context.getAccessControlManager(); - - std::type_index type = getType(query.kind); - context.checkAccess(getRequiredAccessType(query.kind)); + context.checkAccess(getRequiredAccess()); if (!query.cluster.empty()) return executeDDLQueryOnCluster(query_ptr, context); - if (query.kind == Kind::ROW_POLICY) + if (query.type == EntityType::ROW_POLICY) { Strings names; for (auto & name_parts : query.row_policies_name_parts) @@ -73,10 +47,28 @@ BlockIO InterpreterDropAccessEntityQuery::execute() } if (query.if_exists) - access_control.tryRemove(access_control.find(type, query.names)); + access_control.tryRemove(access_control.find(query.type, query.names)); else - access_control.remove(access_control.getIDs(type, query.names)); + access_control.remove(access_control.getIDs(query.type, query.names)); return {}; } + +AccessRightsElements InterpreterDropAccessEntityQuery::getRequiredAccess() const +{ + const auto & query = query_ptr->as(); + AccessRightsElements res; + switch (query.type) + { + case EntityType::USER: res.emplace_back(AccessType::DROP_USER); return res; + case EntityType::ROLE: res.emplace_back(AccessType::DROP_ROLE); return res; + case EntityType::SETTINGS_PROFILE: res.emplace_back(AccessType::DROP_SETTINGS_PROFILE); return res; + case EntityType::ROW_POLICY: res.emplace_back(AccessType::DROP_ROW_POLICY); return res; + case EntityType::QUOTA: res.emplace_back(AccessType::DROP_QUOTA); return res; + case EntityType::MAX: break; + } + throw Exception( + toString(query.type) + ": type is not supported by DROP query", ErrorCodes::NOT_IMPLEMENTED); +} + } diff --git a/src/Interpreters/InterpreterDropAccessEntityQuery.h b/src/Interpreters/InterpreterDropAccessEntityQuery.h index 2a0e749b265..0db68a0ad78 100644 --- a/src/Interpreters/InterpreterDropAccessEntityQuery.h +++ b/src/Interpreters/InterpreterDropAccessEntityQuery.h @@ -6,6 +6,8 @@ namespace DB { +class AccessRightsElements; + class InterpreterDropAccessEntityQuery : public IInterpreter { public: @@ -14,6 +16,8 @@ public: BlockIO execute() override; private: + AccessRightsElements getRequiredAccess() const; + ASTPtr query_ptr; Context & context; }; diff --git a/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp b/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp index beea1bf953e..11354d1f2a5 100644 --- a/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp +++ b/src/Interpreters/InterpreterShowCreateAccessEntityQuery.cpp @@ -30,9 +30,10 @@ namespace DB { namespace ErrorCodes { - extern const int LOGICAL_ERROR; + extern const int NOT_IMPLEMENTED; } + namespace { ASTPtr getCreateQueryImpl( @@ -203,23 +204,10 @@ namespace return getCreateQueryImpl(*quota, manager, attach_mode); if (const SettingsProfile * profile = typeid_cast(&entity)) return getCreateQueryImpl(*profile, manager, attach_mode); - throw Exception("Unexpected type of access entity: " + entity.getTypeName(), ErrorCodes::LOGICAL_ERROR); + throw Exception(entity.outputTypeAndName() + ": type is not supported by SHOW CREATE query", ErrorCodes::NOT_IMPLEMENTED); } - using Kind = ASTShowCreateAccessEntityQuery::Kind; - - std::type_index getType(Kind kind) - { - switch (kind) - { - case Kind::USER: return typeid(User); - case Kind::ROLE: return typeid(Role); - case Kind::QUOTA: return typeid(Quota); - case Kind::ROW_POLICY: return typeid(RowPolicy); - case Kind::SETTINGS_PROFILE: return typeid(SettingsProfile); - } - __builtin_unreachable(); - } + using EntityType = IAccessEntity::Type; } @@ -274,8 +262,7 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(ASTShowCreateAcces return getCreateQueryImpl(*quota, &access_control, false); } - auto type = getType(show_query.kind); - if (show_query.kind == Kind::ROW_POLICY) + if (show_query.type == Type::ROW_POLICY) { if (show_query.row_policy_name_parts.database.empty()) show_query.row_policy_name_parts.database = context.getCurrentDatabase(); @@ -283,30 +270,30 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(ASTShowCreateAcces return getCreateQueryImpl(*policy, &access_control, false); } - auto entity = access_control.read(access_control.getID(type, show_query.name)); + auto entity = access_control.read(access_control.getID(show_query.type, show_query.name)); return getCreateQueryImpl(*entity, &access_control, false); } -AccessRightsElements InterpreterShowCreateAccessEntityQuery::getRequiredAccess() const -{ - const auto & show_query = query_ptr->as(); - AccessRightsElements res; - switch (show_query.kind) - { - case Kind::USER: res.emplace_back(AccessType::SHOW_USERS); break; - case Kind::ROLE: res.emplace_back(AccessType::SHOW_ROLES); break; - case Kind::ROW_POLICY: res.emplace_back(AccessType::SHOW_ROW_POLICIES); break; - case Kind::SETTINGS_PROFILE: res.emplace_back(AccessType::SHOW_SETTINGS_PROFILES); break; - case Kind::QUOTA: res.emplace_back(AccessType::SHOW_QUOTAS); break; - } - return res; -} - - ASTPtr InterpreterShowCreateAccessEntityQuery::getAttachQuery(const IAccessEntity & entity) { return getCreateQueryImpl(entity, nullptr, true); } + +AccessRightsElements InterpreterShowCreateAccessEntityQuery::getRequiredAccess() const +{ + const auto & show_query = query_ptr->as(); + AccessRightsElements res; + switch (show_query.type) + { + case EntityType::USER: res.emplace_back(AccessType::SHOW_USERS); return res; + case EntityType::ROLE: res.emplace_back(AccessType::SHOW_ROLES); return res; + case EntityType::SETTINGS_PROFILE: res.emplace_back(AccessType::SHOW_SETTINGS_PROFILES); return res; + case EntityType::ROW_POLICY: res.emplace_back(AccessType::SHOW_ROW_POLICIES); return res; + case EntityType::QUOTA: res.emplace_back(AccessType::SHOW_QUOTAS); return res; + case EntityType::MAX: break; + } + throw Exception(toString(show_query.type) + ": type is not supported by SHOW CREATE query", ErrorCodes::NOT_IMPLEMENTED); +} } diff --git a/src/Interpreters/InterpreterShowGrantsQuery.cpp b/src/Interpreters/InterpreterShowGrantsQuery.cpp index aa139b8e10e..130749526c7 100644 --- a/src/Interpreters/InterpreterShowGrantsQuery.cpp +++ b/src/Interpreters/InterpreterShowGrantsQuery.cpp @@ -105,7 +105,7 @@ namespace return getGrantQueriesImpl(*user, manager, attach_mode); if (const Role * role = typeid_cast(&entity)) return getGrantQueriesImpl(*role, manager, attach_mode); - throw Exception("Unexpected type of access entity: " + entity.getTypeName(), ErrorCodes::LOGICAL_ERROR); + throw Exception(entity.outputTypeAndName() + " is expected to be user or role", ErrorCodes::LOGICAL_ERROR); } } diff --git a/src/Parsers/ASTDropAccessEntityQuery.cpp b/src/Parsers/ASTDropAccessEntityQuery.cpp index a0e6753af57..9f7a1d86221 100644 --- a/src/Parsers/ASTDropAccessEntityQuery.cpp +++ b/src/Parsers/ASTDropAccessEntityQuery.cpp @@ -4,34 +4,12 @@ namespace DB { -namespace -{ - using Kind = ASTDropAccessEntityQuery::Kind; - - const char * getKeyword(Kind kind) - { - switch (kind) - { - case Kind::USER: return "USER"; - case Kind::ROLE: return "ROLE"; - case Kind::QUOTA: return "QUOTA"; - case Kind::ROW_POLICY: return "ROW POLICY"; - case Kind::SETTINGS_PROFILE: return "SETTINGS PROFILE"; - } - __builtin_unreachable(); - } -} - - -ASTDropAccessEntityQuery::ASTDropAccessEntityQuery(Kind kind_) - : kind(kind_) -{ -} +using EntityTypeInfo = IAccessEntity::TypeInfo; String ASTDropAccessEntityQuery::getID(char) const { - return String("DROP ") + getKeyword(kind) + " query"; + return String("DROP ") + toString(type) + " query"; } @@ -44,11 +22,11 @@ ASTPtr ASTDropAccessEntityQuery::clone() const void ASTDropAccessEntityQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { settings.ostr << (settings.hilite ? hilite_keyword : "") - << "DROP " << getKeyword(kind) + << "DROP " << EntityTypeInfo::get(type).name << (if_exists ? " IF EXISTS" : "") << (settings.hilite ? hilite_none : ""); - if (kind == Kind::ROW_POLICY) + if (type == EntityType::ROW_POLICY) { bool need_comma = false; for (const auto & name_parts : row_policies_name_parts) diff --git a/src/Parsers/ASTDropAccessEntityQuery.h b/src/Parsers/ASTDropAccessEntityQuery.h index a630869e027..160b0c2e212 100644 --- a/src/Parsers/ASTDropAccessEntityQuery.h +++ b/src/Parsers/ASTDropAccessEntityQuery.h @@ -17,21 +17,13 @@ namespace DB class ASTDropAccessEntityQuery : public IAST, public ASTQueryWithOnCluster { public: - enum class Kind - { - USER, - ROLE, - QUOTA, - ROW_POLICY, - SETTINGS_PROFILE, - }; + using EntityType = IAccessEntity::Type; - const Kind kind; + EntityType type; bool if_exists = false; Strings names; std::vector row_policies_name_parts; - ASTDropAccessEntityQuery(Kind kind_); String getID(char) const override; ASTPtr clone() const override; void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; diff --git a/src/Parsers/ASTShowCreateAccessEntityQuery.cpp b/src/Parsers/ASTShowCreateAccessEntityQuery.cpp index eed5766bcd9..954fab673e7 100644 --- a/src/Parsers/ASTShowCreateAccessEntityQuery.cpp +++ b/src/Parsers/ASTShowCreateAccessEntityQuery.cpp @@ -4,34 +4,12 @@ namespace DB { -namespace -{ - using Kind = ASTShowCreateAccessEntityQuery::Kind; - - const char * getKeyword(Kind kind) - { - switch (kind) - { - case Kind::USER: return "USER"; - case Kind::ROLE: return "ROLE"; - case Kind::QUOTA: return "QUOTA"; - case Kind::ROW_POLICY: return "ROW POLICY"; - case Kind::SETTINGS_PROFILE: return "SETTINGS PROFILE"; - } - __builtin_unreachable(); - } -} - - -ASTShowCreateAccessEntityQuery::ASTShowCreateAccessEntityQuery(Kind kind_) - : kind(kind_) -{ -} +using EntityTypeInfo = IAccessEntity::TypeInfo; String ASTShowCreateAccessEntityQuery::getID(char) const { - return String("SHOW CREATE ") + getKeyword(kind) + " query"; + return String("SHOW CREATE ") + toString(type) + " query"; } @@ -44,7 +22,7 @@ ASTPtr ASTShowCreateAccessEntityQuery::clone() const void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const { settings.ostr << (settings.hilite ? hilite_keyword : "") - << "SHOW CREATE " << getKeyword(kind) + << "SHOW CREATE " << EntityTypeInfo::get(type).name << (settings.hilite ? hilite_none : ""); if (current_user) @@ -52,7 +30,7 @@ void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & sett } else if (current_quota) settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : ""); - else if (kind == Kind::ROW_POLICY) + else if (type == EntityType::ROW_POLICY) { const String & database = row_policy_name_parts.database; const String & table_name = row_policy_name_parts.table_name; diff --git a/src/Parsers/ASTShowCreateAccessEntityQuery.h b/src/Parsers/ASTShowCreateAccessEntityQuery.h index 7f82e9f5e34..df7be2e257c 100644 --- a/src/Parsers/ASTShowCreateAccessEntityQuery.h +++ b/src/Parsers/ASTShowCreateAccessEntityQuery.h @@ -15,22 +15,14 @@ namespace DB class ASTShowCreateAccessEntityQuery : public ASTQueryWithOutput { public: - enum class Kind - { - USER, - ROLE, - QUOTA, - ROW_POLICY, - SETTINGS_PROFILE, - }; + using EntityType = IAccessEntity::Type; - const Kind kind; + EntityType type; String name; bool current_quota = false; bool current_user = false; RowPolicy::NameParts row_policy_name_parts; - ASTShowCreateAccessEntityQuery(Kind kind_); String getID(char) const override; ASTPtr clone() const override; diff --git a/src/Parsers/ParserDropAccessEntityQuery.cpp b/src/Parsers/ParserDropAccessEntityQuery.cpp index 034124a42f0..15f8bbf0a62 100644 --- a/src/Parsers/ParserDropAccessEntityQuery.cpp +++ b/src/Parsers/ParserDropAccessEntityQuery.cpp @@ -4,12 +4,16 @@ #include #include #include +#include namespace DB { namespace { + using EntityType = IAccessEntity::Type; + using EntityTypeInfo = IAccessEntity::TypeInfo; + bool parseNames(IParserBase::Pos & pos, Expected & expected, Strings & names) { return IParserBase::wrapParseImpl(pos, [&] @@ -79,19 +83,17 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & if (!ParserKeyword{"DROP"}.ignore(pos, expected)) return false; - using Kind = ASTDropAccessEntityQuery::Kind; - Kind kind; - if (ParserKeyword{"USER"}.ignore(pos, expected)) - kind = Kind::USER; - else if (ParserKeyword{"ROLE"}.ignore(pos, expected)) - kind = Kind::ROLE; - else if (ParserKeyword{"QUOTA"}.ignore(pos, expected)) - kind = Kind::QUOTA; - else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected)) - kind = Kind::ROW_POLICY; - else if (ParserKeyword{"SETTINGS PROFILE"}.ignore(pos, expected) || ParserKeyword{"PROFILE"}.ignore(pos, expected)) - kind = Kind::SETTINGS_PROFILE; - else + std::optional type; + for (auto type_i : ext::range(EntityType::MAX)) + { + const auto & type_info = EntityTypeInfo::get(type_i); + if (ParserKeyword{type_info.name.c_str()}.ignore(pos, expected) + || (!type_info.alias.empty() && ParserKeyword{type_info.alias.c_str()}.ignore(pos, expected))) + { + type = type_i; + } + } + if (!type) return false; bool if_exists = false; @@ -101,12 +103,12 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & Strings names; std::vector row_policies_name_parts; - if ((kind == Kind::USER) || (kind == Kind::ROLE)) + if ((type == EntityType::USER) || (type == EntityType::ROLE)) { if (!parseUserNames(pos, expected, names)) return false; } - else if (kind == Kind::ROW_POLICY) + else if (type == EntityType::ROW_POLICY) { if (!parseRowPolicyNames(pos, expected, row_policies_name_parts)) return false; @@ -124,9 +126,10 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & return false; } - auto query = std::make_shared(kind); + auto query = std::make_shared(); node = query; + query->type = *type; query->if_exists = if_exists; query->cluster = std::move(cluster); query->names = std::move(names); diff --git a/src/Parsers/ParserShowCreateAccessEntityQuery.cpp b/src/Parsers/ParserShowCreateAccessEntityQuery.cpp index 48cd36f68d3..308a1bd7795 100644 --- a/src/Parsers/ParserShowCreateAccessEntityQuery.cpp +++ b/src/Parsers/ParserShowCreateAccessEntityQuery.cpp @@ -4,29 +4,32 @@ #include #include #include +#include #include namespace DB { +using EntityType = IAccessEntity::Type; +using EntityTypeInfo = IAccessEntity::TypeInfo; + + bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { if (!ParserKeyword{"SHOW CREATE"}.ignore(pos, expected)) return false; - using Kind = ASTShowCreateAccessEntityQuery::Kind; - Kind kind; - if (ParserKeyword{"USER"}.ignore(pos, expected)) - kind = Kind::USER; - else if (ParserKeyword{"QUOTA"}.ignore(pos, expected)) - kind = Kind::QUOTA; - else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected)) - kind = Kind::ROW_POLICY; - else if (ParserKeyword{"ROLE"}.ignore(pos, expected)) - kind = Kind::ROLE; - else if (ParserKeyword{"SETTINGS PROFILE"}.ignore(pos, expected) || ParserKeyword{"PROFILE"}.ignore(pos, expected)) - kind = Kind::SETTINGS_PROFILE; - else + std::optional type; + for (auto type_i : ext::range(EntityType::MAX)) + { + const auto & type_info = EntityTypeInfo::get(type_i); + if (ParserKeyword{type_info.name.c_str()}.ignore(pos, expected) + || (!type_info.alias.empty() && ParserKeyword{type_info.alias.c_str()}.ignore(pos, expected))) + { + type = type_i; + } + } + if (!type) return false; String name; @@ -34,17 +37,17 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe bool current_user = false; RowPolicy::NameParts row_policy_name_parts; - if (kind == Kind::USER) + if (type == EntityType::USER) { if (!parseUserNameOrCurrentUserTag(pos, expected, name, current_user)) current_user = true; } - else if (kind == Kind::ROLE) + else if (type == EntityType::ROLE) { if (!parseRoleName(pos, expected, name)) return false; } - else if (kind == Kind::ROW_POLICY) + else if (type == EntityType::ROW_POLICY) { String & database = row_policy_name_parts.database; String & table_name = row_policy_name_parts.table_name; @@ -53,7 +56,7 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe || !parseDatabaseAndTableName(pos, expected, database, table_name)) return false; } - else if (kind == Kind::QUOTA) + else if (type == EntityType::QUOTA) { if (ParserKeyword{"CURRENT"}.ignore(pos, expected)) { @@ -70,15 +73,16 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe current_quota = true; } } - else if (kind == Kind::SETTINGS_PROFILE) + else if (type == EntityType::SETTINGS_PROFILE) { if (!parseIdentifierOrStringLiteral(pos, expected, name)) return false; } - auto query = std::make_shared(kind); + auto query = std::make_shared(); node = query; + query->type = *type; query->name = std::move(name); query->current_quota = current_quota; query->current_user = current_user; diff --git a/tests/integration/test_access_control_on_cluster/test.py b/tests/integration/test_access_control_on_cluster/test.py index 6ca4ac15398..4dc9baca0a0 100644 --- a/tests/integration/test_access_control_on_cluster/test.py +++ b/tests/integration/test_access_control_on_cluster/test.py @@ -35,7 +35,7 @@ def test_access_control_on_cluster(): assert ch3.query("SHOW GRANTS FOR Alex") == "" ch2.query("DROP USER Alex ON CLUSTER 'cluster'") - assert "User `Alex` not found" in ch1.query_and_get_error("SHOW CREATE USER Alex") - assert "User `Alex` not found" in ch2.query_and_get_error("SHOW CREATE USER Alex") - assert "User `Alex` not found" in ch3.query_and_get_error("SHOW CREATE USER Alex") + assert "There is no user `Alex`" in ch1.query_and_get_error("SHOW CREATE USER Alex") + assert "There is no user `Alex`" in ch2.query_and_get_error("SHOW CREATE USER Alex") + assert "There is no user `Alex`" in ch3.query_and_get_error("SHOW CREATE USER Alex") diff --git a/tests/integration/test_disk_access_storage/test.py b/tests/integration/test_disk_access_storage/test.py index babceee7c76..315440b4358 100644 --- a/tests/integration/test_disk_access_storage/test.py +++ b/tests/integration/test_disk_access_storage/test.py @@ -97,9 +97,9 @@ def test_drop(): def check(): assert instance.query("SHOW CREATE USER u1") == "CREATE USER u1\n" assert instance.query("SHOW CREATE SETTINGS PROFILE s2") == "CREATE SETTINGS PROFILE s2\n" - assert "User `u2` not found" in instance.query_and_get_error("SHOW CREATE USER u2") - assert "Row policy `p ON mydb.mytable` not found" in instance.query_and_get_error("SHOW CREATE ROW POLICY p ON mydb.mytable") - assert "Quota `q` not found" in instance.query_and_get_error("SHOW CREATE QUOTA q") + assert "There is no user `u2`" in instance.query_and_get_error("SHOW CREATE USER u2") + assert "There is no row policy `p ON mydb.mytable`" in instance.query_and_get_error("SHOW CREATE ROW POLICY p ON mydb.mytable") + assert "There is no quota `q`" in instance.query_and_get_error("SHOW CREATE QUOTA q") check() instance.restart_clickhouse() # Check persistency