Use enum Type instead of std::type_index to represent the type of IAccessEntity.

This change simplifies handling of access entities in access storages.
This commit is contained in:
Vitaly Baranov 2020-05-03 06:12:03 +03:00
parent b6fe726777
commit dd8b29b4fb
33 changed files with 626 additions and 594 deletions

View File

@ -53,6 +53,9 @@ namespace ErrorCodes
namespace namespace
{ {
using EntityType = IAccessStorage::EntityType;
using EntityTypeInfo = IAccessStorage::EntityTypeInfo;
/// Special parser for the 'ATTACH access entity' queries. /// Special parser for the 'ATTACH access entity' queries.
class ParserAttachAccessEntity : public IParserBase class ParserAttachAccessEntity : public IParserBase
{ {
@ -79,7 +82,7 @@ namespace
/// Reads a file containing ATTACH queries and then parses it to build an access entity. /// Reads a file containing ATTACH queries and then parses it to build an access entity.
AccessEntityPtr readAccessEntityFile(const std::filesystem::path & file_path) AccessEntityPtr readEntityFile(const std::filesystem::path & file_path)
{ {
/// Read the file. /// Read the file.
ReadBufferFromFile in{file_path}; ReadBufferFromFile in{file_path};
@ -164,11 +167,11 @@ namespace
} }
AccessEntityPtr tryReadAccessEntityFile(const std::filesystem::path & file_path, Poco::Logger & log) AccessEntityPtr tryReadEntityFile(const std::filesystem::path & file_path, Poco::Logger & log)
{ {
try try
{ {
return readAccessEntityFile(file_path); return readEntityFile(file_path);
} }
catch (...) catch (...)
{ {
@ -179,12 +182,12 @@ namespace
/// Writes ATTACH queries for building a specified access entity to a file. /// Writes ATTACH queries for building a specified access entity to a file.
void writeAccessEntityFile(const std::filesystem::path & file_path, const IAccessEntity & entity) void writeEntityFile(const std::filesystem::path & file_path, const IAccessEntity & entity)
{ {
/// Build list of ATTACH queries. /// Build list of ATTACH queries.
ASTs queries; ASTs queries;
queries.push_back(InterpreterShowCreateAccessEntityQuery::getAttachQuery(entity)); queries.push_back(InterpreterShowCreateAccessEntityQuery::getAttachQuery(entity));
if (entity.getType() == typeid(User) || entity.getType() == typeid(Role)) if ((entity.getType() == EntityType::USER) || (entity.getType() == EntityType::ROLE))
boost::range::push_back(queries, InterpreterShowGrantsQuery::getAttachGrantQueries(entity)); boost::range::push_back(queries, InterpreterShowGrantsQuery::getAttachGrantQueries(entity));
/// Serialize the list of ATTACH queries to a string. /// Serialize the list of ATTACH queries to a string.
@ -213,21 +216,21 @@ namespace
/// Calculates the path to a file named <id>.sql for saving an access entity. /// Calculates the path to a file named <id>.sql for saving an access entity.
std::filesystem::path getAccessEntityFilePath(const String & directory_path, const UUID & id) std::filesystem::path getEntityFilePath(const String & directory_path, const UUID & id)
{ {
return std::filesystem::path(directory_path).append(toString(id)).replace_extension(".sql"); return std::filesystem::path(directory_path).append(toString(id)).replace_extension(".sql");
} }
/// Reads a map of name of access entity to UUID for access entities of some type from a file. /// Reads a map of name of access entity to UUID for access entities of some type from a file.
std::unordered_map<String, UUID> readListFile(const std::filesystem::path & file_path) std::vector<std::pair<UUID, String>> readListFile(const std::filesystem::path & file_path)
{ {
ReadBufferFromFile in(file_path); ReadBufferFromFile in(file_path);
size_t num; size_t num;
readVarUInt(num, in); readVarUInt(num, in);
std::unordered_map<String, UUID> res; std::vector<std::pair<UUID, String>> id_name_pairs;
res.reserve(num); id_name_pairs.reserve(num);
for (size_t i = 0; i != num; ++i) for (size_t i = 0; i != num; ++i)
{ {
@ -235,19 +238,19 @@ namespace
readStringBinary(name, in); readStringBinary(name, in);
UUID id; UUID id;
readUUIDText(id, in); readUUIDText(id, in);
res[name] = id; id_name_pairs.emplace_back(id, std::move(name));
} }
return res; return id_name_pairs;
} }
/// Writes a map of name of access entity to UUID for access entities of some type to a file. /// Writes a map of name of access entity to UUID for access entities of some type to a file.
void writeListFile(const std::filesystem::path & file_path, const std::unordered_map<String, UUID> & map) void writeListFile(const std::filesystem::path & file_path, const std::vector<std::pair<UUID, std::string_view>> & id_name_pairs)
{ {
WriteBufferFromFile out(file_path); WriteBufferFromFile out(file_path);
writeVarUInt(map.size(), out); writeVarUInt(id_name_pairs.size(), out);
for (const auto & [name, id] : map) for (const auto & [id, name] : id_name_pairs)
{ {
writeStringBinary(name, out); writeStringBinary(name, out);
writeUUIDText(id, out); writeUUIDText(id, out);
@ -256,24 +259,10 @@ namespace
/// Calculates the path for storing a map of name of access entity to UUID for access entities of some type. /// Calculates the path for storing a map of name of access entity to UUID for access entities of some type.
std::filesystem::path getListFilePath(const String & directory_path, std::type_index type) std::filesystem::path getListFilePath(const String & directory_path, EntityType type)
{ {
std::string_view file_name; std::string_view file_name = EntityTypeInfo::get(type).list_filename;
if (type == typeid(User)) return std::filesystem::path(directory_path).append(file_name);
file_name = "users";
else if (type == typeid(Role))
file_name = "roles";
else if (type == typeid(Quota))
file_name = "quotas";
else if (type == typeid(RowPolicy))
file_name = "row_policies";
else if (type == typeid(SettingsProfile))
file_name = "settings_profiles";
else
throw Exception("Unexpected type of access entity: " + IAccessEntity::getTypeName(type),
ErrorCodes::LOGICAL_ERROR);
return std::filesystem::path(directory_path).append(file_name).replace_extension(".list");
} }
@ -297,21 +286,12 @@ namespace
return false; return false;
} }
} }
const std::vector<std::type_index> & getAllAccessEntityTypes()
{
static const std::vector<std::type_index> res = {typeid(User), typeid(Role), typeid(RowPolicy), typeid(Quota), typeid(SettingsProfile)};
return res;
}
} }
DiskAccessStorage::DiskAccessStorage() DiskAccessStorage::DiskAccessStorage()
: IAccessStorage("disk") : IAccessStorage("disk")
{ {
for (auto type : getAllAccessEntityTypes())
name_to_id_maps[type];
} }
@ -363,18 +343,27 @@ void DiskAccessStorage::initialize(const String & directory_path_, Notifications
writeLists(); writeLists();
} }
for (const auto & [id, entry] : id_to_entry_map) for (const auto & [id, entry] : entries_by_id)
prepareNotifications(id, entry, false, notifications); prepareNotifications(id, entry, false, notifications);
} }
void DiskAccessStorage::clear()
{
entries_by_id.clear();
for (auto type : ext::range(EntityType::MAX))
entries_by_name_and_type[static_cast<size_t>(type)].clear();
}
bool DiskAccessStorage::readLists() bool DiskAccessStorage::readLists()
{ {
assert(id_to_entry_map.empty()); clear();
bool ok = true; bool ok = true;
for (auto type : getAllAccessEntityTypes()) for (auto type : ext::range(EntityType::MAX))
{ {
auto & name_to_id_map = name_to_id_maps.at(type); auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
auto file_path = getListFilePath(directory_path, type); auto file_path = getListFilePath(directory_path, type);
if (!std::filesystem::exists(file_path)) if (!std::filesystem::exists(file_path))
{ {
@ -385,7 +374,14 @@ bool DiskAccessStorage::readLists()
try try
{ {
name_to_id_map = readListFile(file_path); for (const auto & [id, name] : readListFile(file_path))
{
auto & entry = entries_by_id[id];
entry.id = id;
entry.type = type;
entry.name = name;
entries_by_name[entry.name] = &entry;
}
} }
catch (...) catch (...)
{ {
@ -393,17 +389,10 @@ bool DiskAccessStorage::readLists()
ok = false; ok = false;
break; break;
} }
for (const auto & [name, id] : name_to_id_map)
id_to_entry_map.emplace(id, Entry{name, type});
} }
if (!ok) if (!ok)
{ clear();
id_to_entry_map.clear();
for (auto & name_to_id_map : name_to_id_maps | boost::adaptors::map_values)
name_to_id_map.clear();
}
return ok; return ok;
} }
@ -419,11 +408,15 @@ bool DiskAccessStorage::writeLists()
for (const auto & type : types_of_lists_to_write) for (const auto & type : types_of_lists_to_write)
{ {
const auto & name_to_id_map = name_to_id_maps.at(type); auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
auto file_path = getListFilePath(directory_path, type); auto file_path = getListFilePath(directory_path, type);
try try
{ {
writeListFile(file_path, name_to_id_map); std::vector<std::pair<UUID, std::string_view>> id_name_pairs;
id_name_pairs.reserve(entries_by_name.size());
for (const auto * entry : entries_by_name | boost::adaptors::map_values)
id_name_pairs.emplace_back(entry->id, entry->name);
writeListFile(file_path, id_name_pairs);
} }
catch (...) catch (...)
{ {
@ -441,7 +434,7 @@ bool DiskAccessStorage::writeLists()
} }
void DiskAccessStorage::scheduleWriteLists(std::type_index type) void DiskAccessStorage::scheduleWriteLists(EntityType type)
{ {
if (failed_to_write_lists) if (failed_to_write_lists)
return; return;
@ -504,7 +497,7 @@ void DiskAccessStorage::listsWritingThreadFunc()
bool DiskAccessStorage::rebuildLists() bool DiskAccessStorage::rebuildLists()
{ {
LOG_WARNING(getLogger(), "Recovering lists in directory " + directory_path); LOG_WARNING(getLogger(), "Recovering lists in directory " + directory_path);
assert(id_to_entry_map.empty()); clear();
for (const auto & directory_entry : std::filesystem::directory_iterator(directory_path)) for (const auto & directory_entry : std::filesystem::directory_iterator(directory_path))
{ {
@ -518,58 +511,64 @@ bool DiskAccessStorage::rebuildLists()
if (!tryParseUUID(path.stem(), id)) if (!tryParseUUID(path.stem(), id))
continue; continue;
const auto access_entity_file_path = getAccessEntityFilePath(directory_path, id); const auto access_entity_file_path = getEntityFilePath(directory_path, id);
auto entity = tryReadAccessEntityFile(access_entity_file_path, *getLogger()); auto entity = tryReadEntityFile(access_entity_file_path, *getLogger());
if (!entity) if (!entity)
continue; continue;
const String & name = entity->getName();
auto type = entity->getType(); auto type = entity->getType();
auto & name_to_id_map = name_to_id_maps.at(type); auto & entry = entries_by_id[id];
auto it_by_name = name_to_id_map.emplace(entity->getName(), id).first; entry.id = id;
id_to_entry_map.emplace(id, Entry{it_by_name->first, type}); entry.type = type;
entry.name = name;
entry.entity = entity;
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
entries_by_name[entry.name] = &entry;
} }
for (auto type : getAllAccessEntityTypes()) for (auto type : ext::range(EntityType::MAX))
types_of_lists_to_write.insert(type); types_of_lists_to_write.insert(type);
return true; return true;
} }
std::optional<UUID> DiskAccessStorage::findImpl(std::type_index type, const String & name) const std::optional<UUID> DiskAccessStorage::findImpl(EntityType type, const String & name) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
const auto & name_to_id_map = name_to_id_maps.at(type); const auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
auto it = name_to_id_map.find(name); auto it = entries_by_name.find(name);
if (it == name_to_id_map.end()) if (it == entries_by_name.end())
return {}; return {};
return it->second; return it->second->id;
} }
std::vector<UUID> DiskAccessStorage::findAllImpl(std::type_index type) const std::vector<UUID> DiskAccessStorage::findAllImpl(EntityType type) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
const auto & name_to_id_map = name_to_id_maps.at(type); const auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
std::vector<UUID> res; std::vector<UUID> res;
res.reserve(name_to_id_map.size()); res.reserve(entries_by_name.size());
boost::range::copy(name_to_id_map | boost::adaptors::map_values, std::back_inserter(res)); for (const auto * entry : entries_by_name | boost::adaptors::map_values)
res.emplace_back(entry->id);
return res; return res;
} }
bool DiskAccessStorage::existsImpl(const UUID & id) const bool DiskAccessStorage::existsImpl(const UUID & id) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
return id_to_entry_map.count(id); return entries_by_id.count(id);
} }
AccessEntityPtr DiskAccessStorage::readImpl(const UUID & id) const AccessEntityPtr DiskAccessStorage::readImpl(const UUID & id) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto it = id_to_entry_map.find(id); auto it = entries_by_id.find(id);
if (it == id_to_entry_map.end()) if (it == entries_by_id.end())
throwNotFound(id); throwNotFound(id);
const auto & entry = it->second; const auto & entry = it->second;
@ -582,8 +581,8 @@ AccessEntityPtr DiskAccessStorage::readImpl(const UUID & id) const
String DiskAccessStorage::readNameImpl(const UUID & id) const String DiskAccessStorage::readNameImpl(const UUID & id) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto it = id_to_entry_map.find(id); auto it = entries_by_id.find(id);
if (it == id_to_entry_map.end()) if (it == entries_by_id.end())
throwNotFound(id); throwNotFound(id);
return String{it->second.name}; return String{it->second.name};
} }
@ -610,24 +609,24 @@ UUID DiskAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool repl
void DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications) void DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications)
{ {
const String & name = new_entity->getName(); const String & name = new_entity->getName();
std::type_index type = new_entity->getType(); EntityType type = new_entity->getType();
if (!initialized) if (!initialized)
throw Exception( throw Exception(
"Cannot insert " + new_entity->getTypeName() + " " + backQuote(name) + " to " + getStorageName() "Cannot insert " + new_entity->outputTypeAndName() + " to storage [" + getStorageName()
+ " because the output directory is not set", + "] because the output directory is not set",
ErrorCodes::LOGICAL_ERROR); ErrorCodes::LOGICAL_ERROR);
/// Check that we can insert. /// Check that we can insert.
auto it_by_id = id_to_entry_map.find(id); auto it_by_id = entries_by_id.find(id);
if (it_by_id != id_to_entry_map.end()) if (it_by_id != entries_by_id.end())
{ {
const auto & existing_entry = it_by_id->second; const auto & existing_entry = it_by_id->second;
throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getName()); throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getName());
} }
auto & name_to_id_map = name_to_id_maps.at(type); auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
auto it_by_name = name_to_id_map.find(name); auto it_by_name = entries_by_name.find(name);
bool name_collision = (it_by_name != name_to_id_map.end()); bool name_collision = (it_by_name != entries_by_name.end());
if (name_collision && !replace_if_exists) if (name_collision && !replace_if_exists)
throwNameCollisionCannotInsert(type, name); throwNameCollisionCannotInsert(type, name);
@ -636,13 +635,15 @@ void DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & ne
writeAccessEntityToDisk(id, *new_entity); writeAccessEntityToDisk(id, *new_entity);
if (name_collision && replace_if_exists) if (name_collision && replace_if_exists)
removeNoLock(it_by_name->second, notifications); removeNoLock(it_by_name->second->id, notifications);
/// Do insertion. /// Do insertion.
it_by_name = name_to_id_map.emplace(name, id).first; auto & entry = entries_by_id[id];
it_by_id = id_to_entry_map.emplace(id, Entry{it_by_name->first, type}).first; entry.id = id;
auto & entry = it_by_id->second; entry.type = type;
entry.name = name;
entry.entity = new_entity; entry.entity = new_entity;
entries_by_name[entry.name] = &entry;
prepareNotifications(id, entry, false, notifications); prepareNotifications(id, entry, false, notifications);
} }
@ -659,22 +660,21 @@ void DiskAccessStorage::removeImpl(const UUID & id)
void DiskAccessStorage::removeNoLock(const UUID & id, Notifications & notifications) void DiskAccessStorage::removeNoLock(const UUID & id, Notifications & notifications)
{ {
auto it = id_to_entry_map.find(id); auto it = entries_by_id.find(id);
if (it == id_to_entry_map.end()) if (it == entries_by_id.end())
throwNotFound(id); throwNotFound(id);
Entry & entry = it->second; Entry & entry = it->second;
String name{it->second.name}; EntityType type = entry.type;
std::type_index type = it->second.type;
scheduleWriteLists(type); scheduleWriteLists(type);
deleteAccessEntityOnDisk(id); deleteAccessEntityOnDisk(id);
/// Do removing. /// Do removing.
prepareNotifications(id, entry, true, notifications); prepareNotifications(id, entry, true, notifications);
id_to_entry_map.erase(it); auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
auto & name_to_id_map = name_to_id_maps.at(type); entries_by_name.erase(entry.name);
name_to_id_map.erase(name); entries_by_id.erase(it);
} }
@ -690,8 +690,8 @@ void DiskAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_fu
void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications) void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications)
{ {
auto it = id_to_entry_map.find(id); auto it = entries_by_id.find(id);
if (it == id_to_entry_map.end()) if (it == entries_by_id.end())
throwNotFound(id); throwNotFound(id);
Entry & entry = it->second; Entry & entry = it->second;
@ -700,18 +700,22 @@ void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_
auto old_entity = entry.entity; auto old_entity = entry.entity;
auto new_entity = update_func(old_entity); auto new_entity = update_func(old_entity);
if (!new_entity->isTypeOf(old_entity->getType()))
throwBadCast(id, new_entity->getType(), new_entity->getName(), old_entity->getType());
if (*new_entity == *old_entity) if (*new_entity == *old_entity)
return; return;
String new_name = new_entity->getName(); const String & new_name = new_entity->getName();
auto old_name = entry.name; const String & old_name = old_entity->getName();
const std::type_index type = entry.type; const EntityType type = entry.type;
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
bool name_changed = (new_name != old_name); bool name_changed = (new_name != old_name);
if (name_changed) if (name_changed)
{ {
const auto & name_to_id_map = name_to_id_maps.at(type); if (entries_by_name.count(new_name))
if (name_to_id_map.count(new_name)) throwNameCollisionCannotRename(type, old_name, new_name);
throwNameCollisionCannotRename(type, String{old_name}, new_name);
scheduleWriteLists(type); scheduleWriteLists(type);
} }
@ -720,10 +724,9 @@ void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_
if (name_changed) if (name_changed)
{ {
auto & name_to_id_map = name_to_id_maps.at(type); entries_by_name.erase(entry.name);
name_to_id_map.erase(String{old_name}); entry.name = new_name;
auto it_by_name = name_to_id_map.emplace(new_name, id).first; entries_by_name[entry.name] = &entry;
entry.name = it_by_name->first;
} }
prepareNotifications(id, entry, false, notifications); prepareNotifications(id, entry, false, notifications);
@ -732,19 +735,19 @@ void DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_
AccessEntityPtr DiskAccessStorage::readAccessEntityFromDisk(const UUID & id) const AccessEntityPtr DiskAccessStorage::readAccessEntityFromDisk(const UUID & id) const
{ {
return readAccessEntityFile(getAccessEntityFilePath(directory_path, id)); return readEntityFile(getEntityFilePath(directory_path, id));
} }
void DiskAccessStorage::writeAccessEntityToDisk(const UUID & id, const IAccessEntity & entity) const void DiskAccessStorage::writeAccessEntityToDisk(const UUID & id, const IAccessEntity & entity) const
{ {
writeAccessEntityFile(getAccessEntityFilePath(directory_path, id), entity); writeEntityFile(getEntityFilePath(directory_path, id), entity);
} }
void DiskAccessStorage::deleteAccessEntityOnDisk(const UUID & id) const void DiskAccessStorage::deleteAccessEntityOnDisk(const UUID & id) const
{ {
auto file_path = getAccessEntityFilePath(directory_path, id); auto file_path = getEntityFilePath(directory_path, id);
if (!std::filesystem::remove(file_path)) if (!std::filesystem::remove(file_path))
throw Exception("Couldn't delete " + file_path.string(), ErrorCodes::FILE_DOESNT_EXIST); throw Exception("Couldn't delete " + file_path.string(), ErrorCodes::FILE_DOESNT_EXIST);
} }
@ -759,17 +762,16 @@ void DiskAccessStorage::prepareNotifications(const UUID & id, const Entry & entr
for (const auto & handler : entry.handlers_by_id) for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, id, entity}); notifications.push_back({handler, id, entity});
auto range = handlers_by_type.equal_range(entry.type); for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.type)])
for (auto it = range.first; it != range.second; ++it) notifications.push_back({handler, id, entity});
notifications.push_back({it->second, id, entity});
} }
ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto it = id_to_entry_map.find(id); auto it = entries_by_id.find(id);
if (it == id_to_entry_map.end()) if (it == entries_by_id.end())
return {}; return {};
const Entry & entry = it->second; const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler); auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
@ -777,8 +779,8 @@ ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, con
return [this, id, handler_it] return [this, id, handler_it]
{ {
std::lock_guard lock2{mutex}; std::lock_guard lock2{mutex};
auto it2 = id_to_entry_map.find(id); auto it2 = entries_by_id.find(id);
if (it2 != id_to_entry_map.end()) if (it2 != entries_by_id.end())
{ {
const Entry & entry2 = it2->second; const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it); entry2.handlers_by_id.erase(handler_it);
@ -786,23 +788,26 @@ ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, con
}; };
} }
ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const ext::scope_guard DiskAccessStorage::subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto handler_it = handlers_by_type.emplace(type, handler); auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, handler_it] return [this, type, handler_it]
{ {
std::lock_guard lock2{mutex}; std::lock_guard lock2{mutex};
handlers_by_type.erase(handler_it); auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
}; };
} }
bool DiskAccessStorage::hasSubscriptionImpl(const UUID & id) const bool DiskAccessStorage::hasSubscriptionImpl(const UUID & id) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto it = id_to_entry_map.find(id); auto it = entries_by_id.find(id);
if (it != id_to_entry_map.end()) if (it != entries_by_id.end())
{ {
const Entry & entry = it->second; const Entry & entry = it->second;
return !entry.handlers_by_id.empty(); return !entry.handlers_by_id.empty();
@ -810,11 +815,11 @@ bool DiskAccessStorage::hasSubscriptionImpl(const UUID & id) const
return false; return false;
} }
bool DiskAccessStorage::hasSubscriptionImpl(std::type_index type) const bool DiskAccessStorage::hasSubscriptionImpl(EntityType type) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto range = handlers_by_type.equal_range(type); const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return range.first != range.second; return !handlers.empty();
} }
} }

