From 9dab51698fdfab9604ab67282496503c7bb5f773 Mon Sep 17 00:00:00 2001 From: hexiaoting Date: Mon, 8 Mar 2021 11:58:18 +0800 Subject: [PATCH] Rewriter fuse logical --- src/Interpreters/GetAggregatesVisitor.h | 9 -- src/Interpreters/TreeRewriter.cpp | 90 +++++++++++++++---- .../01744_fuse_sum_count_aggregate.reference | 6 ++ .../01744_fuse_sum_count_aggregate.sql | 5 +- 4 files changed, 84 insertions(+), 26 deletions(-) diff --git a/src/Interpreters/GetAggregatesVisitor.h b/src/Interpreters/GetAggregatesVisitor.h index 33bf0c58141..0eeab8348fd 100644 --- a/src/Interpreters/GetAggregatesVisitor.h +++ b/src/Interpreters/GetAggregatesVisitor.h @@ -25,7 +25,6 @@ public: std::unordered_set uniq_names {}; std::vector aggregates {}; std::vector window_functions {}; - std::map fuse_sum_count_avg {}; }; static bool needChildVisit(const ASTPtr & node, const ASTPtr & child) @@ -73,14 +72,6 @@ private: data.uniq_names.insert(column_name); data.aggregates.push_back(&node); - if ((node.name == "sum" || node.name == "avg" || node.name == "count") && node.arguments->children.size() == 1) - { - const auto & argument = node.arguments->children.at(0); - ASTIdentifier * column = argument->as(); - if (column) - data.fuse_sum_count_avg[column->name()] |= ((node.name == "sum") ? 0x1 : ((node.name == "count") ? 0x2 : 0x4)); - } - } else if (node.is_window_function) { diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index 2251f3998a0..f2282206018 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -182,29 +182,53 @@ struct CustomizeAggregateFunctionsMoveSuffixData } }; +struct FuseFunctions +{ + std::vector sums {}; + std::vector counts {}; + std::vector avgs {}; + + void addFuncNode(const ASTFunction * func) + { + if (func->name == "sum") + sums.push_back(func); + else if (func->name == "count") + counts.push_back(func); + else if (func->name == "avg") + avgs.push_back(func); + } + + bool canBeFused() const + { + if (sums.empty() && counts.empty()) + return false; + if (sums.empty() && avgs.empty()) + return false; + if (counts.empty() && avgs.empty()) + return false; + return true; + } +}; + struct CustomizeFuseAggregateFunctionsData { using TypeToVisit = ASTFunction; - std::map fuse_info; - - static inline UInt8 bitCount(UInt8 n) - { - UInt8 c = 0; - for (c = 0; n; n >>= 1) - c += n & 1; - return c; - } + std::unordered_map fuse_map {}; void visit(ASTFunction & func, ASTPtr &) const { if (func.name == "sum" || func.name == "avg" || func.name == "count") { + if (func.arguments->children.size() == 0) + return; + ASTIdentifier * ident = func.arguments->children.at(0)->as(); if (!ident) return; - auto column = fuse_info.find(ident->name()); - if (column != fuse_info.end() && bitCount(column->second) > 1) + + auto it = fuse_map.find(ident->name()); + if (it != fuse_map.end() && it->second.canBeFused()) { auto func_base = makeASTFunction("sumCount", func.arguments->children.at(0)->clone()); auto exp_list = std::make_shared(); @@ -253,6 +277,38 @@ void translateQualifiedNames(ASTPtr & query, const ASTSelectQuery & select_query throw Exception("Empty list of columns in SELECT query", ErrorCodes::EMPTY_LIST_OF_COLUMNS_QUERIED); } +void gatherFuseFunctions(std::unordered_map &fuse_map, std::vector &aggregates) +{ + for (auto & func : aggregates) + { + if ((func->name == "sum" || func->name == "avg" || func->name == "count") && func->arguments->children.size() == 1) + { + if (func->arguments->children.size() == 0) + return; + + ASTIdentifier * ident = func->arguments->children.at(0)->as(); + if (!ident) + return; + + ASTIdentifier * column = (func->arguments->children.at(0))->as(); + if (!column) + return; + + auto it = fuse_map.find(column->name()); + if (it != fuse_map.end()) + { + it->second.addFuncNode(func); + } + else + { + DB::FuseFunctions funcs{}; + funcs.addFuncNode(func); + fuse_map.emplace(column->name(), funcs); + } + } + } +} + bool hasArrayJoin(const ASTPtr & ast) { if (const ASTFunction * function = ast->as()) @@ -969,19 +1025,21 @@ void TreeRewriter::normalize(ASTPtr & query, Aliases & aliases, const Settings & CustomizeGlobalNotInVisitor(data_global_not_null_in).visit(query); } - // Try to fuse sum/avg/count with identical column(at least two functions exist) to sumCount() + /// Try to fuse sum/avg/count with identical column(at least two functions exist) to sumCount() if (settings.optimize_fuse_sum_count_avg) { - // Get statistics about sum/avg/count + /// Get statistics about sum/avg/count GetAggregatesVisitor::Data data; GetAggregatesVisitor(data).visit(query); + std::unordered_map fuse_map; + gatherFuseFunctions(fuse_map, data.aggregates); - // Try to fuse - CustomizeFuseAggregateFunctionsVisitor::Data data_fuse{.fuse_info = data.fuse_sum_count_avg}; + /// Try to fuse + CustomizeFuseAggregateFunctionsVisitor::Data data_fuse{.fuse_map = fuse_map}; CustomizeFuseAggregateFunctionsVisitor(data_fuse).visit(query); } - // Rewrite all aggregate functions to add -OrNull suffix to them + /// Rewrite all aggregate functions to add -OrNull suffix to them if (settings.aggregate_functions_null_for_empty) { CustomizeAggregateFunctionsOrNullVisitor::Data data_or_null{"OrNull"}; diff --git a/tests/queries/0_stateless/01744_fuse_sum_count_aggregate.reference b/tests/queries/0_stateless/01744_fuse_sum_count_aggregate.reference index c31bb3be922..70c19fc8ced 100644 --- a/tests/queries/0_stateless/01744_fuse_sum_count_aggregate.reference +++ b/tests/queries/0_stateless/01744_fuse_sum_count_aggregate.reference @@ -4,3 +4,9 @@ SELECT sumCount(b).1, sumCount(b).2 FROM fuse_tbl +---------NOT trigger fuse-------- +210 11.5 +SELECT + sum(a), + avg(b) +FROM fuse_tbl diff --git a/tests/queries/0_stateless/01744_fuse_sum_count_aggregate.sql b/tests/queries/0_stateless/01744_fuse_sum_count_aggregate.sql index 3cabcf18831..cad7b5803d4 100644 --- a/tests/queries/0_stateless/01744_fuse_sum_count_aggregate.sql +++ b/tests/queries/0_stateless/01744_fuse_sum_count_aggregate.sql @@ -2,7 +2,10 @@ DROP TABLE IF EXISTS fuse_tbl; CREATE TABLE fuse_tbl(a Int8, b Int8) Engine = Log; INSERT INTO fuse_tbl SELECT number, number + 1 FROM numbers(1, 20); -set optimize_fuse_sum_count_avg = 1; +SET optimize_fuse_sum_count_avg = 1; SELECT sum(a), sum(b), count(b) from fuse_tbl; EXPLAIN SYNTAX SELECT sum(a), sum(b), count(b) from fuse_tbl; +SELECT '---------NOT trigger fuse--------'; +SELECT sum(a), avg(b) from fuse_tbl; +EXPLAIN SYNTAX SELECT sum(a), avg(b) from fuse_tbl; DROP TABLE fuse_tbl;