From 8e2637aab2b7d22b97d9100aaccfb14da5ed470f Mon Sep 17 00:00:00 2001 From: vdimir Date: Thu, 2 Sep 2021 14:40:04 +0300 Subject: [PATCH] Store all related to one join disjunct in JoinOnClause, pt1 --- src/Interpreters/CollectJoinOnKeysVisitor.cpp | 7 +- src/Interpreters/CollectJoinOnKeysVisitor.h | 1 - .../LogicalExpressionsOptimizer.cpp | 1 - src/Interpreters/MergeJoin.cpp | 8 +- src/Interpreters/TableJoin.cpp | 141 ++++++++---------- src/Interpreters/TableJoin.h | 76 +++++++--- .../Optimizations/filterPushDown.cpp | 3 +- src/Storages/StorageJoin.cpp | 2 +- 8 files changed, 130 insertions(+), 109 deletions(-) diff --git a/src/Interpreters/CollectJoinOnKeysVisitor.cpp b/src/Interpreters/CollectJoinOnKeysVisitor.cpp index e6e9c37f3fc..99f1fbc0082 100644 --- a/src/Interpreters/CollectJoinOnKeysVisitor.cpp +++ b/src/Interpreters/CollectJoinOnKeysVisitor.cpp @@ -41,11 +41,6 @@ void CollectJoinOnKeysMatcher::Data::setDisjuncts(const ASTPtr & or_func_ast) analyzed_join.setDisjuncts(std::move(v)); } -void CollectJoinOnKeysMatcher::Data::addDisjunct(const ASTPtr & ast) -{ - analyzed_join.addDisjunct(std::move(ast)); -} - void CollectJoinOnKeysMatcher::Data::addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, JoinIdentifierPosPair table_pos) { ASTPtr left = left_ast->clone(); @@ -107,7 +102,7 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as return; } - data.addDisjunct(ast); + data.analyzed_join.addDisjunct(ast); if (func.name == "and") return; /// go into children diff --git a/src/Interpreters/CollectJoinOnKeysVisitor.h b/src/Interpreters/CollectJoinOnKeysVisitor.h index 54b4ee39478..61e526b3b4d 100644 --- a/src/Interpreters/CollectJoinOnKeysVisitor.h +++ b/src/Interpreters/CollectJoinOnKeysVisitor.h @@ -52,7 +52,6 @@ public: void addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, JoinIdentifierPosPair table_pos, const ASOF::Inequality & asof_inequality); void setDisjuncts(const ASTPtr & or_func_ast); - void addDisjunct(const ASTPtr & ast); void asofToJoinKeys(); }; diff --git a/src/Interpreters/LogicalExpressionsOptimizer.cpp b/src/Interpreters/LogicalExpressionsOptimizer.cpp index ad3a1b8424a..936ed0149d2 100644 --- a/src/Interpreters/LogicalExpressionsOptimizer.cpp +++ b/src/Interpreters/LogicalExpressionsOptimizer.cpp @@ -1,4 +1,3 @@ -#include #include #include diff --git a/src/Interpreters/MergeJoin.cpp b/src/Interpreters/MergeJoin.cpp index 9a54f64af72..2aced4e72be 100644 --- a/src/Interpreters/MergeJoin.cpp +++ b/src/Interpreters/MergeJoin.cpp @@ -507,7 +507,9 @@ MergeJoin::MergeJoin(std::shared_ptr table_join_, const Block & right ErrorCodes::PARAMETER_OUT_OF_BOUND); } - if (table_join->keyNamesLeft().size() > 1) + const auto & key_names_left_all = table_join->keyNamesLeft(); + const auto & key_names_right_all = table_join->keyNamesRight(); + if (key_names_left_all.size() != 1 || key_names_right_all.size() != 1) throw Exception("MergeJoin does not support OR", ErrorCodes::NOT_IMPLEMENTED); std::tie(mask_column_name_left, mask_column_name_right) = table_join->joinConditionColumnNames(0); @@ -522,8 +524,8 @@ MergeJoin::MergeJoin(std::shared_ptr table_join_, const Block & right key_names_right.push_back(deriveTempName(mask_column_name_right)); } - key_names_left.insert(key_names_left.end(), table_join->keyNamesLeft().front().begin(), table_join->keyNamesLeft().front().end()); - key_names_right.insert(key_names_right.end(), table_join->keyNamesRight().front().begin(), table_join->keyNamesRight().front().end()); + key_names_left.insert(key_names_left.end(), key_names_left_all.front().begin(), key_names_left_all.front().end()); + key_names_right.insert(key_names_right.end(), key_names_right_all.front().begin(), key_names_right_all.front().end()); addConditionJoinColumn(right_sample_block, JoinTableSide::Right); JoinCommon::splitAdditionalColumns(NamesVector{key_names_right}, right_sample_block, right_table_keys, right_columns_to_add); diff --git a/src/Interpreters/TableJoin.cpp b/src/Interpreters/TableJoin.cpp index 2564b00077d..edc03a794e2 100644 --- a/src/Interpreters/TableJoin.cpp +++ b/src/Interpreters/TableJoin.cpp @@ -62,22 +62,17 @@ TableJoin::TableJoin(const Settings & settings, VolumePtr tmp_volume_) , partial_merge_join_left_table_buffer_bytes(settings.partial_merge_join_left_table_buffer_bytes) , max_files_to_merge(settings.join_on_disk_max_files_to_merge) , temporary_files_codec(settings.temporary_files_codec) - , key_names_left(1) - , key_names_right(1) - , on_filter_condition_asts_left(1) - , on_filter_condition_asts_right(1) + , left_clauses(1) + , right_clauses(1) , tmp_volume(tmp_volume_) { } void TableJoin::resetCollected() { - key_names_left.clear(); - key_names_right.clear(); - key_asts_left.clear(); - key_asts_right.clear(); - on_filter_condition_asts_left.clear(); - on_filter_condition_asts_right.clear(); + left_clauses = std::vector(1); + right_clauses = std::vector(1); + columns_from_joined_table.clear(); columns_added_by_join.clear(); original_names.clear(); @@ -92,33 +87,21 @@ void TableJoin::resetCollected() void TableJoin::addUsingKey(const ASTPtr & ast) { - key_names_left.front().push_back(ast->getColumnName()); - key_names_right.front().push_back(ast->getAliasOrColumnName()); - - key_asts_left.push_back(ast); - key_asts_right.push_back(ast); - - auto & right_key = key_names_right.front().back(); - if (renames.count(right_key)) - right_key = renames[right_key]; + left_clauses.back().addKey(ast->getColumnName(), ast); + right_clauses.back().addKey(renamedRightColumnName(ast->getAliasOrColumnName()), ast); } -/// create new disjunct when see a child of a previously discovered OR +/// create new disjunct when see a direct child of a previously discovered OR void TableJoin::addDisjunct(const ASTPtr & ast) { const IAST * addr = ast.get(); if (std::find_if(disjuncts.begin(), disjuncts.end(), [addr](const ASTPtr & ast_){return ast_.get() == addr;}) != disjuncts.end()) { - assert(key_names_left.size() == disjunct_num + 1); - - if (!key_names_left[disjunct_num].empty() || !on_filter_condition_asts_left[disjunct_num].empty() || !on_filter_condition_asts_right[disjunct_num].empty()) + if (!left_clauses.back().key_names.empty() || !left_clauses.back().on_filter_conditions.empty() || !right_clauses.back().on_filter_conditions.empty()) { - disjunct_num++; - key_names_left.resize(disjunct_num + 1); - key_names_right.resize(disjunct_num + 1); - on_filter_condition_asts_left.resize(disjunct_num + 1); - on_filter_condition_asts_right.resize(disjunct_num + 1); + left_clauses.emplace_back(); + right_clauses.emplace_back(); } } } @@ -131,11 +114,8 @@ void TableJoin::setDisjuncts(Disjuncts&& disjuncts_) void TableJoin::addOnKeys(ASTPtr & left_table_ast, ASTPtr & right_table_ast) { - key_names_left[disjunct_num].push_back(left_table_ast->getColumnName()); - key_names_right[disjunct_num].push_back(right_table_ast->getAliasOrColumnName()); - - key_asts_left.push_back(left_table_ast); - key_asts_right.push_back(right_table_ast); + left_clauses.back().addKey(left_table_ast->getColumnName(), left_table_ast); + right_clauses.back().addKey(right_table_ast->getAliasOrColumnName(), right_table_ast); } /// @return how many times right key appears in ON section. @@ -145,9 +125,8 @@ size_t TableJoin::rightKeyInclusion(const String & name) const return 0; size_t count = 0; - for (const auto & key_names : key_names_right) - count += std::count(key_names.begin(), key_names.end(), name); - + for (const auto & clause : right_clauses) + count += std::count(clause.key_names.begin(), clause.key_names.end(), name); return count; } @@ -194,31 +173,39 @@ NamesWithAliases TableJoin::getNamesWithAliases(const NameSet & required_columns ASTPtr TableJoin::leftKeysList() const { ASTPtr keys_list = std::make_shared(); - keys_list->children = key_asts_left; - const size_t disjuncts_num = key_names_left.size(); - for (size_t d = 0; d < disjuncts_num; ++d) - if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Left, d)) + for (size_t i = 0; i < left_clauses.size(); ++i) + { + const auto & clause = left_clauses[i]; + keys_list->children.insert(keys_list->children.end(), clause.key_asts.begin(), clause.key_asts.end()); + if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Left, i)) keys_list->children.push_back(extra_cond); + } return keys_list; } ASTPtr TableJoin::rightKeysList() const { ASTPtr keys_list = std::make_shared(); - if (hasOn()) - keys_list->children = key_asts_right; - const size_t disjuncts_num = key_names_left.size(); - for (size_t d = 0; d < disjuncts_num; ++d) - if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Right, d)) + for (size_t i = 0; i < right_clauses.size(); ++i) + { + if (hasOn()) + { + const auto & clause = right_clauses[i]; + keys_list->children.insert(keys_list->children.end(), clause.key_asts.begin(), clause.key_asts.end()); + } + + if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Right, i)) keys_list->children.push_back(extra_cond); + } return keys_list; } Names TableJoin::requiredJoinedNames() const { NameSet required_columns_set; - for (const auto& key_names_right_part : key_names_right) - required_columns_set.insert(key_names_right_part.begin(), key_names_right_part.end()); + for (const auto & clause : right_clauses) + required_columns_set.insert(clause.key_names.begin(), clause.key_names.end()); + for (const auto & joined_column : columns_added_by_join) required_columns_set.insert(joined_column.name); @@ -228,9 +215,9 @@ Names TableJoin::requiredJoinedNames() const NameSet TableJoin::requiredRightKeys() const { NameSet required; - for (const auto & key_names_right_part : key_names_right) + for (const auto & clause : right_clauses) { - for (const auto & name : key_names_right_part) + for (const auto & name : clause.key_names) { auto rename = renamedRightColumnName(name); for (const auto & column : columns_added_by_join) @@ -369,7 +356,7 @@ bool TableJoin::allowMergeJoin() const bool all_join = is_all && (isInner(kind()) || isLeft(kind()) || isRight(kind()) || isFull(kind())); bool special_left = isLeft(kind()) && (is_any || is_semi); - bool no_ors = (key_names_right.size() == 1); + bool no_ors = (left_clauses.size() == 1); return (all_join || special_left) && no_ors; } @@ -407,7 +394,7 @@ bool TableJoin::tryInitDictJoin(const Block & sample_block, ContextPtr context) if (!allowed_inner && !allowed_left) return false; - const Names & right_keys = keyNamesRight().front(); + const Names & right_keys = right_clauses.front().key_names; if (right_keys.size() != 1) return false; @@ -470,12 +457,14 @@ bool TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const Rig for (const auto & col : right) right_types[renamedRightColumnName(col.name)] = col.type; - for (size_t d = 0; d < key_names_left.size(); ++d) + for (size_t d = 0; d < left_clauses.size(); ++d) { - for (size_t i = 0; i < key_names_left[d].size(); ++i) + auto & key_names_left = left_clauses[d].key_names; + auto & key_names_right = right_clauses[d].key_names; + for (size_t i = 0; i < key_names_left.size(); ++i) { - auto ltype = left_types.find(key_names_left[d][i]); - auto rtype = right_types.find(key_names_right[d][i]); + auto ltype = left_types.find(key_names_left[i]); + auto rtype = right_types.find(key_names_right[i]); if (ltype == left_types.end() || rtype == right_types.end()) { /// Name mismatch, give up @@ -495,13 +484,14 @@ bool TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const Rig } catch (DB::Exception & ex) { - throw DB::Exception(ErrorCodes::TYPE_MISMATCH, - "Can't infer common type for joined columns: {}: {} at left, {}: {} at right. {}", - key_names_left[d][i], ltype->second->getName(), - key_names_right[d][i], rtype->second->getName(), - ex.message()); + throw Exception( + "Type mismatch of columns to JOIN by: " + + key_names_left[d][i] + ": " + ltype->second->getName() + " at left, " + + key_names_right[d][i] + ": " + rtype->second->getName() + " at right. " + + "Can't get supertype: " + ex.message(), + ErrorCodes::TYPE_MISMATCH); } - left_type_map[key_names_left[d][i]] = right_type_map[key_names_right[d][i]] = supertype; + left_type_map[key_names_left[i]] = right_type_map[key_names_right[i]] = supertype; } } @@ -518,7 +508,7 @@ bool TableJoin::inferJoinKeyCommonType(const LeftNamesAndTypes & left, const Rig } ActionsDAGPtr TableJoin::applyKeyConvertToTable( - const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, NamesVector & names_vector_to_rename) const + const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, std::vector & join_clause) const { bool has_some_to_do = false; @@ -540,9 +530,9 @@ ActionsDAGPtr TableJoin::applyKeyConvertToTable( auto dag = ActionsDAG::makeConvertingActions( cols_src, cols_dst, ActionsDAG::MatchColumnsMode::Name, true, !hasUsing(), &key_column_rename); - for (auto & disjunct_names : names_vector_to_rename) + for (auto & clause : join_clause) { - for (auto & name : disjunct_names) + for (auto & name : clause.key_names) { const auto it = key_column_rename.find(name); if (it != key_column_rename.end()) @@ -577,9 +567,9 @@ String TableJoin::renamedRightColumnName(const String & name) const void TableJoin::addJoinCondition(const ASTPtr & ast, bool is_left) { if (is_left) - on_filter_condition_asts_left[disjunct_num].push_back(ast); + left_clauses.back().on_filter_conditions.push_back(ast); else - on_filter_condition_asts_right[disjunct_num].push_back(ast); + right_clauses.back().on_filter_conditions.push_back(ast); } void TableJoin::leftToRightKeyRemap( @@ -607,33 +597,34 @@ std::unordered_map TableJoin::leftToRightKeyRemap() const if (hasUsing()) { const auto & required_right_keys = requiredRightKeys(); - for (size_t i = 0; i < key_names_left.size(); ++i) - TableJoin::leftToRightKeyRemap(key_names_left[i], key_names_right[i], required_right_keys, left_to_right_key_remap); + for (size_t i = 0; i < left_clauses.size(); ++i) + TableJoin::leftToRightKeyRemap(left_clauses[i].key_names, right_clauses[i].key_names, required_right_keys, left_to_right_key_remap); } return left_to_right_key_remap; } /// Returns all conditions related to one table joined with 'and' function -static ASTPtr buildJoinConditionColumn(const ASTsVector & on_filter_condition_asts, size_t disjunct) +static ASTPtr buildJoinConditionColumn(const ASTs & on_filter_condition_asts) { - if (on_filter_condition_asts[disjunct].empty()) + if (on_filter_condition_asts.empty()) return nullptr; - if (on_filter_condition_asts[disjunct].size() == 1) - return on_filter_condition_asts[disjunct][0]; + + if (on_filter_condition_asts.size() == 1) + return on_filter_condition_asts[0]; auto function = std::make_shared(); function->name = "and"; function->arguments = std::make_shared(); function->children.push_back(function->arguments); - function->arguments->children = on_filter_condition_asts[disjunct]; + function->arguments->children = on_filter_condition_asts; return function; } ASTPtr TableJoin::joinConditionColumn(JoinTableSide side, size_t disjunct) const { if (side == JoinTableSide::Left) - return buildJoinConditionColumn(on_filter_condition_asts_left, disjunct); - return buildJoinConditionColumn(on_filter_condition_asts_right, disjunct); + return buildJoinConditionColumn(left_clauses[disjunct].on_filter_conditions); + return buildJoinConditionColumn(right_clauses[disjunct].on_filter_conditions); } std::pair TableJoin::joinConditionColumnNames(size_t disjunct) const diff --git a/src/Interpreters/TableJoin.h b/src/Interpreters/TableJoin.h index 9624134b1eb..77bf6a2215f 100644 --- a/src/Interpreters/TableJoin.h +++ b/src/Interpreters/TableJoin.h @@ -74,16 +74,31 @@ private: const size_t max_files_to_merge = 0; const String temporary_files_codec = "LZ4"; - NamesVector key_names_left; - NamesVector key_names_right; /// Duplicating names are qualified. - ASTsVector on_filter_condition_asts_left; - ASTsVector on_filter_condition_asts_right; -private: - size_t disjunct_num = 0; + /// Corresponds to one disjunct + struct JoinOnClause + { + Names key_names; + ASTs key_asts; + + ASTs on_filter_conditions; + + JoinOnClause() = default; + + explicit JoinOnClause(const Names & names) + : key_names(names) + {} + + void addKey(const String & name, const ASTPtr & ast) + { + key_names.emplace_back(name); + key_asts.emplace_back(ast); + } + }; + Disjuncts disjuncts; - ASTs key_asts_left; - ASTs key_asts_right; + std::vector left_clauses; + std::vector right_clauses; /// Duplicating key_names are qualified. ASTTableJoin table_join; @@ -116,7 +131,7 @@ private: /// Create converting actions and change key column names if required ActionsDAGPtr applyKeyConvertToTable( - const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, NamesVector & names_vector_to_rename) const; + const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, std::vector & join_clause) const; /// Calculates common supertypes for corresponding join key columns. template @@ -131,10 +146,8 @@ private: public: TableJoin() - : key_names_left(1) - , key_names_right(1) - , on_filter_condition_asts_left(1) - , on_filter_condition_asts_right(1) + : left_clauses(1) + , right_clauses(1) { } @@ -142,16 +155,14 @@ public: /// for StorageJoin TableJoin(SizeLimits limits, bool use_nulls, ASTTableJoin::Kind kind, ASTTableJoin::Strictness strictness, - const NamesVector & key_names_right_) + const Names & key_names_right) : size_limits(limits) , default_max_bytes(0) , join_use_nulls(use_nulls) , join_algorithm(JoinAlgorithm::HASH) - , key_names_left(1) - , key_names_right(key_names_right_) - , on_filter_condition_asts_left(1) - , on_filter_condition_asts_right(1) + , left_clauses(1) { + right_clauses.emplace_back(key_names_right); table_join.kind = kind; table_join.strictness = strictness; } @@ -232,8 +243,26 @@ public: ASTPtr leftKeysList() const; ASTPtr rightKeysList() const; /// For ON syntax only - const NamesVector & keyNamesLeft() const { return key_names_left; } - const NamesVector & keyNamesRight() const { return key_names_right; } + NamesVector keyNamesLeft() const + { + NamesVector key_names; + for (const auto & clause : left_clauses) + { + key_names.push_back(clause.key_names); + } + return key_names; + } + + NamesVector keyNamesRight() const + { + NamesVector key_names; + for (const auto & clause : right_clauses) + { + key_names.push_back(clause.key_names); + } + return key_names; + } + const NamesAndTypesList & columnsFromJoinedTable() const { return columns_from_joined_table; } Names columnsAddedByJoin() const @@ -245,7 +274,12 @@ public: } /// StorageJoin overrides key names (cause of different names qualification) - void setRightKeys(const Names & keys) { key_names_right.clear(); key_names_right.push_back(keys); } + void setRightKeys(const Names & keys) + { + // assert(right_clauses.size() <= 1); + right_clauses.clear(); + right_clauses.emplace_back(keys); + } Block getRequiredRightKeys(const Block & right_table_keys, std::vector & keys_sources) const; diff --git a/src/Processors/QueryPlan/Optimizations/filterPushDown.cpp b/src/Processors/QueryPlan/Optimizations/filterPushDown.cpp index 10cfad82dd6..6f599ddd85c 100644 --- a/src/Processors/QueryPlan/Optimizations/filterPushDown.cpp +++ b/src/Processors/QueryPlan/Optimizations/filterPushDown.cpp @@ -202,7 +202,8 @@ size_t tryPushDownFilter(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes const auto & left_header = join->getInputStreams().front().header; const auto & res_header = join->getOutputStream().header; Names allowed_keys; - for (const auto & name : table_join.keyNamesLeft().front()) + const auto & key_names_left = table_join.keyNamesLeft(); + for (const auto & name : key_names_left.front()) { /// Skip key if it is renamed. /// I don't know if it is possible. Just in case. diff --git a/src/Storages/StorageJoin.cpp b/src/Storages/StorageJoin.cpp index 39e81e6cf75..fdc026bb6c2 100644 --- a/src/Storages/StorageJoin.cpp +++ b/src/Storages/StorageJoin.cpp @@ -62,7 +62,7 @@ StorageJoin::StorageJoin( if (!metadata_snapshot->getColumns().hasPhysical(key)) throw Exception{"Key column (" + key + ") does not exist in table declaration.", ErrorCodes::NO_SUCH_COLUMN_IN_TABLE}; - table_join = std::make_shared(limits, use_nulls, kind, strictness, NamesVector{key_names}); + table_join = std::make_shared(limits, use_nulls, kind, strictness, key_names); join = std::make_shared(table_join, metadata_snapshot->getSampleBlock().sortColumns(), overwrite); restore(); }