Rewriter fuse logical

This commit is contained in:
hexiaoting 2021-03-08 11:58:18 +08:00
parent 4fe75ad168
commit 9dab51698f
4 changed files with 84 additions and 26 deletions

View File

@ -25,7 +25,6 @@ public:
std::unordered_set<String> uniq_names {}; std::unordered_set<String> uniq_names {};
std::vector<const ASTFunction *> aggregates {}; std::vector<const ASTFunction *> aggregates {};
std::vector<const ASTFunction *> window_functions {}; std::vector<const ASTFunction *> window_functions {};
std::map<String, UInt8> fuse_sum_count_avg {};
}; };
static bool needChildVisit(const ASTPtr & node, const ASTPtr & child) static bool needChildVisit(const ASTPtr & node, const ASTPtr & child)
@ -73,14 +72,6 @@ private:
data.uniq_names.insert(column_name); data.uniq_names.insert(column_name);
data.aggregates.push_back(&node); data.aggregates.push_back(&node);
if ((node.name == "sum" || node.name == "avg" || node.name == "count") && node.arguments->children.size() == 1)
{
const auto & argument = node.arguments->children.at(0);
ASTIdentifier * column = argument->as<ASTIdentifier>();
if (column)
data.fuse_sum_count_avg[column->name()] |= ((node.name == "sum") ? 0x1 : ((node.name == "count") ? 0x2 : 0x4));
}
} }
else if (node.is_window_function) else if (node.is_window_function)
{ {

View File

@ -182,29 +182,53 @@ struct CustomizeAggregateFunctionsMoveSuffixData
} }
}; };
struct FuseFunctions
{
std::vector<const ASTFunction *> sums {};
std::vector<const ASTFunction *> counts {};
std::vector<const ASTFunction *> avgs {};
void addFuncNode(const ASTFunction * func)
{
if (func->name == "sum")
sums.push_back(func);
else if (func->name == "count")
counts.push_back(func);
else if (func->name == "avg")
avgs.push_back(func);
}
bool canBeFused() const
{
if (sums.empty() && counts.empty())
return false;
if (sums.empty() && avgs.empty())
return false;
if (counts.empty() && avgs.empty())
return false;
return true;
}
};
struct CustomizeFuseAggregateFunctionsData struct CustomizeFuseAggregateFunctionsData
{ {
using TypeToVisit = ASTFunction; using TypeToVisit = ASTFunction;
std::map<String, UInt8> fuse_info; std::unordered_map<String, DB::FuseFunctions> fuse_map {};
static inline UInt8 bitCount(UInt8 n)
{
UInt8 c = 0;
for (c = 0; n; n >>= 1)
c += n & 1;
return c;
}
void visit(ASTFunction & func, ASTPtr &) const void visit(ASTFunction & func, ASTPtr &) const
{ {
if (func.name == "sum" || func.name == "avg" || func.name == "count") if (func.name == "sum" || func.name == "avg" || func.name == "count")
{ {
if (func.arguments->children.size() == 0)
return;
ASTIdentifier * ident = func.arguments->children.at(0)->as<ASTIdentifier>(); ASTIdentifier * ident = func.arguments->children.at(0)->as<ASTIdentifier>();
if (!ident) if (!ident)
return; return;
auto column = fuse_info.find(ident->name());
if (column != fuse_info.end() && bitCount(column->second) > 1) auto it = fuse_map.find(ident->name());
if (it != fuse_map.end() && it->second.canBeFused())
{ {
auto func_base = makeASTFunction("sumCount", func.arguments->children.at(0)->clone()); auto func_base = makeASTFunction("sumCount", func.arguments->children.at(0)->clone());
auto exp_list = std::make_shared<ASTExpressionList>(); auto exp_list = std::make_shared<ASTExpressionList>();
@ -253,6 +277,38 @@ void translateQualifiedNames(ASTPtr & query, const ASTSelectQuery & select_query
throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED); throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED);
} }
void gatherFuseFunctions(std::unordered_map<String, DB::FuseFunctions> &fuse_map, std::vector<const ASTFunction *> &aggregates)
{
for (auto & func : aggregates)
{
if ((func->name == "sum" || func->name == "avg" || func->name == "count") && func->arguments->children.size() == 1)
{
if (func->arguments->children.size() == 0)
return;
ASTIdentifier * ident = func->arguments->children.at(0)->as<ASTIdentifier>();
if (!ident)
return;
ASTIdentifier * column = (func->arguments->children.at(0))->as<ASTIdentifier>();
if (!column)
return;
auto it = fuse_map.find(column->name());
if (it != fuse_map.end())
{
it->second.addFuncNode(func);
}
else
{
DB::FuseFunctions funcs{};
funcs.addFuncNode(func);
fuse_map.emplace(column->name(), funcs);
}
}
}
}
bool hasArrayJoin(const ASTPtr & ast) bool hasArrayJoin(const ASTPtr & ast)
{ {
if (const ASTFunction * function = ast->as<ASTFunction>()) if (const ASTFunction * function = ast->as<ASTFunction>())
@ -969,19 +1025,21 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const Settings &
CustomizeGlobalNotInVisitor(data_global_not_null_in).visit(query); CustomizeGlobalNotInVisitor(data_global_not_null_in).visit(query);
} }
// Try to fuse sum/avg/count with identical column(at least two functions exist) to sumCount() /// Try to fuse sum/avg/count with identical column(at least two functions exist) to sumCount()
if (settings.optimize_fuse_sum_count_avg) if (settings.optimize_fuse_sum_count_avg)
{ {
// Get statistics about sum/avg/count /// Get statistics about sum/avg/count
GetAggregatesVisitor::Data data; GetAggregatesVisitor::Data data;
GetAggregatesVisitor(data).visit(query); GetAggregatesVisitor(data).visit(query);
std::unordered_map<String, DB::FuseFunctions> fuse_map;
gatherFuseFunctions(fuse_map, data.aggregates);
// Try to fuse /// Try to fuse
CustomizeFuseAggregateFunctionsVisitor::Data data_fuse{.fuse_info = data.fuse_sum_count_avg}; CustomizeFuseAggregateFunctionsVisitor::Data data_fuse{.fuse_map = fuse_map};
CustomizeFuseAggregateFunctionsVisitor(data_fuse).visit(query); CustomizeFuseAggregateFunctionsVisitor(data_fuse).visit(query);
} }
// Rewrite all aggregate functions to add -OrNull suffix to them /// Rewrite all aggregate functions to add -OrNull suffix to them
if (settings.aggregate_functions_null_for_empty) if (settings.aggregate_functions_null_for_empty)
{ {
CustomizeAggregateFunctionsOrNullVisitor::Data data_or_null{"OrNull"}; CustomizeAggregateFunctionsOrNullVisitor::Data data_or_null{"OrNull"};

View File

@ -4,3 +4,9 @@ SELECT
sumCount(b).1, sumCount(b).1,
sumCount(b).2 sumCount(b).2
FROM fuse_tbl FROM fuse_tbl
---------NOT trigger fuse--------
210 11.5
SELECT
sum(a),
avg(b)
FROM fuse_tbl

View File

@ -2,7 +2,10 @@ DROP TABLE IF EXISTS fuse_tbl;
CREATE TABLE fuse_tbl(a Int8, b Int8) Engine = Log; CREATE TABLE fuse_tbl(a Int8, b Int8) Engine = Log;
INSERT INTO fuse_tbl SELECT number, number + 1 FROM numbers(1, 20); INSERT INTO fuse_tbl SELECT number, number + 1 FROM numbers(1, 20);
set optimize_fuse_sum_count_avg = 1; SET optimize_fuse_sum_count_avg = 1;
SELECT sum(a), sum(b), count(b) from fuse_tbl; SELECT sum(a), sum(b), count(b) from fuse_tbl;
EXPLAIN SYNTAX SELECT sum(a), sum(b), count(b) from fuse_tbl; EXPLAIN SYNTAX SELECT sum(a), sum(b), count(b) from fuse_tbl;
SELECT '---------NOT trigger fuse--------';
SELECT sum(a), avg(b) from fuse_tbl;
EXPLAIN SYNTAX SELECT sum(a), avg(b) from fuse_tbl;
DROP TABLE fuse_tbl; DROP TABLE fuse_tbl;