diff --git a/dbms/src/Interpreters/ExpressionAnalyzer.cpp b/dbms/src/Interpreters/ExpressionAnalyzer.cpp index 85a2dd2c3f8..d1be66df217 100644 --- a/dbms/src/Interpreters/ExpressionAnalyzer.cpp +++ b/dbms/src/Interpreters/ExpressionAnalyzer.cpp @@ -24,7 +24,6 @@ #include #include -#include #include #include #include diff --git a/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.cpp b/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.cpp index 389d6c825b0..f40e91e7dcd 100644 --- a/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.cpp +++ b/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.cpp @@ -38,6 +38,7 @@ struct NonGlobalTableData const CheckShardsAndTables & checker; const Context & context; + std::vector & renamed_tables; ASTFunction * function = nullptr; ASTTableJoin * table_join = nullptr; @@ -95,10 +96,11 @@ private: String alias = database_and_table->tryGetAlias(); if (alias.empty()) - throw Exception("Distributed table should have an alias when distributed_product_mode set to local.", + throw Exception("Distributed table should have an alias when distributed_product_mode set to local", ErrorCodes::DISTRIBUTED_IN_JOIN_SUBQUERY_DENIED); auto & identifier = database_and_table->as(); + renamed_tables.emplace_back(identifier.clone()); identifier.resetTable(database, table); } else @@ -118,6 +120,7 @@ public: { const CheckShardsAndTables & checker; const Context & context; + std::vector>> & renamed_tables; }; static void visit(ASTPtr & node, Data & data) @@ -148,8 +151,11 @@ private: if (node.name == "in" || node.name == "notIn") { auto & subquery = node.arguments->children.at(1); - NonGlobalTableVisitor::Data table_data{data.checker, data.context, &node, nullptr}; + std::vector renamed; + NonGlobalTableVisitor::Data table_data{data.checker, data.context, renamed, &node, nullptr}; NonGlobalTableVisitor(table_data).visit(subquery); + if (!renamed.empty()) + data.renamed_tables.emplace_back(subquery, std::move(renamed)); } } @@ -163,8 +169,11 @@ private: { if (auto & subquery = node.table_expression->as()->subquery) { - NonGlobalTableVisitor::Data table_data{data.checker, data.context, nullptr, table_join}; + std::vector renamed; + NonGlobalTableVisitor::Data table_data{data.checker, data.context, renamed, nullptr, table_join}; NonGlobalTableVisitor(table_data).visit(subquery); + if (!renamed.empty()) + data.renamed_tables.emplace_back(subquery, std::move(renamed)); } } } @@ -208,7 +217,7 @@ void InJoinSubqueriesPreprocessor::visit(ASTPtr & ast) const return; } - NonGlobalSubqueryVisitor::Data visitor_data{*checker, context}; + NonGlobalSubqueryVisitor::Data visitor_data{*checker, context, renamed_tables}; NonGlobalSubqueryVisitor(visitor_data).visit(ast); } diff --git a/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.h b/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.h index ff39d812dee..5aa9cfbcadf 100644 --- a/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.h +++ b/dbms/src/Interpreters/InJoinSubqueriesPreprocessor.h @@ -35,6 +35,8 @@ class Context; class InJoinSubqueriesPreprocessor { public: + using SubqueryTables = std::vector>>; /// {subquery, renamed_tables} + struct CheckShardsAndTables { using Ptr = std::unique_ptr; @@ -45,8 +47,10 @@ public: virtual ~CheckShardsAndTables() {} }; - InJoinSubqueriesPreprocessor(const Context & context_, CheckShardsAndTables::Ptr _checker = std::make_unique()) + InJoinSubqueriesPreprocessor(const Context & context_, SubqueryTables & renamed_tables_, + CheckShardsAndTables::Ptr _checker = std::make_unique()) : context(context_) + , renamed_tables(renamed_tables_) , checker(std::move(_checker)) {} @@ -54,6 +58,7 @@ public: private: const Context & context; + SubqueryTables & renamed_tables; CheckShardsAndTables::Ptr checker; }; diff --git a/dbms/src/Interpreters/InterpreterSelectQuery.cpp b/dbms/src/Interpreters/InterpreterSelectQuery.cpp index 4fe83afa48d..db1894026d8 100644 --- a/dbms/src/Interpreters/InterpreterSelectQuery.cpp +++ b/dbms/src/Interpreters/InterpreterSelectQuery.cpp @@ -301,6 +301,8 @@ InterpreterSelectQuery::InterpreterSelectQuery( source_header = interpreter_subquery->getSampleBlock(); } + joined_tables.rewriteDistributedInAndJoins(query_ptr); + max_streams = settings.max_threads; ASTSelectQuery & query = getSelectQuery(); diff --git a/dbms/src/Interpreters/JoinedTables.cpp b/dbms/src/Interpreters/JoinedTables.cpp index beec338f9bf..48e763ffb19 100644 --- a/dbms/src/Interpreters/JoinedTables.cpp +++ b/dbms/src/Interpreters/JoinedTables.cpp @@ -1,12 +1,18 @@ #include #include #include +#include +#include +#include #include #include #include #include #include +#include #include +#include +#include namespace DB { @@ -14,6 +20,7 @@ namespace DB namespace ErrorCodes { extern const int ALIAS_REQUIRED; + extern const int AMBIGUOUS_COLUMN_NAME; } namespace @@ -32,6 +39,71 @@ void checkTablesWithColumns(const std::vector & tables_with_columns, const Co } } +class RenameQualifiedIdentifiersMatcher +{ +public: + using Data = const std::vector; + + static void visit(ASTPtr & ast, Data & data) + { + if (auto * t = ast->as()) + visit(*t, ast, data); + if (auto * node = ast->as()) + visit(*node, ast, data); + } + + static bool needChildVisit(ASTPtr & node, const ASTPtr & child) + { + if (node->as() || + node->as() || + child->as()) + return false; // NOLINT + return true; + } + +private: + static void visit(ASTIdentifier & identifier, ASTPtr &, Data & data) + { + if (identifier.isShort()) + return; + + bool rewritten = false; + for (auto & table : data) + { + /// Table has an alias. We do not need to rewrite qualified names with table alias (match == ColumnMatch::TableName). + auto match = IdentifierSemantic::canReferColumnToTable(identifier, table); + if (match == IdentifierSemantic::ColumnMatch::AliasedTableName || + match == IdentifierSemantic::ColumnMatch::DbAndTable) + { + if (rewritten) + throw Exception("Failed to rewrite distributed table names. Ambiguous column '" + identifier.name + "'", + ErrorCodes::AMBIGUOUS_COLUMN_NAME); + /// Table has an alias. So we set a new name qualified by table alias. + IdentifierSemantic::setColumnLongName(identifier, table); + rewritten = true; + } + } + } + + static void visit(const ASTQualifiedAsterisk & node, const ASTPtr &, Data & data) + { + ASTIdentifier & identifier = *node.children[0]->as(); + bool rewritten = false; + for (auto & table : data) + { + if (identifier.name == table.table) + { + if (rewritten) + throw Exception("Failed to rewrite distributed table. Ambiguous column '" + identifier.name + "'", + ErrorCodes::AMBIGUOUS_COLUMN_NAME); + identifier.setShortName(table.alias); + rewritten = true; + } + } + } +}; +using RenameQualifiedIdentifiersVisitor = InDepthNodeVisitor; + } JoinedTables::JoinedTables(Context && context_, const ASTSelectQuery & select_query) @@ -114,4 +186,27 @@ void JoinedTables::makeFakeTable(StoragePtr storage, const Block & source_header tables_with_columns.emplace_back(DatabaseAndTableWithAlias{}, source_header.getNamesAndTypesList()); } +void JoinedTables::rewriteDistributedInAndJoins(ASTPtr & query) +{ + /// Rewrite IN and/or JOIN for distributed tables according to distributed_product_mode setting. + InJoinSubqueriesPreprocessor::SubqueryTables renamed_tables; + InJoinSubqueriesPreprocessor(context, renamed_tables).visit(query); + + String database; + if (!renamed_tables.empty()) + database = context.getCurrentDatabase(); + + for (auto & [subquery, ast_tables] : renamed_tables) + { + std::vector renamed; + renamed.reserve(ast_tables.size()); + for (auto & ast : ast_tables) + renamed.emplace_back(DatabaseAndTableWithAlias(*ast->as(), database)); + + /// Change qualified column names in distributed subqueries using table aliases. + RenameQualifiedIdentifiersVisitor::Data data(renamed); + RenameQualifiedIdentifiersVisitor(data).visit(subquery); + } +} + } diff --git a/dbms/src/Interpreters/JoinedTables.h b/dbms/src/Interpreters/JoinedTables.h index f1940366ef5..66b3c8de609 100644 --- a/dbms/src/Interpreters/JoinedTables.h +++ b/dbms/src/Interpreters/JoinedTables.h @@ -37,6 +37,8 @@ public: const StorageID & leftTableID() const { return table_id; } + void rewriteDistributedInAndJoins(ASTPtr & query); + std::unique_ptr makeLeftTableSubquery(const SelectQueryOptions & select_options); private: diff --git a/dbms/src/Interpreters/SyntaxAnalyzer.cpp b/dbms/src/Interpreters/SyntaxAnalyzer.cpp index 7338487c5e8..f93d11fa1da 100644 --- a/dbms/src/Interpreters/SyntaxAnalyzer.cpp +++ b/dbms/src/Interpreters/SyntaxAnalyzer.cpp @@ -2,7 +2,6 @@ #include #include -#include #include #include #include @@ -819,9 +818,6 @@ SyntaxAnalyzerResultPtr SyntaxAnalyzer::analyzeSelect( translateQualifiedNames(query, *select_query, source_columns_set, tables_with_column_names); - /// Rewrite IN and/or JOIN for distributed tables according to distributed_product_mode setting. - InJoinSubqueriesPreprocessor(context).visit(query); - /// Optimizes logical expressions. LogicalExpressionsOptimizer(select_query, settings.optimize_min_equality_disjunction_chain_length.value).perform(); diff --git a/dbms/src/Interpreters/tests/in_join_subqueries_preprocessor.cpp b/dbms/src/Interpreters/tests/in_join_subqueries_preprocessor.cpp index 9a6d7ca4162..9a17f03f32a 100644 --- a/dbms/src/Interpreters/tests/in_join_subqueries_preprocessor.cpp +++ b/dbms/src/Interpreters/tests/in_join_subqueries_preprocessor.cpp @@ -1181,7 +1181,8 @@ TestResult check(const TestEntry & entry) try { - DB::InJoinSubqueriesPreprocessor(context, std::make_unique()).visit(ast_input); + DB::InJoinSubqueriesPreprocessor::SubqueryTables renamed; + DB::InJoinSubqueriesPreprocessor(context, renamed, std::make_unique()).visit(ast_input); } catch (const DB::Exception & ex) { diff --git a/dbms/tests/queries/0_stateless/01103_distributed_product_mode_local_column_renames.reference b/dbms/tests/queries/0_stateless/01103_distributed_product_mode_local_column_renames.reference new file mode 100644 index 00000000000..53f7332cffb --- /dev/null +++ b/dbms/tests/queries/0_stateless/01103_distributed_product_mode_local_column_renames.reference @@ -0,0 +1,12 @@ +42 +42 +42 +42 +42 +42 +42 +42 +42 +42 +42 +42 diff --git a/dbms/tests/queries/0_stateless/01103_distributed_product_mode_local_column_renames.sql b/dbms/tests/queries/0_stateless/01103_distributed_product_mode_local_column_renames.sql new file mode 100644 index 00000000000..32655420a27 --- /dev/null +++ b/dbms/tests/queries/0_stateless/01103_distributed_product_mode_local_column_renames.sql @@ -0,0 +1,87 @@ +CREATE DATABASE IF NOT EXISTS test_01103; +USE test_01103; + +DROP TABLE IF EXISTS t1_shard; +DROP TABLE IF EXISTS t2_shard; +DROP TABLE IF EXISTS t1_distr; +DROP TABLE IF EXISTS t2_distr; + +create table t1_shard (id Int32) engine MergeTree order by id; +create table t2_shard (id Int32) engine MergeTree order by id; + +create table t1_distr as t1_shard engine Distributed(test_cluster_two_shards_localhost, test_01103, t1_shard, id); +create table t2_distr as t2_shard engine Distributed(test_cluster_two_shards_localhost, test_01103, t2_shard, id); + +insert into t1_shard values (42); +insert into t2_shard values (42); + +SET distributed_product_mode = 'local'; + +select d0.id +from t1_distr d0 +where d0.id in +( + select d1.id + from t1_distr as d1 + inner join t2_distr as d2 on d1.id = d2.id + where d1.id > 0 + order by d1.id +); + +select t1_distr.id +from t1_distr +where t1_distr.id in +( + select t1_distr.id + from t1_distr as d1 + inner join t2_distr as d2 on t1_distr.id = t2_distr.id + where t1_distr.id > 0 + order by t1_distr.id +); + +select test_01103.t1_distr.id +from test_01103.t1_distr +where test_01103.t1_distr.id in +( + select test_01103.t1_distr.id + from test_01103.t1_distr as d1 + inner join test_01103.t2_distr as d2 on test_01103.t1_distr.id = test_01103.t2_distr.id + where test_01103.t1_distr.id > 0 + order by test_01103.t1_distr.id +); + +select d0.id +from t1_distr d0 +join ( + select d1.id + from t1_distr as d1 + inner join t2_distr as d2 on d1.id = d2.id + where d1.id > 0 + order by d1.id +) s0 using id; + +select t1_distr.id +from t1_distr +join ( + select t1_distr.id + from t1_distr as d1 + inner join t2_distr as d2 on t1_distr.id = t2_distr.id + where t1_distr.id > 0 + order by t1_distr.id +) s0 using id; + +select test_01103.t1_distr.id +from test_01103.t1_distr +join ( + select test_01103.t1_distr.id + from test_01103.t1_distr as d1 + inner join test_01103.t2_distr as d2 on test_01103.t1_distr.id = test_01103.t2_distr.id + where test_01103.t1_distr.id > 0 + order by test_01103.t1_distr.id +) s0 using id; + +DROP TABLE t1_shard; +DROP TABLE t2_shard; +DROP TABLE t1_distr; +DROP TABLE t2_distr; +DROP DATABASE test_01103; diff --git a/dbms/tests/queries/0_stateless/01104_distributed_numbers_test.reference b/dbms/tests/queries/0_stateless/01104_distributed_numbers_test.reference new file mode 100644 index 00000000000..c5079fa2cfd --- /dev/null +++ b/dbms/tests/queries/0_stateless/01104_distributed_numbers_test.reference @@ -0,0 +1,4 @@ +100 +100 +100 +100 diff --git a/dbms/tests/queries/0_stateless/01104_distributed_numbers_test.sql b/dbms/tests/queries/0_stateless/01104_distributed_numbers_test.sql new file mode 100644 index 00000000000..b301c0ac00f --- /dev/null +++ b/dbms/tests/queries/0_stateless/01104_distributed_numbers_test.sql @@ -0,0 +1,12 @@ +DROP TABLE IF EXISTS d_numbers; +CREATE TABLE d_numbers (number UInt32) ENGINE = Distributed(test_cluster_two_shards_localhost, system, numbers, rand()); + +SET experimental_use_processors = 1; + +SELECT '100' AS number FROM d_numbers AS n WHERE n.number = 100 LIMIT 2; + +SET distributed_product_mode = 'local'; + +SELECT '100' AS number FROM d_numbers AS n WHERE n.number = 100 LIMIT 2; + +DROP TABLE d_numbers; diff --git a/dbms/tests/queries/0_stateless/01104_distributed_one_test.reference b/dbms/tests/queries/0_stateless/01104_distributed_one_test.reference new file mode 100644 index 00000000000..929dd64ae90 --- /dev/null +++ b/dbms/tests/queries/0_stateless/01104_distributed_one_test.reference @@ -0,0 +1,6 @@ +local_0 1 +distributed_0 1 1 +distributed_0 2 1 +local_0 1 +distributed_0 1 1 +distributed_0 2 1 diff --git a/dbms/tests/queries/0_stateless/01104_distributed_one_test.sql b/dbms/tests/queries/0_stateless/01104_distributed_one_test.sql new file mode 100644 index 00000000000..92b4a83ebf3 --- /dev/null +++ b/dbms/tests/queries/0_stateless/01104_distributed_one_test.sql @@ -0,0 +1,18 @@ +DROP TABLE IF EXISTS d_one; +CREATE TABLE d_one (dummy UInt8) ENGINE = Distributed(test_cluster_two_shards_localhost, system, one, rand()); + +SELECT 'local_0', toUInt8(1) AS dummy FROM system.one AS o WHERE o.dummy = 0; +SELECT 'local_1', toUInt8(1) AS dummy FROM system.one AS o WHERE o.dummy = 1; + +SELECT 'distributed_0', _shard_num, toUInt8(1) AS dummy FROM d_one AS o WHERE o.dummy = 0 ORDER BY _shard_num; +SELECT 'distributed_1', _shard_num, toUInt8(1) AS dummy FROM d_one AS o WHERE o.dummy = 1 ORDER BY _shard_num; + +SET distributed_product_mode = 'local'; + +SELECT 'local_0', toUInt8(1) AS dummy FROM system.one AS o WHERE o.dummy = 0; +SELECT 'local_1', toUInt8(1) AS dummy FROM system.one AS o WHERE o.dummy = 1; + +SELECT 'distributed_0', _shard_num, toUInt8(1) AS dummy FROM d_one AS o WHERE o.dummy = 0 ORDER BY _shard_num; +SELECT 'distributed_1', _shard_num, toUInt8(1) AS dummy FROM d_one AS o WHERE o.dummy = 1 ORDER BY _shard_num; + +DROP TABLE d_one;