From f66e8464f94b41624099d2a71151a52778ed74d8 Mon Sep 17 00:00:00 2001 From: kssenii Date: Sun, 15 Aug 2021 09:55:43 +0300 Subject: [PATCH] Some final fixes --- .../InterpreterSelectIntersectExceptQuery.cpp | 7 +- .../InterpreterSelectIntersectExceptQuery.h | 5 +- .../SelectIntersectExceptQueryVisitor.cpp | 6 +- src/Parsers/ASTSelectIntersectExceptQuery.cpp | 5 +- src/Parsers/ExpressionListParsers.cpp | 86 +++---------------- src/Parsers/ExpressionListParsers.h | 17 +++- src/Parsers/ParserQueryWithOutput.cpp | 1 + src/Parsers/ParserUnionQueryElement.cpp | 3 +- .../QueryPlan/IntersectOrExceptStep.cpp | 6 +- .../Transforms/IntersectOrExceptTransform.cpp | 1 + ...02004_intersect_except_operators.reference | 5 ++ .../02004_intersect_except_operators.sql | 4 + .../02007_test_any_all_operators.reference | 20 +++++ .../02007_test_any_all_operators.sql | 11 ++- 14 files changed, 83 insertions(+), 94 deletions(-) diff --git a/src/Interpreters/InterpreterSelectIntersectExceptQuery.cpp b/src/Interpreters/InterpreterSelectIntersectExceptQuery.cpp index 4edd13d08e5..9c8dda56b44 100644 --- a/src/Interpreters/InterpreterSelectIntersectExceptQuery.cpp +++ b/src/Interpreters/InterpreterSelectIntersectExceptQuery.cpp @@ -89,7 +89,6 @@ InterpreterSelectIntersectExceptQuery::buildCurrentChildInterpreter(const ASTPtr if (ast_ptr_->as()) return std::make_unique(ast_ptr_, context, SelectQueryOptions()); - // if (ast_ptr_->as()) throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected query: {}", ast_ptr_->getID()); } @@ -140,4 +139,10 @@ BlockIO InterpreterSelectIntersectExceptQuery::execute() return res; } +void InterpreterSelectIntersectExceptQuery::ignoreWithTotals() +{ + for (auto & interpreter : nested_interpreters) + interpreter->ignoreWithTotals(); +} + } diff --git a/src/Interpreters/InterpreterSelectIntersectExceptQuery.h b/src/Interpreters/InterpreterSelectIntersectExceptQuery.h index 9cbde055b0b..805565e4c51 100644 --- a/src/Interpreters/InterpreterSelectIntersectExceptQuery.h +++ b/src/Interpreters/InterpreterSelectIntersectExceptQuery.h @@ -28,6 +28,8 @@ public: Block getSampleBlock() { return result_header; } + void ignoreWithTotals() override; + private: static String getName() { return "SelectIntersectExceptQuery"; } @@ -36,9 +38,8 @@ private: void buildQueryPlan(QueryPlan & query_plan) override; - void ignoreWithTotals() override {} - std::vector> nested_interpreters; + Operator final_operator; }; diff --git a/src/Interpreters/SelectIntersectExceptQueryVisitor.cpp b/src/Interpreters/SelectIntersectExceptQueryVisitor.cpp index 190ec279038..e26c4371591 100644 --- a/src/Interpreters/SelectIntersectExceptQueryVisitor.cpp +++ b/src/Interpreters/SelectIntersectExceptQueryVisitor.cpp @@ -12,11 +12,11 @@ namespace ErrorCodes /* * Note: there is a difference between intersect and except behaviour. - * `intersect` is supposed to be a part of last SelectQuery, i.e. the sequence with no parenthesis: + * `intersect` is supposed to be a part of the last SelectQuery, i.e. the sequence with no parenthesis: * select 1 union all select 2 except select 1 intersect 2 except select 2 union distinct select 5; * is interpreted as: * select 1 union all select 2 except (select 1 intersect 2) except select 2 union distinct select 5; - * Whereas `except` is applied to all union part like: + * Whereas `except` is applied to all left union part like: * (((select 1 union all select 2) except (select 1 intersect 2)) except select 2) union distinct select 5; **/ @@ -28,7 +28,7 @@ void SelectIntersectExceptQueryMatcher::visit(ASTPtr & ast, Data & data) void SelectIntersectExceptQueryMatcher::visit(ASTSelectWithUnionQuery & ast, Data &) { - auto & union_modes = ast.list_of_modes; + const auto & union_modes = ast.list_of_modes; if (union_modes.empty()) return; diff --git a/src/Parsers/ASTSelectIntersectExceptQuery.cpp b/src/Parsers/ASTSelectIntersectExceptQuery.cpp index 9d7a717fa6c..3b9cb0a2c16 100644 --- a/src/Parsers/ASTSelectIntersectExceptQuery.cpp +++ b/src/Parsers/ASTSelectIntersectExceptQuery.cpp @@ -30,11 +30,10 @@ void ASTSelectIntersectExceptQuery::formatQueryImpl(const FormatSettings & setti { settings.ostr << settings.nl_or_ws << indent_str << (settings.hilite ? hilite_keyword : "") << (final_operator == Operator::INTERSECT ? "INTERSECT" : "EXCEPT") - << (settings.hilite ? hilite_none : ""); + << (settings.hilite ? hilite_none : "") + << settings.nl_or_ws; } - if (it != children.begin()) - settings.ostr << settings.nl_or_ws; (*it)->formatImpl(settings, state, frame); } } diff --git a/src/Parsers/ExpressionListParsers.cpp b/src/Parsers/ExpressionListParsers.cpp index 69d95422799..58f5e766905 100644 --- a/src/Parsers/ExpressionListParsers.cpp +++ b/src/Parsers/ExpressionListParsers.cpp @@ -277,79 +277,6 @@ static bool modifyAST(ASTPtr ast, SubqueryFunctionType type) return true; } -bool ParserComparisonExpression::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) -{ - bool first = true; - - auto current_depth = pos.depth; - while (true) - { - if (first) - { - ASTPtr elem; - if (!elem_parser.parse(pos, elem, expected)) - return false; - - node = elem; - first = false; - } - else - { - /// try to find any of the valid operators - const char ** it; - Expected stub; - for (it = overlapping_operators_to_skip; *it; ++it) - if (ParserKeyword{*it}.checkWithoutMoving(pos, stub)) - break; - - if (*it) - break; - - for (it = operators; *it; it += 2) - if (parseOperator(pos, *it, expected)) - break; - - if (!*it) - break; - - /// the function corresponding to the operator - auto function = std::make_shared(); - - /// function arguments - auto exp_list = std::make_shared(); - - ASTPtr elem; - SubqueryFunctionType subquery_function_type = SubqueryFunctionType::NONE; - if (ParserKeyword("ANY").ignore(pos, expected)) - subquery_function_type = SubqueryFunctionType::ANY; - else if (ParserKeyword("ALL").ignore(pos, expected)) - subquery_function_type = SubqueryFunctionType::ALL; - else if (!elem_parser.parse(pos, elem, expected)) - return false; - - if (subquery_function_type != SubqueryFunctionType::NONE && !ParserSubquery().parse(pos, elem, expected)) - return false; - - /// the first argument of the function is the previous element, the second is the next one - function->name = it[1]; - function->arguments = exp_list; - function->children.push_back(exp_list); - - exp_list->children.push_back(node); - exp_list->children.push_back(elem); - - if (subquery_function_type != SubqueryFunctionType::NONE && !modifyAST(function, subquery_function_type)) - return false; - - pos.increaseDepth(); - node = function; - } - } - - pos.depth = current_depth; - return true; -} - bool ParserLeftAssociativeBinaryOperatorList::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { bool first = true; @@ -393,7 +320,15 @@ bool ParserLeftAssociativeBinaryOperatorList::parseImpl(Pos & pos, ASTPtr & node auto exp_list = std::make_shared(); ASTPtr elem; - if (!(remaining_elem_parser ? remaining_elem_parser : first_elem_parser)->parse(pos, elem, expected)) + SubqueryFunctionType subquery_function_type = SubqueryFunctionType::NONE; + if (allow_any_all_operators && ParserKeyword("ANY").ignore(pos, expected)) + subquery_function_type = SubqueryFunctionType::ANY; + else if (allow_any_all_operators && ParserKeyword("ALL").ignore(pos, expected)) + subquery_function_type = SubqueryFunctionType::ALL; + else if (!(remaining_elem_parser ? remaining_elem_parser : first_elem_parser)->parse(pos, elem, expected)) + return false; + + if (subquery_function_type != SubqueryFunctionType::NONE && !ParserSubquery().parse(pos, elem, expected)) return false; /// the first argument of the function is the previous element, the second is the next one @@ -404,6 +339,9 @@ bool ParserLeftAssociativeBinaryOperatorList::parseImpl(Pos & pos, ASTPtr & node exp_list->children.push_back(node); exp_list->children.push_back(elem); + if (allow_any_all_operators && subquery_function_type != SubqueryFunctionType::NONE && !modifyAST(function, subquery_function_type)) + return false; + /** special exception for the access operator to the element of the array `x[y]`, which * contains the infix part '[' and the suffix ''] '(specified as' [') */ diff --git a/src/Parsers/ExpressionListParsers.h b/src/Parsers/ExpressionListParsers.h index e44cacb313f..17deec4e9e4 100644 --- a/src/Parsers/ExpressionListParsers.h +++ b/src/Parsers/ExpressionListParsers.h @@ -121,6 +121,8 @@ private: Operators_t overlapping_operators_to_skip = { (const char *[]){ nullptr } }; ParserPtr first_elem_parser; ParserPtr remaining_elem_parser; + /// =, !=, <, > ALL (subquery) / ANY (subquery) + bool allow_any_all_operators = false; public: /** `operators_` - allowed operators and their corresponding functions @@ -130,8 +132,10 @@ public: { } - ParserLeftAssociativeBinaryOperatorList(Operators_t operators_, Operators_t overlapping_operators_to_skip_, ParserPtr && first_elem_parser_) - : operators(operators_), overlapping_operators_to_skip(overlapping_operators_to_skip_), first_elem_parser(std::move(first_elem_parser_)) + ParserLeftAssociativeBinaryOperatorList(Operators_t operators_, + Operators_t overlapping_operators_to_skip_, ParserPtr && first_elem_parser_, bool allow_any_all_operators_ = false) + : operators(operators_), overlapping_operators_to_skip(overlapping_operators_to_skip_), + first_elem_parser(std::move(first_elem_parser_)), allow_any_all_operators(allow_any_all_operators_) { } @@ -341,12 +345,16 @@ class ParserComparisonExpression : public IParserBase private: static const char * operators[]; static const char * overlapping_operators_to_skip[]; - ParserBetweenExpression elem_parser; + ParserLeftAssociativeBinaryOperatorList operator_parser {operators, + overlapping_operators_to_skip, std::make_unique(), true}; protected: const char * getName() const override{ return "comparison expression"; } - bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override + { + return operator_parser.parse(pos, node, expected); + } }; /** Parser for nullity checking with IS (NOT) NULL. @@ -355,6 +363,7 @@ class ParserNullityChecking : public IParserBase { private: ParserComparisonExpression elem_parser; + protected: const char * getName() const override { return "nullity checking"; } bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; diff --git a/src/Parsers/ParserQueryWithOutput.cpp b/src/Parsers/ParserQueryWithOutput.cpp index 4a73952674c..82f9f561187 100644 --- a/src/Parsers/ParserQueryWithOutput.cpp +++ b/src/Parsers/ParserQueryWithOutput.cpp @@ -24,6 +24,7 @@ #include #include + namespace DB { diff --git a/src/Parsers/ParserUnionQueryElement.cpp b/src/Parsers/ParserUnionQueryElement.cpp index d59a7be2278..efd022e6362 100644 --- a/src/Parsers/ParserUnionQueryElement.cpp +++ b/src/Parsers/ParserUnionQueryElement.cpp @@ -10,8 +10,7 @@ namespace DB bool ParserUnionQueryElement::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) { - if (!ParserSubquery().parse(pos, node, expected) - && !ParserSelectQuery().parse(pos, node, expected)) + if (!ParserSubquery().parse(pos, node, expected) && !ParserSelectQuery().parse(pos, node, expected)) return false; if (const auto * ast_subquery = node->as()) diff --git a/src/Processors/QueryPlan/IntersectOrExceptStep.cpp b/src/Processors/QueryPlan/IntersectOrExceptStep.cpp index b75898b815b..d1bb1eb41e9 100644 --- a/src/Processors/QueryPlan/IntersectOrExceptStep.cpp +++ b/src/Processors/QueryPlan/IntersectOrExceptStep.cpp @@ -36,10 +36,7 @@ IntersectOrExceptStep::IntersectOrExceptStep( , max_threads(max_threads_) { input_streams = std::move(input_streams_); - if (input_streams.size() == 1) - output_stream = input_streams.front(); - else - output_stream = DataStream{.header = header}; + output_stream = DataStream{.header = header}; } QueryPipelinePtr IntersectOrExceptStep::updatePipeline(QueryPipelines pipelines, const BuildQueryPipelineSettings &) @@ -71,6 +68,7 @@ QueryPipelinePtr IntersectOrExceptStep::updatePipeline(QueryPipelines pipelines, }); } + /// For the case of union. cur_pipeline->addTransform(std::make_shared(header, cur_pipeline->getNumStreams(), 1)); } diff --git a/src/Processors/Transforms/IntersectOrExceptTransform.cpp b/src/Processors/Transforms/IntersectOrExceptTransform.cpp index abfd1a7f0ad..3e39123ae4b 100644 --- a/src/Processors/Transforms/IntersectOrExceptTransform.cpp +++ b/src/Processors/Transforms/IntersectOrExceptTransform.cpp @@ -4,6 +4,7 @@ namespace DB { +/// After visitor is applied, ASTSelectIntersectExcept always has two child nodes. IntersectOrExceptTransform::IntersectOrExceptTransform(const Block & header_, Operator operator_) : IProcessor(InputPorts(2, header_), {header_}) , current_operator(operator_) diff --git a/tests/queries/0_stateless/02004_intersect_except_operators.reference b/tests/queries/0_stateless/02004_intersect_except_operators.reference index 7f41faaf83a..03b881f690b 100644 --- a/tests/queries/0_stateless/02004_intersect_except_operators.reference +++ b/tests/queries/0_stateless/02004_intersect_except_operators.reference @@ -70,6 +70,10 @@ select * from (select 1 intersect select 1); 1 with (select number from numbers(10) intersect select 5) as a select a * 10; 50 +with (select 5 except select 1) as a select a except select 5; +with (select number from numbers(10) intersect select 5) as a select a intersect select 1; +with (select number from numbers(10) intersect select 5) as a select a except select 1; +5 select count() from (select number from numbers(10) except select 5); 9 select count() from (select number from numbers(1000000) intersect select number from numbers(200000, 600000)); @@ -102,6 +106,7 @@ select * from (select 1 union all select 2 union all select 3 union all select 4 select 1 intersect (select 1 except select 2); 1 select 1 union all select 2 except (select 2 except select 1 union all select 1) except select 4; +select 1 intersect select count() from (select 1 except select 2 intersect select 2 union all select 1); explain syntax select 1 intersect select 1; SELECT 1 INTERSECT diff --git a/tests/queries/0_stateless/02004_intersect_except_operators.sql b/tests/queries/0_stateless/02004_intersect_except_operators.sql index ef0e52da116..7f08cc0adf2 100644 --- a/tests/queries/0_stateless/02004_intersect_except_operators.sql +++ b/tests/queries/0_stateless/02004_intersect_except_operators.sql @@ -21,6 +21,9 @@ select number from numbers(100) intersect select number from numbers(20, 60) exc select * from (select 1 intersect select 1); with (select number from numbers(10) intersect select 5) as a select a * 10; +with (select 5 except select 1) as a select a except select 5; +with (select number from numbers(10) intersect select 5) as a select a intersect select 1; +with (select number from numbers(10) intersect select 5) as a select a except select 1; select count() from (select number from numbers(10) except select 5); select count() from (select number from numbers(1000000) intersect select number from numbers(200000, 600000)); select count() from (select number from numbers(100) intersect select number from numbers(20, 60) except select number from numbers(30, 20) except select number from numbers(60, 20)); @@ -35,6 +38,7 @@ select * from (select 1 union all select 2 union all select 3 union all select 4 select 1 intersect (select 1 except select 2); select 1 union all select 2 except (select 2 except select 1 union all select 1) except select 4; +select 1 intersect select count() from (select 1 except select 2 intersect select 2 union all select 1); explain syntax select 1 intersect select 1; explain syntax select 1 except select 1; diff --git a/tests/queries/0_stateless/02007_test_any_all_operators.reference b/tests/queries/0_stateless/02007_test_any_all_operators.reference index ebd7cd8f6ca..a232320d15c 100644 --- a/tests/queries/0_stateless/02007_test_any_all_operators.reference +++ b/tests/queries/0_stateless/02007_test_any_all_operators.reference @@ -29,3 +29,23 @@ select number as a from numbers(10) where a != any (select 5 from numbers(3, 3)) 7 8 9 +select 1 < any (select 1 from numbers(10)); +0 +select 1 <= any (select 1 from numbers(10)); +1 +select 1 < any (select number from numbers(10)); +1 +select 1 > any (select number from numbers(10)); +1 +select 1 >= any (select number from numbers(10)); +1 +select 11 > all (select number from numbers(10)); +1 +select 11 <= all (select number from numbers(11)); +0 +select 11 < all (select 11 from numbers(10)); +0 +select 11 > all (select 11 from numbers(10)); +0 +select 11 >= all (select 11 from numbers(10)); +1 diff --git a/tests/queries/0_stateless/02007_test_any_all_operators.sql b/tests/queries/0_stateless/02007_test_any_all_operators.sql index 525f7e1fabd..10d7325afca 100644 --- a/tests/queries/0_stateless/02007_test_any_all_operators.sql +++ b/tests/queries/0_stateless/02007_test_any_all_operators.sql @@ -8,10 +8,19 @@ select 1 != all (select number from numbers(10)); select 1 == all (select 1 from numbers(10)); select 1 == all (select number from numbers(10)); - select 1 != any (select 1 from numbers(10)); select 1 != any (select number from numbers(10)); select number as a from numbers(10) where a == any (select number from numbers(3, 3)); select number as a from numbers(10) where a != any (select 5 from numbers(3, 3)); +select 1 < any (select 1 from numbers(10)); +select 1 <= any (select 1 from numbers(10)); +select 1 < any (select number from numbers(10)); +select 1 > any (select number from numbers(10)); +select 1 >= any (select number from numbers(10)); +select 11 > all (select number from numbers(10)); +select 11 <= all (select number from numbers(11)); +select 11 < all (select 11 from numbers(10)); +select 11 > all (select 11 from numbers(10)); +select 11 >= all (select 11 from numbers(10));