better code

This commit is contained in:
Anton Popov 2020-06-18 20:05:26 +03:00
parent 23cd919681
commit bd28b7e1c2
3 changed files with 28 additions and 19 deletions

View File

@ -18,8 +18,8 @@ namespace ErrorCodes
namespace
{
constexpr const char * any = "any";
constexpr const char * anyLast = "anyLast";
constexpr auto * any = "any";
constexpr auto * anyLast = "anyLast";
}
ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind)
@ -30,12 +30,12 @@ ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind)
}
///recursive searching of identifiers
void changeAllIdentifiers(ASTPtr & ast, size_t ind, std::string& mode)
void changeAllIdentifiers(ASTPtr & ast, size_t ind, const std::string & name)
{
const char * name = mode.c_str();
ASTPtr * exact_child = getExactChild(ast, ind);
if (!exact_child)
return;
if ((*exact_child)->as<ASTIdentifier>())
{
///put new any
@ -43,14 +43,15 @@ void changeAllIdentifiers(ASTPtr & ast, size_t ind, std::string& mode)
*exact_child = makeASTFunction(name);
(*exact_child)->as<ASTFunction>()->arguments->children.push_back(old_ast);
}
else if ((*exact_child)->as<ASTFunction>() &&
!AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name))
else if ((*exact_child)->as<ASTFunction>())
{
if (AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name))
throw Exception("Aggregate function " + (*exact_child)->as<ASTFunction>()->name +
" is found inside aggregate function " + name + " in query", ErrorCodes::ILLEGAL_AGGREGATION);
for (size_t i = 0; i < (*exact_child)->as<ASTFunction>()->arguments->children.size(); i++)
changeAllIdentifiers(*exact_child, i, mode);
else if ((*exact_child)->as<ASTFunction>() &&
AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name))
throw Exception("Aggregate function " + (*exact_child)->as<ASTFunction>()->name +
" is found inside aggregate function " + name + " in query", ErrorCodes::ILLEGAL_AGGREGATION);
changeAllIdentifiers(*exact_child, i, name);
}
}
@ -62,18 +63,20 @@ void AnyInputMatcher::visit(ASTPtr & current_ast, Data data)
return;
auto * function_node = current_ast->as<ASTFunction>();
if (function_node && (function_node->name == any || function_node->name == anyLast)
&& !function_node->arguments->children.empty() && function_node->arguments->children[0] &&
function_node->arguments->children[0]->as<ASTFunction>())
if (!function_node || function_node->arguments->children.empty())
return;
const auto & function_argument = function_node->arguments->children[0];
if ((function_node->name == any || function_node->name == anyLast)
&& function_argument && function_argument->as<ASTFunction>())
{
std::string mode = function_node->name;
auto name = function_node->name;
///cut any or anyLast
if (function_node->arguments->children[0]->as<ASTFunction>() &&
!function_node->arguments->children[0]->as<ASTFunction>()->arguments->children.empty())
if (!function_argument->as<ASTFunction>()->arguments->children.empty())
{
current_ast = (function_node->arguments->children[0])->clone();
current_ast = function_argument->clone();
for (size_t i = 0; i < current_ast->as<ASTFunction>()->arguments->children.size(); ++i)
changeAllIdentifiers(current_ast, i, mode);
changeAllIdentifiers(current_ast, i, name);
}
}
}

View File

@ -0,0 +1,2 @@
9
SELECT any(number) + (any(number) * 2)\nFROM numbers(3, 10)

View File

@ -0,0 +1,4 @@
SET optimize_any_input=1;
SET enable_debug_queries=1;
SELECT any(number + number * 2) FROM numbers(3, 10);
ANALYZE SELECT any(number + number * 2) FROM numbers(3, 10);