Rework notifications used in access management.

This commit is contained in:
Vitaly Baranov 2022-05-16 20:43:55 +02:00
parent 9ccddc44c6
commit 58f4a86ec7
27 changed files with 561 additions and 754 deletions

View File

@ -1314,7 +1314,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
global_context->setConfigReloadCallback([&]()
{
main_config_reloader->reload();
access_control.reloadUsersConfigs();
access_control.reload();
});
/// Limit on total number of concurrently executed queries.
@ -1406,6 +1406,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
/// Stop reloading of the main config. This must be done before `global_context->shutdown()` because
/// otherwise the reloading may pass a changed config to some destroyed parts of ContextSharedPart.
main_config_reloader.reset();
access_control.stopPeriodicReloading();
async_metrics.stop();
@ -1629,7 +1630,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
buildLoggers(config(), logger());
main_config_reloader->start();
access_control.startPeriodicReloadingUsersConfigs();
access_control.startPeriodicReloading();
if (dns_cache_updater)
dns_cache_updater->start();

View File

@ -0,0 +1,122 @@
#include <Access/AccessChangesNotifier.h>
#include <boost/range/algorithm/copy.hpp>
namespace DB
{
AccessChangesNotifier::AccessChangesNotifier() : handlers(std::make_shared<Handlers>())
{
}
AccessChangesNotifier::~AccessChangesNotifier() = default;
void AccessChangesNotifier::onEntityAdded(const UUID & id, const AccessEntityPtr & new_entity)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.entity = new_entity;
event.type = new_entity->getType();
queue.push(std::move(event));
}
void AccessChangesNotifier::onEntityUpdated(const UUID & id, const AccessEntityPtr & changed_entity)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.entity = changed_entity;
event.type = changed_entity->getType();
queue.push(std::move(event));
}
void AccessChangesNotifier::onEntityRemoved(const UUID & id, AccessEntityType type)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.type = type;
queue.push(std::move(event));
}
scope_guard AccessChangesNotifier::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler)
{
std::lock_guard lock{handlers->mutex};
auto & list = handlers->by_type[static_cast<size_t>(type)];
list.push_back(handler);
auto handler_it = std::prev(list.end());
return [handlers=handlers, type, handler_it]
{
std::lock_guard lock2{handlers->mutex};
auto & list2 = handlers->by_type[static_cast<size_t>(type)];
list2.erase(handler_it);
};
}
scope_guard AccessChangesNotifier::subscribeForChanges(const UUID & id, const OnChangedHandler & handler)
{
std::lock_guard lock{handlers->mutex};
auto it = handlers->by_id.emplace(id, std::list<OnChangedHandler>{}).first;
auto & list = it->second;
list.push_back(handler);
auto handler_it = std::prev(list.end());
return [handlers=handlers, it, handler_it]
{
std::lock_guard lock2{handlers->mutex};
auto & list2 = it->second;
list2.erase(handler_it);
if (list2.empty())
handlers->by_id.erase(it);
};
}
scope_guard AccessChangesNotifier::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler)
{
scope_guard subscriptions;
for (const auto & id : ids)
subscriptions.join(subscribeForChanges(id, handler));
return subscriptions;
}
void AccessChangesNotifier::sendNotifications()
{
/// Only one thread can send notification at any time.
std::lock_guard sending_notifications_lock{sending_notifications};
std::unique_lock queue_lock{queue_mutex};
while (!queue.empty())
{
auto event = std::move(queue.front());
queue.pop();
queue_lock.unlock();
std::vector<OnChangedHandler> current_handlers;
{
std::lock_guard handlers_lock{handlers->mutex};
boost::range::copy(handlers->by_type[static_cast<size_t>(event.type)], std::back_inserter(current_handlers));
auto it = handlers->by_id.find(event.id);
if (it != handlers->by_id.end())
boost::range::copy(it->second, std::back_inserter(current_handlers));
}
for (const auto & handler : current_handlers)
{
try
{
handler(event.id, event.entity);
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
queue_lock.lock();
}
}
}

View File

@ -0,0 +1,73 @@
#pragma once
#include <Access/IAccessEntity.h>
#include <base/scope_guard.h>
#include <list>
#include <queue>
#include <unordered_map>
namespace DB
{
/// Helper class implementing subscriptions and notifications in access management.
class AccessChangesNotifier
{
public:
AccessChangesNotifier();
~AccessChangesNotifier();
using OnChangedHandler
= std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler);
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler)
{
return subscribeForChanges(EntityClassT::TYPE, handler);
}
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler);
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler);
/// Called by access storages after a new access entity has been added.
void onEntityAdded(const UUID & id, const AccessEntityPtr & new_entity);
/// Called by access storages after an access entity has been changed.
void onEntityUpdated(const UUID & id, const AccessEntityPtr & changed_entity);
/// Called by access storages after an access entity has been removed.
void onEntityRemoved(const UUID & id, AccessEntityType type);
/// Sends notifications to subscribers about changes in access entities
/// (added with previous calls onEntityAdded(), onEntityUpdated(), onEntityRemoved()).
void sendNotifications();
private:
struct Handlers
{
std::unordered_map<UUID, std::list<OnChangedHandler>> by_id;
std::list<OnChangedHandler> by_type[static_cast<size_t>(AccessEntityType::MAX)];
std::mutex mutex;
};
/// shared_ptr is here for safety because AccessChangesNotifier can be destroyed before all subscriptions are removed.
std::shared_ptr<Handlers> handlers;
struct Event
{
UUID id;
AccessEntityPtr entity;
AccessEntityType type;
};
std::queue<Event> queue;
std::mutex queue_mutex;
std::mutex sending_notifications;
};
}

View File

