From e40ee1a173c86af5a7202d29bdde19dcc6c4d668 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Thu, 11 Jun 2020 06:45:12 +0300 Subject: [PATCH 1/2] Return non-Nullable results from COUNT(DISTINCT) --- .../AggregateFunctionCount.h | 6 +++ .../AggregateFunctionFactory.cpp | 10 ++-- .../AggregateFunctionNull.cpp | 47 +++++++++++++------ .../AggregateFunctionNull.h | 32 +++++++------ .../AggregateFunctionUniq.h | 12 +++++ .../AggregateFunctionWindowFunnel.h | 5 +- src/AggregateFunctions/IAggregateFunction.h | 6 +++ 7 files changed, 83 insertions(+), 35 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionCount.h b/src/AggregateFunctions/AggregateFunctionCount.h index 092ffc6b6cf..e54f014f7a4 100644 --- a/src/AggregateFunctions/AggregateFunctionCount.h +++ b/src/AggregateFunctions/AggregateFunctionCount.h @@ -67,6 +67,12 @@ public: { data(place).count = new_count; } + + /// The function returns non-Nullable type even when wrapped with Null combinator. + bool returnDefaultWhenOnlyNull() const override + { + return true; + } }; diff --git a/src/AggregateFunctions/AggregateFunctionFactory.cpp b/src/AggregateFunctions/AggregateFunctionFactory.cpp index aeb4fb6db96..3982c48700b 100644 --- a/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -63,14 +63,15 @@ AggregateFunctionPtr AggregateFunctionFactory::get( { auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types); - /// If one of types is Nullable, we apply aggregate function combinator "Null". + /// If one of the types is Nullable, we apply aggregate function combinator "Null". if (std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(), [](const auto & type) { return type->isNullable(); })) { AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("Null"); if (!combinator) - throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.", ErrorCodes::LOGICAL_ERROR); + throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.", + ErrorCodes::LOGICAL_ERROR); DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); Array nested_parameters = combinator->transformParameters(parameters); @@ -132,9 +133,10 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl( auto hints = this->getHints(name); if (!hints.empty()) - throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); + throw Exception(fmt::format("Unknown aggregate function {}. Maybe you meant: {}", name, toString(hints)), + ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); else - throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); + throw Exception(fmt::format("Unknown aggregate function {}", name), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); } diff --git a/src/AggregateFunctions/AggregateFunctionNull.cpp b/src/AggregateFunctions/AggregateFunctionNull.cpp index 60712636562..77687f9f328 100644 --- a/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -49,35 +49,52 @@ public: } if (!has_nullable_types) - throw Exception("Aggregate function combinator 'Null' requires at least one argument to be Nullable", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - if (nested_function) - if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) - return adapter; - - /// Special case for 'count' function. It could be called with Nullable arguments - /// - that means - count number of calls, when all arguments are not NULL. - if (nested_function && nested_function->getName() == "count") - return std::make_shared(arguments[0], params); + throw Exception("Aggregate function combinator 'Null' requires at least one argument to be Nullable", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); if (has_null_types) return std::make_shared(arguments, params); - bool return_type_is_nullable = nested_function->getReturnType()->canBeInsideNullable(); + assert(nested_function); + + if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) + return adapter; + + /// Special case for 'count' function. It could be called with Nullable arguments + /// - that means - count number of calls, when all arguments are not NULL. + if (nested_function->getName() == "count") + return std::make_shared(arguments[0], params); + + bool return_type_is_nullable = !nested_function->returnDefaultWhenOnlyNull() && nested_function->getReturnType()->canBeInsideNullable(); + bool serialize_flag = return_type_is_nullable || nested_function->returnDefaultWhenOnlyNull(); if (arguments.size() == 1) { if (return_type_is_nullable) - return std::make_shared>(nested_function, arguments, params); + { + return std::make_shared>(nested_function, arguments, params); + } else - return std::make_shared>(nested_function, arguments, params); + { + if (serialize_flag) + return std::make_shared>(nested_function, arguments, params); + else + return std::make_shared>(nested_function, arguments, params); + } } else { if (return_type_is_nullable) - return std::make_shared>(nested_function, arguments, params); + { + return std::make_shared>(nested_function, arguments, params); + } else - return std::make_shared>(nested_function, arguments, params); + { + if (serialize_flag) + return std::make_shared>(nested_function, arguments, params); + else + return std::make_shared>(nested_function, arguments, params); + } } } }; diff --git a/src/AggregateFunctions/AggregateFunctionNull.h b/src/AggregateFunctions/AggregateFunctionNull.h index 55d610207f1..d6f0079232c 100644 --- a/src/AggregateFunctions/AggregateFunctionNull.h +++ b/src/AggregateFunctions/AggregateFunctionNull.h @@ -28,7 +28,10 @@ namespace ErrorCodes /// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter. /// true - return NULL; false - return value from empty aggregation state of nested function. -template +/// When serialize_flag is set to true, the flag about presense of values is serialized +/// regardless to the "result_is_nullable" even if it's unneeded - for protocol compatibility. + +template class AggregateFunctionNullBase : public IAggregateFunctionHelper { protected: @@ -129,7 +132,7 @@ public: void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { bool flag = getFlag(place); - if constexpr (result_is_nullable) + if constexpr (serialize_flag) writeBinary(flag, buf); if (flag) nested_function->serialize(nestedPlace(place), buf); @@ -138,7 +141,7 @@ public: void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override { bool flag = 1; - if constexpr (result_is_nullable) + if constexpr (serialize_flag) readBinary(flag, buf); if (flag) { @@ -183,12 +186,15 @@ public: /** There are two cases: for single argument and variadic. * Code for single argument is much more efficient. */ -template -class AggregateFunctionNullUnary final : public AggregateFunctionNullBase> +template +class AggregateFunctionNullUnary final + : public AggregateFunctionNullBase> { public: AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) - : AggregateFunctionNullBase>(std::move(nested_function_), arguments, params) + : AggregateFunctionNullBase>(std::move(nested_function_), arguments, params) { } @@ -218,12 +224,15 @@ public: }; -template -class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase> +template +class AggregateFunctionNullVariadic final + : public AggregateFunctionNullBase> { public: AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) - : AggregateFunctionNullBase>(std::move(nested_function_), arguments, params), + : AggregateFunctionNullBase>(std::move(nested_function_), arguments, params), number_of_arguments(arguments.size()) { if (number_of_arguments == 1) @@ -263,11 +272,6 @@ public: this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); } - bool allocatesMemoryInArena() const override - { - return this->nested_function->allocatesMemoryInArena(); - } - private: enum { MAX_ARGS = 8 }; size_t number_of_arguments = 0; diff --git a/src/AggregateFunctions/AggregateFunctionUniq.h b/src/AggregateFunctions/AggregateFunctionUniq.h index 334e809ebe7..1588611b8a2 100644 --- a/src/AggregateFunctions/AggregateFunctionUniq.h +++ b/src/AggregateFunctions/AggregateFunctionUniq.h @@ -244,6 +244,12 @@ public: { assert_cast(to).getData().push_back(this->data(place).set.size()); } + + /// The function returns non-Nullable type even when wrapped with Null combinator. + bool returnDefaultWhenOnlyNull() const override + { + return true; + } }; @@ -298,6 +304,12 @@ public: { assert_cast(to).getData().push_back(this->data(place).set.size()); } + + /// The function returns non-Nullable type even when wrapped with Null combinator. + bool returnDefaultWhenOnlyNull() const override + { + return true; + } }; } diff --git a/src/AggregateFunctions/AggregateFunctionWindowFunnel.h b/src/AggregateFunctions/AggregateFunctionWindowFunnel.h index 726656d1ca8..b5704203ade 100644 --- a/src/AggregateFunctions/AggregateFunctionWindowFunnel.h +++ b/src/AggregateFunctions/AggregateFunctionWindowFunnel.h @@ -240,9 +240,10 @@ public: return std::make_shared(); } - AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override + AggregateFunctionPtr getOwnNullAdapter( + const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override { - return std::make_shared>(nested_function, arguments, params); + return std::make_shared>(nested_function, arguments, params); } void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override diff --git a/src/AggregateFunctions/IAggregateFunction.h b/src/AggregateFunctions/IAggregateFunction.h index 0087a41d437..439a5e07c2e 100644 --- a/src/AggregateFunctions/IAggregateFunction.h +++ b/src/AggregateFunctions/IAggregateFunction.h @@ -171,6 +171,12 @@ public: return nullptr; } + /** When the function is wrapped with Null combinator, + * should we return Nullable type with NULL when no values were aggregated + * or we should return non-Nullable type with default value (example: count, countDistinct). + */ + virtual bool returnDefaultWhenOnlyNull() const { return false; } + const DataTypes & getArgumentTypes() const { return argument_types; } const Array & getParameters() const { return parameters; } From 3958a032acd8e3da194704505715d9be22b33787 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 14 Jun 2020 08:15:29 +0300 Subject: [PATCH 2/2] Added a test --- src/AggregateFunctions/AggregateFunctionNull.cpp | 10 +++++----- ...1315_count_distinct_return_not_nullable.reference | 9 +++++++++ .../01315_count_distinct_return_not_nullable.sql | 12 ++++++++++++ 3 files changed, 26 insertions(+), 5 deletions(-) create mode 100644 tests/queries/0_stateless/01315_count_distinct_return_not_nullable.reference create mode 100644 tests/queries/0_stateless/01315_count_distinct_return_not_nullable.sql diff --git a/src/AggregateFunctions/AggregateFunctionNull.cpp b/src/AggregateFunctions/AggregateFunctionNull.cpp index 77687f9f328..993cb93c991 100644 --- a/src/AggregateFunctions/AggregateFunctionNull.cpp +++ b/src/AggregateFunctions/AggregateFunctionNull.cpp @@ -33,6 +33,11 @@ public: AggregateFunctionPtr transformAggregateFunction( const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override { + /// Special case for 'count' function. It could be called with Nullable arguments + /// - that means - count number of calls, when all arguments are not NULL. + if (nested_function && nested_function->getName() == "count") + return std::make_shared(arguments[0], params); + bool has_nullable_types = false; bool has_null_types = false; for (const auto & arg_type : arguments) @@ -60,11 +65,6 @@ public: if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) return adapter; - /// Special case for 'count' function. It could be called with Nullable arguments - /// - that means - count number of calls, when all arguments are not NULL. - if (nested_function->getName() == "count") - return std::make_shared(arguments[0], params); - bool return_type_is_nullable = !nested_function->returnDefaultWhenOnlyNull() && nested_function->getReturnType()->canBeInsideNullable(); bool serialize_flag = return_type_is_nullable || nested_function->returnDefaultWhenOnlyNull(); diff --git a/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.reference b/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.reference new file mode 100644 index 00000000000..f8b77704aa3 --- /dev/null +++ b/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.reference @@ -0,0 +1,9 @@ +0 +0 +0 +5 +5 +5 +0 +\N +\N diff --git a/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.sql b/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.sql new file mode 100644 index 00000000000..2d9b5ef54aa --- /dev/null +++ b/tests/queries/0_stateless/01315_count_distinct_return_not_nullable.sql @@ -0,0 +1,12 @@ +SELECT uniq(number >= 10 ? number : NULL) FROM numbers(10); +SELECT uniqExact(number >= 10 ? number : NULL) FROM numbers(10); +SELECT count(DISTINCT number >= 10 ? number : NULL) FROM numbers(10); + +SELECT uniq(number >= 5 ? number : NULL) FROM numbers(10); +SELECT uniqExact(number >= 5 ? number : NULL) FROM numbers(10); +SELECT count(DISTINCT number >= 5 ? number : NULL) FROM numbers(10); + +SELECT count(NULL); +-- These two returns NULL for now, but we want to change them to return 0. +SELECT uniq(NULL); +SELECT count(DISTINCT NULL);