Merge pull request #62387 from ClickHouse/vdimir/join_on_expression_optimizer_fix

Use function isNotDistinctFrom only in join key
This commit is contained in:
vdimir 2024-04-11 10:34:28 +00:00 committed by GitHub
commit fe4373fa53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 87 additions and 48 deletions

View File

@ -15,6 +15,7 @@
#include <Functions/logical.h>
#include <Common/logger_useful.h>
#include <Analyzer/Utils.h>
namespace DB
@ -61,47 +62,7 @@ const QueryTreeNodePtr & getEquiArgument(const QueryTreeNodePtr & cond, size_t i
return func->getArguments().getNodes()[index];
}
/// Check that node has only one source and return it.
/// {_, false} - multiple sources
/// {nullptr, true} - no sources
/// {source, true} - single source
std::pair<const IQueryTreeNode *, bool> getExpressionSource(const QueryTreeNodePtr & node)
{
if (const auto * column = node->as<ColumnNode>())
{
auto source = column->getColumnSourceOrNull();
if (!source)
return {nullptr, false};
return {source.get(), true};
}
if (const auto * func = node->as<FunctionNode>())
{
const IQueryTreeNode * source = nullptr;
const auto & args = func->getArguments().getNodes();
for (const auto & arg : args)
{
auto [arg_source, is_ok] = getExpressionSource(arg);
if (!is_ok)
return {nullptr, false};
if (!source)
source = arg_source;
else if (arg_source && !source->isEqual(*arg_source))
return {nullptr, false};
}
return {source, true};
}
if (node->as<ConstantNode>())
return {nullptr, true};
return {nullptr, false};
}
bool findInTableExpression(const IQueryTreeNode * source, const QueryTreeNodePtr & table_expression)
bool findInTableExpression(const QueryTreeNodePtr & source, const QueryTreeNodePtr & table_expression)
{
if (!source)
return true;
@ -115,7 +76,6 @@ bool findInTableExpression(const IQueryTreeNode * source, const QueryTreeNodePtr
|| findInTableExpression(source, join_node->getRightTableExpression());
}
return false;
}
@ -169,10 +129,10 @@ public:
auto left_src = getExpressionSource(lhs_equi_argument);
auto right_src = getExpressionSource(rhs_equi_argument);
if (left_src.second && right_src.second && left_src.first && right_src.first)
if (left_src && right_src)
{
if ((findInTableExpression(left_src.first, left_table) && findInTableExpression(right_src.first, right_table)) ||
(findInTableExpression(left_src.first, right_table) && findInTableExpression(right_src.first, left_table)))
if ((findInTableExpression(left_src, left_table) && findInTableExpression(right_src, right_table)) ||
(findInTableExpression(left_src, right_table) && findInTableExpression(right_src, left_table)))
{
can_convert_cross_to_inner = true;
continue;

View File

@ -25,8 +25,9 @@ class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWi
public:
using Base = InDepthQueryTreeVisitorWithContext<JoinOnLogicalExpressionOptimizerVisitor>;
explicit JoinOnLogicalExpressionOptimizerVisitor(ContextPtr context)
explicit JoinOnLogicalExpressionOptimizerVisitor(const JoinNode * join_node_, ContextPtr context)
: Base(std::move(context))
, join_node(join_node_)
{}
void enterImpl(QueryTreeNodePtr & node)
@ -55,10 +56,11 @@ public:
}
private:
const JoinNode * join_node;
bool need_rerun_resolve = false;
/// Returns true if type of some operand is changed and parent function needs to be re-resolved
static bool tryOptimizeIsNotDistinctOrIsNull(QueryTreeNodePtr & node, const ContextPtr & context)
bool tryOptimizeIsNotDistinctOrIsNull(QueryTreeNodePtr & node, const ContextPtr & context)
{
auto & function_node = node->as<FunctionNode &>();
chassert(function_node.getFunctionName() == "or");
@ -93,6 +95,21 @@ private:
const auto & func_name = argument_function->getFunctionName();
if (func_name == "equals" || func_name == "isNotDistinctFrom")
{
const auto & argument_nodes = argument_function->getArguments().getNodes();
if (argument_nodes.size() != 2)
continue;
/// We can rewrite to a <=> b only if we are joining on a and b,
/// because the function is not yet implemented for other cases.
auto first_src = getExpressionSource(argument_nodes[0]);
auto second_src = getExpressionSource(argument_nodes[1]);
if (!first_src || !second_src)
continue;
const auto & lhs_join = *join_node->getLeftTableExpression();
const auto & rhs_join = *join_node->getRightTableExpression();
bool arguments_from_both_sides = (first_src->isEqual(lhs_join) && second_src->isEqual(rhs_join)) ||
(first_src->isEqual(rhs_join) && second_src->isEqual(lhs_join));
if (!arguments_from_both_sides)
continue;
equals_functions_indices.push_back(or_operands.size() - 1);
}
else if (func_name == "and")
@ -231,7 +248,7 @@ public:
/// Operator <=> is not supported outside of JOIN ON section
if (join_node->hasJoinExpression())
{
JoinOnLogicalExpressionOptimizerVisitor join_on_visitor(getContext());
JoinOnLogicalExpressionOptimizerVisitor join_on_visitor(join_node, getContext());
join_on_visitor.visit(join_node->getJoinExpression());
}
return;

View File

@ -760,6 +760,54 @@ QueryTreeNodePtr createCastFunction(QueryTreeNodePtr node, DataTypePtr result_ty
return function_node;
}
/** Returns:
* {_, false} - multiple sources
* {nullptr, true} - no sources (for constants)
* {source, true} - single source
*/
std::pair<QueryTreeNodePtr, bool> getExpressionSourceImpl(const QueryTreeNodePtr & node)
{
if (const auto * column = node->as<ColumnNode>())
{
auto source = column->getColumnSourceOrNull();
if (!source)
return {nullptr, false};
return {source, true};
}
if (const auto * func = node->as<FunctionNode>())
{
QueryTreeNodePtr source = nullptr;
const auto & args = func->getArguments().getNodes();
for (const auto & arg : args)
{
auto [arg_source, is_ok] = getExpressionSourceImpl(arg);
if (!is_ok)
return {nullptr, false};
if (!source)
source = arg_source;
else if (arg_source && !source->isEqual(*arg_source))
return {nullptr, false};
}
return {source, true};
}
if (node->as<ConstantNode>())
return {nullptr, true};
return {nullptr, false};
}
QueryTreeNodePtr getExpressionSource(const QueryTreeNodePtr & node)
{
auto [source, is_ok] = getExpressionSourceImpl(node);
if (!is_ok)
return nullptr;
return source;
}
QueryTreeNodePtr buildSubqueryToReadColumnsFromTableExpression(QueryTreeNodePtr table_node, const ContextPtr & context)
{
const auto & storage_snapshot = table_node->as<TableNode>()->getStorageSnapshot();

View File

@ -105,6 +105,9 @@ NameSet collectIdentifiersFullNames(const QueryTreeNodePtr & node);
/// Wrap node into `_CAST` function
QueryTreeNodePtr createCastFunction(QueryTreeNodePtr node, DataTypePtr result_type, ContextPtr context);
/// Checks that node has only one source and returns it
QueryTreeNodePtr getExpressionSource(const QueryTreeNodePtr & node);
/// Build subquery which we execute for `IN table` function.
QueryTreeNodePtr buildSubqueryToReadColumnsFromTableExpression(QueryTreeNodePtr table_node, const ContextPtr & context);

View File

@ -8,6 +8,14 @@ SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.x IS NULL AND t1.y <=> t2.y AND
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) ORDER BY t1.x;
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON t1.x <=> t2.x AND (t1.x = t1.y OR t1.x IS NULL AND t1.y IS NULL) ORDER BY t1.x;
2 2 2 2
3 3 3 33
\N \N \N \N
SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) AND t1.y <=> t2.y ORDER BY t1.x NULLS LAST;
2 2 2 2
\N \N \N \N

View File

@ -14,6 +14,9 @@ SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) O
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.x IS NULL AND t1.y <=> t2.y AND t2.x IS NULL) ORDER BY t1.x NULLS LAST;
SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) ORDER BY t1.x;
SELECT * FROM t1 JOIN t2 ON t1.x <=> t2.x AND (t1.x = t1.y OR t1.x IS NULL AND t1.y IS NULL) ORDER BY t1.x;
SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) AND t1.y <=> t2.y ORDER BY t1.x NULLS LAST;
SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.y <=> t2.y OR (t1.x IS NULL AND t1.y IS NULL AND t2.x IS NULL AND t2.y IS NULL)) ORDER BY t1.x NULLS LAST;