diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index 19a01635065..6a541ff28b4 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -1,13 +1,17 @@ #include -#include -#include + #include #include #include +#include +#include +#include + #include -#include #include +#include + namespace DB { diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.h b/dbms/src/AggregateFunctions/AggregateFunctionFactory.h index be44a7a9e87..4e5423396bf 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.h @@ -1,14 +1,22 @@ #pragma once -#include #include + #include +#include +#include +#include +#include +#include + namespace DB { +class Context; class IDataType; + using DataTypePtr = std::shared_ptr; using DataTypes = std::vector; @@ -19,22 +27,8 @@ class AggregateFunctionFactory final : public ext::singleton; - public: - - AggregateFunctionPtr get( - const String & name, - const DataTypes & argument_types, - const Array & parameters = {}, - int recursion_level = 0) const; - - AggregateFunctionPtr tryGet(const String & name, const DataTypes & argument_types, const Array & parameters = {}) const; - - bool isAggregateFunctionName(const String & name, int recursion_level = 0) const; + using Creator = std::function; /// For compatibility with SQL, it's possible to specify that certain aggregate function name is case insensitive. enum CaseSensitiveness @@ -43,11 +37,29 @@ public: CaseInsensitive }; - /// Register an aggregate function by its name. - void registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); + /// Register a function by its name. + /// No locking, you must register all functions before usage of get. + void registerFunction( + const String & name, + Creator creator, + CaseSensitiveness case_sensitiveness = CaseSensitive); + + /// Throws an exception if not found. + AggregateFunctionPtr get( + const String & name, + const DataTypes & argument_types, + const Array & parameters = {}, + int recursion_level = 0) const; + + /// Returns nullptr if not found. + AggregateFunctionPtr tryGet( + const String & name, + const DataTypes & argument_types, + const Array & parameters = {}) const; + + bool isAggregateFunctionName(const String & name, int recursion_level = 0) const; private: - AggregateFunctionPtr getImpl( const String & name, const DataTypes & argument_types, @@ -55,6 +67,8 @@ private: int recursion_level) const; private: + using AggregateFunctions = std::unordered_map; + AggregateFunctions aggregate_functions; /// Case insensitive aggregate functions will be additionally added here with lowercased name. diff --git a/dbms/src/Functions/FunctionFactory.cpp b/dbms/src/Functions/FunctionFactory.cpp index c6fb89eee1b..4b681b29263 100644 --- a/dbms/src/Functions/FunctionFactory.cpp +++ b/dbms/src/Functions/FunctionFactory.cpp @@ -1,7 +1,11 @@ -#include #include + +#include + #include +#include + namespace DB { @@ -13,7 +17,10 @@ namespace ErrorCodes } -void FunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness) +void FunctionFactory::registerFunction(const + std::string & name, + Creator creator, + CaseSensitiveness case_sensitiveness) { if (!functions.emplace(name, creator).second) throw Exception("FunctionFactory: the function name '" + name + "' is not unique", diff --git a/dbms/src/Functions/FunctionFactory.h b/dbms/src/Functions/FunctionFactory.h index fa577674178..3f4693080a1 100644 --- a/dbms/src/Functions/FunctionFactory.h +++ b/dbms/src/Functions/FunctionFactory.h @@ -1,20 +1,19 @@ #pragma once -#include -#include -#include +#include + #include -#include -#include +#include +#include +#include +#include namespace DB { class Context; -class IFunction; -using FunctionPtr = std::shared_ptr; /** Creates function by name. @@ -25,12 +24,8 @@ class FunctionFactory : public ext::singleton { friend class StorageSystemFunctions; -private: - using Creator = FunctionPtr(*)(const Context & context); /// Not std::function, for lower object size and less indirection. - using Functions = std::unordered_map; - - Functions functions; - Functions case_insensitive_functions; +public: + using Creator = std::function; /// For compatibility with SQL, it's possible to specify that certain function name is case insensitive. enum CaseSensitiveness @@ -39,18 +34,30 @@ private: CaseInsensitive }; -public: - FunctionPtr get(const String & name, const Context & context) const; /// Throws an exception if not found. - FunctionPtr tryGet(const String & name, const Context & context) const; /// Returns nullptr if not found. - - /// No locking, you must register all functions before usage of get, tryGet. - void registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); + /// Register a function by its name. + /// No locking, you must register all functions before usage of get. + void registerFunction( + const std::string & name, + Creator creator, + CaseSensitiveness case_sensitiveness = CaseSensitive); template void registerFunction() { - registerFunction(String(Function::name), &Function::create); + registerFunction(Function::name, &Function::create); } + + /// Throws an exception if not found. + FunctionPtr get(const std::string & name, const Context & context) const; + + /// Returns nullptr if not found. + FunctionPtr tryGet(const std::string & name, const Context & context) const; + +private: + using Functions = std::unordered_map; + + Functions functions; + Functions case_insensitive_functions; }; } diff --git a/dbms/src/TableFunctions/TableFunctionFactory.cpp b/dbms/src/TableFunctions/TableFunctionFactory.cpp index 267c3c202a2..a3355fd32e9 100644 --- a/dbms/src/TableFunctions/TableFunctionFactory.cpp +++ b/dbms/src/TableFunctions/TableFunctionFactory.cpp @@ -1,7 +1,8 @@ -#include +#include + #include -#include +#include namespace DB @@ -11,9 +12,17 @@ namespace ErrorCodes { extern const int READONLY; extern const int UNKNOWN_FUNCTION; + extern const int LOGICAL_ERROR; } +void TableFunctionFactory::registerFunction(const std::string & name, Creator creator) +{ + if (!functions.emplace(name, std::move(creator)).second) + throw Exception("TableFunctionFactory: the table function name '" + name + "' is not unique", + ErrorCodes::LOGICAL_ERROR); +} + TableFunctionPtr TableFunctionFactory::get( const std::string & name, const Context & context) const @@ -24,6 +33,7 @@ TableFunctionPtr TableFunctionFactory::get( auto it = functions.find(name); if (it == functions.end()) throw Exception("Unknown table function " + name, ErrorCodes::UNKNOWN_FUNCTION); + return it->second(); } diff --git a/dbms/src/TableFunctions/TableFunctionFactory.h b/dbms/src/TableFunctions/TableFunctionFactory.h index fd58b9e625e..de782630386 100644 --- a/dbms/src/TableFunctions/TableFunctionFactory.h +++ b/dbms/src/TableFunctions/TableFunctionFactory.h @@ -1,46 +1,50 @@ #pragma once -#include -#include -#include -#include #include +#include + +#include +#include +#include +#include + namespace DB { -namespace ErrorCodes -{ - extern const int LOGICAL_ERROR; -} +class Context; /** Lets you get a table function by its name. */ -class TableFunctionFactory : public ext::singleton +class TableFunctionFactory final: public ext::singleton { -private: - /// No std::function, for smaller object size and less indirection. - using Creator = TableFunctionPtr(*)(); - using TableFunctions = std::unordered_map; - - TableFunctions functions; - public: - TableFunctionPtr get( - const String & name, - const Context & context) const; + using Creator = std::function; - /// Register a table function by its name. + /// Register a function by its name. /// No locking, you must register all functions before usage of get. + void registerFunction(const std::string & name, Creator creator); + template void registerFunction() { - if (!functions.emplace(std::string(Function::name), []{ return TableFunctionPtr(std::make_unique()); }).second) - throw Exception("TableFunctionFactory: the table function name '" + String(Function::name) + "' is not unique", - ErrorCodes::LOGICAL_ERROR); + auto creator = [] () -> TableFunctionPtr { + return std::make_shared(); + }; + registerFunction(Function::name, std::move(creator)); } + + /// Throws an exception if not found. + TableFunctionPtr get( + const std::string & name, + const Context & context) const; + +private: + using TableFunctions = std::unordered_map; + + TableFunctions functions; }; }