Fix crash in min(multiply(1)) with optimize_arithmetic_operations_in_aggregate_functions (#11756)

This commit is contained in:
Ruslan 2020-06-18 22:25:28 +03:00 committed by GitHub
parent fe24c715ca
commit f44dbcd2ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 14 deletions

View File

@ -41,29 +41,31 @@ bool onlyConstsInside(const ASTFunction * func_node)
bool inappropriateNameInside(const ASTFunction * func_node, const char * inter_func_name)
{
return (func_node->arguments->children[0]->as<ASTFunction>() &&
inter_func_name != func_node->arguments->children[0]->as<ASTFunction>()->name) ||
strcmp(inter_func_name, func_node->arguments->children[0]->as<ASTFunction>()->name.c_str()) != 0) ||
(func_node->arguments->children.size() == 2 &&
func_node->arguments->children[1]->as<ASTFunction>() &&
inter_func_name != func_node->arguments->children[1]->as<ASTFunction>()->name);
strcmp(inter_func_name, func_node->arguments->children[1]->as<ASTFunction>()->name.c_str()) != 0);
}
bool isInappropriate(const ASTPtr & node, const char * inter_func_name)
{
return !node->as<ASTFunction>() || inter_func_name != node->as<ASTFunction>()->name;
return !node->as<ASTFunction>() || (strcmp(inter_func_name, node->as<ASTFunction>()->name.c_str()) != 0);
}
ASTFunction * getInternalFunction(const ASTFunction * f_n)
{
const auto * function_args = f_n->arguments->as<ASTExpressionList>();
if (!function_args || function_args->children.size() != 1)
throw Exception("Wrong number of arguments for function" + f_n->name + "(" + toString(function_args->children.size()) + " instead of 1)",
throw Exception("Wrong number of arguments for function " + f_n->name + "(" + toString(function_args->children.size()) + " instead of 1)",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return f_n->arguments->children[0]->as<ASTFunction>();
}
ASTFunction * treeFiller(ASTFunction * old_tree, const ASTs & nodes_array, size_t size, const char * name)
ASTFunction * treeFiller(ASTFunction * old_tree, const ASTs & nodes_array, size_t size, const char * name, bool flag)
{
if (flag)
--size;
for (size_t i = 0; i < size; ++i)
{
old_tree->arguments->children = {};
@ -94,19 +96,23 @@ std::pair<ASTs, ASTs> tryGetConst(const char * name, const ASTs & arguments)
not_const.push_back(arg);
}
if ((name == plus || name == mul) && const_num.size() + not_const.size() != 2)
{
if ((strcmp(name, plus) == 0 || strcmp(name, mul) == 0) && const_num.size() + not_const.size() != 2)
throw Exception("Wrong number of arguments for function 'plus' or 'multiply' (" + toString(const_num.size() + not_const.size()) + " instead of 2)",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
return {const_num, not_const};
}
std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char * inter_func_name)
{
if (!func_node->arguments)
return {};
if (func_node->arguments->children.empty())
{
if (strcmp(func_node->name.c_str(), plus) == 0 || strcmp(func_node->name.c_str(), mul) == 0)
throw Exception("Wrong number of arguments for function" + func_node->name + "(0 instead of 2)",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
else
return {};
}
if (onlyConstsInside(func_node))
return tryGetConst(func_node->name.c_str(), func_node->arguments->children);
@ -139,7 +145,7 @@ std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char *
std::pair<ASTs, ASTs> fl = tryGetConst(func_node->name.c_str(), func_node->arguments->children);
ASTs first_lvl_consts = fl.first;
ASTs first_lvl_not_consts = fl.second;
if (!first_lvl_not_consts[0]->as<ASTFunction>())
if (first_lvl_not_consts.empty() || !first_lvl_not_consts[0]->as<ASTFunction>())
return {first_lvl_consts, first_lvl_not_consts};
std::pair<ASTs, ASTs> ans = findAllConsts(first_lvl_not_consts[0]->as<ASTFunction>(), inter_func_name);
@ -176,17 +182,21 @@ void buildTree(ASTFunction * cur_node, const char * func_name, const char * intr
ASTs cons_val = tree_comp.first;
ASTs non_cons = tree_comp.second;
bool not_const_empty = non_cons.empty();
cur_node->name = intro_func;
cur_node = treeFiller(cur_node, cons_val, cons_val.size(), intro_func);
cur_node = treeFiller(cur_node, cons_val, cons_val.size(), intro_func, not_const_empty);
cur_node->name = func_name;
if (non_cons.size() == 1)
if (non_cons.empty())
cur_node->arguments->children.push_back(cons_val[cons_val.size() - 1]);
else if (non_cons.size() == 1)
cur_node->arguments->children.push_back(non_cons[0]);
else
{
cur_node->arguments->children.push_back(makeASTFunction(intro_func));
cur_node = cur_node->arguments->children[0]->as<ASTFunction>();
cur_node = treeFiller(cur_node, non_cons, non_cons.size() - 2, intro_func);
cur_node = treeFiller(cur_node, non_cons, non_cons.size() - 2, intro_func, not_const_empty);
cur_node->arguments->children = {non_cons[non_cons.size() - 2], non_cons[non_cons.size() - 1]};
}
}

View File

@ -0,0 +1,15 @@
SET optimize_arithmetic_operations_in_aggregate_functions = 1;
SELECT max(multiply(1)); -- { serverError 42 }
SELECT min(multiply(2));-- { serverError 42 }
SELECT sum(multiply(3)); -- { serverError 42 }
SELECT max(plus(1)); -- { serverError 42 }
SELECT min(plus(2)); -- { serverError 42 }
SELECT sum(plus(3)); -- { serverError 42 }
SELECT max(multiply()); -- { serverError 42 }
SELECT min(multiply(1, 2 ,3)); -- { serverError 42 }
SELECT sum(plus() + multiply()); -- { serverError 42 }
SELECT sum(plus(multiply(42, 3), multiply(42))); -- { serverError 42 }