mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
aggregate function states addition
This commit is contained in:
parent
b32e0c48af
commit
fb923dcbf5
@ -1130,6 +1130,27 @@ class FunctionBinaryArithmetic : public IFunction
|
||||
return FunctionFactory::instance().get(function_name.str(), context);
|
||||
}
|
||||
|
||||
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
|
||||
{
|
||||
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
|
||||
return false;
|
||||
auto is_uint_type = [](const DataTypePtr & type)
|
||||
{
|
||||
return checkDataType<DataTypeUInt8>(type.get()) || checkDataType<DataTypeUInt16>(type.get())
|
||||
|| checkDataType<DataTypeUInt32>(type.get()) || checkDataType<DataTypeUInt64>(type.get());
|
||||
};
|
||||
return ((checkDataType<DataTypeAggregateFunction>(type0.get()) && is_uint_type(type1))
|
||||
|| (is_uint_type(type0) && checkDataType<DataTypeAggregateFunction>(type1.get())));
|
||||
}
|
||||
|
||||
bool isAggregateAddition(const DataTypePtr & type0, const DataTypePtr & type1) const
|
||||
{
|
||||
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, PlusImpl<UInt8, UInt8>>)
|
||||
return false;
|
||||
return checkDataType<DataTypeAggregateFunction>(type0.get())
|
||||
&& checkDataType<DataTypeAggregateFunction>(type1.get());
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr auto name = Name::name;
|
||||
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionBinaryArithmetic>(context); }
|
||||
@ -1156,6 +1177,24 @@ public:
|
||||
return arguments[1];
|
||||
}
|
||||
|
||||
/// Special case - addition of two aggregate functions states
|
||||
if (isAggregateAddition(arguments[0], arguments[1]))
|
||||
{
|
||||
const DataTypeAggregateFunction * new_arguments[2];
|
||||
for (size_t i = 0; i < 2; ++i)
|
||||
new_arguments[i] = typeid_cast<const DataTypeAggregateFunction *>(arguments[i].get());
|
||||
|
||||
if (new_arguments[0]->getFunctionName() != new_arguments[1]->getFunctionName())
|
||||
throw Exception("Cannot add aggregate states of different functions: "
|
||||
+ new_arguments[0]->getFunctionName() + " and " + new_arguments[1]->getFunctionName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
if (new_arguments[0]->getReturnType()->getName() != new_arguments[1]->getReturnType()->getName())
|
||||
throw Exception("Cannot add aggregate states with different return types: "
|
||||
+ new_arguments[0]->getReturnType()->getName() + " and " + new_arguments[1]->getReturnType()->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return arguments[0];
|
||||
}
|
||||
|
||||
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
|
||||
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
|
||||
{
|
||||
@ -1206,19 +1245,6 @@ public:
|
||||
return type_res;
|
||||
}
|
||||
|
||||
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
|
||||
{
|
||||
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
|
||||
return false;
|
||||
auto is_uint_type = [](const DataTypePtr & type)
|
||||
{
|
||||
return checkDataType<DataTypeUInt8>(type.get()) || checkDataType<DataTypeUInt16>(type.get())
|
||||
|| checkDataType<DataTypeUInt32>(type.get()) || checkDataType<DataTypeUInt64>(type.get());
|
||||
};
|
||||
return ((checkDataType<DataTypeAggregateFunction>(type0.get()) && is_uint_type(type1))
|
||||
|| (is_uint_type(type0) && checkDataType<DataTypeAggregateFunction>(type1.get())));
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
|
||||
{
|
||||
/// Special case when multiply aggregate function state
|
||||
@ -1267,6 +1293,24 @@ public:
|
||||
m /= 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Special case - addition of two aggregate functions states
|
||||
if (isAggregateAddition(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
|
||||
{
|
||||
const ColumnAggregateFunction * columns[2];
|
||||
for (size_t i = 0; i < 2; ++i)
|
||||
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
|
||||
|
||||
auto arena = std::make_shared<Arena>();
|
||||
auto column_to = ColumnAggregateFunction::create(columns[0]->getAggregateFunction(), Arenas(1, arena));
|
||||
column_to->reserve(input_rows_count);
|
||||
|
||||
for(size_t i = 0; i < input_rows_count; ++i)
|
||||
{
|
||||
column_to->insertFrom(columns[0]->getData()[i]);
|
||||
column_to->insertMergeFrom(columns[1]->getData()[i]);
|
||||
}
|
||||
|
||||
block.getByPosition(result).column = std::move(column_to);
|
||||
return;
|
||||
|
@ -0,0 +1,6 @@
|
||||
4
|
||||
7 4 3
|
||||
1
|
||||
3
|
||||
[1,1,2,3]
|
||||
[1,2,3]
|
@ -0,0 +1,19 @@
|
||||
USE test;
|
||||
SET send_logs_level = 'none';
|
||||
DROP TABLE IF EXISTS add_aggregate;
|
||||
CREATE TABLE add_aggregate(a UInt32, b UInt32) ENGINE = Memory;
|
||||
|
||||
INSERT INTO add_aggregate VALUES(1, 2);
|
||||
INSERT INTO add_aggregate VALUES(3, 1);
|
||||
|
||||
SELECT countMerge(x + y) FROM (SELECT countState(a) as x, countState(b) as y from add_aggregate);
|
||||
SELECT sumMerge(x + y), sumMerge(x), sumMerge(y) FROM (SELECT sumState(a) as x, sumState(b) as y from add_aggregate);
|
||||
SELECT sumMerge(x) FROM (SELECT sumState(a) + countState(b) as x FROM add_aggregate); -- { serverError 43 }
|
||||
SELECT sumMerge(x) FROM (SELECT sumState(a) + sumState(toInt32(b)) as x FROM add_aggregate); -- { serverError 43 }
|
||||
|
||||
SELECT minMerge(x) FROM (SELECT minState(a) + minState(b) as x FROM add_aggregate);
|
||||
|
||||
SELECT uniqMerge(x + y) FROM (SELECT uniqState(a) as x, uniqState(b) as y FROM add_aggregate);
|
||||
|
||||
SELECT arraySort(groupArrayMerge(x + y)) FROM (SELECT groupArrayState(a) AS x, groupArrayState(b) as y FROM add_aggregate);
|
||||
SELECT arraySort(groupUniqArrayMerge(x + y)) FROM (SELECT groupUniqArrayState(a) AS x, groupUniqArrayState(b) as y FROM add_aggregate);
|
Loading…
Reference in New Issue
Block a user