diff --git a/dbms/src/Functions/FunctionBinaryArithmetic.h b/dbms/src/Functions/FunctionBinaryArithmetic.h index a9149fca0ed..10e7be67919 100644 --- a/dbms/src/Functions/FunctionBinaryArithmetic.h +++ b/dbms/src/Functions/FunctionBinaryArithmetic.h @@ -547,21 +547,27 @@ class FunctionBinaryArithmetic : public IFunction throw Exception{"Illegal column " + block.getByPosition(new_arguments[1]).column->getName() + " of argument of aggregation state multiply. Should be integer constant", ErrorCodes::ILLEGAL_COLUMN}; - const ColumnAggregateFunction * column = typeid_cast(block.getByPosition(new_arguments[0]).column.get()); - IAggregateFunction * function = column->getAggregateFunction().get(); + const IColumn & agg_state_column = *block.getByPosition(new_arguments[0]).column; + bool agg_state_is_const = agg_state_column.isColumnConst(); + const ColumnAggregateFunction & column = typeid_cast( + agg_state_is_const ? static_cast(agg_state_column).getDataColumn() : agg_state_column); + + AggregateFunctionPtr function = column.getAggregateFunction(); auto arena = std::make_shared(); - auto column_to = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena)); - column_to->reserve(input_rows_count); + size_t size = agg_state_is_const ? 1 : input_rows_count; - auto column_from = ColumnAggregateFunction::create(column->getAggregateFunction(), Arenas(1, arena)); - column_from->reserve(input_rows_count); + auto column_to = ColumnAggregateFunction::create(function, Arenas(1, arena)); + column_to->reserve(size); - for (size_t i = 0; i < input_rows_count; ++i) + auto column_from = ColumnAggregateFunction::create(function, Arenas(1, arena)); + column_from->reserve(size); + + for (size_t i = 0; i < size; ++i) { column_to->insertDefault(); - column_from->insertFrom(column->getData()[i]); + column_from->insertFrom(column.getData()[i]); } auto & vec_to = column_to->getData(); @@ -575,38 +581,55 @@ class FunctionBinaryArithmetic : public IFunction { if (m % 2) { - for (size_t i = 0; i < input_rows_count; ++i) + for (size_t i = 0; i < size; ++i) function->merge(vec_to[i], vec_from[i], arena.get()); --m; } else { - for (size_t i = 0; i < input_rows_count; ++i) + for (size_t i = 0; i < size; ++i) function->merge(vec_from[i], vec_from[i], arena.get()); m /= 2; } } - block.getByPosition(result).column = std::move(column_to); + if (agg_state_is_const) + block.getByPosition(result).column = ColumnConst::create(std::move(column_to), input_rows_count); + else + block.getByPosition(result).column = std::move(column_to); } /// Merge two aggregation states together. void executeAggregateAddition(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) const { - const ColumnAggregateFunction * columns[2]; - for (size_t i = 0; i < 2; ++i) - columns[i] = typeid_cast(block.getByPosition(arguments[i]).column.get()); + const IColumn & lhs_column = *block.getByPosition(arguments[0]).column; + const IColumn & rhs_column = *block.getByPosition(arguments[1]).column; - auto column_to = ColumnAggregateFunction::create(columns[0]->getAggregateFunction()); - column_to->reserve(input_rows_count); + bool lhs_is_const = lhs_column.isColumnConst(); + bool rhs_is_const = rhs_column.isColumnConst(); - for (size_t i = 0; i < input_rows_count; ++i) + const ColumnAggregateFunction & lhs = typeid_cast( + lhs_is_const ? static_cast(lhs_column).getDataColumn() : lhs_column); + const ColumnAggregateFunction & rhs = typeid_cast( + rhs_is_const ? static_cast(rhs_column).getDataColumn() : rhs_column); + + AggregateFunctionPtr function = lhs.getAggregateFunction(); + + size_t size = (lhs_is_const && rhs_is_const) ? 1 : input_rows_count; + + auto column_to = ColumnAggregateFunction::create(function); + column_to->reserve(size); + + for (size_t i = 0; i < size; ++i) { - column_to->insertFrom(columns[0]->getData()[i]); - column_to->insertMergeFrom(columns[1]->getData()[i]); + column_to->insertFrom(lhs.getData()[lhs_is_const ? 0 : i]); + column_to->insertMergeFrom(rhs.getData()[rhs_is_const ? 0 : i]); } - block.getByPosition(result).column = std::move(column_to); + if (lhs_is_const && rhs_is_const) + block.getByPosition(result).column = ColumnConst::create(std::move(column_to), input_rows_count); + else + block.getByPosition(result).column = std::move(column_to); } void executeDateTimeIntervalPlusMinus(Block & block, const ColumnNumbers & arguments, diff --git a/dbms/src/Functions/finalizeAggregation.cpp b/dbms/src/Functions/finalizeAggregation.cpp index c04bef41a82..21c62f5dd7e 100644 --- a/dbms/src/Functions/finalizeAggregation.cpp +++ b/dbms/src/Functions/finalizeAggregation.cpp @@ -43,6 +43,8 @@ public: return 1; } + bool useDefaultImplementationForConstants() const override { return true; } + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { const DataTypeAggregateFunction * type = checkAndGetDataType(arguments[0].get()); diff --git a/dbms/tests/queries/0_stateless/00919_sum_aggregate_states_constants.reference b/dbms/tests/queries/0_stateless/00919_sum_aggregate_states_constants.reference new file mode 100644 index 00000000000..84ab9cbe7ca --- /dev/null +++ b/dbms/tests/queries/0_stateless/00919_sum_aggregate_states_constants.reference @@ -0,0 +1,6 @@ +90 +90 +90 +90 +90 +90 diff --git a/dbms/tests/queries/0_stateless/00919_sum_aggregate_states_constants.sql b/dbms/tests/queries/0_stateless/00919_sum_aggregate_states_constants.sql new file mode 100644 index 00000000000..a0962e99f47 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00919_sum_aggregate_states_constants.sql @@ -0,0 +1,6 @@ +SELECT finalizeAggregation((SELECT sumState(number) FROM numbers(10)) + (SELECT sumState(number) FROM numbers(10))); +SELECT finalizeAggregation((SELECT sumState(number) FROM numbers(10)) + materialize((SELECT sumState(number) FROM numbers(10)))); +SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10))) + (SELECT sumState(number) FROM numbers(10))); +SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10))) + materialize((SELECT sumState(number) FROM numbers(10)))); +SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10)) + (SELECT sumState(number) FROM numbers(10)))); +SELECT materialize(finalizeAggregation((SELECT sumState(number) FROM numbers(10)) + (SELECT sumState(number) FROM numbers(10)))); diff --git a/dbms/tests/queries/0_stateless/00920_multiply_aggregate_states_constants.reference b/dbms/tests/queries/0_stateless/00920_multiply_aggregate_states_constants.reference new file mode 100644 index 00000000000..b56570c9497 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00920_multiply_aggregate_states_constants.reference @@ -0,0 +1,4 @@ +450 +450 +450 +450 diff --git a/dbms/tests/queries/0_stateless/00920_multiply_aggregate_states_constants.sql b/dbms/tests/queries/0_stateless/00920_multiply_aggregate_states_constants.sql new file mode 100644 index 00000000000..e88b5d520d1 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00920_multiply_aggregate_states_constants.sql @@ -0,0 +1,4 @@ +SELECT finalizeAggregation((SELECT sumState(number) FROM numbers(10)) * 10); +SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10))) * 10); +SELECT finalizeAggregation(materialize((SELECT sumState(number) FROM numbers(10)) * 10)); +SELECT materialize(finalizeAggregation((SELECT sumState(number) FROM numbers(10)) * 10));