Merge pull request #11593 from ClickHouse/return-not-nullable-from-count-distinct

Return non-Nullable results from COUNT(DISTINCT)
This commit is contained in:
alexey-milovidov 2020-06-14 20:39:19 +03:00 committed by GitHub
commit d990b98b90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 104 additions and 35 deletions

View File

@ -67,6 +67,12 @@ public:
{ {
data(place).count = new_count; data(place).count = new_count;
} }
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
}; };

View File

@ -63,14 +63,15 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
{ {
auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types); auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types);
/// If one of types is Nullable, we apply aggregate function combinator "Null". /// If one of the types is Nullable, we apply aggregate function combinator "Null".
if (std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(), if (std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(),
[](const auto & type) { return type->isNullable(); })) [](const auto & type) { return type->isNullable(); }))
{ {
AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("Null"); AggregateFunctionCombinatorPtr combinator = AggregateFunctionCombinatorFactory::instance().tryFindSuffix("Null");
if (!combinator) if (!combinator)
throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.",
ErrorCodes::LOGICAL_ERROR);
DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality);
Array nested_parameters = combinator->transformParameters(parameters); Array nested_parameters = combinator->transformParameters(parameters);
@ -132,9 +133,10 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
auto hints = this->getHints(name); auto hints = this->getHints(name);
if (!hints.empty()) if (!hints.empty())
throw Exception("Unknown aggregate function " + name + ". Maybe you meant: " + toString(hints), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); throw Exception(fmt::format("Unknown aggregate function {}. Maybe you meant: {}", name, toString(hints)),
ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
else else
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); throw Exception(fmt::format("Unknown aggregate function {}", name), ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
} }

View File

@ -33,6 +33,11 @@ public:
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
{ {
/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function && nested_function->getName() == "count")
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
bool has_nullable_types = false; bool has_nullable_types = false;
bool has_null_types = false; bool has_null_types = false;
for (const auto & arg_type : arguments) for (const auto & arg_type : arguments)
@ -49,35 +54,47 @@ public:
} }
if (!has_nullable_types) if (!has_nullable_types)
throw Exception("Aggregate function combinator 'Null' requires at least one argument to be Nullable", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Aggregate function combinator 'Null' requires at least one argument to be Nullable",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (nested_function)
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params))
return adapter;
/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function && nested_function->getName() == "count")
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
if (has_null_types) if (has_null_types)
return std::make_shared<AggregateFunctionNothing>(arguments, params); return std::make_shared<AggregateFunctionNothing>(arguments, params);
bool return_type_is_nullable = nested_function->getReturnType()->canBeInsideNullable(); assert(nested_function);
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params))
return adapter;
bool return_type_is_nullable = !nested_function->returnDefaultWhenOnlyNull() && nested_function->getReturnType()->canBeInsideNullable();
bool serialize_flag = return_type_is_nullable || nested_function->returnDefaultWhenOnlyNull();
if (arguments.size() == 1) if (arguments.size() == 1)
{ {
if (return_type_is_nullable) if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullUnary<true>>(nested_function, arguments, params); {
return std::make_shared<AggregateFunctionNullUnary<true, true>>(nested_function, arguments, params);
}
else else
return std::make_shared<AggregateFunctionNullUnary<false>>(nested_function, arguments, params); {
if (serialize_flag)
return std::make_shared<AggregateFunctionNullUnary<false, true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullUnary<false, false>>(nested_function, arguments, params);
}
} }
else else
{ {
if (return_type_is_nullable) if (return_type_is_nullable)
return std::make_shared<AggregateFunctionNullVariadic<true, true>>(nested_function, arguments, params); {
return std::make_shared<AggregateFunctionNullVariadic<true, true, true>>(nested_function, arguments, params);
}
else else
return std::make_shared<AggregateFunctionNullVariadic<false, true>>(nested_function, arguments, params); {
if (serialize_flag)
return std::make_shared<AggregateFunctionNullVariadic<false, true, true>>(nested_function, arguments, params);
else
return std::make_shared<AggregateFunctionNullVariadic<false, true, false>>(nested_function, arguments, params);
}
} }
} }
}; };

View File

