Rework notifications used in access management.

This commit is contained in:
Vitaly Baranov 2022-05-16 20:43:55 +02:00
parent 9ccddc44c6
commit 58f4a86ec7
27 changed files with 561 additions and 754 deletions

View File

@ -1314,7 +1314,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
global_context->setConfigReloadCallback([&]() global_context->setConfigReloadCallback([&]()
{ {
main_config_reloader->reload(); main_config_reloader->reload();
access_control.reloadUsersConfigs(); access_control.reload();
}); });
/// Limit on total number of concurrently executed queries. /// Limit on total number of concurrently executed queries.
@ -1406,6 +1406,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
/// Stop reloading of the main config. This must be done before `global_context->shutdown()` because /// Stop reloading of the main config. This must be done before `global_context->shutdown()` because
/// otherwise the reloading may pass a changed config to some destroyed parts of ContextSharedPart. /// otherwise the reloading may pass a changed config to some destroyed parts of ContextSharedPart.
main_config_reloader.reset(); main_config_reloader.reset();
access_control.stopPeriodicReloading();
async_metrics.stop(); async_metrics.stop();
@ -1629,7 +1630,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
buildLoggers(config(), logger()); buildLoggers(config(), logger());
main_config_reloader->start(); main_config_reloader->start();
access_control.startPeriodicReloadingUsersConfigs(); access_control.startPeriodicReloading();
if (dns_cache_updater) if (dns_cache_updater)
dns_cache_updater->start(); dns_cache_updater->start();

View File

@ -0,0 +1,122 @@
#include <Access/AccessChangesNotifier.h>
#include <boost/range/algorithm/copy.hpp>
namespace DB
{
AccessChangesNotifier::AccessChangesNotifier() : handlers(std::make_shared<Handlers>())
{
}
AccessChangesNotifier::~AccessChangesNotifier() = default;
void AccessChangesNotifier::onEntityAdded(const UUID & id, const AccessEntityPtr & new_entity)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.entity = new_entity;
event.type = new_entity->getType();
queue.push(std::move(event));
}
void AccessChangesNotifier::onEntityUpdated(const UUID & id, const AccessEntityPtr & changed_entity)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.entity = changed_entity;
event.type = changed_entity->getType();
queue.push(std::move(event));
}
void AccessChangesNotifier::onEntityRemoved(const UUID & id, AccessEntityType type)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.type = type;
queue.push(std::move(event));
}
scope_guard AccessChangesNotifier::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler)
{
std::lock_guard lock{handlers->mutex};
auto & list = handlers->by_type[static_cast<size_t>(type)];
list.push_back(handler);
auto handler_it = std::prev(list.end());
return [handlers=handlers, type, handler_it]
{
std::lock_guard lock2{handlers->mutex};
auto & list2 = handlers->by_type[static_cast<size_t>(type)];
list2.erase(handler_it);
};
}
scope_guard AccessChangesNotifier::subscribeForChanges(const UUID & id, const OnChangedHandler & handler)
{
std::lock_guard lock{handlers->mutex};
auto it = handlers->by_id.emplace(id, std::list<OnChangedHandler>{}).first;
auto & list = it->second;
list.push_back(handler);
auto handler_it = std::prev(list.end());
return [handlers=handlers, it, handler_it]
{
std::lock_guard lock2{handlers->mutex};
auto & list2 = it->second;
list2.erase(handler_it);
if (list2.empty())
handlers->by_id.erase(it);
};
}
scope_guard AccessChangesNotifier::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler)
{
scope_guard subscriptions;
for (const auto & id : ids)
subscriptions.join(subscribeForChanges(id, handler));
return subscriptions;
}
void AccessChangesNotifier::sendNotifications()
{
/// Only one thread can send notification at any time.
std::lock_guard sending_notifications_lock{sending_notifications};
std::unique_lock queue_lock{queue_mutex};
while (!queue.empty())
{
auto event = std::move(queue.front());
queue.pop();
queue_lock.unlock();
std::vector<OnChangedHandler> current_handlers;
{
std::lock_guard handlers_lock{handlers->mutex};
boost::range::copy(handlers->by_type[static_cast<size_t>(event.type)], std::back_inserter(current_handlers));
auto it = handlers->by_id.find(event.id);
if (it != handlers->by_id.end())
boost::range::copy(it->second, std::back_inserter(current_handlers));
}
for (const auto & handler : current_handlers)
{
try
{
handler(event.id, event.entity);
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
queue_lock.lock();
}
}
}

View File

@ -0,0 +1,73 @@
#pragma once
#include <Access/IAccessEntity.h>
#include <base/scope_guard.h>
#include <list>
#include <queue>
#include <unordered_map>
namespace DB
{
/// Helper class implementing subscriptions and notifications in access management.
class AccessChangesNotifier
{
public:
AccessChangesNotifier();
~AccessChangesNotifier();
using OnChangedHandler
= std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler);
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler)
{
return subscribeForChanges(EntityClassT::TYPE, handler);
}
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler);
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler);
/// Called by access storages after a new access entity has been added.
void onEntityAdded(const UUID & id, const AccessEntityPtr & new_entity);
/// Called by access storages after an access entity has been changed.
void onEntityUpdated(const UUID & id, const AccessEntityPtr & changed_entity);
/// Called by access storages after an access entity has been removed.
void onEntityRemoved(const UUID & id, AccessEntityType type);
/// Sends notifications to subscribers about changes in access entities
/// (added with previous calls onEntityAdded(), onEntityUpdated(), onEntityRemoved()).
void sendNotifications();
private:
struct Handlers
{
std::unordered_map<UUID, std::list<OnChangedHandler>> by_id;
std::list<OnChangedHandler> by_type[static_cast<size_t>(AccessEntityType::MAX)];
std::mutex mutex;
};
/// shared_ptr is here for safety because AccessChangesNotifier can be destroyed before all subscriptions are removed.
std::shared_ptr<Handlers> handlers;
struct Event
{
UUID id;
AccessEntityPtr entity;
AccessEntityType type;
};
std::queue<Event> queue;
std::mutex queue_mutex;
std::mutex sending_notifications;
};
}

View File

