Get rid of where_stack in CrossToInnerJoinVisitor

This commit is contained in:
vdimir 2023-02-23 13:18:55 +00:00
parent d4bb84e68b
commit bea8525234
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862

View File

@ -120,6 +120,20 @@ bool findInTableExpression(const IQueryTreeNode * source, const QueryTreeNodePtr
return false;
}
void getJoinNodes(QueryTreeNodePtr & join_tree_node, std::vector<JoinNode *> & join_nodes)
{
auto * join_node = join_tree_node->as<JoinNode>();
if (!join_node)
return;
if (!isCrossOrComma(join_node->getKind()))
return;
join_nodes.push_back(join_node);
getJoinNodes(join_node->getLeftTableExpression(), join_nodes);
getJoinNodes(join_node->getRightTableExpression(), join_nodes);
}
class CrossToInnerJoinVisitor : public InDepthQueryTreeVisitorWithContext<CrossToInnerJoinVisitor>
{
public:
@ -127,20 +141,16 @@ public:
using Base::Base;
/// Returns false if can't rewrite cross to inner join
bool tryRewrite(JoinNode * join_node)
bool tryRewrite(JoinNode & join_node, QueryTreeNodePtr & where_condition)
{
if (!isCrossOrComma(join_node->getKind()))
return true;
if (where_stack.empty())
if (!isCrossOrComma(join_node.getKind()))
return false;
auto & where_condition = *where_stack.back();
if (!where_condition)
return false;
const auto & left_table = join_node->getLeftTableExpression();
const auto & right_table = join_node->getRightTableExpression();
const auto & left_table = join_node.getLeftTableExpression();
const auto & right_table = join_node.getRightTableExpression();
QueryTreeNodes equi_conditions;
QueryTreeNodes other_conditions;
@ -171,7 +181,7 @@ public:
return false;
equi_conditions.erase(std::remove(equi_conditions.begin(), equi_conditions.end(), nullptr), equi_conditions.end());
join_node->crossToInner(makeConjunction(equi_conditions));
join_node.crossToInner(makeConjunction(equi_conditions));
where_condition = makeConjunction(other_conditions);
return true;
}
@ -183,29 +193,32 @@ public:
if (auto * query_node = node->as<QueryNode>())
{
/// We are entering the subtree and can use WHERE condition from this subtree
if (auto & where_node = query_node->getWhere())
where_stack.push_back(&where_node);
}
auto & where_node = query_node->getWhere();
if (!where_node)
return;
if (auto * join_node = node->as<JoinNode>())
{
bool is_rewritten = tryRewrite(join_node);
if (!is_rewritten && forceRewrite(join_node->getKind()))
auto & join_tree_node = query_node->getJoinTree();
if (!join_tree_node || join_tree_node->getNodeType() != QueryTreeNodeType::JOIN)
return;
/// In case of multiple joins, we can try to rewrite all of them
/// Example: SELECT * FROM t1, t2, t3 WHERE t1.a = t2.a AND t2.a = t3.a
std::vector<JoinNode *> join_nodes;
getJoinNodes(join_tree_node, join_nodes);
for (auto * join_node : join_nodes)
{
throw Exception(ErrorCodes::INCORRECT_QUERY,
"Failed to rewrite '{}' to INNER JOIN: "
"no equi-join conditions found in WHERE clause. "
"You may set setting `cross_to_inner_join_rewrite` to `1` to allow slow CROSS JOIN for this case",
join_node->formatASTForErrorMessage());
}
}
bool is_rewritten = tryRewrite(*join_node, where_node);
if (!where_stack.empty() && where_stack.back()->get() == node.get())
{
/// We are visiting the WHERE clause.
/// It means that we have visited current subtree and will go out of WHERE scope.
where_stack.pop_back();
if (!is_rewritten && forceRewrite(join_node->getKind()))
{
throw Exception(ErrorCodes::INCORRECT_QUERY,
"Failed to rewrite '{}' to INNER JOIN: "
"no equi-join conditions found in WHERE clause. "
"You may set setting `cross_to_inner_join_rewrite` to `1` to allow slow CROSS JOIN for this case",
join_node->formatASTForErrorMessage());
}
}
}
}
@ -239,8 +252,6 @@ private:
function_node->resolveAsFunction(function->build(function_node->getArgumentColumns()));
return function_node;
}
std::deque<QueryTreeNodePtr *> where_stack;
};
}