diff --git a/programs/odbc-bridge/CMakeLists.txt b/programs/odbc-bridge/CMakeLists.txt index ab8d94f2a0c..628f9ee018a 100644 --- a/programs/odbc-bridge/CMakeLists.txt +++ b/programs/odbc-bridge/CMakeLists.txt @@ -14,6 +14,7 @@ set (CLICKHOUSE_ODBC_BRIDGE_SOURCES set (CLICKHOUSE_ODBC_BRIDGE_LINK PRIVATE clickhouse_parsers + clickhouse_aggregate_functions daemon dbms Poco::Data diff --git a/src/AggregateFunctions/AggregateFunctionArray.cpp b/src/AggregateFunctions/AggregateFunctionArray.cpp index ced95185263..7fe4f1f448b 100644 --- a/src/AggregateFunctions/AggregateFunctionArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionArray.cpp @@ -36,7 +36,10 @@ public: } AggregateFunctionPtr transformAggregateFunction( - const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override + const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties &, + const DataTypes & arguments, + const Array &) const override { return std::make_shared(nested_function, arguments); } diff --git a/src/AggregateFunctions/AggregateFunctionCount.cpp b/src/AggregateFunctions/AggregateFunctionCount.cpp index 6c22fec87a2..b00adaa0f1a 100644 --- a/src/AggregateFunctions/AggregateFunctionCount.cpp +++ b/src/AggregateFunctions/AggregateFunctionCount.cpp @@ -7,6 +7,12 @@ namespace DB { +AggregateFunctionPtr AggregateFunctionCount::getOwnNullAdapter( + const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const +{ + return std::make_shared(types[0], params); +} + namespace { @@ -22,7 +28,7 @@ AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, cons void registerAggregateFunctionCount(AggregateFunctionFactory & factory) { - factory.registerFunction("count", createAggregateFunctionCount, AggregateFunctionFactory::CaseInsensitive); + factory.registerFunction("count", {createAggregateFunctionCount, {true}}, AggregateFunctionFactory::CaseInsensitive); } } diff --git a/src/AggregateFunctions/AggregateFunctionCount.h b/src/AggregateFunctions/AggregateFunctionCount.h index e54f014f7a4..feb5725d9f1 100644 --- a/src/AggregateFunctions/AggregateFunctionCount.h +++ b/src/AggregateFunctions/AggregateFunctionCount.h @@ -68,16 +68,14 @@ public: data(place).count = new_count; } - /// The function returns non-Nullable type even when wrapped with Null combinator. - bool returnDefaultWhenOnlyNull() const override - { - return true; - } + AggregateFunctionPtr getOwnNullAdapter( + const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const override; }; /// Simply count number of not-NULL values. -class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper +class AggregateFunctionCountNotNullUnary final + : public IAggregateFunctionDataHelper { public: AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params) diff --git a/src/AggregateFunctions/AggregateFunctionFactory.cpp b/src/AggregateFunctions/AggregateFunctionFactory.cpp index 3982c48700b..7ff52fe0f70 100644 --- a/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -29,18 +29,18 @@ namespace ErrorCodes } -void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness) +void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness) { - if (creator == nullptr) + if (creator_with_properties.creator == nullptr) throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided " " a null constructor", ErrorCodes::LOGICAL_ERROR); - if (!aggregate_functions.emplace(name, creator).second) + if (!aggregate_functions.emplace(name, creator_with_properties).second) throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique", ErrorCodes::LOGICAL_ERROR); if (case_sensitiveness == CaseInsensitive - && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second) + && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator_with_properties).second) throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique", ErrorCodes::LOGICAL_ERROR); } @@ -59,6 +59,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get( const String & name, const DataTypes & argument_types, const Array & parameters, + AggregateFunctionProperties & out_properties, int recursion_level) const { auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types); @@ -76,18 +77,11 @@ AggregateFunctionPtr AggregateFunctionFactory::get( DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); Array nested_parameters = combinator->transformParameters(parameters); - AggregateFunctionPtr nested_function; - - /// A little hack - if we have NULL arguments, don't even create nested function. - /// Combinator will check if nested_function was created. - if (name == "count" || std::none_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(), - [](const auto & type) { return type->onlyNull(); })) - nested_function = getImpl(name, nested_types, nested_parameters, recursion_level); - - return combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters); + AggregateFunctionPtr nested_function = getImpl(name, nested_types, nested_parameters, out_properties, recursion_level); + return combinator->transformAggregateFunction(nested_function, out_properties, type_without_low_cardinality, parameters); } - auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level); + auto res = getImpl(name, type_without_low_cardinality, parameters, out_properties, recursion_level); if (!res) throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR); return res; @@ -98,19 +92,37 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl( const String & name_param, const DataTypes & argument_types, const Array & parameters, + AggregateFunctionProperties & out_properties, int recursion_level) const { String name = getAliasToOrName(name_param); + Value found; + /// Find by exact match. if (auto it = aggregate_functions.find(name); it != aggregate_functions.end()) - return it->second(name, argument_types, parameters); - + { + found = it->second; + } /// Find by case-insensitive name. /// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names. - if (recursion_level == 0) + else if (recursion_level == 0) { - if (auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name)); it != case_insensitive_aggregate_functions.end()) - return it->second(name, argument_types, parameters); + if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end()) + found = jt->second; + } + + if (found.creator) + { + out_properties = found.properties; + + /// The case when aggregate function should return NULL on NULL arguments. This case is handled in "get" method. + if (!out_properties.returns_default_when_only_null + && std::any_of(argument_types.begin(), argument_types.end(), [](const auto & type) { return type->onlyNull(); })) + { + return nullptr; + } + + return found.creator(name, argument_types, parameters); } /// Combinators of aggregate functions. @@ -126,9 +138,8 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl( DataTypes nested_types = combinator->transformArguments(argument_types); Array nested_parameters = combinator->transformParameters(parameters); - AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1); - - return combinator->transformAggregateFunction(nested_function, argument_types, parameters); + AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, out_properties, recursion_level + 1); + return combinator->transformAggregateFunction(nested_function, out_properties, argument_types, parameters); } auto hints = this->getHints(name); @@ -140,10 +151,11 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl( } -AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const +AggregateFunctionPtr AggregateFunctionFactory::tryGet( + const String & name, const DataTypes & argument_types, const Array & parameters, AggregateFunctionProperties & out_properties) const { return isAggregateFunctionName(name) - ? get(name, argument_types, parameters) + ? get(name, argument_types, parameters, out_properties) : nullptr; } diff --git a/src/AggregateFunctions/AggregateFunctionFactory.h b/src/AggregateFunctions/AggregateFunctionFactory.h index 6e755cc9e8c..ab45fcf683f 100644 --- a/src/AggregateFunctions/AggregateFunctionFactory.h +++ b/src/AggregateFunctions/AggregateFunctionFactory.h @@ -26,34 +26,50 @@ using DataTypes = std::vector; */ using AggregateFunctionCreator = std::function; +struct AggregateFunctionWithProperties +{ + AggregateFunctionCreator creator; + AggregateFunctionProperties properties; + + AggregateFunctionWithProperties() = default; + AggregateFunctionWithProperties(const AggregateFunctionWithProperties &) = default; + + template > * = nullptr> + AggregateFunctionWithProperties(Creator creator_, AggregateFunctionProperties properties_ = {}) + : creator(std::forward(creator_)), properties(std::move(properties_)) + { + } +}; + /** Creates an aggregate function by name. */ -class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases +class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases { public: - static AggregateFunctionFactory & instance(); /// Register a function by its name. /// No locking, you must register all functions before usage of get. void registerFunction( const String & name, - Creator creator, + Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive); /// Throws an exception if not found. AggregateFunctionPtr get( const String & name, const DataTypes & argument_types, - const Array & parameters = {}, + const Array & parameters, + AggregateFunctionProperties & out_properties, int recursion_level = 0) const; /// Returns nullptr if not found. AggregateFunctionPtr tryGet( const String & name, const DataTypes & argument_types, - const Array & parameters = {}) const; + const Array & parameters, + AggregateFunctionProperties & out_properties) const; bool isAggregateFunctionName(const String & name, int recursion_level = 0) const; @@ -62,19 +78,20 @@ private: const String & name, const DataTypes & argument_types, const Array & parameters, + AggregateFunctionProperties & out_properties, int recursion_level) const; private: - using AggregateFunctions = std::unordered_map; + using AggregateFunctions = std::unordered_map; AggregateFunctions aggregate_functions; /// Case insensitive aggregate functions will be additionally added here with lowercased name. AggregateFunctions case_insensitive_aggregate_functions; - const AggregateFunctions & getCreatorMap() const override { return aggregate_functions; } + const AggregateFunctions & getMap() const override { return aggregate_functions; } - const AggregateFunctions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_aggregate_functions; } + const AggregateFunctions & getCaseInsensitiveMap() const override { return case_insensitive_aggregate_functions; } String getFactoryName() const override { return "AggregateFunctionFactory"; } diff --git a/src/AggregateFunctions/AggregateFunctionForEach.cpp b/src/AggregateFunctions/AggregateFunctionForEach.cpp index 775dab2dcd9..693bc6839fa 100644 --- a/src/AggregateFunctions/AggregateFunctionForEach.cpp +++ b/src/AggregateFunctions/AggregateFunctionForEach.cpp @@ -33,7 +33,10 @@ public: } AggregateFunctionPtr transformAggregateFunction( - const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override + const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties &, + const DataTypes & arguments, + const Array &) const override { return std::make_shared(nested_function, arguments); } diff --git a/src/AggregateFunctions/AggregateFunctionIf.cpp b/src/AggregateFunctions/AggregateFunctionIf.cpp index cb5f9f15b1c..19a175de911 100644 --- a/src/AggregateFunctions/AggregateFunctionIf.cpp +++ b/src/AggregateFunctions/AggregateFunctionIf.cpp @@ -31,7 +31,10 @@ public: } AggregateFunctionPtr transformAggregateFunction( - const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override + const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties &, + const DataTypes & arguments, + const Array &) const override { return std::make_shared(nested_function, arguments); } diff --git a/src/AggregateFunctions/AggregateFunctionMerge.cpp b/src/AggregateFunctions/AggregateFunctionMerge.cpp index 05d941844d9..2ce3f0e11f6 100644 --- a/src/AggregateFunctions/AggregateFunctionMerge.cpp +++ b/src/AggregateFunctions/AggregateFunctionMerge.cpp @@ -34,7 +34,10 @@ public: } AggregateFunctionPtr transformAggregateFunction( - const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override + const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties &, + const DataTypes & arguments, + const Array &) const override { const DataTypePtr & argument = arguments[0]; diff --git a/src/AggregateFunctions/AggregateFunctionNothing.h b/src/AggregateFunctions/AggregateFunctionNothing.h index 511dbbecd38..b3206f6db6e 100644 --- a/src/AggregateFunctions/AggregateFunctionNothing.h +++ b/src/AggregateFunctions/AggregateFunctionNothing.h @@ -25,7 +25,7 @@ public: DataTypePtr getReturnType() const override { - return std::make_shared(std::make_shared()); + return argument_types.front(); } void create(AggregateDataPtr) const override diff --git a/src/AggregateFunctions/AggregateFunctionNull.cpp b/src/AggregateFunctions/AggregateFunctionNull.cpp index 993cb93c991..85d960eae62 100644 --- a/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -31,13 +31,11 @@ public: } AggregateFunctionPtr transformAggregateFunction( - const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override + const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties & properties, + const DataTypes & arguments, + const Array & params) const override { - /// Special case for 'count' function. It could be called with Nullable arguments - /// - that means - count number of calls, when all arguments are not NULL. - if (nested_function && nested_function->getName() == "count") - return std::make_shared(arguments[0], params); - bool has_nullable_types = false; bool has_null_types = false; for (const auto & arg_type : arguments) @@ -58,15 +56,23 @@ public: ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (has_null_types) - return std::make_shared(arguments, params); + { + std::cerr << properties.returns_default_when_only_null << "\n"; + + /// Currently the only functions that returns not-NULL on all NULL arguments are count and uniq, and they returns UInt64. + if (properties.returns_default_when_only_null) + return std::make_shared(DataTypes{std::make_shared()}, params); + else + return std::make_shared(arguments, params); + } assert(nested_function); if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) return adapter; - bool return_type_is_nullable = !nested_function->returnDefaultWhenOnlyNull() && nested_function->getReturnType()->canBeInsideNullable(); - bool serialize_flag = return_type_is_nullable || nested_function->returnDefaultWhenOnlyNull(); + bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getReturnType()->canBeInsideNullable(); + bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null; if (arguments.size() == 1) { diff --git a/src/AggregateFunctions/AggregateFunctionOrFill.cpp b/src/AggregateFunctions/AggregateFunctionOrFill.cpp index b9cc2f9b8b7..ce8fc8d9ca5 100644 --- a/src/AggregateFunctions/AggregateFunctionOrFill.cpp +++ b/src/AggregateFunctions/AggregateFunctionOrFill.cpp @@ -21,6 +21,7 @@ public: AggregateFunctionPtr transformAggregateFunction( const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties &, const DataTypes & arguments, const Array & params) const override { diff --git a/src/AggregateFunctions/AggregateFunctionResample.cpp b/src/AggregateFunctions/AggregateFunctionResample.cpp index d8d13e22120..389c9048918 100644 --- a/src/AggregateFunctions/AggregateFunctionResample.cpp +++ b/src/AggregateFunctions/AggregateFunctionResample.cpp @@ -43,6 +43,7 @@ public: AggregateFunctionPtr transformAggregateFunction( const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties &, const DataTypes & arguments, const Array & params) const override { diff --git a/src/AggregateFunctions/AggregateFunctionState.cpp b/src/AggregateFunctions/AggregateFunctionState.cpp index fd92953d114..9d1c677c0ff 100644 --- a/src/AggregateFunctions/AggregateFunctionState.cpp +++ b/src/AggregateFunctions/AggregateFunctionState.cpp @@ -24,7 +24,10 @@ public: } AggregateFunctionPtr transformAggregateFunction( - const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override + const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties &, + const DataTypes & arguments, + const Array & params) const override { return std::make_shared(nested_function, arguments, params); } diff --git a/src/AggregateFunctions/AggregateFunctionUniq.cpp b/src/AggregateFunctions/AggregateFunctionUniq.cpp index 1d079550124..40742ae336e 100644 --- a/src/AggregateFunctions/AggregateFunctionUniq.cpp +++ b/src/AggregateFunctions/AggregateFunctionUniq.cpp @@ -123,13 +123,13 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory) { factory.registerFunction("uniq", - createAggregateFunctionUniq); + {createAggregateFunctionUniq, {true}}); factory.registerFunction("uniqHLL12", - createAggregateFunctionUniq); + {createAggregateFunctionUniq, {true}}); factory.registerFunction("uniqExact", - createAggregateFunctionUniq>); + {createAggregateFunctionUniq>, {true}}); } } diff --git a/src/AggregateFunctions/AggregateFunctionUniq.h b/src/AggregateFunctions/AggregateFunctionUniq.h index 1588611b8a2..334e809ebe7 100644 --- a/src/AggregateFunctions/AggregateFunctionUniq.h +++ b/src/AggregateFunctions/AggregateFunctionUniq.h @@ -244,12 +244,6 @@ public: { assert_cast(to).getData().push_back(this->data(place).set.size()); } - - /// The function returns non-Nullable type even when wrapped with Null combinator. - bool returnDefaultWhenOnlyNull() const override - { - return true; - } }; @@ -304,12 +298,6 @@ public: { assert_cast(to).getData().push_back(this->data(place).set.size()); } - - /// The function returns non-Nullable type even when wrapped with Null combinator. - bool returnDefaultWhenOnlyNull() const override - { - return true; - } }; } diff --git a/src/AggregateFunctions/AggregateFunctionUniqUpTo.cpp b/src/AggregateFunctions/AggregateFunctionUniqUpTo.cpp index a9a8ae0eaf3..9befc515de6 100644 --- a/src/AggregateFunctions/AggregateFunctionUniqUpTo.cpp +++ b/src/AggregateFunctions/AggregateFunctionUniqUpTo.cpp @@ -85,7 +85,7 @@ AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, c void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory & factory) { - factory.registerFunction("uniqUpTo", createAggregateFunctionUniqUpTo); + factory.registerFunction("uniqUpTo", {createAggregateFunctionUniqUpTo, {true}}); } } diff --git a/src/AggregateFunctions/IAggregateFunction.h b/src/AggregateFunctions/IAggregateFunction.h index 439a5e07c2e..5f4291dd21d 100644 --- a/src/AggregateFunctions/IAggregateFunction.h +++ b/src/AggregateFunctions/IAggregateFunction.h @@ -166,17 +166,12 @@ public: * nested_function is a smart pointer to this aggregate function itself. * arguments and params are for nested_function. */ - virtual AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const + virtual AggregateFunctionPtr getOwnNullAdapter( + const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const { return nullptr; } - /** When the function is wrapped with Null combinator, - * should we return Nullable type with NULL when no values were aggregated - * or we should return non-Nullable type with default value (example: count, countDistinct). - */ - virtual bool returnDefaultWhenOnlyNull() const { return false; } - const DataTypes & getArgumentTypes() const { return argument_types; } const Array & getParameters() const { return parameters; } @@ -286,4 +281,15 @@ public: }; +/// Properties of aggregate function that are independent of argument types and parameters. +struct AggregateFunctionProperties +{ + /** When the function is wrapped with Null combinator, + * should we return Nullable type with NULL when no values were aggregated + * or we should return non-Nullable type with default value (example: count, countDistinct). + */ + bool returns_default_when_only_null = false; +}; + + } diff --git a/src/AggregateFunctions/IAggregateFunctionCombinator.h b/src/AggregateFunctions/IAggregateFunctionCombinator.h index 03e2766dc2c..89c313567a3 100644 --- a/src/AggregateFunctions/IAggregateFunctionCombinator.h +++ b/src/AggregateFunctions/IAggregateFunctionCombinator.h @@ -59,6 +59,7 @@ public: */ virtual AggregateFunctionPtr transformAggregateFunction( const AggregateFunctionPtr & nested_function, + const AggregateFunctionProperties & properties, const DataTypes & arguments, const Array & params) const = 0; diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fe223373cf3..321bba1139a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -381,6 +381,6 @@ if (ENABLE_TESTS AND USE_GTEST) -Wno-gnu-zero-variadic-macro-arguments ) - target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils) + target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_aggregate_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils) add_check(unit_tests_dbms) endif () diff --git a/src/Common/IFactoryWithAliases.h b/src/Common/IFactoryWithAliases.h index 64703e51082..994b2c1a02c 100644 --- a/src/Common/IFactoryWithAliases.h +++ b/src/Common/IFactoryWithAliases.h @@ -16,14 +16,14 @@ namespace ErrorCodes } /** If stored objects may have several names (aliases) - * this interface may be helpful - * template parameter is available as Creator - */ -template -class IFactoryWithAliases : public IHints<2, IFactoryWithAliases> + * this interface may be helpful + * template parameter is available as Value + */ +template +class IFactoryWithAliases : public IHints<2, IFactoryWithAliases> { protected: - using Creator = CreatorFunc; + using Value = ValueType; String getAliasToOrName(const String & name) const { @@ -43,13 +43,13 @@ public: CaseInsensitive }; - /** Register additional name for creator - * real_name have to be already registered. - */ + /** Register additional name for value + * real_name have to be already registered. + */ void registerAlias(const String & alias_name, const String & real_name, CaseSensitiveness case_sensitiveness = CaseSensitive) { - const auto & creator_map = getCreatorMap(); - const auto & case_insensitive_creator_map = getCaseInsensitiveCreatorMap(); + const auto & creator_map = getMap(); + const auto & case_insensitive_creator_map = getCaseInsensitiveMap(); const String factory_name = getFactoryName(); String real_dict_name; @@ -80,7 +80,7 @@ public: { std::vector result; auto getter = [](const auto & pair) { return pair.first; }; - std::transform(getCreatorMap().begin(), getCreatorMap().end(), std::back_inserter(result), getter); + std::transform(getMap().begin(), getMap().end(), std::back_inserter(result), getter); std::transform(aliases.begin(), aliases.end(), std::back_inserter(result), getter); return result; } @@ -88,7 +88,7 @@ public: bool isCaseInsensitive(const String & name) const { String name_lowercase = Poco::toLower(name); - return getCaseInsensitiveCreatorMap().count(name_lowercase) || case_insensitive_aliases.count(name_lowercase); + return getCaseInsensitiveMap().count(name_lowercase) || case_insensitive_aliases.count(name_lowercase); } const String & aliasTo(const String & name) const @@ -109,11 +109,11 @@ public: virtual ~IFactoryWithAliases() override {} private: - using InnerMap = std::unordered_map; // name -> creator + using InnerMap = std::unordered_map; // name -> creator using AliasMap = std::unordered_map; // alias -> original type - virtual const InnerMap & getCreatorMap() const = 0; - virtual const InnerMap & getCaseInsensitiveCreatorMap() const = 0; + virtual const InnerMap & getMap() const = 0; + virtual const InnerMap & getCaseInsensitiveMap() const = 0; virtual String getFactoryName() const = 0; /// Alias map to data_types from previous two maps diff --git a/src/DataStreams/tests/CMakeLists.txt b/src/DataStreams/tests/CMakeLists.txt index 14db417b71c..d01c79aee5f 100644 --- a/src/DataStreams/tests/CMakeLists.txt +++ b/src/DataStreams/tests/CMakeLists.txt @@ -1,4 +1,4 @@ set(SRCS) add_executable (finish_sorting_stream finish_sorting_stream.cpp ${SRCS}) -target_link_libraries (finish_sorting_stream PRIVATE dbms) +target_link_libraries (finish_sorting_stream PRIVATE clickhouse_aggregate_functions dbms) diff --git a/src/DataTypes/DataTypeAggregateFunction.cpp b/src/DataTypes/DataTypeAggregateFunction.cpp index 59811b1cd55..fdb17606f78 100644 --- a/src/DataTypes/DataTypeAggregateFunction.cpp +++ b/src/DataTypes/DataTypeAggregateFunction.cpp @@ -392,7 +392,8 @@ static DataTypePtr create(const ASTPtr & arguments) if (function_name.empty()) throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR); - function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); + AggregateFunctionProperties properties; + function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties); return std::make_shared(function, argument_types, params_row); } diff --git a/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp b/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp index 2ddce184cce..157192642ba 100644 --- a/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp +++ b/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp @@ -110,7 +110,8 @@ static std::pair create(const ASTPtr & argum if (function_name.empty()) throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR); - function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); + AggregateFunctionProperties properties; + function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties); // check function if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) diff --git a/src/DataTypes/DataTypeFactory.cpp b/src/DataTypes/DataTypeFactory.cpp index 880f25d009d..69dbed10ccc 100644 --- a/src/DataTypes/DataTypeFactory.cpp +++ b/src/DataTypes/DataTypeFactory.cpp @@ -80,7 +80,7 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr } -void DataTypeFactory::registerDataType(const String & family_name, Creator creator, CaseSensitiveness case_sensitiveness) +void DataTypeFactory::registerDataType(const String & family_name, Value creator, CaseSensitiveness case_sensitiveness) { if (creator == nullptr) throw Exception("DataTypeFactory: the data type family " + family_name + " has been provided " @@ -136,7 +136,7 @@ void DataTypeFactory::registerSimpleDataTypeCustom(const String &name, SimpleCre }, case_sensitiveness); } -const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const +const DataTypeFactory::Value & DataTypeFactory::findCreatorByName(const String & family_name) const { { DataTypesDictionary::const_iterator it = data_types.find(family_name); diff --git a/src/DataTypes/DataTypeFactory.h b/src/DataTypes/DataTypeFactory.h index 6bf09d31727..67b72945acc 100644 --- a/src/DataTypes/DataTypeFactory.h +++ b/src/DataTypes/DataTypeFactory.h @@ -23,7 +23,7 @@ class DataTypeFactory final : private boost::noncopyable, public IFactoryWithAli { private: using SimpleCreator = std::function; - using DataTypesDictionary = std::unordered_map; + using DataTypesDictionary = std::unordered_map; using CreatorWithCustom = std::function(const ASTPtr & parameters)>; using SimpleCreatorWithCustom = std::function()>; @@ -35,7 +35,7 @@ public: DataTypePtr get(const ASTPtr & ast) const; /// Register a type family by its name. - void registerDataType(const String & family_name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); + void registerDataType(const String & family_name, Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive); /// Register a simple data type, that have no parameters. void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); @@ -47,7 +47,7 @@ public: void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive); private: - const Creator& findCreatorByName(const String & family_name) const; + const Value & findCreatorByName(const String & family_name) const; private: DataTypesDictionary data_types; @@ -57,9 +57,9 @@ private: DataTypeFactory(); - const DataTypesDictionary & getCreatorMap() const override { return data_types; } + const DataTypesDictionary & getMap() const override { return data_types; } - const DataTypesDictionary & getCaseInsensitiveCreatorMap() const override { return case_insensitive_data_types; } + const DataTypesDictionary & getCaseInsensitiveMap() const override { return case_insensitive_data_types; } String getFactoryName() const override { return "DataTypeFactory"; } }; diff --git a/src/Formats/tests/CMakeLists.txt b/src/Formats/tests/CMakeLists.txt index 187700dff72..e1cb7604fab 100644 --- a/src/Formats/tests/CMakeLists.txt +++ b/src/Formats/tests/CMakeLists.txt @@ -1,4 +1,4 @@ set(SRCS ) add_executable (tab_separated_streams tab_separated_streams.cpp ${SRCS}) -target_link_libraries (tab_separated_streams PRIVATE dbms) +target_link_libraries (tab_separated_streams PRIVATE clickhouse_aggregate_functions dbms) diff --git a/src/Functions/FunctionFactory.cpp b/src/Functions/FunctionFactory.cpp index 63f12188771..fbc8e11a9c9 100644 --- a/src/Functions/FunctionFactory.cpp +++ b/src/Functions/FunctionFactory.cpp @@ -20,7 +20,7 @@ namespace ErrorCodes void FunctionFactory::registerFunction(const std::string & name, - Creator creator, + Value creator, CaseSensitiveness case_sensitiveness) { if (!functions.emplace(name, creator).second) diff --git a/src/Functions/FunctionFactory.h b/src/Functions/FunctionFactory.h index ccaf2044693..7990e78daf8 100644 --- a/src/Functions/FunctionFactory.h +++ b/src/Functions/FunctionFactory.h @@ -53,7 +53,7 @@ public: FunctionOverloadResolverImplPtr tryGetImpl(const std::string & name, const Context & context) const; private: - using Functions = std::unordered_map; + using Functions = std::unordered_map; Functions functions; Functions case_insensitive_functions; @@ -64,9 +64,9 @@ private: return std::make_unique(Function::create(context)); } - const Functions & getCreatorMap() const override { return functions; } + const Functions & getMap() const override { return functions; } - const Functions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_functions; } + const Functions & getCaseInsensitiveMap() const override { return case_insensitive_functions; } String getFactoryName() const override { return "FunctionFactory"; } @@ -74,7 +74,7 @@ private: /// No locking, you must register all functions before usage of get. void registerFunction( const std::string & name, - Creator creator, + Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive); }; diff --git a/src/Functions/FunctionsBitmap.h b/src/Functions/FunctionsBitmap.h index bf84bfbe47e..868bf8095a4 100644 --- a/src/Functions/FunctionsBitmap.h +++ b/src/Functions/FunctionsBitmap.h @@ -113,8 +113,9 @@ public: auto nested_type = array_type->getNestedType(); DataTypes argument_types = {nested_type}; Array params_row; - AggregateFunctionPtr bitmap_function - = AggregateFunctionFactory::instance().get(AggregateFunctionGroupBitmapData::name(), argument_types, params_row); + AggregateFunctionProperties properties; + AggregateFunctionPtr bitmap_function = AggregateFunctionFactory::instance().get( + AggregateFunctionGroupBitmapData::name(), argument_types, params_row, properties); return std::make_shared(bitmap_function, argument_types, params_row); } @@ -156,8 +157,9 @@ private: // output data Array params_row; - AggregateFunctionPtr bitmap_function - = AggregateFunctionFactory::instance().get(AggregateFunctionGroupBitmapData::name(), argument_types, params_row); + AggregateFunctionProperties properties; + AggregateFunctionPtr bitmap_function = AggregateFunctionFactory::instance().get( + AggregateFunctionGroupBitmapData::name(), argument_types, params_row, properties); auto col_to = ColumnAggregateFunction::create(bitmap_function); col_to->reserve(offsets.size()); diff --git a/src/Functions/array/arrayReduce.cpp b/src/Functions/array/arrayReduce.cpp index 8d44acc82f5..2b37965260f 100644 --- a/src/Functions/array/arrayReduce.cpp +++ b/src/Functions/array/arrayReduce.cpp @@ -97,7 +97,8 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params, aggregate_function_name, params_row, "function " + getName()); - aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row); + AggregateFunctionProperties properties; + aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties); } return aggregate_function->getReturnType(); diff --git a/src/Functions/array/arrayReduceInRanges.cpp b/src/Functions/array/arrayReduceInRanges.cpp index 2dd0cd56343..c3c65c4d9e5 100644 --- a/src/Functions/array/arrayReduceInRanges.cpp +++ b/src/Functions/array/arrayReduceInRanges.cpp @@ -115,7 +115,8 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params, aggregate_function_name, params_row, "function " + getName()); - aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row); + AggregateFunctionProperties properties; + aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties); } return std::make_shared(aggregate_function->getReturnType()); diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index ecfa011f1c8..4c2a8b3dcea 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -420,8 +420,9 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ExpressionActionsPtr & action aggregate.argument_names[i] = name; } + AggregateFunctionProperties properties; aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters) : Array(); - aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters); + aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters, properties); aggregate_descriptions.push_back(aggregate); } diff --git a/src/Interpreters/tests/CMakeLists.txt b/src/Interpreters/tests/CMakeLists.txt index 324a38b1a17..4ab7da014e4 100644 --- a/src/Interpreters/tests/CMakeLists.txt +++ b/src/Interpreters/tests/CMakeLists.txt @@ -34,11 +34,11 @@ target_include_directories (two_level_hash_map SYSTEM BEFORE PRIVATE ${SPARSEHAS target_link_libraries (two_level_hash_map PRIVATE dbms) add_executable (in_join_subqueries_preprocessor in_join_subqueries_preprocessor.cpp) -target_link_libraries (in_join_subqueries_preprocessor PRIVATE dbms clickhouse_parsers) +target_link_libraries (in_join_subqueries_preprocessor PRIVATE clickhouse_aggregate_functions dbms clickhouse_parsers) add_check(in_join_subqueries_preprocessor) add_executable (users users.cpp) -target_link_libraries (users PRIVATE dbms clickhouse_common_config) +target_link_libraries (users PRIVATE clickhouse_aggregate_functions dbms clickhouse_common_config) if (OS_LINUX) add_executable (internal_iotop internal_iotop.cpp) diff --git a/src/Interpreters/tests/hash_map.cpp b/src/Interpreters/tests/hash_map.cpp index 8ddbd3b5886..dc87fd9ddde 100644 --- a/src/Interpreters/tests/hash_map.cpp +++ b/src/Interpreters/tests/hash_map.cpp @@ -103,9 +103,10 @@ int main(int argc, char ** argv) std::vector data(n); Value value; - AggregateFunctionPtr func_count = factory.get("count", data_types_empty); - AggregateFunctionPtr func_avg = factory.get("avg", data_types_uint64); - AggregateFunctionPtr func_uniq = factory.get("uniq", data_types_uint64); + AggregateFunctionProperties properties; + AggregateFunctionPtr func_count = factory.get("count", data_types_empty, {}, properties); + AggregateFunctionPtr func_avg = factory.get("avg", data_types_uint64, {}, properties); + AggregateFunctionPtr func_uniq = factory.get("uniq", data_types_uint64, {}, properties); #define INIT \ { \ diff --git a/src/Processors/Merges/Algorithms/SummingSortedAlgorithm.cpp b/src/Processors/Merges/Algorithms/SummingSortedAlgorithm.cpp index 89154044ae5..8be4aac4067 100644 --- a/src/Processors/Merges/Algorithms/SummingSortedAlgorithm.cpp +++ b/src/Processors/Merges/Algorithms/SummingSortedAlgorithm.cpp @@ -47,7 +47,8 @@ struct SummingSortedAlgorithm::AggregateDescription void init(const char * function_name, const DataTypes & argument_types) { - function = AggregateFunctionFactory::instance().get(function_name, argument_types); + AggregateFunctionProperties properties; + function = AggregateFunctionFactory::instance().get(function_name, argument_types, {}, properties); add_function = function->getAddressOfAddFunction(); state.reset(function->sizeOfData(), function->alignOfData()); } diff --git a/src/Storages/MergeTree/registerStorageMergeTree.cpp b/src/Storages/MergeTree/registerStorageMergeTree.cpp index e08ea1739a5..13cce2b0536 100644 --- a/src/Storages/MergeTree/registerStorageMergeTree.cpp +++ b/src/Storages/MergeTree/registerStorageMergeTree.cpp @@ -117,8 +117,9 @@ static void appendGraphitePattern( aggregate_function_name, params_row, "GraphiteMergeTree storage initialization"); /// TODO Not only Float64 - pattern.function = AggregateFunctionFactory::instance().get(aggregate_function_name, {std::make_shared()}, - params_row); + AggregateFunctionProperties properties; + pattern.function = AggregateFunctionFactory::instance().get( + aggregate_function_name, {std::make_shared()}, params_row, properties); } else if (startsWith(key, "retention")) { diff --git a/src/TableFunctions/TableFunctionFactory.cpp b/src/TableFunctions/TableFunctionFactory.cpp index 1b34c1a1e6f..bc139edfb73 100644 --- a/src/TableFunctions/TableFunctionFactory.cpp +++ b/src/TableFunctions/TableFunctionFactory.cpp @@ -15,7 +15,7 @@ namespace ErrorCodes } -void TableFunctionFactory::registerFunction(const std::string & name, Creator creator, CaseSensitiveness case_sensitiveness) +void TableFunctionFactory::registerFunction(const std::string & name, Value creator, CaseSensitiveness case_sensitiveness) { if (!table_functions.emplace(name, creator).second) throw Exception("TableFunctionFactory: the table function name '" + name + "' is not unique", diff --git a/src/TableFunctions/TableFunctionFactory.h b/src/TableFunctions/TableFunctionFactory.h index cd87fa9c7f0..6d0302a64ff 100644 --- a/src/TableFunctions/TableFunctionFactory.h +++ b/src/TableFunctions/TableFunctionFactory.h @@ -24,12 +24,11 @@ using TableFunctionCreator = std::function; class TableFunctionFactory final: private boost::noncopyable, public IFactoryWithAliases { public: - static TableFunctionFactory & instance(); /// Register a function by its name. /// No locking, you must register all functions before usage of get. - void registerFunction(const std::string & name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); + void registerFunction(const std::string & name, Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive); template void registerFunction(CaseSensitiveness case_sensitiveness = CaseSensitive) @@ -50,11 +49,11 @@ public: bool isTableFunctionName(const std::string & name) const; private: - using TableFunctions = std::unordered_map; + using TableFunctions = std::unordered_map; - const TableFunctions & getCreatorMap() const override { return table_functions; } + const TableFunctions & getMap() const override { return table_functions; } - const TableFunctions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_table_functions; } + const TableFunctions & getCaseInsensitiveMap() const override { return case_insensitive_table_functions; } String getFactoryName() const override { return "TableFunctionFactory"; } diff --git a/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.reference b/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.reference index f8b77704aa3..76b82419556 100644 --- a/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.reference +++ b/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.reference @@ -5,5 +5,5 @@ 5 5 0 -\N -\N +0 +0 diff --git a/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.sql b/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.sql index 2d9b5ef54aa..9787ee2bd70 100644 --- a/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.sql +++ b/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.sql @@ -7,6 +7,5 @@ SELECT uniqExact(number >= 5 ? number : NULL) FROM numbers(10); SELECT count(DISTINCT number >= 5 ? number : NULL) FROM numbers(10); SELECT count(NULL); --- These two returns NULL for now, but we want to change them to return 0. SELECT uniq(NULL); SELECT count(DISTINCT NULL); diff --git a/utils/convert-month-partitioned-parts/CMakeLists.txt b/utils/convert-month-partitioned-parts/CMakeLists.txt index 14853590c76..ea6429a0610 100644 --- a/utils/convert-month-partitioned-parts/CMakeLists.txt +++ b/utils/convert-month-partitioned-parts/CMakeLists.txt @@ -1,2 +1,2 @@ add_executable (convert-month-partitioned-parts main.cpp) -target_link_libraries(convert-month-partitioned-parts PRIVATE dbms clickhouse_parsers boost::program_options) +target_link_libraries(convert-month-partitioned-parts PRIVATE clickhouse_aggregate_functions dbms clickhouse_parsers boost::program_options) diff --git a/utils/zookeeper-adjust-block-numbers-to-parts/CMakeLists.txt b/utils/zookeeper-adjust-block-numbers-to-parts/CMakeLists.txt index 08907e1c5b9..882c510ea1c 100644 --- a/utils/zookeeper-adjust-block-numbers-to-parts/CMakeLists.txt +++ b/utils/zookeeper-adjust-block-numbers-to-parts/CMakeLists.txt @@ -1,3 +1,3 @@ add_executable (zookeeper-adjust-block-numbers-to-parts main.cpp ${SRCS}) target_compile_options(zookeeper-adjust-block-numbers-to-parts PRIVATE -Wno-format) -target_link_libraries (zookeeper-adjust-block-numbers-to-parts PRIVATE dbms clickhouse_common_zookeeper boost::program_options) +target_link_libraries (zookeeper-adjust-block-numbers-to-parts PRIVATE clickhouse_aggregate_functions dbms clickhouse_common_zookeeper boost::program_options)