@ -28,7 +28,10 @@ namespace ErrorCodes
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter. /// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
/// true - return NULL; false - return value from empty aggregation state of nested function. /// true - return NULL; false - return value from empty aggregation state of nested function.
template <bool result_is_nullable, typename Derived> /// When serialize_flag is set to true, the flag about presense of values is serialized
/// regardless to the "result_is_nullable" even if it's unneeded - for protocol compatibility.
template <bool result_is_nullable, bool serialize_flag, typename Derived>
class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived> class AggregateFunctionNullBase : public IAggregateFunctionHelper<Derived>
{ {
protected: protected:
@ -129,7 +132,7 @@ public:
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
bool flag = getFlag(place); bool flag = getFlag(place);
if constexpr (result_is_nullable) if constexpr (serialize_flag)
writeBinary(flag, buf); writeBinary(flag, buf);
if (flag) if (flag)
nested_function->serialize(nestedPlace(place), buf); nested_function->serialize(nestedPlace(place), buf);
@ -138,7 +141,7 @@ public:
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{ {
bool flag = 1; bool flag = 1;
if constexpr (result_is_nullable) if constexpr (serialize_flag)
readBinary(flag, buf); readBinary(flag, buf);
if (flag) if (flag)
{ {
@ -183,12 +186,15 @@ public:
/** There are two cases: for single argument and variadic. /** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient. * Code for single argument is much more efficient.
*/ */
template <bool result_is_nullable> template <bool result_is_nullable, bool serialize_flag>
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>> class AggregateFunctionNullUnary final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>
{ {
public: public:
AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) AggregateFunctionNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullUnary<result_is_nullable>>(std::move(nested_function_), arguments, params) : AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullUnary<result_is_nullable, serialize_flag>>(std::move(nested_function_), arguments, params)
{ {
} }
@ -218,12 +224,15 @@ public:
}; };
template <bool result_is_nullable, bool null_is_skipped> template <bool result_is_nullable, bool serialize_flag, bool null_is_skipped>
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable, null_is_skipped>> class AggregateFunctionNullVariadic final
: public AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>
{ {
public: public:
AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params) AggregateFunctionNullVariadic(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: AggregateFunctionNullBase<result_is_nullable, AggregateFunctionNullVariadic<result_is_nullable, null_is_skipped>>(std::move(nested_function_), arguments, params), : AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionNullVariadic<result_is_nullable, serialize_flag, null_is_skipped>>(std::move(nested_function_), arguments, params),
number_of_arguments(arguments.size()) number_of_arguments(arguments.size())
{ {
if (number_of_arguments == 1) if (number_of_arguments == 1)
@ -263,11 +272,6 @@ public:
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena); this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
} }
bool allocatesMemoryInArena() const override
{
return this->nested_function->allocatesMemoryInArena();
}
private: private:
enum { MAX_ARGS = 8 }; enum { MAX_ARGS = 8 };
size_t number_of_arguments = 0; size_t number_of_arguments = 0;

View File

@ -244,6 +244,12 @@ public:
{ {
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size()); assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
} }
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
}; };
@ -298,6 +304,12 @@ public:
{ {
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size()); assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
} }
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
}; };
} }

View File

@ -240,9 +240,10 @@ public:
return std::make_shared<DataTypeUInt8>(); return std::make_shared<DataTypeUInt8>();
} }
AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override
{ {
return std::make_shared<AggregateFunctionNullVariadic<false, false>>(nested_function, arguments, params); return std::make_shared<AggregateFunctionNullVariadic<false, false, false>>(nested_function, arguments, params);
} }
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override

View File

@ -171,6 +171,12 @@ public:
return nullptr; return nullptr;
} }
/** When the function is wrapped with Null combinator,
* should we return Nullable type with NULL when no values were aggregated
* or we should return non-Nullable type with default value (example: count, countDistinct).
*/
virtual bool returnDefaultWhenOnlyNull() const { return false; }
const DataTypes & getArgumentTypes() const { return argument_types; } const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; } const Array & getParameters() const { return parameters; }

View File

@ -0,0 +1,9 @@
0
0
0
5
5
5
0
\N
\N

View File

@ -0,0 +1,12 @@
SELECT uniq(number >= 10 ? number : NULL) FROM numbers(10);
SELECT uniqExact(number >= 10 ? number : NULL) FROM numbers(10);
SELECT count(DISTINCT number >= 10 ? number : NULL) FROM numbers(10);
SELECT uniq(number >= 5 ? number : NULL) FROM numbers(10);
SELECT uniqExact(number >= 5 ? number : NULL) FROM numbers(10);
SELECT count(DISTINCT number >= 5 ? number : NULL) FROM numbers(10);
SELECT count(NULL);
-- These two returns NULL for now, but we want to change them to return 0.
SELECT uniq(NULL);
SELECT count(DISTINCT NULL);