Added support for aggregate functions of Nullable arguments in cases when they return non-Nullable result [#CLICKHOUSE-2].

This commit is contained in:
Alexey Milovidov 2017-12-08 12:07:52 +03:00
parent aa1641937a
commit e2cd0272a4
4 changed files with 60 additions and 30 deletions

View File

@ -5,12 +5,20 @@ namespace DB
AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested)
{
return std::make_shared<AggregateFunctionNullUnary>(nested);
const DataTypePtr & nested_return_type = nested->getReturnType();
if (nested_return_type && !nested_return_type->canBeInsideNullable())
return std::make_shared<AggregateFunctionNullUnary<false>>(nested);
else
return std::make_shared<AggregateFunctionNullUnary<true>>(nested);
}
AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested)
{
return std::make_shared<AggregateFunctionNullVariadic>(nested);
const DataTypePtr & nested_return_type = nested->getReturnType();
if (nested_return_type && !nested_return_type->canBeInsideNullable())
return std::make_shared<AggregateFunctionNullVariadic<false>>(nested);
else
return std::make_shared<AggregateFunctionNullVariadic<true>>(nested);
}
}

View File

@ -24,6 +24,10 @@ namespace ErrorCodes
/// at least one nullable argument. It implements the logic according to which any
/// row that contains at least one NULL is skipped.
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
/// true - return NULL; false - return value from empty aggregation state of nested function.
template <bool result_is_nullable>
class AggregateFunctionNullBase : public IAggregateFunction
{
protected:
@ -36,27 +40,29 @@ protected:
static AggregateDataPtr nestedPlace(AggregateDataPtr place) noexcept
{
return place + 1;
return place + (result_is_nullable ? 1 : 0);
}
static ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr place) noexcept
{
return place + 1;
return place + (result_is_nullable ? 1 : 0);
}
static void initFlag(AggregateDataPtr place) noexcept
{
place[0] = 0;
if (result_is_nullable)
place[0] = 0;
}
static void setFlag(AggregateDataPtr place) noexcept
{
place[0] = 1;
if (result_is_nullable)
place[0] = 1;
}
static bool getFlag(ConstAggregateDataPtr place) noexcept
{
return place[0];
return result_is_nullable ? place[0] : 1;
}
public:
@ -78,7 +84,9 @@ public:
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeNullable>(nested_function->getReturnType());
return result_is_nullable
? std::make_shared<DataTypeNullable>(nested_function->getReturnType())
: nested_function->getReturnType();
}
void create(AggregateDataPtr place) const override
@ -109,7 +117,7 @@ public:
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
if (getFlag(rhs))
if (result_is_nullable && getFlag(rhs))
setFlag(place);
nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena);
@ -118,15 +126,17 @@ public:
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
bool flag = getFlag(place);
writeBinary(flag, buf);
if (result_is_nullable)
writeBinary(flag, buf);
if (flag)
nested_function->serialize(nestedPlace(place), buf);
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
bool flag;
readBinary(flag, buf);
bool flag = 1;
if (result_is_nullable)
readBinary(flag, buf);
if (flag)
{
setFlag(place);
@ -136,15 +146,22 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
ColumnNullable & to_concrete = static_cast<ColumnNullable &>(to);
if (getFlag(place))
if (result_is_nullable)
{
nested_function->insertResultInto(nestedPlace(place), *to_concrete.getNestedColumn());
to_concrete.getNullMap().push_back(0);
ColumnNullable & to_concrete = static_cast<ColumnNullable &>(to);
if (getFlag(place))
{
nested_function->insertResultInto(nestedPlace(place), *to_concrete.getNestedColumn());
to_concrete.getNullMap().push_back(0);
}
else
{
to_concrete.insertDefault();
}
}
else
{
to_concrete.insertDefault();
nested_function->insertResultInto(nestedPlace(place), to);
}
}
@ -165,10 +182,11 @@ public:
/** There are two cases: for single argument and variadic.
* Code for single argument is much more efficient.
*/
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase
template <bool result_is_nullable>
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable>
{
public:
using AggregateFunctionNullBase::AggregateFunctionNullBase;
using AggregateFunctionNullBase<result_is_nullable>::AggregateFunctionNullBase;
void setArguments(const DataTypes & arguments) override
{
@ -178,7 +196,7 @@ public:
if (!arguments.front()->isNullable())
throw Exception("Logical error: not nullable data type is passed to AggregateFunctionNullUnary", ErrorCodes::LOGICAL_ERROR);
nested_function->setArguments({static_cast<const DataTypeNullable &>(*arguments.front()).getNestedType()});
this->nested_function->setArguments({static_cast<const DataTypeNullable &>(*arguments.front()).getNestedType()});
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
@ -186,9 +204,9 @@ public:
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
if (!column->isNullAt(row_num))
{
setFlag(place);
this->setFlag(place);
const IColumn * nested_column = column->getNestedColumn().get();
nested_function->add(nestedPlace(place), &nested_column, row_num, arena);
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
}
}
@ -198,17 +216,18 @@ public:
return static_cast<const AggregateFunctionNullUnary &>(*that).add(place, columns, row_num, arena);
}
AddFunc getAddressOfAddFunction() const override
IAggregateFunction::AddFunc getAddressOfAddFunction() const override
{
return &addFree;
}
};
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase
template <bool result_is_nullable>
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable>
{
public:
using AggregateFunctionNullBase::AggregateFunctionNullBase;
using AggregateFunctionNullBase<result_is_nullable>::AggregateFunctionNullBase;
void setArguments(const DataTypes & arguments) override
{
@ -238,7 +257,7 @@ public:
nested_args[i] = arguments[i];
}
nested_function->setArguments(nested_args);
this->nested_function->setArguments(nested_args);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
@ -263,13 +282,13 @@ public:
nested_columns[i] = columns[i];
}
setFlag(place);
nested_function->add(nestedPlace(place), nested_columns, row_num, arena);
this->setFlag(place);
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
}
bool allocatesMemoryInArena() const override
{
return nested_function->allocatesMemoryInArena();
return this->nested_function->allocatesMemoryInArena();
}
static void addFree(const IAggregateFunction * that, AggregateDataPtr place,
@ -278,7 +297,7 @@ public:
return static_cast<const AggregateFunctionNullVariadic &>(*that).add(place, columns, row_num, arena);
}
AddFunc getAddressOfAddFunction() const override
IAggregateFunction::AddFunc getAddressOfAddFunction() const override
{
return &addFree;
}

View File

@ -0,0 +1,2 @@
1 [1,2] Array(UInt8) 1.5 Nullable(Float64)
2 [] Array(UInt8) \N Nullable(Float64)

View File

@ -0,0 +1 @@
SELECT k, groupArray(x) AS res1, toTypeName(res1), avg(x) AS res2, toTypeName(res2) FROM (SELECT 1 AS k, arrayJoin([1, NULL, 2]) AS x UNION ALL SELECT 2 AS k, CAST(arrayJoin([NULL, NULL]) AS Nullable(UInt8)) AS x) GROUP BY k ORDER BY k;