Add aggregate function combinators: -OrNull & -OrDefault (#7331)

This commit is contained in:
hcz 2019-10-18 17:25:39 +08:00 committed by Alexander Kuzmenkov
parent 71cbe878fc
commit 502672c973
7 changed files with 389 additions and 19 deletions

View File

@ -0,0 +1,39 @@
#include <AggregateFunctions/AggregateFunctionOrFill.h>
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
namespace DB
{
template <bool UseNull>
class AggregateFunctionCombinatorOrFill final : public IAggregateFunctionCombinator
{
public:
String getName() const override
{
if constexpr (UseNull)
return "OrNull";
else
return "OrDefault";
}
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const DataTypes & arguments,
const Array & params) const override
{
return std::make_shared<AggregateFunctionOrFill<UseNull>>(
nested_function,
arguments,
params);
}
};
void registerAggregateFunctionCombinatorOrFill(AggregateFunctionCombinatorFactory & factory)
{
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorOrFill<false>>());
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorOrFill<true>>());
}
}

View File

@ -0,0 +1,179 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnNullable.h>
#include <Common/typeid_cast.h>
#include <DataTypes/DataTypeNullable.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ARGUMENT_OUT_OF_BOUND;
}
/**
* -OrDefault and -OrNull combinators for aggregate functions.
* If there are no input values, return NULL or a default value, accordingly.
* Use a single additional byte of data after the nested function data:
* 0 means there was no input, 1 means there was some.
*/
template <bool UseNull>
class AggregateFunctionOrFill final : public IAggregateFunctionHelper<AggregateFunctionOrFill<UseNull>>
{
private:
AggregateFunctionPtr nested_function;
size_t size_of_data;
DataTypePtr inner_type;
bool inner_nullable;
public:
AggregateFunctionOrFill(AggregateFunctionPtr nested_function_, const DataTypes & arguments, const Array & params)
: IAggregateFunctionHelper<AggregateFunctionOrFill>{arguments, params}
, nested_function{nested_function_}
, size_of_data {nested_function->sizeOfData()}
, inner_type {nested_function->getReturnType()}
, inner_nullable {inner_type->isNullable()}
{
// nothing
}
String getName() const override
{
if constexpr (UseNull)
return nested_function->getName() + "OrNull";
else
return nested_function->getName() + "OrDefault";
}
const char * getHeaderFilePath() const override
{
return __FILE__;
}
bool isState() const override
{
return nested_function->isState();
}
bool allocatesMemoryInArena() const override
{
return nested_function->allocatesMemoryInArena();
}
bool hasTrivialDestructor() const override
{
return nested_function->hasTrivialDestructor();
}
size_t sizeOfData() const override
{
return size_of_data + sizeof(char);
}
size_t alignOfData() const override
{
return nested_function->alignOfData();
}
void create(AggregateDataPtr place) const override
{
nested_function->create(place);
place[size_of_data] = 0;
}
void destroy(AggregateDataPtr place) const noexcept override
{
nested_function->destroy(place);
}
void add(
AggregateDataPtr place,
const IColumn ** columns,
size_t row_num,
Arena * arena) const override
{
nested_function->add(place, columns, row_num, arena);
place[size_of_data] = 1;
}
void merge(
AggregateDataPtr place,
ConstAggregateDataPtr rhs,
Arena * arena) const override
{
nested_function->merge(place, rhs, arena);
}
void serialize(
ConstAggregateDataPtr place,
WriteBuffer & buf) const override
{
nested_function->serialize(place, buf);
}
void deserialize(
AggregateDataPtr place,
ReadBuffer & buf,
Arena * arena) const override
{
nested_function->deserialize(place, buf, arena);
}
DataTypePtr getReturnType() const override
{
if constexpr (UseNull)
{
// -OrNull
if (inner_nullable)
return inner_type;
return std::make_shared<DataTypeNullable>(inner_type);
}
else
{
// -OrDefault
return inner_type;
}
}
void insertResultInto(
ConstAggregateDataPtr place,
IColumn & to) const override
{
if (place[size_of_data])
{
if constexpr (UseNull)
{
// -OrNull
if (inner_nullable)
nested_function->insertResultInto(place, to);
else
{
ColumnNullable & col = typeid_cast<ColumnNullable &>(to);
col.getNullMapColumn().insertDefault();
nested_function->insertResultInto(place, col.getNestedColumn());
}
}
else
{
// -OrDefault
nested_function->insertResultInto(place, to);
}
}
else
to.insertDefault();
}
};
}

