Fix *If combinator with Nullable types

sumIf(Nullable()) and similar unary functions (unary w/o If combinator)
was working incorrectly, since it returns "sum" from the getName()
helper, and so distributed query processing fails.

The problem is in the optimization in
AggregateFunctionIfNullUnary::add() for the unary functions. It pass
only one column to write result to, instead of all passed arguments +
result columns.
While AggregateFunctionIf::add() assumes that it accepts arguments  +
result columns, and use last column as a result.

Introduced-in: #16610
Fixes: #18210
This commit is contained in:
Azat Khuzhin 2021-01-07 01:55:01 +03:00
parent 7907292bd7
commit fdcfacda60
3 changed files with 36 additions and 6 deletions

View File

@ -53,17 +53,35 @@ class AggregateFunctionIfNullUnary final
private:
size_t num_arguments;
/// The name of the nested function, including combinators (i.e. *If)
///
/// getName() from the nested_function cannot be used because in case of *If combinator
/// with Nullable argument nested_function will point to the function w/o combinator.
/// (I.e. sumIf(Nullable, 1) -> sum()), and distributed query processing will fail.
///
/// And nested_function cannot point to the function with *If since
/// due to optimization in the add() which pass only one column with the result,
/// and so AggregateFunctionIf::add() cannot be called this way
/// (it write to the last argument -- num_arguments-1).
///
/// And to avoid extra level of indirection, the name of function is cached:
///
/// AggregateFunctionIfNullUnary::add -> [ AggregateFunctionIf::add -> ] AggregateFunctionSum::add
String name;
using Base = AggregateFunctionNullBase<result_is_nullable, serialize_flag,
AggregateFunctionIfNullUnary<result_is_nullable, serialize_flag>>;
public:
String getName() const override
{
return Base::getName();
return name;
}
AggregateFunctionIfNullUnary(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: Base(std::move(nested_function_), arguments, params), num_arguments(arguments.size())
AggregateFunctionIfNullUnary(const String & name_, AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: Base(std::move(nested_function_), arguments, params)
, num_arguments(arguments.size())
, name(name_)
{
if (num_arguments == 0)
throw Exception("Aggregate function " + getName() + " require at least one argument",
@ -174,14 +192,14 @@ AggregateFunctionPtr AggregateFunctionIf::getOwnNullAdapter(
{
if (return_type_is_nullable)
{
return std::make_shared<AggregateFunctionIfNullUnary<true, true>>(nested_func, arguments, params);
return std::make_shared<AggregateFunctionIfNullUnary<true, true>>(nested_function->getName(), nested_func, arguments, params);
}
else
{
if (serialize_flag)
return std::make_shared<AggregateFunctionIfNullUnary<false, true>>(nested_func, arguments, params);
return std::make_shared<AggregateFunctionIfNullUnary<false, true>>(nested_function->getName(), nested_func, arguments, params);
else
return std::make_shared<AggregateFunctionIfNullUnary<false, false>>(nested_func, arguments, params);
return std::make_shared<AggregateFunctionIfNullUnary<false, false>>(nested_function->getName(), nested_func, arguments, params);
}
}
else

View File

@ -0,0 +1,5 @@
\N
\N
\N
0
90

View File

@ -0,0 +1,7 @@
SELECT sumIf(dummy, dummy) FROM remote('127.0.0.{1,2}', view(SELECT cast(Null AS Nullable(UInt8)) AS dummy FROM system.one));
SELECT sumIf(dummy, 1) FROM remote('127.0.0.{1,2}', view(SELECT cast(Null AS Nullable(UInt8)) AS dummy FROM system.one));
-- Before #16610 it returns 0 while with this patch it will return NULL
SELECT sumIf(dummy, dummy) FROM remote('127.0.0.{1,2}', view(SELECT cast(dummy AS Nullable(UInt8)) AS dummy FROM system.one));
SELECT sumIf(dummy, 1) FROM remote('127.0.0.{1,2}', view(SELECT cast(dummy AS Nullable(UInt8)) AS dummy FROM system.one));
SELECT sumIf(n, 1) FROM remote('127.0.0.{1,2}', view(SELECT cast(* AS Nullable(UInt8)) AS n FROM system.numbers limit 10))