@ -14,6 +14,7 @@
#include <Access/SettingsProfilesCache.h>
#include <Access/User.h>
#include <Access/ExternalAuthenticators.h>
#include <Access/AccessChangesNotifier.h>
#include <Core/Settings.h>
#include <base/find_symbols.h>
#include <Poco/ExpireCache.h>
@ -142,7 +143,8 @@ AccessControl::AccessControl()
quota_cache(std::make_unique<QuotaCache>(*this)),
settings_profiles_cache(std::make_unique<SettingsProfilesCache>(*this)),
external_authenticators(std::make_unique<ExternalAuthenticators>()),
custom_settings_prefixes(std::make_unique<CustomSettingsPrefixes>())
custom_settings_prefixes(std::make_unique<CustomSettingsPrefixes>()),
changes_notifier(std::make_unique<AccessChangesNotifier>())
{
}
@ -231,35 +233,6 @@ void AccessControl::addUsersConfigStorage(
LOG_DEBUG(getLogger(), "Added {} access storage '{}', path: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getPath());
}
void AccessControl::reloadUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->reload();
}
}
void AccessControl::startPeriodicReloadingUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->startPeriodicReloading();
}
}
void AccessControl::stopPeriodicReloadingUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->stopPeriodicReloading();
}
}
void AccessControl::addReplicatedStorage(
const String & storage_name_,
@ -272,10 +245,9 @@ void AccessControl::addReplicatedStorage(
if (auto replicated_storage = typeid_cast<std::shared_ptr<ReplicatedAccessStorage>>(storage))
return;
}
auto new_storage = std::make_shared<ReplicatedAccessStorage>(storage_name_, zookeeper_path_, get_zookeeper_function_);
auto new_storage = std::make_shared<ReplicatedAccessStorage>(storage_name_, zookeeper_path_, get_zookeeper_function_, *changes_notifier);
addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}'", String(new_storage->getStorageType()), new_storage->getStorageName());
new_storage->startup();
}
void AccessControl::addDiskStorage(const String & directory_, bool readonly_)
@ -298,7 +270,7 @@ void AccessControl::addDiskStorage(const String & storage_name_, const String &
}
}
}
auto new_storage = std::make_shared<DiskAccessStorage>(storage_name_, directory_, readonly_);
auto new_storage = std::make_shared<DiskAccessStorage>(storage_name_, directory_, readonly_, *changes_notifier);
addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}', path: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getPath());
}
@ -312,7 +284,7 @@ void AccessControl::addMemoryStorage(const String & storage_name_)
if (auto memory_storage = typeid_cast<std::shared_ptr<MemoryAccessStorage>>(storage))
return;
}
auto new_storage = std::make_shared<MemoryAccessStorage>(storage_name_);
auto new_storage = std::make_shared<MemoryAccessStorage>(storage_name_, *changes_notifier);
addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}'", String(new_storage->getStorageType()), new_storage->getStorageName());
}
@ -320,7 +292,7 @@ void AccessControl::addMemoryStorage(const String & storage_name_)
void AccessControl::addLDAPStorage(const String & storage_name_, const Poco::Util::AbstractConfiguration & config_, const String & prefix_)
{
auto new_storage = std::make_shared<LDAPAccessStorage>(storage_name_, this, config_, prefix_);
auto new_storage = std::make_shared<LDAPAccessStorage>(storage_name_, *this, config_, prefix_);
addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}', LDAP server name: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getLDAPServerName());
}
@ -423,6 +395,57 @@ void AccessControl::addStoragesFromMainConfig(
}
void AccessControl::reload()
{
MultipleAccessStorage::reload();
changes_notifier->sendNotifications();
}
scope_guard AccessControl::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(type, handler);
}
scope_guard AccessControl::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(id, handler);
}
scope_guard AccessControl::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(ids, handler);
}
std::optional<UUID> AccessControl::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists)
{
auto id = MultipleAccessStorage::insertImpl(entity, replace_if_exists, throw_if_exists);
if (id)
changes_notifier->sendNotifications();
return id;
}
bool AccessControl::removeImpl(const UUID & id, bool throw_if_not_exists)
{
bool removed = MultipleAccessStorage::removeImpl(id, throw_if_not_exists);
if (removed)
changes_notifier->sendNotifications();
return removed;
}
bool AccessControl::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
bool updated = MultipleAccessStorage::updateImpl(id, update_func, throw_if_not_exists);
if (updated)
changes_notifier->sendNotifications();
return updated;
}
AccessChangesNotifier & AccessControl::getChangesNotifier()
{
return *changes_notifier;
}
UUID AccessControl::authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const
{
try

View File

@ -3,8 +3,8 @@
#include <Access/MultipleAccessStorage.h>
#include <Common/SettingsChanges.h>
#include <Common/ZooKeeper/Common.h>
#include <base/scope_guard.h>
#include <boost/container/flat_set.hpp>
#include <Access/UsersConfigAccessStorage.h>
#include <memory>
@ -40,6 +40,7 @@ class SettingsProfilesCache;
class SettingsProfileElements;
class ClientInfo;
class ExternalAuthenticators;
class AccessChangesNotifier;
struct Settings;
@ -50,6 +51,7 @@ public:
AccessControl();
~AccessControl() override;
/// Initializes access storage (user directories).
void setUpFromMainConfig(const Poco::Util::AbstractConfiguration & config_, const String & config_path_,
const zkutil::GetZooKeeper & get_zookeeper_function_);
@ -74,9 +76,6 @@ public:
const String & preprocessed_dir_,
const zkutil::GetZooKeeper & get_zookeeper_function_ = {});
void reloadUsersConfigs();
void startPeriodicReloadingUsersConfigs();
void stopPeriodicReloadingUsersConfigs();
/// Loads access entities from the directory on the local disk.
/// Use that directory to keep created users/roles/etc.
void addDiskStorage(const String & directory_, bool readonly_ = false);
@ -106,6 +105,26 @@ public:
const String & config_path,
const zkutil::GetZooKeeper & get_zookeeper_function);
/// Reloads and updates entities in this storage. This function is used to implement SYSTEM RELOAD CONFIG.
void reload() override;
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const;
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(EntityClassT::TYPE, handler); }
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const;
void setExternalAuthenticatorsConfig(const Poco::Util::AbstractConfiguration & config);
/// Sets the default profile's name.
/// The default profile's settings are always applied before any other profile's.
void setDefaultProfileName(const String & default_profile_name);
@ -135,9 +154,6 @@ public:
void setOnClusterQueriesRequireClusterGrant(bool enable) { on_cluster_queries_require_cluster_grant = enable; }
bool doesOnClusterQueriesRequireClusterGrant() const { return on_cluster_queries_require_cluster_grant; }
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const;
void setExternalAuthenticatorsConfig(const Poco::Util::AbstractConfiguration & config);
std::shared_ptr<const ContextAccess> getContextAccess(
const UUID & user_id,
const std::vector<UUID> & current_roles,
@ -178,10 +194,17 @@ public:
const ExternalAuthenticators & getExternalAuthenticators() const;
/// Gets manager of notifications.
AccessChangesNotifier & getChangesNotifier();
private:
class ContextAccessCache;
class CustomSettingsPrefixes;
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
std::unique_ptr<ContextAccessCache> context_access_cache;
std::unique_ptr<RoleCache> role_cache;
std::unique_ptr<RowPolicyCache> row_policy_cache;
@ -189,6 +212,7 @@ private:
std::unique_ptr<SettingsProfilesCache> settings_profiles_cache;
std::unique_ptr<ExternalAuthenticators> external_authenticators;
std::unique_ptr<CustomSettingsPrefixes> custom_settings_prefixes;
std::unique_ptr<AccessChangesNotifier> changes_notifier;
std::atomic_bool allow_plaintext_password = true;
std::atomic_bool allow_no_password = true;
std::atomic_bool users_without_row_policies_can_read_rows = false;

View File

@ -149,6 +149,21 @@ ContextAccess::ContextAccess(const AccessControl & access_control_, const Params
}
ContextAccess::~ContextAccess()
{
enabled_settings.reset();
enabled_quota.reset();
enabled_row_policies.reset();
access_with_implicit.reset();
access.reset();
roles_info.reset();
subscription_for_roles_changes.reset();
enabled_roles.reset();
subscription_for_user_change.reset();
user.reset();
}
void ContextAccess::initialize()
{
std::lock_guard lock{mutex};

View File

@ -155,6 +155,8 @@ public:
/// without any limitations. This is used for the global context.
static std::shared_ptr<const ContextAccess> getFullAccess();
~ContextAccess();
private:
friend class AccessControl;
ContextAccess() {} /// NOLINT

View File

@ -1,5 +1,6 @@
#include <Access/DiskAccessStorage.h>
#include <Access/AccessEntityIO.h>
#include <Access/AccessChangesNotifier.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <IO/ReadBufferFromFile.h>
@ -164,13 +165,8 @@ namespace
}
DiskAccessStorage::DiskAccessStorage(const String & directory_path_, bool readonly_)
: DiskAccessStorage(STORAGE_TYPE, directory_path_, readonly_)
{
}
DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_)
: IAccessStorage(storage_name_)
DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_, AccessChangesNotifier & changes_notifier_)
: IAccessStorage(storage_name_), changes_notifier(changes_notifier_)
{
directory_path = makeDirectoryPathCanonical(directory_path_);
readonly = readonly_;
@ -199,7 +195,15 @@ DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String
DiskAccessStorage::~DiskAccessStorage()
{
stopListsWritingThread();
writeLists();
try
{
writeLists();
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
@ -470,19 +474,16 @@ std::optional<String> DiskAccessStorage::readNameImpl(const UUID & id, bool thro
std::optional<UUID> DiskAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
UUID id = generateRandomID();
std::lock_guard lock{mutex};
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists, notifications))
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists))
return id;
return std::nullopt;
}
bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications)
bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{
const String & name = new_entity->getName();
AccessEntityType type = new_entity->getType();
@ -514,7 +515,7 @@ bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & ne
writeAccessEntityToDisk(id, *new_entity);
if (name_collision && replace_if_exists)
removeNoLock(it_by_name->second->id, /* throw_if_not_exists = */ false, notifications);
removeNoLock(it_by_name->second->id, /* throw_if_not_exists = */ false);
/// Do insertion.
auto & entry = entries_by_id[id];
@ -523,22 +524,20 @@ bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & ne
entry.name = name;
entry.entity = new_entity;
entries_by_name[entry.name] = &entry;
prepareNotifications(id, entry, false, notifications);
changes_notifier.onEntityAdded(id, new_entity);
return true;
}
bool DiskAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
return removeNoLock(id, throw_if_not_exists, notifications);
return removeNoLock(id, throw_if_not_exists);
}
bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications)
bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists)
{
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
@ -559,25 +558,24 @@ bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists,
deleteAccessEntityOnDisk(id);
/// Do removing.
prepareNotifications(id, entry, true, notifications);
UUID removed_id = id;
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
entries_by_name.erase(entry.name);
entries_by_id.erase(it);
changes_notifier.onEntityRemoved(removed_id, type);
return true;
}
bool DiskAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
return updateNoLock(id, update_func, throw_if_not_exists, notifications);
return updateNoLock(id, update_func, throw_if_not_exists);
}
bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications)
bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
@ -626,7 +624,8 @@ bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_
entries_by_name[entry.name] = &entry;
}
prepareNotifications(id, entry, false, notifications);
changes_notifier.onEntityUpdated(id, new_entity);
return true;
}
@ -650,74 +649,4 @@ void DiskAccessStorage::deleteAccessEntityOnDisk(const UUID & id) const
throw Exception("Couldn't delete " + file_path, ErrorCodes::FILE_DOESNT_EXIST);
}
void DiskAccessStorage::prepareNotifications(const UUID & id, const Entry & entry, bool remove, Notifications & notifications) const
{
if (!remove && !entry.entity)
return;
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.type)])
notifications.push_back({handler, id, entity});
}
scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
scope_guard DiskAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
bool DiskAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool DiskAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
}

