diff --git a/src/Interpreters/CrossToInnerJoinVisitor.cpp b/src/Interpreters/CrossToInnerJoinVisitor.cpp index c8195706f04..b6f977cd9b5 100644 --- a/src/Interpreters/CrossToInnerJoinVisitor.cpp +++ b/src/Interpreters/CrossToInnerJoinVisitor.cpp @@ -81,56 +81,6 @@ private: ASTTableJoin * join = nullptr; }; -/// Collect all identifiers from ast -class IdentifiersCollector -{ -public: - using ASTIdentPtr = const ASTIdentifier *; - using ASTIdentifiers = std::vector; - struct Data - { - ASTIdentifiers idents; - }; - - static void visit(const ASTPtr & node, Data & data) - { - if (const auto * ident = node->as()) - data.idents.push_back(ident); - } - - static bool needChildVisit(const ASTPtr &, const ASTPtr &) - { - return true; - } - - static ASTIdentifiers collect(const ASTPtr & node) - { - IdentifiersCollector::Data ident_data; - ConstInDepthNodeVisitor ident_visitor(ident_data); - ident_visitor.visit(node); - return ident_data.idents; - } -}; - -/// Split expression `expr_1 AND expr_2 AND ... AND expr_n` into vector `[expr_1, expr_2, ..., expr_n]` -void collectConjunctions(const ASTPtr & node, std::vector & members) -{ - if (const auto * func = node->as(); func && func->name == NameAnd::name) - { - for (const auto & child : func->arguments->children) - collectConjunctions(child, members); - return; - } - members.push_back(node); -} - -std::vector collectConjunctions(const ASTPtr & node) -{ - std::vector members; - collectConjunctions(node, members); - return members; -} - bool isAllowedToRewriteCrossJoin(const ASTPtr & node, const Aliases & aliases) { if (node->as()) diff --git a/src/Interpreters/IdentifierSemantic.cpp b/src/Interpreters/IdentifierSemantic.cpp index ad5598afb5b..ff1fbbc8e2d 100644 --- a/src/Interpreters/IdentifierSemantic.cpp +++ b/src/Interpreters/IdentifierSemantic.cpp @@ -3,6 +3,8 @@ #include #include +#include + namespace DB { @@ -313,4 +315,22 @@ std::optional IdentifierMembershipCollector::getIdentsMembership(ASTPtr return IdentifierSemantic::getIdentsMembership(ast, tables, aliases); } +static void collectConjunctions(const ASTPtr & node, std::vector & members) +{ + if (const auto * func = node->as(); func && func->name == "and") + { + for (const auto & child : func->arguments->children) + collectConjunctions(child, members); + return; + } + members.push_back(node); +} + +std::vector collectConjunctions(const ASTPtr & node) +{ + std::vector members; + collectConjunctions(node, members); + return members; +} + } diff --git a/src/Interpreters/IdentifierSemantic.h b/src/Interpreters/IdentifierSemantic.h index 3a99150b792..9f11d8bdb9d 100644 --- a/src/Interpreters/IdentifierSemantic.h +++ b/src/Interpreters/IdentifierSemantic.h @@ -107,4 +107,7 @@ private: Aliases aliases; }; +/// Split expression `expr_1 AND expr_2 AND ... AND expr_n` into vector `[expr_1, expr_2, ..., expr_n]` +std::vector collectConjunctions(const ASTPtr & node); + } diff --git a/src/Storages/StorageMerge.cpp b/src/Storages/StorageMerge.cpp index df176bd3bcf..7730ef98c93 100644 --- a/src/Storages/StorageMerge.cpp +++ b/src/Storages/StorageMerge.cpp @@ -71,31 +71,23 @@ TreeRewriterResult modifySelect(ASTSelectQuery & select, const TreeRewriterResul return; const size_t left_table_pos = 0; - if (const auto * conjunctions = where->as(); conjunctions && conjunctions->name == "and") + /// Test each argument of `and` function and select ones related to only left table + std::shared_ptr new_conj = makeASTFunction("and"); + for (const auto & node : collectConjunctions(where)) { - /// Test each argument of `and` function and select related to only left table - std::shared_ptr new_conj = makeASTFunction("and"); - for (const auto & node : conjunctions->arguments->children) - { - if (membership_collector.getIdentsMembership(node) == left_table_pos) - new_conj->arguments->children.push_back(std::move(node)); - } - if (new_conj->arguments->children.empty()) - /// No identifiers from left table - query.setExpression(expr, {}); - else if (new_conj->arguments->children.size() == 1) - /// Only one expression, lift from `and` - query.setExpression(expr, std::move(new_conj->arguments->children[0])); - else - /// Set new expression - query.setExpression(expr, std::move(new_conj)); + if (membership_collector.getIdentsMembership(node) == left_table_pos) + new_conj->arguments->children.push_back(std::move(node)); } + + if (new_conj->arguments->children.empty()) + /// No identifiers from left table + query.setExpression(expr, {}); + else if (new_conj->arguments->children.size() == 1) + /// Only one expression, lift from `and` + query.setExpression(expr, std::move(new_conj->arguments->children[0])); else - { - /// Remove whole expression if not match to left table - if (membership_collector.getIdentsMembership(where) != left_table_pos) - query.setExpression(expr, {}); - } + /// Set new expression + query.setExpression(expr, std::move(new_conj)); }; replace_where(select,ASTSelectQuery::Expression::WHERE); replace_where(select,ASTSelectQuery::Expression::PREWHERE); diff --git a/tests/queries/0_stateless/01783_merge_engine_join_key_condition.reference b/tests/queries/0_stateless/01783_merge_engine_join_key_condition.reference index 9f7c2e7ee16..4068a6e00dd 100644 --- a/tests/queries/0_stateless/01783_merge_engine_join_key_condition.reference +++ b/tests/queries/0_stateless/01783_merge_engine_join_key_condition.reference @@ -2,3 +2,4 @@ 1 4 1 4 1 4 +1 4 diff --git a/tests/queries/0_stateless/01783_merge_engine_join_key_condition.sql b/tests/queries/0_stateless/01783_merge_engine_join_key_condition.sql index 97a5f2f0ef7..115ee42fe11 100644 --- a/tests/queries/0_stateless/01783_merge_engine_join_key_condition.sql +++ b/tests/queries/0_stateless/01783_merge_engine_join_key_condition.sql @@ -14,6 +14,7 @@ SET force_primary_key = 1; SELECT * FROM foo_merge WHERE Val = 3 AND Id = 3; SELECT count(), X FROM foo_merge JOIN t2 USING Val WHERE Val = 3 AND Id = 3 AND t2.X == 4 GROUP BY X; +SELECT count(), X FROM foo_merge JOIN t2 USING Val WHERE Val = 3 AND (Id = 3 AND t2.X == 4) GROUP BY X; SELECT count(), X FROM foo_merge JOIN t2 USING Val WHERE Val = 3 AND Id = 3 GROUP BY X; SELECT count(), X FROM (SELECT * FROM foo_merge) f JOIN t2 USING Val WHERE Val = 3 AND Id = 3 GROUP BY X;