diff --git a/dbms/src/Functions/FunctionsArithmetic.h b/dbms/src/Functions/FunctionsArithmetic.h index 5c3633a6d1a..a9e6079f946 100644 --- a/dbms/src/Functions/FunctionsArithmetic.h +++ b/dbms/src/Functions/FunctionsArithmetic.h @@ -780,12 +780,33 @@ class FunctionBinaryArithmetic : public IFunction return castType(left, [&](const auto & left) { return castType(right, [&](const auto & right) { return f(left, right); }); }); } - bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) const + bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1, bool & shift) const { return std::is_same_v, MultiplyImpl> - && checkDataType(type0.get()) - && (checkDataType(type1.get()) - || checkDataType(type1.get())); + && + ( + ( + checkDataType(type0.get()) + && + ( + checkDataType(type1.get()) + || checkDataType(type1.get()) + ) + && + !(shift = false) + ) + || + ( + checkDataType(type1.get()) + && + ( + checkDataType(type0.get()) + || checkDataType(type0.get()) + ) + && + (shift = true) + ) + ); } FunctionBuilderPtr getFunctionForIntervalArithmetic(const DataTypePtr & type0, const DataTypePtr & type1) const @@ -845,8 +866,9 @@ public: DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { /// Special case when multiply aggregate function state - if (isAggregateMultiply(arguments[0], arguments[1])) - return arguments[0]; + bool shift; + if (isAggregateMultiply(arguments[0], arguments[1], shift)) + return arguments[shift?1:0]; /// Special case when the function is plus or minus, one of arguments is Date/DateTime and another is Interval. if (auto function_builder = getFunctionForIntervalArithmetic(arguments[0], arguments[1])) @@ -888,12 +910,13 @@ public: void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t input_rows_count) override { - if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type)) + bool shift; + if (isAggregateMultiply(block.getByPosition(arguments[0]).type, block.getByPosition(arguments[1]).type, shift)) { - auto c = block.getByPosition(arguments[0]).column->cloneEmpty(); - size_t m = block.getByPosition(arguments[1]).column->getUInt(0); + auto c = block.getByPosition(arguments[shift?1:0]).column->cloneEmpty(); + size_t m = block.getByPosition(arguments[shift?0:1]).column->getUInt(0); for (size_t i = 0; i < m; ++i) - c->insertRangeFrom(*(block.getByPosition(arguments[0]).column.get()), 0, input_rows_count); + c->insertRangeFrom(*(block.getByPosition(arguments[shift?1:0]).column.get()), 0, input_rows_count); block.getByPosition(result).column = std::move(c); return; } diff --git a/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.reference b/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.reference index 604440660c8..3ebd70dce2f 100644 --- a/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.reference +++ b/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.reference @@ -1,3 +1,6 @@ 2 0 33 +2 +0 +18 diff --git a/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.sql b/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.sql index 1264e8b991b..209a84e3548 100644 --- a/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.sql +++ b/dbms/tests/queries/0_stateless/00647_multiply_aggregation_state.sql @@ -1,3 +1,6 @@ SELECT countMerge(x) AS y FROM ( SELECT countState() * 2 AS x FROM ( SELECT 1 )); SELECT countMerge(x) AS y FROM ( SELECT countState() * 0 AS x FROM ( SELECT 1 UNION ALL SELECT 2)); SELECT sumMerge(y) AS z FROM ( SELECT sumState(x) * 11 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x)); +SELECT countMerge(x) AS y FROM ( SELECT 2 * countState() AS x FROM ( SELECT 1 )); +SELECT countMerge(x) AS y FROM ( SELECT 0 * countState() AS x FROM ( SELECT 1 UNION ALL SELECT 2)); +SELECT sumMerge(y) AS z FROM ( SELECT 3 * sumState(x) * 2 AS y FROM ( SELECT 1 AS x UNION ALL SELECT 2 AS x));