@ -14,6 +14,7 @@
#include <Access/SettingsProfilesCache.h> #include <Access/SettingsProfilesCache.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/ExternalAuthenticators.h> #include <Access/ExternalAuthenticators.h>
#include <Access/AccessChangesNotifier.h>
#include <Core/Settings.h> #include <Core/Settings.h>
#include <base/find_symbols.h> #include <base/find_symbols.h>
#include <Poco/ExpireCache.h> #include <Poco/ExpireCache.h>
@ -142,7 +143,8 @@ AccessControl::AccessControl()
quota_cache(std::make_unique<QuotaCache>(*this)), quota_cache(std::make_unique<QuotaCache>(*this)),
settings_profiles_cache(std::make_unique<SettingsProfilesCache>(*this)), settings_profiles_cache(std::make_unique<SettingsProfilesCache>(*this)),
external_authenticators(std::make_unique<ExternalAuthenticators>()), external_authenticators(std::make_unique<ExternalAuthenticators>()),
custom_settings_prefixes(std::make_unique<CustomSettingsPrefixes>()) custom_settings_prefixes(std::make_unique<CustomSettingsPrefixes>()),
changes_notifier(std::make_unique<AccessChangesNotifier>())
{ {
} }
@ -231,35 +233,6 @@ void AccessControl::addUsersConfigStorage(
LOG_DEBUG(getLogger(), "Added {} access storage '{}', path: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getPath()); LOG_DEBUG(getLogger(), "Added {} access storage '{}', path: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getPath());
} }
void AccessControl::reloadUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->reload();
}
}
void AccessControl::startPeriodicReloadingUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->startPeriodicReloading();
}
}
void AccessControl::stopPeriodicReloadingUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->stopPeriodicReloading();
}
}
void AccessControl::addReplicatedStorage( void AccessControl::addReplicatedStorage(
const String & storage_name_, const String & storage_name_,
@ -272,10 +245,9 @@ void AccessControl::addReplicatedStorage(
if (auto replicated_storage = typeid_cast<std::shared_ptr<ReplicatedAccessStorage>>(storage)) if (auto replicated_storage = typeid_cast<std::shared_ptr<ReplicatedAccessStorage>>(storage))
return; return;
} }
auto new_storage = std::make_shared<ReplicatedAccessStorage>(storage_name_, zookeeper_path_, get_zookeeper_function_); auto new_storage = std::make_shared<ReplicatedAccessStorage>(storage_name_, zookeeper_path_, get_zookeeper_function_, *changes_notifier);
addStorage(new_storage); addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}'", String(new_storage->getStorageType()), new_storage->getStorageName()); LOG_DEBUG(getLogger(), "Added {} access storage '{}'", String(new_storage->getStorageType()), new_storage->getStorageName());
new_storage->startup();
} }
void AccessControl::addDiskStorage(const String & directory_, bool readonly_) void AccessControl::addDiskStorage(const String & directory_, bool readonly_)
@ -298,7 +270,7 @@ void AccessControl::addDiskStorage(const String & storage_name_, const String &
} }
} }
} }
auto new_storage = std::make_shared<DiskAccessStorage>(storage_name_, directory_, readonly_); auto new_storage = std::make_shared<DiskAccessStorage>(storage_name_, directory_, readonly_, *changes_notifier);
addStorage(new_storage); addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}', path: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getPath()); LOG_DEBUG(getLogger(), "Added {} access storage '{}', path: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getPath());
} }
@ -312,7 +284,7 @@ void AccessControl::addMemoryStorage(const String & storage_name_)
if (auto memory_storage = typeid_cast<std::shared_ptr<MemoryAccessStorage>>(storage)) if (auto memory_storage = typeid_cast<std::shared_ptr<MemoryAccessStorage>>(storage))
return; return;
} }
auto new_storage = std::make_shared<MemoryAccessStorage>(storage_name_); auto new_storage = std::make_shared<MemoryAccessStorage>(storage_name_, *changes_notifier);
addStorage(new_storage); addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}'", String(new_storage->getStorageType()), new_storage->getStorageName()); LOG_DEBUG(getLogger(), "Added {} access storage '{}'", String(new_storage->getStorageType()), new_storage->getStorageName());
} }
@ -320,7 +292,7 @@ void AccessControl::addMemoryStorage(const String & storage_name_)
void AccessControl::addLDAPStorage(const String & storage_name_, const Poco::Util::AbstractConfiguration & config_, const String & prefix_) void AccessControl::addLDAPStorage(const String & storage_name_, const Poco::Util::AbstractConfiguration & config_, const String & prefix_)
{ {
auto new_storage = std::make_shared<LDAPAccessStorage>(storage_name_, this, config_, prefix_); auto new_storage = std::make_shared<LDAPAccessStorage>(storage_name_, *this, config_, prefix_);
addStorage(new_storage); addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}', LDAP server name: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getLDAPServerName()); LOG_DEBUG(getLogger(), "Added {} access storage '{}', LDAP server name: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getLDAPServerName());
} }
@ -423,6 +395,57 @@ void AccessControl::addStoragesFromMainConfig(
} }
void AccessControl::reload()
{
MultipleAccessStorage::reload();
changes_notifier->sendNotifications();
}
scope_guard AccessControl::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(type, handler);
}
scope_guard AccessControl::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(id, handler);
}
scope_guard AccessControl::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(ids, handler);
}
std::optional<UUID> AccessControl::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists)
{
auto id = MultipleAccessStorage::insertImpl(entity, replace_if_exists, throw_if_exists);
if (id)
changes_notifier->sendNotifications();
return id;
}
bool AccessControl::removeImpl(const UUID & id, bool throw_if_not_exists)
{
bool removed = MultipleAccessStorage::removeImpl(id, throw_if_not_exists);
if (removed)
changes_notifier->sendNotifications();
return removed;
}
bool AccessControl::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
bool updated = MultipleAccessStorage::updateImpl(id, update_func, throw_if_not_exists);
if (updated)
changes_notifier->sendNotifications();
return updated;
}
AccessChangesNotifier & AccessControl::getChangesNotifier()
{
return *changes_notifier;
}
UUID AccessControl::authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const UUID AccessControl::authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const
{ {
try try

View File

@ -3,8 +3,8 @@
#include <Access/MultipleAccessStorage.h> #include <Access/MultipleAccessStorage.h>
#include <Common/SettingsChanges.h> #include <Common/SettingsChanges.h>
#include <Common/ZooKeeper/Common.h> #include <Common/ZooKeeper/Common.h>
#include <base/scope_guard.h>
#include <boost/container/flat_set.hpp> #include <boost/container/flat_set.hpp>
#include <Access/UsersConfigAccessStorage.h>
#include <memory> #include <memory>
@ -40,6 +40,7 @@ class SettingsProfilesCache;
class SettingsProfileElements; class SettingsProfileElements;
class ClientInfo; class ClientInfo;
class ExternalAuthenticators; class ExternalAuthenticators;
class AccessChangesNotifier;
struct Settings; struct Settings;
@ -50,6 +51,7 @@ public:
AccessControl(); AccessControl();
~AccessControl() override; ~AccessControl() override;
/// Initializes access storage (user directories).
void setUpFromMainConfig(const Poco::Util::AbstractConfiguration & config_, const String & config_path_, void setUpFromMainConfig(const Poco::Util::AbstractConfiguration & config_, const String & config_path_,
const zkutil::GetZooKeeper & get_zookeeper_function_); const zkutil::GetZooKeeper & get_zookeeper_function_);
@ -74,9 +76,6 @@ public:
const String & preprocessed_dir_, const String & preprocessed_dir_,
const zkutil::GetZooKeeper & get_zookeeper_function_ = {}); const zkutil::GetZooKeeper & get_zookeeper_function_ = {});
void reloadUsersConfigs();
void startPeriodicReloadingUsersConfigs();
void stopPeriodicReloadingUsersConfigs();
/// Loads access entities from the directory on the local disk. /// Loads access entities from the directory on the local disk.
/// Use that directory to keep created users/roles/etc. /// Use that directory to keep created users/roles/etc.
void addDiskStorage(const String & directory_, bool readonly_ = false); void addDiskStorage(const String & directory_, bool readonly_ = false);
@ -106,6 +105,26 @@ public:
const String & config_path, const String & config_path,
const zkutil::GetZooKeeper & get_zookeeper_function); const zkutil::GetZooKeeper & get_zookeeper_function);
/// Reloads and updates entities in this storage. This function is used to implement SYSTEM RELOAD CONFIG.
void reload() override;
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const;
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(EntityClassT::TYPE, handler); }
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const;
void setExternalAuthenticatorsConfig(const Poco::Util::AbstractConfiguration & config);
/// Sets the default profile's name. /// Sets the default profile's name.
/// The default profile's settings are always applied before any other profile's. /// The default profile's settings are always applied before any other profile's.
void setDefaultProfileName(const String & default_profile_name); void setDefaultProfileName(const String & default_profile_name);
@ -135,9 +154,6 @@ public:
void setOnClusterQueriesRequireClusterGrant(bool enable) { on_cluster_queries_require_cluster_grant = enable; } void setOnClusterQueriesRequireClusterGrant(bool enable) { on_cluster_queries_require_cluster_grant = enable; }
bool doesOnClusterQueriesRequireClusterGrant() const { return on_cluster_queries_require_cluster_grant; } bool doesOnClusterQueriesRequireClusterGrant() const { return on_cluster_queries_require_cluster_grant; }
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const;
void setExternalAuthenticatorsConfig(const Poco::Util::AbstractConfiguration & config);
std::shared_ptr<const ContextAccess> getContextAccess( std::shared_ptr<const ContextAccess> getContextAccess(
const UUID & user_id, const UUID & user_id,
const std::vector<UUID> & current_roles, const std::vector<UUID> & current_roles,
@ -178,10 +194,17 @@ public:
const ExternalAuthenticators & getExternalAuthenticators() const; const ExternalAuthenticators & getExternalAuthenticators() const;
/// Gets manager of notifications.
AccessChangesNotifier & getChangesNotifier();
private: private:
class ContextAccessCache; class ContextAccessCache;
class CustomSettingsPrefixes; class CustomSettingsPrefixes;
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
std::unique_ptr<ContextAccessCache> context_access_cache; std::unique_ptr<ContextAccessCache> context_access_cache;
std::unique_ptr<RoleCache> role_cache; std::unique_ptr<RoleCache> role_cache;
std::unique_ptr<RowPolicyCache> row_policy_cache; std::unique_ptr<RowPolicyCache> row_policy_cache;
@ -189,6 +212,7 @@ private:
std::unique_ptr<SettingsProfilesCache> settings_profiles_cache; std::unique_ptr<SettingsProfilesCache> settings_profiles_cache;
std::unique_ptr<ExternalAuthenticators> external_authenticators; std::unique_ptr<ExternalAuthenticators> external_authenticators;
std::unique_ptr<CustomSettingsPrefixes> custom_settings_prefixes; std::unique_ptr<CustomSettingsPrefixes> custom_settings_prefixes;
std::unique_ptr<AccessChangesNotifier> changes_notifier;
std::atomic_bool allow_plaintext_password = true; std::atomic_bool allow_plaintext_password = true;
std::atomic_bool allow_no_password = true; std::atomic_bool allow_no_password = true;
std::atomic_bool users_without_row_policies_can_read_rows = false; std::atomic_bool users_without_row_policies_can_read_rows = false;

View File

@ -149,6 +149,21 @@ ContextAccess::ContextAccess(const AccessControl & access_control_, const Params
} }
ContextAccess::~ContextAccess()
{
enabled_settings.reset();
enabled_quota.reset();
enabled_row_policies.reset();
access_with_implicit.reset();
access.reset();
roles_info.reset();
subscription_for_roles_changes.reset();
enabled_roles.reset();
subscription_for_user_change.reset();
user.reset();
}
void ContextAccess::initialize() void ContextAccess::initialize()
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};

View File

@ -155,6 +155,8 @@ public:
/// without any limitations. This is used for the global context. /// without any limitations. This is used for the global context.
static std::shared_ptr<const ContextAccess> getFullAccess(); static std::shared_ptr<const ContextAccess> getFullAccess();
~ContextAccess();
private: private:
friend class AccessControl; friend class AccessControl;
ContextAccess() {} /// NOLINT ContextAccess() {} /// NOLINT

View File

