aggregate function states addition

This commit is contained in:
CurtizJ 2018-09-06 20:59:23 +03:00
parent b32e0c48af
commit fb923dcbf5
3 changed files with 82 additions and 13 deletions

View File

@ -1130,6 +1130,27 @@ class FunctionBinaryArithmetic : public IFunction
return FunctionFactory::instance().get(function_name.str(), context);
}
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
{
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
return false;
auto is_uint_type = [](const DataTypePtr & type)
{
return checkDataType<DataTypeUInt8>(type.get()) || checkDataType<DataTypeUInt16>(type.get())
|| checkDataType<DataTypeUInt32>(type.get()) || checkDataType<DataTypeUInt64>(type.get());
};
return ((checkDataType<DataTypeAggregateFunction>(type0.get()) && is_uint_type(type1))
|| (is_uint_type(type0) && checkDataType<DataTypeAggregateFunction>(type1.get())));
}
bool isAggregateAddition(const DataTypePtr & type0, const DataTypePtr & type1) const
{
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, PlusImpl<UInt8, UInt8>>)
return false;
return checkDataType<DataTypeAggregateFunction>(type0.get())
&& checkDataType<DataTypeAggregateFunction>(type1.get());
}
public:
static constexpr auto name = Name::name;
static FunctionPtr create(const Context & context) { return std::make_shared<FunctionBinaryArithmetic>(context); }
@ -1156,6 +1177,24 @@ public:
return arguments[1];
}
/// Special case - addition of two aggregate functions states
if (isAggregateAddition(arguments[0], arguments[1]))
{
const DataTypeAggregateFunction * new_arguments[2];
for (size_t i = 0; i < 2; ++i)
new_arguments[i] = typeid_cast<const DataTypeAggregateFunction *>(arguments[i].get());
if (new_arguments[0]->getFunctionName() != new_arguments[1]->getFunctionName())
throw Exception("Cannot add aggregate states of different functions: "
+ new_arguments[0]->getFunctionName() + " and " + new_arguments[1]->getFunctionName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (new_arguments[0]->getReturnType()->getName() != new_arguments[1]->getReturnType()->getName())
throw Exception("Cannot add aggregate states with different return types: "
+ new_arguments[0]->getReturnType()->getName() + " and " + new_arguments[1]->getReturnType()->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return arguments[0];
}
/// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval.
if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1]))
{
@ -1206,19 +1245,6 @@ public:
return type_res;
}
bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const
{
if constexpr (!std::is_same_v<Op<UInt8, UInt8>, MultiplyImpl<UInt8, UInt8>>)
return false;
auto is_uint_type = [](const DataTypePtr & type)
{
return checkDataType<DataTypeUInt8>(type.get()) || checkDataType<DataTypeUInt16>(type.get())
|| checkDataType<DataTypeUInt32>(type.get()) || checkDataType<DataTypeUInt64>(type.get());
};
return ((checkDataType<DataTypeAggregateFunction>(type0.get()) && is_uint_type(type1))
|| (is_uint_type(type0) && checkDataType<DataTypeAggregateFunction>(type1.get())));
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override
{
/// Special case when multiply aggregate function state
@ -1267,6 +1293,24 @@ public:
m /= 2;
}
}
}
/// Special case - addition of two aggregate functions states
if (isAggregateAddition(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type))
{
const ColumnAggregateFunction * columns[2];
for (size_t i = 0; i < 2; ++i)
columns[i] = typeid_cast<const ColumnAggregateFunction *>(block.getByPosition(arguments[i]).column.get());
auto arena = std::make_shared<Arena>();
auto column_to = ColumnAggregateFunction::create(columns[0]->getAggregateFunction(), Arenas(1, arena));
column_to->reserve(input_rows_count);
for(size_t i = 0; i < input_rows_count; ++i)
{
column_to->insertFrom(columns[0]->getData()[i]);
column_to->insertMergeFrom(columns[1]->getData()[i]);
}
block.getByPosition(result).column = std::move(column_to);
return;

View File

@ -0,0 +1,6 @@
4
7 4 3
1
3
[1,1,2,3]
[1,2,3]

View File

@ -0,0 +1,19 @@
USE test;
SET send_logs_level = 'none';
DROP TABLE IF EXISTS add_aggregate;
CREATE TABLE add_aggregate(a UInt32, b UInt32) ENGINE = Memory;
INSERT INTO add_aggregate VALUES(1, 2);
INSERT INTO add_aggregate VALUES(3, 1);
SELECT countMerge(x + y) FROM (SELECT countState(a) as x, countState(b) as y from add_aggregate);
SELECT sumMerge(x + y), sumMerge(x), sumMerge(y) FROM (SELECT sumState(a) as x, sumState(b) as y from add_aggregate);
SELECT sumMerge(x) FROM (SELECT sumState(a) + countState(b) as x FROM add_aggregate); -- { serverError 43 }
SELECT sumMerge(x) FROM (SELECT sumState(a) + sumState(toInt32(b)) as x FROM add_aggregate); -- { serverError 43 }
SELECT minMerge(x) FROM (SELECT minState(a) + minState(b) as x FROM add_aggregate);
SELECT uniqMerge(x + y) FROM (SELECT uniqState(a) as x, uniqState(b) as y FROM add_aggregate);
SELECT arraySort(groupArrayMerge(x + y)) FROM (SELECT groupArrayState(a) AS x, groupArrayState(b) as y FROM add_aggregate);
SELECT arraySort(groupUniqArrayMerge(x + y)) FROM (SELECT groupUniqArrayState(a) AS x, groupUniqArrayState(b) as y FROM add_aggregate);