diff --git a/src/Interpreters/JoinToSubqueryTransformVisitor.cpp b/src/Interpreters/JoinToSubqueryTransformVisitor.cpp index ca21a53b5b0..331c364c5fa 100644 --- a/src/Interpreters/JoinToSubqueryTransformVisitor.cpp +++ b/src/Interpreters/JoinToSubqueryTransformVisitor.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -129,6 +130,8 @@ private: /// Make aliases maps (alias -> column_name, column_name -> alias) struct ColumnAliasesMatcher { + using Visitor = ConstInDepthNodeVisitor; + struct Data { const std::vector tables; @@ -137,6 +140,7 @@ struct ColumnAliasesMatcher std::unordered_map aliases; /// alias -> long_name std::vector> compound_identifiers; std::set allowed_long_names; /// original names allowed as aliases '--t.x as t.x' (select expressions only). + bool inside_function = false; explicit Data(const std::vector && tables_) : tables(tables_) @@ -192,6 +196,10 @@ struct ColumnAliasesMatcher static bool needChildVisit(const ASTPtr & node, const ASTPtr &) { + /// Do not go into subqueries. Function visits children itself. + if (node->as() || + node->as()) + return false; return !node->as(); } @@ -199,11 +207,24 @@ struct ColumnAliasesMatcher { if (auto * t = ast->as()) visit(*t, ast, data); + else if (auto * f = ast->as()) + visit(*f, ast, data); - if (ast->as() || ast->as()) + /// Do not allow asterisks but ignore them inside functions. I.e. allow 'count(*)'. + if (!data.inside_function && (ast->as() || ast->as())) throw Exception("Multiple JOIN do not support asterisks for complex queries yet", ErrorCodes::NOT_IMPLEMENTED); } + static void visit(const ASTFunction &, const ASTPtr & ast, Data & data) + { + /// Grandchild case: Function -> (ExpressionList) -> Asterisk + data.inside_function = true; + Visitor visitor(data); + for (auto & child : ast->children) + visitor.visit(child); + data.inside_function = false; + } + static void visit(const ASTIdentifier & const_node, const ASTPtr &, Data & data) { ASTIdentifier & node = const_cast(const_node); /// we know it's not const @@ -348,7 +369,7 @@ bool needRewrite(ASTSelectQuery & select, std::vector; using RewriteVisitor = InDepthNodeVisitor; using ExtractAsterisksVisitor = ConstInDepthNodeVisitor; -using ColumnAliasesVisitor = ConstInDepthNodeVisitor; +using ColumnAliasesVisitor = ColumnAliasesMatcher::Visitor; using AppendSemanticMatcher = OneTypeMatcher; using AppendSemanticVisitor = InDepthNodeVisitor; diff --git a/tests/queries/0_stateless/01116_cross_count_asterisks.reference b/tests/queries/0_stateless/01116_cross_count_asterisks.reference new file mode 100644 index 00000000000..8347b144a35 --- /dev/null +++ b/tests/queries/0_stateless/01116_cross_count_asterisks.reference @@ -0,0 +1,4 @@ +2 +1 +2 +1 diff --git a/tests/queries/0_stateless/01116_cross_count_asterisks.sql b/tests/queries/0_stateless/01116_cross_count_asterisks.sql new file mode 100644 index 00000000000..1fb8b0b0e66 --- /dev/null +++ b/tests/queries/0_stateless/01116_cross_count_asterisks.sql @@ -0,0 +1,29 @@ +SET multiple_joins_rewriter_version = 2; + +SELECT count(*) +FROM numbers(2) AS n1, numbers(3) AS n2, numbers(4) AS n3 +WHERE (n1.number = n2.number) AND (n2.number = n3.number); + +SELECT count(*) c FROM ( + SELECT count(*), count(*) as c + FROM numbers(2) AS n1, numbers(3) AS n2, numbers(4) AS n3 + WHERE (n1.number = n2.number) AND (n2.number = n3.number) + AND (SELECT count(*) FROM numbers(1)) = 1 +) +WHERE (SELECT count(*) FROM numbers(2)) = 2 +HAVING c IN(SELECT count(*) c FROM numbers(1)); + +SET multiple_joins_rewriter_version = 1; + +SELECT count(*) +FROM numbers(2) AS n1, numbers(3) AS n2, numbers(4) AS n3 +WHERE (n1.number = n2.number) AND (n2.number = n3.number); + +SELECT count(*) c FROM ( + SELECT count(*), count(*) as c + FROM numbers(2) AS n1, numbers(3) AS n2, numbers(4) AS n3 + WHERE (n1.number = n2.number) AND (n2.number = n3.number) + AND (SELECT count(*) FROM numbers(1)) = 1 +) +WHERE (SELECT count(*) FROM numbers(2)) = 2 +HAVING c IN(SELECT count(*) c FROM numbers(1));