@ -1,5 +1,6 @@
#include <Access/DiskAccessStorage.h> #include <Access/DiskAccessStorage.h>
#include <Access/AccessEntityIO.h> #include <Access/AccessEntityIO.h>
#include <Access/AccessChangesNotifier.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/ReadBufferFromFile.h> #include <IO/ReadBufferFromFile.h>
@ -164,13 +165,8 @@ namespace
} }
DiskAccessStorage::DiskAccessStorage(const String & directory_path_, bool readonly_) DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_, AccessChangesNotifier & changes_notifier_)
: DiskAccessStorage(STORAGE_TYPE, directory_path_, readonly_) : IAccessStorage(storage_name_), changes_notifier(changes_notifier_)
{
}
DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_)
: IAccessStorage(storage_name_)
{ {
directory_path = makeDirectoryPathCanonical(directory_path_); directory_path = makeDirectoryPathCanonical(directory_path_);
readonly = readonly_; readonly = readonly_;
@ -199,8 +195,16 @@ DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String
DiskAccessStorage::~DiskAccessStorage() DiskAccessStorage::~DiskAccessStorage()
{ {
stopListsWritingThread(); stopListsWritingThread();
try
{
writeLists(); writeLists();
} }
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
String DiskAccessStorage::getStorageParamsJSON() const String DiskAccessStorage::getStorageParamsJSON() const
@ -470,19 +474,16 @@ std::optional<String> DiskAccessStorage::readNameImpl(const UUID & id, bool thro
std::optional<UUID> DiskAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists) std::optional<UUID> DiskAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{ {
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
UUID id = generateRandomID(); UUID id = generateRandomID();
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists, notifications)) if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists))
return id; return id;
return std::nullopt; return std::nullopt;
} }
bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications) bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{ {
const String & name = new_entity->getName(); const String & name = new_entity->getName();
AccessEntityType type = new_entity->getType(); AccessEntityType type = new_entity->getType();
@ -514,7 +515,7 @@ bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & ne
writeAccessEntityToDisk(id, *new_entity); writeAccessEntityToDisk(id, *new_entity);
if (name_collision && replace_if_exists) if (name_collision && replace_if_exists)
removeNoLock(it_by_name->second->id, /* throw_if_not_exists = */ false, notifications); removeNoLock(it_by_name->second->id, /* throw_if_not_exists = */ false);
/// Do insertion. /// Do insertion.
auto & entry = entries_by_id[id]; auto & entry = entries_by_id[id];
@ -523,22 +524,20 @@ bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & ne
entry.name = name; entry.name = name;
entry.entity = new_entity; entry.entity = new_entity;
entries_by_name[entry.name] = &entry; entries_by_name[entry.name] = &entry;
prepareNotifications(id, entry, false, notifications);
changes_notifier.onEntityAdded(id, new_entity);
return true; return true;
} }
bool DiskAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exists) bool DiskAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exists)
{ {
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
return removeNoLock(id, throw_if_not_exists, notifications); return removeNoLock(id, throw_if_not_exists);
} }
bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications) bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists)
{ {
auto it = entries_by_id.find(id); auto it = entries_by_id.find(id);
if (it == entries_by_id.end()) if (it == entries_by_id.end())
@ -559,25 +558,24 @@ bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists,
deleteAccessEntityOnDisk(id); deleteAccessEntityOnDisk(id);
/// Do removing. /// Do removing.
prepareNotifications(id, entry, true, notifications); UUID removed_id = id;
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)]; auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
entries_by_name.erase(entry.name); entries_by_name.erase(entry.name);
entries_by_id.erase(it); entries_by_id.erase(it);
changes_notifier.onEntityRemoved(removed_id, type);
return true; return true;
} }
bool DiskAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) bool DiskAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{ {
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
return updateNoLock(id, update_func, throw_if_not_exists, notifications); return updateNoLock(id, update_func, throw_if_not_exists);
} }
bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications) bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{ {
auto it = entries_by_id.find(id); auto it = entries_by_id.find(id);
if (it == entries_by_id.end()) if (it == entries_by_id.end())
@ -626,7 +624,8 @@ bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_
entries_by_name[entry.name] = &entry; entries_by_name[entry.name] = &entry;
} }
prepareNotifications(id, entry, false, notifications); changes_notifier.onEntityUpdated(id, new_entity);
return true; return true;
} }
@ -650,74 +649,4 @@ void DiskAccessStorage::deleteAccessEntityOnDisk(const UUID & id) const
throw Exception("Couldn't delete " + file_path, ErrorCodes::FILE_DOESNT_EXIST); throw Exception("Couldn't delete " + file_path, ErrorCodes::FILE_DOESNT_EXIST);
} }
void DiskAccessStorage::prepareNotifications(const UUID & id, const Entry & entry, bool remove, Notifications & notifications) const
{
if (!remove && !entry.entity)
return;
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.type)])
notifications.push_back({handler, id, entity});
}
scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
scope_guard DiskAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
bool DiskAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool DiskAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
} }

View File

@ -7,14 +7,15 @@
namespace DB namespace DB
{ {
class AccessChangesNotifier;
/// Loads and saves access entities on a local disk to a specified directory. /// Loads and saves access entities on a local disk to a specified directory.
class DiskAccessStorage : public IAccessStorage class DiskAccessStorage : public IAccessStorage
{ {
public: public:
static constexpr char STORAGE_TYPE[] = "local directory"; static constexpr char STORAGE_TYPE[] = "local directory";
DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_ = false); DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_, AccessChangesNotifier & changes_notifier_);
DiskAccessStorage(const String & directory_path_, bool readonly_ = false);
~DiskAccessStorage() override; ~DiskAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; } const char * getStorageType() const override { return STORAGE_TYPE; }
@ -27,8 +28,6 @@ public:
bool isReadOnly() const override { return readonly; } bool isReadOnly() const override { return readonly; }
bool exists(const UUID & id) const override; bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private: private:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override; std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -38,8 +37,6 @@ private:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override; std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override; bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override; bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
void clear(); void clear();
bool readLists(); bool readLists();
@ -50,9 +47,9 @@ private:
void listsWritingThreadFunc(); void listsWritingThreadFunc();
void stopListsWritingThread(); void stopListsWritingThread();
bool insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications); bool insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists);
bool removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications); bool removeNoLock(const UUID & id, bool throw_if_not_exists);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications); bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
AccessEntityPtr readAccessEntityFromDisk(const UUID & id) const; AccessEntityPtr readAccessEntityFromDisk(const UUID & id) const;
void writeAccessEntityToDisk(const UUID & id, const IAccessEntity & entity) const; void writeAccessEntityToDisk(const UUID & id, const IAccessEntity & entity) const;
@ -65,11 +62,8 @@ private:
String name; String name;
AccessEntityType type; AccessEntityType type;
mutable AccessEntityPtr entity; /// may be nullptr, if the entity hasn't been loaded yet. mutable AccessEntityPtr entity; /// may be nullptr, if the entity hasn't been loaded yet.
mutable std::list<OnChangedHandler> handlers_by_id;
}; };
void prepareNotifications(const UUID & id, const Entry & entry, bool remove, Notifications & notifications) const;
String directory_path; String directory_path;
std::atomic<bool> readonly; std::atomic<bool> readonly;
std::unordered_map<UUID, Entry> entries_by_id; std::unordered_map<UUID, Entry> entries_by_id;
@ -79,7 +73,7 @@ private:
ThreadFromGlobalPool lists_writing_thread; /// List files are written in a separate thread. ThreadFromGlobalPool lists_writing_thread; /// List files are written in a separate thread.
std::condition_variable lists_writing_thread_should_exit; /// Signals `lists_writing_thread` to exit. std::condition_variable lists_writing_thread_should_exit; /// Signals `lists_writing_thread` to exit.
bool lists_writing_thread_is_waiting = false; bool lists_writing_thread_is_waiting = false;
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)]; AccessChangesNotifier & changes_notifier;
mutable std::mutex mutex; mutable std::mutex mutex;
}; };
} }

View File

@ -6,7 +6,7 @@
namespace DB namespace DB
{ {
EnabledRoles::EnabledRoles(const Params & params_) : params(params_) EnabledRoles::EnabledRoles(const Params & params_) : params(params_), handlers(std::make_shared<Handlers>())
{ {
} }
@ -15,42 +15,50 @@ EnabledRoles::~EnabledRoles() = default;
std::shared_ptr<const EnabledRolesInfo> EnabledRoles::getRolesInfo() const std::shared_ptr<const EnabledRolesInfo> EnabledRoles::getRolesInfo() const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{info_mutex};
return info; return info;
} }
scope_guard EnabledRoles::subscribeForChanges(const OnChangeHandler & handler) const scope_guard EnabledRoles::subscribeForChanges(const OnChangeHandler & handler) const
{ {
std::lock_guard lock{mutex}; std::lock_guard lock{handlers->mutex};
handlers.push_back(handler); handlers->list.push_back(handler);
auto it = std::prev(handlers.end()); auto it = std::prev(handlers->list.end());
return [this, it] return [handlers=handlers, it]
{ {
std::lock_guard lock2{mutex}; std::lock_guard lock2{handlers->mutex};
handlers.erase(it); handlers->list.erase(it);
}; };
} }
void EnabledRoles::setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard & notifications) void EnabledRoles::setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard * notifications)
{ {
std::lock_guard lock{mutex}; {
std::lock_guard lock{info_mutex};
if (info && info_ && *info == *info_) if (info && info_ && *info == *info_)
return; return;
info = info_; info = info_;
}
if (notifications)
{
std::vector<OnChangeHandler> handlers_to_notify; std::vector<OnChangeHandler> handlers_to_notify;
boost::range::copy(handlers, std::back_inserter(handlers_to_notify)); {
std::lock_guard lock{handlers->mutex};
boost::range::copy(handlers->list, std::back_inserter(handlers_to_notify));
}
notifications.join(scope_guard([info = info, handlers_to_notify = std::move(handlers_to_notify)] notifications->join(scope_guard(
[info = info, handlers_to_notify = std::move(handlers_to_notify)]
{ {
for (const auto & handler : handlers_to_notify) for (const auto & handler : handlers_to_notify)
handler(info); handler(info);
})); }));
} }
}
} }

View File

@ -4,6 +4,7 @@
#include <base/scope_guard.h> #include <base/scope_guard.h>
#include <boost/container/flat_set.hpp> #include <boost/container/flat_set.hpp>
#include <list> #include <list>
#include <memory>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
@ -43,12 +44,21 @@ private:
friend class RoleCache; friend class RoleCache;
explicit EnabledRoles(const Params & params_); explicit EnabledRoles(const Params & params_);
void setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard & notifications); void setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard * notifications);
const Params params; const Params params;
mutable std::shared_ptr<const EnabledRolesInfo> info;
mutable std::list<OnChangeHandler> handlers; std::shared_ptr<const EnabledRolesInfo> info;
mutable std::mutex mutex; mutable std::mutex info_mutex;
struct Handlers
{
std::list<OnChangeHandler> list;
std::mutex mutex;
};
/// shared_ptr is here for safety because EnabledRoles can be destroyed before all subscriptions are removed.
std::shared_ptr<Handlers> handlers;
}; };
} }

View File

