Fix crash in min(multiply(1)) with optimize_arithmetic_operations_in_aggregate_functions (#11756)

This commit is contained in:
Ruslan 2020-06-18 22:25:28 +03:00 committed by GitHub
parent fe24c715ca
commit f44dbcd2ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 14 deletions

View File

@ -41,15 +41,15 @@ bool onlyConstsInside(const ASTFunction * func_node)
bool inappropriateNameInside(const ASTFunction * func_node, const char * inter_func_name) bool inappropriateNameInside(const ASTFunction * func_node, const char * inter_func_name)
{ {
return (func_node->arguments->children[0]->as<ASTFunction>() && return (func_node->arguments->children[0]->as<ASTFunction>() &&
inter_func_name != func_node->arguments->children[0]->as<ASTFunction>()->name) || strcmp(inter_func_name, func_node->arguments->children[0]->as<ASTFunction>()->name.c_str()) != 0) ||
(func_node->arguments->children.size() == 2 && (func_node->arguments->children.size() == 2 &&
func_node->arguments->children[1]->as<ASTFunction>() && func_node->arguments->children[1]->as<ASTFunction>() &&
inter_func_name != func_node->arguments->children[1]->as<ASTFunction>()->name); strcmp(inter_func_name, func_node->arguments->children[1]->as<ASTFunction>()->name.c_str()) != 0);
} }
bool isInappropriate(const ASTPtr & node, const char * inter_func_name) bool isInappropriate(const ASTPtr & node, const char * inter_func_name)
{ {
return !node->as<ASTFunction>() || inter_func_name != node->as<ASTFunction>()->name; return !node->as<ASTFunction>() || (strcmp(inter_func_name, node->as<ASTFunction>()->name.c_str()) != 0);
} }
ASTFunction * getInternalFunction(const ASTFunction * f_n) ASTFunction * getInternalFunction(const ASTFunction * f_n)
@ -62,8 +62,10 @@ ASTFunction * getInternalFunction(const ASTFunction * f_n)
return f_n->arguments->children[0]->as<ASTFunction>(); return f_n->arguments->children[0]->as<ASTFunction>();
} }
ASTFunction * treeFiller(ASTFunction * old_tree, const ASTs & nodes_array, size_t size, const char * name) ASTFunction * treeFiller(ASTFunction * old_tree, const ASTs & nodes_array, size_t size, const char * name, bool flag)
{ {
if (flag)
--size;
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
{ {
old_tree->arguments->children = {}; old_tree->arguments->children = {};
@ -94,19 +96,23 @@ std::pair<ASTs, ASTs> tryGetConst(const char * name, const ASTs & arguments)
not_const.push_back(arg); not_const.push_back(arg);
} }
if ((name == plus || name == mul) && const_num.size() + not_const.size() != 2) if ((strcmp(name, plus) == 0 || strcmp(name, mul) == 0) && const_num.size() + not_const.size() != 2)
{
throw Exception("Wrong number of arguments for function 'plus' or 'multiply' (" + toString(const_num.size() + not_const.size()) + " instead of 2)", throw Exception("Wrong number of arguments for function 'plus' or 'multiply' (" + toString(const_num.size() + not_const.size()) + " instead of 2)",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
}
return {const_num, not_const}; return {const_num, not_const};
} }
std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char * inter_func_name) std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char * inter_func_name)
{ {
if (!func_node->arguments) if (func_node->arguments->children.empty())
{
if (strcmp(func_node->name.c_str(), plus) == 0 || strcmp(func_node->name.c_str(), mul) == 0)
throw Exception("Wrong number of arguments for function" + func_node->name + "(0 instead of 2)",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
else
return {}; return {};
}
if (onlyConstsInside(func_node)) if (onlyConstsInside(func_node))
return tryGetConst(func_node->name.c_str(), func_node->arguments->children); return tryGetConst(func_node->name.c_str(), func_node->arguments->children);
@ -139,7 +145,7 @@ std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char *
std::pair<ASTs, ASTs> fl = tryGetConst(func_node->name.c_str(), func_node->arguments->children); std::pair<ASTs, ASTs> fl = tryGetConst(func_node->name.c_str(), func_node->arguments->children);
ASTs first_lvl_consts = fl.first; ASTs first_lvl_consts = fl.first;
ASTs first_lvl_not_consts = fl.second; ASTs first_lvl_not_consts = fl.second;
if (!first_lvl_not_consts[0]->as<ASTFunction>()) if (first_lvl_not_consts.empty() || !first_lvl_not_consts[0]->as<ASTFunction>())
return {first_lvl_consts, first_lvl_not_consts}; return {first_lvl_consts, first_lvl_not_consts};
std::pair<ASTs, ASTs> ans = findAllConsts(first_lvl_not_consts[0]->as<ASTFunction>(), inter_func_name); std::pair<ASTs, ASTs> ans = findAllConsts(first_lvl_not_consts[0]->as<ASTFunction>(), inter_func_name);
@ -176,17 +182,21 @@ void buildTree(ASTFunction * cur_node, const char * func_name, const char * intr
ASTs cons_val = tree_comp.first; ASTs cons_val = tree_comp.first;
ASTs non_cons = tree_comp.second; ASTs non_cons = tree_comp.second;
bool not_const_empty = non_cons.empty();
cur_node->name = intro_func; cur_node->name = intro_func;
cur_node = treeFiller(cur_node, cons_val, cons_val.size(), intro_func); cur_node = treeFiller(cur_node, cons_val, cons_val.size(), intro_func, not_const_empty);
cur_node->name = func_name; cur_node->name = func_name;
if (non_cons.size() == 1) if (non_cons.empty())
cur_node->arguments->children.push_back(cons_val[cons_val.size() - 1]);
else if (non_cons.size() == 1)
cur_node->arguments->children.push_back(non_cons[0]); cur_node->arguments->children.push_back(non_cons[0]);
else else
{ {
cur_node->arguments->children.push_back(makeASTFunction(intro_func)); cur_node->arguments->children.push_back(makeASTFunction(intro_func));
cur_node = cur_node->arguments->children[0]->as<ASTFunction>(); cur_node = cur_node->arguments->children[0]->as<ASTFunction>();
cur_node = treeFiller(cur_node, non_cons, non_cons.size() - 2, intro_func); cur_node = treeFiller(cur_node, non_cons, non_cons.size() - 2, intro_func, not_const_empty);
cur_node->arguments->children = {non_cons[non_cons.size() - 2], non_cons[non_cons.size() - 1]}; cur_node->arguments->children = {non_cons[non_cons.size() - 2], non_cons[non_cons.size() - 1]};
} }
} }

View File

@ -0,0 +1,15 @@
SET optimize_arithmetic_operations_in_aggregate_functions = 1;
SELECT max(multiply(1)); -- { serverError 42 }
SELECT min(multiply(2));-- { serverError 42 }
SELECT sum(multiply(3)); -- { serverError 42 }
SELECT max(plus(1)); -- { serverError 42 }
SELECT min(plus(2)); -- { serverError 42 }
SELECT sum(plus(3)); -- { serverError 42 }
SELECT max(multiply()); -- { serverError 42 }
SELECT min(multiply(1, 2 ,3)); -- { serverError 42 }
SELECT sum(plus() + multiply()); -- { serverError 42 }
SELECT sum(plus(multiply(42, 3), multiply(42))); -- { serverError 42 }