View File

@ -7,14 +7,15 @@
namespace DB
{
class AccessChangesNotifier;
/// Loads and saves access entities on a local disk to a specified directory.
class DiskAccessStorage : public IAccessStorage
{
public:
static constexpr char STORAGE_TYPE[] = "local directory";
DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_ = false);
DiskAccessStorage(const String & directory_path_, bool readonly_ = false);
DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_, AccessChangesNotifier & changes_notifier_);
~DiskAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; }
@ -27,8 +28,6 @@ public:
bool isReadOnly() const override { return readonly; }
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -38,8 +37,6 @@ private:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
void clear();
bool readLists();
@ -50,9 +47,9 @@ private:
void listsWritingThreadFunc();
void stopListsWritingThread();
bool insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications);
bool removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications);
bool insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists);
bool removeNoLock(const UUID & id, bool throw_if_not_exists);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
AccessEntityPtr readAccessEntityFromDisk(const UUID & id) const;
void writeAccessEntityToDisk(const UUID & id, const IAccessEntity & entity) const;
@ -65,11 +62,8 @@ private:
String name;
AccessEntityType type;
mutable AccessEntityPtr entity; /// may be nullptr, if the entity hasn't been loaded yet.
mutable std::list<OnChangedHandler> handlers_by_id;
};
void prepareNotifications(const UUID & id, const Entry & entry, bool remove, Notifications & notifications) const;
String directory_path;
std::atomic<bool> readonly;
std::unordered_map<UUID, Entry> entries_by_id;
@ -79,7 +73,7 @@ private:
ThreadFromGlobalPool lists_writing_thread; /// List files are written in a separate thread.
std::condition_variable lists_writing_thread_should_exit; /// Signals `lists_writing_thread` to exit.
bool lists_writing_thread_is_waiting = false;
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
AccessChangesNotifier & changes_notifier;
mutable std::mutex mutex;
};
}

View File

@ -6,7 +6,7 @@
namespace DB
{
EnabledRoles::EnabledRoles(const Params & params_) : params(params_)
EnabledRoles::EnabledRoles(const Params & params_) : params(params_), handlers(std::make_shared<Handlers>())
{
}
@ -15,42 +15,50 @@ EnabledRoles::~EnabledRoles() = default;
std::shared_ptr<const EnabledRolesInfo> EnabledRoles::getRolesInfo() const
{
std::lock_guard lock{mutex};
std::lock_guard lock{info_mutex};
return info;
}
scope_guard EnabledRoles::subscribeForChanges(const OnChangeHandler & handler) const
{
std::lock_guard lock{mutex};
handlers.push_back(handler);
auto it = std::prev(handlers.end());
std::lock_guard lock{handlers->mutex};
handlers->list.push_back(handler);
auto it = std::prev(handlers->list.end());
return [this, it]
return [handlers=handlers, it]
{
std::lock_guard lock2{mutex};
handlers.erase(it);
std::lock_guard lock2{handlers->mutex};
handlers->list.erase(it);
};
}
void EnabledRoles::setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard & notifications)
void EnabledRoles::setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard * notifications)
{
std::lock_guard lock{mutex};
if (info && info_ && *info == *info_)
return;
info = info_;
std::vector<OnChangeHandler> handlers_to_notify;
boost::range::copy(handlers, std::back_inserter(handlers_to_notify));
notifications.join(scope_guard([info = info, handlers_to_notify = std::move(handlers_to_notify)]
{
for (const auto & handler : handlers_to_notify)
handler(info);
}));
std::lock_guard lock{info_mutex};
if (info && info_ && *info == *info_)
return;
info = info_;
}
if (notifications)
{
std::vector<OnChangeHandler> handlers_to_notify;
{
std::lock_guard lock{handlers->mutex};
boost::range::copy(handlers->list, std::back_inserter(handlers_to_notify));
}
notifications->join(scope_guard(
[info = info, handlers_to_notify = std::move(handlers_to_notify)]
{
for (const auto & handler : handlers_to_notify)
handler(info);
}));
}
}
}

View File

@ -4,6 +4,7 @@
#include <base/scope_guard.h>
#include <boost/container/flat_set.hpp>
#include <list>
#include <memory>
#include <mutex>
#include <vector>
@ -43,12 +44,21 @@ private:
friend class RoleCache;
explicit EnabledRoles(const Params & params_);
void setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard & notifications);
void setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard * notifications);
const Params params;
mutable std::shared_ptr<const EnabledRolesInfo> info;
mutable std::list<OnChangeHandler> handlers;
mutable std::mutex mutex;
std::shared_ptr<const EnabledRolesInfo> info;
mutable std::mutex info_mutex;
struct Handlers
{
std::list<OnChangeHandler> list;
std::mutex mutex;
};
/// shared_ptr is here for safety because EnabledRoles can be destroyed before all subscriptions are removed.
std::shared_ptr<Handlers> handlers;
};
}

View File

