diff --git a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.cpp b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.cpp new file mode 100644 index 00000000000..68e04b45d99 --- /dev/null +++ b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.cpp @@ -0,0 +1,210 @@ +#include + +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int INVALID_JOIN_ON_EXPRESSION; + extern const int AMBIGUOUS_COLUMN_NAME; + extern const int NOT_IMPLEMENTED; + extern const int LOGICAL_ERROR; +} + +void CollectJoinOnKeysMatcher::Data::addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, + const std::pair & table_no) +{ + ASTPtr left = left_ast->clone(); + ASTPtr right = right_ast->clone(); + + if (table_no.first == 1 || table_no.second == 2) + analyzed_join.addOnKeys(left, right); + else if (table_no.first == 2 || table_no.second == 1) + analyzed_join.addOnKeys(right, left); + else + throw Exception("Cannot detect left and right JOIN keys. JOIN ON section is ambiguous.", + ErrorCodes::AMBIGUOUS_COLUMN_NAME); + has_some = true; +} + +void CollectJoinOnKeysMatcher::Data::addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, + const std::pair & table_no) +{ + if (table_no.first == 1 || table_no.second == 2) + { + asof_left_key = left_ast->clone(); + asof_right_key = right_ast->clone(); + return; + } + + throw Exception("ASOF JOIN for (left_table.x <= right_table.x) is not implemented", ErrorCodes::NOT_IMPLEMENTED); +} + +void CollectJoinOnKeysMatcher::Data::asofToJoinKeys() +{ + if (!asof_left_key || !asof_right_key) + throw Exception("No inequality in ASOF JOIN ON section.", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); + addJoinKeys(asof_left_key, asof_right_key, {1, 2}); +} + + +void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & ast, Data & data) +{ + if (func.name == "and") + return; /// go into children + + if (func.name == "equals") + { + ASTPtr left = func.arguments->children.at(0); + ASTPtr right = func.arguments->children.at(1); + auto table_numbers = getTableNumbers(ast, left, right, data); + data.addJoinKeys(left, right, table_numbers); + return; + } + + bool less_or_equals = (func.name == "lessOrEquals"); + bool greater_or_equals = (func.name == "greaterOrEquals"); + + if (data.is_asof && (less_or_equals || greater_or_equals)) + { + if (data.asof_left_key || data.asof_right_key) + throwSyntaxException("ASOF JOIN expects exactly one inequality in ON section, unexpected " + queryToString(ast) + "."); + + ASTPtr left = func.arguments->children.at(0); + ASTPtr right = func.arguments->children.at(1); + auto table_numbers = getTableNumbers(ast, left, right, data); + + if (greater_or_equals) + data.addAsofJoinKeys(left, right, table_numbers); + else + data.addAsofJoinKeys(right, left, std::make_pair(table_numbers.second, table_numbers.first)); + + return; + } + + throwSyntaxException("Expected equals expression, got " + queryToString(ast) + "."); +} + +void CollectJoinOnKeysMatcher::getIdentifiers(const ASTPtr & ast, std::vector & out) +{ + if (const auto * ident = ast->as()) + { + if (IdentifierSemantic::getColumnName(*ident)) + out.push_back(ident); + return; + } + + for (const auto & child : ast->children) + getIdentifiers(child, out); +} + +std::pair CollectJoinOnKeysMatcher::getTableNumbers(const ASTPtr & expr, const ASTPtr & left_ast, const ASTPtr & right_ast, + Data & data) +{ + std::vector left_identifiers; + std::vector right_identifiers; + + getIdentifiers(left_ast, left_identifiers); + getIdentifiers(right_ast, right_identifiers); + + size_t left_idents_table = getTableForIdentifiers(left_identifiers, data); + size_t right_idents_table = getTableForIdentifiers(right_identifiers, data); + + if (left_idents_table && left_idents_table == right_idents_table) + { + auto left_name = queryToString(*left_identifiers[0]); + auto right_name = queryToString(*right_identifiers[0]); + + throwSyntaxException("In expression " + queryToString(expr) + " columns " + left_name + " and " + right_name + + " are from the same table but from different arguments of equal function."); + } + + return std::make_pair(left_idents_table, right_idents_table); +} + +const ASTIdentifier * CollectJoinOnKeysMatcher::unrollAliases(const ASTIdentifier * identifier, const Aliases & aliases) +{ + if (identifier->compound()) + return identifier; + + UInt32 max_attempts = 100; + for (auto it = aliases.find(identifier->name); it != aliases.end();) + { + const ASTIdentifier * parent = identifier; + identifier = it->second->as(); + if (!identifier) + break; /// not a column alias + if (identifier == parent) + break; /// alias to itself with the same name: 'a as a' + if (identifier->compound()) + break; /// not an alias. Break to prevent cycle through short names: 'a as b, t1.b as a' + + it = aliases.find(identifier->name); + if (!max_attempts--) + throw Exception("Cannot unroll aliases for '" + identifier->name + "'", ErrorCodes::LOGICAL_ERROR); + } + + return identifier; +} + +/// @returns 1 if identifiers belongs to left table, 2 for right table and 0 if unknown. Throws on table mix. +/// Place detected identifier into identifiers[0] if any. +size_t CollectJoinOnKeysMatcher::getTableForIdentifiers(std::vector & identifiers, const Data & data) +{ + size_t table_number = 0; + + for (auto & ident : identifiers) + { + const ASTIdentifier * identifier = unrollAliases(ident, data.aliases); + if (!identifier) + continue; + + /// Column name could be cropped to a short form in TranslateQualifiedNamesVisitor. + /// In this case it saves membership in IdentifierSemantic. + size_t membership = IdentifierSemantic::getMembership(*identifier); + + if (!membership) + { + const String & name = identifier->name; + bool in_left_table = data.source_columns.count(name); + bool in_right_table = data.joined_columns.count(name); + + if (in_left_table && in_right_table) + throw Exception("Column '" + name + "' is ambiguous", ErrorCodes::AMBIGUOUS_COLUMN_NAME); + + if (in_left_table) + membership = 1; + if (in_right_table) + membership = 2; + } + + if (membership && table_number == 0) + { + table_number = membership; + std::swap(ident, identifiers[0]); /// move first detected identifier to the first position + } + + if (membership && membership != table_number) + { + throw Exception("Invalid columns in JOIN ON section. Columns " + + identifiers[0]->getAliasOrColumnName() + " and " + ident->getAliasOrColumnName() + + " are from different tables.", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); + } + } + + return table_number; +} + +[[noreturn]] void CollectJoinOnKeysMatcher::throwSyntaxException(const String & msg) +{ + throw Exception("Invalid expression for JOIN ON. " + msg + + " Supported syntax: JOIN ON Expr([table.]column, ...) = Expr([table.]column, ...) " + "[AND Expr([table.]column, ...) = Expr([table.]column, ...) ...]", + ErrorCodes::INVALID_JOIN_ON_EXPRESSION); +} + +} diff --git a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h index 7dc3051167a..bae6781a18a 100644 --- a/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h +++ b/dbms/src/Interpreters/CollectJoinOnKeysVisitor.h @@ -1,23 +1,16 @@ #pragma once +#include #include -#include - #include #include -#include namespace DB { -namespace ErrorCodes -{ - extern const int INVALID_JOIN_ON_EXPRESSION; - extern const int AMBIGUOUS_COLUMN_NAME; - extern const int LOGICAL_ERROR; -} - +class ASTIdentifier; +struct AnalyzedJoin; class CollectJoinOnKeysMatcher { @@ -30,7 +23,14 @@ public: const NameSet & source_columns; const NameSet & joined_columns; const Aliases & aliases; - bool has_some = false; + const bool is_asof; + ASTPtr asof_left_key{}; + ASTPtr asof_right_key{}; + bool has_some{false}; + + void addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no); + void addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no); + void asofToJoinKeys(); }; static void visit(const ASTPtr & ast, Data & data) @@ -48,146 +48,14 @@ public: } private: - static void visit(const ASTFunction & func, const ASTPtr & ast, Data & data) - { - if (func.name == "and") - return; /// go into children + static void visit(const ASTFunction & func, const ASTPtr & ast, Data & data); - if (func.name == "equals") - { - ASTPtr left = func.arguments->children.at(0)->clone(); - ASTPtr right = func.arguments->children.at(1)->clone(); - addJoinKeys(ast, left, right, data); - return; - } + static void getIdentifiers(const ASTPtr & ast, std::vector & out); + static std::pair getTableNumbers(const ASTPtr & expr, const ASTPtr & left_ast, const ASTPtr & right_ast, Data & data); + static const ASTIdentifier * unrollAliases(const ASTIdentifier * identifier, const Aliases & aliases); + static size_t getTableForIdentifiers(std::vector & identifiers, const Data & data); - throwSyntaxException("Expected equals expression, got " + queryToString(ast) + "."); - } - - static void getIdentifiers(const ASTPtr & ast, std::vector & out) - { - if (const auto * ident = ast->as()) - { - if (IdentifierSemantic::getColumnName(*ident)) - out.push_back(ident); - return; - } - - for (const auto & child : ast->children) - getIdentifiers(child, out); - } - - static void addJoinKeys(const ASTPtr & expr, ASTPtr left_ast, ASTPtr right_ast, Data & data) - { - std::vector left_identifiers; - std::vector right_identifiers; - - getIdentifiers(left_ast, left_identifiers); - getIdentifiers(right_ast, right_identifiers); - - size_t left_idents_table = getTableForIdentifiers(left_identifiers, data); - size_t right_idents_table = getTableForIdentifiers(right_identifiers, data); - - if (left_idents_table && left_idents_table == right_idents_table) - { - auto left_name = queryToString(*left_identifiers[0]); - auto right_name = queryToString(*right_identifiers[0]); - - throwSyntaxException("In expression " + queryToString(expr) + " columns " + left_name + " and " + right_name - + " are from the same table but from different arguments of equal function."); - } - - if (left_idents_table == 1 || right_idents_table == 2) - data.analyzed_join.addOnKeys(left_ast, right_ast); - else if (left_idents_table == 2 || right_idents_table == 1) - data.analyzed_join.addOnKeys(right_ast, left_ast); - else - throw Exception("Cannot detect left and right JOIN keys. JOIN ON section is ambiguous.", - ErrorCodes::AMBIGUOUS_COLUMN_NAME); - - data.has_some = true; - } - - static const ASTIdentifier * unrollAliases(const ASTIdentifier * identifier, const Aliases & aliases) - { - if (identifier->compound()) - return identifier; - - UInt32 max_attempts = 100; - for (auto it = aliases.find(identifier->name); it != aliases.end();) - { - const ASTIdentifier * parent = identifier; - identifier = it->second->as(); - if (!identifier) - break; /// not a column alias - if (identifier == parent) - break; /// alias to itself with the same name: 'a as a' - if (identifier->compound()) - break; /// not an alias. Break to prevent cycle through short names: 'a as b, t1.b as a' - - it = aliases.find(identifier->name); - if (!max_attempts--) - throw Exception("Cannot unroll aliases for '" + identifier->name + "'", ErrorCodes::LOGICAL_ERROR); - } - - return identifier; - } - - /// @returns 1 if identifiers belongs to left table, 2 for right table and 0 if unknown. Throws on table mix. - /// Place detected identifier into identifiers[0] if any. - static size_t getTableForIdentifiers(std::vector & identifiers, const Data & data) - { - size_t table_number = 0; - - for (auto & ident : identifiers) - { - const ASTIdentifier * identifier = unrollAliases(ident, data.aliases); - if (!identifier) - continue; - - /// Column name could be cropped to a short form in TranslateQualifiedNamesVisitor. - /// In this case it saves membership in IdentifierSemantic. - size_t membership = IdentifierSemantic::getMembership(*identifier); - - if (!membership) - { - const String & name = identifier->name; - bool in_left_table = data.source_columns.count(name); - bool in_right_table = data.joined_columns.count(name); - - if (in_left_table && in_right_table) - throw Exception("Column '" + name + "' is ambiguous", ErrorCodes::AMBIGUOUS_COLUMN_NAME); - - if (in_left_table) - membership = 1; - if (in_right_table) - membership = 2; - } - - if (membership && table_number == 0) - { - table_number = membership; - std::swap(ident, identifiers[0]); /// move first detected identifier to the first position - } - - if (membership && membership != table_number) - { - throw Exception("Invalid columns in JOIN ON section. Columns " - + identifiers[0]->getAliasOrColumnName() + " and " + ident->getAliasOrColumnName() - + " are from different tables.", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); - } - } - - return table_number; - } - - [[noreturn]] static void throwSyntaxException(const String & msg) - { - throw Exception("Invalid expression for JOIN ON. " + msg + - " Supported syntax: JOIN ON Expr([table.]column, ...) = Expr([table.]column, ...) " - "[AND Expr([table.]column, ...) = Expr([table.]column, ...) ...]", - ErrorCodes::INVALID_JOIN_ON_EXPRESSION); - } + [[noreturn]] static void throwSyntaxException(const String & msg); }; /// Parse JOIN ON expression and collect ASTs for joined columns. diff --git a/dbms/src/Interpreters/SyntaxAnalyzer.cpp b/dbms/src/Interpreters/SyntaxAnalyzer.cpp index 04102f5ae15..02156b20995 100644 --- a/dbms/src/Interpreters/SyntaxAnalyzer.cpp +++ b/dbms/src/Interpreters/SyntaxAnalyzer.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -509,11 +510,14 @@ void collectJoinedColumns(AnalyzedJoin & analyzed_join, const ASTSelectQuery & s for (const auto & col : analyzed_join.columns_from_joined_table) joined_columns.insert(col.original_name); - CollectJoinOnKeysVisitor::Data data{analyzed_join, source_columns, joined_columns, aliases}; + bool is_asof = (table_join.strictness == ASTTableJoin::Strictness::Asof); + CollectJoinOnKeysVisitor::Data data{analyzed_join, source_columns, joined_columns, aliases, is_asof}; CollectJoinOnKeysVisitor(data).visit(table_join.on_expression); if (!data.has_some) throw Exception("Cannot get JOIN keys from JOIN ON section: " + queryToString(table_join.on_expression), ErrorCodes::INVALID_JOIN_ON_EXPRESSION); + if (is_asof) + data.asofToJoinKeys(); } bool make_nullable = join_use_nulls && isLeftOrFull(table_join.kind); diff --git a/dbms/tests/queries/0_stateless/00927_asof_join_noninclusive.sql b/dbms/tests/queries/0_stateless/00927_asof_join_noninclusive.sql index 50644352b64..5f15f3b593d 100644 --- a/dbms/tests/queries/0_stateless/00927_asof_join_noninclusive.sql +++ b/dbms/tests/queries/0_stateless/00927_asof_join_noninclusive.sql @@ -12,7 +12,7 @@ INSERT INTO B(k,t,b) VALUES (2,3,3); SELECT A.k, toString(A.t, 'UTC'), A.a, B.b, toString(B.t, 'UTC'), B.k FROM A ASOF LEFT JOIN B USING(k,t) ORDER BY (A.k, A.t); -SELECT A.k, toString(A.t, 'UTC'), A.a, B.b, toString(B.t, 'UTC'), B.k FROM A ASOF INNER JOIN B ON A.k == B.k AND A.t == B.t ORDER BY (A.k, A.t); +SELECT A.k, toString(A.t, 'UTC'), A.a, B.b, toString(B.t, 'UTC'), B.k FROM A ASOF INNER JOIN B ON A.k == B.k AND A.t >= B.t ORDER BY (A.k, A.t); SELECT A.k, toString(A.t, 'UTC'), A.a, B.b, toString(B.t, 'UTC'), B.k FROM A ASOF JOIN B USING(k,t) ORDER BY (A.k, A.t); diff --git a/dbms/tests/queries/0_stateless/00976_asof_join_on.reference b/dbms/tests/queries/0_stateless/00976_asof_join_on.reference new file mode 100644 index 00000000000..ffa8117cc75 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00976_asof_join_on.reference @@ -0,0 +1,13 @@ +1 1 0 0 +1 2 1 2 +1 3 1 2 +2 1 0 0 +2 2 0 0 +2 3 2 3 +3 1 0 0 +3 2 0 0 +3 3 0 0 +9 +1 2 1 2 +1 3 1 2 +2 3 2 3 diff --git a/dbms/tests/queries/0_stateless/00976_asof_join_on.sql b/dbms/tests/queries/0_stateless/00976_asof_join_on.sql new file mode 100644 index 00000000000..740287b7c30 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00976_asof_join_on.sql @@ -0,0 +1,21 @@ +DROP TABLE IF EXISTS A; +DROP TABLE IF EXISTS B; + +CREATE TABLE A(a UInt32, t UInt32) ENGINE = Memory; +CREATE TABLE B(b UInt32, t UInt32) ENGINE = Memory; + +INSERT INTO A (a,t) VALUES (1,1),(1,2),(1,3), (2,1),(2,2),(2,3), (3,1),(3,2),(3,3); +INSERT INTO B (b,t) VALUES (1,2),(1,4),(2,3); + +SELECT A.a, A.t, B.b, B.t FROM A ASOF LEFT JOIN B ON A.a == B.b AND A.t >= B.t ORDER BY (A.a, A.t); +SELECT count() FROM A ASOF LEFT JOIN B ON A.a == B.b AND B.t <= A.t; +SELECT A.a, A.t, B.b, B.t FROM A ASOF INNER JOIN B ON B.t <= A.t AND A.a == B.b; +SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t <= B.t; -- { serverError 48 } +SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND B.t >= A.t; -- { serverError 48 } +SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t > B.t; -- { serverError 403 } +SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t < B.t; -- { serverError 403 } +SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t == B.t; -- { serverError 403 } +SELECT count() FROM A ASOF JOIN B ON A.a == B.b AND A.t != B.t; -- { serverError 403 } + +DROP TABLE A; +DROP TABLE B;