diff --git a/dbms/src/Functions/runningAccumulate.cpp b/dbms/src/Functions/runningAccumulate.cpp index a4ccc1e1553..53dc5e19777 100644 --- a/dbms/src/Functions/runningAccumulate.cpp +++ b/dbms/src/Functions/runningAccumulate.cpp @@ -15,6 +15,7 @@ namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } @@ -46,10 +47,9 @@ public: return true; } - size_t getNumberOfArguments() const override - { - return 1; - } + bool isVariadic() const override { return true; } + + size_t getNumberOfArguments() const override { return 0; } bool isDeterministic() const override { return false; } @@ -60,6 +60,10 @@ public: DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { + if (arguments.size() < 1 || arguments.size() > 2) + throw Exception("Incorrect number of arguments of function " + getName() + ". Must be 1 or 2.", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + const DataTypeAggregateFunction * type = checkAndGetDataType(arguments[0].get()); if (!type) throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.", @@ -72,19 +76,24 @@ public: { const ColumnAggregateFunction * column_with_states = typeid_cast(&*block.getByPosition(arguments.at(0)).column); + if (!column_with_states) throw Exception("Illegal column " + block.getByPosition(arguments.at(0)).column->getName() + " of first argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN); + ColumnPtr column_with_groups; + + if (arguments.size() == 2) + column_with_groups = block.getByPosition(arguments[1]).column; + AggregateFunctionPtr aggregate_function_ptr = column_with_states->getAggregateFunction(); const IAggregateFunction & agg_func = *aggregate_function_ptr; AlignedBuffer place(agg_func.sizeOfData(), agg_func.alignOfData()); - agg_func.create(place.data()); - SCOPE_EXIT(agg_func.destroy(place.data())); + /// Will pass empty arena if agg_func does not allocate memory in arena std::unique_ptr arena = agg_func.allocatesMemoryInArena() ? std::make_unique() : nullptr; auto result_column_ptr = agg_func.getReturnType()->createColumn(); @@ -92,11 +101,32 @@ public: result_column.reserve(column_with_states->size()); const auto & states = column_with_states->getData(); + + bool state_created = false; + SCOPE_EXIT({ + if (state_created) + agg_func.destroy(place.data()); + }); + + size_t row_number = 0; for (const auto & state_to_add : states) { - /// Will pass empty arena if agg_func does not allocate memory in arena + if (row_number == 0 || (column_with_groups && column_with_groups->compareAt(row_number, row_number - 1, *column_with_groups, 1) != 0)) + { + if (state_created) + { + agg_func.destroy(place.data()); + state_created = false; + } + + agg_func.create(place.data()); + state_created = true; + } + agg_func.merge(place.data(), state_to_add, arena.get()); agg_func.insertResultInto(place.data(), result_column); + + ++row_number; } block.getByPosition(result).column = std::move(result_column_ptr); diff --git a/dbms/tests/queries/0_stateless/01012_reset_running_accumulate.reference b/dbms/tests/queries/0_stateless/01012_reset_running_accumulate.reference new file mode 100644 index 00000000000..98d21902f5c --- /dev/null +++ b/dbms/tests/queries/0_stateless/01012_reset_running_accumulate.reference @@ -0,0 +1,30 @@ +0 0 0 +0 6 6 +0 12 18 +0 18 36 +0 24 60 +1 1 1 +1 7 8 +1 13 21 +1 19 40 +1 25 65 +2 2 2 +2 8 10 +2 14 24 +2 20 44 +2 26 70 +3 3 3 +3 9 12 +3 15 27 +3 21 48 +3 27 75 +4 4 4 +4 10 14 +4 16 30 +4 22 52 +4 28 80 +5 5 5 +5 11 16 +5 17 33 +5 23 56 +5 29 85 diff --git a/dbms/tests/queries/0_stateless/01012_reset_running_accumulate.sql b/dbms/tests/queries/0_stateless/01012_reset_running_accumulate.sql new file mode 100644 index 00000000000..b9336b2f50c --- /dev/null +++ b/dbms/tests/queries/0_stateless/01012_reset_running_accumulate.sql @@ -0,0 +1,11 @@ +SELECT grouping, + item, + runningAccumulate(state, grouping) +FROM ( + SELECT number % 6 AS grouping, + number AS item, + sumState(number) AS state + FROM (SELECT number FROM system.numbers LIMIT 30) + GROUP BY grouping, item + ORDER BY grouping, item +); \ No newline at end of file