View File

@ -17,8 +17,8 @@ public:
void setDirectory(const String & directory_path_); void setDirectory(const String & directory_path_);
private: private:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override; std::optional<UUID> findImpl(EntityType type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override; std::vector<UUID> findAllImpl(EntityType type) const override;
bool existsImpl(const UUID & id) const override; bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override; AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID & id) const override; String readNameImpl(const UUID & id) const override;
@ -27,14 +27,15 @@ private:
void removeImpl(const UUID & id) override; void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override; void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override; ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override; ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override; bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override; bool hasSubscriptionImpl(EntityType type) const override;
void initialize(const String & directory_path_, Notifications & notifications); void initialize(const String & directory_path_, Notifications & notifications);
void clear();
bool readLists(); bool readLists();
bool writeLists(); bool writeLists();
void scheduleWriteLists(std::type_index type); void scheduleWriteLists(EntityType type);
bool rebuildLists(); bool rebuildLists();
void startListsWritingThread(); void startListsWritingThread();
@ -52,9 +53,9 @@ private:
using NameToIDMap = std::unordered_map<String, UUID>; using NameToIDMap = std::unordered_map<String, UUID>;
struct Entry struct Entry
{ {
Entry(const std::string_view & name_, std::type_index type_) : name(name_), type(type_) {} UUID id;
std::string_view name; /// view points to a string in `name_to_id_maps`. String name;
std::type_index type; EntityType type;
mutable AccessEntityPtr entity; /// may be nullptr, if the entity hasn't been loaded yet. mutable AccessEntityPtr entity; /// may be nullptr, if the entity hasn't been loaded yet.
mutable std::list<OnChangedHandler> handlers_by_id; mutable std::list<OnChangedHandler> handlers_by_id;
}; };
@ -63,14 +64,14 @@ private:
String directory_path; String directory_path;
bool initialized = false; bool initialized = false;
std::unordered_map<std::type_index, NameToIDMap> name_to_id_maps; std::unordered_map<UUID, Entry> entries_by_id;
std::unordered_map<UUID, Entry> id_to_entry_map; std::unordered_map<std::string_view, Entry *> entries_by_name_and_type[static_cast<size_t>(EntityType::MAX)];
boost::container::flat_set<std::type_index> types_of_lists_to_write; boost::container::flat_set<EntityType> types_of_lists_to_write;
bool failed_to_write_lists = false; /// Whether writing of the list files has been failed since the recent restart of the server. bool failed_to_write_lists = false; /// Whether writing of the list files has been failed since the recent restart of the server.
ThreadFromGlobalPool lists_writing_thread; /// List files are written in a separate thread. ThreadFromGlobalPool lists_writing_thread; /// List files are written in a separate thread.
std::condition_variable lists_writing_thread_should_exit; /// Signals `lists_writing_thread` to exit. std::condition_variable lists_writing_thread_should_exit; /// Signals `lists_writing_thread` to exit.
std::atomic<bool> lists_writing_thread_exited = false; std::atomic<bool> lists_writing_thread_exited = false;
mutable std::unordered_multimap<std::type_index, OnChangedHandler> handlers_by_type; mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(EntityType::MAX)];
mutable std::mutex mutex; mutable std::mutex mutex;
}; };
} }

View File

@ -1,48 +1,12 @@
#include <Access/IAccessEntity.h> #include <Access/IAccessEntity.h>
#include <Access/Quota.h>
#include <Access/RowPolicy.h>
#include <Access/User.h>
#include <Access/Role.h>
#include <Access/SettingsProfile.h>
#include <common/demangle.h>
namespace DB namespace DB
{ {
String IAccessEntity::getTypeName(std::type_index type)
{
if (type == typeid(User))
return "User";
if (type == typeid(Quota))
return "Quota";
if (type == typeid(RowPolicy))
return "Row policy";
if (type == typeid(Role))
return "Role";
if (type == typeid(SettingsProfile))
return "Settings profile";
return demangle(type.name());
}
const char * IAccessEntity::getKeyword(std::type_index type)
{
if (type == typeid(User))
return "USER";
if (type == typeid(Quota))
return "QUOTA";
if (type == typeid(RowPolicy))
return "ROW POLICY";
if (type == typeid(Role))
return "ROLE";
if (type == typeid(SettingsProfile))
return "SETTINGS PROFILE";
__builtin_unreachable();
}
bool IAccessEntity::equal(const IAccessEntity & other) const bool IAccessEntity::equal(const IAccessEntity & other) const
{ {
return (name == other.name) && (getType() == other.getType()); return (name == other.name) && (getType() == other.getType());
} }
} }

