Review fixes

This commit is contained in:
kssenii 2021-10-22 12:22:16 +00:00
parent 5a51d1c29b
commit a1a2e276ae
4 changed files with 51 additions and 24 deletions

View File

@ -100,7 +100,7 @@ bool checkPositionalArguments(ASTPtr & argument, const ASTSelectQuery * select_q
{ {
auto columns = select_query->select()->children; auto columns = select_query->select()->children;
/// In case of expression/function (order by 1+2 and 2*x1, max(1, 2)) replace /// In case of expression/function (order by 1+2 and 2*x1, greatest(1, 2)) replace
/// positions only if all literals are numbers, otherwise it is not positional. /// positions only if all literals are numbers, otherwise it is not positional.
bool positional = true; bool positional = true;
@ -120,22 +120,47 @@ bool checkPositionalArguments(ASTPtr & argument, const ASTSelectQuery * select_q
} }
else if (const auto * function_ast = typeid_cast<const ASTFunction *>(column.get())) else if (const auto * function_ast = typeid_cast<const ASTFunction *>(column.get()))
{ {
auto is_aggregate_function = AggregateFunctionFactory::instance().isAggregateFunctionName(function_ast->name); std::function<void(ASTPtr)> throw_if_aggregate_function = [&](ASTPtr node)
if (is_aggregate_function && expression != ASTSelectQuery::Expression::ORDER_BY)
{ {
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, if (const auto * function = typeid_cast<const ASTFunction *>(node.get()))
"Illegal value (aggregate function) for positional argument in {}", {
ASTSelectQuery::expressionToString(expression)); auto is_aggregate_function = AggregateFunctionFactory::instance().isAggregateFunctionName(function->name);
} if (is_aggregate_function)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal value (aggregate function) for positional argument in {}",
ASTSelectQuery::expressionToString(expression));
}
else
{
if (function->arguments)
{
for (const auto & arg : function->arguments->children)
throw_if_aggregate_function(arg);
}
}
}
};
if (expression == ASTSelectQuery::Expression::GROUP_BY)
throw_if_aggregate_function(column);
argument = column->clone(); argument = column->clone();
} }
else else
{ {
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal value for positional argument in {}", throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal value for positional argument in {}",
ASTSelectQuery::expressionToString(expression)); ASTSelectQuery::expressionToString(expression));
} }
} }
/// Do not throw if out of bounds, see appendUnusedGroupByColumn. else if (pos > columns.size() || !pos)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Positional argument out of bounds: {} (exprected in range [1, {}]",
pos, columns.size());
}
/// Do not throw if pos < 0, becuase of TreeOptimizer::appendUnusedColumn()
} }
else else
positional = false; positional = false;

View File

@ -67,26 +67,17 @@ const std::unordered_set<String> possibly_injective_function_names
* Instead, leave `GROUP BY const`. * Instead, leave `GROUP BY const`.
* Next, see deleting the constants in the analyzeAggregation method. * Next, see deleting the constants in the analyzeAggregation method.
*/ */
void appendUnusedGroupByColumn(ASTSelectQuery * select_query, const NameSet & source_columns) void appendUnusedGroupByColumn(ASTSelectQuery * select_query)
{ {
/// You must insert a constant that is not the name of the column in the table. Such a case is rare, but it happens. /// You must insert a constant that is not the name of the column in the table. Such a case is rare, but it happens.
/// Also start unused_column integer from source_columns.size() + 1, because lower numbers ([1, source_columns.size()]) /// Also start unused_column integer must not intersect with ([1, source_columns.size()])
/// might be in positional GROUP BY. /// might be in positional GROUP BY.
UInt64 unused_column = source_columns.size() + 1;
String unused_column_name = toString(unused_column);
while (source_columns.count(unused_column_name))
{
++unused_column;
unused_column_name = toString(unused_column);
}
select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, std::make_shared<ASTExpressionList>()); select_query->setExpression(ASTSelectQuery::Expression::GROUP_BY, std::make_shared<ASTExpressionList>());
select_query->groupBy()->children.emplace_back(std::make_shared<ASTLiteral>(UInt64(unused_column))); select_query->groupBy()->children.emplace_back(std::make_shared<ASTLiteral>(Int64(-1)));
} }
/// Eliminates injective function calls and constant expressions from group by statement. /// Eliminates injective function calls and constant expressions from group by statement.
void optimizeGroupBy(ASTSelectQuery * select_query, const NameSet & source_columns, ContextPtr context) void optimizeGroupBy(ASTSelectQuery * select_query, ContextPtr context)
{ {
const FunctionFactory & function_factory = FunctionFactory::instance(); const FunctionFactory & function_factory = FunctionFactory::instance();
@ -191,7 +182,7 @@ void optimizeGroupBy(ASTSelectQuery * select_query, const NameSet & source_colum
} }
if (group_exprs.empty()) if (group_exprs.empty())
appendUnusedGroupByColumn(select_query, source_columns); appendUnusedGroupByColumn(select_query);
} }
struct GroupByKeysInfo struct GroupByKeysInfo
@ -710,7 +701,7 @@ void TreeOptimizer::apply(ASTPtr & query, TreeRewriterResult & result,
optimizeAggregationFunctions(query); optimizeAggregationFunctions(query);
/// GROUP BY injective function elimination. /// GROUP BY injective function elimination.
optimizeGroupBy(select_query, result.source_columns_set, context); optimizeGroupBy(select_query, context);
/// GROUP BY functions of other keys elimination. /// GROUP BY functions of other keys elimination.
if (settings.optimize_group_by_function_keys) if (settings.optimize_group_by_function_keys)

View File

@ -100,7 +100,16 @@ GROUP BY x2
ORDER BY ORDER BY
max(x1) ASC, max(x1) ASC,
x2 ASC x2 ASC
explain syntax select 1 + greatest(x1, 1), x2 from test group by 1, 2;
SELECT
1 + greatest(x1, 1),
x2
FROM test
GROUP BY
1 + greatest(x1, 1),
x2
select max(x1), x2 from test group by 1, 2; -- { serverError 43 } select max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select 1 + max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select x1 + x2, x3 from test group by x1 + x2, x3; select x1 + x2, x3 from test group by x1 + x2, x3;
11 100 11 100
200 1 200 1

View File

@ -33,8 +33,10 @@ explain syntax select x3 + 1, x2, x1 from test order by 1;
explain syntax select x3, x3 - x2, x2, x1 from test order by 2; explain syntax select x3, x3 - x2, x2, x1 from test order by 2;
explain syntax select x3, if(x3 > 10, x3, plus(x1, x2)), x1 + x2 from test order by 2; explain syntax select x3, if(x3 > 10, x3, plus(x1, x2)), x1 + x2 from test order by 2;
explain syntax select max(x1), x2 from test group by 2 order by 1, 2; explain syntax select max(x1), x2 from test group by 2 order by 1, 2;
explain syntax select 1 + greatest(x1, 1), x2 from test group by 1, 2;
select max(x1), x2 from test group by 1, 2; -- { serverError 43 } select max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select 1 + max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select x1 + x2, x3 from test group by x1 + x2, x3; select x1 + x2, x3 from test group by x1 + x2, x3;
select x3, x2, x1 from test order by x3 * 2, x2, x1; -- check x3 * 2 does not become x3 * x2 select x3, x2, x1 from test order by x3 * 2, x2, x1; -- check x3 * 2 does not become x3 * x2