mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 07:31:57 +00:00
Added support for aggregate functions of Nullable arguments in cases when they return non-Nullable result [#CLICKHOUSE-2].
This commit is contained in:
parent
aa1641937a
commit
e2cd0272a4
@ -5,12 +5,20 @@ namespace DB
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested)
|
||||
{
|
||||
return std::make_shared<AggregateFunctionNullUnary>(nested);
|
||||
const DataTypePtr & nested_return_type = nested->getReturnType();
|
||||
if (nested_return_type && !nested_return_type->canBeInsideNullable())
|
||||
return std::make_shared<AggregateFunctionNullUnary<false>>(nested);
|
||||
else
|
||||
return std::make_shared<AggregateFunctionNullUnary<true>>(nested);
|
||||
}
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested)
|
||||
{
|
||||
return std::make_shared<AggregateFunctionNullVariadic>(nested);
|
||||
const DataTypePtr & nested_return_type = nested->getReturnType();
|
||||
if (nested_return_type && !nested_return_type->canBeInsideNullable())
|
||||
return std::make_shared<AggregateFunctionNullVariadic<false>>(nested);
|
||||
else
|
||||
return std::make_shared<AggregateFunctionNullVariadic<true>>(nested);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -24,6 +24,10 @@ namespace ErrorCodes
|
||||
/// at least one nullable argument. It implements the logic according to which any
|
||||
/// row that contains at least one NULL is skipped.
|
||||
|
||||
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
|
||||
/// true - return NULL; false - return value from empty aggregation state of nested function.
|
||||
|
||||
template <bool result_is_nullable>
|
||||
class AggregateFunctionNullBase : public IAggregateFunction
|
||||
{
|
||||
protected:
|
||||
@ -36,27 +40,29 @@ protected:
|
||||
|
||||
static AggregateDataPtr nestedPlace(AggregateDataPtr place) noexcept
|
||||
{
|
||||
return place + 1;
|
||||
return place + (result_is_nullable ? 1 : 0);
|
||||
}
|
||||
|
||||
static ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr place) noexcept
|
||||
{
|
||||
return place + 1;
|
||||
return place + (result_is_nullable ? 1 : 0);
|
||||
}
|
||||
|
||||
static void initFlag(AggregateDataPtr place) noexcept
|
||||
{
|
||||
if (result_is_nullable)
|
||||
place[0] = 0;
|
||||
}
|
||||
|
||||
static void setFlag(AggregateDataPtr place) noexcept
|
||||
{
|
||||
if (result_is_nullable)
|
||||
place[0] = 1;
|
||||
}
|
||||
|
||||
static bool getFlag(ConstAggregateDataPtr place) noexcept
|
||||
{
|
||||
return place[0];
|
||||
return result_is_nullable ? place[0] : 1;
|
||||
}
|
||||
|
||||
public:
|
||||
@ -78,7 +84,9 @@ public:
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNullable>(nested_function->getReturnType());
|
||||
return result_is_nullable
|
||||
? std::make_shared<DataTypeNullable>(nested_function->getReturnType())
|
||||
: nested_function->getReturnType();
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr place) const override
|
||||
@ -109,7 +117,7 @@ public:
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
if (getFlag(rhs))
|
||||
if (result_is_nullable && getFlag(rhs))
|
||||
setFlag(place);
|
||||
|
||||
nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena);
|
||||
@ -118,6 +126,7 @@ public:
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
bool flag = getFlag(place);
|
||||
if (result_is_nullable)
|
||||
writeBinary(flag, buf);
|
||||
if (flag)
|
||||
nested_function->serialize(nestedPlace(place), buf);
|
||||
@ -125,7 +134,8 @@ public:
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
|
||||
{
|
||||
bool flag;
|
||||
bool flag = 1;
|
||||
if (result_is_nullable)
|
||||
readBinary(flag, buf);
|
||||
if (flag)
|
||||
{
|
||||
@ -135,6 +145,8 @@ public:
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
if (result_is_nullable)
|
||||
{
|
||||
ColumnNullable & to_concrete = static_cast<ColumnNullable &>(to);
|
||||
if (getFlag(place))
|
||||
@ -147,6 +159,11 @@ public:
|
||||
to_concrete.insertDefault();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
nested_function->insertResultInto(nestedPlace(place), to);
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
@ -165,10 +182,11 @@ public:
|
||||
/** There are two cases: for single argument and variadic.
|
||||
* Code for single argument is much more efficient.
|
||||
*/
|
||||
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase
|
||||
template <bool result_is_nullable>
|
||||
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable>
|
||||
{
|
||||
public:
|
||||
using AggregateFunctionNullBase::AggregateFunctionNullBase;
|
||||
using AggregateFunctionNullBase<result_is_nullable>::AggregateFunctionNullBase;
|
||||
|
||||
void setArguments(const DataTypes & arguments) override
|
||||
{
|
||||
@ -178,7 +196,7 @@ public:
|
||||
if (!arguments.front()->isNullable())
|
||||
throw Exception("Logical error: not nullable data type is passed to AggregateFunctionNullUnary", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
nested_function->setArguments({static_cast<const DataTypeNullable &>(*arguments.front()).getNestedType()});
|
||||
this->nested_function->setArguments({static_cast<const DataTypeNullable &>(*arguments.front()).getNestedType()});
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
@ -186,9 +204,9 @@ public:
|
||||
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
|
||||
if (!column->isNullAt(row_num))
|
||||
{
|
||||
setFlag(place);
|
||||
this->setFlag(place);
|
||||
const IColumn * nested_column = column->getNestedColumn().get();
|
||||
nested_function->add(nestedPlace(place), &nested_column, row_num, arena);
|
||||
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
|
||||
}
|
||||
}
|
||||
|
||||
@ -198,17 +216,18 @@ public:
|
||||
return static_cast<const AggregateFunctionNullUnary &>(*that).add(place, columns, row_num, arena);
|
||||
}
|
||||
|
||||
AddFunc getAddressOfAddFunction() const override
|
||||
IAggregateFunction::AddFunc getAddressOfAddFunction() const override
|
||||
{
|
||||
return &addFree;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase
|
||||
template <bool result_is_nullable>
|
||||
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable>
|
||||
{
|
||||
public:
|
||||
using AggregateFunctionNullBase::AggregateFunctionNullBase;
|
||||
using AggregateFunctionNullBase<result_is_nullable>::AggregateFunctionNullBase;
|
||||
|
||||
void setArguments(const DataTypes & arguments) override
|
||||
{
|
||||
@ -238,7 +257,7 @@ public:
|
||||
nested_args[i] = arguments[i];
|
||||
}
|
||||
|
||||
nested_function->setArguments(nested_args);
|
||||
this->nested_function->setArguments(nested_args);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
@ -263,13 +282,13 @@ public:
|
||||
nested_columns[i] = columns[i];
|
||||
}
|
||||
|
||||
setFlag(place);
|
||||
nested_function->add(nestedPlace(place), nested_columns, row_num, arena);
|
||||
this->setFlag(place);
|
||||
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return nested_function->allocatesMemoryInArena();
|
||||
return this->nested_function->allocatesMemoryInArena();
|
||||
}
|
||||
|
||||
static void addFree(const IAggregateFunction * that, AggregateDataPtr place,
|
||||
@ -278,7 +297,7 @@ public:
|
||||
return static_cast<const AggregateFunctionNullVariadic &>(*that).add(place, columns, row_num, arena);
|
||||
}
|
||||
|
||||
AddFunc getAddressOfAddFunction() const override
|
||||
IAggregateFunction::AddFunc getAddressOfAddFunction() const override
|
||||
{
|
||||
return &addFree;
|
||||
}
|
||||
|
@ -0,0 +1,2 @@
|
||||
1 [1,2] Array(UInt8) 1.5 Nullable(Float64)
|
||||
2 [] Array(UInt8) \N Nullable(Float64)
|
@ -0,0 +1 @@
|
||||
SELECT k, groupArray(x) AS res1, toTypeName(res1), avg(x) AS res2, toTypeName(res2) FROM (SELECT 1 AS k, arrayJoin([1, NULL, 2]) AS x UNION ALL SELECT 2 AS k, CAST(arrayJoin([NULL, NULL]) AS Nullable(UInt8)) AS x) GROUP BY k ORDER BY k;
|
Loading…
Reference in New Issue
Block a user