diff --git a/src/AggregateFunctions/AggregateFunctionTopK.h b/src/AggregateFunctions/AggregateFunctionTopK.h index f77fc482685..791df21d1a7 100644 --- a/src/AggregateFunctions/AggregateFunctionTopK.h +++ b/src/AggregateFunctions/AggregateFunctionTopK.h @@ -64,7 +64,10 @@ public: void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { - this->data(place).value.merge(this->data(rhs).value); + auto & set = this->data(place).value; + if (set.capacity() != reserved) + set.resize(reserved); + set.merge(this->data(rhs).value); } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override @@ -197,7 +200,10 @@ public: void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { - this->data(place).value.merge(this->data(rhs).value); + auto & set = this->data(place).value; + if (set.capacity() != reserved) + set.resize(reserved); + set.merge(this->data(rhs).value); } void insertResultInto(AggregateDataPtr place, IColumn & to, Arena *) const override diff --git a/tests/queries/0_stateless/01409_topK_merge.reference b/tests/queries/0_stateless/01409_topK_merge.reference new file mode 100644 index 00000000000..69db2381018 --- /dev/null +++ b/tests/queries/0_stateless/01409_topK_merge.reference @@ -0,0 +1,6 @@ +AggregateFunctionTopK +20 +20 +AggregateFunctionTopKGenericData +20 +20 diff --git a/tests/queries/0_stateless/01409_topK_merge.sql b/tests/queries/0_stateless/01409_topK_merge.sql new file mode 100644 index 00000000000..5ac7c350093 --- /dev/null +++ b/tests/queries/0_stateless/01409_topK_merge.sql @@ -0,0 +1,13 @@ +drop table if exists data_01409; +create table data_01409 engine=Memory as select * from numbers(20); + +-- easier to check merging via distributed tables +-- but can be done vai topKMerge(topKState()) as well + +select 'AggregateFunctionTopK'; +select length(topK(20)(number)) from remote('127.{1,1}', currentDatabase(), data_01409); +select length(topKWeighted(20)(number, 1)) from remote('127.{1,1}', currentDatabase(), data_01409); + +select 'AggregateFunctionTopKGenericData'; +select length(topK(20)((number, ''))) from remote('127.{1,1}', currentDatabase(), data_01409); +select length(topKWeighted(20)((number, ''), 1)) from remote('127.{1,1}', currentDatabase(), data_01409);