View File

@ -2,12 +2,24 @@
#include <Core/Types.h> #include <Core/Types.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Common/quoteString.h>
#include <boost/algorithm/string.hpp>
#include <memory> #include <memory>
#include <typeindex>
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int UNKNOWN_USER;
extern const int UNKNOWN_ROLE;
extern const int UNKNOWN_ROW_POLICY;
extern const int UNKNOWN_QUOTA;
extern const int THERE_IS_NO_PROFILE;
extern const int LOGICAL_ERROR;
}
/// Access entity is a set of data which have a name and a type. Access entity control something related to the access control. /// Access entity is a set of data which have a name and a type. Access entity control something related to the access control.
/// Entities can be stored to a file or another storage, see IAccessStorage. /// Entities can be stored to a file or another storage, see IAccessStorage.
struct IAccessEntity struct IAccessEntity
@ -17,15 +29,39 @@ struct IAccessEntity
virtual ~IAccessEntity() = default; virtual ~IAccessEntity() = default;
virtual std::shared_ptr<IAccessEntity> clone() const = 0; virtual std::shared_ptr<IAccessEntity> clone() const = 0;
std::type_index getType() const { return typeid(*this); } enum class Type
static String getTypeName(std::type_index type); {
const String getTypeName() const { return getTypeName(getType()); } USER,
static const char * getKeyword(std::type_index type); ROLE,
const char * getKeyword() const { return getKeyword(getType()); } SETTINGS_PROFILE,
ROW_POLICY,
QUOTA,
template <typename EntityType> MAX,
bool isTypeOf() const { return isTypeOf(typeid(EntityType)); } };
bool isTypeOf(std::type_index type) const { return type == getType(); }
virtual Type getType() const = 0;
struct TypeInfo
{
const char * const raw_name;
const String name; /// Uppercased with spaces instead of underscores, e.g. "SETTINGS PROFILE".
const String alias; /// Alias of the keyword or empty string, e.g. "PROFILE".
const String name_for_output_with_entity_name; /// Lowercased with spaces instead of underscores, e.g. "settings profile".
const char unique_char; /// Unique character for this type. E.g. 'P' for SETTINGS_PROFILE.
const String list_filename; /// Name of the file containing list of objects of this type, including the file extension ".list".
const int not_found_error_code;
static const TypeInfo & get(Type type_);
String outputWithEntityName(const String & entity_name) const;
};
const TypeInfo & getTypeInfo() const { return TypeInfo::get(getType()); }
String outputTypeAndName() const { return getTypeInfo().outputWithEntityName(getName()); }
template <typename EntityClassT>
bool isTypeOf() const { return isTypeOf(EntityClassT::TYPE); }
bool isTypeOf(Type type) const { return type == getType(); }
virtual void setName(const String & name_) { name = name_; } virtual void setName(const String & name_) { name = name_; }
const String & getName() const { return name; } const String & getName() const { return name; }
@ -39,12 +75,74 @@ protected:
virtual bool equal(const IAccessEntity & other) const; virtual bool equal(const IAccessEntity & other) const;
/// Helper function to define clone() in the derived classes. /// Helper function to define clone() in the derived classes.
template <typename EntityType> template <typename EntityClassT>
std::shared_ptr<IAccessEntity> cloneImpl() const std::shared_ptr<IAccessEntity> cloneImpl() const
{ {
return std::make_shared<EntityType>(typeid_cast<const EntityType &>(*this)); return std::make_shared<EntityClassT>(typeid_cast<const EntityClassT &>(*this));
} }
}; };
using AccessEntityPtr = std::shared_ptr<const IAccessEntity>; using AccessEntityPtr = std::shared_ptr<const IAccessEntity>;
inline const IAccessEntity::TypeInfo & IAccessEntity::TypeInfo::get(Type type_)
{
static constexpr auto make_info = [](const char * raw_name_, char unique_char_, const char * list_filename_, int not_found_error_code_)
{
String init_name = raw_name_;
boost::to_upper(init_name);
boost::replace_all(init_name, "_", " ");
String init_alias;
if (auto underscore_pos = init_name.find_first_of(" "); underscore_pos != String::npos)
init_alias = init_name.substr(underscore_pos + 1);
String init_name_for_output_with_entity_name = init_name;
boost::to_lower(init_name_for_output_with_entity_name);
return TypeInfo{raw_name_, std::move(init_name), std::move(init_alias), std::move(init_name_for_output_with_entity_name), unique_char_, list_filename_, not_found_error_code_};
};
switch (type_)
{
case Type::USER:
{
static const auto info = make_info("USER", 'U', "users.list", ErrorCodes::UNKNOWN_USER);
return info;
}
case Type::ROLE:
{
static const auto info = make_info("ROLE", 'R', "roles.list", ErrorCodes::UNKNOWN_ROLE);
return info;
}
case Type::SETTINGS_PROFILE:
{
static const auto info = make_info("SETTINGS_PROFILE", 'S', "settings_profiles.list", ErrorCodes::THERE_IS_NO_PROFILE);
return info;
}
case Type::ROW_POLICY:
{
static const auto info = make_info("ROW_POLICY", 'P', "row_policies.list", ErrorCodes::UNKNOWN_ROW_POLICY);
return info;
}
case Type::QUOTA:
{
static const auto info = make_info("QUOTA", 'Q', "quotas.list", ErrorCodes::UNKNOWN_QUOTA);
return info;
}
case Type::MAX: break;
}
throw Exception("Unknown type: " + std::to_string(static_cast<size_t>(type_)), ErrorCodes::LOGICAL_ERROR);
}
inline String IAccessEntity::TypeInfo::outputWithEntityName(const String & entity_name) const
{
String msg = name_for_output_with_entity_name;
msg += " ";
msg += backQuote(entity_name);
return msg;
}
inline String toString(IAccessEntity::Type type)
{
return IAccessEntity::TypeInfo::get(type).name;
}
} }

View File

