diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index 828f332af1d..6b019609dd9 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -40,6 +40,7 @@ #include #include #include +#include #include #include @@ -325,6 +326,39 @@ struct ExistsExpressionData using ExistsExpressionVisitor = InDepthNodeVisitor, false>; +struct ReplacePositionalArgumentsData +{ + using TypeToVisit = ASTSelectQuery; + ContextPtr context; + + void visit(ASTSelectQuery & select_query, ASTPtr &) const + { + if (context->getSettingsRef().enable_positional_arguments) + { + if (select_query.groupBy()) + { + for (auto & expr : select_query.groupBy()->children) + replaceForPositionalArguments(expr, &select_query, ASTSelectQuery::Expression::GROUP_BY); + } + if (select_query.orderBy()) + { + for (auto & expr : select_query.orderBy()->children) + { + auto & elem = assert_cast(*expr).children.at(0); + replaceForPositionalArguments(elem, &select_query, ASTSelectQuery::Expression::ORDER_BY); + } + } + if (select_query.limitBy()) + { + for (auto & expr : select_query.limitBy()->children) + replaceForPositionalArguments(expr, &select_query, ASTSelectQuery::Expression::LIMIT_BY); + } + } + } +}; + +using ReplacePositionalArgumentsVisitor = InDepthNodeVisitor, false>; + /// Translate qualified names such as db.table.column, table.column, table_alias.column to names' normal form. /// Expand asterisks and qualified asterisks with column names. /// There would be columns in normal form & column aliases after translation. Column & column alias would be normalized in QueryNormalizer. @@ -1316,25 +1350,6 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect( all_source_columns_set.insert(name); } - if (getContext()->getSettingsRef().enable_positional_arguments) - { - if (select_query->groupBy()) - { - for (auto & expr : select_query->groupBy()->children) - replaceForPositionalArguments(expr, select_query, ASTSelectQuery::Expression::GROUP_BY); - } - if (select_query->orderBy()) - { - for (auto & expr : select_query->orderBy()->children) - replaceForPositionalArguments(expr, select_query, ASTSelectQuery::Expression::ORDER_BY); - } - if (select_query->limitBy()) - { - for (auto & expr : select_query->limitBy()->children) - replaceForPositionalArguments(expr, select_query, ASTSelectQuery::Expression::LIMIT_BY); - } - } - normalize(query, result.aliases, all_source_columns_set, select_options.ignore_alias, settings, /* allow_self_aliases = */ true, getContext()); // expand GROUP BY ALL @@ -1493,6 +1508,9 @@ void TreeRewriter::normalize( ExistsExpressionVisitor::Data exists; ExistsExpressionVisitor(exists).visit(query); + ReplacePositionalArgumentsVisitor::Data data_replace_positional_arguments{context_}; + ReplacePositionalArgumentsVisitor(data_replace_positional_arguments).visit(query); + if (settings.transform_null_in) { CustomizeInVisitor::Data data_null_in{"nullIn"}; diff --git a/tests/queries/0_stateless/02006_test_positional_arguments.reference b/tests/queries/0_stateless/02006_test_positional_arguments.reference index 56817961b30..e2bbea2149d 100644 --- a/tests/queries/0_stateless/02006_test_positional_arguments.reference +++ b/tests/queries/0_stateless/02006_test_positional_arguments.reference @@ -119,9 +119,25 @@ select b from (select 5 as a, 'Hello' as b order by 1); Hello drop table if exists tp2; create table tp2(first_col String, second_col Int32) engine = MergeTree() order by tuple(); +insert into tp2 select 'bbb', 1; +insert into tp2 select 'aaa', 2; select count(*) from (select first_col, count(second_col) from tp2 group by 1); -0 +2 select total from (select first_col, count(second_col) as total from tp2 group by 1); +1 +1 +select first_col from (select first_col, second_col as total from tp2 order by 1 desc); +bbb +aaa +select first_col from (select first_col, second_col as total from tp2 order by 2 desc); +aaa +bbb +select max from (select max(first_col) as max, second_col as total from tp2 group by 2) order by 1; +aaa +bbb +with res as (select first_col from (select first_col, second_col as total from tp2 order by 2 desc) limit 1) +select * from res; +aaa drop table if exists test; create table test ( diff --git a/tests/queries/0_stateless/02006_test_positional_arguments.sql b/tests/queries/0_stateless/02006_test_positional_arguments.sql index 8829a204ab6..67f4fe24c55 100644 --- a/tests/queries/0_stateless/02006_test_positional_arguments.sql +++ b/tests/queries/0_stateless/02006_test_positional_arguments.sql @@ -51,8 +51,15 @@ select b from (select 5 as a, 'Hello' as b order by 1); drop table if exists tp2; create table tp2(first_col String, second_col Int32) engine = MergeTree() order by tuple(); +insert into tp2 select 'bbb', 1; +insert into tp2 select 'aaa', 2; select count(*) from (select first_col, count(second_col) from tp2 group by 1); select total from (select first_col, count(second_col) as total from tp2 group by 1); +select first_col from (select first_col, second_col as total from tp2 order by 1 desc); +select first_col from (select first_col, second_col as total from tp2 order by 2 desc); +select max from (select max(first_col) as max, second_col as total from tp2 group by 2) order by 1; +with res as (select first_col from (select first_col, second_col as total from tp2 order by 2 desc) limit 1) +select * from res; drop table if exists test; create table test