Merge pull request #27901 from amosbird/applylambda

APPLY with lambda.
This commit is contained in:
Kseniia Sumarokova 2021-08-21 00:26:03 +03:00 committed by GitHub
commit f3f44ec8fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 101 additions and 18 deletions

View File

@ -7,6 +7,7 @@
#include <Common/quoteString.h>
#include <IO/Operators.h>
#include <re2/re2.h>
#include <stack>
namespace DB
@ -40,10 +41,18 @@ void ASTColumnsApplyTransformer::formatImpl(const FormatSettings & settings, For
if (!column_name_prefix.empty())
settings.ostr << "(";
settings.ostr << func_name;
if (parameters)
parameters->formatImpl(settings, state, frame);
if (lambda)
{
lambda->formatImpl(settings, state, frame);
}
else
{
settings.ostr << func_name;
if (parameters)
parameters->formatImpl(settings, state, frame);
}
if (!column_name_prefix.empty())
settings.ostr << ", '" << column_name_prefix << "')";
@ -64,9 +73,33 @@ void ASTColumnsApplyTransformer::transform(ASTs & nodes) const
else
name = column->getColumnName();
}
auto function = makeASTFunction(func_name, column);
function->parameters = parameters;
column = function;
if (lambda)
{
auto body = lambda->as<const ASTFunction &>().arguments->children.at(1)->clone();
std::stack<ASTPtr> stack;
stack.push(body);
while (!stack.empty())
{
auto ast = stack.top();
stack.pop();
for (auto & child : ast->children)
{
if (auto arg_name = tryGetIdentifierName(child); arg_name && arg_name == lambda_arg)
{
child = column->clone();
continue;
}
stack.push(child);
}
}
column = body;
}
else
{
auto function = makeASTFunction(func_name, column);
function->parameters = parameters;
column = function;
}
if (!column_name_prefix.empty())
column->setAlias(column_name_prefix + name);
}

View File

@ -25,13 +25,22 @@ public:
auto res = std::make_shared<ASTColumnsApplyTransformer>(*this);
if (parameters)
res->parameters = parameters->clone();
if (lambda)
res->lambda = lambda->clone();
return res;
}
void transform(ASTs & nodes) const override;
// Case 1 APPLY (quantile(0.9))
String func_name;
String column_name_prefix;
ASTPtr parameters;
// Case 2 APPLY (x -> quantile(0.9)(x))
ASTPtr lambda;
String lambda_arg;
String column_name_prefix;
protected:
void formatImpl(const FormatSettings & settings, FormatState &, FormatStateStacked) const override;
};

View File

@ -1827,20 +1827,47 @@ bool ParserColumnsTransformers::parseImpl(Pos & pos, ASTPtr & node, Expected & e
with_open_round_bracket = true;
}
ASTPtr lambda;
String lambda_arg;
ASTPtr func_name;
if (!ParserIdentifier().parse(pos, func_name, expected))
return false;
ASTPtr expr_list_args;
if (pos->type == TokenType::OpeningRoundBracket)
auto opos = pos;
if (ParserLambdaExpression().parse(pos, lambda, expected))
{
++pos;
if (!ParserExpressionList(false).parse(pos, expr_list_args, expected))
if (const auto * func = lambda->as<ASTFunction>(); func && func->name == "lambda")
{
const auto * lambda_args_tuple = func->arguments->children.at(0)->as<ASTFunction>();
const ASTs & lambda_arg_asts = lambda_args_tuple->arguments->children;
if (lambda_arg_asts.size() != 1)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "APPLY column transformer can only accept lambda with one argument");
if (auto opt_arg_name = tryGetIdentifierName(lambda_arg_asts[0]); opt_arg_name)
lambda_arg = *opt_arg_name;
else
throw Exception(ErrorCodes::BAD_ARGUMENTS, "lambda argument declarations must be identifiers");
}
else
{
lambda = nullptr;
pos = opos;
}
}
if (!lambda)
{
if (!ParserIdentifier().parse(pos, func_name, expected))
return false;
if (pos->type != TokenType::ClosingRoundBracket)
return false;
++pos;
if (pos->type == TokenType::OpeningRoundBracket)
{
++pos;
if (!ParserExpressionList(false).parse(pos, expr_list_args, expected))
return false;
if (pos->type != TokenType::ClosingRoundBracket)
return false;
++pos;
}
}
String column_name_prefix;
@ -1864,8 +1891,16 @@ bool ParserColumnsTransformers::parseImpl(Pos & pos, ASTPtr & node, Expected & e
}
auto res = std::make_shared<ASTColumnsApplyTransformer>();
res->func_name = getIdentifierName(func_name);
res->parameters = expr_list_args;
if (lambda)
{
res->lambda = lambda;
res->lambda_arg = lambda_arg;
}
else
{
res->func_name = getIdentifierName(func_name);
res->parameters = expr_list_args;
}
res->column_name_prefix = column_name_prefix;
node = std::move(res);
return true;

View File

@ -1 +1,4 @@
100 10 324 120.00 B 8.00 B 23.00 B
0
SELECT argMax(number, number)
FROM numbers(1)

View File

@ -5,3 +5,6 @@ INSERT INTO columns_transformers VALUES (100, 10, 324, 120, 8, 23);
SELECT * EXCEPT 'bytes', COLUMNS('bytes') APPLY formatReadableSize FROM columns_transformers;
DROP TABLE IF EXISTS columns_transformers;
SELECT * APPLY x->argMax(x, number) FROM numbers(1);
EXPLAIN SYNTAX SELECT * APPLY x->argMax(x, number) FROM numbers(1);