@ -13,27 +13,44 @@ namespace DB
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int BAD_CAST; extern const int BAD_CAST;
extern const int ACCESS_ENTITY_NOT_FOUND;
extern const int ACCESS_ENTITY_ALREADY_EXISTS; extern const int ACCESS_ENTITY_ALREADY_EXISTS;
extern const int ACCESS_ENTITY_NOT_FOUND;
extern const int ACCESS_STORAGE_READONLY; extern const int ACCESS_STORAGE_READONLY;
extern const int UNKNOWN_USER;
extern const int UNKNOWN_ROLE;
} }
std::vector<UUID> IAccessStorage::findAll(std::type_index type) const namespace
{
using EntityType = IAccessStorage::EntityType;
using EntityTypeInfo = IAccessStorage::EntityTypeInfo;
bool isNotFoundErrorCode(int error_code)
{
if (error_code == ErrorCodes::ACCESS_ENTITY_NOT_FOUND)
return true;
for (auto type : ext::range(EntityType::MAX))
if (error_code == EntityTypeInfo::get(type).not_found_error_code)
return true;
return false;
}
}
std::vector<UUID> IAccessStorage::findAll(EntityType type) const
{ {
return findAllImpl(type); return findAllImpl(type);
} }
std::optional<UUID> IAccessStorage::find(std::type_index type, const String & name) const std::optional<UUID> IAccessStorage::find(EntityType type, const String & name) const
{ {
return findImpl(type, name); return findImpl(type, name);
} }
std::vector<UUID> IAccessStorage::find(std::type_index type, const Strings & names) const std::vector<UUID> IAccessStorage::find(EntityType type, const Strings & names) const
{ {
std::vector<UUID> ids; std::vector<UUID> ids;
ids.reserve(names.size()); ids.reserve(names.size());
@ -47,7 +64,7 @@ std::vector<UUID> IAccessStorage::find(std::type_index type, const Strings & nam
} }
UUID IAccessStorage::getID(std::type_index type, const String & name) const UUID IAccessStorage::getID(EntityType type, const String & name) const
{ {
auto id = findImpl(type, name); auto id = findImpl(type, name);
if (id) if (id)
@ -56,7 +73,7 @@ UUID IAccessStorage::getID(std::type_index type, const String & name) const
} }
std::vector<UUID> IAccessStorage::getIDs(std::type_index type, const Strings & names) const std::vector<UUID> IAccessStorage::getIDs(EntityType type, const Strings & names) const
{ {
std::vector<UUID> ids; std::vector<UUID> ids;
ids.reserve(names.size()); ids.reserve(names.size());
@ -190,6 +207,7 @@ void IAccessStorage::remove(const UUID & id)
void IAccessStorage::remove(const std::vector<UUID> & ids) void IAccessStorage::remove(const std::vector<UUID> & ids)
{ {
String error_message; String error_message;
std::optional<int> error_code;
for (const auto & id : ids) for (const auto & id : ids)
{ {
try try
@ -198,13 +216,17 @@ void IAccessStorage::remove(const std::vector<UUID> & ids)
} }
catch (Exception & e) catch (Exception & e)
{ {
if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND) if (!isNotFoundErrorCode(e.code()))
throw; throw;
error_message += (error_message.empty() ? "" : ". ") + e.message(); error_message += (error_message.empty() ? "" : ". ") + e.message();
if (error_code && (*error_code != e.code()))
error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND;
else
error_code = e.code();
} }
} }
if (!error_message.empty()) if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND); throw Exception(error_message, *error_code);
} }
@ -250,6 +272,7 @@ void IAccessStorage::update(const UUID & id, const UpdateFunc & update_func)
void IAccessStorage::update(const std::vector<UUID> & ids, const UpdateFunc & update_func) void IAccessStorage::update(const std::vector<UUID> & ids, const UpdateFunc & update_func)
{ {
String error_message; String error_message;
std::optional<int> error_code;
for (const auto & id : ids) for (const auto & id : ids)
{ {
try try
@ -258,13 +281,17 @@ void IAccessStorage::update(const std::vector<UUID> & ids, const UpdateFunc & up
} }
catch (Exception & e) catch (Exception & e)
{ {
if (e.code() != ErrorCodes::ACCESS_ENTITY_NOT_FOUND) if (!isNotFoundErrorCode(e.code()))
throw; throw;
error_message += (error_message.empty() ? "" : ". ") + e.message(); error_message += (error_message.empty() ? "" : ". ") + e.message();
if (error_code && (*error_code != e.code()))
error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND;
else
error_code = e.code();
} }
} }
if (!error_message.empty()) if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_NOT_FOUND); throw Exception(error_message, *error_code);
} }
@ -301,7 +328,7 @@ std::vector<UUID> IAccessStorage::tryUpdate(const std::vector<UUID> & ids, const
} }
ext::scope_guard IAccessStorage::subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const ext::scope_guard IAccessStorage::subscribeForChanges(EntityType type, const OnChangedHandler & handler) const
{ {
return subscribeForChangesImpl(type, handler); return subscribeForChangesImpl(type, handler);
} }
@ -322,7 +349,7 @@ ext::scope_guard IAccessStorage::subscribeForChanges(const std::vector<UUID> & i
} }
bool IAccessStorage::hasSubscription(std::type_index type) const bool IAccessStorage::hasSubscription(EntityType type) const
{ {
return hasSubscriptionImpl(type); return hasSubscriptionImpl(type);
} }
@ -361,79 +388,72 @@ Poco::Logger * IAccessStorage::getLogger() const
void IAccessStorage::throwNotFound(const UUID & id) const void IAccessStorage::throwNotFound(const UUID & id) const
{ {
throw Exception("ID {" + toString(id) + "} not found in " + getStorageName(), ErrorCodes::ACCESS_ENTITY_NOT_FOUND); throw Exception("ID {" + toString(id) + "} not found in [" + getStorageName() + "]", ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
} }
void IAccessStorage::throwNotFound(std::type_index type, const String & name) const void IAccessStorage::throwNotFound(EntityType type, const String & name) const
{ {
int error_code; int error_code = EntityTypeInfo::get(type).not_found_error_code;
if (type == typeid(User)) throw Exception("There is no " + outputEntityTypeAndName(type, name) + " in [" + getStorageName() + "]", error_code);
error_code = ErrorCodes::UNKNOWN_USER;
else if (type == typeid(Role))
error_code = ErrorCodes::UNKNOWN_ROLE;
else
error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND;
throw Exception(getTypeName(type) + " " + backQuote(name) + " not found in " + getStorageName(), error_code);
} }
void IAccessStorage::throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type) void IAccessStorage::throwBadCast(const UUID & id, EntityType type, const String & name, EntityType required_type)
{ {
throw Exception( throw Exception(
"ID {" + toString(id) + "}: " + getTypeName(type) + backQuote(name) + " expected to be of type " + getTypeName(required_type), "ID {" + toString(id) + "}: " + outputEntityTypeAndName(type, name) + " expected to be of type " + toString(required_type),
ErrorCodes::BAD_CAST); ErrorCodes::BAD_CAST);
} }
void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, EntityType type, const String & name, EntityType existing_type, const String & existing_name) const
{ {
throw Exception( throw Exception(
getTypeName(type) + " " + backQuote(name) + ": cannot insert because the ID {" + toString(id) + "} is already used by " outputEntityTypeAndName(type, name) + ": cannot insert because the ID {" + toString(id) + "} is already used by "
+ getTypeName(existing_type) + " " + backQuote(existing_name) + " in " + getStorageName(), + outputEntityTypeAndName(existing_type, existing_name) + " in [" + getStorageName() + "]",
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS); ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
} }
void IAccessStorage::throwNameCollisionCannotInsert(std::type_index type, const String & name) const void IAccessStorage::throwNameCollisionCannotInsert(EntityType type, const String & name) const
{ {
throw Exception( throw Exception(
getTypeName(type) + " " + backQuote(name) + ": cannot insert because " + getTypeName(type) + " " + backQuote(name) outputEntityTypeAndName(type, name) + ": cannot insert because " + outputEntityTypeAndName(type, name) + " already exists in ["
+ " already exists in " + getStorageName(), + getStorageName() + "]",
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS); ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
} }
void IAccessStorage::throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const void IAccessStorage::throwNameCollisionCannotRename(EntityType type, const String & old_name, const String & new_name) const
{ {
throw Exception( throw Exception(
getTypeName(type) + " " + backQuote(old_name) + ": cannot rename to " + backQuote(new_name) + " because " + getTypeName(type) + " " outputEntityTypeAndName(type, old_name) + ": cannot rename to " + backQuote(new_name) + " because "
+ backQuote(new_name) + " already exists in " + getStorageName(), + outputEntityTypeAndName(type, new_name) + " already exists in [" + getStorageName() + "]",
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS); ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
} }
void IAccessStorage::throwReadonlyCannotInsert(std::type_index type, const String & name) const void IAccessStorage::throwReadonlyCannotInsert(EntityType type, const String & name) const
{ {
throw Exception( throw Exception(
"Cannot insert " + getTypeName(type) + " " + backQuote(name) + " to " + getStorageName() + " because this storage is readonly", "Cannot insert " + outputEntityTypeAndName(type, name) + " to [" + getStorageName() + "] because this storage is readonly",
ErrorCodes::ACCESS_STORAGE_READONLY); ErrorCodes::ACCESS_STORAGE_READONLY);
} }
void IAccessStorage::throwReadonlyCannotUpdate(std::type_index type, const String & name) const void IAccessStorage::throwReadonlyCannotUpdate(EntityType type, const String & name) const
{ {
throw Exception( throw Exception(
"Cannot update " + getTypeName(type) + " " + backQuote(name) + " in " + getStorageName() + " because this storage is readonly", "Cannot update " + outputEntityTypeAndName(type, name) + " in [" + getStorageName() + "] because this storage is readonly",
ErrorCodes::ACCESS_STORAGE_READONLY); ErrorCodes::ACCESS_STORAGE_READONLY);
} }
void IAccessStorage::throwReadonlyCannotRemove(std::type_index type, const String & name) const void IAccessStorage::throwReadonlyCannotRemove(EntityType type, const String & name) const
{ {
throw Exception( throw Exception(
"Cannot remove " + getTypeName(type) + " " + backQuote(name) + " from " + getStorageName() + " because this storage is readonly", "Cannot remove " + outputEntityTypeAndName(type, name) + " from [" + getStorageName() + "] because this storage is readonly",
ErrorCodes::ACCESS_STORAGE_READONLY); ErrorCodes::ACCESS_STORAGE_READONLY);
} }
} }

View File

@ -25,50 +25,53 @@ public:
/// Returns the name of this storage. /// Returns the name of this storage.
const String & getStorageName() const { return storage_name; } const String & getStorageName() const { return storage_name; }
/// Returns the identifiers of all the entities of a specified type contained in the storage. using EntityType = IAccessEntity::Type;
std::vector<UUID> findAll(std::type_index type) const; using EntityTypeInfo = IAccessEntity::TypeInfo;
template <typename EntityType> /// Returns the identifiers of all the entities of a specified type contained in the storage.
std::vector<UUID> findAll() const { return findAll(typeid(EntityType)); } std::vector<UUID> findAll(EntityType type) const;
template <typename EntityClassT>
std::vector<UUID> findAll() const { return findAll(EntityClassT::TYPE); }
/// Searchs for an entity with specified type and name. Returns std::nullopt if not found. /// Searchs for an entity with specified type and name. Returns std::nullopt if not found.
std::optional<UUID> find(std::type_index type, const String & name) const; std::optional<UUID> find(EntityType type, const String & name) const;
template <typename EntityType> template <typename EntityClassT>
std::optional<UUID> find(const String & name) const { return find(typeid(EntityType), name); } std::optional<UUID> find(const String & name) const { return find(EntityClassT::TYPE, name); }
std::vector<UUID> find(std::type_index type, const Strings & names) const; std::vector<UUID> find(EntityType type, const Strings & names) const;
template <typename EntityType> template <typename EntityClassT>
std::vector<UUID> find(const Strings & names) const { return find(typeid(EntityType), names); } std::vector<UUID> find(const Strings & names) const { return find(EntityClassT::TYPE, names); }
/// Searchs for an entity with specified name and type. Throws an exception if not found. /// Searchs for an entity with specified name and type. Throws an exception if not found.
UUID getID(std::type_index type, const String & name) const; UUID getID(EntityType type, const String & name) const;
template <typename EntityType> template <typename EntityClassT>
UUID getID(const String & name) const { return getID(typeid(EntityType), name); } UUID getID(const String & name) const { return getID(EntityClassT::TYPE, name); }
std::vector<UUID> getIDs(std::type_index type, const Strings & names) const; std::vector<UUID> getIDs(EntityType type, const Strings & names) const;
template <typename EntityType> template <typename EntityClassT>
std::vector<UUID> getIDs(const Strings & names) const { return getIDs(typeid(EntityType), names); } std::vector<UUID> getIDs(const Strings & names) const { return getIDs(EntityClassT::TYPE, names); }
/// Returns whether there is an entity with such identifier in the storage. /// Returns whether there is an entity with such identifier in the storage.
bool exists(const UUID & id) const; bool exists(const UUID & id) const;
/// Reads an entity. Throws an exception if not found. /// Reads an entity. Throws an exception if not found.
template <typename EntityType = IAccessEntity> template <typename EntityClassT = IAccessEntity>
std::shared_ptr<const EntityType> read(const UUID & id) const; std::shared_ptr<const EntityClassT> read(const UUID & id) const;
template <typename EntityType = IAccessEntity> template <typename EntityClassT = IAccessEntity>
std::shared_ptr<const EntityType> read(const String & name) const; std::shared_ptr<const EntityClassT> read(const String & name) const;
/// Reads an entity. Returns nullptr if not found. /// Reads an entity. Returns nullptr if not found.
template <typename EntityType = IAccessEntity> template <typename EntityClassT = IAccessEntity>
std::shared_ptr<const EntityType> tryRead(const UUID & id) const; std::shared_ptr<const EntityClassT> tryRead(const UUID & id) const;
template <typename EntityType = IAccessEntity> template <typename EntityClassT = IAccessEntity>
std::shared_ptr<const EntityType> tryRead(const String & name) const; std::shared_ptr<const EntityClassT> tryRead(const String & name) const;
/// Reads only name of an entity. /// Reads only name of an entity.
String readName(const UUID & id) const; String readName(const UUID & id) const;
@ -118,22 +121,22 @@ public:
/// Subscribes for all changes. /// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only). /// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
ext::scope_guard subscribeForChanges(std::type_index type, const OnChangedHandler & handler) const; ext::scope_guard subscribeForChanges(EntityType type, const OnChangedHandler & handler) const;
template <typename EntityType> template <typename EntityClassT>
ext::scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(typeid(EntityType), handler); } ext::scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(EntityClassT::TYPE, handler); }
/// Subscribes for changes of a specific entry. /// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only). /// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
ext::scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const; ext::scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
ext::scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const; ext::scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
bool hasSubscription(std::type_index type) const; bool hasSubscription(EntityType type) const;
bool hasSubscription(const UUID & id) const; bool hasSubscription(const UUID & id) const;
protected: protected:
virtual std::optional<UUID> findImpl(std::type_index type, const String & name) const = 0; virtual std::optional<UUID> findImpl(EntityType type, const String & name) const = 0;
virtual std::vector<UUID> findAllImpl(std::type_index type) const = 0; virtual std::vector<UUID> findAllImpl(EntityType type) const = 0;
virtual bool existsImpl(const UUID & id) const = 0; virtual bool existsImpl(const UUID & id) const = 0;
virtual AccessEntityPtr readImpl(const UUID & id) const = 0; virtual AccessEntityPtr readImpl(const UUID & id) const = 0;
virtual String readNameImpl(const UUID & id) const = 0; virtual String readNameImpl(const UUID & id) const = 0;
@ -142,23 +145,23 @@ protected:
virtual void removeImpl(const UUID & id) = 0; virtual void removeImpl(const UUID & id) = 0;
virtual void updateImpl(const UUID & id, const UpdateFunc & update_func) = 0; virtual void updateImpl(const UUID & id, const UpdateFunc & update_func) = 0;
virtual ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0; virtual ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0;
virtual ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const = 0; virtual ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const = 0;
virtual bool hasSubscriptionImpl(const UUID & id) const = 0; virtual bool hasSubscriptionImpl(const UUID & id) const = 0;
virtual bool hasSubscriptionImpl(std::type_index type) const = 0; virtual bool hasSubscriptionImpl(EntityType type) const = 0;
static UUID generateRandomID(); static UUID generateRandomID();
Poco::Logger * getLogger() const; Poco::Logger * getLogger() const;
static String getTypeName(std::type_index type) { return IAccessEntity::getTypeName(type); } static String outputEntityTypeAndName(EntityType type, const String & name) { return EntityTypeInfo::get(type).outputWithEntityName(name); }
[[noreturn]] void throwNotFound(const UUID & id) const; [[noreturn]] void throwNotFound(const UUID & id) const;
[[noreturn]] void throwNotFound(std::type_index type, const String & name) const; [[noreturn]] void throwNotFound(EntityType type, const String & name) const;
[[noreturn]] static void throwBadCast(const UUID & id, std::type_index type, const String & name, std::type_index required_type); [[noreturn]] static void throwBadCast(const UUID & id, EntityType type, const String & name, EntityType required_type);
[[noreturn]] void throwIDCollisionCannotInsert( [[noreturn]] void throwIDCollisionCannotInsert(
const UUID & id, std::type_index type, const String & name, std::type_index existing_type, const String & existing_name) const; const UUID & id, EntityType type, const String & name, EntityType existing_type, const String & existing_name) const;
[[noreturn]] void throwNameCollisionCannotInsert(std::type_index type, const String & name) const; [[noreturn]] void throwNameCollisionCannotInsert(EntityType type, const String & name) const;
[[noreturn]] void throwNameCollisionCannotRename(std::type_index type, const String & old_name, const String & new_name) const; [[noreturn]] void throwNameCollisionCannotRename(EntityType type, const String & old_name, const String & new_name) const;
[[noreturn]] void throwReadonlyCannotInsert(std::type_index type, const String & name) const; [[noreturn]] void throwReadonlyCannotInsert(EntityType type, const String & name) const;
[[noreturn]] void throwReadonlyCannotUpdate(std::type_index type, const String & name) const; [[noreturn]] void throwReadonlyCannotUpdate(EntityType type, const String & name) const;
[[noreturn]] void throwReadonlyCannotRemove(std::type_index type, const String & name) const; [[noreturn]] void throwReadonlyCannotRemove(EntityType type, const String & name) const;
using Notification = std::tuple<OnChangedHandler, UUID, AccessEntityPtr>; using Notification = std::tuple<OnChangedHandler, UUID, AccessEntityPtr>;
using Notifications = std::vector<Notification>; using Notifications = std::vector<Notification>;
@ -172,38 +175,43 @@ private:
}; };
template <typename EntityType> template <typename EntityClassT>
std::shared_ptr<const EntityType> IAccessStorage::read(const UUID & id) const std::shared_ptr<const EntityClassT> IAccessStorage::read(const UUID & id) const
{ {
auto entity = readImpl(id); auto entity = readImpl(id);
auto ptr = typeid_cast<std::shared_ptr<const EntityType>>(entity); if constexpr (std::is_same_v<EntityClassT, IAccessEntity>)
if (ptr) return entity;
return ptr; else
throwBadCast(id, entity->getType(), entity->getName(), typeid(EntityType)); {
auto ptr = typeid_cast<std::shared_ptr<const EntityClassT>>(entity);
if (ptr)
return ptr;
throwBadCast(id, entity->getType(), entity->getName(), EntityClassT::TYPE);
}
} }
template <typename EntityType> template <typename EntityClassT>
std::shared_ptr<const EntityType> IAccessStorage::read(const String & name) const std::shared_ptr<const EntityClassT> IAccessStorage::read(const String & name) const
{ {
return read<EntityType>(getID<EntityType>(name)); return read<EntityClassT>(getID<EntityClassT>(name));
} }
template <typename EntityType> template <typename EntityClassT>
std::shared_ptr<const EntityType> IAccessStorage::tryRead(const UUID & id) const std::shared_ptr<const EntityClassT> IAccessStorage::tryRead(const UUID & id) const
{ {
auto entity = tryReadBase(id); auto entity = tryReadBase(id);
if (!entity) if (!entity)
return nullptr; return nullptr;
return typeid_cast<std::shared_ptr<const EntityType>>(entity); return typeid_cast<std::shared_ptr<const EntityClassT>>(entity);
} }
template <typename EntityType> template <typename EntityClassT>
std::shared_ptr<const EntityType> IAccessStorage::tryRead(const String & name) const std::shared_ptr<const EntityClassT> IAccessStorage::tryRead(const String & name) const
{ {
auto id = find<EntityType>(name); auto id = find<EntityClassT>(name);
return id ? tryRead<EntityType>(*id) : nullptr; return id ? tryRead<EntityClassT>(*id) : nullptr;
} }
} }

