diff --git a/programs/local/LocalServer.cpp b/programs/local/LocalServer.cpp index fbb64ea1135..7357c239e6b 100644 --- a/programs/local/LocalServer.cpp +++ b/programs/local/LocalServer.cpp @@ -43,7 +43,7 @@ #include #include #include -#include +#include #include #include #include @@ -757,7 +757,7 @@ void LocalServer::processConfig() } /// For ClickHouse local if path is not set the loader will be disabled. - global_context->getUserDefinedSQLObjectsLoader().loadObjects(); + global_context->getUserDefinedSQLObjectsStorage().loadObjects(); LOG_DEBUG(log, "Loaded metadata."); } diff --git a/programs/server/Server.cpp b/programs/server/Server.cpp index 36f0ce90e57..b88bbb37866 100644 --- a/programs/server/Server.cpp +++ b/programs/server/Server.cpp @@ -66,7 +66,7 @@ #include #include #include -#include +#include #include #include #include @@ -1716,7 +1716,7 @@ try /// After loading validate that default database exists database_catalog.assertDatabaseExists(default_database); /// Load user-defined SQL functions. - global_context->getUserDefinedSQLObjectsLoader().loadObjects(); + global_context->getUserDefinedSQLObjectsStorage().loadObjects(); } catch (...) { diff --git a/src/Functions/UserDefined/IUserDefinedSQLObjectsLoader.h b/src/Functions/UserDefined/IUserDefinedSQLObjectsLoader.h deleted file mode 100644 index 4c7850951b5..00000000000 --- a/src/Functions/UserDefined/IUserDefinedSQLObjectsLoader.h +++ /dev/null @@ -1,47 +0,0 @@ -#pragma once - -#include - - -namespace DB -{ -class IAST; -struct Settings; -enum class UserDefinedSQLObjectType; - -/// Interface for a loader of user-defined SQL objects. -/// Implementations: UserDefinedSQLLoaderFromDisk, UserDefinedSQLLoaderFromZooKeeper -class IUserDefinedSQLObjectsLoader -{ -public: - virtual ~IUserDefinedSQLObjectsLoader() = default; - - /// Whether this loader can replicate SQL objects to another node. - virtual bool isReplicated() const { return false; } - virtual String getReplicationID() const { return ""; } - - /// Loads all objects. Can be called once - if objects are already loaded the function does nothing. - virtual void loadObjects() = 0; - - /// Stops watching. - virtual void stopWatching() {} - - /// Immediately reloads all objects, throws an exception if failed. - virtual void reloadObjects() = 0; - - /// Immediately reloads a specified object only. - virtual void reloadObject(UserDefinedSQLObjectType object_type, const String & object_name) = 0; - - /// Stores an object (must be called only by UserDefinedSQLFunctionFactory::registerFunction). - virtual bool storeObject( - UserDefinedSQLObjectType object_type, - const String & object_name, - const IAST & create_object_query, - bool throw_if_exists, - bool replace_if_exists, - const Settings & settings) = 0; - - /// Removes an object (must be called only by UserDefinedSQLFunctionFactory::unregisterFunction). - virtual bool removeObject(UserDefinedSQLObjectType object_type, const String & object_name, bool throw_if_not_exists) = 0; -}; -} diff --git a/src/Functions/UserDefined/IUserDefinedSQLObjectsStorage.h b/src/Functions/UserDefined/IUserDefinedSQLObjectsStorage.h new file mode 100644 index 00000000000..345ff8c5954 --- /dev/null +++ b/src/Functions/UserDefined/IUserDefinedSQLObjectsStorage.h @@ -0,0 +1,74 @@ +#pragma once + +#include + +#include + +#include + + +namespace DB +{ +class IAST; +struct Settings; +enum class UserDefinedSQLObjectType; + +/// Interface for a storage of user-defined SQL objects. +/// Implementations: UserDefinedSQLObjectsDiskStorage, UserDefinedSQLObjectsZooKeeperStorage +class IUserDefinedSQLObjectsStorage +{ +public: + virtual ~IUserDefinedSQLObjectsStorage() = default; + + /// Whether this loader can replicate SQL objects to another node. + virtual bool isReplicated() const { return false; } + virtual String getReplicationID() const { return ""; } + + /// Loads all objects. Can be called once - if objects are already loaded the function does nothing. + virtual void loadObjects() = 0; + + /// Get object by name. If no object stored with object_name throws exception. + virtual ASTPtr get(const String & object_name) const = 0; + + /// Get object by name. If no object stored with object_name return nullptr. + virtual ASTPtr tryGet(const String & object_name) const = 0; + + /// Check if object with object_name is stored. + virtual bool has(const String & object_name) const = 0; + + /// Get all user defined object names. + virtual std::vector getAllObjectNames() const = 0; + + /// Get all user defined objects. + virtual std::vector> getAllObjects() const = 0; + + /// Check whether any UDFs have been stored. + virtual bool empty() const = 0; + + /// Stops watching. + virtual void stopWatching() {} + + /// Immediately reloads all objects, throws an exception if failed. + virtual void reloadObjects() = 0; + + /// Immediately reloads a specified object only. + virtual void reloadObject(UserDefinedSQLObjectType object_type, const String & object_name) = 0; + + /// Stores an object (must be called only by UserDefinedSQLFunctionFactory::registerFunction). + virtual bool storeObject( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + ASTPtr create_object_query, + bool throw_if_exists, + bool replace_if_exists, + const Settings & settings) = 0; + + /// Removes an object (must be called only by UserDefinedSQLFunctionFactory::unregisterFunction). + virtual bool removeObject( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + bool throw_if_not_exists) = 0; +}; +} diff --git a/src/Functions/UserDefined/UserDefinedSQLFunctionFactory.cpp b/src/Functions/UserDefined/UserDefinedSQLFunctionFactory.cpp index c4a503589eb..e37e4a23b63 100644 --- a/src/Functions/UserDefined/UserDefinedSQLFunctionFactory.cpp +++ b/src/Functions/UserDefined/UserDefinedSQLFunctionFactory.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include #include @@ -14,8 +14,6 @@ #include #include -#include - namespace DB { @@ -23,7 +21,6 @@ namespace DB namespace ErrorCodes { extern const int FUNCTION_ALREADY_EXISTS; - extern const int UNKNOWN_FUNCTION; extern const int CANNOT_DROP_FUNCTION; extern const int CANNOT_CREATE_RECURSIVE_FUNCTION; extern const int UNSUPPORTED_METHOD; @@ -130,20 +127,17 @@ bool UserDefinedSQLFunctionFactory::registerFunction(const ContextMutablePtr & c checkCanBeRegistered(context, function_name, *create_function_query); create_function_query = normalizeCreateFunctionQuery(*create_function_query); - std::lock_guard lock{mutex}; - auto it = function_name_to_create_query_map.find(function_name); - if (it != function_name_to_create_query_map.end()) - { - if (throw_if_exists) - throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "User-defined function '{}' already exists", function_name); - else if (!replace_if_exists) - return false; - } - try { - auto & loader = context->getUserDefinedSQLObjectsLoader(); - bool stored = loader.storeObject(UserDefinedSQLObjectType::Function, function_name, *create_function_query, throw_if_exists, replace_if_exists, context->getSettingsRef()); + auto & loader = context->getUserDefinedSQLObjectsStorage(); + bool stored = loader.storeObject( + context, + UserDefinedSQLObjectType::Function, + function_name, + create_function_query, + throw_if_exists, + replace_if_exists, + context->getSettingsRef()); if (!stored) return false; } @@ -153,7 +147,6 @@ bool UserDefinedSQLFunctionFactory::registerFunction(const ContextMutablePtr & c throw; } - function_name_to_create_query_map[function_name] = create_function_query; return true; } @@ -161,20 +154,14 @@ bool UserDefinedSQLFunctionFactory::unregisterFunction(const ContextMutablePtr & { checkCanBeUnregistered(context, function_name); - std::lock_guard lock(mutex); - auto it = function_name_to_create_query_map.find(function_name); - if (it == function_name_to_create_query_map.end()) - { - if (throw_if_not_exists) - throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "User-defined function '{}' doesn't exist", function_name); - else - return false; - } - try { - auto & loader = context->getUserDefinedSQLObjectsLoader(); - bool removed = loader.removeObject(UserDefinedSQLObjectType::Function, function_name, throw_if_not_exists); + auto & storage = context->getUserDefinedSQLObjectsStorage(); + bool removed = storage.removeObject( + context, + UserDefinedSQLObjectType::Function, + function_name, + throw_if_not_exists); if (!removed) return false; } @@ -184,61 +171,41 @@ bool UserDefinedSQLFunctionFactory::unregisterFunction(const ContextMutablePtr & throw; } - function_name_to_create_query_map.erase(function_name); return true; } ASTPtr UserDefinedSQLFunctionFactory::get(const String & function_name) const { - std::lock_guard lock(mutex); - - auto it = function_name_to_create_query_map.find(function_name); - if (it == function_name_to_create_query_map.end()) - throw Exception(ErrorCodes::UNKNOWN_FUNCTION, - "The function name '{}' is not registered", - function_name); - - return it->second; + return global_context->getUserDefinedSQLObjectsStorage().get(function_name); } ASTPtr UserDefinedSQLFunctionFactory::tryGet(const std::string & function_name) const { - std::lock_guard lock(mutex); - - auto it = function_name_to_create_query_map.find(function_name); - if (it == function_name_to_create_query_map.end()) - return nullptr; - - return it->second; + return global_context->getUserDefinedSQLObjectsStorage().tryGet(function_name); } bool UserDefinedSQLFunctionFactory::has(const String & function_name) const { - return tryGet(function_name) != nullptr; + return global_context->getUserDefinedSQLObjectsStorage().has(function_name); } std::vector UserDefinedSQLFunctionFactory::getAllRegisteredNames() const { - std::vector registered_names; - - std::lock_guard lock(mutex); - registered_names.reserve(function_name_to_create_query_map.size()); - - for (const auto & [name, _] : function_name_to_create_query_map) - registered_names.emplace_back(name); - - return registered_names; + return global_context->getUserDefinedSQLObjectsStorage().getAllObjectNames(); } bool UserDefinedSQLFunctionFactory::empty() const { - std::lock_guard lock(mutex); - return function_name_to_create_query_map.empty(); + return global_context->getUserDefinedSQLObjectsStorage().empty(); } void UserDefinedSQLFunctionFactory::backup(BackupEntriesCollector & backup_entries_collector, const String & data_path_in_backup) const { - backupUserDefinedSQLObjects(backup_entries_collector, data_path_in_backup, UserDefinedSQLObjectType::Function, getAllFunctions()); + backupUserDefinedSQLObjects( + backup_entries_collector, + data_path_in_backup, + UserDefinedSQLObjectType::Function, + global_context->getUserDefinedSQLObjectsStorage().getAllObjects()); } void UserDefinedSQLFunctionFactory::restore(RestorerFromBackup & restorer, const String & data_path_in_backup) @@ -252,52 +219,4 @@ void UserDefinedSQLFunctionFactory::restore(RestorerFromBackup & restorer, const registerFunction(context, function_name, create_function_query, throw_if_exists, replace_if_exists); } -void UserDefinedSQLFunctionFactory::setAllFunctions(const std::vector> & new_functions) -{ - std::unordered_map normalized_functions; - for (const auto & [function_name, create_query] : new_functions) - normalized_functions[function_name] = normalizeCreateFunctionQuery(*create_query); - - std::lock_guard lock(mutex); - function_name_to_create_query_map = std::move(normalized_functions); -} - -std::vector> UserDefinedSQLFunctionFactory::getAllFunctions() const -{ - std::lock_guard lock{mutex}; - std::vector> all_functions; - all_functions.reserve(function_name_to_create_query_map.size()); - std::copy(function_name_to_create_query_map.begin(), function_name_to_create_query_map.end(), std::back_inserter(all_functions)); - return all_functions; -} - -void UserDefinedSQLFunctionFactory::setFunction(const String & function_name, const IAST & create_function_query) -{ - std::lock_guard lock(mutex); - function_name_to_create_query_map[function_name] = normalizeCreateFunctionQuery(create_function_query); -} - -void UserDefinedSQLFunctionFactory::removeFunction(const String & function_name) -{ - std::lock_guard lock(mutex); - function_name_to_create_query_map.erase(function_name); -} - -void UserDefinedSQLFunctionFactory::removeAllFunctionsExcept(const Strings & function_names_to_keep) -{ - boost::container::flat_set names_set_to_keep{function_names_to_keep.begin(), function_names_to_keep.end()}; - std::lock_guard lock(mutex); - for (auto it = function_name_to_create_query_map.begin(); it != function_name_to_create_query_map.end();) - { - auto current = it++; - if (!names_set_to_keep.contains(current->first)) - function_name_to_create_query_map.erase(current); - } -} - -std::unique_lock UserDefinedSQLFunctionFactory::getLock() const -{ - return std::unique_lock{mutex}; -} - } diff --git a/src/Functions/UserDefined/UserDefinedSQLFunctionFactory.h b/src/Functions/UserDefined/UserDefinedSQLFunctionFactory.h index a7d586061b2..b1f3940323a 100644 --- a/src/Functions/UserDefined/UserDefinedSQLFunctionFactory.h +++ b/src/Functions/UserDefined/UserDefinedSQLFunctionFactory.h @@ -6,7 +6,7 @@ #include #include -#include +#include namespace DB @@ -48,23 +48,11 @@ public: void restore(RestorerFromBackup & restorer, const String & data_path_in_backup); private: - friend class UserDefinedSQLObjectsLoaderFromDisk; - friend class UserDefinedSQLObjectsLoaderFromZooKeeper; - /// Checks that a specified function can be registered, throws an exception if not. static void checkCanBeRegistered(const ContextPtr & context, const String & function_name, const IAST & create_function_query); static void checkCanBeUnregistered(const ContextPtr & context, const String & function_name); - /// The following functions must be called only by the loader. - void setAllFunctions(const std::vector> & new_functions); - std::vector> getAllFunctions() const; - void setFunction(const String & function_name, const IAST & create_function_query); - void removeFunction(const String & function_name); - void removeAllFunctionsExcept(const Strings & function_names_to_keep); - std::unique_lock getLock() const; - - std::unordered_map function_name_to_create_query_map; - mutable std::recursive_mutex mutex; + ContextPtr global_context = Context::getGlobalContextInstance(); }; } diff --git a/src/Functions/UserDefined/UserDefinedSQLObjectsBackup.cpp b/src/Functions/UserDefined/UserDefinedSQLObjectsBackup.cpp index 6920e8ce2c2..3ec5393fa6f 100644 --- a/src/Functions/UserDefined/UserDefinedSQLObjectsBackup.cpp +++ b/src/Functions/UserDefined/UserDefinedSQLObjectsBackup.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include #include #include @@ -37,9 +37,9 @@ void backupUserDefinedSQLObjects( escapeForFileName(object_name) + ".sql", std::make_shared(queryToString(create_object_query))); auto context = backup_entries_collector.getContext(); - const auto & loader = context->getUserDefinedSQLObjectsLoader(); + const auto & storage = context->getUserDefinedSQLObjectsStorage(); - if (!loader.isReplicated()) + if (!storage.isReplicated()) { fs::path data_path_in_backup_fs{data_path_in_backup}; for (const auto & [file_name, entry] : backup_entries) @@ -47,7 +47,7 @@ void backupUserDefinedSQLObjects( return; } - String replication_id = loader.getReplicationID(); + String replication_id = storage.getReplicationID(); auto backup_coordination = backup_entries_collector.getBackupCoordination(); backup_coordination->addReplicatedSQLObjectsDir(replication_id, object_type, data_path_in_backup); @@ -80,9 +80,9 @@ std::vector> restoreUserDefinedSQLObjects(RestorerFromBackup & restorer, const String & data_path_in_backup, UserDefinedSQLObjectType object_type) { auto context = restorer.getContext(); - const auto & loader = context->getUserDefinedSQLObjectsLoader(); + const auto & storage = context->getUserDefinedSQLObjectsStorage(); - if (loader.isReplicated() && !restorer.getRestoreCoordination()->acquireReplicatedSQLObjects(loader.getReplicationID(), object_type)) + if (storage.isReplicated() && !restorer.getRestoreCoordination()->acquireReplicatedSQLObjects(storage.getReplicationID(), object_type)) return {}; /// Other replica is already restoring user-defined SQL objects. auto backup = restorer.getBackup(); diff --git a/src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromDisk.cpp b/src/Functions/UserDefined/UserDefinedSQLObjectsDiskStorage.cpp similarity index 80% rename from src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromDisk.cpp rename to src/Functions/UserDefined/UserDefinedSQLObjectsDiskStorage.cpp index d67c48f166d..271c464e79a 100644 --- a/src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromDisk.cpp +++ b/src/Functions/UserDefined/UserDefinedSQLObjectsDiskStorage.cpp @@ -1,4 +1,4 @@ -#include "Functions/UserDefined/UserDefinedSQLObjectsLoaderFromDisk.h" +#include "Functions/UserDefined/UserDefinedSQLObjectsDiskStorage.h" #include "Functions/UserDefined/UserDefinedSQLFunctionFactory.h" #include "Functions/UserDefined/UserDefinedSQLObjectType.h" @@ -51,7 +51,7 @@ namespace } } -UserDefinedSQLObjectsLoaderFromDisk::UserDefinedSQLObjectsLoaderFromDisk(const ContextPtr & global_context_, const String & dir_path_) +UserDefinedSQLObjectsDiskStorage::UserDefinedSQLObjectsDiskStorage(const ContextPtr & global_context_, const String & dir_path_) : global_context(global_context_) , dir_path{makeDirectoryPathCanonical(dir_path_)} , log{&Poco::Logger::get("UserDefinedSQLObjectsLoaderFromDisk")} @@ -60,13 +60,13 @@ UserDefinedSQLObjectsLoaderFromDisk::UserDefinedSQLObjectsLoaderFromDisk(const C } -ASTPtr UserDefinedSQLObjectsLoaderFromDisk::tryLoadObject(UserDefinedSQLObjectType object_type, const String & object_name) +ASTPtr UserDefinedSQLObjectsDiskStorage::tryLoadObject(UserDefinedSQLObjectType object_type, const String & object_name) { return tryLoadObject(object_type, object_name, getFilePath(object_type, object_name), /* check_file_exists= */ true); } -ASTPtr UserDefinedSQLObjectsLoaderFromDisk::tryLoadObject(UserDefinedSQLObjectType object_type, const String & object_name, const String & path, bool check_file_exists) +ASTPtr UserDefinedSQLObjectsDiskStorage::tryLoadObject(UserDefinedSQLObjectType object_type, const String & object_name, const String & path, bool check_file_exists) { LOG_DEBUG(log, "Loading user defined object {} from file {}", backQuote(object_name), path); @@ -93,7 +93,6 @@ ASTPtr UserDefinedSQLObjectsLoaderFromDisk::tryLoadObject(UserDefinedSQLObjectTy "", 0, global_context->getSettingsRef().max_parser_depth); - UserDefinedSQLFunctionFactory::checkCanBeRegistered(global_context, object_name, *ast); return ast; } } @@ -106,20 +105,20 @@ ASTPtr UserDefinedSQLObjectsLoaderFromDisk::tryLoadObject(UserDefinedSQLObjectTy } -void UserDefinedSQLObjectsLoaderFromDisk::loadObjects() +void UserDefinedSQLObjectsDiskStorage::loadObjects() { if (!objects_loaded) loadObjectsImpl(); } -void UserDefinedSQLObjectsLoaderFromDisk::reloadObjects() +void UserDefinedSQLObjectsDiskStorage::reloadObjects() { loadObjectsImpl(); } -void UserDefinedSQLObjectsLoaderFromDisk::loadObjectsImpl() +void UserDefinedSQLObjectsDiskStorage::loadObjectsImpl() { LOG_INFO(log, "Loading user defined objects from {}", dir_path); createDirectory(); @@ -148,26 +147,25 @@ void UserDefinedSQLObjectsLoaderFromDisk::loadObjectsImpl() function_names_and_queries.emplace_back(function_name, ast); } - UserDefinedSQLFunctionFactory::instance().setAllFunctions(function_names_and_queries); + setAllObjects(function_names_and_queries); objects_loaded = true; LOG_DEBUG(log, "User defined objects loaded"); } -void UserDefinedSQLObjectsLoaderFromDisk::reloadObject(UserDefinedSQLObjectType object_type, const String & object_name) +void UserDefinedSQLObjectsDiskStorage::reloadObject(UserDefinedSQLObjectType object_type, const String & object_name) { createDirectory(); auto ast = tryLoadObject(object_type, object_name); - auto & factory = UserDefinedSQLFunctionFactory::instance(); if (ast) - factory.setFunction(object_name, *ast); + setObject(object_name, *ast); else - factory.removeFunction(object_name); + removeObject(object_name); } -void UserDefinedSQLObjectsLoaderFromDisk::createDirectory() +void UserDefinedSQLObjectsDiskStorage::createDirectory() { std::error_code create_dir_error_code; fs::create_directories(dir_path, create_dir_error_code); @@ -177,10 +175,11 @@ void UserDefinedSQLObjectsLoaderFromDisk::createDirectory() } -bool UserDefinedSQLObjectsLoaderFromDisk::storeObject( +bool UserDefinedSQLObjectsDiskStorage::storeObjectImpl( + const ContextPtr & /*current_context*/, UserDefinedSQLObjectType object_type, const String & object_name, - const IAST & create_object_query, + ASTPtr create_object_query, bool throw_if_exists, bool replace_if_exists, const Settings & settings) @@ -197,7 +196,7 @@ bool UserDefinedSQLObjectsLoaderFromDisk::storeObject( } WriteBufferFromOwnString create_statement_buf; - formatAST(create_object_query, create_statement_buf, false); + formatAST(*create_object_query, create_statement_buf, false); writeChar('\n', create_statement_buf); String create_statement = create_statement_buf.str(); @@ -228,8 +227,11 @@ bool UserDefinedSQLObjectsLoaderFromDisk::storeObject( } -bool UserDefinedSQLObjectsLoaderFromDisk::removeObject( - UserDefinedSQLObjectType object_type, const String & object_name, bool throw_if_not_exists) +bool UserDefinedSQLObjectsDiskStorage::removeObjectImpl( + const ContextPtr & /*current_context*/, + UserDefinedSQLObjectType object_type, + const String & object_name, + bool throw_if_not_exists) { String file_path = getFilePath(object_type, object_name); LOG_DEBUG(log, "Removing user defined object {} stored in file {}", backQuote(object_name), file_path); @@ -249,7 +251,7 @@ bool UserDefinedSQLObjectsLoaderFromDisk::removeObject( } -String UserDefinedSQLObjectsLoaderFromDisk::getFilePath(UserDefinedSQLObjectType object_type, const String & object_name) const +String UserDefinedSQLObjectsDiskStorage::getFilePath(UserDefinedSQLObjectType object_type, const String & object_name) const { String file_path; switch (object_type) diff --git a/src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromDisk.h b/src/Functions/UserDefined/UserDefinedSQLObjectsDiskStorage.h similarity index 65% rename from src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromDisk.h rename to src/Functions/UserDefined/UserDefinedSQLObjectsDiskStorage.h index 7b0bb291f42..f0986dbda72 100644 --- a/src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromDisk.h +++ b/src/Functions/UserDefined/UserDefinedSQLObjectsDiskStorage.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include @@ -9,10 +9,10 @@ namespace DB { /// Loads user-defined sql objects from a specified folder. -class UserDefinedSQLObjectsLoaderFromDisk : public IUserDefinedSQLObjectsLoader +class UserDefinedSQLObjectsDiskStorage : public UserDefinedSQLObjectsStorageBase { public: - UserDefinedSQLObjectsLoaderFromDisk(const ContextPtr & global_context_, const String & dir_path_); + UserDefinedSQLObjectsDiskStorage(const ContextPtr & global_context_, const String & dir_path_); void loadObjects() override; @@ -20,17 +20,22 @@ public: void reloadObject(UserDefinedSQLObjectType object_type, const String & object_name) override; - bool storeObject( +private: + bool storeObjectImpl( + const ContextPtr & current_context, UserDefinedSQLObjectType object_type, const String & object_name, - const IAST & create_object_query, + ASTPtr create_object_query, bool throw_if_exists, bool replace_if_exists, const Settings & settings) override; - bool removeObject(UserDefinedSQLObjectType object_type, const String & object_name, bool throw_if_not_exists) override; + bool removeObjectImpl( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + bool throw_if_not_exists) override; -private: void createDirectory(); void loadObjectsImpl(); ASTPtr tryLoadObject(UserDefinedSQLObjectType object_type, const String & object_name); diff --git a/src/Functions/UserDefined/UserDefinedSQLObjectsStorageBase.cpp b/src/Functions/UserDefined/UserDefinedSQLObjectsStorageBase.cpp new file mode 100644 index 00000000000..4f47a46b10d --- /dev/null +++ b/src/Functions/UserDefined/UserDefinedSQLObjectsStorageBase.cpp @@ -0,0 +1,190 @@ +#include "Functions/UserDefined/UserDefinedSQLObjectsStorageBase.h" + +#include + +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int FUNCTION_ALREADY_EXISTS; + extern const int UNKNOWN_FUNCTION; +} + +namespace +{ + +ASTPtr normalizeCreateFunctionQuery(const IAST & create_function_query) +{ + auto ptr = create_function_query.clone(); + auto & res = typeid_cast(*ptr); + res.if_not_exists = false; + res.or_replace = false; + FunctionNameNormalizer().visit(res.function_core.get()); + return ptr; +} + +} + +ASTPtr UserDefinedSQLObjectsStorageBase::get(const String & object_name) const +{ + std::lock_guard lock(mutex); + + auto it = object_name_to_create_object_map.find(object_name); + if (it == object_name_to_create_object_map.end()) + throw Exception(ErrorCodes::UNKNOWN_FUNCTION, + "The object name '{}' is not saved", + object_name); + + return it->second; +} + +ASTPtr UserDefinedSQLObjectsStorageBase::tryGet(const std::string & object_name) const +{ + std::lock_guard lock(mutex); + + auto it = object_name_to_create_object_map.find(object_name); + if (it == object_name_to_create_object_map.end()) + return nullptr; + + return it->second; +} + +bool UserDefinedSQLObjectsStorageBase::has(const String & object_name) const +{ + return tryGet(object_name) != nullptr; +} + +std::vector UserDefinedSQLObjectsStorageBase::getAllObjectNames() const +{ + std::vector object_names; + + std::lock_guard lock(mutex); + object_names.reserve(object_name_to_create_object_map.size()); + + for (const auto & [name, _] : object_name_to_create_object_map) + object_names.emplace_back(name); + + return object_names; +} + +bool UserDefinedSQLObjectsStorageBase::empty() const +{ + std::lock_guard lock(mutex); + return object_name_to_create_object_map.empty(); +} + +bool UserDefinedSQLObjectsStorageBase::storeObject( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + ASTPtr create_object_query, + bool throw_if_exists, + bool replace_if_exists, + const Settings & settings) +{ + std::lock_guard lock{mutex}; + auto it = object_name_to_create_object_map.find(object_name); + if (it != object_name_to_create_object_map.end()) + { + if (throw_if_exists) + throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "User-defined object '{}' already exists", object_name); + else if (!replace_if_exists) + return false; + } + + bool stored = storeObjectImpl( + current_context, + object_type, + object_name, + create_object_query, + throw_if_exists, + replace_if_exists, + settings); + + if (stored) + object_name_to_create_object_map[object_name] = create_object_query; + + return stored; +} + +bool UserDefinedSQLObjectsStorageBase::removeObject( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + bool throw_if_not_exists) +{ + std::lock_guard lock(mutex); + auto it = object_name_to_create_object_map.find(object_name); + if (it == object_name_to_create_object_map.end()) + { + if (throw_if_not_exists) + throw Exception(ErrorCodes::UNKNOWN_FUNCTION, "User-defined object '{}' doesn't exist", object_name); + else + return false; + } + + bool removed = removeObjectImpl( + current_context, + object_type, + object_name, + throw_if_not_exists); + + if (removed) + object_name_to_create_object_map.erase(object_name); + + return removed; +} + +std::unique_lock UserDefinedSQLObjectsStorageBase::getLock() const +{ + return std::unique_lock{mutex}; +} + +void UserDefinedSQLObjectsStorageBase::setAllObjects(const std::vector> & new_objects) +{ + std::unordered_map normalized_functions; + for (const auto & [function_name, create_query] : new_objects) + normalized_functions[function_name] = normalizeCreateFunctionQuery(*create_query); + + std::lock_guard lock(mutex); + object_name_to_create_object_map = std::move(normalized_functions); +} + +std::vector> UserDefinedSQLObjectsStorageBase::getAllObjects() const +{ + std::lock_guard lock{mutex}; + std::vector> all_objects; + all_objects.reserve(object_name_to_create_object_map.size()); + std::copy(object_name_to_create_object_map.begin(), object_name_to_create_object_map.end(), std::back_inserter(all_objects)); + return all_objects; +} + +void UserDefinedSQLObjectsStorageBase::setObject(const String & object_name, const IAST & create_object_query) +{ + std::lock_guard lock(mutex); + object_name_to_create_object_map[object_name] = normalizeCreateFunctionQuery(create_object_query); +} + +void UserDefinedSQLObjectsStorageBase::removeObject(const String & object_name) +{ + std::lock_guard lock(mutex); + object_name_to_create_object_map.erase(object_name); +} + +void UserDefinedSQLObjectsStorageBase::removeAllObjectsExcept(const Strings & object_names_to_keep) +{ + boost::container::flat_set names_set_to_keep{object_names_to_keep.begin(), object_names_to_keep.end()}; + std::lock_guard lock(mutex); + for (auto it = object_name_to_create_object_map.begin(); it != object_name_to_create_object_map.end();) + { + auto current = it++; + if (!names_set_to_keep.contains(current->first)) + object_name_to_create_object_map.erase(current); + } +} + +} diff --git a/src/Functions/UserDefined/UserDefinedSQLObjectsStorageBase.h b/src/Functions/UserDefined/UserDefinedSQLObjectsStorageBase.h new file mode 100644 index 00000000000..cab63a3bfcf --- /dev/null +++ b/src/Functions/UserDefined/UserDefinedSQLObjectsStorageBase.h @@ -0,0 +1,69 @@ +#pragma once + +#include +#include + +#include + +#include + +namespace DB +{ + +class UserDefinedSQLObjectsStorageBase : public IUserDefinedSQLObjectsStorage +{ +public: + ASTPtr get(const String & object_name) const override; + + ASTPtr tryGet(const String & object_name) const override; + + bool has(const String & object_name) const override; + + std::vector getAllObjectNames() const override; + + std::vector> getAllObjects() const override; + + bool empty() const override; + + bool storeObject( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + ASTPtr create_object_query, + bool throw_if_exists, + bool replace_if_exists, + const Settings & settings) override; + + bool removeObject( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + bool throw_if_not_exists) override; + +protected: + virtual bool storeObjectImpl( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + ASTPtr create_object_query, + bool throw_if_exists, + bool replace_if_exists, + const Settings & settings) = 0; + + virtual bool removeObjectImpl( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + bool throw_if_not_exists) = 0; + + std::unique_lock getLock() const; + void setAllObjects(const std::vector> & new_objects); + void setObject(const String & object_name, const IAST & create_object_query); + void removeObject(const String & object_name); + void removeAllObjectsExcept(const Strings & object_names_to_keep); + + std::unordered_map object_name_to_create_object_map; + mutable std::recursive_mutex mutex; +}; + +} diff --git a/src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromZooKeeper.cpp b/src/Functions/UserDefined/UserDefinedSQLObjectsZooKeeperStorage.cpp similarity index 82% rename from src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromZooKeeper.cpp rename to src/Functions/UserDefined/UserDefinedSQLObjectsZooKeeperStorage.cpp index 29aff666da5..6e5a5338437 100644 --- a/src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromZooKeeper.cpp +++ b/src/Functions/UserDefined/UserDefinedSQLObjectsZooKeeperStorage.cpp @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -47,7 +47,7 @@ namespace } -UserDefinedSQLObjectsLoaderFromZooKeeper::UserDefinedSQLObjectsLoaderFromZooKeeper( +UserDefinedSQLObjectsZooKeeperStorage::UserDefinedSQLObjectsZooKeeperStorage( const ContextPtr & global_context_, const String & zookeeper_path_) : global_context{global_context_} , zookeeper_getter{[global_context_]() { return global_context_->getZooKeeper(); }} @@ -66,20 +66,20 @@ UserDefinedSQLObjectsLoaderFromZooKeeper::UserDefinedSQLObjectsLoaderFromZooKeep zookeeper_path = "/" + zookeeper_path; } -UserDefinedSQLObjectsLoaderFromZooKeeper::~UserDefinedSQLObjectsLoaderFromZooKeeper() +UserDefinedSQLObjectsZooKeeperStorage::~UserDefinedSQLObjectsZooKeeperStorage() { SCOPE_EXIT_SAFE(stopWatchingThread()); } -void UserDefinedSQLObjectsLoaderFromZooKeeper::startWatchingThread() +void UserDefinedSQLObjectsZooKeeperStorage::startWatchingThread() { if (!watching_flag.exchange(true)) { - watching_thread = ThreadFromGlobalPool(&UserDefinedSQLObjectsLoaderFromZooKeeper::processWatchQueue, this); + watching_thread = ThreadFromGlobalPool(&UserDefinedSQLObjectsZooKeeperStorage::processWatchQueue, this); } } -void UserDefinedSQLObjectsLoaderFromZooKeeper::stopWatchingThread() +void UserDefinedSQLObjectsZooKeeperStorage::stopWatchingThread() { if (watching_flag.exchange(false)) { @@ -89,7 +89,7 @@ void UserDefinedSQLObjectsLoaderFromZooKeeper::stopWatchingThread() } } -zkutil::ZooKeeperPtr UserDefinedSQLObjectsLoaderFromZooKeeper::getZooKeeper() +zkutil::ZooKeeperPtr UserDefinedSQLObjectsZooKeeperStorage::getZooKeeper() { auto [zookeeper, session_status] = zookeeper_getter.getZooKeeper(); @@ -106,18 +106,18 @@ zkutil::ZooKeeperPtr UserDefinedSQLObjectsLoaderFromZooKeeper::getZooKeeper() return zookeeper; } -void UserDefinedSQLObjectsLoaderFromZooKeeper::initZooKeeperIfNeeded() +void UserDefinedSQLObjectsZooKeeperStorage::initZooKeeperIfNeeded() { getZooKeeper(); } -void UserDefinedSQLObjectsLoaderFromZooKeeper::resetAfterError() +void UserDefinedSQLObjectsZooKeeperStorage::resetAfterError() { zookeeper_getter.resetCache(); } -void UserDefinedSQLObjectsLoaderFromZooKeeper::loadObjects() +void UserDefinedSQLObjectsZooKeeperStorage::loadObjects() { /// loadObjects() is called at start from Server::main(), so it's better not to stop here on no connection to ZooKeeper or any other error. /// However the watching thread must be started anyway in case the connection will be established later. @@ -136,7 +136,7 @@ void UserDefinedSQLObjectsLoaderFromZooKeeper::loadObjects() } -void UserDefinedSQLObjectsLoaderFromZooKeeper::processWatchQueue() +void UserDefinedSQLObjectsZooKeeperStorage::processWatchQueue() { LOG_DEBUG(log, "Started watching thread"); setThreadName("UserDefObjWatch"); @@ -173,13 +173,13 @@ void UserDefinedSQLObjectsLoaderFromZooKeeper::processWatchQueue() } -void UserDefinedSQLObjectsLoaderFromZooKeeper::stopWatching() +void UserDefinedSQLObjectsZooKeeperStorage::stopWatching() { stopWatchingThread(); } -void UserDefinedSQLObjectsLoaderFromZooKeeper::reloadObjects() +void UserDefinedSQLObjectsZooKeeperStorage::reloadObjects() { auto zookeeper = getZooKeeper(); refreshAllObjects(zookeeper); @@ -187,23 +187,24 @@ void UserDefinedSQLObjectsLoaderFromZooKeeper::reloadObjects() } -void UserDefinedSQLObjectsLoaderFromZooKeeper::reloadObject(UserDefinedSQLObjectType object_type, const String & object_name) +void UserDefinedSQLObjectsZooKeeperStorage::reloadObject(UserDefinedSQLObjectType object_type, const String & object_name) { auto zookeeper = getZooKeeper(); refreshObject(zookeeper, object_type, object_name); } -void UserDefinedSQLObjectsLoaderFromZooKeeper::createRootNodes(const zkutil::ZooKeeperPtr & zookeeper) +void UserDefinedSQLObjectsZooKeeperStorage::createRootNodes(const zkutil::ZooKeeperPtr & zookeeper) { zookeeper->createAncestors(zookeeper_path); zookeeper->createIfNotExists(zookeeper_path, ""); } -bool UserDefinedSQLObjectsLoaderFromZooKeeper::storeObject( +bool UserDefinedSQLObjectsZooKeeperStorage::storeObjectImpl( + const ContextPtr & /*current_context*/, UserDefinedSQLObjectType object_type, const String & object_name, - const IAST & create_object_query, + ASTPtr create_object_query, bool throw_if_exists, bool replace_if_exists, const Settings &) @@ -212,7 +213,7 @@ bool UserDefinedSQLObjectsLoaderFromZooKeeper::storeObject( LOG_DEBUG(log, "Storing user-defined object {} at zk path {}", backQuote(object_name), path); WriteBufferFromOwnString create_statement_buf; - formatAST(create_object_query, create_statement_buf, false); + formatAST(*create_object_query, create_statement_buf, false); writeChar('\n', create_statement_buf); String create_statement = create_statement_buf.str(); @@ -252,8 +253,11 @@ bool UserDefinedSQLObjectsLoaderFromZooKeeper::storeObject( } -bool UserDefinedSQLObjectsLoaderFromZooKeeper::removeObject( - UserDefinedSQLObjectType object_type, const String & object_name, bool throw_if_not_exists) +bool UserDefinedSQLObjectsZooKeeperStorage::removeObjectImpl( + const ContextPtr & /*current_context*/, + UserDefinedSQLObjectType object_type, + const String & object_name, + bool throw_if_not_exists) { String path = getNodePath(zookeeper_path, object_type, object_name); LOG_DEBUG(log, "Removing user-defined object {} at zk path {}", backQuote(object_name), path); @@ -276,7 +280,7 @@ bool UserDefinedSQLObjectsLoaderFromZooKeeper::removeObject( return true; } -bool UserDefinedSQLObjectsLoaderFromZooKeeper::getObjectDataAndSetWatch( +bool UserDefinedSQLObjectsZooKeeperStorage::getObjectDataAndSetWatch( const zkutil::ZooKeeperPtr & zookeeper, String & data, const String & path, @@ -298,7 +302,7 @@ bool UserDefinedSQLObjectsLoaderFromZooKeeper::getObjectDataAndSetWatch( return zookeeper->tryGetWatch(path, data, &entity_stat, object_watcher); } -ASTPtr UserDefinedSQLObjectsLoaderFromZooKeeper::parseObjectData(const String & object_data, UserDefinedSQLObjectType object_type) +ASTPtr UserDefinedSQLObjectsZooKeeperStorage::parseObjectData(const String & object_data, UserDefinedSQLObjectType object_type) { switch (object_type) { @@ -317,7 +321,7 @@ ASTPtr UserDefinedSQLObjectsLoaderFromZooKeeper::parseObjectData(const String & UNREACHABLE(); } -ASTPtr UserDefinedSQLObjectsLoaderFromZooKeeper::tryLoadObject( +ASTPtr UserDefinedSQLObjectsZooKeeperStorage::tryLoadObject( const zkutil::ZooKeeperPtr & zookeeper, UserDefinedSQLObjectType object_type, const String & object_name) { String path = getNodePath(zookeeper_path, object_type, object_name); @@ -343,7 +347,7 @@ ASTPtr UserDefinedSQLObjectsLoaderFromZooKeeper::tryLoadObject( } } -Strings UserDefinedSQLObjectsLoaderFromZooKeeper::getObjectNamesAndSetWatch( +Strings UserDefinedSQLObjectsZooKeeperStorage::getObjectNamesAndSetWatch( const zkutil::ZooKeeperPtr & zookeeper, UserDefinedSQLObjectType object_type) { auto object_list_watcher = [my_watch_queue = watch_queue, object_type](const Coordination::WatchResponse &) @@ -371,7 +375,7 @@ Strings UserDefinedSQLObjectsLoaderFromZooKeeper::getObjectNamesAndSetWatch( return object_names; } -void UserDefinedSQLObjectsLoaderFromZooKeeper::refreshAllObjects(const zkutil::ZooKeeperPtr & zookeeper) +void UserDefinedSQLObjectsZooKeeperStorage::refreshAllObjects(const zkutil::ZooKeeperPtr & zookeeper) { /// It doesn't make sense to keep the old watch events because we will reread everything in this function. watch_queue->clear(); @@ -380,7 +384,7 @@ void UserDefinedSQLObjectsLoaderFromZooKeeper::refreshAllObjects(const zkutil::Z objects_loaded = true; } -void UserDefinedSQLObjectsLoaderFromZooKeeper::refreshObjects(const zkutil::ZooKeeperPtr & zookeeper, UserDefinedSQLObjectType object_type) +void UserDefinedSQLObjectsZooKeeperStorage::refreshObjects(const zkutil::ZooKeeperPtr & zookeeper, UserDefinedSQLObjectType object_type) { LOG_DEBUG(log, "Refreshing all user-defined {} objects", object_type); Strings object_names = getObjectNamesAndSetWatch(zookeeper, object_type); @@ -393,21 +397,20 @@ void UserDefinedSQLObjectsLoaderFromZooKeeper::refreshObjects(const zkutil::ZooK function_names_and_asts.emplace_back(function_name, ast); } - UserDefinedSQLFunctionFactory::instance().setAllFunctions(function_names_and_asts); + setAllObjects(function_names_and_asts); LOG_DEBUG(log, "All user-defined {} objects refreshed", object_type); } -void UserDefinedSQLObjectsLoaderFromZooKeeper::syncObjects(const zkutil::ZooKeeperPtr & zookeeper, UserDefinedSQLObjectType object_type) +void UserDefinedSQLObjectsZooKeeperStorage::syncObjects(const zkutil::ZooKeeperPtr & zookeeper, UserDefinedSQLObjectType object_type) { LOG_DEBUG(log, "Syncing user-defined {} objects", object_type); Strings object_names = getObjectNamesAndSetWatch(zookeeper, object_type); - auto & factory = UserDefinedSQLFunctionFactory::instance(); - auto lock = factory.getLock(); + getLock(); /// Remove stale objects - factory.removeAllFunctionsExcept(object_names); + removeAllObjectsExcept(object_names); /// Read & parse only new SQL objects from ZooKeeper for (const auto & function_name : object_names) { @@ -418,16 +421,15 @@ void UserDefinedSQLObjectsLoaderFromZooKeeper::syncObjects(const zkutil::ZooKeep LOG_DEBUG(log, "User-defined {} objects synced", object_type); } -void UserDefinedSQLObjectsLoaderFromZooKeeper::refreshObject( +void UserDefinedSQLObjectsZooKeeperStorage::refreshObject( const zkutil::ZooKeeperPtr & zookeeper, UserDefinedSQLObjectType object_type, const String & object_name) { auto ast = tryLoadObject(zookeeper, object_type, object_name); - auto & factory = UserDefinedSQLFunctionFactory::instance(); if (ast) - factory.setFunction(object_name, *ast); + setObject(object_name, *ast); else - factory.removeFunction(object_name); + removeObject(object_name); } } diff --git a/src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromZooKeeper.h b/src/Functions/UserDefined/UserDefinedSQLObjectsZooKeeperStorage.h similarity index 80% rename from src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromZooKeeper.h rename to src/Functions/UserDefined/UserDefinedSQLObjectsZooKeeperStorage.h index 38e061fd4d9..9f41763c59c 100644 --- a/src/Functions/UserDefined/UserDefinedSQLObjectsLoaderFromZooKeeper.h +++ b/src/Functions/UserDefined/UserDefinedSQLObjectsZooKeeperStorage.h @@ -1,6 +1,6 @@ #pragma once -#include +#include #include #include #include @@ -12,11 +12,11 @@ namespace DB { /// Loads user-defined sql objects from ZooKeeper. -class UserDefinedSQLObjectsLoaderFromZooKeeper : public IUserDefinedSQLObjectsLoader +class UserDefinedSQLObjectsZooKeeperStorage : public UserDefinedSQLObjectsStorageBase { public: - UserDefinedSQLObjectsLoaderFromZooKeeper(const ContextPtr & global_context_, const String & zookeeper_path_); - ~UserDefinedSQLObjectsLoaderFromZooKeeper() override; + UserDefinedSQLObjectsZooKeeperStorage(const ContextPtr & global_context_, const String & zookeeper_path_); + ~UserDefinedSQLObjectsZooKeeperStorage() override; bool isReplicated() const override { return true; } String getReplicationID() const override { return zookeeper_path; } @@ -26,16 +26,21 @@ public: void reloadObjects() override; void reloadObject(UserDefinedSQLObjectType object_type, const String & object_name) override; - bool storeObject( +private: + bool storeObjectImpl( + const ContextPtr & current_context, UserDefinedSQLObjectType object_type, const String & object_name, - const IAST & create_object_query, + ASTPtr create_object_query, bool throw_if_exists, bool replace_if_exists, const Settings & settings) override; - bool removeObject(UserDefinedSQLObjectType object_type, const String & object_name, bool throw_if_not_exists) override; + bool removeObjectImpl( + const ContextPtr & current_context, + UserDefinedSQLObjectType object_type, + const String & object_name, + bool throw_if_not_exists) override; -private: void processWatchQueue(); zkutil::ZooKeeperPtr getZooKeeper(); diff --git a/src/Functions/UserDefined/createUserDefinedSQLObjectsLoader.h b/src/Functions/UserDefined/createUserDefinedSQLObjectsLoader.h deleted file mode 100644 index b3a4623dba3..00000000000 --- a/src/Functions/UserDefined/createUserDefinedSQLObjectsLoader.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - - -namespace DB -{ -class IUserDefinedSQLObjectsLoader; - -std::unique_ptr createUserDefinedSQLObjectsLoader(const ContextMutablePtr & global_context); - -} diff --git a/src/Functions/UserDefined/createUserDefinedSQLObjectsLoader.cpp b/src/Functions/UserDefined/createUserDefinedSQLObjectsStorage.cpp similarity index 61% rename from src/Functions/UserDefined/createUserDefinedSQLObjectsLoader.cpp rename to src/Functions/UserDefined/createUserDefinedSQLObjectsStorage.cpp index b7ebc7abf14..f8847024508 100644 --- a/src/Functions/UserDefined/createUserDefinedSQLObjectsLoader.cpp +++ b/src/Functions/UserDefined/createUserDefinedSQLObjectsStorage.cpp @@ -1,6 +1,6 @@ -#include -#include -#include +#include +#include +#include #include #include #include @@ -17,7 +17,7 @@ namespace ErrorCodes extern const int INVALID_CONFIG_PARAMETER; } -std::unique_ptr createUserDefinedSQLObjectsLoader(const ContextMutablePtr & global_context) +std::unique_ptr createUserDefinedSQLObjectsStorage(const ContextMutablePtr & global_context) { const String zookeeper_path_key = "user_defined_zookeeper_path"; const String disk_path_key = "user_defined_path"; @@ -33,12 +33,12 @@ std::unique_ptr createUserDefinedSQLObjectsLoader( zookeeper_path_key, disk_path_key); } - return std::make_unique(global_context, config.getString(zookeeper_path_key)); + return std::make_unique(global_context, config.getString(zookeeper_path_key)); } String default_path = fs::path{global_context->getPath()} / "user_defined/"; String path = config.getString(disk_path_key, default_path); - return std::make_unique(global_context, path); + return std::make_unique(global_context, path); } } diff --git a/src/Functions/UserDefined/createUserDefinedSQLObjectsStorage.h b/src/Functions/UserDefined/createUserDefinedSQLObjectsStorage.h new file mode 100644 index 00000000000..01659372dec --- /dev/null +++ b/src/Functions/UserDefined/createUserDefinedSQLObjectsStorage.h @@ -0,0 +1,12 @@ +#pragma once + +#include + + +namespace DB +{ +class IUserDefinedSQLObjectsStorage; + +std::unique_ptr createUserDefinedSQLObjectsStorage(const ContextMutablePtr & global_context); + +} diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 0a8a8f1f529..248b61f6e9b 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -65,8 +65,8 @@ #include #include #include -#include -#include +#include +#include #include #include #include @@ -253,8 +253,8 @@ struct ContextSharedPart : boost::noncopyable ExternalLoaderXMLConfigRepository * user_defined_executable_functions_config_repository TSA_GUARDED_BY(external_user_defined_executable_functions_mutex) = nullptr; scope_guard user_defined_executable_functions_xmls TSA_GUARDED_BY(external_user_defined_executable_functions_mutex); - mutable OnceFlag user_defined_sql_objects_loader_initialized; - mutable std::unique_ptr user_defined_sql_objects_loader; + mutable OnceFlag user_defined_sql_objects_storage_initialized; + mutable std::unique_ptr user_defined_sql_objects_storage; #if USE_NLP mutable OnceFlag synonyms_extensions_initialized; @@ -545,7 +545,7 @@ struct ContextSharedPart : boost::noncopyable SHUTDOWN(log, "dictionaries loader", external_dictionaries_loader, enablePeriodicUpdates(false)); SHUTDOWN(log, "UDFs loader", external_user_defined_executable_functions_loader, enablePeriodicUpdates(false)); - SHUTDOWN(log, "another UDFs loader", user_defined_sql_objects_loader, stopWatching()); + SHUTDOWN(log, "another UDFs storage", user_defined_sql_objects_storage, stopWatching()); LOG_TRACE(log, "Shutting down named sessions"); Session::shutdownNamedSessions(); @@ -572,7 +572,7 @@ struct ContextSharedPart : boost::noncopyable std::unique_ptr delete_embedded_dictionaries; std::unique_ptr delete_external_dictionaries_loader; std::unique_ptr delete_external_user_defined_executable_functions_loader; - std::unique_ptr delete_user_defined_sql_objects_loader; + std::unique_ptr delete_user_defined_sql_objects_storage; std::unique_ptr delete_buffer_flush_schedule_pool; std::unique_ptr delete_schedule_pool; std::unique_ptr delete_distributed_schedule_pool; @@ -652,7 +652,7 @@ struct ContextSharedPart : boost::noncopyable delete_embedded_dictionaries = std::move(embedded_dictionaries); delete_external_dictionaries_loader = std::move(external_dictionaries_loader); delete_external_user_defined_executable_functions_loader = std::move(external_user_defined_executable_functions_loader); - delete_user_defined_sql_objects_loader = std::move(user_defined_sql_objects_loader); + delete_user_defined_sql_objects_storage = std::move(user_defined_sql_objects_storage); delete_buffer_flush_schedule_pool = std::move(buffer_flush_schedule_pool); delete_schedule_pool = std::move(schedule_pool); delete_distributed_schedule_pool = std::move(distributed_schedule_pool); @@ -670,7 +670,7 @@ struct ContextSharedPart : boost::noncopyable delete_embedded_dictionaries.reset(); delete_external_dictionaries_loader.reset(); delete_external_user_defined_executable_functions_loader.reset(); - delete_user_defined_sql_objects_loader.reset(); + delete_user_defined_sql_objects_storage.reset(); delete_ddl_worker.reset(); delete_buffer_flush_schedule_pool.reset(); delete_schedule_pool.reset(); @@ -2448,24 +2448,30 @@ void Context::loadOrReloadUserDefinedExecutableFunctions(const Poco::Util::Abstr shared->user_defined_executable_functions_xmls = external_user_defined_executable_functions_loader.addConfigRepository(std::move(repository)); } -const IUserDefinedSQLObjectsLoader & Context::getUserDefinedSQLObjectsLoader() const +const IUserDefinedSQLObjectsStorage & Context::getUserDefinedSQLObjectsStorage() const { - callOnce(shared->user_defined_sql_objects_loader_initialized, [&] { - shared->user_defined_sql_objects_loader = createUserDefinedSQLObjectsLoader(getGlobalContext()); + callOnce(shared->user_defined_sql_objects_storage_initialized, [&] { + shared->user_defined_sql_objects_storage = createUserDefinedSQLObjectsStorage(getGlobalContext()); }); SharedLockGuard lock(shared->mutex); - return *shared->user_defined_sql_objects_loader; + return *shared->user_defined_sql_objects_storage; } -IUserDefinedSQLObjectsLoader & Context::getUserDefinedSQLObjectsLoader() +IUserDefinedSQLObjectsStorage & Context::getUserDefinedSQLObjectsStorage() { - callOnce(shared->user_defined_sql_objects_loader_initialized, [&] { - shared->user_defined_sql_objects_loader = createUserDefinedSQLObjectsLoader(getGlobalContext()); + callOnce(shared->user_defined_sql_objects_storage_initialized, [&] { + shared->user_defined_sql_objects_storage = createUserDefinedSQLObjectsStorage(getGlobalContext()); }); - SharedLockGuard lock(shared->mutex); - return *shared->user_defined_sql_objects_loader; + std::lock_guard lock(shared->mutex); + return *shared->user_defined_sql_objects_storage; +} + +void Context::setUserDefinedSQLObjectsStorage(std::unique_ptr storage) +{ + std::lock_guard lock(shared->mutex); + shared->user_defined_sql_objects_storage = std::move(storage); } #if USE_NLP diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 8c169dd664f..63a919c5f1a 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -68,7 +68,7 @@ enum class RowPolicyFilterType; class EmbeddedDictionaries; class ExternalDictionariesLoader; class ExternalUserDefinedExecutableFunctionsLoader; -class IUserDefinedSQLObjectsLoader; +class IUserDefinedSQLObjectsStorage; class InterserverCredentials; using InterserverCredentialsPtr = std::shared_ptr; class InterserverIOHandler; @@ -802,8 +802,9 @@ public: const ExternalUserDefinedExecutableFunctionsLoader & getExternalUserDefinedExecutableFunctionsLoader() const; ExternalUserDefinedExecutableFunctionsLoader & getExternalUserDefinedExecutableFunctionsLoader(); - const IUserDefinedSQLObjectsLoader & getUserDefinedSQLObjectsLoader() const; - IUserDefinedSQLObjectsLoader & getUserDefinedSQLObjectsLoader(); + const IUserDefinedSQLObjectsStorage & getUserDefinedSQLObjectsStorage() const; + IUserDefinedSQLObjectsStorage & getUserDefinedSQLObjectsStorage(); + void setUserDefinedSQLObjectsStorage(std::unique_ptr storage); void loadOrReloadUserDefinedExecutableFunctions(const Poco::Util::AbstractConfiguration & config); #if USE_NLP diff --git a/src/Interpreters/InterpreterCreateFunctionQuery.cpp b/src/Interpreters/InterpreterCreateFunctionQuery.cpp index 3e87f4fe440..b155476fd79 100644 --- a/src/Interpreters/InterpreterCreateFunctionQuery.cpp +++ b/src/Interpreters/InterpreterCreateFunctionQuery.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include @@ -32,7 +32,7 @@ BlockIO InterpreterCreateFunctionQuery::execute() if (!create_function_query.cluster.empty()) { - if (current_context->getUserDefinedSQLObjectsLoader().isReplicated()) + if (current_context->getUserDefinedSQLObjectsStorage().isReplicated()) throw Exception(ErrorCodes::INCORRECT_QUERY, "ON CLUSTER is not allowed because used-defined functions are replicated automatically"); DDLQueryOnClusterParams params; diff --git a/src/Interpreters/InterpreterDropFunctionQuery.cpp b/src/Interpreters/InterpreterDropFunctionQuery.cpp index af60d9c5df7..c2cd24044da 100644 --- a/src/Interpreters/InterpreterDropFunctionQuery.cpp +++ b/src/Interpreters/InterpreterDropFunctionQuery.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include @@ -32,7 +32,7 @@ BlockIO InterpreterDropFunctionQuery::execute() if (!drop_function_query.cluster.empty()) { - if (current_context->getUserDefinedSQLObjectsLoader().isReplicated()) + if (current_context->getUserDefinedSQLObjectsStorage().isReplicated()) throw Exception(ErrorCodes::INCORRECT_QUERY, "ON CLUSTER is not allowed because used-defined functions are replicated automatically"); DDLQueryOnClusterParams params; diff --git a/src/Interpreters/removeOnClusterClauseIfNeeded.cpp b/src/Interpreters/removeOnClusterClauseIfNeeded.cpp index 7dc452a0fcb..bee9a54cd0d 100644 --- a/src/Interpreters/removeOnClusterClauseIfNeeded.cpp +++ b/src/Interpreters/removeOnClusterClauseIfNeeded.cpp @@ -3,7 +3,7 @@ #include #include #include -#include +#include #include #include #include @@ -45,7 +45,7 @@ ASTPtr removeOnClusterClauseIfNeeded(const ASTPtr & query, ContextPtr context, c if ((isUserDefinedFunctionQuery(query) && context->getSettings().ignore_on_cluster_for_replicated_udf_queries - && context->getUserDefinedSQLObjectsLoader().isReplicated()) + && context->getUserDefinedSQLObjectsStorage().isReplicated()) || (isAccessControlQuery(query) && context->getSettings().ignore_on_cluster_for_replicated_access_entities_queries && context->getAccessControl().containsStorage(ReplicatedAccessStorage::STORAGE_TYPE))) diff --git a/tests/integration/test_replicated_user_defined_functions/test.py b/tests/integration/test_replicated_user_defined_functions/test.py index f54be21c4c0..e5f6683b90b 100644 --- a/tests/integration/test_replicated_user_defined_functions/test.py +++ b/tests/integration/test_replicated_user_defined_functions/test.py @@ -116,7 +116,7 @@ def test_create_and_replace(): node1.query("CREATE FUNCTION f1 AS (x, y) -> x + y") assert node1.query("SELECT f1(12, 3)") == "15\n" - expected_error = "User-defined function 'f1' already exists" + expected_error = "User-defined object 'f1' already exists" assert expected_error in node1.query_and_get_error( "CREATE FUNCTION f1 AS (x, y) -> x + 2 * y" ) @@ -135,7 +135,7 @@ def test_drop_if_exists(): node1.query("DROP FUNCTION IF EXISTS f1") node1.query("DROP FUNCTION IF EXISTS f1") - expected_error = "User-defined function 'f1' doesn't exist" + expected_error = "User-defined object 'f1' doesn't exist" assert expected_error in node1.query_and_get_error("DROP FUNCTION f1")