diff --git a/src/Functions/FunctionFactory.cpp b/src/Functions/FunctionFactory.cpp index 75d045a24ed..26cefc45df0 100644 --- a/src/Functions/FunctionFactory.cpp +++ b/src/Functions/FunctionFactory.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -32,11 +31,12 @@ const String & getFunctionCanonicalNameIfAny(const String & name) return FunctionFactory::instance().getCanonicalNameIfAny(name); } -void FunctionFactory::registerFunction(const - std::string & name, +void FunctionFactory::registerFunction( + const std::string & name, Value creator, CaseSensitiveness case_sensitiveness) { + std::lock_guard guard(mutex); if (!functions.emplace(name, creator).second) throw Exception("FunctionFactory: the function name '" + name + "' is not unique", ErrorCodes::LOGICAL_ERROR); @@ -79,6 +79,7 @@ FunctionOverloadResolverPtr FunctionFactory::getImpl( std::vector FunctionFactory::getAllNames() const { + std::lock_guard guard(mutex); std::vector res; res.reserve(functions.size()); for (const auto & func : functions) @@ -97,6 +98,7 @@ FunctionOverloadResolverPtr FunctionFactory::tryGetImpl( const std::string & name_param, ContextConstPtr context) const { + std::lock_guard guard(mutex); String name = getAliasToOrName(name_param); FunctionOverloadResolverPtr res; @@ -146,6 +148,10 @@ void FunctionFactory::registerUserDefinedFunction(const ASTCreateFunctionQuery & if (AggregateFunctionFactory::instance().isAggregateFunctionName(create_function_query.function_name)) throw Exception(ErrorCodes::FUNCTION_ALREADY_EXISTS, "The aggregate function {} already exists", create_function_query.function_name); + { + std::lock_guard guard(mutex); + user_defined_functions.insert(create_function_query.function_name); + } registerFunction(create_function_query.function_name, [create_function_query](ContextPtr context) { auto function = UserDefinedFunction::create(context); @@ -155,11 +161,11 @@ void FunctionFactory::registerUserDefinedFunction(const ASTCreateFunctionQuery & FunctionOverloadResolverPtr res = std::make_unique(function); return res; }, CaseSensitiveness::CaseSensitive); - user_defined_functions.insert(create_function_query.function_name); } void FunctionFactory::unregisterUserDefinedFunction(const String & name) { + std::lock_guard guard(mutex); if (functions.contains(name)) { if (user_defined_functions.contains(name)) diff --git a/src/Functions/FunctionFactory.h b/src/Functions/FunctionFactory.h index ade597d63bb..e62f8eb97da 100644 --- a/src/Functions/FunctionFactory.h +++ b/src/Functions/FunctionFactory.h @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -20,7 +21,7 @@ namespace DB * some dictionaries from Context. */ class FunctionFactory : private boost::noncopyable, - public IFactoryWithAliases> + public IFactoryWithAliases> { public: static FunctionFactory & instance(); @@ -49,14 +50,14 @@ public: std::vector getAllNames() const; /// Throws an exception if not found. - FunctionOverloadResolverPtr get(const std::string & name, ContextPtr context) const; + FunctionOverloadResolverPtr get(const std::string & name, ContextConstPtr context) const; /// Returns nullptr if not found. - FunctionOverloadResolverPtr tryGet(const std::string & name, ContextPtr context) const; + FunctionOverloadResolverPtr tryGet(const std::string & name, ContextConstPtr context) const; /// The same methods to get developer interface implementation. - FunctionOverloadResolverPtr getImpl(const std::string & name, ContextPtr context) const; - FunctionOverloadResolverPtr tryGetImpl(const std::string & name, ContextPtr context) const; + FunctionOverloadResolverPtr getImpl(const std::string & name, ContextConstPtr context) const; + FunctionOverloadResolverPtr tryGetImpl(const std::string & name, ContextConstPtr context) const; /// Register a function by its name. /// No locking, you must register all functions before usage of get. @@ -71,9 +72,10 @@ private: Functions functions; std::unordered_set user_defined_functions; Functions case_insensitive_functions; + mutable std::mutex mutex; template - static FunctionOverloadResolverPtr adaptFunctionToOverloadResolver(ContextPtr context) + static FunctionOverloadResolverPtr adaptFunctionToOverloadResolver(ContextConstPtr context) { return std::make_unique(Function::create(context)); }