diff --git a/dbms/src/Parsers/ParserSelectWithUnionQuery.cpp b/dbms/src/Parsers/ParserSelectWithUnionQuery.cpp index 503d92cbcb1..8aa16b0e971 100644 --- a/dbms/src/Parsers/ParserSelectWithUnionQuery.cpp +++ b/dbms/src/Parsers/ParserSelectWithUnionQuery.cpp @@ -1,7 +1,9 @@ -#include #include #include #include +#include +#include +#include namespace DB @@ -11,17 +13,34 @@ bool ParserSelectWithUnionQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & { ASTPtr list_node; - ParserList parser(std::make_unique(), std::make_unique("UNION ALL"), false); + ParserList parser(std::make_unique(), std::make_unique("UNION ALL"), false); if (!parser.parse(pos, list_node, expected)) return false; - auto res = std::make_shared(); + auto select_with_union_query = std::make_shared(); - res->list_of_selects = std::move(list_node); - res->children.push_back(res->list_of_selects); + node = select_with_union_query; + select_with_union_query->list_of_selects = std::make_shared(); + select_with_union_query->children.push_back(select_with_union_query->list_of_selects); + + // flatten inner union query + for (auto & child : list_node->children) + getSelectsFromUnionListNode(child, select_with_union_query->list_of_selects->children); - node = res; return true; } +void ParserSelectWithUnionQuery::getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects) +{ + if (ASTSelectWithUnionQuery * inner_union = typeid_cast(ast_select.get())) + { + for (auto & child : inner_union->list_of_selects->children) + getSelectsFromUnionListNode(child, selects); + + return; + } + + selects.push_back(std::move(ast_select)); +} + } diff --git a/dbms/src/Parsers/ParserSelectWithUnionQuery.h b/dbms/src/Parsers/ParserSelectWithUnionQuery.h index 33857fe33cb..07217a2ec3f 100644 --- a/dbms/src/Parsers/ParserSelectWithUnionQuery.h +++ b/dbms/src/Parsers/ParserSelectWithUnionQuery.h @@ -1,7 +1,6 @@ #pragma once -#include - +#include namespace DB { @@ -12,6 +11,9 @@ class ParserSelectWithUnionQuery : public IParserBase protected: const char * getName() const override { return "SELECT query, possibly with UNION"; } bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; + +private: + void getSelectsFromUnionListNode(ASTPtr & ast_select, ASTs & selects); }; } diff --git a/dbms/src/Parsers/ParserUnionQueryElement.cpp b/dbms/src/Parsers/ParserUnionQueryElement.cpp new file mode 100644 index 00000000000..b4c8408312d --- /dev/null +++ b/dbms/src/Parsers/ParserUnionQueryElement.cpp @@ -0,0 +1,22 @@ +#include +#include +#include +#include +#include + + +namespace DB +{ + +bool ParserUnionQueryElement::parseImpl(Pos & pos, ASTPtr & node, Expected & expected) +{ + if (!ParserSubquery().parse(pos, node, expected) && !ParserSelectQuery().parse(pos, node, expected)) + return false; + + if (auto * ast_sub_query = typeid_cast(node.get())) + node = ast_sub_query->children.at(0); + + return true; +} + +} diff --git a/dbms/src/Parsers/ParserUnionQueryElement.h b/dbms/src/Parsers/ParserUnionQueryElement.h new file mode 100644 index 00000000000..6b63c62c85b --- /dev/null +++ b/dbms/src/Parsers/ParserUnionQueryElement.h @@ -0,0 +1,17 @@ +#pragma once + +#include + + +namespace DB +{ + + +class ParserUnionQueryElement : public IParserBase +{ +protected: + const char * getName() const override { return "SELECT query, subquery, possibly with UNION"; } + bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override; +}; + +} diff --git a/dbms/tests/queries/0_stateless/00612_union_query_with_subquery.reference b/dbms/tests/queries/0_stateless/00612_union_query_with_subquery.reference new file mode 100644 index 00000000000..64eef762b5d --- /dev/null +++ b/dbms/tests/queries/0_stateless/00612_union_query_with_subquery.reference @@ -0,0 +1,12 @@ +0 +0 +0 +1 +1 +2 +0 +0 +0 +1 +1 +2 diff --git a/dbms/tests/queries/0_stateless/00612_union_query_with_subquery.sql b/dbms/tests/queries/0_stateless/00612_union_query_with_subquery.sql new file mode 100644 index 00000000000..5db394ec6e9 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00612_union_query_with_subquery.sql @@ -0,0 +1,2 @@ +SELECT * FROM ((SELECT * FROM system.numbers LIMIT 1) UNION ALL SELECT * FROM system.numbers LIMIT 2 UNION ALL (SELECT * FROM system.numbers LIMIT 3)) ORDER BY number; +SELECT * FROM (SELECT * FROM system.numbers LIMIT 1 UNION ALL (SELECT * FROM system.numbers LIMIT 2 UNION ALL (SELECT * FROM system.numbers LIMIT 3))) ORDER BY number; \ No newline at end of file