@ -410,34 +410,6 @@ bool IAccessStorage::updateImpl(const UUID & id, const UpdateFunc &, bool throw_
} }
scope_guard IAccessStorage::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(type, handler);
}
scope_guard IAccessStorage::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(id, handler);
}
scope_guard IAccessStorage::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
{
scope_guard subscriptions;
for (const auto & id : ids)
subscriptions.join(subscribeForChangesImpl(id, handler));
return subscriptions;
}
void IAccessStorage::notify(const Notifications & notifications)
{
for (const auto & [fn, id, new_entity] : notifications)
fn(id, new_entity);
}
UUID IAccessStorage::authenticate( UUID IAccessStorage::authenticate(
const Credentials & credentials, const Credentials & credentials,
const Poco::Net::IPAddress & address, const Poco::Net::IPAddress & address,

View File

@ -3,7 +3,6 @@
#include <Access/IAccessEntity.h> #include <Access/IAccessEntity.h>
#include <Core/Types.h> #include <Core/Types.h>
#include <Core/UUID.h> #include <Core/UUID.h>
#include <base/scope_guard.h>
#include <functional> #include <functional>
#include <optional> #include <optional>
#include <vector> #include <vector>
@ -22,7 +21,7 @@ enum class AuthenticationType;
/// Contains entities, i.e. instances of classes derived from IAccessEntity. /// Contains entities, i.e. instances of classes derived from IAccessEntity.
/// The implementations of this class MUST be thread-safe. /// The implementations of this class MUST be thread-safe.
class IAccessStorage class IAccessStorage : public boost::noncopyable
{ {
public: public:
explicit IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {} explicit IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {}
@ -41,6 +40,15 @@ public:
/// Returns true if this entity is readonly. /// Returns true if this entity is readonly.
virtual bool isReadOnly(const UUID &) const { return isReadOnly(); } virtual bool isReadOnly(const UUID &) const { return isReadOnly(); }
/// Reloads and updates entities in this storage. This function is used to implement SYSTEM RELOAD CONFIG.
virtual void reload() {}
/// Starts periodic reloading and update of entities in this storage.
virtual void startPeriodicReloading() {}
/// Stops periodic reloading and update of entities in this storage.
virtual void stopPeriodicReloading() {}
/// Returns the identifiers of all the entities of a specified type contained in the storage. /// Returns the identifiers of all the entities of a specified type contained in the storage.
std::vector<UUID> findAll(AccessEntityType type) const; std::vector<UUID> findAll(AccessEntityType type) const;
@ -130,23 +138,6 @@ public:
/// Updates multiple entities in the storage. Returns the list of successfully updated. /// Updates multiple entities in the storage. Returns the list of successfully updated.
std::vector<UUID> tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func); std::vector<UUID> tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func);
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const;
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(EntityClassT::TYPE, handler); }
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
virtual bool hasSubscription(AccessEntityType type) const = 0;
virtual bool hasSubscription(const UUID & id) const = 0;
/// Finds a user, check the provided credentials and returns the ID of the user if they are valid. /// Finds a user, check the provided credentials and returns the ID of the user if they are valid.
/// Throws an exception if no such user or credentials are invalid. /// Throws an exception if no such user or credentials are invalid.
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool allow_no_password, bool allow_plaintext_password) const; UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool allow_no_password, bool allow_plaintext_password) const;
@ -160,8 +151,6 @@ protected:
virtual std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists); virtual std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
virtual bool removeImpl(const UUID & id, bool throw_if_not_exists); virtual bool removeImpl(const UUID & id, bool throw_if_not_exists);
virtual bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists); virtual bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
virtual scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0;
virtual scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const = 0;
virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const; virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const;
virtual bool areCredentialsValid(const User & user, const Credentials & credentials, const ExternalAuthenticators & external_authenticators) const; virtual bool areCredentialsValid(const User & user, const Credentials & credentials, const ExternalAuthenticators & external_authenticators) const;
virtual bool isAddressAllowed(const User & user, const Poco::Net::IPAddress & address) const; virtual bool isAddressAllowed(const User & user, const Poco::Net::IPAddress & address) const;
@ -181,9 +170,6 @@ protected:
[[noreturn]] static void throwAddressNotAllowed(const Poco::Net::IPAddress & address); [[noreturn]] static void throwAddressNotAllowed(const Poco::Net::IPAddress & address);
[[noreturn]] static void throwInvalidCredentials(); [[noreturn]] static void throwInvalidCredentials();
[[noreturn]] static void throwAuthenticationTypeNotAllowed(AuthenticationType auth_type); [[noreturn]] static void throwAuthenticationTypeNotAllowed(AuthenticationType auth_type);
using Notification = std::tuple<OnChangedHandler, UUID, AccessEntityPtr>;
using Notifications = std::vector<Notification>;
static void notify(const Notifications & notifications);
private: private:
const String storage_name; const String storage_name;

View File

@ -27,10 +27,10 @@ namespace ErrorCodes
} }
LDAPAccessStorage::LDAPAccessStorage(const String & storage_name_, AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix) LDAPAccessStorage::LDAPAccessStorage(const String & storage_name_, AccessControl & access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix)
: IAccessStorage(storage_name_) : IAccessStorage(storage_name_), access_control(access_control_), memory_storage(storage_name_, access_control.getChangesNotifier())
{ {
setConfiguration(access_control_, config, prefix); setConfiguration(config, prefix);
} }
@ -40,7 +40,7 @@ String LDAPAccessStorage::getLDAPServerName() const
} }
void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix) void LDAPAccessStorage::setConfiguration(const Poco::Util::AbstractConfiguration & config, const String & prefix)
{ {
std::scoped_lock lock(mutex); std::scoped_lock lock(mutex);
@ -80,7 +80,6 @@ void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const
} }
} }
access_control = access_control_;
ldap_server_name = ldap_server_name_cfg; ldap_server_name = ldap_server_name_cfg;
role_search_params.swap(role_search_params_cfg); role_search_params.swap(role_search_params_cfg);
common_role_names.swap(common_roles_cfg); common_role_names.swap(common_roles_cfg);
@ -91,7 +90,7 @@ void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const
granted_role_names.clear(); granted_role_names.clear();
granted_role_ids.clear(); granted_role_ids.clear();
role_change_subscription = access_control->subscribeForChanges<Role>( role_change_subscription = access_control.subscribeForChanges<Role>(
[this] (const UUID & id, const AccessEntityPtr & entity) [this] (const UUID & id, const AccessEntityPtr & entity)
{ {
return this->processRoleChange(id, entity); return this->processRoleChange(id, entity);
@ -215,7 +214,7 @@ void LDAPAccessStorage::assignRolesNoLock(User & user, const LDAPClient::SearchR
auto it = granted_role_ids.find(role_name); auto it = granted_role_ids.find(role_name);
if (it == granted_role_ids.end()) if (it == granted_role_ids.end())
{ {
if (const auto role_id = access_control->find<Role>(role_name)) if (const auto role_id = access_control.find<Role>(role_name))
{ {
granted_role_names.insert_or_assign(*role_id, role_name); granted_role_names.insert_or_assign(*role_id, role_name);
it = granted_role_ids.insert_or_assign(role_name, *role_id).first; it = granted_role_ids.insert_or_assign(role_name, *role_id).first;
@ -450,33 +449,6 @@ std::optional<String> LDAPAccessStorage::readNameImpl(const UUID & id, bool thro
} }
scope_guard LDAPAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::scoped_lock lock(mutex);
return memory_storage.subscribeForChanges(id, handler);
}
scope_guard LDAPAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::scoped_lock lock(mutex);
return memory_storage.subscribeForChanges(type, handler);
}
bool LDAPAccessStorage::hasSubscription(const UUID & id) const
{
std::scoped_lock lock(mutex);
return memory_storage.hasSubscription(id);
}
bool LDAPAccessStorage::hasSubscription(AccessEntityType type) const
{
std::scoped_lock lock(mutex);
return memory_storage.hasSubscription(type);
}
std::optional<UUID> LDAPAccessStorage::authenticateImpl( std::optional<UUID> LDAPAccessStorage::authenticateImpl(
const Credentials & credentials, const Credentials & credentials,
const Poco::Net::IPAddress & address, const Poco::Net::IPAddress & address,

View File

@ -32,7 +32,7 @@ class LDAPAccessStorage : public IAccessStorage
public: public:
static constexpr char STORAGE_TYPE[] = "ldap"; static constexpr char STORAGE_TYPE[] = "ldap";
explicit LDAPAccessStorage(const String & storage_name_, AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix); explicit LDAPAccessStorage(const String & storage_name_, AccessControl & access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix);
virtual ~LDAPAccessStorage() override = default; virtual ~LDAPAccessStorage() override = default;
String getLDAPServerName() const; String getLDAPServerName() const;
@ -42,19 +42,15 @@ public:
virtual String getStorageParamsJSON() const override; virtual String getStorageParamsJSON() const override;
virtual bool isReadOnly() const override { return true; } virtual bool isReadOnly() const override { return true; }
virtual bool exists(const UUID & id) const override; virtual bool exists(const UUID & id) const override;
virtual bool hasSubscription(const UUID & id) const override;
virtual bool hasSubscription(AccessEntityType type) const override;
private: // IAccessStorage implementations. private: // IAccessStorage implementations.
virtual std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override; virtual std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
virtual std::vector<UUID> findAllImpl(AccessEntityType type) const override; virtual std::vector<UUID> findAllImpl(AccessEntityType type) const override;
virtual AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override; virtual AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
virtual std::optional<String> readNameImpl(const UUID & id, bool throw_if_not_exists) const override; virtual std::optional<String> readNameImpl(const UUID & id, bool throw_if_not_exists) const override;
virtual scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
virtual scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const override; virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const override;
void setConfiguration(AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix); void setConfiguration(const Poco::Util::AbstractConfiguration & config, const String & prefix);
void processRoleChange(const UUID & id, const AccessEntityPtr & entity); void processRoleChange(const UUID & id, const AccessEntityPtr & entity);
void applyRoleChangeNoLock(bool grant, const UUID & role_id, const String & role_name); void applyRoleChangeNoLock(bool grant, const UUID & role_id, const String & role_name);
@ -66,7 +62,7 @@ private: // IAccessStorage implementations.
const ExternalAuthenticators & external_authenticators, LDAPClient::SearchResultsList & role_search_results) const; const ExternalAuthenticators & external_authenticators, LDAPClient::SearchResultsList & role_search_results) const;
mutable std::recursive_mutex mutex; mutable std::recursive_mutex mutex;
AccessControl * access_control = nullptr; AccessControl & access_control;
String ldap_server_name; String ldap_server_name;
LDAPClient::RoleSearchParamsList role_search_params; LDAPClient::RoleSearchParamsList role_search_params;
std::set<String> common_role_names; // role name that should be granted to all users at all times std::set<String> common_role_names; // role name that should be granted to all users at all times

View File

@ -1,4 +1,5 @@
#include <Access/MemoryAccessStorage.h> #include <Access/MemoryAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
#include <base/scope_guard.h> #include <base/scope_guard.h>
#include <boost/container/flat_set.hpp> #include <boost/container/flat_set.hpp>
#include <boost/range/adaptor/map.hpp> #include <boost/range/adaptor/map.hpp>
@ -7,8 +8,8 @@
namespace DB namespace DB
{ {
MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_) MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_, AccessChangesNotifier & changes_notifier_)
: IAccessStorage(storage_name_) : IAccessStorage(storage_name_), changes_notifier(changes_notifier_)
{ {
} }
@ -63,19 +64,16 @@ AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id, bool throw_if_not
std::optional<UUID> MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists) std::optional<UUID> MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{ {
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
UUID id = generateRandomID(); UUID id = generateRandomID();
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists, notifications)) if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists))
return id; return id;
return std::nullopt; return std::nullopt;
} }
bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications) bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{ {
const String & name = new_entity->getName(); const String & name = new_entity->getName();
AccessEntityType type = new_entity->getType(); AccessEntityType type = new_entity->getType();
@ -103,7 +101,7 @@ bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr &
if (name_collision && replace_if_exists) if (name_collision && replace_if_exists)
{ {
const auto & existing_entry = *(it_by_name->second); const auto & existing_entry = *(it_by_name->second);
removeNoLock(existing_entry.id, /* throw_if_not_exists = */ false, notifications); removeNoLock(existing_entry.id, /* throw_if_not_exists = */ false);
} }
/// Do insertion. /// Do insertion.
@ -111,22 +109,19 @@ bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr &
entry.id = id; entry.id = id;
entry.entity = new_entity; entry.entity = new_entity;
entries_by_name[name] = &entry; entries_by_name[name] = &entry;
prepareNotifications(entry, false, notifications); changes_notifier.onEntityAdded(id, new_entity);
return true; return true;
} }
bool MemoryAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exists) bool MemoryAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exists)
{ {
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
return removeNoLock(id, throw_if_not_exists, notifications); return removeNoLock(id, throw_if_not_exists);
} }
bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications) bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists)
{ {
auto it = entries_by_id.find(id); auto it = entries_by_id.find(id);
if (it == entries_by_id.end()) if (it == entries_by_id.end())
@ -141,27 +136,25 @@ bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists
const String & name = entry.entity->getName(); const String & name = entry.entity->getName();
AccessEntityType type = entry.entity->getType(); AccessEntityType type = entry.entity->getType();
prepareNotifications(entry, true, notifications);
/// Do removing. /// Do removing.
UUID removed_id = id;
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)]; auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
entries_by_name.erase(name); entries_by_name.erase(name);
entries_by_id.erase(it); entries_by_id.erase(it);
changes_notifier.onEntityRemoved(removed_id, type);
return true; return true;
} }
bool MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) bool MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{ {
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
return updateNoLock(id, update_func, throw_if_not_exists, notifications); return updateNoLock(id, update_func, throw_if_not_exists);
} }
bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications) bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{ {
auto it = entries_by_id.find(id); auto it = entries_by_id.find(id);
if (it == entries_by_id.end()) if (it == entries_by_id.end())
@ -195,7 +188,7 @@ bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & updat
entries_by_name[new_entity->getName()] = &entry; entries_by_name[new_entity->getName()] = &entry;
} }
prepareNotifications(entry, false, notifications); changes_notifier.onEntityUpdated(id, new_entity);
return true; return true;
} }
@ -212,16 +205,8 @@ void MemoryAccessStorage::setAll(const std::vector<AccessEntityPtr> & all_entiti
void MemoryAccessStorage::setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities) void MemoryAccessStorage::setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities)
{ {
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
setAllNoLock(all_entities, notifications);
}
void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications)
{
boost::container::flat_set<UUID> not_used_ids; boost::container::flat_set<UUID> not_used_ids;
std::vector<UUID> conflicting_ids; std::vector<UUID> conflicting_ids;
@ -256,7 +241,7 @@ void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessE
boost::container::flat_set<UUID> ids_to_remove = std::move(not_used_ids); boost::container::flat_set<UUID> ids_to_remove = std::move(not_used_ids);
boost::range::copy(conflicting_ids, std::inserter(ids_to_remove, ids_to_remove.end())); boost::range::copy(conflicting_ids, std::inserter(ids_to_remove, ids_to_remove.end()));
for (const auto & id : ids_to_remove) for (const auto & id : ids_to_remove)
removeNoLock(id, /* throw_if_not_exists = */ false, notifications); removeNoLock(id, /* throw_if_not_exists = */ false);
/// Insert or update entities. /// Insert or update entities.
for (const auto & [id, entity] : all_entities) for (const auto & [id, entity] : all_entities)
@ -269,84 +254,14 @@ void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessE
const AccessEntityPtr & changed_entity = entity; const AccessEntityPtr & changed_entity = entity;
updateNoLock(id, updateNoLock(id,
[&changed_entity](const AccessEntityPtr &) { return changed_entity; }, [&changed_entity](const AccessEntityPtr &) { return changed_entity; },
/* throw_if_not_exists = */ true, /* throw_if_not_exists = */ true);
notifications);
} }
} }
else else
{ {
insertNoLock(id, entity, /* replace_if_exists = */ false, /* throw_if_exists = */ true, notifications); insertNoLock(id, entity, /* replace_if_exists = */ false, /* throw_if_exists = */ true);
} }
} }
} }
void MemoryAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.entity->getType())])
notifications.push_back({handler, entry.id, entity});
}
scope_guard MemoryAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
bool MemoryAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool MemoryAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
} }

