Merge pull request #32961 from kssenii/fix-positional-args

Fix for positional args
This commit is contained in:
Kseniia Sumarokova 2021-12-22 09:41:55 +03:00 committed by GitHub
commit 5ae2f0028f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 110 deletions

View File

@ -116,82 +116,62 @@ bool checkPositionalArguments(ASTPtr & argument, const ASTSelectQuery * select_q
}
}
/// In case of expression/function (order by 1+2 and 2*x1, greatest(1, 2)) replace
/// positions only if all literals are numbers, otherwise it is not positional.
bool positional = true;
const auto * ast_literal = typeid_cast<const ASTLiteral *>(argument.get());
if (!ast_literal)
return false;
/// Case when GROUP BY element is position.
if (const auto * ast_literal = typeid_cast<const ASTLiteral *>(argument.get()))
auto which = ast_literal->value.getType();
if (which != Field::Types::UInt64)
return false;
auto pos = ast_literal->value.get<UInt64>();
if (!pos || pos > columns.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Positional argument out of bounds: {} (exprected in range [1, {}]",
pos, columns.size());
const auto & column = columns[--pos];
if (typeid_cast<const ASTIdentifier *>(column.get()))
{
auto which = ast_literal->value.getType();
if (which == Field::Types::UInt64)
argument = column->clone();
}
else if (typeid_cast<const ASTFunction *>(column.get()))
{
std::function<void(ASTPtr)> throw_if_aggregate_function = [&](ASTPtr node)
{
auto pos = ast_literal->value.get<UInt64>();
if (pos > 0 && pos <= columns.size())
if (const auto * function = typeid_cast<const ASTFunction *>(node.get()))
{
const auto & column = columns[--pos];
if (typeid_cast<const ASTIdentifier *>(column.get()))
auto is_aggregate_function = AggregateFunctionFactory::instance().isAggregateFunctionName(function->name);
if (is_aggregate_function)
{
argument = column->clone();
}
else if (typeid_cast<const ASTFunction *>(column.get()))
{
std::function<void(ASTPtr)> throw_if_aggregate_function = [&](ASTPtr node)
{
if (const auto * function = typeid_cast<const ASTFunction *>(node.get()))
{
auto is_aggregate_function = AggregateFunctionFactory::instance().isAggregateFunctionName(function->name);
if (is_aggregate_function)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal value (aggregate function) for positional argument in {}",
ASTSelectQuery::expressionToString(expression));
}
else
{
if (function->arguments)
{
for (const auto & arg : function->arguments->children)
throw_if_aggregate_function(arg);
}
}
}
};
if (expression == ASTSelectQuery::Expression::GROUP_BY)
throw_if_aggregate_function(column);
argument = column->clone();
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal value (aggregate function) for positional argument in {}",
ASTSelectQuery::expressionToString(expression));
}
else
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal value for positional argument in {}",
ASTSelectQuery::expressionToString(expression));
if (function->arguments)
{
for (const auto & arg : function->arguments->children)
throw_if_aggregate_function(arg);
}
}
}
else if (pos > columns.size() || !pos)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Positional argument out of bounds: {} (exprected in range [1, {}]",
pos, columns.size());
}
}
else
positional = false;
}
else if (const auto * ast_function = typeid_cast<const ASTFunction *>(argument.get()))
{
if (ast_function->arguments)
{
for (auto & arg : ast_function->arguments->children)
positional &= checkPositionalArguments(arg, select_query, expression);
}
};
if (expression == ASTSelectQuery::Expression::GROUP_BY)
throw_if_aggregate_function(column);
argument = column->clone();
}
else
positional = false;
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal value for positional argument in {}",
ASTSelectQuery::expressionToString(expression));
}
return positional;
return true;
}
void replaceForPositionalArguments(ASTPtr & argument, const ASTSelectQuery * select_query, ASTSelectQuery::Expression expression)

View File

@ -46,22 +46,6 @@ select x1, x2, x3 from test order by 3 limit 1 by 1;
100 100 1
10 1 10
1 10 100
explain syntax select x3, x2, x1 from test order by 1 + 1;
SELECT
x3,
x2,
x1
FROM test
ORDER BY x3 + x3 ASC
explain syntax select x3, x2, x1 from test order by (1 + 1) * 3;
SELECT
x3,
x2,
x1
FROM test
ORDER BY (x3 + x3) * x1 ASC
select x2, x1 from test group by x2 + x1; -- { serverError 215 }
select x2, x1 from test group by 1 + 2; -- { serverError 215 }
explain syntax select x3, x2, x1 from test order by 1;
SELECT
x3,
@ -110,27 +94,6 @@ GROUP BY
x2
select max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select 1 + max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select x1 + x2, x3 from test group by x1 + x2, x3;
11 100
200 1
11 200
11 10
select x3, x2, x1 from test order by x3 * 2, x2, x1; -- check x3 * 2 does not become x3 * x2
1 100 100
1 100 100
10 1 10
100 10 1
200 1 10
200 10 1
explain syntax select x1, x3 from test group by 1 + 2, 1, 2;
SELECT
x1,
x3
FROM test
GROUP BY
x1 + x3,
x1,
x3
explain syntax select x1 + x3, x3 from test group by 1, 2;
SELECT
x1 + x3,
@ -152,3 +115,5 @@ SELECT 1 + 1 AS a
GROUP BY a
select substr('aaaaaaaaaaaaaa', 8) as a group by a;
aaaaaaa
select substr('aaaaaaaaaaaaaa', 8) as a group by substr('aaaaaaaaaaaaaa', 8);
aaaaaaa

View File

@ -22,12 +22,6 @@ select x1, x2, x3 from test order by 3 limit 1 by 3;
select x1, x2, x3 from test order by x3 limit 1 by x1;
select x1, x2, x3 from test order by 3 limit 1 by 1;
explain syntax select x3, x2, x1 from test order by 1 + 1;
explain syntax select x3, x2, x1 from test order by (1 + 1) * 3;
select x2, x1 from test group by x2 + x1; -- { serverError 215 }
select x2, x1 from test group by 1 + 2; -- { serverError 215 }
explain syntax select x3, x2, x1 from test order by 1;
explain syntax select x3 + 1, x2, x1 from test order by 1;
explain syntax select x3, x3 - x2, x2, x1 from test order by 2;
@ -37,11 +31,7 @@ explain syntax select 1 + greatest(x1, 1), x2 from test group by 1, 2;
select max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select 1 + max(x1), x2 from test group by 1, 2; -- { serverError 43 }
select x1 + x2, x3 from test group by x1 + x2, x3;
select x3, x2, x1 from test order by x3 * 2, x2, x1; -- check x3 * 2 does not become x3 * x2
explain syntax select x1, x3 from test group by 1 + 2, 1, 2;
explain syntax select x1 + x3, x3 from test group by 1, 2;
create table test2(x1 Int, x2 Int, x3 Int) engine=Memory;
@ -52,3 +42,5 @@ select a, b, c, d, e, f from (select 44 a, 88 b, 13 c, 14 d, 15 e, 16 f) t grou
explain syntax select plus(1, 1) as a group by a;
select substr('aaaaaaaaaaaaaa', 8) as a group by a;
select substr('aaaaaaaaaaaaaa', 8) as a group by substr('aaaaaaaaaaaaaa', 8);