mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 08:02:02 +00:00
Fix crash in min(multiply(1))
with optimize_arithmetic_operations_in_aggregate_functions (#11756)
This commit is contained in:
parent
fe24c715ca
commit
f44dbcd2ab
@ -41,29 +41,31 @@ bool onlyConstsInside(const ASTFunction * func_node)
|
|||||||
bool inappropriateNameInside(const ASTFunction * func_node, const char * inter_func_name)
|
bool inappropriateNameInside(const ASTFunction * func_node, const char * inter_func_name)
|
||||||
{
|
{
|
||||||
return (func_node->arguments->children[0]->as<ASTFunction>() &&
|
return (func_node->arguments->children[0]->as<ASTFunction>() &&
|
||||||
inter_func_name != func_node->arguments->children[0]->as<ASTFunction>()->name) ||
|
strcmp(inter_func_name, func_node->arguments->children[0]->as<ASTFunction>()->name.c_str()) != 0) ||
|
||||||
(func_node->arguments->children.size() == 2 &&
|
(func_node->arguments->children.size() == 2 &&
|
||||||
func_node->arguments->children[1]->as<ASTFunction>() &&
|
func_node->arguments->children[1]->as<ASTFunction>() &&
|
||||||
inter_func_name != func_node->arguments->children[1]->as<ASTFunction>()->name);
|
strcmp(inter_func_name, func_node->arguments->children[1]->as<ASTFunction>()->name.c_str()) != 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool isInappropriate(const ASTPtr & node, const char * inter_func_name)
|
bool isInappropriate(const ASTPtr & node, const char * inter_func_name)
|
||||||
{
|
{
|
||||||
return !node->as<ASTFunction>() || inter_func_name != node->as<ASTFunction>()->name;
|
return !node->as<ASTFunction>() || (strcmp(inter_func_name, node->as<ASTFunction>()->name.c_str()) != 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
ASTFunction * getInternalFunction(const ASTFunction * f_n)
|
ASTFunction * getInternalFunction(const ASTFunction * f_n)
|
||||||
{
|
{
|
||||||
const auto * function_args = f_n->arguments->as<ASTExpressionList>();
|
const auto * function_args = f_n->arguments->as<ASTExpressionList>();
|
||||||
if (!function_args || function_args->children.size() != 1)
|
if (!function_args || function_args->children.size() != 1)
|
||||||
throw Exception("Wrong number of arguments for function" + f_n->name + "(" + toString(function_args->children.size()) + " instead of 1)",
|
throw Exception("Wrong number of arguments for function " + f_n->name + "(" + toString(function_args->children.size()) + " instead of 1)",
|
||||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
|
||||||
return f_n->arguments->children[0]->as<ASTFunction>();
|
return f_n->arguments->children[0]->as<ASTFunction>();
|
||||||
}
|
}
|
||||||
|
|
||||||
ASTFunction * treeFiller(ASTFunction * old_tree, const ASTs & nodes_array, size_t size, const char * name)
|
ASTFunction * treeFiller(ASTFunction * old_tree, const ASTs & nodes_array, size_t size, const char * name, bool flag)
|
||||||
{
|
{
|
||||||
|
if (flag)
|
||||||
|
--size;
|
||||||
for (size_t i = 0; i < size; ++i)
|
for (size_t i = 0; i < size; ++i)
|
||||||
{
|
{
|
||||||
old_tree->arguments->children = {};
|
old_tree->arguments->children = {};
|
||||||
@ -94,19 +96,23 @@ std::pair<ASTs, ASTs> tryGetConst(const char * name, const ASTs & arguments)
|
|||||||
not_const.push_back(arg);
|
not_const.push_back(arg);
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((name == plus || name == mul) && const_num.size() + not_const.size() != 2)
|
if ((strcmp(name, plus) == 0 || strcmp(name, mul) == 0) && const_num.size() + not_const.size() != 2)
|
||||||
{
|
|
||||||
throw Exception("Wrong number of arguments for function 'plus' or 'multiply' (" + toString(const_num.size() + not_const.size()) + " instead of 2)",
|
throw Exception("Wrong number of arguments for function 'plus' or 'multiply' (" + toString(const_num.size() + not_const.size()) + " instead of 2)",
|
||||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
}
|
|
||||||
|
|
||||||
return {const_num, not_const};
|
return {const_num, not_const};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char * inter_func_name)
|
std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char * inter_func_name)
|
||||||
{
|
{
|
||||||
if (!func_node->arguments)
|
if (func_node->arguments->children.empty())
|
||||||
return {};
|
{
|
||||||
|
if (strcmp(func_node->name.c_str(), plus) == 0 || strcmp(func_node->name.c_str(), mul) == 0)
|
||||||
|
throw Exception("Wrong number of arguments for function" + func_node->name + "(0 instead of 2)",
|
||||||
|
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||||
|
else
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
if (onlyConstsInside(func_node))
|
if (onlyConstsInside(func_node))
|
||||||
return tryGetConst(func_node->name.c_str(), func_node->arguments->children);
|
return tryGetConst(func_node->name.c_str(), func_node->arguments->children);
|
||||||
@ -139,7 +145,7 @@ std::pair<ASTs, ASTs> findAllConsts(const ASTFunction * func_node, const char *
|
|||||||
std::pair<ASTs, ASTs> fl = tryGetConst(func_node->name.c_str(), func_node->arguments->children);
|
std::pair<ASTs, ASTs> fl = tryGetConst(func_node->name.c_str(), func_node->arguments->children);
|
||||||
ASTs first_lvl_consts = fl.first;
|
ASTs first_lvl_consts = fl.first;
|
||||||
ASTs first_lvl_not_consts = fl.second;
|
ASTs first_lvl_not_consts = fl.second;
|
||||||
if (!first_lvl_not_consts[0]->as<ASTFunction>())
|
if (first_lvl_not_consts.empty() || !first_lvl_not_consts[0]->as<ASTFunction>())
|
||||||
return {first_lvl_consts, first_lvl_not_consts};
|
return {first_lvl_consts, first_lvl_not_consts};
|
||||||
|
|
||||||
std::pair<ASTs, ASTs> ans = findAllConsts(first_lvl_not_consts[0]->as<ASTFunction>(), inter_func_name);
|
std::pair<ASTs, ASTs> ans = findAllConsts(first_lvl_not_consts[0]->as<ASTFunction>(), inter_func_name);
|
||||||
@ -176,17 +182,21 @@ void buildTree(ASTFunction * cur_node, const char * func_name, const char * intr
|
|||||||
ASTs cons_val = tree_comp.first;
|
ASTs cons_val = tree_comp.first;
|
||||||
ASTs non_cons = tree_comp.second;
|
ASTs non_cons = tree_comp.second;
|
||||||
|
|
||||||
|
bool not_const_empty = non_cons.empty();
|
||||||
|
|
||||||
cur_node->name = intro_func;
|
cur_node->name = intro_func;
|
||||||
cur_node = treeFiller(cur_node, cons_val, cons_val.size(), intro_func);
|
cur_node = treeFiller(cur_node, cons_val, cons_val.size(), intro_func, not_const_empty);
|
||||||
cur_node->name = func_name;
|
cur_node->name = func_name;
|
||||||
|
|
||||||
if (non_cons.size() == 1)
|
if (non_cons.empty())
|
||||||
|
cur_node->arguments->children.push_back(cons_val[cons_val.size() - 1]);
|
||||||
|
else if (non_cons.size() == 1)
|
||||||
cur_node->arguments->children.push_back(non_cons[0]);
|
cur_node->arguments->children.push_back(non_cons[0]);
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
cur_node->arguments->children.push_back(makeASTFunction(intro_func));
|
cur_node->arguments->children.push_back(makeASTFunction(intro_func));
|
||||||
cur_node = cur_node->arguments->children[0]->as<ASTFunction>();
|
cur_node = cur_node->arguments->children[0]->as<ASTFunction>();
|
||||||
cur_node = treeFiller(cur_node, non_cons, non_cons.size() - 2, intro_func);
|
cur_node = treeFiller(cur_node, non_cons, non_cons.size() - 2, intro_func, not_const_empty);
|
||||||
cur_node->arguments->children = {non_cons[non_cons.size() - 2], non_cons[non_cons.size() - 1]};
|
cur_node->arguments->children = {non_cons[non_cons.size() - 2], non_cons[non_cons.size() - 1]};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,15 @@
|
|||||||
|
SET optimize_arithmetic_operations_in_aggregate_functions = 1;
|
||||||
|
|
||||||
|
SELECT max(multiply(1)); -- { serverError 42 }
|
||||||
|
SELECT min(multiply(2));-- { serverError 42 }
|
||||||
|
SELECT sum(multiply(3)); -- { serverError 42 }
|
||||||
|
|
||||||
|
SELECT max(plus(1)); -- { serverError 42 }
|
||||||
|
SELECT min(plus(2)); -- { serverError 42 }
|
||||||
|
SELECT sum(plus(3)); -- { serverError 42 }
|
||||||
|
|
||||||
|
SELECT max(multiply()); -- { serverError 42 }
|
||||||
|
SELECT min(multiply(1, 2 ,3)); -- { serverError 42 }
|
||||||
|
SELECT sum(plus() + multiply()); -- { serverError 42 }
|
||||||
|
|
||||||
|
SELECT sum(plus(multiply(42, 3), multiply(42))); -- { serverError 42 }
|
Loading…
Reference in New Issue
Block a user