merge 2 visitors to 1

This commit is contained in:
JackyWoo 2023-08-02 16:45:16 +08:00
parent f512b7a217
commit 461c2fba8b

View File

@ -86,72 +86,34 @@ private :
bool & keep_aggregator;
};
/// Try to eliminate min/max/any/anyLast which will not decent into subqueries.
/// Try to eliminate min/max/any/anyLast.
class EliminateFunctionVisitor : public InDepthQueryTreeVisitorWithContext<EliminateFunctionVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<EliminateFunctionVisitor>;
using Base::Base;
explicit EliminateFunctionVisitor(const QueryTreeNodes & group_by_keys_, ContextPtr context) : Base(context), group_by_keys(group_by_keys_) { }
void enterImpl(QueryTreeNodePtr & node)
{
/// Check if function is min/max/any/anyLast
auto * function_node = node->as<FunctionNode>();
if (!function_node
|| !(function_node->getFunctionName() == "min" || function_node->getFunctionName() == "max"
|| function_node->getFunctionName() == "any" || function_node->getFunctionName() == "anyLast"))
return;
if (!function_node->getArguments().getNodes().empty())
{
bool keep_aggregator = false;
KeepEliminateFunctionVisitor visitor(group_by_keys, keep_aggregator);
visitor.visit(function_node->getArgumentsNode());
/// Place argument of an aggregate function instead of function
if (!keep_aggregator)
node = function_node->getArguments().getNodes()[0];
}
}
static bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child)
{
/// Don't descent into table functions and subqueries and special case for ArrayJoin.
return !child->as<QueryNode>() && !child->as<TableNode>() && !child->as<ArrayJoinNode>();
}
private:
const QueryTreeNodes & group_by_keys;
};
/// Collect QueryNode and its group by keys.
class CollectQueryAndGroupByKeysVisitor : public InDepthQueryTreeVisitorWithContext<CollectQueryAndGroupByKeysVisitor>
{
public:
using Base = InDepthQueryTreeVisitorWithContext<CollectQueryAndGroupByKeysVisitor>;
using Base::Base;
using Data = std::unordered_map<QueryTreeNodePtr, QueryTreeNodes>;
Data data;
using GroupByKeysStack = std::vector<const QueryTreeNodes>;
void enterImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_aggregators_of_group_by_keys)
return;
auto * query_node = node->as<QueryNode>();
if (!query_node)
return;
/// (1) collect group by keys
if (auto * query_node = node->as<QueryNode>())
{
if (!query_node->hasGroupBy())
return;
if (query_node->isGroupByWithTotals() || query_node->isGroupByWithCube() || query_node->isGroupByWithRollup())
return;
{
group_by_keys_stack.push_back({});
}
else if (query_node->isGroupByWithTotals() || query_node->isGroupByWithCube() || query_node->isGroupByWithRollup())
{
/// Keep aggregator if group by is with totals/cube/rollup.
group_by_keys_stack.push_back({});
}
else
{
QueryTreeNodes group_by_keys;
for (auto & group_key : query_node->getGroupBy().getNodes())
{
@ -166,23 +128,59 @@ public:
group_by_keys.push_back(group_key);
}
}
data.insert({node, std::move(group_by_keys)});
group_by_keys_stack.push_back(std::move(group_by_keys));
}
}
/// (2) Try to eliminate any/min/max
else if (auto * function_node = node->as<FunctionNode>())
{
if (!function_node
|| !(function_node->getFunctionName() == "min" || function_node->getFunctionName() == "max"
|| function_node->getFunctionName() == "any" || function_node->getFunctionName() == "anyLast"))
return;
if (!function_node->getArguments().getNodes().empty())
{
bool keep_aggregator = false;
KeepEliminateFunctionVisitor visitor(group_by_keys_stack.back(), keep_aggregator);
visitor.visit(function_node->getArgumentsNode());
/// Place argument of an aggregate function instead of function
if (!keep_aggregator)
node = function_node->getArguments().getNodes()[0];
}
}
}
/// Now we visit all nodes in QueryNode, we should remove group_by_keys from stack.
void leaveImpl(QueryTreeNodePtr & node)
{
if (!getSettings().optimize_aggregators_of_group_by_keys)
return;
if (auto * query_node = node->as<QueryNode>())
group_by_keys_stack.pop_back();
}
static bool needChildVisit(VisitQueryTreeNodeType & parent [[maybe_unused]], VisitQueryTreeNodeType & child)
{
/// Skip ArrayJoin.
return !child->as<ArrayJoinNode>();
}
private:
GroupByKeysStack group_by_keys_stack;
};
}
void AggregateFunctionOfGroupByKeysPass::run(QueryTreeNodePtr query_tree_node, ContextPtr context)
{
CollectQueryAndGroupByKeysVisitor collector(context);
collector.visit(query_tree_node);
for (auto & [query_node, group_by_keys] : collector.data)
{
EliminateFunctionVisitor eliminator(group_by_keys, query_node->as<QueryNode>()->getContext());
auto mutable_query_node = query_node;
eliminator.visit(mutable_query_node);
}
EliminateFunctionVisitor eliminator(context);
eliminator.visit(query_tree_node);
}
};