View File

@ -9,13 +9,15 @@
namespace DB namespace DB
{ {
class AccessChangesNotifier;
/// Implementation of IAccessStorage which keeps all data in memory. /// Implementation of IAccessStorage which keeps all data in memory.
class MemoryAccessStorage : public IAccessStorage class MemoryAccessStorage : public IAccessStorage
{ {
public: public:
static constexpr char STORAGE_TYPE[] = "memory"; static constexpr char STORAGE_TYPE[] = "memory";
explicit MemoryAccessStorage(const String & storage_name_ = STORAGE_TYPE); explicit MemoryAccessStorage(const String & storage_name_, AccessChangesNotifier & changes_notifier_);
const char * getStorageType() const override { return STORAGE_TYPE; } const char * getStorageType() const override { return STORAGE_TYPE; }
@ -24,8 +26,6 @@ public:
void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities); void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities);
bool exists(const UUID & id) const override; bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private: private:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override; std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -34,25 +34,20 @@ private:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override; std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override; bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override; bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override; bool insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
bool removeNoLock(const UUID & id, bool throw_if_not_exists);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
struct Entry struct Entry
{ {
UUID id; UUID id;
AccessEntityPtr entity; AccessEntityPtr entity;
mutable std::list<OnChangedHandler> handlers_by_id;
}; };
bool insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications); mutable std::mutex mutex;
bool removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications);
void setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications);
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
mutable std::recursive_mutex mutex;
std::unordered_map<UUID, Entry> entries_by_id; /// We want to search entries both by ID and by the pair of name and type. std::unordered_map<UUID, Entry> entries_by_id; /// We want to search entries both by ID and by the pair of name and type.
std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(AccessEntityType::MAX)]; std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)]; AccessChangesNotifier & changes_notifier;
}; };
} }

View File

@ -45,7 +45,6 @@ void MultipleAccessStorage::setStorages(const std::vector<StoragePtr> & storages
std::unique_lock lock{mutex}; std::unique_lock lock{mutex};
nested_storages = std::make_shared<const Storages>(storages); nested_storages = std::make_shared<const Storages>(storages);
ids_cache.reset(); ids_cache.reset();
updateSubscriptionsToNestedStorages(lock);
} }
void MultipleAccessStorage::addStorage(const StoragePtr & new_storage) void MultipleAccessStorage::addStorage(const StoragePtr & new_storage)
@ -56,7 +55,6 @@ void MultipleAccessStorage::addStorage(const StoragePtr & new_storage)
auto new_storages = std::make_shared<Storages>(*nested_storages); auto new_storages = std::make_shared<Storages>(*nested_storages);
new_storages->push_back(new_storage); new_storages->push_back(new_storage);
nested_storages = new_storages; nested_storages = new_storages;
updateSubscriptionsToNestedStorages(lock);
} }
void MultipleAccessStorage::removeStorage(const StoragePtr & storage_to_remove) void MultipleAccessStorage::removeStorage(const StoragePtr & storage_to_remove)
@ -70,7 +68,6 @@ void MultipleAccessStorage::removeStorage(const StoragePtr & storage_to_remove)
new_storages->erase(new_storages->begin() + index); new_storages->erase(new_storages->begin() + index);
nested_storages = new_storages; nested_storages = new_storages;
ids_cache.reset(); ids_cache.reset();
updateSubscriptionsToNestedStorages(lock);
} }
std::vector<StoragePtr> MultipleAccessStorage::getStorages() std::vector<StoragePtr> MultipleAccessStorage::getStorages()
@ -225,6 +222,28 @@ bool MultipleAccessStorage::isReadOnly(const UUID & id) const
} }
void MultipleAccessStorage::reload()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->reload();
}
void MultipleAccessStorage::startPeriodicReloading()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->startPeriodicReloading();
}
void MultipleAccessStorage::stopPeriodicReloading()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->stopPeriodicReloading();
}
std::optional<UUID> MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) std::optional<UUID> MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists)
{ {
std::shared_ptr<IAccessStorage> storage_for_insertion; std::shared_ptr<IAccessStorage> storage_for_insertion;
@ -310,145 +329,6 @@ bool MultipleAccessStorage::updateImpl(const UUID & id, const UpdateFunc & updat
} }
scope_guard MultipleAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
auto storage = findStorage(id);
if (!storage)
return {};
return storage->subscribeForChanges(id, handler);
}
bool MultipleAccessStorage::hasSubscription(const UUID & id) const
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
{
if (storage->hasSubscription(id))
return true;
}
return false;
}
scope_guard MultipleAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::unique_lock lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
if (handlers.size() == 1)
updateSubscriptionsToNestedStorages(lock);
return [this, type, handler_it]
{
std::unique_lock lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
if (handlers2.empty())
updateSubscriptionsToNestedStorages(lock2);
};
}
bool MultipleAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
/// Updates subscriptions to nested storages.
/// We need the subscriptions to the nested storages if someone has subscribed to us.
/// If any of the nested storages is changed we call our subscribers.
void MultipleAccessStorage::updateSubscriptionsToNestedStorages(std::unique_lock<std::mutex> & lock) const
{
/// lock is already locked.
std::vector<std::pair<StoragePtr, scope_guard>> added_subscriptions[static_cast<size_t>(AccessEntityType::MAX)];
std::vector<scope_guard> removed_subscriptions;
for (auto type : collections::range(AccessEntityType::MAX))
{
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
auto & subscriptions = subscriptions_to_nested_storages[static_cast<size_t>(type)];
if (handlers.empty())
{
/// None has subscribed to us, we need no subscriptions to the nested storages.
for (auto & subscription : subscriptions | boost::adaptors::map_values)
removed_subscriptions.push_back(std::move(subscription));
subscriptions.clear();
}
else
{
/// Someone has subscribed to us, now we need to have a subscription to each nested storage.
for (auto it = subscriptions.begin(); it != subscriptions.end();)
{
const auto & storage = it->first;
auto & subscription = it->second;
if (boost::range::find(*nested_storages, storage) == nested_storages->end())
{
removed_subscriptions.push_back(std::move(subscription));
it = subscriptions.erase(it);
}
else
++it;
}
for (const auto & storage : *nested_storages)
{
if (!subscriptions.contains(storage))
added_subscriptions[static_cast<size_t>(type)].push_back({storage, nullptr});
}
}
}
/// Unlock the mutex temporarily because it's much better to subscribe to the nested storages
/// with the mutex unlocked.
lock.unlock();
removed_subscriptions.clear();
for (auto type : collections::range(AccessEntityType::MAX))
{
if (!added_subscriptions[static_cast<size_t>(type)].empty())
{
auto on_changed = [this, type](const UUID & id, const AccessEntityPtr & entity)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock2{mutex};
for (const auto & handler : handlers_by_type[static_cast<size_t>(type)])
notifications.push_back({handler, id, entity});
};
for (auto & [storage, subscription] : added_subscriptions[static_cast<size_t>(type)])
subscription = storage->subscribeForChanges(type, on_changed);
}
}
/// Lock the mutex again to store added subscriptions to the nested storages.
lock.lock();
for (auto type : collections::range(AccessEntityType::MAX))
{
if (!added_subscriptions[static_cast<size_t>(type)].empty())
{
auto & subscriptions = subscriptions_to_nested_storages[static_cast<size_t>(type)];
for (auto & [storage, subscription] : added_subscriptions[static_cast<size_t>(type)])
{
if (!subscriptions.contains(storage) && (boost::range::find(*nested_storages, storage) != nested_storages->end())
&& !handlers_by_type[static_cast<size_t>(type)].empty())
{
subscriptions.emplace(std::move(storage), std::move(subscription));
}
}
}
}
lock.unlock();
}
std::optional<UUID> std::optional<UUID>
MultipleAccessStorage::authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, MultipleAccessStorage::authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address,
const ExternalAuthenticators & external_authenticators, const ExternalAuthenticators & external_authenticators,

