diff --git a/src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp b/src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp index 46056aeaf6f..081a27eb8fa 100644 --- a/src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp +++ b/src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp @@ -5,11 +5,214 @@ #include #include #include +#include #include +#include namespace DB { +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; +} + +/// Visitor that optimizes logical expressions _only_ in JOIN ON section +class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithContext +{ +public: + using Base = InDepthQueryTreeVisitorWithContext; + + explicit JoinOnLogicalExpressionOptimizerVisitor(ContextPtr context) + : Base(std::move(context)) + {} + + void enterImpl(QueryTreeNodePtr & node) + { + auto * function_node = node->as(); + + if (!function_node) + return; + + if (function_node->getFunctionName() == "or") + { + bool is_argument_type_changed = tryOptimizeIsNotDistinctOrIsNull(node, getContext()); + if (is_argument_type_changed) + need_rerun_resolve = true; + return; + } + } + + void leaveImpl(QueryTreeNodePtr & node) + { + if (!need_rerun_resolve) + return; + + if (auto * function_node = node->as()) + rerunFunctionResolve(function_node, getContext()); + } + +private: + bool need_rerun_resolve = false; + + /// Returns true if type of some operand is changed and parent function needs to be re-resolved + static bool tryOptimizeIsNotDistinctOrIsNull(QueryTreeNodePtr & node, const ContextPtr & context) + { + auto & function_node = node->as(); + chassert(function_node.getFunctionName() == "or"); + + + QueryTreeNodes or_operands; + or_operands.reserve(function_node.getArguments().getNodes().size()); + + /// Indices of `equals` or `isNotDistinctFrom` functions in the vector above + std::vector equals_functions_indices; + + /** Map from `isNull` argument to indices of operands that contains that `isNull` functions + * `a = b OR (a IS NULL AND b IS NULL) OR (a IS NULL AND c IS NULL)` + * will be mapped to + * { + * a => [(a IS NULL AND b IS NULL), (a IS NULL AND c IS NULL)] + * b => [(a IS NULL AND b IS NULL)] + * c => [(a IS NULL AND c IS NULL)] + * } + * Then for each a <=> b we can find all operands that contains both a IS NULL and b IS NULL + */ + QueryTreeNodePtrWithHashMap> is_null_argument_to_indices; + + for (const auto & argument : function_node.getArguments()) + { + or_operands.push_back(argument); + + auto * argument_function = argument->as(); + if (!argument_function) + continue; + + const auto & func_name = argument_function->getFunctionName(); + if (func_name == "equals" || func_name == "isNotDistinctFrom") + { + equals_functions_indices.push_back(or_operands.size() - 1); + } + else if (func_name == "and") + { + for (const auto & and_argument : argument_function->getArguments().getNodes()) + { + auto * and_argument_function = and_argument->as(); + if (and_argument_function && and_argument_function->getFunctionName() == "isNull") + { + const auto & is_null_argument = and_argument_function->getArguments().getNodes()[0]; + is_null_argument_to_indices[is_null_argument].push_back(or_operands.size() - 1); + } + } + } + } + + /// OR operands that are changed to and needs to be re-resolved + std::unordered_set arguments_to_reresolve; + + for (size_t equals_function_idx : equals_functions_indices) + { + auto * equals_function = or_operands[equals_function_idx]->as(); + + /// For a <=> b we are looking for expressions containing both `a IS NULL` and `b IS NULL` combined with AND + const auto & argument_nodes = equals_function->getArguments().getNodes(); + const auto & lhs_is_null_parents = is_null_argument_to_indices[argument_nodes[0]]; + const auto & rhs_is_null_parents = is_null_argument_to_indices[argument_nodes[1]]; + std::unordered_set operands_to_optimize; + std::set_intersection(lhs_is_null_parents.begin(), lhs_is_null_parents.end(), + rhs_is_null_parents.begin(), rhs_is_null_parents.end(), + std::inserter(operands_to_optimize, operands_to_optimize.begin())); + + /// If we have `a = b OR (a IS NULL AND b IS NULL)` we can optimize it to `a <=> b` + if (!operands_to_optimize.empty() && equals_function->getFunctionName() == "equals") + arguments_to_reresolve.insert(equals_function_idx); + + for (size_t to_optimize_idx : operands_to_optimize) + { + /// We are looking for operand `a IS NULL AND b IS NULL AND ...` + auto * operand_to_optimize = or_operands[to_optimize_idx]->as(); + + /// Remove `a IS NULL` and `b IS NULL` arguments from AND + QueryTreeNodes new_arguments; + for (const auto & and_argument : operand_to_optimize->getArguments().getNodes()) + { + bool to_eliminate = false; + + const auto * and_argument_function = and_argument->as(); + if (and_argument_function && and_argument_function->getFunctionName() == "isNull") + { + const auto & is_null_argument = and_argument_function->getArguments().getNodes()[0]; + to_eliminate = (is_null_argument->isEqual(*argument_nodes[0]) || is_null_argument->isEqual(*argument_nodes[1])); + } + + if (to_eliminate) + arguments_to_reresolve.insert(to_optimize_idx); + else + new_arguments.emplace_back(and_argument); + } + /// If less than two arguments left, we will remove or replace the whole AND below + operand_to_optimize->getArguments().getNodes() = std::move(new_arguments); + } + } + + if (arguments_to_reresolve.empty()) + /// Nothing have been changed + return false; + + auto and_function_resolver = FunctionFactory::instance().get("and", context); + auto strict_equals_function_resolver = FunctionFactory::instance().get("isNotDistinctFrom", context); + + bool need_reresolve = false; + QueryTreeNodes new_or_operands; + for (size_t i = 0; i < or_operands.size(); ++i) + { + if (arguments_to_reresolve.contains(i)) + { + auto * function = or_operands[i]->as(); + if (function->getFunctionName() == "equals") + { + /// We should replace `a = b` with `a <=> b` because we removed checks for IS NULL + need_reresolve |= function->getResultType()->isNullable(); + function->resolveAsFunction(strict_equals_function_resolver); + new_or_operands.emplace_back(std::move(or_operands[i])); + } + else if (function->getFunctionName() == "and") + { + const auto & and_arguments = function->getArguments().getNodes(); + if (and_arguments.size() > 1) + { + function->resolveAsFunction(and_function_resolver); + new_or_operands.emplace_back(std::move(or_operands[i])); + } + else if (and_arguments.size() == 1) + { + /// Replace AND with a single argument with the argument itself + new_or_operands.emplace_back(and_arguments[0]); + } + } + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected function name: '{}'", function->getFunctionName()); + } + else + { + new_or_operands.emplace_back(std::move(or_operands[i])); + } + } + + if (new_or_operands.size() == 1) + { + node = std::move(new_or_operands[0]); + return need_reresolve; + } + + /// Rebuild OR function + auto or_function_resolver = FunctionFactory::instance().get("or", context); + function_node.getArguments().getNodes() = std::move(new_or_operands); + function_node.resolveAsFunction(or_function_resolver); + return need_reresolve; + } +}; + class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithContext { public: @@ -21,6 +224,17 @@ public: void enterImpl(QueryTreeNodePtr & node) { + if (auto * join_node = node->as()) + { + /// Operator <=> is not supported outside of JOIN ON section + if (join_node->hasJoinExpression()) + { + JoinOnLogicalExpressionOptimizerVisitor join_on_visitor(getContext()); + join_on_visitor.visit(join_node->getJoinExpression()); + } + return; + } + auto * function_node = node->as(); if (!function_node) @@ -38,6 +252,7 @@ public: return; } } + private: void tryReplaceAndEqualsChainsWithConstant(QueryTreeNodePtr & node) { diff --git a/src/Analyzer/Passes/LogicalExpressionOptimizerPass.h b/src/Analyzer/Passes/LogicalExpressionOptimizerPass.h index 05c10ddc685..80062f38eac 100644 --- a/src/Analyzer/Passes/LogicalExpressionOptimizerPass.h +++ b/src/Analyzer/Passes/LogicalExpressionOptimizerPass.h @@ -67,6 +67,17 @@ namespace DB * FROM TABLE * WHERE a = 1 AND b = 'test'; * ------------------------------- + * + * 5. Remove unnecessary IS NULL checks in JOIN ON clause + * - equality check with explicit IS NULL check replaced with <=> operator + * ------------------------------- + * SELECT * FROM t1 JOIN t2 ON a = b OR (a IS NULL AND b IS NULL) + * SELECT * FROM t1 JOIN t2 ON a <=> b OR (a IS NULL AND b IS NULL) + * + * will be transformed into + * + * SELECT * FROM t1 JOIN t2 ON a <=> b + * ------------------------------- */ class LogicalExpressionOptimizerPass final : public IQueryTreePass diff --git a/tests/queries/0_stateless/02911_join_on_nullsafe_optimization.reference b/tests/queries/0_stateless/02911_join_on_nullsafe_optimization.reference new file mode 100644 index 00000000000..976c1503b02 --- /dev/null +++ b/tests/queries/0_stateless/02911_join_on_nullsafe_optimization.reference @@ -0,0 +1,25 @@ +-- { echoOn } +SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) ORDER BY t1.x NULLS LAST; +2 2 2 2 +3 3 3 33 +\N \N \N \N +SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.x IS NULL AND t1.y <=> t2.y AND t2.x IS NULL) ORDER BY t1.x NULLS LAST; +1 42 4 42 +2 2 2 2 +3 3 3 33 +\N \N \N \N +SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) AND t1.y <=> t2.y ORDER BY t1.x NULLS LAST; +2 2 2 2 +\N \N \N \N +SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.y <=> t2.y OR (t1.x IS NULL AND t1.y IS NULL AND t2.x IS NULL AND t2.y IS NULL)) ORDER BY t1.x NULLS LAST; +1 42 4 42 +2 2 2 2 +3 3 3 33 +\N \N \N \N +SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) AND (t1.y == t2.y OR (t1.y IS NULL AND t2.y IS NULL)) AND COALESCE(t1.x, 0) != 2 ORDER BY t1.x NULLS LAST; +\N \N \N \N +SELECT x = y OR (x IS NULL AND y IS NULL) FROM t1 ORDER BY x NULLS LAST; +0 +1 +1 +1 diff --git a/tests/queries/0_stateless/02911_join_on_nullsafe_optimization.sql b/tests/queries/0_stateless/02911_join_on_nullsafe_optimization.sql new file mode 100644 index 00000000000..6a98a7bb57b --- /dev/null +++ b/tests/queries/0_stateless/02911_join_on_nullsafe_optimization.sql @@ -0,0 +1,27 @@ +DROP TABLE IF EXISTS t1; +DROP TABLE IF EXISTS t2; + +CREATE TABLE t1 (x Nullable(Int64), y Nullable(UInt64)) ENGINE = TinyLog; +CREATE TABLE t2 (x Nullable(Int64), y Nullable(UInt64)) ENGINE = TinyLog; + +INSERT INTO t1 VALUES (1,42), (2,2), (3,3), (NULL,NULL); +INSERT INTO t2 VALUES (NULL,NULL), (2,2), (3,33), (4,42); + +SET allow_experimental_analyzer = 1; + +-- { echoOn } +SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) ORDER BY t1.x NULLS LAST; + +SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.x IS NULL AND t1.y <=> t2.y AND t2.x IS NULL) ORDER BY t1.x NULLS LAST; + +SELECT * FROM t1 JOIN t2 ON (t1.x = t2.x OR t1.x IS NULL AND t2.x IS NULL) AND t1.y <=> t2.y ORDER BY t1.x NULLS LAST; + +SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR t1.y <=> t2.y OR (t1.x IS NULL AND t1.y IS NULL AND t2.x IS NULL AND t2.y IS NULL)) ORDER BY t1.x NULLS LAST; + +SELECT * FROM t1 JOIN t2 ON (t1.x <=> t2.x OR (t1.x IS NULL AND t2.x IS NULL)) AND (t1.y == t2.y OR (t1.y IS NULL AND t2.y IS NULL)) AND COALESCE(t1.x, 0) != 2 ORDER BY t1.x NULLS LAST; + +SELECT x = y OR (x IS NULL AND y IS NULL) FROM t1 ORDER BY x NULLS LAST; +-- { echoOff } + +DROP TABLE IF EXISTS t1; +DROP TABLE IF EXISTS t2;