View File

@ -29,8 +29,8 @@ private:
size_t step; size_t step;
size_t total; size_t total;
size_t aod; size_t align_of_data;
size_t sod; size_t size_of_data;
public: public:
AggregateFunctionResample( AggregateFunctionResample(
@ -47,8 +47,8 @@ public:
, end{end_} , end{end_}
, step{step_} , step{step_}
, total{0} , total{0}
, aod{nested_function->alignOfData()} , align_of_data{nested_function->alignOfData()}
, sod{(nested_function->sizeOfData() + aod - 1) / aod * aod} , size_of_data{(nested_function->sizeOfData() + align_of_data - 1) / align_of_data * align_of_data}
{ {
// notice: argument types has been checked before // notice: argument types has been checked before
if (step == 0) if (step == 0)
@ -94,24 +94,24 @@ public:
size_t sizeOfData() const override size_t sizeOfData() const override
{ {
return total * sod; return total * size_of_data;
} }
size_t alignOfData() const override size_t alignOfData() const override
{ {
return aod; return align_of_data;
} }
void create(AggregateDataPtr place) const override void create(AggregateDataPtr place) const override
{ {
for (size_t i = 0; i < total; ++i) for (size_t i = 0; i < total; ++i)
nested_function->create(place + i * sod); nested_function->create(place + i * size_of_data);
} }
void destroy(AggregateDataPtr place) const noexcept override void destroy(AggregateDataPtr place) const noexcept override
{ {
for (size_t i = 0; i < total; ++i) for (size_t i = 0; i < total; ++i)
nested_function->destroy(place + i * sod); nested_function->destroy(place + i * size_of_data);
} }
void add( void add(
@ -132,7 +132,7 @@ public:
size_t pos = (key - begin) / step; size_t pos = (key - begin) / step;
nested_function->add(place + pos * sod, columns, row_num, arena); nested_function->add(place + pos * size_of_data, columns, row_num, arena);
} }
void merge( void merge(
@ -141,7 +141,7 @@ public:
Arena * arena) const override Arena * arena) const override
{ {
for (size_t i = 0; i < total; ++i) for (size_t i = 0; i < total; ++i)
nested_function->merge(place + i * sod, rhs + i * sod, arena); nested_function->merge(place + i * size_of_data, rhs + i * size_of_data, arena);
} }
void serialize( void serialize(
@ -149,7 +149,7 @@ public:
WriteBuffer & buf) const override WriteBuffer & buf) const override
{ {
for (size_t i = 0; i < total; ++i) for (size_t i = 0; i < total; ++i)
nested_function->serialize(place + i * sod, buf); nested_function->serialize(place + i * size_of_data, buf);
} }
void deserialize( void deserialize(
@ -158,7 +158,7 @@ public:
Arena * arena) const override Arena * arena) const override
{ {
for (size_t i = 0; i < total; ++i) for (size_t i = 0; i < total; ++i)
nested_function->deserialize(place + i * sod, buf, arena); nested_function->deserialize(place + i * size_of_data, buf, arena);
} }
DataTypePtr getReturnType() const override DataTypePtr getReturnType() const override
@ -174,7 +174,7 @@ public:
auto & col_offsets = assert_cast<ColumnArray::ColumnOffsets &>(col.getOffsetsColumn()); auto & col_offsets = assert_cast<ColumnArray::ColumnOffsets &>(col.getOffsetsColumn());
for (size_t i = 0; i < total; ++i) for (size_t i = 0; i < total; ++i)
nested_function->insertResultInto(place + i * sod, col.getData()); nested_function->insertResultInto(place + i * size_of_data, col.getData());
col_offsets.getData().push_back(col.getData().size()); col_offsets.getData().push_back(col.getData().size());
} }

View File

@ -42,6 +42,7 @@ void registerAggregateFunctionCombinatorForEach(AggregateFunctionCombinatorFacto
void registerAggregateFunctionCombinatorState(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorState(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorMerge(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorMerge(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorOrFill(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorResample(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorResample(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctions() void registerAggregateFunctions()
@ -88,6 +89,7 @@ void registerAggregateFunctions()
registerAggregateFunctionCombinatorState(factory); registerAggregateFunctionCombinatorState(factory);
registerAggregateFunctionCombinatorMerge(factory); registerAggregateFunctionCombinatorMerge(factory);
registerAggregateFunctionCombinatorNull(factory); registerAggregateFunctionCombinatorNull(factory);
registerAggregateFunctionCombinatorOrFill(factory);
registerAggregateFunctionCombinatorResample(factory); registerAggregateFunctionCombinatorResample(factory);
} }
} }

View File

@ -0,0 +1,48 @@
--- Int Empty ---
0
\N
0
\N
0
\N
0
\N
0
\N
0
\N
--- Int Non-empty ---
1
1
nan
nan
1
1
1
1
nan
nan
1
1
--- Other Types Empty ---
\N
\N
\N
0.00
\N
0
\N
0.00
\N
--- Other Types Non-empty ---
hello
hello
2011-04-05 14:19:19
2011-04-05 14:19:19
-123.45
-123.45
inf
inf
-123.45
-123.45

View File

@ -0,0 +1,61 @@
SELECT '--- Int Empty ---';
SELECT arrayReduce('avgOrDefault', arrayPopBack([1]));
SELECT arrayReduce('avgOrNull', arrayPopBack([1]));
SELECT arrayReduce('stddevSampOrDefault', arrayPopBack([1]));
SELECT arrayReduce('stddevSampOrNull', arrayPopBack([1]));
SELECT arrayReduce('maxOrDefault', arrayPopBack([1]));
SELECT arrayReduce('maxOrNull', arrayPopBack([1]));
SELECT avgOrDefaultIf(x, x > 1) FROM (SELECT 1 AS x);
SELECT avgOrNullIf(x, x > 1) FROM (SELECT 1 AS x);
SELECT stddevSampOrDefaultIf(x, x > 1) FROM (SELECT 1 AS x);
SELECT stddevSampOrNullIf(x, x > 1) FROM (SELECT 1 AS x);
SELECT maxOrDefaultIf(x, x > 1) FROM (SELECT 1 AS x);
SELECT maxOrNullIf(x, x > 1) FROM (SELECT 1 AS x);
SELECT '--- Int Non-empty ---';
SELECT arrayReduce('avgOrDefault', [1]);
SELECT arrayReduce('avgOrNull', [1]);
SELECT arrayReduce('stddevSampOrDefault', [1]);
SELECT arrayReduce('stddevSampOrNull', [1]);
SELECT arrayReduce('maxOrDefault', [1]);
SELECT arrayReduce('maxOrNull', [1]);
SELECT avgOrDefaultIf(x, x > 0) FROM (SELECT 1 AS x);
SELECT avgOrNullIf(x, x > 0) FROM (SELECT 1 AS x);
SELECT stddevSampOrDefaultIf(x, x > 0) FROM (SELECT 1 AS x);
SELECT stddevSampOrNullIf(x, x > 0) FROM (SELECT 1 AS x);
SELECT maxOrDefaultIf(x, x > 0) FROM (SELECT 1 AS x);
SELECT maxOrNullIf(x, x > 0) FROM (SELECT 1 AS x);
SELECT '--- Other Types Empty ---';
SELECT arrayReduce('maxOrDefault', arrayPopBack(['hello']));
SELECT arrayReduce('maxOrNull', arrayPopBack(['hello']));
SELECT arrayReduce('maxOrDefault', arrayPopBack(arrayPopBack([toDateTime('2011-04-05 14:19:19'), null])));
SELECT arrayReduce('maxOrNull', arrayPopBack(arrayPopBack([toDateTime('2011-04-05 14:19:19'), null])));
SELECT arrayReduce('avgOrDefault', arrayPopBack([toDecimal128(-123.45, 2)]));
SELECT arrayReduce('avgOrNull', arrayPopBack([toDecimal128(-123.45, 2)]));
SELECT arrayReduce('stddevSampOrDefault', arrayPopBack([toDecimal128(-123.45, 2)]));
SELECT arrayReduce('stddevSampOrNull', arrayPopBack([toDecimal128(-123.45, 2)]));
SELECT arrayReduce('maxOrDefault', arrayPopBack([toDecimal128(-123.45, 2)]));
SELECT arrayReduce('maxOrNull', arrayPopBack([toDecimal128(-123.45, 2)]));
SELECT '--- Other Types Non-empty ---';
SELECT arrayReduce('maxOrDefault', ['hello']);
SELECT arrayReduce('maxOrNull', ['hello']);
SELECT arrayReduce('maxOrDefault', [toDateTime('2011-04-05 14:19:19'), null]);
SELECT arrayReduce('maxOrNull', [toDateTime('2011-04-05 14:19:19'), null]);
SELECT arrayReduce('avgOrDefault', [toDecimal128(-123.45, 2)]);
SELECT arrayReduce('avgOrNull', [toDecimal128(-123.45, 2)]);
SELECT arrayReduce('stddevSampOrDefault', [toDecimal128(-123.45, 2)]);
SELECT arrayReduce('stddevSampOrNull', [toDecimal128(-123.45, 2)]);
SELECT arrayReduce('maxOrDefault', [toDecimal128(-123.45, 2)]);
SELECT arrayReduce('maxOrNull', [toDecimal128(-123.45, 2)]);

View File

@ -10,7 +10,7 @@ Examples: `sumIf(column, cond)`, `countIf(cond)`, `avgIf(x, cond)`, `quantilesTi
With conditional aggregate functions, you can calculate aggregates for several conditions at once, without using subqueries and `JOIN`s. For example, in Yandex.Metrica, conditional aggregate functions are used to implement the segment comparison functionality. With conditional aggregate functions, you can calculate aggregates for several conditions at once, without using subqueries and `JOIN`s. For example, in Yandex.Metrica, conditional aggregate functions are used to implement the segment comparison functionality.
## -Array ## -Array {#agg-functions-combinator-array}
The -Array suffix can be appended to any aggregate function. In this case, the aggregate function takes arguments of the 'Array(T)' type (arrays) instead of 'T' type arguments. If the aggregate function accepts multiple arguments, this must be arrays of equal lengths. When processing arrays, the aggregate function works like the original aggregate function across all array elements. The -Array suffix can be appended to any aggregate function. In this case, the aggregate function takes arguments of the 'Array(T)' type (arrays) instead of 'T' type arguments. If the aggregate function accepts multiple arguments, this must be arrays of equal lengths. When processing arrays, the aggregate function works like the original aggregate function across all array elements.
@ -18,9 +18,9 @@ Example 1: `sumArray(arr)` - Totals all the elements of all 'arr' arrays. In thi
Example 2: `uniqArray(arr)` Counts the number of unique elements in all 'arr' arrays. This could be done an easier way: `uniq(arrayJoin(arr))`, but it's not always possible to add 'arrayJoin' to a query. Example 2: `uniqArray(arr)` Counts the number of unique elements in all 'arr' arrays. This could be done an easier way: `uniq(arrayJoin(arr))`, but it's not always possible to add 'arrayJoin' to a query.
-If and -Array can be combined. However, 'Array' must come first, then 'If'. Examples: `uniqArrayIf(arr, cond)`, `quantilesTimingArrayIf(level1, level2)(arr, cond)`. Due to this order, the 'cond' argument can't be an array. -If and -Array can be combined. However, 'Array' must come first, then 'If'. Examples: `uniqArrayIf(arr, cond)`, `quantilesTimingArrayIf(level1, level2)(arr, cond)`. Due to this order, the 'cond' argument won't be an array.
## -State ## -State {#agg-functions-combinator-state}
If you apply this combinator, the aggregate function doesn't return the resulting value (such as the number of unique values for the [uniq](reference.md#agg_function-uniq) function), but an intermediate state of the aggregation (for `uniq`, this is the hash table for calculating the number of unique values). This is an `AggregateFunction(...)` that can be used for further processing or stored in a table to finish aggregating later. If you apply this combinator, the aggregate function doesn't return the resulting value (such as the number of unique values for the [uniq](reference.md#agg_function-uniq) function), but an intermediate state of the aggregation (for `uniq`, this is the hash table for calculating the number of unique values). This is an `AggregateFunction(...)` that can be used for further processing or stored in a table to finish aggregating later.
@ -40,10 +40,51 @@ If you apply this combinator, the aggregate function takes the intermediate aggr
Merges the intermediate aggregation states in the same way as the -Merge combinator. However, it doesn't return the resulting value, but an intermediate aggregation state, similar to the -State combinator. Merges the intermediate aggregation states in the same way as the -Merge combinator. However, it doesn't return the resulting value, but an intermediate aggregation state, similar to the -State combinator.
## -ForEach ## -ForEach {#agg-functions-combinator-foreach}
Converts an aggregate function for tables into an aggregate function for arrays that aggregates the corresponding array items and returns an array of results. For example, `sumForEach` for the arrays `[1, 2]`, `[3, 4, 5]`and`[6, 7]`returns the result `[10, 13, 5]` after adding together the corresponding array items. Converts an aggregate function for tables into an aggregate function for arrays that aggregates the corresponding array items and returns an array of results. For example, `sumForEach` for the arrays `[1, 2]`, `[3, 4, 5]`and`[6, 7]`returns the result `[10, 13, 5]` after adding together the corresponding array items.
## -OrDefault {#agg-functions-combinator-ordefault}
Fills the default value of the aggregate function's return type if there is nothing to aggregate.
```sql
SELECT avg(number), avgOrDefault(number) FROM numbers(0)
```
```text
┌─avg(number)─┬─avgOrDefault(number)─┐
│ nan │ 0 │
└─────────────┴──────────────────────┘
```
## -OrNull {#agg-functions-combinator-ornull}
Fills `null` if there is nothing to aggregate. The return column will be nullable.
```sql
SELECT avg(number), avgOrNull(number) FROM numbers(0)
```
```text
┌─avg(number)─┬─avgOrNull(number)─┐
│ nan │ ᴺᵁᴸᴸ │
└─────────────┴───────────────────┘
```
-OrDefault and -OrNull can be combined with other combinators. It is useful when the aggregate function does not accept the empty input.
```sql
SELECT avgOrNullIf(x, x > 10)
FROM
(
SELECT toDecimal32(1.23, 2) AS x
)
```
```text
┌─avgOrNullIf(x, greater(x, 10))─┐
│ ᴺᵁᴸᴸ │
└────────────────────────────────┘
```
## -Resample {#agg_functions-combinator-resample} ## -Resample {#agg_functions-combinator-resample}
Lets you divide data into groups, and then separately aggregates the data in those groups. Groups are created by splitting the values from one column into intervals. Lets you divide data into groups, and then separately aggregates the data in those groups. Groups are created by splitting the values from one column into intervals.
@ -85,7 +126,7 @@ Let's get the names of the people whose age lies in the intervals of `[30,60)` a
To aggregate names in an array, we use the [groupArray](reference.md#agg_function-grouparray) aggregate function. It takes one argument. In our case, it's the `name` column. The `groupArrayResample` function should use the `age` column to aggregate names by age. To define the required intervals, we pass the `30, 75, 30` arguments into the `groupArrayResample` function. To aggregate names in an array, we use the [groupArray](reference.md#agg_function-grouparray) aggregate function. It takes one argument. In our case, it's the `name` column. The `groupArrayResample` function should use the `age` column to aggregate names by age. To define the required intervals, we pass the `30, 75, 30` arguments into the `groupArrayResample` function.
```sql ```sql
SELECT groupArrayResample(30, 75, 30)(name, age) from people SELECT groupArrayResample(30, 75, 30)(name, age) FROM people
``` ```
```text ```text
┌─groupArrayResample(30, 75, 30)(name, age)─────┐ ┌─groupArrayResample(30, 75, 30)(name, age)─────┐