View File

@ -24,6 +24,10 @@ public:
bool isReadOnly() const override; bool isReadOnly() const override;
bool isReadOnly(const UUID & id) const override; bool isReadOnly(const UUID & id) const override;
void reload() override;
void startPeriodicReloading() override;
void stopPeriodicReloading() override;
void setStorages(const std::vector<StoragePtr> & storages); void setStorages(const std::vector<StoragePtr> & storages);
void addStorage(const StoragePtr & new_storage); void addStorage(const StoragePtr & new_storage);
void removeStorage(const StoragePtr & storage_to_remove); void removeStorage(const StoragePtr & storage_to_remove);
@ -37,8 +41,6 @@ public:
StoragePtr getStorage(const UUID & id); StoragePtr getStorage(const UUID & id);
bool exists(const UUID & id) const override; bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
protected: protected:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override; std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -48,19 +50,14 @@ protected:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override; std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override; bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override; bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const override; std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const override;
private: private:
using Storages = std::vector<StoragePtr>; using Storages = std::vector<StoragePtr>;
std::shared_ptr<const Storages> getStoragesInternal() const; std::shared_ptr<const Storages> getStoragesInternal() const;
void updateSubscriptionsToNestedStorages(std::unique_lock<std::mutex> & lock) const;
std::shared_ptr<const Storages> nested_storages; std::shared_ptr<const Storages> nested_storages;
mutable LRUCache<UUID, Storage> ids_cache; mutable LRUCache<UUID, Storage> ids_cache;
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::unordered_map<StoragePtr, scope_guard> subscriptions_to_nested_storages[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::mutex mutex; mutable std::mutex mutex;
}; };

View File

