From a1a2e276ae6b0f157500525ce7bd4a7a9d89c742 Mon Sep 17 00:00:00 2001 From: kssenii Date: Fri, 22 Oct 2021 12:22:16 +0000 Subject: [PATCH] Review fixes --- src/Interpreters/ExpressionAnalyzer.cpp | 43 +++++++++++++++---- src/Interpreters/TreeOptimizer.cpp | 21 +++------ .../02006_test_positional_arguments.reference | 9 ++++ .../02006_test_positional_arguments.sql | 2 + 4 files changed, 51 insertions(+), 24 deletions(-) diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index 68951f5ed68..c9ec8e15c64 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -100,7 +100,7 @@ bool checkPositionalArguments(ASTPtr & argument, const ASTSelectQuery * select_q { auto columns = select_query->select()->children; - /// In case of expression/function (order by 1+2 and 2*x1, max(1, 2)) replace + /// In case of expression/function (order by 1+2 and 2*x1, greatest(1, 2)) replace /// positions only if all literals are numbers, otherwise it is not positional. bool positional = true; @@ -120,22 +120,47 @@ bool checkPositionalArguments(ASTPtr & argument, const ASTSelectQuery * select_q } else if (const auto * function_ast = typeid_cast(column.get())) { - auto is_aggregate_function = AggregateFunctionFactory::instance().isAggregateFunctionName(function_ast->name); - if (is_aggregate_function && expression != ASTSelectQuery::Expression::ORDER_BY) + std::function throw_if_aggregate_function = [&](ASTPtr node) { - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal value (aggregate function) for positional argument in {}", - ASTSelectQuery::expressionToString(expression)); - } + if (const auto * function = typeid_cast(node.get())) + { + auto is_aggregate_function = AggregateFunctionFactory::instance().isAggregateFunctionName(function->name); + if (is_aggregate_function) + { + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal value (aggregate function) for positional argument in {}", + ASTSelectQuery::expressionToString(expression)); + } + else + { + if (function->arguments) + { + for (const auto & arg : function->arguments->children) + throw_if_aggregate_function(arg); + } + } + } + }; + + if (expression == ASTSelectQuery::Expression::GROUP_BY) + throw_if_aggregate_function(column); + argument = column->clone(); } else { - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal value for positional argument in {}", + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal value for positional argument in {}", ASTSelectQuery::expressionToString(expression)); } } - /// Do not throw if out of bounds, see appendUnusedGroupByColumn. + else if (pos > columns.size() || !pos) + { + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Positional argument out of bounds: {} (exprected in range [1, {}]", + pos, columns.size()); + } + /// Do not throw if pos < 0, becuase of TreeOptimizer::appendUnusedColumn() } else positional = false; diff --git a/src/Interpreters/TreeOptimizer.cpp b/src/Interpreters/TreeOptimizer.cpp index 8fb72f74c65..e811299b327 100644 --- a/src/Interpreters/TreeOptimizer.cpp +++ b/src/Interpreters/TreeOptimizer.cpp @@ -67,26 +67,17 @@ const std::unordered_set possibly_injective_function_names * Instead, leave `GROUP BY const`. * Next, see deleting the constants in the analyzeAggregation method. */ -void appendUnusedGroupByColumn(ASTSelectQuery * select_query, const NameSet & source_columns) +void appendUnusedGroupByColumn(ASTSelectQuery * select_query) { /// You must insert a constant that is not the name of the column in the table. Such a case is rare, but it happens. - /// Also start unused_column integer from source_columns.size() + 1, because lower numbers ([1, source_columns.size()]) + /// Also start unused_column integer must not intersect with ([1, source_columns.size()]) /// might be in positional GROUP BY. - UInt64 unused_column = source_columns.size() + 1; - String unused_column_name = toString(unused_column); - - while (source_columns.count(unused_column_name)) - { - ++unused_column; - unused_column_name = toString(unused_column); - } - select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, std::make_shared()); - select_query->groupBy()->children.emplace_back(std::make_shared(UInt64(unused_column))); + select_query->groupBy()->children.emplace_back(std::make_shared(Int64(-1))); } /// Eliminates injective function calls and constant expressions from group by statement. -void optimizeGroupBy(ASTSelectQuery * select_query, const NameSet & source_columns, ContextPtr context) +void optimizeGroupBy(ASTSelectQuery * select_query, ContextPtr context) { const FunctionFactory & function_factory = FunctionFactory::instance(); @@ -191,7 +182,7 @@ void optimizeGroupBy(ASTSelectQuery * select_query, const NameSet & source_colum } if (group_exprs.empty()) - appendUnusedGroupByColumn(select_query, source_columns); + appendUnusedGroupByColumn(select_query); } struct GroupByKeysInfo @@ -710,7 +701,7 @@ void TreeOptimizer::apply(ASTPtr & query, TreeRewriterResult & result, optimizeAggregationFunctions(query); /// GROUP BY injective function elimination. - optimizeGroupBy(select_query, result.source_columns_set, context); + optimizeGroupBy(select_query, context); /// GROUP BY functions of other keys elimination. if (settings.optimize_group_by_function_keys) diff --git a/tests/queries/0_stateless/02006_test_positional_arguments.reference b/tests/queries/0_stateless/02006_test_positional_arguments.reference index ff985d765f6..27936137a1b 100644 --- a/tests/queries/0_stateless/02006_test_positional_arguments.reference +++ b/tests/queries/0_stateless/02006_test_positional_arguments.reference @@ -100,7 +100,16 @@ GROUP BY x2 ORDER BY max(x1) ASC, x2 ASC +explain syntax select 1 + greatest(x1, 1), x2 from test group by 1, 2; +SELECT + 1 + greatest(x1, 1), + x2 +FROM test +GROUP BY + 1 + greatest(x1, 1), + x2 select max(x1), x2 from test group by 1, 2; -- { serverError 43 } +select 1 + max(x1), x2 from test group by 1, 2; -- { serverError 43 } select x1 + x2, x3 from test group by x1 + x2, x3; 11 100 200 1 diff --git a/tests/queries/0_stateless/02006_test_positional_arguments.sql b/tests/queries/0_stateless/02006_test_positional_arguments.sql index f1c3676405e..4b6affc290a 100644 --- a/tests/queries/0_stateless/02006_test_positional_arguments.sql +++ b/tests/queries/0_stateless/02006_test_positional_arguments.sql @@ -33,8 +33,10 @@ explain syntax select x3 + 1, x2, x1 from test order by 1; explain syntax select x3, x3 - x2, x2, x1 from test order by 2; explain syntax select x3, if(x3 > 10, x3, plus(x1, x2)), x1 + x2 from test order by 2; explain syntax select max(x1), x2 from test group by 2 order by 1, 2; +explain syntax select 1 + greatest(x1, 1), x2 from test group by 1, 2; select max(x1), x2 from test group by 1, 2; -- { serverError 43 } +select 1 + max(x1), x2 from test group by 1, 2; -- { serverError 43 } select x1 + x2, x3 from test group by x1 + x2, x3; select x3, x2, x1 from test order by x3 * 2, x2, x1; -- check x3 * 2 does not become x3 * x2