From 461c2fba8b9cd8ab58b3448e9438d961f9ed3660 Mon Sep 17 00:00:00 2001 From: JackyWoo Date: Wed, 2 Aug 2023 16:45:16 +0800 Subject: [PATCH] merge 2 visitors to 1 --- .../AggregateFunctionOfGroupByKeysPass.cpp | 140 +++++++++--------- 1 file changed, 69 insertions(+), 71 deletions(-) diff --git a/src/Analyzer/Passes/AggregateFunctionOfGroupByKeysPass.cpp b/src/Analyzer/Passes/AggregateFunctionOfGroupByKeysPass.cpp index f7b0e80a7a7..de2551a049c 100644 --- a/src/Analyzer/Passes/AggregateFunctionOfGroupByKeysPass.cpp +++ b/src/Analyzer/Passes/AggregateFunctionOfGroupByKeysPass.cpp @@ -86,103 +86,101 @@ private : bool & keep_aggregator; }; -/// Try to eliminate min/max/any/anyLast which will not decent into subqueries. +/// Try to eliminate min/max/any/anyLast. class EliminateFunctionVisitor : public InDepthQueryTreeVisitorWithContext { public: using Base = InDepthQueryTreeVisitorWithContext; using Base::Base; - explicit EliminateFunctionVisitor(const QueryTreeNodes & group_by_keys_, ContextPtr context) : Base(context), group_by_keys(group_by_keys_) { } - - void enterImpl(QueryTreeNodePtr & node) - { - /// Check if function is min/max/any/anyLast - auto * function_node = node->as(); - if (!function_node - || !(function_node->getFunctionName() == "min" || function_node->getFunctionName() == "max" - || function_node->getFunctionName() == "any" || function_node->getFunctionName() == "anyLast")) - return; - - if (!function_node->getArguments().getNodes().empty()) - { - bool keep_aggregator = false; - - KeepEliminateFunctionVisitor visitor(group_by_keys, keep_aggregator); - visitor.visit(function_node->getArgumentsNode()); - - /// Place argument of an aggregate function instead of function - if (!keep_aggregator) - node = function_node->getArguments().getNodes()[0]; - } - } - - static bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child) - { - /// Don't descent into table functions and subqueries and special case for ArrayJoin. - return !child->as() && !child->as() && !child->as(); - } - -private: - const QueryTreeNodes & group_by_keys; -}; - -/// Collect QueryNode and its group by keys. -class CollectQueryAndGroupByKeysVisitor : public InDepthQueryTreeVisitorWithContext -{ -public: - using Base = InDepthQueryTreeVisitorWithContext; - using Base::Base; - - using Data = std::unordered_map; - Data data; + using GroupByKeysStack = std::vector; void enterImpl(QueryTreeNodePtr & node) { if (!getSettings().optimize_aggregators_of_group_by_keys) return; - auto * query_node = node->as(); - if (!query_node) - return; - - if (!query_node->hasGroupBy()) - return; - - if (query_node->isGroupByWithTotals() || query_node->isGroupByWithCube() || query_node->isGroupByWithRollup()) - return; - - QueryTreeNodes group_by_keys; - for (auto & group_key : query_node->getGroupBy().getNodes()) + /// (1) collect group by keys + if (auto * query_node = node->as()) { - /// for grouping sets case - if (auto * list = group_key->as()) + if (!query_node->hasGroupBy()) { - for (auto & group_elem : list->getNodes()) - group_by_keys.push_back(group_elem); + group_by_keys_stack.push_back({}); + } + else if (query_node->isGroupByWithTotals() || query_node->isGroupByWithCube() || query_node->isGroupByWithRollup()) + { + /// Keep aggregator if group by is with totals/cube/rollup. + group_by_keys_stack.push_back({}); } else { - group_by_keys.push_back(group_key); + QueryTreeNodes group_by_keys; + for (auto & group_key : query_node->getGroupBy().getNodes()) + { + /// for grouping sets case + if (auto * list = group_key->as()) + { + for (auto & group_elem : list->getNodes()) + group_by_keys.push_back(group_elem); + } + else + { + group_by_keys.push_back(group_key); + } + } + group_by_keys_stack.push_back(std::move(group_by_keys)); } } - data.insert({node, std::move(group_by_keys)}); + /// (2) Try to eliminate any/min/max + else if (auto * function_node = node->as()) + { + if (!function_node + || !(function_node->getFunctionName() == "min" || function_node->getFunctionName() == "max" + || function_node->getFunctionName() == "any" || function_node->getFunctionName() == "anyLast")) + return; + + if (!function_node->getArguments().getNodes().empty()) + { + bool keep_aggregator = false; + + KeepEliminateFunctionVisitor visitor(group_by_keys_stack.back(), keep_aggregator); + visitor.visit(function_node->getArgumentsNode()); + + /// Place argument of an aggregate function instead of function + if (!keep_aggregator) + node = function_node->getArguments().getNodes()[0]; + } + } + } + + /// Now we visit all nodes in QueryNode, we should remove group_by_keys from stack. + void leaveImpl(QueryTreeNodePtr & node) + { + if (!getSettings().optimize_aggregators_of_group_by_keys) + return; + + if (auto * query_node = node->as()) + group_by_keys_stack.pop_back(); + } + + static bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child) + { + /// Skip ArrayJoin. + return !child->as(); + } + +private: + GroupByKeysStack group_by_keys_stack; + }; } void AggregateFunctionOfGroupByKeysPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context) { - CollectQueryAndGroupByKeysVisitor collector(context); - collector.visit(query_tree_node); - - for (auto & [query_node, group_by_keys] : collector.data) - { - EliminateFunctionVisitor eliminator(group_by_keys, query_node->as()->getContext()); - auto mutable_query_node = query_node; - eliminator.visit(mutable_query_node); - } + EliminateFunctionVisitor eliminator(context); + eliminator.visit(query_tree_node); } };