@ -1,12 +1,14 @@
#include <Access/AccessEntityIO.h> #include <Access/AccessEntityIO.h>
#include <Access/MemoryAccessStorage.h> #include <Access/MemoryAccessStorage.h>
#include <Access/ReplicatedAccessStorage.h> #include <Access/ReplicatedAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <boost/container/flat_set.hpp> #include <boost/container/flat_set.hpp>
#include <Common/ZooKeeper/KeeperException.h> #include <Common/ZooKeeper/KeeperException.h>
#include <Common/ZooKeeper/Types.h> #include <Common/ZooKeeper/Types.h>
#include <Common/ZooKeeper/ZooKeeper.h> #include <Common/ZooKeeper/ZooKeeper.h>
#include <Common/escapeForFileName.h> #include <Common/escapeForFileName.h>
#include <Common/setThreadName.h>
#include <base/range.h> #include <base/range.h>
#include <base/sleep.h> #include <base/sleep.h>
@ -30,11 +32,13 @@ static UUID parseUUID(const String & text)
ReplicatedAccessStorage::ReplicatedAccessStorage( ReplicatedAccessStorage::ReplicatedAccessStorage(
const String & storage_name_, const String & storage_name_,
const String & zookeeper_path_, const String & zookeeper_path_,
zkutil::GetZooKeeper get_zookeeper_) zkutil::GetZooKeeper get_zookeeper_,
AccessChangesNotifier & changes_notifier_)
: IAccessStorage(storage_name_) : IAccessStorage(storage_name_)
, zookeeper_path(zookeeper_path_) , zookeeper_path(zookeeper_path_)
, get_zookeeper(get_zookeeper_) , get_zookeeper(get_zookeeper_)
, refresh_queue(std::numeric_limits<size_t>::max()) , watched_queue(std::make_shared<ConcurrentBoundedQueue<UUID>>(std::numeric_limits<size_t>::max()))
, changes_notifier(changes_notifier_)
{ {
if (zookeeper_path.empty()) if (zookeeper_path.empty())
throw Exception("ZooKeeper path must be non-empty", ErrorCodes::BAD_ARGUMENTS); throw Exception("ZooKeeper path must be non-empty", ErrorCodes::BAD_ARGUMENTS);
@ -45,29 +49,30 @@ ReplicatedAccessStorage::ReplicatedAccessStorage(
/// If zookeeper chroot prefix is used, path should start with '/', because chroot concatenates without it. /// If zookeeper chroot prefix is used, path should start with '/', because chroot concatenates without it.
if (zookeeper_path.front() != '/') if (zookeeper_path.front() != '/')
zookeeper_path = "/" + zookeeper_path; zookeeper_path = "/" + zookeeper_path;
initializeZookeeper();
} }
ReplicatedAccessStorage::~ReplicatedAccessStorage() ReplicatedAccessStorage::~ReplicatedAccessStorage()
{ {
ReplicatedAccessStorage::shutdown(); stopWatchingThread();
} }
void ReplicatedAccessStorage::startWatchingThread()
void ReplicatedAccessStorage::startup()
{ {
initializeZookeeper(); bool prev_watching_flag = watching.exchange(true);
worker_thread = ThreadFromGlobalPool(&ReplicatedAccessStorage::runWorkerThread, this); if (!prev_watching_flag)
watching_thread = ThreadFromGlobalPool(&ReplicatedAccessStorage::runWatchingThread, this);
} }
void ReplicatedAccessStorage::shutdown() void ReplicatedAccessStorage::stopWatchingThread()
{ {
bool prev_stop_flag = stop_flag.exchange(true); bool prev_watching_flag = watching.exchange(false);
if (!prev_stop_flag) if (prev_watching_flag)
{ {
refresh_queue.finish(); watched_queue->finish();
if (watching_thread.joinable())
if (worker_thread.joinable()) watching_thread.join();
worker_thread.join();
} }
} }
@ -105,10 +110,8 @@ std::optional<UUID> ReplicatedAccessStorage::insertImpl(const AccessEntityPtr &
if (!ok) if (!ok)
return std::nullopt; return std::nullopt;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id, notifications); refreshEntityNoLock(zookeeper, id);
return id; return id;
} }
@ -207,10 +210,8 @@ bool ReplicatedAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exis
if (!ok) if (!ok)
return false; return false;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
removeEntityNoLock(id, notifications); removeEntityNoLock(id);
return true; return true;
} }
@ -261,10 +262,8 @@ bool ReplicatedAccessStorage::updateImpl(const UUID & id, const UpdateFunc & upd
if (!ok) if (!ok)
return false; return false;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id, notifications); refreshEntityNoLock(zookeeper, id);
return true; return true;
} }
@ -328,16 +327,18 @@ bool ReplicatedAccessStorage::updateZooKeeper(const zkutil::ZooKeeperPtr & zooke
} }
void ReplicatedAccessStorage::runWorkerThread() void ReplicatedAccessStorage::runWatchingThread()
{ {
LOG_DEBUG(getLogger(), "Started worker thread"); LOG_DEBUG(getLogger(), "Started watching thread");
while (!stop_flag) setThreadName("ReplACLWatch");
while (watching)
{ {
try try
{ {
if (!initialized) if (!initialized)
initializeZookeeper(); initializeZookeeper();
refresh(); if (refresh())
changes_notifier.sendNotifications();
} }
catch (...) catch (...)
{ {
@ -353,7 +354,7 @@ void ReplicatedAccessStorage::resetAfterError()
initialized = false; initialized = false;
UUID id; UUID id;
while (refresh_queue.tryPop(id)) {} while (watched_queue->tryPop(id)) {}
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
for (const auto type : collections::range(AccessEntityType::MAX)) for (const auto type : collections::range(AccessEntityType::MAX))
@ -389,13 +390,11 @@ void ReplicatedAccessStorage::createRootNodes(const zkutil::ZooKeeperPtr & zooke
} }
} }
void ReplicatedAccessStorage::refresh() bool ReplicatedAccessStorage::refresh()
{ {
UUID id; UUID id;
if (refresh_queue.tryPop(id, /* timeout_ms: */ 10000)) if (!watched_queue->tryPop(id, /* timeout_ms: */ 10000))
{ return false;
if (stop_flag)
return;
auto zookeeper = get_zookeeper(); auto zookeeper = get_zookeeper();
@ -403,7 +402,8 @@ void ReplicatedAccessStorage::refresh()
refreshEntities(zookeeper); refreshEntities(zookeeper);
else else
refreshEntity(zookeeper, id); refreshEntity(zookeeper, id);
}
return true;
} }
@ -412,9 +412,9 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
LOG_DEBUG(getLogger(), "Refreshing entities list"); LOG_DEBUG(getLogger(), "Refreshing entities list");
const String zookeeper_uuids_path = zookeeper_path + "/uuid"; const String zookeeper_uuids_path = zookeeper_path + "/uuid";
auto watch_entities_list = [this](const Coordination::WatchResponse &) auto watch_entities_list = [watched_queue = watched_queue](const Coordination::WatchResponse &)
{ {
[[maybe_unused]] bool push_result = refresh_queue.push(UUIDHelpers::Nil); [[maybe_unused]] bool push_result = watched_queue->push(UUIDHelpers::Nil);
}; };
Coordination::Stat stat; Coordination::Stat stat;
const auto entity_uuid_strs = zookeeper->getChildrenWatch(zookeeper_uuids_path, &stat, watch_entities_list); const auto entity_uuid_strs = zookeeper->getChildrenWatch(zookeeper_uuids_path, &stat, watch_entities_list);
@ -424,8 +424,6 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
for (const String & entity_uuid_str : entity_uuid_strs) for (const String & entity_uuid_str : entity_uuid_strs)
entity_uuids.insert(parseUUID(entity_uuid_str)); entity_uuids.insert(parseUUID(entity_uuid_str));
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
std::vector<UUID> entities_to_remove; std::vector<UUID> entities_to_remove;
@ -437,14 +435,14 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
entities_to_remove.push_back(entity_uuid); entities_to_remove.push_back(entity_uuid);
} }
for (const auto & entity_uuid : entities_to_remove) for (const auto & entity_uuid : entities_to_remove)
removeEntityNoLock(entity_uuid, notifications); removeEntityNoLock(entity_uuid);
/// Locally add entities that were added to ZooKeeper /// Locally add entities that were added to ZooKeeper
for (const auto & entity_uuid : entity_uuids) for (const auto & entity_uuid : entity_uuids)
{ {
const auto it = entries_by_id.find(entity_uuid); const auto it = entries_by_id.find(entity_uuid);
if (it == entries_by_id.end()) if (it == entries_by_id.end())
refreshEntityNoLock(zookeeper, entity_uuid, notifications); refreshEntityNoLock(zookeeper, entity_uuid);
} }
LOG_DEBUG(getLogger(), "Refreshing entities list finished"); LOG_DEBUG(getLogger(), "Refreshing entities list finished");
@ -452,21 +450,18 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
void ReplicatedAccessStorage::refreshEntity(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id) void ReplicatedAccessStorage::refreshEntity(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id)
{ {
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id);
refreshEntityNoLock(zookeeper, id, notifications);
} }
void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, Notifications & notifications) void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id)
{ {
LOG_DEBUG(getLogger(), "Refreshing entity {}", toString(id)); LOG_DEBUG(getLogger(), "Refreshing entity {}", toString(id));
const auto watch_entity = [this, id](const Coordination::WatchResponse & response) const auto watch_entity = [watched_queue = watched_queue, id](const Coordination::WatchResponse & response)
{ {
if (response.type == Coordination::Event::CHANGED) if (response.type == Coordination::Event::CHANGED)
[[maybe_unused]] bool push_result = refresh_queue.push(id); [[maybe_unused]] bool push_result = watched_queue->push(id);
}; };
Coordination::Stat entity_stat; Coordination::Stat entity_stat;
const String entity_path = zookeeper_path + "/uuid/" + toString(id); const String entity_path = zookeeper_path + "/uuid/" + toString(id);
@ -475,16 +470,16 @@ void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & z
if (exists) if (exists)
{ {
const AccessEntityPtr entity = deserializeAccessEntity(entity_definition, entity_path); const AccessEntityPtr entity = deserializeAccessEntity(entity_definition, entity_path);
setEntityNoLock(id, entity, notifications); setEntityNoLock(id, entity);
} }
else else
{ {
removeEntityNoLock(id, notifications); removeEntityNoLock(id);
} }
} }
void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntityPtr & entity, Notifications & notifications) void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntityPtr & entity)
{ {
LOG_DEBUG(getLogger(), "Setting id {} to entity named {}", toString(id), entity->getName()); LOG_DEBUG(getLogger(), "Setting id {} to entity named {}", toString(id), entity->getName());
const AccessEntityType type = entity->getType(); const AccessEntityType type = entity->getType();
@ -494,12 +489,14 @@ void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntit
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)]; auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
if (auto it = entries_by_name.find(name); it != entries_by_name.end() && it->second->id != id) if (auto it = entries_by_name.find(name); it != entries_by_name.end() && it->second->id != id)
{ {
removeEntityNoLock(it->second->id, notifications); removeEntityNoLock(it->second->id);
} }
/// If the entity already exists under a different type+name, remove old type+name /// If the entity already exists under a different type+name, remove old type+name
bool existed_before = false;
if (auto it = entries_by_id.find(id); it != entries_by_id.end()) if (auto it = entries_by_id.find(id); it != entries_by_id.end())
{ {
existed_before = true;
const AccessEntityPtr & existing_entity = it->second.entity; const AccessEntityPtr & existing_entity = it->second.entity;
const AccessEntityType existing_type = existing_entity->getType(); const AccessEntityType existing_type = existing_entity->getType();
const String & existing_name = existing_entity->getName(); const String & existing_name = existing_entity->getName();
@ -514,11 +511,18 @@ void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntit
entry.id = id; entry.id = id;
entry.entity = entity; entry.entity = entity;
entries_by_name[name] = &entry; entries_by_name[name] = &entry;
prepareNotifications(entry, false, notifications);
if (initialized)
{
if (existed_before)
changes_notifier.onEntityUpdated(id, entity);
else
changes_notifier.onEntityAdded(id, entity);
}
} }
void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications & notifications) void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id)
{ {
LOG_DEBUG(getLogger(), "Removing entity with id {}", toString(id)); LOG_DEBUG(getLogger(), "Removing entity with id {}", toString(id));
const auto it = entries_by_id.find(id); const auto it = entries_by_id.find(id);
@ -531,7 +535,6 @@ void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications
const Entry & entry = it->second; const Entry & entry = it->second;
const AccessEntityType type = entry.entity->getType(); const AccessEntityType type = entry.entity->getType();
const String & name = entry.entity->getName(); const String & name = entry.entity->getName();
prepareNotifications(entry, true, notifications);
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)]; auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
const auto name_it = entries_by_name.find(name); const auto name_it = entries_by_name.find(name);
@ -542,8 +545,11 @@ void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications
else else
entries_by_name.erase(name); entries_by_name.erase(name);
UUID removed_id = id;
entries_by_id.erase(id); entries_by_id.erase(id);
LOG_DEBUG(getLogger(), "Removed entity with id {}", toString(id)); LOG_DEBUG(getLogger(), "Removed entity with id {}", toString(id));
changes_notifier.onEntityRemoved(removed_id, type);
} }
@ -594,73 +600,4 @@ AccessEntityPtr ReplicatedAccessStorage::readImpl(const UUID & id, bool throw_if
return entry.entity; return entry.entity;
} }
void ReplicatedAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.entity->getType())])
notifications.push_back({handler, entry.id, entity});
}
scope_guard ReplicatedAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
scope_guard ReplicatedAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
const auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
bool ReplicatedAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
const auto & it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool ReplicatedAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
} }

View File

@ -18,32 +18,33 @@
namespace DB namespace DB
{ {
class AccessChangesNotifier;
/// Implementation of IAccessStorage which keeps all data in zookeeper. /// Implementation of IAccessStorage which keeps all data in zookeeper.
class ReplicatedAccessStorage : public IAccessStorage class ReplicatedAccessStorage : public IAccessStorage
{ {
public: public:
static constexpr char STORAGE_TYPE[] = "replicated"; static constexpr char STORAGE_TYPE[] = "replicated";
ReplicatedAccessStorage(const String & storage_name, const String & zookeeper_path, zkutil::GetZooKeeper get_zookeeper); ReplicatedAccessStorage(const String & storage_name, const String & zookeeper_path, zkutil::GetZooKeeper get_zookeeper, AccessChangesNotifier & changes_notifier_);
virtual ~ReplicatedAccessStorage() override; virtual ~ReplicatedAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; } const char * getStorageType() const override { return STORAGE_TYPE; }
virtual void startup(); void startPeriodicReloading() override { startWatchingThread(); }
virtual void shutdown(); void stopPeriodicReloading() override { stopWatchingThread(); }
bool exists(const UUID & id) const override; bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private: private:
String zookeeper_path; String zookeeper_path;
zkutil::GetZooKeeper get_zookeeper; zkutil::GetZooKeeper get_zookeeper;
std::atomic<bool> initialized = false; std::atomic<bool> initialized = false;
std::atomic<bool> stop_flag = false;
ThreadFromGlobalPool worker_thread; std::atomic<bool> watching = false;
ConcurrentBoundedQueue<UUID> refresh_queue; ThreadFromGlobalPool watching_thread;
std::shared_ptr<ConcurrentBoundedQueue<UUID>> watched_queue;
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override; std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override; bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
@ -53,37 +54,36 @@ private:
bool removeZooKeeper(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, bool throw_if_not_exists); bool removeZooKeeper(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, bool throw_if_not_exists);
bool updateZooKeeper(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists); bool updateZooKeeper(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
void runWorkerThread();
void resetAfterError();
void initializeZookeeper(); void initializeZookeeper();
void createRootNodes(const zkutil::ZooKeeperPtr & zookeeper); void createRootNodes(const zkutil::ZooKeeperPtr & zookeeper);
void refresh(); void startWatchingThread();
void stopWatchingThread();
void runWatchingThread();
void resetAfterError();
bool refresh();
void refreshEntities(const zkutil::ZooKeeperPtr & zookeeper); void refreshEntities(const zkutil::ZooKeeperPtr & zookeeper);
void refreshEntity(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id); void refreshEntity(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id);
void refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, Notifications & notifications); void refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id);
void setEntityNoLock(const UUID & id, const AccessEntityPtr & entity, Notifications & notifications); void setEntityNoLock(const UUID & id, const AccessEntityPtr & entity);
void removeEntityNoLock(const UUID & id, Notifications & notifications); void removeEntityNoLock(const UUID & id);
struct Entry struct Entry
{ {
UUID id; UUID id;
AccessEntityPtr entity; AccessEntityPtr entity;
mutable std::list<OnChangedHandler> handlers_by_id;
}; };
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override; std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
std::vector<UUID> findAllImpl(AccessEntityType type) const override; std::vector<UUID> findAllImpl(AccessEntityType type) const override;
AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override; AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
mutable std::mutex mutex; mutable std::mutex mutex;
std::unordered_map<UUID, Entry> entries_by_id; std::unordered_map<UUID, Entry> entries_by_id;
std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(AccessEntityType::MAX)]; std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)]; AccessChangesNotifier & changes_notifier;
}; };
} }

