Rewriter visitor and add performance test

This commit is contained in:
hexiaoting 2021-03-12 15:29:38 +08:00
parent 055073931a
commit ab2aaa7fe7
2 changed files with 66 additions and 70 deletions

View File

@ -184,18 +184,18 @@ struct CustomizeAggregateFunctionsMoveSuffixData
struct FuseFunctions
{
std::vector<const ASTFunction *> sums {};
std::vector<const ASTFunction *> counts {};
std::vector<const ASTFunction *> avgs {};
std::vector<ASTFunction *> sums {};
std::vector<ASTFunction *> counts {};
std::vector<ASTFunction *> avgs {};
void addFuncNode(const ASTFunction * func)
void addFuncNode(ASTFunction & func)
{
if (func->name == "sum")
sums.push_back(func);
else if (func->name == "count")
counts.push_back(func);
else if (func->name == "avg")
avgs.push_back(func);
if (func.name == "sum")
sums.push_back(&func);
else if (func.name == "count")
counts.push_back(&func);
else if (func.name == "avg")
avgs.push_back(&func);
}
bool canBeFused() const
@ -214,9 +214,9 @@ struct CustomizeFuseAggregateFunctionsData
{
using TypeToVisit = ASTFunction;
std::unordered_map<String, DB::FuseFunctions> fuse_map {};
std::unordered_map<String, DB::FuseFunctions> fuse_map;
void visit(ASTFunction & func, ASTPtr &) const
void visit(ASTFunction & func, ASTPtr &)
{
if (func.name == "sum" || func.name == "avg" || func.name == "count")
{
@ -226,32 +226,16 @@ struct CustomizeFuseAggregateFunctionsData
ASTIdentifier * ident = func.arguments->children.at(0)->as<ASTIdentifier>();
if (!ident)
return;
auto it = fuse_map.find(ident->name());
if (it != fuse_map.end() && it->second.canBeFused())
if (it != fuse_map.end())
{
auto func_base = makeASTFunction("sumCount", func.arguments->children.at(0)->clone());
auto exp_list = std::make_shared<ASTExpressionList>();
if (func.name == "sum" || func.name == "count")
{
/// Rewrite "sum" to sumCount().1, rewrite "count" to sumCount().2
UInt8 idx = (func.name == "sum" ? 1 : 2);
func.name = "tupleElement";
exp_list->children.push_back(func_base);
exp_list->children.push_back(std::make_shared<ASTLiteral>(idx));
}
else
{
/// Rewrite "avg" to sumCount().1 / sumCount().2
auto new_arg1 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(1)));
auto new_arg2 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(2)));
func.name = "divide";
exp_list->children.push_back(new_arg1);
exp_list->children.push_back(new_arg2);
}
func.arguments = exp_list;
func.children.push_back(func.arguments);
it->second.addFuncNode(func);
}
else
{
DB::FuseFunctions funcs{};
funcs.addFuncNode(func);
fuse_map[ident->name()] = funcs;
}
}
}
@ -277,34 +261,43 @@ void translateQualifiedNames(ASTPtr & query, const ASTSelectQuery & select_query
throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED);
}
void gatherFuseFunctions(std::unordered_map<String, DB::FuseFunctions> &fuse_map, std::vector<const ASTFunction *> &aggregates)
void rewriterFusedFunction(String column_name, ASTFunction & func)
{
for (auto & func : aggregates)
auto func_base = makeASTFunction("sumCount", std::make_shared<ASTIdentifier>(column_name));
auto exp_list = std::make_shared<ASTExpressionList>();
if (func.name == "sum" || func.name == "count")
{
if ((func->name == "sum" || func->name == "avg" || func->name == "count") && func->arguments->children.size() == 1)
/// Rewrite "sum" to sumCount().1, rewrite "count" to sumCount().2
UInt8 idx = (func.name == "sum" ? 1 : 2);
func.name = "tupleElement";
exp_list->children.push_back(func_base);
exp_list->children.push_back(std::make_shared<ASTLiteral>(idx));
}
else
{
/// Rewrite "avg" to sumCount().1 / sumCount().2
auto new_arg1 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(1)));
auto new_arg2 = makeASTFunction("tupleElement", func_base, std::make_shared<ASTLiteral>(UInt8(2)));
func.name = "divide";
exp_list->children.push_back(new_arg1);
exp_list->children.push_back(new_arg2);
}
func.arguments = exp_list;
func.children.push_back(func.arguments);
}
void fuseCandidates(std::unordered_map<String, DB::FuseFunctions> &fuse_map)
{
for (auto & it : fuse_map)
{
if (it.second.canBeFused())
{
if (func->arguments->children.empty())
return;
ASTIdentifier * ident = func->arguments->children.at(0)->as<ASTIdentifier>();
if (!ident)
return;
ASTIdentifier * column = (func->arguments->children.at(0))->as<ASTIdentifier>();
if (!column)
return;
auto it = fuse_map.find(column->name());
if (it != fuse_map.end())
{
it->second.addFuncNode(func);
}
else
{
DB::FuseFunctions funcs{};
funcs.addFuncNode(func);
fuse_map.emplace(column->name(), funcs);
}
for (auto & func: it.second.sums)
rewriterFusedFunction(it.first, *func);
for (auto & func: it.second.avgs)
rewriterFusedFunction(it.first, *func);
for (auto & func: it.second.counts)
rewriterFusedFunction(it.first, *func);
}
}
}
@ -1028,15 +1021,9 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const Settings &
/// Try to fuse sum/avg/count with identical column(at least two functions exist) to sumCount()
if (settings.optimize_fuse_sum_count_avg)
{
/// Get statistics about sum/avg/count
GetAggregatesVisitor::Data data;
GetAggregatesVisitor(data).visit(query);
std::unordered_map<String, DB::FuseFunctions> fuse_map;
gatherFuseFunctions(fuse_map, data.aggregates);
/// Try to fuse
CustomizeFuseAggregateFunctionsVisitor::Data data_fuse{.fuse_map = fuse_map};
CustomizeFuseAggregateFunctionsVisitor(data_fuse).visit(query);
CustomizeFuseAggregateFunctionsVisitor::Data data;
CustomizeFuseAggregateFunctionsVisitor(data).visit(query);
fuseCandidates(data.fuse_map);
}
/// Rewrite all aggregate functions to add -OrNull suffix to them

View File

@ -0,0 +1,9 @@
<test max_ignored_relative_change="0.4">
<settings>
<optimize_fuse_sum_count_avg>true</optimize_fuse_sum_count_avg>
</settings>
<query>SELECT sum(number), avg(number) FROM numbers(100000000)</query>
<query>SELECT sum(number), count(number) FROM numbers(100000000)</query>
<query>SELECT avg(number), count(number) FROM numbers(100000000)</query>
<query>SELECT sum(number), avg(number), count(number) FROM numbers(100000000)</query>
</test>