diff --git a/src/AggregateFunctions/AggregateFunctionIf.cpp b/src/AggregateFunctions/AggregateFunctionIf.cpp index d841fe8c06d..89688ce1ffd 100644 --- a/src/AggregateFunctions/AggregateFunctionIf.cpp +++ b/src/AggregateFunctions/AggregateFunctionIf.cpp @@ -1,5 +1,5 @@ -#include #include +#include #include "AggregateFunctionNull.h" @@ -11,6 +11,7 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int ILLEGAL_AGGREGATION; } class AggregateFunctionCombinatorIf final : public IAggregateFunctionCombinator @@ -37,6 +38,10 @@ public: const DataTypes & arguments, const Array & params) const override { + if (nested_function->getName().find(getName()) != String::npos) + { + throw Exception(ErrorCodes::ILLEGAL_AGGREGATION, "nested function for {0}-combinator must not have {0}-combinator", getName()); + } return std::make_shared(nested_function, arguments, params); } }; diff --git a/src/Parsers/ExpressionElementParsers.cpp b/src/Parsers/ExpressionElementParsers.cpp index dd9fc738094..5190e805922 100644 --- a/src/Parsers/ExpressionElementParsers.cpp +++ b/src/Parsers/ExpressionElementParsers.cpp @@ -49,7 +49,6 @@ namespace ErrorCodes extern const int BAD_ARGUMENTS; extern const int SYNTAX_ERROR; extern const int LOGICAL_ERROR; - extern const int ILLEGAL_AGGREGATION; } @@ -521,8 +520,8 @@ bool ParserFilterClause::parseImpl(Pos & pos, ASTPtr & node, Expected & expected assert(node); ASTFunction & function = dynamic_cast(*node); - ParserToken parser_openging_bracket(TokenType::OpeningRoundBracket); - if (!parser_openging_bracket.ignore(pos, expected)) + ParserToken parser_opening_bracket(TokenType::OpeningRoundBracket); + if (!parser_opening_bracket.ignore(pos, expected)) { return false; } @@ -534,7 +533,7 @@ bool ParserFilterClause::parseImpl(Pos & pos, ASTPtr & node, Expected & expected } ParserExpressionList parser_condition(false); ASTPtr condition; - if (!parser_condition.parse(pos, condition, expected)) + if (!parser_condition.parse(pos, condition, expected) || condition->children.size() != 1) { return false; } @@ -545,17 +544,6 @@ bool ParserFilterClause::parseImpl(Pos & pos, ASTPtr & node, Expected & expected return false; } - if (function.name.find("If") != String::npos) - { - throw Exception( - ErrorCodes::ILLEGAL_AGGREGATION, - "Filter clause provided for an aggregating function (" + function.name + ") already containing If suffix"); - } - if (condition->children.empty()) - { - throw Exception(ErrorCodes::SYNTAX_ERROR, "Empty condition for WHERE"); - } - function.name += "If"; function.arguments->children.push_back(condition->children[0]); return true; diff --git a/tests/queries/0_stateless/00545_weird_aggregate_functions.sql b/tests/queries/0_stateless/00545_weird_aggregate_functions.sql index 1f662850d05..c728dfcc534 100644 --- a/tests/queries/0_stateless/00545_weird_aggregate_functions.sql +++ b/tests/queries/0_stateless/00545_weird_aggregate_functions.sql @@ -1 +1 @@ -SELECT sumForEachMergeArray(y) FROM (SELECT sumForEachStateForEachIfArrayIfMerge(x) AS y FROM (SELECT sumForEachStateForEachIfArrayIfState([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], [1], 1) AS x)); +SELECT sumForEachMergeArray(y) FROM (SELECT sumForEachStateForEachIfArrayMerge(x) AS y FROM (SELECT sumForEachStateForEachIfArrayState([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], [1]) AS x)); diff --git a/tests/queries/0_stateless/02000_select_with_filter.sql b/tests/queries/0_stateless/02000_select_with_filter.sql index eb83b6478a1..4d10f86ed96 100644 --- a/tests/queries/0_stateless/02000_select_with_filter.sql +++ b/tests/queries/0_stateless/02000_select_with_filter.sql @@ -1,4 +1,3 @@ SELECT argMax(number, number + 1) FILTER(WHERE number != 99) FROM numbers(100) ; SELECT sum(number) FILTER(WHERE number % 2 == 0) FROM numbers(100); -SELECT sumIfOrNull(number, number % 2 == 1) FILTER(WHERE number % 2 == 0) FROM numbers(100); -- { clientError 184 } -SELECT sum(number) FILTER(WHERE) FROM numbers(100); -- { clientError 62 } +SELECT sumIfOrNull(number, number % 2 == 1) FILTER(WHERE number % 2 == 0) FROM numbers(100); -- { serverError 184 }