Added support for aggregate functions of Nullable arguments in cases when they return non-Nullable result [#CLICKHOUSE-2].

This commit is contained in:
Alexey Milovidov 2017-12-08 12:07:52 +03:00
parent aa1641937a
commit e2cd0272a4
4 changed files with 60 additions and 30 deletions

View File

@ -5,12 +5,20 @@ namespace DB
AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested) AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested)
{ {
return std::make_shared<AggregateFunctionNullUnary>(nested); const DataTypePtr & nested_return_type = nested->getReturnType();
if (nested_return_type && !nested_return_type->canBeInsideNullable())
return std::make_shared<AggregateFunctionNullUnary<false>>(nested);
else
return std::make_shared<AggregateFunctionNullUnary<true>>(nested);
} }
AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested) AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested)
{ {
return std::make_shared<AggregateFunctionNullVariadic>(nested); const DataTypePtr & nested_return_type = nested->getReturnType();
if (nested_return_type && !nested_return_type->canBeInsideNullable())
return std::make_shared<AggregateFunctionNullVariadic<false>>(nested);
else
return std::make_shared<AggregateFunctionNullVariadic<true>>(nested);
} }
} }

View File

@ -24,6 +24,10 @@ namespace ErrorCodes
/// at least one nullable argument. It implements the logic according to which any /// at least one nullable argument. It implements the logic according to which any
/// row that contains at least one NULL is skipped. /// row that contains at least one NULL is skipped.
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
/// true - return NULL; false - return value from empty aggregation state of nested function.
template <bool result_is_nullable>
class AggregateFunctionNullBase : public IAggregateFunction class AggregateFunctionNullBase : public IAggregateFunction
{ {
protected: protected:
@ -36,27 +40,29 @@ protected:
static AggregateDataPtr nestedPlace(AggregateDataPtr place) noexcept static AggregateDataPtr nestedPlace(AggregateDataPtr place) noexcept
{ {
return place + 1; return place + (result_is_nullable ? 1 : 0);
} }
static ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr place) noexcept static ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr place) noexcept
{ {
return place + 1; return place + (result_is_nullable ? 1 : 0);
} }
static void initFlag(AggregateDataPtr place) noexcept static void initFlag(AggregateDataPtr place) noexcept
{ {
if (result_is_nullable)
place[0] = 0; place[0] = 0;
} }
static void setFlag(AggregateDataPtr place) noexcept static void setFlag(AggregateDataPtr place) noexcept
{ {
if (result_is_nullable)
place[0] = 1; place[0] = 1;
} }
static bool getFlag(ConstAggregateDataPtr place) noexcept static bool getFlag(ConstAggregateDataPtr place) noexcept
{ {
return place[0]; return result_is_nullable ? place[0] : 1;
} }
public: public:
@ -78,7 +84,9 @@ public:
DataTypePtr getReturnType() const override DataTypePtr getReturnType() const override
{ {
return std::make_shared<DataTypeNullable>(nested_function->getReturnType()); return result_is_nullable
? std::make_shared<DataTypeNullable>(nested_function->getReturnType())
: nested_function->getReturnType();
} }
void create(AggregateDataPtr place) const override void create(AggregateDataPtr place) const override
@ -109,7 +117,7 @@ public:
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{ {
if (getFlag(rhs)) if (result_is_nullable && getFlag(rhs))
setFlag(place); setFlag(place);
nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena); nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena);
@ -118,6 +126,7 @@ public:
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{ {
bool flag = getFlag(place); bool flag = getFlag(place);
if (result_is_nullable)
writeBinary(flag, buf); writeBinary(flag, buf);
if (flag) if (flag)
nested_function->serialize(nestedPlace(place), buf); nested_function->serialize(nestedPlace(place), buf);
@ -125,7 +134,8 @@ public:
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{ {
bool flag; bool flag = 1;
if (result_is_nullable)
readBinary(flag, buf); readBinary(flag, buf);
if (flag) if (flag)
{ {
@ -135,6 +145,8 @@ public:
} }
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
if (result_is_nullable)
{ {
ColumnNullable & to_concrete = static_cast<ColumnNullable &>(to); ColumnNullable & to_concrete = static_cast<ColumnNullable &>(to);
if (getFlag(place)) if (getFlag(place))
@ -147,6 +159,11 @@ public:
to_concrete.insertDefault(); to_concrete.insertDefault();
} }
} }
else
{
nested_function->insertResultInto(nestedPlace(place), to);
}
}
bool allocatesMemoryInArena() const override bool allocatesMemoryInArena() const override
{ {
@ -165,10 +182,11 @@ public:
/** There are two cases: for single argument and variadic. /** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient. * Code for single argument is much more efficient.
*/ */
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase template <bool result_is_nullable>
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable>
{ {
public: public:
using AggregateFunctionNullBase::AggregateFunctionNullBase; using AggregateFunctionNullBase<result_is_nullable>::AggregateFunctionNullBase;
void setArguments(const DataTypes & arguments) override void setArguments(const DataTypes & arguments) override
{ {
@ -178,7 +196,7 @@ public:
if (!arguments.front()->isNullable()) if (!arguments.front()->isNullable())
throw Exception("Logical error: not nullable data type is passed to AggregateFunctionNullUnary", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: not nullable data type is passed to AggregateFunctionNullUnary", ErrorCodes::LOGICAL_ERROR);
nested_function->setArguments({static_cast<const DataTypeNullable &>(*arguments.front()).getNestedType()}); this->nested_function->setArguments({static_cast<const DataTypeNullable &>(*arguments.front()).getNestedType()});
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
@ -186,9 +204,9 @@ public:
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]); const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
if (!column->isNullAt(row_num)) if (!column->isNullAt(row_num))
{ {
setFlag(place); this->setFlag(place);
const IColumn * nested_column = column->getNestedColumn().get(); const IColumn * nested_column = column->getNestedColumn().get();
nested_function->add(nestedPlace(place), &nested_column, row_num, arena); this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
} }
} }
@ -198,17 +216,18 @@ public:
return static_cast<const AggregateFunctionNullUnary &>(*that).add(place, columns, row_num, arena); return static_cast<const AggregateFunctionNullUnary &>(*that).add(place, columns, row_num, arena);
} }
AddFunc getAddressOfAddFunction() const override IAggregateFunction::AddFunc getAddressOfAddFunction() const override
{ {
return &addFree; return &addFree;
} }
}; };
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase template <bool result_is_nullable>
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable>
{ {
public: public:
using AggregateFunctionNullBase::AggregateFunctionNullBase; using AggregateFunctionNullBase<result_is_nullable>::AggregateFunctionNullBase;
void setArguments(const DataTypes & arguments) override void setArguments(const DataTypes & arguments) override
{ {
@ -238,7 +257,7 @@ public:
nested_args[i] = arguments[i]; nested_args[i] = arguments[i];
} }
nested_function->setArguments(nested_args); this->nested_function->setArguments(nested_args);
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
@ -263,13 +282,13 @@ public:
nested_columns[i] = columns[i]; nested_columns[i] = columns[i];
} }
setFlag(place); this->setFlag(place);
nested_function->add(nestedPlace(place), nested_columns, row_num, arena); this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
} }
bool allocatesMemoryInArena() const override bool allocatesMemoryInArena() const override
{ {
return nested_function->allocatesMemoryInArena(); return this->nested_function->allocatesMemoryInArena();
} }
static void addFree(const IAggregateFunction * that, AggregateDataPtr place, static void addFree(const IAggregateFunction * that, AggregateDataPtr place,
@ -278,7 +297,7 @@ public:
return static_cast<const AggregateFunctionNullVariadic &>(*that).add(place, columns, row_num, arena); return static_cast<const AggregateFunctionNullVariadic &>(*that).add(place, columns, row_num, arena);
} }
AddFunc getAddressOfAddFunction() const override IAggregateFunction::AddFunc getAddressOfAddFunction() const override
{ {
return &addFree; return &addFree;
} }

View File

@ -0,0 +1,2 @@
1 [1,2] Array(UInt8) 1.5 Nullable(Float64)
2 [] Array(UInt8) \N Nullable(Float64)

View File

@ -0,0 +1 @@
SELECT k, groupArray(x) AS res1, toTypeName(res1), avg(x) AS res2, toTypeName(res2) FROM (SELECT 1 AS k, arrayJoin([1, NULL, 2]) AS x UNION ALL SELECT 2 AS k, CAST(arrayJoin([NULL, NULL]) AS Nullable(UInt8)) AS x) GROUP BY k ORDER BY k;