better code

This commit is contained in:
Anton Popov 2020-06-18 20:05:26 +03:00
parent 23cd919681
commit bd28b7e1c2
3 changed files with 28 additions and 19 deletions

View File

@ -18,8 +18,8 @@ namespace ErrorCodes
namespace namespace
{ {
constexpr const char * any = "any"; constexpr auto * any = "any";
constexpr const char * anyLast = "anyLast"; constexpr auto * anyLast = "anyLast";
} }
ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind) ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind)
@ -30,12 +30,12 @@ ASTPtr * getExactChild(const ASTPtr & ast, const size_t ind)
} }
///recursive searching of identifiers ///recursive searching of identifiers
void changeAllIdentifiers(ASTPtr & ast, size_t ind, std::string& mode) void changeAllIdentifiers(ASTPtr & ast, size_t ind, const std::string & name)
{ {
const char * name = mode.c_str();
ASTPtr * exact_child = getExactChild(ast, ind); ASTPtr * exact_child = getExactChild(ast, ind);
if (!exact_child) if (!exact_child)
return; return;
if ((*exact_child)->as<ASTIdentifier>()) if ((*exact_child)->as<ASTIdentifier>())
{ {
///put new any ///put new any
@ -43,14 +43,15 @@ void changeAllIdentifiers(ASTPtr & ast, size_t ind, std::string& mode)
*exact_child = makeASTFunction(name); *exact_child = makeASTFunction(name);
(*exact_child)->as<ASTFunction>()->arguments->children.push_back(old_ast); (*exact_child)->as<ASTFunction>()->arguments->children.push_back(old_ast);
} }
else if ((*exact_child)->as<ASTFunction>() && else if ((*exact_child)->as<ASTFunction>())
!AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name)) {
if (AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name))
throw Exception("Aggregate function " + (*exact_child)->as<ASTFunction>()->name +
" is found inside aggregate function " + name + " in query", ErrorCodes::ILLEGAL_AGGREGATION);
for (size_t i = 0; i < (*exact_child)->as<ASTFunction>()->arguments->children.size(); i++) for (size_t i = 0; i < (*exact_child)->as<ASTFunction>()->arguments->children.size(); i++)
changeAllIdentifiers(*exact_child, i, mode); changeAllIdentifiers(*exact_child, i, name);
else if ((*exact_child)->as<ASTFunction>() && }
AggregateFunctionFactory::instance().isAggregateFunctionName((*exact_child)->as<ASTFunction>()->name))
throw Exception("Aggregate function " + (*exact_child)->as<ASTFunction>()->name +
" is found inside aggregate function " + name + " in query", ErrorCodes::ILLEGAL_AGGREGATION);
} }
@ -62,18 +63,20 @@ void AnyInputMatcher::visit(ASTPtr & current_ast, Data data)
return; return;
auto * function_node = current_ast->as<ASTFunction>(); auto * function_node = current_ast->as<ASTFunction>();
if (function_node && (function_node->name == any || function_node->name == anyLast) if (!function_node || function_node->arguments->children.empty())
&& !function_node->arguments->children.empty() && function_node->arguments->children[0] && return;
function_node->arguments->children[0]->as<ASTFunction>())
const auto & function_argument = function_node->arguments->children[0];
if ((function_node->name == any || function_node->name == anyLast)
&& function_argument && function_argument->as<ASTFunction>())
{ {
std::string mode = function_node->name; auto name = function_node->name;
///cut any or anyLast ///cut any or anyLast
if (function_node->arguments->children[0]->as<ASTFunction>() && if (!function_argument->as<ASTFunction>()->arguments->children.empty())
!function_node->arguments->children[0]->as<ASTFunction>()->arguments->children.empty())
{ {
current_ast = (function_node->arguments->children[0])->clone(); current_ast = function_argument->clone();
for (size_t i = 0; i < current_ast->as<ASTFunction>()->arguments->children.size(); ++i) for (size_t i = 0; i < current_ast->as<ASTFunction>()->arguments->children.size(); ++i)
changeAllIdentifiers(current_ast, i, mode); changeAllIdentifiers(current_ast, i, name);
} }
} }
} }

View File

@ -0,0 +1,2 @@
9
SELECT any(number) + (any(number) * 2)\nFROM numbers(3, 10)

View File

@ -0,0 +1,4 @@
SET optimize_any_input=1;
SET enable_debug_queries=1;
SELECT any(number + number * 2) FROM numbers(3, 10);
ANALYZE SELECT any(number + number * 2) FROM numbers(3, 10);