diff --git a/src/AggregateFunctions/AggregateFunctionMap.h b/src/AggregateFunctions/AggregateFunctionMap.h index 9ed4b48c281..f2c56755504 100644 --- a/src/AggregateFunctions/AggregateFunctionMap.h +++ b/src/AggregateFunctions/AggregateFunctionMap.h @@ -84,6 +84,11 @@ private: using Base = IAggregateFunctionDataHelper>; public: + bool isState() const override + { + return nested_func->isState(); + } + AggregateFunctionMap(AggregateFunctionPtr nested, const DataTypes & types) : Base(types, nested->getParameters()), nested_func(nested) { if (types.empty()) diff --git a/src/Interpreters/AggregationUtils.cpp b/src/Interpreters/AggregationUtils.cpp index 43062546450..9b237a24928 100644 --- a/src/Interpreters/AggregationUtils.cpp +++ b/src/Interpreters/AggregationUtils.cpp @@ -1,4 +1,5 @@ #include +#include namespace DB { @@ -6,6 +7,7 @@ namespace DB namespace ErrorCodes { extern const int LOGICAL_ERROR; + extern const int ILLEGAL_COLUMN; } OutputBlockColumns prepareOutputBlockColumns( @@ -50,19 +52,27 @@ OutputBlockColumns prepareOutputBlockColumns( if (aggregate_functions[i]->isState()) { + IColumn * column_to_check = final_aggregate_columns[i].get(); + /// Aggregate state can be wrapped into array/map if aggregate function ends with -Resample/Map combinator + if (auto * column_map = typeid_cast(final_aggregate_columns[i].get())) + column_to_check = &column_map->getNestedData().getColumn(1); + else if (auto * column_array = typeid_cast(final_aggregate_columns[i].get())) + column_to_check = &column_array->getData(); + /// The ColumnAggregateFunction column captures the shared ownership of the arena with aggregate function states. - if (auto * column_aggregate_func = typeid_cast(final_aggregate_columns[i].get())) + if (auto * column_aggregate_func = typeid_cast(column_to_check)) + { for (auto & pool : aggregates_pools) column_aggregate_func->addArena(pool); - - /// Aggregate state can be wrapped into array if aggregate function ends with -Resample combinator. - final_aggregate_columns[i]->forEachSubcolumn( - [&aggregates_pools](auto & subcolumn) - { - if (auto * column_aggregate_func = typeid_cast(subcolumn.get())) - for (auto & pool : aggregates_pools) - column_aggregate_func->addArena(pool); - }); + } + else + { + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Aggregate function {} was marked as State, but result column {} doesn't contain AggregateFunction column", + aggregate_functions[i]->getName(), + final_aggregate_columns[i]->getName()); + } } } } diff --git a/tests/queries/0_stateless/02418_map_combinator_bug.reference b/tests/queries/0_stateless/02418_map_combinator_bug.reference new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/queries/0_stateless/02418_map_combinator_bug.sql b/tests/queries/0_stateless/02418_map_combinator_bug.sql new file mode 100644 index 00000000000..9fae64cd3f9 --- /dev/null +++ b/tests/queries/0_stateless/02418_map_combinator_bug.sql @@ -0,0 +1,11 @@ +drop table if exists test; +create table test (x Map(UInt8, AggregateFunction(uniq, UInt64))) engine=Memory; +insert into test select uniqStateMap(map(1, number)) from numbers(10); +select * from test format Null; +drop table test; + +create table test (x AggregateFunction(uniq, UInt64), y Int64) engine=Memory; +insert into test select uniqState(number) as x, number as y from numbers(10) group by number; +select uniqStateMap(map(1, x)) OVER (PARTITION BY y) from test; -- {serverError ILLEGAL_COLUMN} +drop table test; +