View File

@ -1,6 +1,8 @@
#include <Access/MemoryAccessStorage.h> #include <Access/MemoryAccessStorage.h>
#include <ext/scope_guard.h> #include <ext/scope_guard.h>
#include <unordered_set> #include <boost/container/flat_set.hpp>
#include <boost/range/adaptor/map.hpp>
#include <boost/range/algorithm/copy.hpp>
namespace DB namespace DB
@ -11,11 +13,12 @@ MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_)
} }
std::optional<UUID> MemoryAccessStorage::findImpl(std::type_index type, const String & name) const std::optional<UUID> MemoryAccessStorage::findImpl(EntityType type, const String & name) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto it = names.find({name, type}); const auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
if (it == names.end()) auto it = entries_by_name.find(name);
if (it == entries_by_name.end())
return {}; return {};
Entry & entry = *(it->second); Entry & entry = *(it->second);
@ -23,12 +26,12 @@ std::optional<UUID> MemoryAccessStorage::findImpl(std::type_index type, const St
} }
std::vector<UUID> MemoryAccessStorage::findAllImpl(std::type_index type) const std::vector<UUID> MemoryAccessStorage::findAllImpl(EntityType type) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
std::vector<UUID> result; std::vector<UUID> result;
result.reserve(entries.size()); result.reserve(entries_by_id.size());
for (const auto & [id, entry] : entries) for (const auto & [id, entry] : entries_by_id)
if (entry.entity->isTypeOf(type)) if (entry.entity->isTypeOf(type))
result.emplace_back(id); result.emplace_back(id);
return result; return result;
@ -38,15 +41,15 @@ std::vector<UUID> MemoryAccessStorage::findAllImpl(std::type_index type) const
bool MemoryAccessStorage::existsImpl(const UUID & id) const bool MemoryAccessStorage::existsImpl(const UUID & id) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
return entries.count(id); return entries_by_id.count(id);
} }
AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id) const AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto it = entries.find(id); auto it = entries_by_id.find(id);
if (it == entries.end()) if (it == entries_by_id.end())
throwNotFound(id); throwNotFound(id);
const Entry & entry = it->second; const Entry & entry = it->second;
return entry.entity; return entry.entity;
@ -74,18 +77,19 @@ UUID MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool re
void MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications) void MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, Notifications & notifications)
{ {
const String & name = new_entity->getName(); const String & name = new_entity->getName();
std::type_index type = new_entity->getType(); EntityType type = new_entity->getType();
/// Check that we can insert. /// Check that we can insert.
auto it = entries.find(id); auto it = entries_by_id.find(id);
if (it != entries.end()) if (it != entries_by_id.end())
{ {
const auto & existing_entry = it->second; const auto & existing_entry = it->second;
throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getName()); throwIDCollisionCannotInsert(id, type, name, existing_entry.entity->getType(), existing_entry.entity->getName());
} }
auto it2 = names.find({name, type}); auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
if (it2 != names.end()) auto it2 = entries_by_name.find(name);
if (it2 != entries_by_name.end())
{ {
const auto & existing_entry = *(it2->second); const auto & existing_entry = *(it2->second);
if (replace_if_exists) if (replace_if_exists)
@ -95,10 +99,10 @@ void MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr &
} }
/// Do insertion. /// Do insertion.
auto & entry = entries[id]; auto & entry = entries_by_id[id];
entry.id = id; entry.id = id;
entry.entity = new_entity; entry.entity = new_entity;
names[std::pair{name, type}] = &entry; entries_by_name[name] = &entry;
prepareNotifications(entry, false, notifications); prepareNotifications(entry, false, notifications);
} }
@ -115,19 +119,20 @@ void MemoryAccessStorage::removeImpl(const UUID & id)
void MemoryAccessStorage::removeNoLock(const UUID & id, Notifications & notifications) void MemoryAccessStorage::removeNoLock(const UUID & id, Notifications & notifications)
{ {
auto it = entries.find(id); auto it = entries_by_id.find(id);
if (it == entries.end()) if (it == entries_by_id.end())
throwNotFound(id); throwNotFound(id);
Entry & entry = it->second; Entry & entry = it->second;
const String & name = entry.entity->getName(); const String & name = entry.entity->getName();
std::type_index type = entry.entity->getType(); EntityType type = entry.entity->getType();
prepareNotifications(entry, true, notifications); prepareNotifications(entry, true, notifications);
/// Do removing. /// Do removing.
names.erase({name, type}); auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
entries.erase(it); entries_by_name.erase(name);
entries_by_id.erase(it);
} }
@ -143,14 +148,17 @@ void MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_
void MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications) void MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, Notifications & notifications)
{ {
auto it = entries.find(id); auto it = entries_by_id.find(id);
if (it == entries.end()) if (it == entries_by_id.end())
throwNotFound(id); throwNotFound(id);
Entry & entry = it->second; Entry & entry = it->second;
auto old_entity = entry.entity; auto old_entity = entry.entity;
auto new_entity = update_func(old_entity); auto new_entity = update_func(old_entity);
if (!new_entity->isTypeOf(old_entity->getType()))
throwBadCast(id, new_entity->getType(), new_entity->getName(), old_entity->getType());
if (*new_entity == *old_entity) if (*new_entity == *old_entity)
return; return;
@ -158,12 +166,12 @@ void MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & updat
if (new_entity->getName() != old_entity->getName()) if (new_entity->getName() != old_entity->getName())
{ {
auto it2 = names.find({new_entity->getName(), new_entity->getType()}); auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(old_entity->getType())];
if (it2 != names.end()) auto it2 = entries_by_name.find(new_entity->getName());
if (it2 != entries_by_name.end())
throwNameCollisionCannotRename(old_entity->getType(), old_entity->getName(), new_entity->getName()); throwNameCollisionCannotRename(old_entity->getType(), old_entity->getName(), new_entity->getName());
names.erase({old_entity->getName(), old_entity->getType()}); entries_by_name[new_entity->getName()] = &entry;
names[std::pair{new_entity->getName(), new_entity->getType()}] = &entry;
} }
prepareNotifications(entry, false, notifications); prepareNotifications(entry, false, notifications);
@ -192,43 +200,47 @@ void MemoryAccessStorage::setAll(const std::vector<std::pair<UUID, AccessEntityP
void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications) void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications)
{ {
/// Get list of the currently used IDs. Later we will remove those of them which are not used anymore. boost::container::flat_set<UUID> not_used_ids;
std::unordered_set<UUID> not_used_ids; std::vector<UUID> conflicting_ids;
for (const auto & id_and_entry : entries)
not_used_ids.emplace(id_and_entry.first);
/// Remove conflicting entities. /// Get the list of currently used IDs. Later we will remove those of them which are not used anymore.
for (const auto & id : entries_by_id | boost::adaptors::map_keys)
not_used_ids.emplace(id);
/// Get the list of conflicting IDs and update the list of currently used ones.
for (const auto & [id, entity] : all_entities) for (const auto & [id, entity] : all_entities)
{ {
auto it = entries.find(id); auto it = entries_by_id.find(id);
if (it != entries.end()) if (it != entries_by_id.end())
{ {
not_used_ids.erase(id); /// ID is used. not_used_ids.erase(id); /// ID is used.
Entry & entry = it->second; Entry & entry = it->second;
if (entry.entity->getType() != entity->getType()) if (entry.entity->getType() != entity->getType())
{ conflicting_ids.emplace_back(id); /// Conflict: same ID, different type.
removeNoLock(id, notifications);
continue;
}
} }
auto it2 = names.find({entity->getName(), entity->getType()});
if (it2 != names.end()) const auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(entity->getType())];
auto it2 = entries_by_name.find(entity->getName());
if (it2 != entries_by_name.end())
{ {
Entry & entry = *(it2->second); Entry & entry = *(it2->second);
if (entry.id != id) if (entry.id != id)
removeNoLock(id, notifications); conflicting_ids.emplace_back(entry.id); /// Conflict: same name and type, different ID.
} }
} }
/// Remove entities which are not used anymore. /// Remove entities which are not used anymore and which are in conflict with new entities.
for (const auto & id : not_used_ids) boost::container::flat_set<UUID> ids_to_remove = std::move(not_used_ids);
boost::range::copy(conflicting_ids, std::inserter(ids_to_remove, ids_to_remove.end()));
for (const auto & id : ids_to_remove)
removeNoLock(id, notifications); removeNoLock(id, notifications);
/// Insert or update entities. /// Insert or update entities.
for (const auto & [id, entity] : all_entities) for (const auto & [id, entity] : all_entities)
{ {
auto it = entries.find(id); auto it = entries_by_id.find(id);
if (it != entries.end()) if (it != entries_by_id.end())
{ {
if (*(it->second.entity) != *entity) if (*(it->second.entity) != *entity)
{ {
@ -244,24 +256,27 @@ void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessE
void MemoryAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const void MemoryAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{ {
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id) for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, remove ? nullptr : entry.entity}); notifications.push_back({handler, entry.id, entity});
auto range = handlers_by_type.equal_range(entry.entity->getType()); for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.entity->getType())])
for (auto it = range.first; it != range.second; ++it) notifications.push_back({handler, entry.id, entity});
notifications.push_back({it->second, entry.id, remove ? nullptr : entry.entity});
} }
ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto handler_it = handlers_by_type.emplace(type, handler); auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, handler_it] return [this, type, handler_it]
{ {
std::lock_guard lock2{mutex}; std::lock_guard lock2{mutex};
handlers_by_type.erase(handler_it); auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
}; };
} }
@ -269,8 +284,8 @@ ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(std::type_index ty
ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto it = entries.find(id); auto it = entries_by_id.find(id);
if (it == entries.end()) if (it == entries_by_id.end())
return {}; return {};
const Entry & entry = it->second; const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler); auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
@ -278,8 +293,8 @@ ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, c
return [this, id, handler_it] return [this, id, handler_it]
{ {
std::lock_guard lock2{mutex}; std::lock_guard lock2{mutex};
auto it2 = entries.find(id); auto it2 = entries_by_id.find(id);
if (it2 != entries.end()) if (it2 != entries_by_id.end())
{ {
const Entry & entry2 = it2->second; const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it); entry2.handlers_by_id.erase(handler_it);
@ -291,8 +306,8 @@ ext::scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, c
bool MemoryAccessStorage::hasSubscriptionImpl(const UUID & id) const bool MemoryAccessStorage::hasSubscriptionImpl(const UUID & id) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto it = entries.find(id); auto it = entries_by_id.find(id);
if (it != entries.end()) if (it != entries_by_id.end())
{ {
const Entry & entry = it->second; const Entry & entry = it->second;
return !entry.handlers_by_id.empty(); return !entry.handlers_by_id.empty();
@ -301,10 +316,10 @@ bool MemoryAccessStorage::hasSubscriptionImpl(const UUID & id) const
} }
bool MemoryAccessStorage::hasSubscriptionImpl(std::type_index type) const bool MemoryAccessStorage::hasSubscriptionImpl(EntityType type) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
auto range = handlers_by_type.equal_range(type); const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return range.first != range.second; return !handlers.empty();
} }
} }

View File

@ -20,8 +20,8 @@ public:
void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities); void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities);
private: private:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override; std::optional<UUID> findImpl(EntityType type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override; std::vector<UUID> findAllImpl(EntityType type) const override;
bool existsImpl(const UUID & id) const override; bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override; AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID & id) const override; String readNameImpl(const UUID & id) const override;
@ -30,9 +30,9 @@ private:
void removeImpl(const UUID & id) override; void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override; void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override; ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override; ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override; bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override; bool hasSubscriptionImpl(EntityType type) const override;
struct Entry struct Entry
{ {
@ -47,18 +47,9 @@ private:
void setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications); void setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications);
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const; void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
using NameTypePair = std::pair<String, std::type_index>;
struct Hash
{
size_t operator()(const NameTypePair & key) const
{
return std::hash<String>{}(key.first) - std::hash<std::type_index>{}(key.second);
}
};
mutable std::mutex mutex; mutable std::mutex mutex;
std::unordered_map<UUID, Entry> entries; /// We want to search entries both by ID and by the pair of name and type. std::unordered_map<UUID, Entry> entries_by_id; /// We want to search entries both by ID and by the pair of name and type.
std::unordered_map<NameTypePair, Entry *, Hash> names; /// and by the pair of name and type. std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(EntityType::MAX)];
mutable std::unordered_multimap<std::type_index, OnChangedHandler> handlers_by_type; mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(EntityType::MAX)];
}; };
} }

View File

