diff --git a/src/Parsers/ExpressionListParsers.cpp b/src/Parsers/ExpressionListParsers.cpp index 13af308736b..10cbd95ec59 100644 --- a/src/Parsers/ExpressionListParsers.cpp +++ b/src/Parsers/ExpressionListParsers.cpp @@ -328,14 +328,20 @@ bool ParserLeftAssociativeBinaryOperatorList::parseImpl(Pos & pos, ASTPtr & node ASTPtr elem; SubqueryFunctionType subquery_function_type = SubqueryFunctionType::NONE; - if (allow_any_all_operators && ParserKeyword("ANY").ignore(pos, expected)) - subquery_function_type = SubqueryFunctionType::ANY; - else if (allow_any_all_operators && ParserKeyword("ALL").ignore(pos, expected)) - subquery_function_type = SubqueryFunctionType::ALL; - else if (!(remaining_elem_parser ? remaining_elem_parser : first_elem_parser)->parse(pos, elem, expected)) - return false; + + if (comparison_expression) + { + if (ParserKeyword("ANY").ignore(pos, expected)) + subquery_function_type = SubqueryFunctionType::ANY; + else if (ParserKeyword("ALL").ignore(pos, expected)) + subquery_function_type = SubqueryFunctionType::ALL; + } if (subquery_function_type != SubqueryFunctionType::NONE && !ParserSubquery().parse(pos, elem, expected)) + subquery_function_type = SubqueryFunctionType::NONE; + + if (subquery_function_type == SubqueryFunctionType::NONE + && !(remaining_elem_parser ? remaining_elem_parser : first_elem_parser)->parse(pos, elem, expected)) return false; /// the first argument of the function is the previous element, the second is the next one @@ -346,7 +352,7 @@ bool ParserLeftAssociativeBinaryOperatorList::parseImpl(Pos & pos, ASTPtr & node exp_list->children.push_back(node); exp_list->children.push_back(elem); - if (allow_any_all_operators && subquery_function_type != SubqueryFunctionType::NONE && !modifyAST(function, subquery_function_type)) + if (comparison_expression && subquery_function_type != SubqueryFunctionType::NONE && !modifyAST(function, subquery_function_type)) return false; /** special exception for the access operator to the element of the array `x[y]`, which diff --git a/src/Parsers/ExpressionListParsers.h b/src/Parsers/ExpressionListParsers.h index 86d0fd0f861..18d25bcb6d9 100644 --- a/src/Parsers/ExpressionListParsers.h +++ b/src/Parsers/ExpressionListParsers.h @@ -122,7 +122,7 @@ private: ParserPtr first_elem_parser; ParserPtr remaining_elem_parser; /// =, !=, <, > ALL (subquery) / ANY (subquery) - bool allow_any_all_operators = false; + bool comparison_expression = false; public: /** `operators_` - allowed operators and their corresponding functions @@ -133,9 +133,9 @@ public: } ParserLeftAssociativeBinaryOperatorList(Operators_t operators_, - Operators_t overlapping_operators_to_skip_, ParserPtr && first_elem_parser_, bool allow_any_all_operators_ = false) + Operators_t overlapping_operators_to_skip_, ParserPtr && first_elem_parser_, bool comparison_expression_ = false) : operators(operators_), overlapping_operators_to_skip(overlapping_operators_to_skip_), - first_elem_parser(std::move(first_elem_parser_)), allow_any_all_operators(allow_any_all_operators_) + first_elem_parser(std::move(first_elem_parser_)), comparison_expression(comparison_expression_) { } diff --git a/tests/queries/0_stateless/02007_test_any_all_operators.reference b/tests/queries/0_stateless/02007_test_any_all_operators.reference index a232320d15c..69436560ff1 100644 --- a/tests/queries/0_stateless/02007_test_any_all_operators.reference +++ b/tests/queries/0_stateless/02007_test_any_all_operators.reference @@ -49,3 +49,7 @@ select 11 > all (select 11 from numbers(10)); 0 select 11 >= all (select 11 from numbers(10)); 1 +select sum(number) = any(number) from numbers(1) group by number; +1 +select 1 == any (1); +1 diff --git a/tests/queries/0_stateless/02007_test_any_all_operators.sql b/tests/queries/0_stateless/02007_test_any_all_operators.sql index 10d7325afca..dd539bcbc5c 100644 --- a/tests/queries/0_stateless/02007_test_any_all_operators.sql +++ b/tests/queries/0_stateless/02007_test_any_all_operators.sql @@ -24,3 +24,5 @@ select 11 <= all (select number from numbers(11)); select 11 < all (select 11 from numbers(10)); select 11 > all (select 11 from numbers(10)); select 11 >= all (select 11 from numbers(10)); +select sum(number) = any(number) from numbers(1) group by number; +select 1 == any (1);