@ -410,34 +410,6 @@ bool IAccessStorage::updateImpl(const UUID & id, const UpdateFunc &, bool throw_
}
scope_guard IAccessStorage::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(type, handler);
}
scope_guard IAccessStorage::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(id, handler);
}
scope_guard IAccessStorage::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
{
scope_guard subscriptions;
for (const auto & id : ids)
subscriptions.join(subscribeForChangesImpl(id, handler));
return subscriptions;
}
void IAccessStorage::notify(const Notifications & notifications)
{
for (const auto & [fn, id, new_entity] : notifications)
fn(id, new_entity);
}
UUID IAccessStorage::authenticate(
const Credentials & credentials,
const Poco::Net::IPAddress & address,

View File

@ -3,7 +3,6 @@
#include <Access/IAccessEntity.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <base/scope_guard.h>
#include <functional>
#include <optional>
#include <vector>
@ -22,7 +21,7 @@ enum class AuthenticationType;
/// Contains entities, i.e. instances of classes derived from IAccessEntity.
/// The implementations of this class MUST be thread-safe.
class IAccessStorage
class IAccessStorage : public boost::noncopyable
{
public:
explicit IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {}
@ -41,6 +40,15 @@ public:
/// Returns true if this entity is readonly.
virtual bool isReadOnly(const UUID &) const { return isReadOnly(); }
/// Reloads and updates entities in this storage. This function is used to implement SYSTEM RELOAD CONFIG.
virtual void reload() {}
/// Starts periodic reloading and update of entities in this storage.
virtual void startPeriodicReloading() {}
/// Stops periodic reloading and update of entities in this storage.
virtual void stopPeriodicReloading() {}
/// Returns the identifiers of all the entities of a specified type contained in the storage.
std::vector<UUID> findAll(AccessEntityType type) const;
@ -130,23 +138,6 @@ public:
/// Updates multiple entities in the storage. Returns the list of successfully updated.
std::vector<UUID> tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func);
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const;
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(EntityClassT::TYPE, handler); }
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
virtual bool hasSubscription(AccessEntityType type) const = 0;
virtual bool hasSubscription(const UUID & id) const = 0;
/// Finds a user, check the provided credentials and returns the ID of the user if they are valid.
/// Throws an exception if no such user or credentials are invalid.
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool allow_no_password, bool allow_plaintext_password) const;
@ -160,8 +151,6 @@ protected:
virtual std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
virtual bool removeImpl(const UUID & id, bool throw_if_not_exists);
virtual bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
virtual scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0;
virtual scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const = 0;
virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const;
virtual bool areCredentialsValid(const User & user, const Credentials & credentials, const ExternalAuthenticators & external_authenticators) const;
virtual bool isAddressAllowed(const User & user, const Poco::Net::IPAddress & address) const;
@ -181,9 +170,6 @@ protected:
[[noreturn]] static void throwAddressNotAllowed(const Poco::Net::IPAddress & address);
[[noreturn]] static void throwInvalidCredentials();
[[noreturn]] static void throwAuthenticationTypeNotAllowed(AuthenticationType auth_type);
using Notification = std::tuple<OnChangedHandler, UUID, AccessEntityPtr>;
using Notifications = std::vector<Notification>;
static void notify(const Notifications & notifications);
private:
const String storage_name;

View File

@ -27,10 +27,10 @@ namespace ErrorCodes
}
LDAPAccessStorage::LDAPAccessStorage(const String & storage_name_, AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix)
: IAccessStorage(storage_name_)
LDAPAccessStorage::LDAPAccessStorage(const String & storage_name_, AccessControl & access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix)
: IAccessStorage(storage_name_), access_control(access_control_), memory_storage(storage_name_, access_control.getChangesNotifier())
{
setConfiguration(access_control_, config, prefix);
setConfiguration(config, prefix);
}
@ -40,7 +40,7 @@ String LDAPAccessStorage::getLDAPServerName() const
}
void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix)
void LDAPAccessStorage::setConfiguration(const Poco::Util::AbstractConfiguration & config, const String & prefix)
{
std::scoped_lock lock(mutex);
@ -80,7 +80,6 @@ void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const
}
}
access_control = access_control_;
ldap_server_name = ldap_server_name_cfg;
role_search_params.swap(role_search_params_cfg);
common_role_names.swap(common_roles_cfg);
@ -91,7 +90,7 @@ void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const
granted_role_names.clear();
granted_role_ids.clear();
role_change_subscription = access_control->subscribeForChanges<Role>(
role_change_subscription = access_control.subscribeForChanges<Role>(
[this] (const UUID & id, const AccessEntityPtr & entity)
{
return this->processRoleChange(id, entity);
@ -215,7 +214,7 @@ void LDAPAccessStorage::assignRolesNoLock(User & user, const LDAPClient::SearchR
auto it = granted_role_ids.find(role_name);
if (it == granted_role_ids.end())
{
if (const auto role_id = access_control->find<Role>(role_name))
if (const auto role_id = access_control.find<Role>(role_name))
{
granted_role_names.insert_or_assign(*role_id, role_name);
it = granted_role_ids.insert_or_assign(role_name, *role_id).first;
@ -450,33 +449,6 @@ std::optional<String> LDAPAccessStorage::readNameImpl(const UUID & id, bool thro
}
scope_guard LDAPAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::scoped_lock lock(mutex);
return memory_storage.subscribeForChanges(id, handler);
}
scope_guard LDAPAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::scoped_lock lock(mutex);
return memory_storage.subscribeForChanges(type, handler);
}
bool LDAPAccessStorage::hasSubscription(const UUID & id) const
{
std::scoped_lock lock(mutex);
return memory_storage.hasSubscription(id);
}
bool LDAPAccessStorage::hasSubscription(AccessEntityType type) const
{
std::scoped_lock lock(mutex);
return memory_storage.hasSubscription(type);
}
std::optional<UUID> LDAPAccessStorage::authenticateImpl(
const Credentials & credentials,
const Poco::Net::IPAddress & address,

View File

@ -32,7 +32,7 @@ class LDAPAccessStorage : public IAccessStorage
public:
static constexpr char STORAGE_TYPE[] = "ldap";
explicit LDAPAccessStorage(const String & storage_name_, AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix);
explicit LDAPAccessStorage(const String & storage_name_, AccessControl & access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix);
virtual ~LDAPAccessStorage() override = default;
String getLDAPServerName() const;
@ -42,19 +42,15 @@ public:
virtual String getStorageParamsJSON() const override;
virtual bool isReadOnly() const override { return true; }
virtual bool exists(const UUID & id) const override;
virtual bool hasSubscription(const UUID & id) const override;
virtual bool hasSubscription(AccessEntityType type) const override;
private: // IAccessStorage implementations.
virtual std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
virtual std::vector<UUID> findAllImpl(AccessEntityType type) const override;
virtual AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
virtual std::optional<String> readNameImpl(const UUID & id, bool throw_if_not_exists) const override;
virtual scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
virtual scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const override;
void setConfiguration(AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix);
void setConfiguration(const Poco::Util::AbstractConfiguration & config, const String & prefix);
void processRoleChange(const UUID & id, const AccessEntityPtr & entity);
void applyRoleChangeNoLock(bool grant, const UUID & role_id, const String & role_name);
@ -66,7 +62,7 @@ private: // IAccessStorage implementations.
const ExternalAuthenticators & external_authenticators, LDAPClient::SearchResultsList & role_search_results) const;
mutable std::recursive_mutex mutex;
AccessControl * access_control = nullptr;
AccessControl & access_control;
String ldap_server_name;
LDAPClient::RoleSearchParamsList role_search_params;
std::set<String> common_role_names; // role name that should be granted to all users at all times

View File

@ -1,4 +1,5 @@
#include <Access/MemoryAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
#include <base/scope_guard.h>
#include <boost/container/flat_set.hpp>
#include <boost/range/adaptor/map.hpp>
@ -7,8 +8,8 @@
namespace DB
{
MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_)
: IAccessStorage(storage_name_)
MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_, AccessChangesNotifier & changes_notifier_)
: IAccessStorage(storage_name_), changes_notifier(changes_notifier_)
{
}
@ -63,19 +64,16 @@ AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id, bool throw_if_not
std::optional<UUID> MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
UUID id = generateRandomID();
std::lock_guard lock{mutex};
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists, notifications))
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists))
return id;
return std::nullopt;
}
bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications)
bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{
const String & name = new_entity->getName();
AccessEntityType type = new_entity->getType();
@ -103,7 +101,7 @@ bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr &
if (name_collision && replace_if_exists)
{
const auto & existing_entry = *(it_by_name->second);
removeNoLock(existing_entry.id, /* throw_if_not_exists = */ false, notifications);
removeNoLock(existing_entry.id, /* throw_if_not_exists = */ false);
}
/// Do insertion.
@ -111,22 +109,19 @@ bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr &
entry.id = id;
entry.entity = new_entity;
entries_by_name[name] = &entry;
prepareNotifications(entry, false, notifications);
changes_notifier.onEntityAdded(id, new_entity);
return true;
}
bool MemoryAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
return removeNoLock(id, throw_if_not_exists, notifications);
return removeNoLock(id, throw_if_not_exists);
}
bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications)
bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists)
{
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
@ -141,27 +136,25 @@ bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists
const String & name = entry.entity->getName();
AccessEntityType type = entry.entity->getType();
prepareNotifications(entry, true, notifications);
/// Do removing.
UUID removed_id = id;
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
entries_by_name.erase(name);
entries_by_id.erase(it);
changes_notifier.onEntityRemoved(removed_id, type);
return true;
}
bool MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
return updateNoLock(id, update_func, throw_if_not_exists, notifications);
return updateNoLock(id, update_func, throw_if_not_exists);
}
bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications)
bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
@ -195,7 +188,7 @@ bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & updat
entries_by_name[new_entity->getName()] = &entry;
}
prepareNotifications(entry, false, notifications);
changes_notifier.onEntityUpdated(id, new_entity);
return true;
}
@ -212,16 +205,8 @@ void MemoryAccessStorage::setAll(const std::vector<AccessEntityPtr> & all_entiti
void MemoryAccessStorage::setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
setAllNoLock(all_entities, notifications);
}
void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications)
{
boost::container::flat_set<UUID> not_used_ids;
std::vector<UUID> conflicting_ids;
@ -256,7 +241,7 @@ void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessE
boost::container::flat_set<UUID> ids_to_remove = std::move(not_used_ids);
boost::range::copy(conflicting_ids, std::inserter(ids_to_remove, ids_to_remove.end()));
for (const auto & id : ids_to_remove)
removeNoLock(id, /* throw_if_not_exists = */ false, notifications);
removeNoLock(id, /* throw_if_not_exists = */ false);
/// Insert or update entities.
for (const auto & [id, entity] : all_entities)
@ -269,84 +254,14 @@ void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessE
const AccessEntityPtr & changed_entity = entity;
updateNoLock(id,
[&changed_entity](const AccessEntityPtr &) { return changed_entity; },
/* throw_if_not_exists = */ true,
notifications);
/* throw_if_not_exists = */ true);
}
}
else
{
insertNoLock(id, entity, /* replace_if_exists = */ false, /* throw_if_exists = */ true, notifications);
insertNoLock(id, entity, /* replace_if_exists = */ false, /* throw_if_exists = */ true);
}
}
}
void MemoryAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.entity->getType())])
notifications.push_back({handler, entry.id, entity});
}
scope_guard MemoryAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
bool MemoryAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool MemoryAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
}