@ -38,7 +38,7 @@ MultipleAccessStorage::MultipleAccessStorage(
} }
std::vector<UUID> MultipleAccessStorage::findMultiple(std::type_index type, const String & name) const std::vector<UUID> MultipleAccessStorage::findMultiple(EntityType type, const String & name) const
{ {
std::vector<UUID> ids; std::vector<UUID> ids;
for (const auto & nested_storage : nested_storages) for (const auto & nested_storage : nested_storages)
@ -55,7 +55,7 @@ std::vector<UUID> MultipleAccessStorage::findMultiple(std::type_index type, cons
} }
std::optional<UUID> MultipleAccessStorage::findImpl(std::type_index type, const String & name) const std::optional<UUID> MultipleAccessStorage::findImpl(EntityType type, const String & name) const
{ {
auto ids = findMultiple(type, name); auto ids = findMultiple(type, name);
if (ids.empty()) if (ids.empty())
@ -72,13 +72,13 @@ std::optional<UUID> MultipleAccessStorage::findImpl(std::type_index type, const
} }
throw Exception( throw Exception(
"Found " + getTypeName(type) + " " + backQuote(name) + " in " + std::to_string(ids.size()) "Found " + outputEntityTypeAndName(type, name) + " in " + std::to_string(ids.size())
+ " storages: " + joinStorageNames(storages_with_duplicates), + " storages [" + joinStorageNames(storages_with_duplicates) + "]",
ErrorCodes::ACCESS_ENTITY_FOUND_DUPLICATES); ErrorCodes::ACCESS_ENTITY_FOUND_DUPLICATES);
} }
std::vector<UUID> MultipleAccessStorage::findAllImpl(std::type_index type) const std::vector<UUID> MultipleAccessStorage::findAllImpl(EntityType type) const
{ {
std::vector<UUID> all_ids; std::vector<UUID> all_ids;
for (const auto & nested_storage : nested_storages) for (const auto & nested_storage : nested_storages)
@ -180,11 +180,7 @@ UUID MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool repl
} }
if (!nested_storage_for_insertion) if (!nested_storage_for_insertion)
{ throw Exception("Not found a storage to insert " + entity->outputTypeAndName(), ErrorCodes::ACCESS_STORAGE_FOR_INSERTION_NOT_FOUND);
throw Exception(
"Not found a storage to insert " + entity->getTypeName() + backQuote(entity->getName()),
ErrorCodes::ACCESS_STORAGE_FOR_INSERTION_NOT_FOUND);
}
auto id = replace_if_exists ? nested_storage_for_insertion->insertOrReplace(entity) : nested_storage_for_insertion->insert(entity); auto id = replace_if_exists ? nested_storage_for_insertion->insertOrReplace(entity) : nested_storage_for_insertion->insert(entity);
std::lock_guard lock{ids_cache_mutex}; std::lock_guard lock{ids_cache_mutex};
@ -214,7 +210,7 @@ ext::scope_guard MultipleAccessStorage::subscribeForChangesImpl(const UUID & id,
} }
ext::scope_guard MultipleAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const ext::scope_guard MultipleAccessStorage::subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const
{ {
ext::scope_guard subscriptions; ext::scope_guard subscriptions;
for (const auto & nested_storage : nested_storages) for (const auto & nested_storage : nested_storages)
@ -234,7 +230,7 @@ bool MultipleAccessStorage::hasSubscriptionImpl(const UUID & id) const
} }
bool MultipleAccessStorage::hasSubscriptionImpl(std::type_index type) const bool MultipleAccessStorage::hasSubscriptionImpl(EntityType type) const
{ {
for (const auto & nested_storage : nested_storages) for (const auto & nested_storage : nested_storages)
{ {

View File

@ -15,7 +15,7 @@ public:
MultipleAccessStorage(std::vector<std::unique_ptr<Storage>> nested_storages_); MultipleAccessStorage(std::vector<std::unique_ptr<Storage>> nested_storages_);
std::vector<UUID> findMultiple(std::type_index type, const String & name) const; std::vector<UUID> findMultiple(EntityType type, const String & name) const;
template <typename EntityType> template <typename EntityType>
std::vector<UUID> findMultiple(const String & name) const { return findMultiple(EntityType::TYPE, name); } std::vector<UUID> findMultiple(const String & name) const { return findMultiple(EntityType::TYPE, name); }
@ -29,8 +29,8 @@ public:
const Storage & getStorageByIndex(size_t i) const { return *(nested_storages[i]); } const Storage & getStorageByIndex(size_t i) const { return *(nested_storages[i]); }
protected: protected:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override; std::optional<UUID> findImpl(EntityType type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override; std::vector<UUID> findAllImpl(EntityType type) const override;
bool existsImpl(const UUID & id) const override; bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override; AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID &id) const override; String readNameImpl(const UUID &id) const override;
@ -39,9 +39,9 @@ protected:
void removeImpl(const UUID & id) override; void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override; void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override; ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override; ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override; bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override; bool hasSubscriptionImpl(EntityType type) const override;
private: private:
std::vector<std::unique_ptr<Storage>> nested_storages; std::vector<std::unique_ptr<Storage>> nested_storages;

View File

@ -75,6 +75,8 @@ struct Quota : public IAccessEntity
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); }
static constexpr const Type TYPE = Type::QUOTA;
Type getType() const override { return TYPE; }
static const char * getNameOfResourceType(ResourceType resource_type); static const char * getNameOfResourceType(ResourceType resource_type);
static const char * resourceTypeToKeyword(ResourceType resource_type); static const char * resourceTypeToKeyword(ResourceType resource_type);

View File

@ -11,4 +11,5 @@ bool Role::equal(const IAccessEntity & other) const
const auto & other_role = typeid_cast<const Role &>(other); const auto & other_role = typeid_cast<const Role &>(other);
return (access == other_role.access) && (granted_roles == other_role.granted_roles) && (settings == other_role.settings); return (access == other_role.access) && (granted_roles == other_role.granted_roles) && (settings == other_role.settings);
} }
} }

View File

@ -17,6 +17,8 @@ struct Role : public IAccessEntity
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Role>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Role>(); }
static constexpr const Type TYPE = Type::ROLE;
Type getType() const override { return TYPE; }
}; };
using RolePtr = std::shared_ptr<const Role>; using RolePtr = std::shared_ptr<const Role>;

View File

@ -69,6 +69,8 @@ struct RowPolicy : public IAccessEntity
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<RowPolicy>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<RowPolicy>(); }
static constexpr const Type TYPE = Type::ROW_POLICY;
Type getType() const override { return TYPE; }
/// Which roles or users should use this row policy. /// Which roles or users should use this row policy.
ExtendedRoleSet to_roles; ExtendedRoleSet to_roles;

View File

@ -3,6 +3,7 @@
namespace DB namespace DB
{ {
bool SettingsProfile::equal(const IAccessEntity & other) const bool SettingsProfile::equal(const IAccessEntity & other) const
{ {
if (!IAccessEntity::equal(other)) if (!IAccessEntity::equal(other))
@ -10,4 +11,5 @@ bool SettingsProfile::equal(const IAccessEntity & other) const
const auto & other_profile = typeid_cast<const SettingsProfile &>(other); const auto & other_profile = typeid_cast<const SettingsProfile &>(other);
return (elements == other_profile.elements) && (to_roles == other_profile.to_roles); return (elements == other_profile.elements) && (to_roles == other_profile.to_roles);
} }
} }

View File

@ -18,6 +18,8 @@ struct SettingsProfile : public IAccessEntity
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<SettingsProfile>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<SettingsProfile>(); }
static constexpr const Type TYPE = Type::SETTINGS_PROFILE;
Type getType() const override { return TYPE; }
}; };
using SettingsProfilePtr = std::shared_ptr<const SettingsProfile>; using SettingsProfilePtr = std::shared_ptr<const SettingsProfile>;

View File

@ -13,4 +13,5 @@ bool User::equal(const IAccessEntity & other) const
&& (access == other_user.access) && (granted_roles == other_user.granted_roles) && (default_roles == other_user.default_roles) && (access == other_user.access) && (granted_roles == other_user.granted_roles) && (default_roles == other_user.default_roles)
&& (settings == other_user.settings); && (settings == other_user.settings);
} }
} }

View File

@ -24,6 +24,8 @@ struct User : public IAccessEntity
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<User>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<User>(); }
static constexpr const Type TYPE = Type::USER;
Type getType() const override { return TYPE; }
}; };
using UserPtr = std::shared_ptr<const User>; using UserPtr = std::shared_ptr<const User>;

View File

@ -27,35 +27,24 @@ namespace ErrorCodes
namespace namespace
{ {
char getTypeChar(std::type_index type) using EntityType = IAccessStorage::EntityType;
{ using EntityTypeInfo = IAccessStorage::EntityTypeInfo;
if (type == typeid(User))
return 'U';
if (type == typeid(Quota))
return 'Q';
if (type == typeid(RowPolicy))
return 'P';
if (type == typeid(SettingsProfile))
return 'S';
return 0;
}
UUID generateID(EntityType type, const String & name)
UUID generateID(std::type_index type, const String & name)
{ {
Poco::MD5Engine md5; Poco::MD5Engine md5;
md5.update(name); md5.update(name);
char type_storage_chars[] = " USRSXML"; char type_storage_chars[] = " USRSXML";
type_storage_chars[0] = getTypeChar(type); type_storage_chars[0] = EntityTypeInfo::get(type).unique_char;
md5.update(type_storage_chars, strlen(type_storage_chars)); md5.update(type_storage_chars, strlen(type_storage_chars));
UUID result; UUID result;
memcpy(&result, md5.digest().data(), md5.digestLength()); memcpy(&result, md5.digest().data(), md5.digestLength());
return result; return result;
} }
UUID generateID(const IAccessEntity & entity) { return generateID(entity.getType(), entity.getName()); } UUID generateID(const IAccessEntity & entity) { return generateID(entity.getType(), entity.getName()); }
UserPtr parseUser(const Poco::Util::AbstractConfiguration & config, const String & user_name) UserPtr parseUser(const Poco::Util::AbstractConfiguration & config, const String & user_name)
{ {
auto user = std::make_shared<User>(); auto user = std::make_shared<User>();
@ -95,7 +84,7 @@ namespace
{ {
auto profile_name = config.getString(profile_name_config); auto profile_name = config.getString(profile_name_config);
SettingsProfileElement profile_element; SettingsProfileElement profile_element;
profile_element.parent_profile = generateID(typeid(SettingsProfile), profile_name); profile_element.parent_profile = generateID(EntityType::SETTINGS_PROFILE, profile_name);
user->settings.push_back(std::move(profile_element)); user->settings.push_back(std::move(profile_element));
} }
@ -260,7 +249,7 @@ namespace
for (const auto & user_name : user_names) for (const auto & user_name : user_names)
{ {
if (config.has("users." + user_name + ".quota")) if (config.has("users." + user_name + ".quota"))
quota_to_user_ids[config.getString("users." + user_name + ".quota")].push_back(generateID(typeid(User), user_name)); quota_to_user_ids[config.getString("users." + user_name + ".quota")].push_back(generateID(EntityType::USER, user_name));
} }
Poco::Util::AbstractConfiguration::Keys quota_names; Poco::Util::AbstractConfiguration::Keys quota_names;
@ -346,7 +335,7 @@ namespace
auto policy = std::make_shared<RowPolicy>(); auto policy = std::make_shared<RowPolicy>();
policy->setNameParts(user_name, database, table_name); policy->setNameParts(user_name, database, table_name);
policy->conditions[RowPolicy::SELECT_FILTER] = filter; policy->conditions[RowPolicy::SELECT_FILTER] = filter;
policy->to_roles.add(generateID(typeid(User), user_name)); policy->to_roles.add(generateID(EntityType::USER, user_name));
policies.push_back(policy); policies.push_back(policy);
} }
} }
@ -400,7 +389,7 @@ namespace
{ {
String parent_profile_name = config.getString(profile_config + "." + key); String parent_profile_name = config.getString(profile_config + "." + key);
SettingsProfileElement profile_element; SettingsProfileElement profile_element;
profile_element.parent_profile = generateID(typeid(SettingsProfile), parent_profile_name); profile_element.parent_profile = generateID(EntityType::SETTINGS_PROFILE, parent_profile_name);
profile->elements.emplace_back(std::move(profile_element)); profile->elements.emplace_back(std::move(profile_element));
continue; continue;
} }
@ -462,13 +451,13 @@ void UsersConfigAccessStorage::setConfiguration(const Poco::Util::AbstractConfig
} }
std::optional<UUID> UsersConfigAccessStorage::findImpl(std::type_index type, const String & name) const std::optional<UUID> UsersConfigAccessStorage::findImpl(EntityType type, const String & name) const
{ {
return memory_storage.find(type, name); return memory_storage.find(type, name);
} }
std::vector<UUID> UsersConfigAccessStorage::findAllImpl(std::type_index type) const std::vector<UUID> UsersConfigAccessStorage::findAllImpl(EntityType type) const
{ {
return memory_storage.findAll(type); return memory_storage.findAll(type);
} }
@ -518,7 +507,7 @@ ext::scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(const UUID &
} }
ext::scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const ext::scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const
{ {
return memory_storage.subscribeForChanges(type, handler); return memory_storage.subscribeForChanges(type, handler);
} }
@ -530,7 +519,7 @@ bool UsersConfigAccessStorage::hasSubscriptionImpl(const UUID & id) const
} }
bool UsersConfigAccessStorage::hasSubscriptionImpl(std::type_index type) const bool UsersConfigAccessStorage::hasSubscriptionImpl(EntityType type) const
{ {
return memory_storage.hasSubscription(type); return memory_storage.hasSubscription(type);
} }

View File

