Better positional args

This commit is contained in:
kssenii 2021-10-20 10:45:19 +03:00
parent 5324cc8359
commit 6c990400d1
3 changed files with 211 additions and 56 deletions

View File

@ -96,6 +96,71 @@ bool allowEarlyConstantFolding(const ActionsDAG & actions, const Settings & sett
return true;
}
bool checkPositionalArguments(ASTPtr & argument, const ASTSelectQuery * select_query, ASTSelectQuery::Expression expression)
{
auto columns = select_query->select()->children;
/// In case of expression/function (order by 1+2 and 2*x1, max(1, 2)) replace
/// positions only if all literals are numbers, otherwise it is not positional.
bool positional = true;
/// Case when GROUP BY element is position.
if (const auto * ast_literal = typeid_cast<const ASTLiteral *>(argument.get()))
{
auto which = ast_literal->value.getType();
if (which == Field::Types::UInt64)
{
auto pos = ast_literal->value.get<UInt64>();
if (pos > 0 && pos <= columns.size())
{
const auto & column = columns[--pos];
if (typeid_cast<const ASTIdentifier *>(column.get()))
{
argument = column->clone();
}
else if (const auto * function_ast = typeid_cast<const ASTFunction *>(column.get()))
{
auto is_aggregate_function = AggregateFunctionFactory::instance().isAggregateFunctionName(function_ast->name);
if (is_aggregate_function && expression != ASTSelectQuery::Expression::ORDER_BY)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal value (aggregate function) for positional argument in {}",
ASTSelectQuery::expressionToString(expression));
}
argument = column->clone();
}
else
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal value for positional argument in {}",
ASTSelectQuery::expressionToString(expression));
}
}
/// Do not throw if out of bounds, see appendUnusedGroupByColumn.
}
else
positional = false;
}
else if (auto * ast_function = typeid_cast<const ASTFunction *>(argument.get()))
{
if (ast_function->arguments)
{
for (auto & arg : ast_function->arguments->children)
positional &= checkPositionalArguments(arg, select_query, expression);
}
}
else
positional = false;
return positional;
}
void replaceForPositionalArguments(ASTPtr & argument, const ASTSelectQuery * select_query, ASTSelectQuery::Expression expression)
{
auto argument_with_replacement = argument->clone();
if (checkPositionalArguments(argument_with_replacement, select_query, expression))
argument = argument_with_replacement;
}
}
bool sanitizeBlock(Block & block, bool throw_if_cannot_create_column)
@ -164,37 +229,6 @@ ExpressionAnalyzer::ExpressionAnalyzer(
analyzeAggregation(temp_actions);
}
static ASTPtr checkPositionalArgument(ASTPtr argument, const ASTSelectQuery * select_query, ASTSelectQuery::Expression expression)
{
auto columns = select_query->select()->children;
/// Case when GROUP BY element is position.
/// Do not consider case when GROUP BY element is not a literal, but expression, even if all values are constants.
if (const auto * ast_literal = typeid_cast<const ASTLiteral *>(argument.get()))
{
auto which = ast_literal->value.getType();
if (which == Field::Types::UInt64)
{
auto pos = ast_literal->value.get<UInt64>();
if (pos > 0 && pos <= columns.size())
{
const auto & column = columns[--pos];
if (const auto * literal_ast = typeid_cast<const ASTIdentifier *>(column.get()))
{
return std::make_shared<ASTIdentifier>(literal_ast->name());
}
else
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal value for positional argument in {}",
ASTSelectQuery::expressionToString(expression));
}
}
/// Do not throw if out of bounds, see appendUnusedGroupByColumn.
}
}
return nullptr;
}
NamesAndTypesList ExpressionAnalyzer::getColumnsAfterArrayJoin(ActionsDAGPtr & actions, const NamesAndTypesList & src_columns)
{
const auto * select_query = query->as<ASTSelectQuery>();
@ -282,16 +316,14 @@ void ExpressionAnalyzer::analyzeAggregation(ActionsDAGPtr & temp_actions)
for (ssize_t i = 0; i < ssize_t(group_asts.size()); ++i)
{
ssize_t size = group_asts.size();
getRootActionsNoMakeSet(group_asts[i], true, temp_actions, false);
if (getContext()->getSettingsRef().enable_positional_arguments)
{
auto new_argument = checkPositionalArgument(group_asts[i], select_query, ASTSelectQuery::Expression::GROUP_BY);
if (new_argument)
group_asts[i] = new_argument;
}
replaceForPositionalArguments(group_asts[i], select_query, ASTSelectQuery::Expression::GROUP_BY);
getRootActionsNoMakeSet(group_asts[i], true, temp_actions, false);
const auto & column_name = group_asts[i]->getColumnName();
const auto * node = temp_actions->tryFindInIndex(column_name);
if (!node)
throw Exception("Unknown identifier (in GROUP BY): " + column_name, ErrorCodes::UNKNOWN_IDENTIFIER);
@ -1231,6 +1263,16 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChai
ExpressionActionsChain::Step & step = chain.lastStep(aggregated_columns);
for (auto & child : select_query->orderBy()->children)
{
auto * ast = child->as<ASTOrderByElement>();
if (!ast || ast->children.empty())
throw Exception("Bad ORDER BY expression AST", ErrorCodes::UNKNOWN_TYPE_OF_AST_NODE);
if (getContext()->getSettingsRef().enable_positional_arguments)
replaceForPositionalArguments(ast->children.at(0), select_query, ASTSelectQuery::Expression::ORDER_BY);
}
getRootActions(select_query->orderBy(), only_types, step.actions());
bool with_fill = false;
@ -1239,16 +1281,6 @@ ActionsDAGPtr SelectQueryExpressionAnalyzer::appendOrderBy(ExpressionActionsChai
for (auto & child : select_query->orderBy()->children)
{
auto * ast = child->as<ASTOrderByElement>();
if (!ast || ast->children.empty())
throw Exception("Bad ORDER BY expression AST", ErrorCodes::UNKNOWN_TYPE_OF_AST_NODE);
if (getContext()->getSettingsRef().enable_positional_arguments)
{
auto new_argument = checkPositionalArgument(ast->children.at(0), select_query, ASTSelectQuery::Expression::ORDER_BY);
if (new_argument)
ast->children[0] = new_argument;
}
ASTPtr order_expression = ast->children.at(0);
step.addRequiredOutput(order_expression->getColumnName());
@ -1302,11 +1334,7 @@ bool SelectQueryExpressionAnalyzer::appendLimitBy(ExpressionActionsChain & chain
for (auto & child : children)
{
if (getContext()->getSettingsRef().enable_positional_arguments)
{
auto new_argument = checkPositionalArgument(child, select_query, ASTSelectQuery::Expression::LIMIT_BY);
if (new_argument)
child = new_argument;
}
replaceForPositionalArguments(child, select_query, ASTSelectQuery::Expression::LIMIT_BY);
auto child_name = child->getColumnName();
if (!aggregated_names.count(child_name))

View File

@ -46,5 +46,108 @@ select x1, x2, x3 from test order by 3 limit 1 by 1;
100 100 1
10 1 10
1 10 100
select max(x3), max(x2), max(x1) from test group by 1; -- { serverError 43 }
select max(x1) from test order by 1; -- { serverError 43 }
select x3, x2, x1 from test order by x3 + x3;
1 100 100
1 100 100
10 1 10
100 10 1
200 10 1
200 1 10
select x3, x2, x1 from test order by 1 + 1;
1 100 100
1 100 100
10 1 10
100 10 1
200 10 1
200 1 10
select x3, x2, x1 from test order by (x3 + x3) * x1;
1 100 100
100 10 1
10 1 10
1 100 100
200 10 1
200 1 10
select x3, x2, x1 from test order by (1 + 1) * 3;
1 100 100
100 10 1
10 1 10
1 100 100
200 10 1
200 1 10
select x2, x1 from test group by x2 + x1; -- { serverError 215 }
select x2, x1 from test group by 1 + 2; -- { serverError 215 }
select x3, x2, x1 from test order by 1;
1 100 100
1 100 100
10 1 10
100 10 1
200 10 1
200 1 10
select x3 + 1, x2, x1 from test order by 1;
2 100 100
2 100 100
11 1 10
101 10 1
201 10 1
201 1 10
select x3, x3 - x2, x2, x1 from test order by 2;
1 -99 100 100
1 -99 100 100
10 9 1 10
100 90 10 1
200 190 10 1
200 199 1 10
select x3, if(0, x3, plus(x1, x2)), x1 + x2 from test order by 2;
200 11 11
200 11 11
100 11 11
10 11 11
1 200 200
1 200 200
select max(x1), x2 from test group by 2 order by 1, 2;
1 10
10 1
100 100
select max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select x1 + x2, x3 from test group by x1 + x2, x3;
11 100
200 1
11 200
11 10
select x3, x2, x1 from test order by x3 * 2, x2, x1; -- check x3 * 2 does not become x3 * x2
1 100 100
1 100 100
10 1 10
100 10 1
200 1 10
200 10 1
select x3, x2, x1 from test order by 1 + 1, 2, 1;
1 100 100
1 100 100
10 1 10
100 10 1
200 1 10
200 10 1
explain syntax select x1, x3 from test group by 1 + 2, 1, 2;
SELECT
x1,
x3
FROM test
GROUP BY
x1 + x3,
x1,
x3
explain syntax select x1 + x3, x3 from test group by 1, 2;
SELECT
x1 + x3,
x3
FROM test
GROUP BY
x1 + x3,
x3
create table test2(x1 Int, x2 Int, x3 Int) engine=Memory;
insert into test2 values (1, 10, 100), (10, 1, 10), (100, 100, 1);
select x1, x1 * 2, max(x2), max(x3) from test2 group by 2, 1, x1 order by 1, 2, 4 desc, 3 asc;
1 2 10 100
10 20 1 10
100 200 100 1

View File

@ -1,6 +1,8 @@
set enable_positional_arguments = 1;
drop table if exists test;
drop table if exists test2;
create table test(x1 Int, x2 Int, x3 Int) engine=Memory();
insert into test values (1, 10, 100), (10, 1, 10), (100, 100, 1);
@ -20,7 +22,29 @@ select x1, x2, x3 from test order by 3 limit 1 by 3;
select x1, x2, x3 from test order by x3 limit 1 by x1;
select x1, x2, x3 from test order by 3 limit 1 by 1;
select max(x3), max(x2), max(x1) from test group by 1; -- { serverError 43 }
select max(x1) from test order by 1; -- { serverError 43 }
select x3, x2, x1 from test order by x3 + x3;
select x3, x2, x1 from test order by 1 + 1;
select x3, x2, x1 from test order by (x3 + x3) * x1;
select x3, x2, x1 from test order by (1 + 1) * 3;
select x2, x1 from test group by x2 + x1; -- { serverError 215 }
select x2, x1 from test group by 1 + 2; -- { serverError 215 }
select x3, x2, x1 from test order by 1;
select x3 + 1, x2, x1 from test order by 1;
select x3, x3 - x2, x2, x1 from test order by 2;
select x3, if(0, x3, plus(x1, x2)), x1 + x2 from test order by 2;
select max(x1), x2 from test group by 2 order by 1, 2;
select max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select x1 + x2, x3 from test group by x1 + x2, x3;
select x3, x2, x1 from test order by x3 * 2, x2, x1; -- check x3 * 2 does not become x3 * x2
select x3, x2, x1 from test order by 1 + 1, 2, 1;
explain syntax select x1, x3 from test group by 1 + 2, 1, 2;
explain syntax select x1 + x3, x3 from test group by 1, 2;
create table test2(x1 Int, x2 Int, x3 Int) engine=Memory;
insert into test2 values (1, 10, 100), (10, 1, 10), (100, 100, 1);
select x1, x1 * 2, max(x2), max(x3) from test2 group by 2, 1, x1 order by 1, 2, 4 desc, 3 asc;