View File

@ -9,13 +9,15 @@
namespace DB
{
class AccessChangesNotifier;
/// Implementation of IAccessStorage which keeps all data in memory.
class MemoryAccessStorage : public IAccessStorage
{
public:
static constexpr char STORAGE_TYPE[] = "memory";
explicit MemoryAccessStorage(const String & storage_name_ = STORAGE_TYPE);
explicit MemoryAccessStorage(const String & storage_name_, AccessChangesNotifier & changes_notifier_);
const char * getStorageType() const override { return STORAGE_TYPE; }
@ -24,8 +26,6 @@ public:
void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities);
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -34,25 +34,20 @@ private:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
bool insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
bool removeNoLock(const UUID & id, bool throw_if_not_exists);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
struct Entry
{
UUID id;
AccessEntityPtr entity;
mutable std::list<OnChangedHandler> handlers_by_id;
};
bool insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications);
bool removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications);
void setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications);
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
mutable std::recursive_mutex mutex;
mutable std::mutex mutex;
std::unordered_map<UUID, Entry> entries_by_id; /// We want to search entries both by ID and by the pair of name and type.
std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
AccessChangesNotifier & changes_notifier;
};
}

View File

@ -45,7 +45,6 @@ void MultipleAccessStorage::setStorages(const std::vector<StoragePtr> & storages
std::unique_lock lock{mutex};
nested_storages = std::make_shared<const Storages>(storages);
ids_cache.reset();
updateSubscriptionsToNestedStorages(lock);
}
void MultipleAccessStorage::addStorage(const StoragePtr & new_storage)
@ -56,7 +55,6 @@ void MultipleAccessStorage::addStorage(const StoragePtr & new_storage)
auto new_storages = std::make_shared<Storages>(*nested_storages);
new_storages->push_back(new_storage);
nested_storages = new_storages;
updateSubscriptionsToNestedStorages(lock);
}
void MultipleAccessStorage::removeStorage(const StoragePtr & storage_to_remove)
@ -70,7 +68,6 @@ void MultipleAccessStorage::removeStorage(const StoragePtr & storage_to_remove)
new_storages->erase(new_storages->begin() + index);
nested_storages = new_storages;
ids_cache.reset();
updateSubscriptionsToNestedStorages(lock);
}
std::vector<StoragePtr> MultipleAccessStorage::getStorages()
@ -225,6 +222,28 @@ bool MultipleAccessStorage::isReadOnly(const UUID & id) const
}
void MultipleAccessStorage::reload()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->reload();
}
void MultipleAccessStorage::startPeriodicReloading()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->startPeriodicReloading();
}
void MultipleAccessStorage::stopPeriodicReloading()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->stopPeriodicReloading();
}
std::optional<UUID> MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists)
{
std::shared_ptr<IAccessStorage> storage_for_insertion;
@ -310,145 +329,6 @@ bool MultipleAccessStorage::updateImpl(const UUID & id, const UpdateFunc & updat
}
scope_guard MultipleAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
auto storage = findStorage(id);
if (!storage)
return {};
return storage->subscribeForChanges(id, handler);
}
bool MultipleAccessStorage::hasSubscription(const UUID & id) const
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
{
if (storage->hasSubscription(id))
return true;
}
return false;
}
scope_guard MultipleAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::unique_lock lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
if (handlers.size() == 1)
updateSubscriptionsToNestedStorages(lock);
return [this, type, handler_it]
{
std::unique_lock lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
if (handlers2.empty())
updateSubscriptionsToNestedStorages(lock2);
};
}
bool MultipleAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
/// Updates subscriptions to nested storages.
/// We need the subscriptions to the nested storages if someone has subscribed to us.
/// If any of the nested storages is changed we call our subscribers.
void MultipleAccessStorage::updateSubscriptionsToNestedStorages(std::unique_lock<std::mutex> & lock) const
{
/// lock is already locked.
std::vector<std::pair<StoragePtr, scope_guard>> added_subscriptions[static_cast<size_t>(AccessEntityType::MAX)];
std::vector<scope_guard> removed_subscriptions;
for (auto type : collections::range(AccessEntityType::MAX))
{
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
auto & subscriptions = subscriptions_to_nested_storages[static_cast<size_t>(type)];
if (handlers.empty())
{
/// None has subscribed to us, we need no subscriptions to the nested storages.
for (auto & subscription : subscriptions | boost::adaptors::map_values)
removed_subscriptions.push_back(std::move(subscription));
subscriptions.clear();
}
else
{
/// Someone has subscribed to us, now we need to have a subscription to each nested storage.
for (auto it = subscriptions.begin(); it != subscriptions.end();)
{
const auto & storage = it->first;
auto & subscription = it->second;
if (boost::range::find(*nested_storages, storage) == nested_storages->end())
{
removed_subscriptions.push_back(std::move(subscription));
it = subscriptions.erase(it);
}
else
++it;
}
for (const auto & storage : *nested_storages)
{
if (!subscriptions.contains(storage))
added_subscriptions[static_cast<size_t>(type)].push_back({storage, nullptr});
}
}
}
/// Unlock the mutex temporarily because it's much better to subscribe to the nested storages
/// with the mutex unlocked.
lock.unlock();
removed_subscriptions.clear();
for (auto type : collections::range(AccessEntityType::MAX))
{
if (!added_subscriptions[static_cast<size_t>(type)].empty())
{
auto on_changed = [this, type](const UUID & id, const AccessEntityPtr & entity)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock2{mutex};
for (const auto & handler : handlers_by_type[static_cast<size_t>(type)])
notifications.push_back({handler, id, entity});
};
for (auto & [storage, subscription] : added_subscriptions[static_cast<size_t>(type)])
subscription = storage->subscribeForChanges(type, on_changed);
}
}
/// Lock the mutex again to store added subscriptions to the nested storages.
lock.lock();
for (auto type : collections::range(AccessEntityType::MAX))
{
if (!added_subscriptions[static_cast<size_t>(type)].empty())
{
auto & subscriptions = subscriptions_to_nested_storages[static_cast<size_t>(type)];
for (auto & [storage, subscription] : added_subscriptions[static_cast<size_t>(type)])
{
if (!subscriptions.contains(storage) && (boost::range::find(*nested_storages, storage) != nested_storages->end())
&& !handlers_by_type[static_cast<size_t>(type)].empty())
{
subscriptions.emplace(std::move(storage), std::move(subscription));
}
}
}
}
lock.unlock();
}
std::optional<UUID>
MultipleAccessStorage::authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address,
const ExternalAuthenticators & external_authenticators,