@ -23,8 +23,8 @@ public:
void setConfiguration(const Poco::Util::AbstractConfiguration & config); void setConfiguration(const Poco::Util::AbstractConfiguration & config);
private: private:
std::optional<UUID> findImpl(std::type_index type, const String & name) const override; std::optional<UUID> findImpl(EntityType type, const String & name) const override;
std::vector<UUID> findAllImpl(std::type_index type) const override; std::vector<UUID> findAllImpl(EntityType type) const override;
bool existsImpl(const UUID & id) const override; bool existsImpl(const UUID & id) const override;
AccessEntityPtr readImpl(const UUID & id) const override; AccessEntityPtr readImpl(const UUID & id) const override;
String readNameImpl(const UUID & id) const override; String readNameImpl(const UUID & id) const override;
@ -33,9 +33,9 @@ private:
void removeImpl(const UUID & id) override; void removeImpl(const UUID & id) override;
void updateImpl(const UUID & id, const UpdateFunc & update_func) override; void updateImpl(const UUID & id, const UpdateFunc & update_func) override;
ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override; ext::scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
ext::scope_guard subscribeForChangesImpl(std::type_index type, const OnChangedHandler & handler) const override; ext::scope_guard subscribeForChangesImpl(EntityType type, const OnChangedHandler & handler) const override;
bool hasSubscriptionImpl(const UUID & id) const override; bool hasSubscriptionImpl(const UUID & id) const override;
bool hasSubscriptionImpl(std::type_index type) const override; bool hasSubscriptionImpl(EntityType type) const override;
MemoryAccessStorage memory_storage; MemoryAccessStorage memory_storage;
}; };

View File

@ -493,6 +493,7 @@ namespace ErrorCodes
extern const int NO_REMOTE_SHARD_AVAILABLE = 519; extern const int NO_REMOTE_SHARD_AVAILABLE = 519;
extern const int CANNOT_DETACH_DICTIONARY_AS_TABLE = 520; extern const int CANNOT_DETACH_DICTIONARY_AS_TABLE = 520;
extern const int ATOMIC_RENAME_FAIL = 521; extern const int ATOMIC_RENAME_FAIL = 521;
extern const int UNKNOWN_ROW_POLICY = 522;
extern const int KEEPER_EXCEPTION = 999; extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000; extern const int POCO_EXCEPTION = 1000;

View File

@ -9,54 +9,28 @@
#include <Access/Quota.h> #include <Access/Quota.h>
#include <Access/RowPolicy.h> #include <Access/RowPolicy.h>
#include <Access/SettingsProfile.h> #include <Access/SettingsProfile.h>
#include <boost/range/algorithm/transform.hpp>
namespace DB namespace DB
{ {
namespace namespace ErrorCodes
{ {
using Kind = ASTDropAccessEntityQuery::Kind; extern const int NOT_IMPLEMENTED;
std::type_index getType(Kind kind)
{
switch (kind)
{
case Kind::USER: return typeid(User);
case Kind::ROLE: return typeid(Role);
case Kind::QUOTA: return typeid(Quota);
case Kind::ROW_POLICY: return typeid(RowPolicy);
case Kind::SETTINGS_PROFILE: return typeid(SettingsProfile);
}
__builtin_unreachable();
}
AccessType getRequiredAccessType(Kind kind)
{
switch (kind)
{
case Kind::USER: return AccessType::DROP_USER;
case Kind::ROLE: return AccessType::DROP_ROLE;
case Kind::QUOTA: return AccessType::DROP_QUOTA;
case Kind::ROW_POLICY: return AccessType::DROP_ROW_POLICY;
case Kind::SETTINGS_PROFILE: return AccessType::DROP_SETTINGS_PROFILE;
}
__builtin_unreachable();
}
} }
using EntityType = IAccessEntity::Type;
BlockIO InterpreterDropAccessEntityQuery::execute() BlockIO InterpreterDropAccessEntityQuery::execute()
{ {
auto & query = query_ptr->as<ASTDropAccessEntityQuery &>(); auto & query = query_ptr->as<ASTDropAccessEntityQuery &>();
auto & access_control = context.getAccessControlManager(); auto & access_control = context.getAccessControlManager();
context.checkAccess(getRequiredAccess());
std::type_index type = getType(query.kind);
context.checkAccess(getRequiredAccessType(query.kind));
if (!query.cluster.empty()) if (!query.cluster.empty())
return executeDDLQueryOnCluster(query_ptr, context); return executeDDLQueryOnCluster(query_ptr, context);
if (query.kind == Kind::ROW_POLICY) if (query.type == EntityType::ROW_POLICY)
{ {
Strings names; Strings names;
for (auto & name_parts : query.row_policies_name_parts) for (auto & name_parts : query.row_policies_name_parts)
@ -73,10 +47,28 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
} }
if (query.if_exists) if (query.if_exists)
access_control.tryRemove(access_control.find(type, query.names)); access_control.tryRemove(access_control.find(query.type, query.names));
else else
access_control.remove(access_control.getIDs(type, query.names)); access_control.remove(access_control.getIDs(query.type, query.names));
return {}; return {};
} }
AccessRightsElements InterpreterDropAccessEntityQuery::getRequiredAccess() const
{
const auto & query = query_ptr->as<const ASTDropAccessEntityQuery &>();
AccessRightsElements res;
switch (query.type)
{
case EntityType::USER: res.emplace_back(AccessType::DROP_USER); return res;
case EntityType::ROLE: res.emplace_back(AccessType::DROP_ROLE); return res;
case EntityType::SETTINGS_PROFILE: res.emplace_back(AccessType::DROP_SETTINGS_PROFILE); return res;
case EntityType::ROW_POLICY: res.emplace_back(AccessType::DROP_ROW_POLICY); return res;
case EntityType::QUOTA: res.emplace_back(AccessType::DROP_QUOTA); return res;
case EntityType::MAX: break;
}
throw Exception(
toString(query.type) + ": type is not supported by DROP query", ErrorCodes::NOT_IMPLEMENTED);
}
} }

View File

@ -6,6 +6,8 @@
namespace DB namespace DB
{ {
class AccessRightsElements;
class InterpreterDropAccessEntityQuery : public IInterpreter class InterpreterDropAccessEntityQuery : public IInterpreter
{ {
public: public:
@ -14,6 +16,8 @@ public:
BlockIO execute() override; BlockIO execute() override;
private: private:
AccessRightsElements getRequiredAccess() const;
ASTPtr query_ptr; ASTPtr query_ptr;
Context & context; Context & context;
}; };

View File

@ -30,9 +30,10 @@ namespace DB
{ {
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int LOGICAL_ERROR; extern const int NOT_IMPLEMENTED;
} }
namespace namespace
{ {
ASTPtr getCreateQueryImpl( ASTPtr getCreateQueryImpl(
@ -203,23 +204,10 @@ namespace
return getCreateQueryImpl(*quota, manager, attach_mode); return getCreateQueryImpl(*quota, manager, attach_mode);
if (const SettingsProfile * profile = typeid_cast<const SettingsProfile *>(&entity)) if (const SettingsProfile * profile = typeid_cast<const SettingsProfile *>(&entity))
return getCreateQueryImpl(*profile, manager, attach_mode); return getCreateQueryImpl(*profile, manager, attach_mode);
throw Exception("Unexpected type of access entity: " + entity.getTypeName(), ErrorCodes::LOGICAL_ERROR); throw Exception(entity.outputTypeAndName() + ": type is not supported by SHOW CREATE query", ErrorCodes::NOT_IMPLEMENTED);
} }
using Kind = ASTShowCreateAccessEntityQuery::Kind; using EntityType = IAccessEntity::Type;
std::type_index getType(Kind kind)
{
switch (kind)
{
case Kind::USER: return typeid(User);
case Kind::ROLE: return typeid(Role);
case Kind::QUOTA: return typeid(Quota);
case Kind::ROW_POLICY: return typeid(RowPolicy);
case Kind::SETTINGS_PROFILE: return typeid(SettingsProfile);
}
__builtin_unreachable();
}
} }
@ -274,8 +262,7 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(ASTShowCreateAcces
return getCreateQueryImpl(*quota, &access_control, false); return getCreateQueryImpl(*quota, &access_control, false);
} }
auto type = getType(show_query.kind); if (show_query.type == Type::ROW_POLICY)
if (show_query.kind == Kind::ROW_POLICY)
{ {
if (show_query.row_policy_name_parts.database.empty()) if (show_query.row_policy_name_parts.database.empty())
show_query.row_policy_name_parts.database = context.getCurrentDatabase(); show_query.row_policy_name_parts.database = context.getCurrentDatabase();
@ -283,30 +270,30 @@ ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(ASTShowCreateAcces
return getCreateQueryImpl(*policy, &access_control, false); return getCreateQueryImpl(*policy, &access_control, false);
} }
auto entity = access_control.read(access_control.getID(type, show_query.name)); auto entity = access_control.read(access_control.getID(show_query.type, show_query.name));
return getCreateQueryImpl(*entity, &access_control, false); return getCreateQueryImpl(*entity, &access_control, false);
} }
AccessRightsElements InterpreterShowCreateAccessEntityQuery::getRequiredAccess() const
{
const auto & show_query = query_ptr->as<ASTShowCreateAccessEntityQuery &>();
AccessRightsElements res;
switch (show_query.kind)
{
case Kind::USER: res.emplace_back(AccessType::SHOW_USERS); break;
case Kind::ROLE: res.emplace_back(AccessType::SHOW_ROLES); break;
case Kind::ROW_POLICY: res.emplace_back(AccessType::SHOW_ROW_POLICIES); break;
case Kind::SETTINGS_PROFILE: res.emplace_back(AccessType::SHOW_SETTINGS_PROFILES); break;
case Kind::QUOTA: res.emplace_back(AccessType::SHOW_QUOTAS); break;
}
return res;
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getAttachQuery(const IAccessEntity & entity) ASTPtr InterpreterShowCreateAccessEntityQuery::getAttachQuery(const IAccessEntity & entity)
{ {
return getCreateQueryImpl(entity, nullptr, true); return getCreateQueryImpl(entity, nullptr, true);
} }
AccessRightsElements InterpreterShowCreateAccessEntityQuery::getRequiredAccess() const
{
const auto & show_query = query_ptr->as<const ASTShowCreateAccessEntityQuery &>();
AccessRightsElements res;
switch (show_query.type)
{
case EntityType::USER: res.emplace_back(AccessType::SHOW_USERS); return res;
case EntityType::ROLE: res.emplace_back(AccessType::SHOW_ROLES); return res;
case EntityType::SETTINGS_PROFILE: res.emplace_back(AccessType::SHOW_SETTINGS_PROFILES); return res;
case EntityType::ROW_POLICY: res.emplace_back(AccessType::SHOW_ROW_POLICIES); return res;
case EntityType::QUOTA: res.emplace_back(AccessType::SHOW_QUOTAS); return res;
case EntityType::MAX: break;
}
throw Exception(toString(show_query.type) + ": type is not supported by SHOW CREATE query", ErrorCodes::NOT_IMPLEMENTED);
}
} }

View File

@ -105,7 +105,7 @@ namespace
return getGrantQueriesImpl(*user, manager, attach_mode); return getGrantQueriesImpl(*user, manager, attach_mode);
if (const Role * role = typeid_cast<const Role *>(&entity)) if (const Role * role = typeid_cast<const Role *>(&entity))
return getGrantQueriesImpl(*role, manager, attach_mode); return getGrantQueriesImpl(*role, manager, attach_mode);
throw Exception("Unexpected type of access entity: " + entity.getTypeName(), ErrorCodes::LOGICAL_ERROR); throw Exception(entity.outputTypeAndName() + " is expected to be user or role", ErrorCodes::LOGICAL_ERROR);
} }
} }

View File

