merge 2 visitors to 1

This commit is contained in:
JackyWoo 2023-08-02 16:45:16 +08:00
parent f512b7a217
commit 461c2fba8b

View File

@ -86,72 +86,34 @@ private :
bool & keep_aggregator; bool & keep_aggregator;
}; };
/// Try to eliminate min/max/any/anyLast which will not decent into subqueries. /// Try to eliminate min/max/any/anyLast.
class EliminateFunctionVisitor : public InDepthQueryTreeVisitorWithContext<EliminateFunctionVisitor> class EliminateFunctionVisitor : public InDepthQueryTreeVisitorWithContext<EliminateFunctionVisitor>
{ {
public: public:
using Base = InDepthQueryTreeVisitorWithContext<EliminateFunctionVisitor>; using Base = InDepthQueryTreeVisitorWithContext<EliminateFunctionVisitor>;
using Base::Base; using Base::Base;
explicit EliminateFunctionVisitor(const QueryTreeNodes & group_by_keys_, ContextPtr context) : Base(context), group_by_keys(group_by_keys_) { } using GroupByKeysStack = std::vector<const QueryTreeNodes>;
void enterImpl(QueryTreeNodePtr & node)
{
/// Check if function is min/max/any/anyLast
auto * function_node = node->as<FunctionNode>();
if (!function_node
|| !(function_node->getFunctionName() == "min" || function_node->getFunctionName() == "max"
|| function_node->getFunctionName() == "any" || function_node->getFunctionName() == "anyLast"))
return;
if (!function_node->getArguments().getNodes().empty())
{
bool keep_aggregator = false;
KeepEliminateFunctionVisitor visitor(group_by_keys, keep_aggregator);
visitor.visit(function_node->getArgumentsNode());
/// Place argument of an aggregate function instead of function
if (!keep_aggregator)
node = function_node->getArguments().getNodes()[0];
}
}
static bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child)
{
/// Don't descent into table functions and subqueries and special case for ArrayJoin.
return !child->as<QueryNode>() && !child->as<TableNode>() && !child->as<ArrayJoinNode>();
}
private:
const QueryTreeNodes & group_by_keys;
};
/// Collect QueryNode and its group by keys.
class CollectQueryAndGroupByKeysVisitor : public InDepthQueryTreeVisitorWithContext<CollectQueryAndGroupByKeysVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<CollectQueryAndGroupByKeysVisitor>;
using Base::Base;
using Data = std::unordered_map<QueryTreeNodePtr, QueryTreeNodes>;
Data data;
void enterImpl(QueryTreeNodePtr & node) void enterImpl(QueryTreeNodePtr & node)
{ {
if (!getSettings().optimize_aggregators_of_group_by_keys) if (!getSettings().optimize_aggregators_of_group_by_keys)
return; return;
auto * query_node = node->as<QueryNode>(); /// (1) collect group by keys
if (!query_node) if (auto * query_node = node->as<QueryNode>())
return; {
if (!query_node->hasGroupBy()) if (!query_node->hasGroupBy())
return; {
group_by_keys_stack.push_back({});
if (query_node->isGroupByWithTotals() || query_node->isGroupByWithCube() || query_node->isGroupByWithRollup()) }
return; else if (query_node->isGroupByWithTotals() || query_node->isGroupByWithCube() || query_node->isGroupByWithRollup())
{
/// Keep aggregator if group by is with totals/cube/rollup.
group_by_keys_stack.push_back({});
}
else
{
QueryTreeNodes group_by_keys; QueryTreeNodes group_by_keys;
for (auto & group_key : query_node->getGroupBy().getNodes()) for (auto & group_key : query_node->getGroupBy().getNodes())
{ {
@ -166,23 +128,59 @@ public:
group_by_keys.push_back(group_key); group_by_keys.push_back(group_key);
} }
} }
data.insert({node, std::move(group_by_keys)}); group_by_keys_stack.push_back(std::move(group_by_keys));
} }
}
/// (2) Try to eliminate any/min/max
else if (auto * function_node = node->as<FunctionNode>())
{
if (!function_node
|| !(function_node->getFunctionName() == "min" || function_node->getFunctionName() == "max"
|| function_node->getFunctionName() == "any" || function_node->getFunctionName() == "anyLast"))
return;
if (!function_node->getArguments().getNodes().empty())
{
bool keep_aggregator = false;
KeepEliminateFunctionVisitor visitor(group_by_keys_stack.back(), keep_aggregator);
visitor.visit(function_node->getArgumentsNode());
/// Place argument of an aggregate function instead of function
if (!keep_aggregator)
node = function_node->getArguments().getNodes()[0];
}
}
}
/// Now we visit all nodes in QueryNode, we should remove group_by_keys from stack.
void leaveImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_aggregators_of_group_by_keys)
return;
if (auto * query_node = node->as<QueryNode>())
group_by_keys_stack.pop_back();
}
static bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child)
{
/// Skip ArrayJoin.
return !child->as<ArrayJoinNode>();
}
private:
GroupByKeysStack group_by_keys_stack;
}; };
} }
void AggregateFunctionOfGroupByKeysPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context) void AggregateFunctionOfGroupByKeysPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{ {
CollectQueryAndGroupByKeysVisitor collector(context); EliminateFunctionVisitor eliminator(context);
collector.visit(query_tree_node); eliminator.visit(query_tree_node);
for (auto & [query_node, group_by_keys] : collector.data)
{
EliminateFunctionVisitor eliminator(group_by_keys, query_node->as<QueryNode>()->getContext());
auto mutable_query_node = query_node;
eliminator.visit(mutable_query_node);
}
} }
}; };