mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
Added support for aggregate functions of Nullable arguments in cases when they return non-Nullable result [#CLICKHOUSE-2].
This commit is contained in:
parent
aa1641937a
commit
e2cd0272a4
@ -5,12 +5,20 @@ namespace DB
|
|||||||
|
|
||||||
AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested)
|
AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested)
|
||||||
{
|
{
|
||||||
return std::make_shared<AggregateFunctionNullUnary>(nested);
|
const DataTypePtr & nested_return_type = nested->getReturnType();
|
||||||
|
if (nested_return_type && !nested_return_type->canBeInsideNullable())
|
||||||
|
return std::make_shared<AggregateFunctionNullUnary<false>>(nested);
|
||||||
|
else
|
||||||
|
return std::make_shared<AggregateFunctionNullUnary<true>>(nested);
|
||||||
}
|
}
|
||||||
|
|
||||||
AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested)
|
AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested)
|
||||||
{
|
{
|
||||||
return std::make_shared<AggregateFunctionNullVariadic>(nested);
|
const DataTypePtr & nested_return_type = nested->getReturnType();
|
||||||
|
if (nested_return_type && !nested_return_type->canBeInsideNullable())
|
||||||
|
return std::make_shared<AggregateFunctionNullVariadic<false>>(nested);
|
||||||
|
else
|
||||||
|
return std::make_shared<AggregateFunctionNullVariadic<true>>(nested);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -24,6 +24,10 @@ namespace ErrorCodes
|
|||||||
/// at least one nullable argument. It implements the logic according to which any
|
/// at least one nullable argument. It implements the logic according to which any
|
||||||
/// row that contains at least one NULL is skipped.
|
/// row that contains at least one NULL is skipped.
|
||||||
|
|
||||||
|
/// If all rows had NULL, the behaviour is determined by "result_is_nullable" template parameter.
|
||||||
|
/// true - return NULL; false - return value from empty aggregation state of nested function.
|
||||||
|
|
||||||
|
template <bool result_is_nullable>
|
||||||
class AggregateFunctionNullBase : public IAggregateFunction
|
class AggregateFunctionNullBase : public IAggregateFunction
|
||||||
{
|
{
|
||||||
protected:
|
protected:
|
||||||
@ -36,27 +40,29 @@ protected:
|
|||||||
|
|
||||||
static AggregateDataPtr nestedPlace(AggregateDataPtr place) noexcept
|
static AggregateDataPtr nestedPlace(AggregateDataPtr place) noexcept
|
||||||
{
|
{
|
||||||
return place + 1;
|
return place + (result_is_nullable ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr place) noexcept
|
static ConstAggregateDataPtr nestedPlace(ConstAggregateDataPtr place) noexcept
|
||||||
{
|
{
|
||||||
return place + 1;
|
return place + (result_is_nullable ? 1 : 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void initFlag(AggregateDataPtr place) noexcept
|
static void initFlag(AggregateDataPtr place) noexcept
|
||||||
{
|
{
|
||||||
|
if (result_is_nullable)
|
||||||
place[0] = 0;
|
place[0] = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void setFlag(AggregateDataPtr place) noexcept
|
static void setFlag(AggregateDataPtr place) noexcept
|
||||||
{
|
{
|
||||||
|
if (result_is_nullable)
|
||||||
place[0] = 1;
|
place[0] = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool getFlag(ConstAggregateDataPtr place) noexcept
|
static bool getFlag(ConstAggregateDataPtr place) noexcept
|
||||||
{
|
{
|
||||||
return place[0];
|
return result_is_nullable ? place[0] : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -78,7 +84,9 @@ public:
|
|||||||
|
|
||||||
DataTypePtr getReturnType() const override
|
DataTypePtr getReturnType() const override
|
||||||
{
|
{
|
||||||
return std::make_shared<DataTypeNullable>(nested_function->getReturnType());
|
return result_is_nullable
|
||||||
|
? std::make_shared<DataTypeNullable>(nested_function->getReturnType())
|
||||||
|
: nested_function->getReturnType();
|
||||||
}
|
}
|
||||||
|
|
||||||
void create(AggregateDataPtr place) const override
|
void create(AggregateDataPtr place) const override
|
||||||
@ -109,7 +117,7 @@ public:
|
|||||||
|
|
||||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||||
{
|
{
|
||||||
if (getFlag(rhs))
|
if (result_is_nullable && getFlag(rhs))
|
||||||
setFlag(place);
|
setFlag(place);
|
||||||
|
|
||||||
nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena);
|
nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena);
|
||||||
@ -118,6 +126,7 @@ public:
|
|||||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||||
{
|
{
|
||||||
bool flag = getFlag(place);
|
bool flag = getFlag(place);
|
||||||
|
if (result_is_nullable)
|
||||||
writeBinary(flag, buf);
|
writeBinary(flag, buf);
|
||||||
if (flag)
|
if (flag)
|
||||||
nested_function->serialize(nestedPlace(place), buf);
|
nested_function->serialize(nestedPlace(place), buf);
|
||||||
@ -125,7 +134,8 @@ public:
|
|||||||
|
|
||||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
|
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
|
||||||
{
|
{
|
||||||
bool flag;
|
bool flag = 1;
|
||||||
|
if (result_is_nullable)
|
||||||
readBinary(flag, buf);
|
readBinary(flag, buf);
|
||||||
if (flag)
|
if (flag)
|
||||||
{
|
{
|
||||||
@ -135,6 +145,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||||
|
{
|
||||||
|
if (result_is_nullable)
|
||||||
{
|
{
|
||||||
ColumnNullable & to_concrete = static_cast<ColumnNullable &>(to);
|
ColumnNullable & to_concrete = static_cast<ColumnNullable &>(to);
|
||||||
if (getFlag(place))
|
if (getFlag(place))
|
||||||
@ -147,6 +159,11 @@ public:
|
|||||||
to_concrete.insertDefault();
|
to_concrete.insertDefault();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
nested_function->insertResultInto(nestedPlace(place), to);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool allocatesMemoryInArena() const override
|
bool allocatesMemoryInArena() const override
|
||||||
{
|
{
|
||||||
@ -165,10 +182,11 @@ public:
|
|||||||
/** There are two cases: for single argument and variadic.
|
/** There are two cases: for single argument and variadic.
|
||||||
* Code for single argument is much more efficient.
|
* Code for single argument is much more efficient.
|
||||||
*/
|
*/
|
||||||
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase
|
template <bool result_is_nullable>
|
||||||
|
class AggregateFunctionNullUnary final : public AggregateFunctionNullBase<result_is_nullable>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using AggregateFunctionNullBase::AggregateFunctionNullBase;
|
using AggregateFunctionNullBase<result_is_nullable>::AggregateFunctionNullBase;
|
||||||
|
|
||||||
void setArguments(const DataTypes & arguments) override
|
void setArguments(const DataTypes & arguments) override
|
||||||
{
|
{
|
||||||
@ -178,7 +196,7 @@ public:
|
|||||||
if (!arguments.front()->isNullable())
|
if (!arguments.front()->isNullable())
|
||||||
throw Exception("Logical error: not nullable data type is passed to AggregateFunctionNullUnary", ErrorCodes::LOGICAL_ERROR);
|
throw Exception("Logical error: not nullable data type is passed to AggregateFunctionNullUnary", ErrorCodes::LOGICAL_ERROR);
|
||||||
|
|
||||||
nested_function->setArguments({static_cast<const DataTypeNullable &>(*arguments.front()).getNestedType()});
|
this->nested_function->setArguments({static_cast<const DataTypeNullable &>(*arguments.front()).getNestedType()});
|
||||||
}
|
}
|
||||||
|
|
||||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||||
@ -186,9 +204,9 @@ public:
|
|||||||
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
|
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
|
||||||
if (!column->isNullAt(row_num))
|
if (!column->isNullAt(row_num))
|
||||||
{
|
{
|
||||||
setFlag(place);
|
this->setFlag(place);
|
||||||
const IColumn * nested_column = column->getNestedColumn().get();
|
const IColumn * nested_column = column->getNestedColumn().get();
|
||||||
nested_function->add(nestedPlace(place), &nested_column, row_num, arena);
|
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -198,17 +216,18 @@ public:
|
|||||||
return static_cast<const AggregateFunctionNullUnary &>(*that).add(place, columns, row_num, arena);
|
return static_cast<const AggregateFunctionNullUnary &>(*that).add(place, columns, row_num, arena);
|
||||||
}
|
}
|
||||||
|
|
||||||
AddFunc getAddressOfAddFunction() const override
|
IAggregateFunction::AddFunc getAddressOfAddFunction() const override
|
||||||
{
|
{
|
||||||
return &addFree;
|
return &addFree;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase
|
template <bool result_is_nullable>
|
||||||
|
class AggregateFunctionNullVariadic final : public AggregateFunctionNullBase<result_is_nullable>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using AggregateFunctionNullBase::AggregateFunctionNullBase;
|
using AggregateFunctionNullBase<result_is_nullable>::AggregateFunctionNullBase;
|
||||||
|
|
||||||
void setArguments(const DataTypes & arguments) override
|
void setArguments(const DataTypes & arguments) override
|
||||||
{
|
{
|
||||||
@ -238,7 +257,7 @@ public:
|
|||||||
nested_args[i] = arguments[i];
|
nested_args[i] = arguments[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
nested_function->setArguments(nested_args);
|
this->nested_function->setArguments(nested_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||||
@ -263,13 +282,13 @@ public:
|
|||||||
nested_columns[i] = columns[i];
|
nested_columns[i] = columns[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
setFlag(place);
|
this->setFlag(place);
|
||||||
nested_function->add(nestedPlace(place), nested_columns, row_num, arena);
|
this->nested_function->add(this->nestedPlace(place), nested_columns, row_num, arena);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool allocatesMemoryInArena() const override
|
bool allocatesMemoryInArena() const override
|
||||||
{
|
{
|
||||||
return nested_function->allocatesMemoryInArena();
|
return this->nested_function->allocatesMemoryInArena();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void addFree(const IAggregateFunction * that, AggregateDataPtr place,
|
static void addFree(const IAggregateFunction * that, AggregateDataPtr place,
|
||||||
@ -278,7 +297,7 @@ public:
|
|||||||
return static_cast<const AggregateFunctionNullVariadic &>(*that).add(place, columns, row_num, arena);
|
return static_cast<const AggregateFunctionNullVariadic &>(*that).add(place, columns, row_num, arena);
|
||||||
}
|
}
|
||||||
|
|
||||||
AddFunc getAddressOfAddFunction() const override
|
IAggregateFunction::AddFunc getAddressOfAddFunction() const override
|
||||||
{
|
{
|
||||||
return &addFree;
|
return &addFree;
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
1 [1,2] Array(UInt8) 1.5 Nullable(Float64)
|
||||||
|
2 [] Array(UInt8) \N Nullable(Float64)
|
@ -0,0 +1 @@
|
|||||||
|
SELECT k, groupArray(x) AS res1, toTypeName(res1), avg(x) AS res2, toTypeName(res2) FROM (SELECT 1 AS k, arrayJoin([1, NULL, 2]) AS x UNION ALL SELECT 2 AS k, CAST(arrayJoin([NULL, NULL]) AS Nullable(UInt8)) AS x) GROUP BY k ORDER BY k;
|
Loading…
Reference in New Issue
Block a user