Merge pull request #16853 from amosbird/ss

Add -SimpleState combinator
This commit is contained in:
Kruglov Pavel 2020-12-17 14:00:05 +03:00 committed by GitHub
commit d82c23d5cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 162 additions and 9 deletions

View File

@ -25,6 +25,10 @@ Example 2: `uniqArray(arr)` Counts the number of unique elements in all a
-If and -Array can be combined. However, Array must come first, then If. Examples: `uniqArrayIf(arr, cond)`, `quantilesTimingArrayIf(level1, level2)(arr, cond)`. Due to this order, the cond argument wont be an array.
## -SimpleState {#agg-functions-combinator-simplestate}
If you apply this combinator, the aggregate function returns the same value but with a different type. This is an `SimpleAggregateFunction(...)` that can be stored in a table to work with [AggregatingMergeTree](../../engines/table-engines/mergetree-family/aggregatingmergetree.md) table engines.
## -State {#agg-functions-combinator-state}
If you apply this combinator, the aggregate function doesnt return the resulting value (such as the number of unique values for the [uniq](../../sql-reference/aggregate-functions/reference/uniq.md#agg_function-uniq) function), but an intermediate state of the aggregation (for `uniq`, this is the hash table for calculating the number of unique values). This is an `AggregateFunction(...)` that can be used for further processing or stored in a table to finish aggregating later.

View File

@ -0,0 +1,32 @@
#include <AggregateFunctions/AggregateFunctionCombinatorFactory.h>
#include <AggregateFunctions/AggregateFunctionSimpleState.h>
namespace DB
{
namespace
{
class AggregateFunctionCombinatorSimpleState final : public IAggregateFunctionCombinator
{
public:
String getName() const override { return "SimpleState"; }
DataTypes transformArguments(const DataTypes & arguments) const override { return arguments; }
AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{
return std::make_shared<AggregateFunctionSimpleState>(nested_function, arguments, params);
}
};
}
void registerAggregateFunctionCombinatorSimpleState(AggregateFunctionCombinatorFactory & factory)
{
factory.registerCombinator(std::make_shared<AggregateFunctionCombinatorSimpleState>());
}
}

View File

@ -0,0 +1,77 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <DataTypes/DataTypeCustomSimpleAggregateFunction.h>
#include <DataTypes/DataTypeFactory.h>
namespace DB
{
/** Not an aggregate function, but an adapter of aggregate functions.
* Aggregate functions with the `SimpleState` suffix is almost identical to the corresponding ones,
* except the return type becomes DataTypeCustomSimpleAggregateFunction.
*/
class AggregateFunctionSimpleState final : public IAggregateFunctionHelper<AggregateFunctionSimpleState>
{
private:
AggregateFunctionPtr nested_func;
DataTypes arguments;
Array params;
public:
AggregateFunctionSimpleState(AggregateFunctionPtr nested_, const DataTypes & arguments_, const Array & params_)
: IAggregateFunctionHelper<AggregateFunctionSimpleState>(arguments_, params_)
, nested_func(nested_)
, arguments(arguments_)
, params(params_)
{
}
String getName() const override { return nested_func->getName() + "SimpleState"; }
DataTypePtr getReturnType() const override
{
DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(nested_func);
// Need to make a clone because it'll be customized.
auto storage_type = DataTypeFactory::instance().get(nested_func->getReturnType()->getName());
DataTypeCustomNamePtr custom_name
= std::make_unique<DataTypeCustomSimpleAggregateFunction>(nested_func, DataTypes{nested_func->getReturnType()}, params);
storage_type->setCustomization(std::make_unique<DataTypeCustomDesc>(std::move(custom_name), nullptr));
return storage_type;
}
void create(AggregateDataPtr place) const override { nested_func->create(place); }
void destroy(AggregateDataPtr place) const noexcept override { nested_func->destroy(place); }
bool hasTrivialDestructor() const override { return nested_func->hasTrivialDestructor(); }
size_t sizeOfData() const override { return nested_func->sizeOfData(); }
size_t alignOfData() const override { return nested_func->alignOfData(); }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
nested_func->add(place, columns, row_num, arena);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override { nested_func->merge(place, rhs, arena); }
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { nested_func->serialize(place, buf); }
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
nested_func->deserialize(place, buf, arena);
}
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override
{
nested_func->insertResultInto(place, to, arena);
}
bool allocatesMemoryInArena() const override { return nested_func->allocatesMemoryInArena(); }
AggregateFunctionPtr getNestedFunction() const { return nested_func; }
};
}

View File

@ -47,6 +47,7 @@ class AggregateFunctionCombinatorFactory;
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorForEach(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorSimpleState(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorState(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorMerge(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory &);
@ -104,6 +105,7 @@ void registerAggregateFunctions()
registerAggregateFunctionCombinatorIf(factory);
registerAggregateFunctionCombinatorArray(factory);
registerAggregateFunctionCombinatorForEach(factory);
registerAggregateFunctionCombinatorSimpleState(factory);
registerAggregateFunctionCombinatorState(factory);
registerAggregateFunctionCombinatorMerge(factory);
registerAggregateFunctionCombinatorNull(factory);

View File

@ -41,6 +41,7 @@ SRCS(
AggregateFunctionRetention.cpp
AggregateFunctionSequenceMatch.cpp
AggregateFunctionSimpleLinearRegression.cpp
AggregateFunctionSimpleState.cpp
AggregateFunctionState.cpp
AggregateFunctionStatistics.cpp
AggregateFunctionStatisticsSimple.cpp

View File

@ -25,10 +25,19 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
static const std::vector<String> supported_functions{"any", "anyLast", "min",
"max", "sum", "sumWithOverflow", "groupBitAnd", "groupBitOr", "groupBitXor",
"sumMap", "minMap", "maxMap", "groupArrayArray", "groupUniqArrayArray"};
void DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(const AggregateFunctionPtr & function)
{
static const std::vector<String> supported_functions{"any", "anyLast", "min",
"max", "sum", "sumWithOverflow", "groupBitAnd", "groupBitOr", "groupBitXor",
"sumMap", "minMap", "maxMap", "groupArrayArray", "groupUniqArrayArray"};
// check function
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))
{
throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","),
ErrorCodes::BAD_ARGUMENTS);
}
}
String DataTypeCustomSimpleAggregateFunction::getName() const
{
@ -114,12 +123,7 @@ static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & argum
AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties);
// check function
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))
{
throw Exception("Unsupported aggregate function " + function->getName() + ", supported functions are " + boost::algorithm::join(supported_functions, ","),
ErrorCodes::BAD_ARGUMENTS);
}
DataTypeCustomSimpleAggregateFunction::checkSupportedFunctions(function);
DataTypePtr storage_type = DataTypeFactory::instance().get(argument_types[0]->getName());