View File

@ -66,9 +66,6 @@ RoleCache::~RoleCache() = default;
std::shared_ptr<const EnabledRoles> std::shared_ptr<const EnabledRoles>
RoleCache::getEnabledRoles(const std::vector<UUID> & roles, const std::vector<UUID> & roles_with_admin_option) RoleCache::getEnabledRoles(const std::vector<UUID> & roles, const std::vector<UUID> & roles_with_admin_option)
{ {
/// Declared before `lock` to send notifications after the mutex will be unlocked.
scope_guard notifications;
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
EnabledRoles::Params params; EnabledRoles::Params params;
params.current_roles.insert(roles.begin(), roles.end()); params.current_roles.insert(roles.begin(), roles.end());
@ -83,13 +80,13 @@ RoleCache::getEnabledRoles(const std::vector<UUID> & roles, const std::vector<UU
} }
auto res = std::shared_ptr<EnabledRoles>(new EnabledRoles(params)); auto res = std::shared_ptr<EnabledRoles>(new EnabledRoles(params));
collectEnabledRoles(*res, notifications); collectEnabledRoles(*res, nullptr);
enabled_roles.emplace(std::move(params), res); enabled_roles.emplace(std::move(params), res);
return res; return res;
} }
void RoleCache::collectEnabledRoles(scope_guard & notifications) void RoleCache::collectEnabledRoles(scope_guard * notifications)
{ {
/// `mutex` is already locked. /// `mutex` is already locked.
@ -107,7 +104,7 @@ void RoleCache::collectEnabledRoles(scope_guard & notifications)
} }
void RoleCache::collectEnabledRoles(EnabledRoles & enabled, scope_guard & notifications) void RoleCache::collectEnabledRoles(EnabledRoles & enabled, scope_guard * notifications)
{ {
/// `mutex` is already locked. /// `mutex` is already locked.
@ -170,7 +167,7 @@ void RoleCache::roleChanged(const UUID & role_id, const RolePtr & changed_role)
return; return;
role_from_cache->first = changed_role; role_from_cache->first = changed_role;
cache.update(role_id, role_from_cache); cache.update(role_id, role_from_cache);
collectEnabledRoles(notifications); collectEnabledRoles(&notifications);
} }
@ -181,7 +178,7 @@ void RoleCache::roleRemoved(const UUID & role_id)
std::lock_guard lock{mutex}; std::lock_guard lock{mutex};
cache.remove(role_id); cache.remove(role_id);
collectEnabledRoles(notifications); collectEnabledRoles(&notifications);
} }
} }

View File

@ -24,8 +24,8 @@ public:
const std::vector<UUID> & current_roles_with_admin_option); const std::vector<UUID> & current_roles_with_admin_option);
private: private:
void collectEnabledRoles(scope_guard & notifications); void collectEnabledRoles(scope_guard * notifications);
void collectEnabledRoles(EnabledRoles & enabled, scope_guard & notifications); void collectEnabledRoles(EnabledRoles & enabled, scope_guard * notifications);
RolePtr getRole(const UUID & role_id); RolePtr getRole(const UUID & role_id);
void roleChanged(const UUID & role_id, const RolePtr & changed_role); void roleChanged(const UUID & role_id, const RolePtr & changed_role);
void roleRemoved(const UUID & role_id); void roleRemoved(const UUID & role_id);

View File

@ -4,6 +4,7 @@
#include <Access/User.h> #include <Access/User.h>
#include <Access/SettingsProfile.h> #include <Access/SettingsProfile.h>
#include <Access/AccessControl.h> #include <Access/AccessControl.h>
#include <Access/AccessChangesNotifier.h>
#include <Dictionaries/IDictionary.h> #include <Dictionaries/IDictionary.h>
#include <Common/Config/ConfigReloader.h> #include <Common/Config/ConfigReloader.h>
#include <Common/StringUtils/StringUtils.h> #include <Common/StringUtils/StringUtils.h>
@ -14,9 +15,6 @@
#include <Poco/JSON/JSON.h> #include <Poco/JSON/JSON.h>
#include <Poco/JSON/Object.h> #include <Poco/JSON/Object.h>
#include <Poco/JSON/Stringifier.h> #include <Poco/JSON/Stringifier.h>
#include <Common/logger_useful.h>
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/adaptor/map.hpp>
#include <cstring> #include <cstring>
#include <filesystem> #include <filesystem>
#include <base/FnTraits.h> #include <base/FnTraits.h>
@ -525,8 +523,8 @@ namespace
} }
} }
UsersConfigAccessStorage::UsersConfigAccessStorage(const String & storage_name_, const AccessControl & access_control_) UsersConfigAccessStorage::UsersConfigAccessStorage(const String & storage_name_, AccessControl & access_control_)
: IAccessStorage(storage_name_), access_control(access_control_) : IAccessStorage(storage_name_), access_control(access_control_), memory_storage(storage_name_, access_control.getChangesNotifier())
{ {
} }
@ -605,9 +603,9 @@ void UsersConfigAccessStorage::load(
std::make_shared<Poco::Event>(), std::make_shared<Poco::Event>(),
[&](Poco::AutoPtr<Poco::Util::AbstractConfiguration> new_config, bool /*initial_loading*/) [&](Poco::AutoPtr<Poco::Util::AbstractConfiguration> new_config, bool /*initial_loading*/)
{ {
parseFromConfig(*new_config);
Settings::checkNoSettingNamesAtTopLevel(*new_config, users_config_path); Settings::checkNoSettingNamesAtTopLevel(*new_config, users_config_path);
parseFromConfig(*new_config);
access_control.getChangesNotifier().sendNotifications();
}, },
/* already_loaded = */ false); /* already_loaded = */ false);
} }
@ -662,27 +660,4 @@ std::optional<String> UsersConfigAccessStorage::readNameImpl(const UUID & id, bo
return memory_storage.readName(id, throw_if_not_exists); return memory_storage.readName(id, throw_if_not_exists);
} }
scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(id, handler);
}
scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(type, handler);
}
bool UsersConfigAccessStorage::hasSubscription(const UUID & id) const
{
return memory_storage.hasSubscription(id);
}
bool UsersConfigAccessStorage::hasSubscription(AccessEntityType type) const
{
return memory_storage.hasSubscription(type);
}
} }

View File

@ -22,7 +22,7 @@ public:
static constexpr char STORAGE_TYPE[] = "users.xml"; static constexpr char STORAGE_TYPE[] = "users.xml";
UsersConfigAccessStorage(const String & storage_name_, const AccessControl & access_control_); UsersConfigAccessStorage(const String & storage_name_, AccessControl & access_control_);
~UsersConfigAccessStorage() override; ~UsersConfigAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; } const char * getStorageType() const override { return STORAGE_TYPE; }
@ -37,13 +37,12 @@ public:
const String & include_from_path = {}, const String & include_from_path = {},
const String & preprocessed_dir = {}, const String & preprocessed_dir = {},
const zkutil::GetZooKeeper & get_zookeeper_function = {}); const zkutil::GetZooKeeper & get_zookeeper_function = {});
void reload();
void startPeriodicReloading(); void reload() override;
void stopPeriodicReloading(); void startPeriodicReloading() override;
void stopPeriodicReloading() override;
bool exists(const UUID & id) const override; bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private: private:
void parseFromConfig(const Poco::Util::AbstractConfiguration & config); void parseFromConfig(const Poco::Util::AbstractConfiguration & config);
@ -51,10 +50,8 @@ private:
std::vector<UUID> findAllImpl(AccessEntityType type) const override; std::vector<UUID> findAllImpl(AccessEntityType type) const override;
AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override; AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
std::optional<String> readNameImpl(const UUID & id, bool throw_if_not_exists) const override; std::optional<String> readNameImpl(const UUID & id, bool throw_if_not_exists) const override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
const AccessControl & access_control; AccessControl & access_control;
MemoryAccessStorage memory_storage; MemoryAccessStorage memory_storage;
String path; String path;
std::unique_ptr<ConfigReloader> config_reloader; std::unique_ptr<ConfigReloader> config_reloader;

View File

@ -1,5 +1,6 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <Access/ReplicatedAccessStorage.h> #include <Access/ReplicatedAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
using namespace DB; using namespace DB;
@ -12,18 +13,6 @@ namespace ErrorCodes
} }
TEST(ReplicatedAccessStorage, ShutdownWithoutStartup)
{
auto get_zk = []()
{
return std::shared_ptr<zkutil::ZooKeeper>();
};
auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk);
storage.shutdown();
}
TEST(ReplicatedAccessStorage, ShutdownWithFailedStartup) TEST(ReplicatedAccessStorage, ShutdownWithFailedStartup)
{ {
auto get_zk = []() auto get_zk = []()
@ -31,16 +20,16 @@ TEST(ReplicatedAccessStorage, ShutdownWithFailedStartup)
return std::shared_ptr<zkutil::ZooKeeper>(); return std::shared_ptr<zkutil::ZooKeeper>();
}; };
auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk); AccessChangesNotifier changes_notifier;
try try
{ {
storage.startup(); auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk, changes_notifier);
} }
catch (Exception & e) catch (Exception & e)
{ {
if (e.code() != ErrorCodes::NO_ZOOKEEPER) if (e.code() != ErrorCodes::NO_ZOOKEEPER)
throw; throw;
} }
storage.shutdown();
} }

View File

@ -342,8 +342,6 @@ struct ContextSharedPart
/// Stop periodic reloading of the configuration files. /// Stop periodic reloading of the configuration files.
/// This must be done first because otherwise the reloading may pass a changed config /// This must be done first because otherwise the reloading may pass a changed config
/// to some destroyed parts of ContextSharedPart. /// to some destroyed parts of ContextSharedPart.
if (access_control)
access_control->stopPeriodicReloadingUsersConfigs();
if (external_dictionaries_loader) if (external_dictionaries_loader)
external_dictionaries_loader->enablePeriodicUpdates(false); external_dictionaries_loader->enablePeriodicUpdates(false);
if (external_user_defined_executable_functions_loader) if (external_user_defined_executable_functions_loader)