diff --git a/dbms/src/Functions/FunctionsArithmetic.h b/dbms/src/Functions/FunctionsArithmetic.h index 8d2f8d06c14..37b45b34caf 100644 --- a/dbms/src/Functions/FunctionsArithmetic.h +++ b/dbms/src/Functions/FunctionsArithmetic.h @@ -909,16 +909,24 @@ public: const ColumnAggregateFunction * column = typeid_cast(block.getByPosition(new_arguments[0]).column.get()); IAggregateFunction * function = column->getAggregateFunction().get(); - MutableColumnPtr current = column->cloneEmpty(); auto arena = std::make_shared(); - auto & res = typeid_cast(*current); - auto & vec_to = res.getData(); - const auto & vec_from = column->getData(); + + auto column_to = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena)); + column_to->reserve(input_rows_count); + + auto column_from = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena)); + column_from->reserve(input_rows_count); for (size_t i = 0; i < input_rows_count; ++i) - res.insertDefault(); + { + column_to->insertDefault(); + column_from->insertFrom(column->getData()[i]); + } - size_t m = block.getByPosition(new_arguments[1]).column->getUInt(0); + auto & vec_to = column_to->getData(); + auto & vec_from = column_from->getData(); + + UInt64 m = block.getByPosition(new_arguments[1]).column->getUInt(0); /// We use exponentiation by squaring algorithm to perform multiplying aggregate states by N in O(log(N)) operations /// https://en.wikipedia.org/wiki/Exponentiation_by_squaring @@ -937,8 +945,8 @@ public: m /= 2; } } - - block.getByPosition(result).column = std::move(current); + + block.getByPosition(result).column = std::move(column_to); return; } diff --git a/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.reference b/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.reference index 3ebd70dce2f..ff08aade246 100644 --- a/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.reference +++ b/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.reference @@ -4,3 +4,10 @@ 2 0 18 +20 4 +2 +3 +2 +[1,1,1,1,1] +[1,1] +[1] diff --git a/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.sql b/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.sql index 209a84e3548..d7ebd7b0313 100644 --- a/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.sql +++ b/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.sql @@ -4,3 +4,19 @@ SELECT sumMerge(y) AS z FROM ( SELECT sumState(x) * 11 AS y FROM ( SELECT 1 AS x SELECT countMerge(x) AS y FROM ( SELECT 2 * countState() AS x FROM ( SELECT 1 )); SELECT countMerge(x) AS y FROM ( SELECT 0 * countState() AS x FROM ( SELECT 1 UNION ALL SELECT 2)); SELECT sumMerge(y) AS z FROM ( SELECT 3 * sumState(x) * 2 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x)); + +DROP TABLE IF EXISTS test.mult_aggregation; +CREATE TABLE test.mult_aggregation(a UInt32, b UInt32) ENGINE = Memory; +INSERT INTO test.mult_aggregation VALUES(1, 1); +INSERT INTO test.mult_aggregation VALUES(1, 3); + +SELECT sumMerge(x * 5), sumMerge(x) FROM (SELECT sumState(b) AS x FROM test.mult_aggregation); +SELECT uniqMerge(x * 10) FROM (SELECT uniqState(b) AS x FROM test.mult_aggregation); +SELECT maxMerge(x * 10) FROM (SELECT maxState(b) AS x FROM test.mult_aggregation); +SELECT avgMerge(x * 10) FROM (SELECT avgState(b) AS x FROM test.mult_aggregation); + +SELECT groupArrayMerge(y * 5) FROM (SELECT groupArrayState(x) AS y FROM (SELECT 1 AS x)); +SELECT groupArrayMerge(2)(y * 5) FROM (SELECT groupArrayState(2)(x) AS y FROM (SELECT 1 AS x)); +SELECT groupUniqArrayMerge(y * 5) FROM (SELECT groupUniqArrayState(x) AS y FROM (SELECT 1 AS x)); + +DROP TABLE IF EXISTS test.mult_aggregation;