diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h index bf4ad82c98d..435d5c4c638 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMerge.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMerge.h @@ -34,6 +34,11 @@ public: return nested_func->getReturnType(); } + AggregateFunctionPtr getNestedFunction() const + { + return nested_func_owner; + } + void setArguments(const DataTypes & arguments) override { if (arguments.size() != 1) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionState.cpp b/dbms/src/AggregateFunctions/AggregateFunctionState.cpp index f9d60a2f885..0bb1345cf00 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionState.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionState.cpp @@ -1,8 +1,34 @@ #include +#include namespace DB { +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} + +DataTypePtr AggregateFunctionState::getReturnType() const +{ + auto ptr = std::make_shared(nested_func_owner, arguments, params); + + /// Special case: it is -MergeState combinator + if (typeid_cast(ptr->getFunction().get())) + { + if (arguments.size() != 1) + throw Exception("Combinator -MergeState expects only one argument", ErrorCodes::BAD_ARGUMENTS); + + if (!typeid_cast(arguments[0].get())) + throw Exception("Combinator -MergeState expects argument with AggregateFunction type", ErrorCodes::BAD_ARGUMENTS); + + return arguments[0]; + } + + return ptr; +} + + AggregateFunctionPtr createAggregateFunctionState(AggregateFunctionPtr & nested) { return std::make_shared(nested); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionState.h b/dbms/src/AggregateFunctions/AggregateFunctionState.h index 59e5f984399..86511be93d0 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionState.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionState.h @@ -1,3 +1,4 @@ + #pragma once #include @@ -30,10 +31,7 @@ public: return nested_func->getName() + "State"; } - DataTypePtr getReturnType() const override - { - return std::make_shared(nested_func_owner, arguments, params); - } + DataTypePtr getReturnType() const override; void setArguments(const DataTypes & arguments_) override { diff --git a/dbms/tests/queries/0_stateless/00432_aggregate_function_scalars_and_constants.reference b/dbms/tests/queries/0_stateless/00432_aggregate_function_scalars_and_constants.reference index 2b71732c082..6087cae7ec5 100644 --- a/dbms/tests/queries/0_stateless/00432_aggregate_function_scalars_and_constants.reference +++ b/dbms/tests/queries/0_stateless/00432_aggregate_function_scalars_and_constants.reference @@ -27,3 +27,6 @@ 1 0 + +1 +2 diff --git a/dbms/tests/queries/0_stateless/00432_aggregate_function_scalars_and_constants.sql b/dbms/tests/queries/0_stateless/00432_aggregate_function_scalars_and_constants.sql index 8235c6af5e9..e8d9704c3a9 100644 --- a/dbms/tests/queries/0_stateless/00432_aggregate_function_scalars_and_constants.sql +++ b/dbms/tests/queries/0_stateless/00432_aggregate_function_scalars_and_constants.sql @@ -38,3 +38,20 @@ SELECT arrayReduce('groupUniqArrayMergeIf', SELECT ''; SELECT arrayReduce('avgState', [0]) IN (arrayReduce('avgState', [0, 1]), arrayReduce('avgState', [0])); SELECT arrayReduce('avgState', [0]) IN (arrayReduce('avgState', [0, 1]), arrayReduce('avgState', [1])); + +SELECT ''; +SELECT arrayReduce('uniqExactMerge', + [arrayReduce('uniqExactMergeState', + [ + arrayReduce('uniqExactState', [12345678901]), + arrayReduce('uniqExactState', [12345678901]) + ]) + ]); + +SELECT arrayReduce('uniqExactMerge', + [arrayReduce('uniqExactMergeState', + [ + arrayReduce('uniqExactState', [12345678901]), + arrayReduce('uniqExactState', [12345678902]) + ]) + ]);