View File

@ -24,6 +24,10 @@ public:
bool isReadOnly() const override;
bool isReadOnly(const UUID & id) const override;
void reload() override;
void startPeriodicReloading() override;
void stopPeriodicReloading() override;
void setStorages(const std::vector<StoragePtr> & storages);
void addStorage(const StoragePtr & new_storage);
void removeStorage(const StoragePtr & storage_to_remove);
@ -37,8 +41,6 @@ public:
StoragePtr getStorage(const UUID & id);
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
protected:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -48,19 +50,14 @@ protected:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const override;
private:
using Storages = std::vector<StoragePtr>;
std::shared_ptr<const Storages> getStoragesInternal() const;
void updateSubscriptionsToNestedStorages(std::unique_lock<std::mutex> & lock) const;
std::shared_ptr<const Storages> nested_storages;
mutable LRUCache<UUID, Storage> ids_cache;
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::unordered_map<StoragePtr, scope_guard> subscriptions_to_nested_storages[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::mutex mutex;
};

View File

@ -1,12 +1,14 @@
#include <Access/AccessEntityIO.h>
#include <Access/MemoryAccessStorage.h>
#include <Access/ReplicatedAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
#include <IO/ReadHelpers.h>
#include <boost/container/flat_set.hpp>
#include <Common/ZooKeeper/KeeperException.h>
#include <Common/ZooKeeper/Types.h>
#include <Common/ZooKeeper/ZooKeeper.h>
#include <Common/escapeForFileName.h>
#include <Common/setThreadName.h>
#include <base/range.h>
#include <base/sleep.h>
@ -30,11 +32,13 @@ static UUID parseUUID(const String & text)
ReplicatedAccessStorage::ReplicatedAccessStorage(
const String & storage_name_,
const String & zookeeper_path_,
zkutil::GetZooKeeper get_zookeeper_)
zkutil::GetZooKeeper get_zookeeper_,
AccessChangesNotifier & changes_notifier_)
: IAccessStorage(storage_name_)
, zookeeper_path(zookeeper_path_)
, get_zookeeper(get_zookeeper_)
, refresh_queue(std::numeric_limits<size_t>::max())
, watched_queue(std::make_shared<ConcurrentBoundedQueue<UUID>>(std::numeric_limits<size_t>::max()))
, changes_notifier(changes_notifier_)
{
if (zookeeper_path.empty())
throw Exception("ZooKeeper path must be non-empty", ErrorCodes::BAD_ARGUMENTS);
@ -45,29 +49,30 @@ ReplicatedAccessStorage::ReplicatedAccessStorage(
/// If zookeeper chroot prefix is used, path should start with '/', because chroot concatenates without it.
if (zookeeper_path.front() != '/')
zookeeper_path = "/" + zookeeper_path;
initializeZookeeper();
}
ReplicatedAccessStorage::~ReplicatedAccessStorage()
{
ReplicatedAccessStorage::shutdown();
stopWatchingThread();
}
void ReplicatedAccessStorage::startup()
void ReplicatedAccessStorage::startWatchingThread()
{
initializeZookeeper();
worker_thread = ThreadFromGlobalPool(&ReplicatedAccessStorage::runWorkerThread, this);
bool prev_watching_flag = watching.exchange(true);
if (!prev_watching_flag)
watching_thread = ThreadFromGlobalPool(&ReplicatedAccessStorage::runWatchingThread, this);
}
void ReplicatedAccessStorage::shutdown()
void ReplicatedAccessStorage::stopWatchingThread()
{
bool prev_stop_flag = stop_flag.exchange(true);
if (!prev_stop_flag)
bool prev_watching_flag = watching.exchange(false);
if (prev_watching_flag)
{
refresh_queue.finish();
if (worker_thread.joinable())
worker_thread.join();
watched_queue->finish();
if (watching_thread.joinable())
watching_thread.join();
}
}
@ -105,10 +110,8 @@ std::optional<UUID> ReplicatedAccessStorage::insertImpl(const AccessEntityPtr &
if (!ok)
return std::nullopt;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id, notifications);
refreshEntityNoLock(zookeeper, id);
return id;
}
@ -207,10 +210,8 @@ bool ReplicatedAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exis
if (!ok)
return false;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
removeEntityNoLock(id, notifications);
removeEntityNoLock(id);
return true;
}
@ -261,10 +262,8 @@ bool ReplicatedAccessStorage::updateImpl(const UUID & id, const UpdateFunc & upd
if (!ok)
return false;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id, notifications);
refreshEntityNoLock(zookeeper, id);
return true;
}
@ -328,16 +327,18 @@ bool ReplicatedAccessStorage::updateZooKeeper(const zkutil::ZooKeeperPtr & zooke
}
void ReplicatedAccessStorage::runWorkerThread()
void ReplicatedAccessStorage::runWatchingThread()
{
LOG_DEBUG(getLogger(), "Started worker thread");
while (!stop_flag)
LOG_DEBUG(getLogger(), "Started watching thread");
setThreadName("ReplACLWatch");
while (watching)
{
try
{
if (!initialized)
initializeZookeeper();
refresh();
if (refresh())
changes_notifier.sendNotifications();
}
catch (...)
{
@ -353,7 +354,7 @@ void ReplicatedAccessStorage::resetAfterError()
initialized = false;
UUID id;
while (refresh_queue.tryPop(id)) {}
while (watched_queue->tryPop(id)) {}
std::lock_guard lock{mutex};
for (const auto type : collections::range(AccessEntityType::MAX))
@ -389,21 +390,20 @@ void ReplicatedAccessStorage::createRootNodes(const zkutil::ZooKeeperPtr & zooke
}
}
void ReplicatedAccessStorage::refresh()
bool ReplicatedAccessStorage::refresh()
{
UUID id;
if (refresh_queue.tryPop(id, /* timeout_ms: */ 10000))
{
if (stop_flag)
return;
if (!watched_queue->tryPop(id, /* timeout_ms: */ 10000))
return false;
auto zookeeper = get_zookeeper();
auto zookeeper = get_zookeeper();
if (id == UUIDHelpers::Nil)
refreshEntities(zookeeper);
else
refreshEntity(zookeeper, id);
}
if (id == UUIDHelpers::Nil)
refreshEntities(zookeeper);
else
refreshEntity(zookeeper, id);
return true;
}
@ -412,9 +412,9 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
LOG_DEBUG(getLogger(), "Refreshing entities list");
const String zookeeper_uuids_path = zookeeper_path + "/uuid";
auto watch_entities_list = [this](const Coordination::WatchResponse &)
auto watch_entities_list = [watched_queue = watched_queue](const Coordination::WatchResponse &)
{
[[maybe_unused]] bool push_result = refresh_queue.push(UUIDHelpers::Nil);
[[maybe_unused]] bool push_result = watched_queue->push(UUIDHelpers::Nil);
};
Coordination::Stat stat;
const auto entity_uuid_strs = zookeeper->getChildrenWatch(zookeeper_uuids_path, &stat, watch_entities_list);
@ -424,8 +424,6 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
for (const String & entity_uuid_str : entity_uuid_strs)
entity_uuids.insert(parseUUID(entity_uuid_str));
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
std::vector<UUID> entities_to_remove;
@ -437,14 +435,14 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
entities_to_remove.push_back(entity_uuid);
}
for (const auto & entity_uuid : entities_to_remove)
removeEntityNoLock(entity_uuid, notifications);
removeEntityNoLock(entity_uuid);
/// Locally add entities that were added to ZooKeeper
for (const auto & entity_uuid : entity_uuids)
{
const auto it = entries_by_id.find(entity_uuid);
if (it == entries_by_id.end())
refreshEntityNoLock(zookeeper, entity_uuid, notifications);
refreshEntityNoLock(zookeeper, entity_uuid);
}
LOG_DEBUG(getLogger(), "Refreshing entities list finished");
@ -452,21 +450,18 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
void ReplicatedAccessStorage::refreshEntity(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id, notifications);
refreshEntityNoLock(zookeeper, id);
}
void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, Notifications & notifications)
void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id)
{
LOG_DEBUG(getLogger(), "Refreshing entity {}", toString(id));
const auto watch_entity = [this, id](const Coordination::WatchResponse & response)
const auto watch_entity = [watched_queue = watched_queue, id](const Coordination::WatchResponse & response)
{
if (response.type == Coordination::Event::CHANGED)
[[maybe_unused]] bool push_result = refresh_queue.push(id);
[[maybe_unused]] bool push_result = watched_queue->push(id);
};
Coordination::Stat entity_stat;
const String entity_path = zookeeper_path + "/uuid/" + toString(id);
@ -475,16 +470,16 @@ void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & z
if (exists)
{
const AccessEntityPtr entity = deserializeAccessEntity(entity_definition, entity_path);
setEntityNoLock(id, entity, notifications);
setEntityNoLock(id, entity);
}
else
{
removeEntityNoLock(id, notifications);
removeEntityNoLock(id);
}
}
void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntityPtr & entity, Notifications & notifications)
void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntityPtr & entity)
{
LOG_DEBUG(getLogger(), "Setting id {} to entity named {}", toString(id), entity->getName());
const AccessEntityType type = entity->getType();
@ -494,12 +489,14 @@ void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntit
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
if (auto it = entries_by_name.find(name); it != entries_by_name.end() && it->second->id != id)
{
removeEntityNoLock(it->second->id, notifications);
removeEntityNoLock(it->second->id);
}
/// If the entity already exists under a different type+name, remove old type+name
bool existed_before = false;
if (auto it = entries_by_id.find(id); it != entries_by_id.end())
{
existed_before = true;
const AccessEntityPtr & existing_entity = it->second.entity;
const AccessEntityType existing_type = existing_entity->getType();
const String & existing_name = existing_entity->getName();
@ -514,11 +511,18 @@ void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntit
entry.id = id;
entry.entity = entity;
entries_by_name[name] = &entry;
prepareNotifications(entry, false, notifications);
if (initialized)
{
if (existed_before)
changes_notifier.onEntityUpdated(id, entity);
else
changes_notifier.onEntityAdded(id, entity);
}
}
void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications & notifications)
void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id)
{
LOG_DEBUG(getLogger(), "Removing entity with id {}", toString(id));
const auto it = entries_by_id.find(id);
@ -531,7 +535,6 @@ void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications
const Entry & entry = it->second;
const AccessEntityType type = entry.entity->getType();
const String & name = entry.entity->getName();
prepareNotifications(entry, true, notifications);
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
const auto name_it = entries_by_name.find(name);
@ -542,8 +545,11 @@ void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications
else
entries_by_name.erase(name);
UUID removed_id = id;
entries_by_id.erase(id);
LOG_DEBUG(getLogger(), "Removed entity with id {}", toString(id));
changes_notifier.onEntityRemoved(removed_id, type);
}
@ -594,73 +600,4 @@ AccessEntityPtr ReplicatedAccessStorage::readImpl(const UUID & id, bool throw_if
return entry.entity;
}
void ReplicatedAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.entity->getType())])
notifications.push_back({handler, entry.id, entity});
}
scope_guard ReplicatedAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
scope_guard ReplicatedAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
const auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
bool ReplicatedAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
const auto & it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool ReplicatedAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
}