View File

@ -37,6 +37,7 @@ public:
const AggregateFunctionPtr getFunction() const { return function; }
String getName() const override;
static void checkSupportedFunctions(const AggregateFunctionPtr & function);
};
}

View File

@ -451,6 +451,7 @@ public:
static bool isSpecialCompressionAllowed(const SubstreamPath & path);
private:
friend class DataTypeFactory;
friend class AggregateFunctionSimpleState;
/// Customize this DataType
void setCustomization(DataTypeCustomDescPtr custom_desc_) const;

View File

@ -0,0 +1,14 @@
SimpleAggregateFunction(any, UInt64) 0
SimpleAggregateFunction(anyLast, UInt64) 0
SimpleAggregateFunction(min, UInt64) 0
SimpleAggregateFunction(max, UInt64) 0
SimpleAggregateFunction(sum, UInt64) 0
SimpleAggregateFunction(sumWithOverflow, UInt64) 0
SimpleAggregateFunction(groupBitAnd, UInt64) 0
SimpleAggregateFunction(groupBitOr, UInt64) 0
SimpleAggregateFunction(groupBitXor, UInt64) 0
SimpleAggregateFunction(sumMap, Tuple(Array(UInt64), Array(UInt64))) ([],[])
SimpleAggregateFunction(minMap, Tuple(Array(UInt64), Array(UInt64))) ([0],[0])
SimpleAggregateFunction(maxMap, Tuple(Array(UInt64), Array(UInt64))) ([0],[0])
SimpleAggregateFunction(groupArrayArray, Array(UInt64)) [0]
SimpleAggregateFunction(groupUniqArrayArray, Array(UInt64)) [0]

View File

@ -0,0 +1,17 @@
with anySimpleState(number) as c select toTypeName(c), c from numbers(1);
with anyLastSimpleState(number) as c select toTypeName(c), c from numbers(1);
with minSimpleState(number) as c select toTypeName(c), c from numbers(1);
with maxSimpleState(number) as c select toTypeName(c), c from numbers(1);
with sumSimpleState(number) as c select toTypeName(c), c from numbers(1);
with sumWithOverflowSimpleState(number) as c select toTypeName(c), c from numbers(1);
with groupBitAndSimpleState(number) as c select toTypeName(c), c from numbers(1);
with groupBitOrSimpleState(number) as c select toTypeName(c), c from numbers(1);
with groupBitXorSimpleState(number) as c select toTypeName(c), c from numbers(1);
with sumMapSimpleState(([number], [number])) as c select toTypeName(c), c from numbers(1);
with minMapSimpleState(([number], [number])) as c select toTypeName(c), c from numbers(1);
with maxMapSimpleState(([number], [number])) as c select toTypeName(c), c from numbers(1);
with groupArrayArraySimpleState([number]) as c select toTypeName(c), c from numbers(1);
with groupUniqArrayArraySimpleState([number]) as c select toTypeName(c), c from numbers(1);
-- non-SimpleAggregateFunction
with countSimpleState(number) as c select toTypeName(c), c from numbers(1); -- { serverError 36 }