@ -4,34 +4,12 @@
namespace DB namespace DB
{ {
namespace using EntityTypeInfo = IAccessEntity::TypeInfo;
{
using Kind = ASTDropAccessEntityQuery::Kind;
const char * getKeyword(Kind kind)
{
switch (kind)
{
case Kind::USER: return "USER";
case Kind::ROLE: return "ROLE";
case Kind::QUOTA: return "QUOTA";
case Kind::ROW_POLICY: return "ROW POLICY";
case Kind::SETTINGS_PROFILE: return "SETTINGS PROFILE";
}
__builtin_unreachable();
}
}
ASTDropAccessEntityQuery::ASTDropAccessEntityQuery(Kind kind_)
: kind(kind_)
{
}
String ASTDropAccessEntityQuery::getID(char) const String ASTDropAccessEntityQuery::getID(char) const
{ {
return String("DROP ") + getKeyword(kind) + " query"; return String("DROP ") + toString(type) + " query";
} }
@ -44,11 +22,11 @@ ASTPtr ASTDropAccessEntityQuery::clone() const
void ASTDropAccessEntityQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const void ASTDropAccessEntityQuery::formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{ {
settings.ostr << (settings.hilite ? hilite_keyword : "") settings.ostr << (settings.hilite ? hilite_keyword : "")
<< "DROP " << getKeyword(kind) << "DROP " << EntityTypeInfo::get(type).name
<< (if_exists ? " IF EXISTS" : "") << (if_exists ? " IF EXISTS" : "")
<< (settings.hilite ? hilite_none : ""); << (settings.hilite ? hilite_none : "");
if (kind == Kind::ROW_POLICY) if (type == EntityType::ROW_POLICY)
{ {
bool need_comma = false; bool need_comma = false;
for (const auto & name_parts : row_policies_name_parts) for (const auto & name_parts : row_policies_name_parts)

View File

@ -17,21 +17,13 @@ namespace DB
class ASTDropAccessEntityQuery : public IAST, public ASTQueryWithOnCluster class ASTDropAccessEntityQuery : public IAST, public ASTQueryWithOnCluster
{ {
public: public:
enum class Kind using EntityType = IAccessEntity::Type;
{
USER,
ROLE,
QUOTA,
ROW_POLICY,
SETTINGS_PROFILE,
};
const Kind kind; EntityType type;
bool if_exists = false; bool if_exists = false;
Strings names; Strings names;
std::vector<RowPolicy::NameParts> row_policies_name_parts; std::vector<RowPolicy::NameParts> row_policies_name_parts;
ASTDropAccessEntityQuery(Kind kind_);
String getID(char) const override; String getID(char) const override;
ASTPtr clone() const override; ASTPtr clone() const override;
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override; void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;

View File

@ -4,34 +4,12 @@
namespace DB namespace DB
{ {
namespace using EntityTypeInfo = IAccessEntity::TypeInfo;
{
using Kind = ASTShowCreateAccessEntityQuery::Kind;
const char * getKeyword(Kind kind)
{
switch (kind)
{
case Kind::USER: return "USER";
case Kind::ROLE: return "ROLE";
case Kind::QUOTA: return "QUOTA";
case Kind::ROW_POLICY: return "ROW POLICY";
case Kind::SETTINGS_PROFILE: return "SETTINGS PROFILE";
}
__builtin_unreachable();
}
}
ASTShowCreateAccessEntityQuery::ASTShowCreateAccessEntityQuery(Kind kind_)
: kind(kind_)
{
}
String ASTShowCreateAccessEntityQuery::getID(char) const String ASTShowCreateAccessEntityQuery::getID(char) const
{ {
return String("SHOW CREATE ") + getKeyword(kind) + " query"; return String("SHOW CREATE ") + toString(type) + " query";
} }
@ -44,7 +22,7 @@ ASTPtr ASTShowCreateAccessEntityQuery::clone() const
void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const
{ {
settings.ostr << (settings.hilite ? hilite_keyword : "") settings.ostr << (settings.hilite ? hilite_keyword : "")
<< "SHOW CREATE " << getKeyword(kind) << "SHOW CREATE " << EntityTypeInfo::get(type).name
<< (settings.hilite ? hilite_none : ""); << (settings.hilite ? hilite_none : "");
if (current_user) if (current_user)
@ -52,7 +30,7 @@ void ASTShowCreateAccessEntityQuery::formatQueryImpl(const FormatSettings & sett
} }
else if (current_quota) else if (current_quota)
settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : ""); settings.ostr << (settings.hilite ? hilite_keyword : "") << " CURRENT" << (settings.hilite ? hilite_none : "");
else if (kind == Kind::ROW_POLICY) else if (type == EntityType::ROW_POLICY)
{ {
const String & database = row_policy_name_parts.database; const String & database = row_policy_name_parts.database;
const String & table_name = row_policy_name_parts.table_name; const String & table_name = row_policy_name_parts.table_name;

View File

@ -15,22 +15,14 @@ namespace DB
class ASTShowCreateAccessEntityQuery : public ASTQueryWithOutput class ASTShowCreateAccessEntityQuery : public ASTQueryWithOutput
{ {
public: public:
enum class Kind using EntityType = IAccessEntity::Type;
{
USER,
ROLE,
QUOTA,
ROW_POLICY,
SETTINGS_PROFILE,
};
const Kind kind; EntityType type;
String name; String name;
bool current_quota = false; bool current_quota = false;
bool current_user = false; bool current_user = false;
RowPolicy::NameParts row_policy_name_parts; RowPolicy::NameParts row_policy_name_parts;
ASTShowCreateAccessEntityQuery(Kind kind_);
String getID(char) const override; String getID(char) const override;
ASTPtr clone() const override; ASTPtr clone() const override;

View File

@ -4,12 +4,16 @@
#include <Parsers/parseIdentifierOrStringLiteral.h> #include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseDatabaseAndTableName.h> #include <Parsers/parseDatabaseAndTableName.h>
#include <Parsers/parseUserName.h> #include <Parsers/parseUserName.h>
#include <ext/range.h>
namespace DB namespace DB
{ {
namespace namespace
{ {
using EntityType = IAccessEntity::Type;
using EntityTypeInfo = IAccessEntity::TypeInfo;
bool parseNames(IParserBase::Pos & pos, Expected & expected, Strings & names) bool parseNames(IParserBase::Pos & pos, Expected & expected, Strings & names)
{ {
return IParserBase::wrapParseImpl(pos, [&] return IParserBase::wrapParseImpl(pos, [&]
@ -79,19 +83,17 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
if (!ParserKeyword{"DROP"}.ignore(pos, expected)) if (!ParserKeyword{"DROP"}.ignore(pos, expected))
return false; return false;
using Kind = ASTDropAccessEntityQuery::Kind; std::optional<EntityType> type;
Kind kind; for (auto type_i : ext::range(EntityType::MAX))
if (ParserKeyword{"USER"}.ignore(pos, expected)) {
kind = Kind::USER; const auto & type_info = EntityTypeInfo::get(type_i);
else if (ParserKeyword{"ROLE"}.ignore(pos, expected)) if (ParserKeyword{type_info.name.c_str()}.ignore(pos, expected)
kind = Kind::ROLE; || (!type_info.alias.empty() && ParserKeyword{type_info.alias.c_str()}.ignore(pos, expected)))
else if (ParserKeyword{"QUOTA"}.ignore(pos, expected)) {
kind = Kind::QUOTA; type = type_i;
else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected)) }
kind = Kind::ROW_POLICY; }
else if (ParserKeyword{"SETTINGS PROFILE"}.ignore(pos, expected) || ParserKeyword{"PROFILE"}.ignore(pos, expected)) if (!type)
kind = Kind::SETTINGS_PROFILE;
else
return false; return false;
bool if_exists = false; bool if_exists = false;
@ -101,12 +103,12 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
Strings names; Strings names;
std::vector<RowPolicy::NameParts> row_policies_name_parts; std::vector<RowPolicy::NameParts> row_policies_name_parts;
if ((kind == Kind::USER) || (kind == Kind::ROLE)) if ((type == EntityType::USER) || (type == EntityType::ROLE))
{ {
if (!parseUserNames(pos, expected, names)) if (!parseUserNames(pos, expected, names))
return false; return false;
} }
else if (kind == Kind::ROW_POLICY) else if (type == EntityType::ROW_POLICY)
{ {
if (!parseRowPolicyNames(pos, expected, row_policies_name_parts)) if (!parseRowPolicyNames(pos, expected, row_policies_name_parts))
return false; return false;
@ -124,9 +126,10 @@ bool ParserDropAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected &
return false; return false;
} }
auto query = std::make_shared<ASTDropAccessEntityQuery>(kind); auto query = std::make_shared<ASTDropAccessEntityQuery>();
node = query; node = query;
query->type = *type;
query->if_exists = if_exists; query->if_exists = if_exists;
query->cluster = std::move(cluster); query->cluster = std::move(cluster);
query->names = std::move(names); query->names = std::move(names);

View File

@ -4,29 +4,32 @@
#include <Parsers/parseIdentifierOrStringLiteral.h> #include <Parsers/parseIdentifierOrStringLiteral.h>
#include <Parsers/parseDatabaseAndTableName.h> #include <Parsers/parseDatabaseAndTableName.h>
#include <Parsers/parseUserName.h> #include <Parsers/parseUserName.h>
#include <ext/range.h>
#include <assert.h> #include <assert.h>
namespace DB namespace DB
{ {
using EntityType = IAccessEntity::Type;
using EntityTypeInfo = IAccessEntity::TypeInfo;
bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expected)
{ {
if (!ParserKeyword{"SHOW CREATE"}.ignore(pos, expected)) if (!ParserKeyword{"SHOW CREATE"}.ignore(pos, expected))
return false; return false;
using Kind = ASTShowCreateAccessEntityQuery::Kind; std::optional<EntityType> type;
Kind kind; for (auto type_i : ext::range(EntityType::MAX))
if (ParserKeyword{"USER"}.ignore(pos, expected)) {
kind = Kind::USER; const auto & type_info = EntityTypeInfo::get(type_i);
else if (ParserKeyword{"QUOTA"}.ignore(pos, expected)) if (ParserKeyword{type_info.name.c_str()}.ignore(pos, expected)
kind = Kind::QUOTA; || (!type_info.alias.empty() && ParserKeyword{type_info.alias.c_str()}.ignore(pos, expected)))
else if (ParserKeyword{"POLICY"}.ignore(pos, expected) || ParserKeyword{"ROW POLICY"}.ignore(pos, expected)) {
kind = Kind::ROW_POLICY; type = type_i;
else if (ParserKeyword{"ROLE"}.ignore(pos, expected)) }
kind = Kind::ROLE; }
else if (ParserKeyword{"SETTINGS PROFILE"}.ignore(pos, expected) || ParserKeyword{"PROFILE"}.ignore(pos, expected)) if (!type)
kind = Kind::SETTINGS_PROFILE;
else
return false; return false;
String name; String name;
@ -34,17 +37,17 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe
bool current_user = false; bool current_user = false;
RowPolicy::NameParts row_policy_name_parts; RowPolicy::NameParts row_policy_name_parts;
if (kind == Kind::USER) if (type == EntityType::USER)
{ {
if (!parseUserNameOrCurrentUserTag(pos, expected, name, current_user)) if (!parseUserNameOrCurrentUserTag(pos, expected, name, current_user))
current_user = true; current_user = true;
} }
else if (kind == Kind::ROLE) else if (type == EntityType::ROLE)
{ {
if (!parseRoleName(pos, expected, name)) if (!parseRoleName(pos, expected, name))
return false; return false;
} }
else if (kind == Kind::ROW_POLICY) else if (type == EntityType::ROW_POLICY)
{ {
String & database = row_policy_name_parts.database; String & database = row_policy_name_parts.database;
String & table_name = row_policy_name_parts.table_name; String & table_name = row_policy_name_parts.table_name;
@ -53,7 +56,7 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe
|| !parseDatabaseAndTableName(pos, expected, database, table_name)) || !parseDatabaseAndTableName(pos, expected, database, table_name))
return false; return false;
} }
else if (kind == Kind::QUOTA) else if (type == EntityType::QUOTA)
{ {
if (ParserKeyword{"CURRENT"}.ignore(pos, expected)) if (ParserKeyword{"CURRENT"}.ignore(pos, expected))
{ {
@ -70,15 +73,16 @@ bool ParserShowCreateAccessEntityQuery::parseImpl(Pos & pos, ASTPtr & node, Expe
current_quota = true; current_quota = true;
} }
} }
else if (kind == Kind::SETTINGS_PROFILE) else if (type == EntityType::SETTINGS_PROFILE)
{ {
if (!parseIdentifierOrStringLiteral(pos, expected, name)) if (!parseIdentifierOrStringLiteral(pos, expected, name))
return false; return false;
} }
auto query = std::make_shared<ASTShowCreateAccessEntityQuery>(kind); auto query = std::make_shared<ASTShowCreateAccessEntityQuery>();
node = query; node = query;
query->type = *type;
query->name = std::move(name); query->name = std::move(name);
query->current_quota = current_quota; query->current_quota = current_quota;
query->current_user = current_user; query->current_user = current_user;

View File

@ -35,7 +35,7 @@ def test_access_control_on_cluster():
assert ch3.query("SHOW GRANTS FOR Alex") == "" assert ch3.query("SHOW GRANTS FOR Alex") == ""
ch2.query("DROP USER Alex ON CLUSTER 'cluster'") ch2.query("DROP USER Alex ON CLUSTER 'cluster'")
assert "User `Alex` not found" in ch1.query_and_get_error("SHOW CREATE USER Alex") assert "There is no user `Alex`" in ch1.query_and_get_error("SHOW CREATE USER Alex")
assert "User `Alex` not found" in ch2.query_and_get_error("SHOW CREATE USER Alex") assert "There is no user `Alex`" in ch2.query_and_get_error("SHOW CREATE USER Alex")
assert "User `Alex` not found" in ch3.query_and_get_error("SHOW CREATE USER Alex") assert "There is no user `Alex`" in ch3.query_and_get_error("SHOW CREATE USER Alex")

View File

@ -97,9 +97,9 @@ def test_drop():
def check(): def check():
assert instance.query("SHOW CREATE USER u1") == "CREATE USER u1\n" assert instance.query("SHOW CREATE USER u1") == "CREATE USER u1\n"
assert instance.query("SHOW CREATE SETTINGS PROFILE s2") == "CREATE SETTINGS PROFILE s2\n" assert instance.query("SHOW CREATE SETTINGS PROFILE s2") == "CREATE SETTINGS PROFILE s2\n"
assert "User `u2` not found" in instance.query_and_get_error("SHOW CREATE USER u2") assert "There is no user `u2`" in instance.query_and_get_error("SHOW CREATE USER u2")
assert "Row policy `p ON mydb.mytable` not found" in instance.query_and_get_error("SHOW CREATE ROW POLICY p ON mydb.mytable") assert "There is no row policy `p ON mydb.mytable`" in instance.query_and_get_error("SHOW CREATE ROW POLICY p ON mydb.mytable")
assert "Quota `q` not found" in instance.query_and_get_error("SHOW CREATE QUOTA q") assert "There is no quota `q`" in instance.query_and_get_error("SHOW CREATE QUOTA q")
check() check()
instance.restart_clickhouse() # Check persistency instance.restart_clickhouse() # Check persistency