View File

@ -18,32 +18,33 @@
namespace DB
{
class AccessChangesNotifier;
/// Implementation of IAccessStorage which keeps all data in zookeeper.
class ReplicatedAccessStorage : public IAccessStorage
{
public:
static constexpr char STORAGE_TYPE[] = "replicated";
ReplicatedAccessStorage(const String & storage_name, const String & zookeeper_path, zkutil::GetZooKeeper get_zookeeper);
ReplicatedAccessStorage(const String & storage_name, const String & zookeeper_path, zkutil::GetZooKeeper get_zookeeper, AccessChangesNotifier & changes_notifier_);
virtual ~ReplicatedAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; }
virtual void startup();
virtual void shutdown();
void startPeriodicReloading() override { startWatchingThread(); }
void stopPeriodicReloading() override { stopWatchingThread(); }
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private:
String zookeeper_path;
zkutil::GetZooKeeper get_zookeeper;
std::atomic<bool> initialized = false;
std::atomic<bool> stop_flag = false;
ThreadFromGlobalPool worker_thread;
ConcurrentBoundedQueue<UUID> refresh_queue;
std::atomic<bool> watching = false;
ThreadFromGlobalPool watching_thread;
std::shared_ptr<ConcurrentBoundedQueue<UUID>> watched_queue;
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
@ -53,37 +54,36 @@ private:
bool removeZooKeeper(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, bool throw_if_not_exists);
bool updateZooKeeper(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
void runWorkerThread();
void resetAfterError();
void initializeZookeeper();
void createRootNodes(const zkutil::ZooKeeperPtr & zookeeper);
void refresh();
void startWatchingThread();
void stopWatchingThread();
void runWatchingThread();
void resetAfterError();
bool refresh();
void refreshEntities(const zkutil::ZooKeeperPtr & zookeeper);
void refreshEntity(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id);
void refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, Notifications & notifications);
void refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id);
void setEntityNoLock(const UUID & id, const AccessEntityPtr & entity, Notifications & notifications);
void removeEntityNoLock(const UUID & id, Notifications & notifications);
void setEntityNoLock(const UUID & id, const AccessEntityPtr & entity);
void removeEntityNoLock(const UUID & id);
struct Entry
{
UUID id;
AccessEntityPtr entity;
mutable std::list<OnChangedHandler> handlers_by_id;
};
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
std::vector<UUID> findAllImpl(AccessEntityType type) const override;
AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
mutable std::mutex mutex;
std::unordered_map<UUID, Entry> entries_by_id;
std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
AccessChangesNotifier & changes_notifier;
};
}

View File

@ -66,9 +66,6 @@ RoleCache::~RoleCache() = default;
std::shared_ptr<const EnabledRoles>
RoleCache::getEnabledRoles(const std::vector<UUID> & roles, const std::vector<UUID> & roles_with_admin_option)
{
/// Declared before `lock` to send notifications after the mutex will be unlocked.
scope_guard notifications;
std::lock_guard lock{mutex};
EnabledRoles::Params params;
params.current_roles.insert(roles.begin(), roles.end());
@ -83,13 +80,13 @@ RoleCache::getEnabledRoles(const std::vector<UUID> & roles, const std::vector<UU
}
auto res = std::shared_ptr<EnabledRoles>(new EnabledRoles(params));
collectEnabledRoles(*res, notifications);
collectEnabledRoles(*res, nullptr);
enabled_roles.emplace(std::move(params), res);
return res;
}
void RoleCache::collectEnabledRoles(scope_guard & notifications)
void RoleCache::collectEnabledRoles(scope_guard * notifications)
{
/// `mutex` is already locked.
@ -107,7 +104,7 @@ void RoleCache::collectEnabledRoles(scope_guard & notifications)
}
void RoleCache::collectEnabledRoles(EnabledRoles & enabled, scope_guard & notifications)
void RoleCache::collectEnabledRoles(EnabledRoles & enabled, scope_guard * notifications)
{
/// `mutex` is already locked.
@ -170,7 +167,7 @@ void RoleCache::roleChanged(const UUID & role_id, const RolePtr & changed_role)
return;
role_from_cache->first = changed_role;
cache.update(role_id, role_from_cache);
collectEnabledRoles(notifications);
collectEnabledRoles(&notifications);
}
@ -181,7 +178,7 @@ void RoleCache::roleRemoved(const UUID & role_id)
std::lock_guard lock{mutex};
cache.remove(role_id);
collectEnabledRoles(notifications);
collectEnabledRoles(&notifications);
}
}

View File

@ -24,8 +24,8 @@ public:
const std::vector<UUID> & current_roles_with_admin_option);
private:
void collectEnabledRoles(scope_guard & notifications);
void collectEnabledRoles(EnabledRoles & enabled, scope_guard & notifications);
void collectEnabledRoles(scope_guard * notifications);
void collectEnabledRoles(EnabledRoles & enabled, scope_guard * notifications);
RolePtr getRole(const UUID & role_id);
void roleChanged(const UUID & role_id, const RolePtr & changed_role);
void roleRemoved(const UUID & role_id);

View File

@ -4,6 +4,7 @@
#include <Access/User.h>
#include <Access/SettingsProfile.h>
#include <Access/AccessControl.h>
#include <Access/AccessChangesNotifier.h>
#include <Dictionaries/IDictionary.h>
#include <Common/Config/ConfigReloader.h>
#include <Common/StringUtils/StringUtils.h>
@ -14,9 +15,6 @@
#include <Poco/JSON/JSON.h>
#include <Poco/JSON/Object.h>
#include <Poco/JSON/Stringifier.h>
#include <Common/logger_useful.h>
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/adaptor/map.hpp>
#include <cstring>
#include <filesystem>
#include <base/FnTraits.h>
@ -525,8 +523,8 @@ namespace
}
}
UsersConfigAccessStorage::UsersConfigAccessStorage(const String & storage_name_, const AccessControl & access_control_)
: IAccessStorage(storage_name_), access_control(access_control_)
UsersConfigAccessStorage::UsersConfigAccessStorage(const String & storage_name_, AccessControl & access_control_)
: IAccessStorage(storage_name_), access_control(access_control_), memory_storage(storage_name_, access_control.getChangesNotifier())
{
}
@ -605,9 +603,9 @@ void UsersConfigAccessStorage::load(
std::make_shared<Poco::Event>(),
[&](Poco::AutoPtr<Poco::Util::AbstractConfiguration> new_config, bool /*initial_loading*/)
{
parseFromConfig(*new_config);
Settings::checkNoSettingNamesAtTopLevel(*new_config, users_config_path);
parseFromConfig(*new_config);
access_control.getChangesNotifier().sendNotifications();
},
/* already_loaded = */ false);
}
@ -662,27 +660,4 @@ std::optional<String> UsersConfigAccessStorage::readNameImpl(const UUID & id, bo
return memory_storage.readName(id, throw_if_not_exists);
}
scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(id, handler);
}
scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(type, handler);
}
bool UsersConfigAccessStorage::hasSubscription(const UUID & id) const
{
return memory_storage.hasSubscription(id);
}
bool UsersConfigAccessStorage::hasSubscription(AccessEntityType type) const
{
return memory_storage.hasSubscription(type);
}
}

View File

@ -22,7 +22,7 @@ public:
static constexpr char STORAGE_TYPE[] = "users.xml";
UsersConfigAccessStorage(const String & storage_name_, const AccessControl & access_control_);
UsersConfigAccessStorage(const String & storage_name_, AccessControl & access_control_);
~UsersConfigAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; }
@ -37,13 +37,12 @@ public:
const String & include_from_path = {},
const String & preprocessed_dir = {},
const zkutil::GetZooKeeper & get_zookeeper_function = {});
void reload();
void startPeriodicReloading();
void stopPeriodicReloading();
void reload() override;
void startPeriodicReloading() override;
void stopPeriodicReloading() override;
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private:
void parseFromConfig(const Poco::Util::AbstractConfiguration & config);
@ -51,10 +50,8 @@ private:
std::vector<UUID> findAllImpl(AccessEntityType type) const override;
AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
std::optional<String> readNameImpl(const UUID & id, bool throw_if_not_exists) const override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
const AccessControl & access_control;
AccessControl & access_control;
MemoryAccessStorage memory_storage;
String path;
std::unique_ptr<ConfigReloader> config_reloader;

View File

@ -1,5 +1,6 @@
#include <gtest/gtest.h>
#include <Access/ReplicatedAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
using namespace DB;
@ -12,18 +13,6 @@ namespace ErrorCodes
}
TEST(ReplicatedAccessStorage, ShutdownWithoutStartup)
{
auto get_zk = []()
{
return std::shared_ptr<zkutil::ZooKeeper>();
};
auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk);
storage.shutdown();
}
TEST(ReplicatedAccessStorage, ShutdownWithFailedStartup)
{
auto get_zk = []()
@ -31,16 +20,16 @@ TEST(ReplicatedAccessStorage, ShutdownWithFailedStartup)
return std::shared_ptr<zkutil::ZooKeeper>();
};
auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk);
AccessChangesNotifier changes_notifier;
try
{
storage.startup();
auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk, changes_notifier);
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::NO_ZOOKEEPER)
throw;
}
storage.shutdown();
}

View File

@ -342,8 +342,6 @@ struct ContextSharedPart
/// Stop periodic reloading of the configuration files.
/// This must be done first because otherwise the reloading may pass a changed config
/// to some destroyed parts of ContextSharedPart.
if (access_control)
access_control->stopPeriodicReloadingUsersConfigs();
if (external_dictionaries_loader)
external_dictionaries_loader->enablePeriodicUpdates(false);
if (external_user_defined_executable_functions_loader)