From 6729a3313ee11f3a4da8224b1c10edbabcdafdb5 Mon Sep 17 00:00:00 2001 From: vdimir Date: Fri, 15 Nov 2024 17:28:04 +0000 Subject: [PATCH] wip --- src/Analyzer/Utils.cpp | 2 +- src/Interpreters/HashJoin/HashJoin.cpp | 36 +- src/Interpreters/HashJoin/HashJoin.h | 3 + .../HashJoin/HashJoinMethodsImpl.h | 19 +- src/Interpreters/InterpreterExplainQuery.cpp | 8 +- src/Interpreters/JoinInfo.cpp | 28 + src/Interpreters/JoinInfo.h | 194 ++++++ src/Interpreters/JoinUtils.cpp | 11 +- src/Interpreters/TableJoin.cpp | 109 +++- src/Interpreters/TableJoin.h | 41 +- src/Interpreters/TreeRewriter.cpp | 5 +- src/Parsers/CreateQueryUUIDs.cpp | 2 +- src/Planner/PlannerJoinTree.cpp | 494 +++------------ src/Planner/PlannerJoins.cpp | 299 ++++++++- src/Planner/PlannerJoins.h | 15 + src/Processors/QueryPlan/JoinStepLogical.cpp | 579 ++++++++++++++++++ src/Processors/QueryPlan/JoinStepLogical.h | 60 ++ .../QueryPlan/Optimizations/Optimizations.h | 1 + .../QueryPlanOptimizationSettings.h | 2 + .../QueryPlan/Optimizations/optimizeJoin.cpp | 116 ++++ .../QueryPlan/Optimizations/optimizeTree.cpp | 2 + src/Processors/QueryPlan/QueryPlan.cpp | 30 +- .../QueryPlan/ReadFromMemoryStorageStep.h | 2 + .../QueryPlan/ReadFromPreparedSource.cpp | 6 +- .../QueryPlan/ReadFromPreparedSource.h | 8 +- src/Storages/IStorage.cpp | 6 +- src/Storages/IStorage.h | 2 +- src/Storages/NATS/StorageNATS.cpp | 2 +- src/Storages/RabbitMQ/StorageRabbitMQ.cpp | 2 +- src/Storages/StorageExecutable.cpp | 2 +- 30 files changed, 1598 insertions(+), 488 deletions(-) create mode 100644 src/Interpreters/JoinInfo.cpp create mode 100644 src/Interpreters/JoinInfo.h create mode 100644 src/Processors/QueryPlan/JoinStepLogical.cpp create mode 100644 src/Processors/QueryPlan/JoinStepLogical.h create mode 100644 src/Processors/QueryPlan/Optimizations/optimizeJoin.cpp diff --git a/src/Analyzer/Utils.cpp b/src/Analyzer/Utils.cpp index c73400532ba..b00628766fd 100644 --- a/src/Analyzer/Utils.cpp +++ b/src/Analyzer/Utils.cpp @@ -229,7 +229,7 @@ QueryTreeNodePtr buildCastFunction(const QueryTreeNodePtr & expression, std::optional tryExtractConstantFromConditionNode(const QueryTreeNodePtr & condition_node) { - const auto * constant_node = condition_node->as(); + const auto * constant_node = condition_node ? condition_node->as() : nullptr; if (!constant_node) return {}; diff --git a/src/Interpreters/HashJoin/HashJoin.cpp b/src/Interpreters/HashJoin/HashJoin.cpp index 3e7f3deea8b..f0b6656bd3a 100644 --- a/src/Interpreters/HashJoin/HashJoin.cpp +++ b/src/Interpreters/HashJoin/HashJoin.cpp @@ -115,6 +115,11 @@ HashJoin::HashJoin(std::shared_ptr table_join_, const Block & right_s , instance_log_id(!instance_id_.empty() ? "(" + instance_id_ + ") " : "") , log(getLogger("HashJoin")) { + for (auto & column : right_sample_block) + { + if (!column.column) + column.column = column.type->createColumn(); + } LOG_TRACE(log, "{}Keys: {}, datatype: {}, kind: {}, strictness: {}, right header: {}", instance_log_id, TableJoin::formatClauses(table_join->getClauses(), true), data->type, kind, strictness, right_sample_block.dumpStructure()); @@ -383,6 +388,16 @@ size_t HashJoin::getTotalByteCount() const return res; } +bool HashJoin::isUsedByAnotherAlgorithm() const +{ + return table_join->isEnabledAlgorithm(JoinAlgorithm::AUTO) || table_join->isEnabledAlgorithm(JoinAlgorithm::GRACE_HASH); +} + +bool HashJoin::canRemoveColumnsFromLeftBlock() const +{ + return table_join->enableEnalyzer() && !table_join->hasUsing() && !isUsedByAnotherAlgorithm(); +} + void HashJoin::initRightBlockStructure(Block & saved_block_sample) { if (isCrossOrComma(kind)) @@ -394,8 +409,7 @@ void HashJoin::initRightBlockStructure(Block & saved_block_sample) bool multiple_disjuncts = !table_join->oneDisjunct(); /// We could remove key columns for LEFT | INNER HashJoin but we should keep them for JoinSwitcher (if any). - bool save_key_columns = table_join->isEnabledAlgorithm(JoinAlgorithm::AUTO) || - table_join->isEnabledAlgorithm(JoinAlgorithm::GRACE_HASH) || + bool save_key_columns = isUsedByAnotherAlgorithm() || isRightOrFull(kind) || multiple_disjuncts || table_join->getMixedJoinExpression(); @@ -909,6 +923,7 @@ void HashJoin::checkTypesOfKeys(const Block & block) const void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed) { + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: >>> [{}]", __FILE__, __LINE__, block.dumpNames()); if (!data) throw Exception(ErrorCodes::LOGICAL_ERROR, "Cannot join after data has been released"); @@ -920,7 +935,7 @@ void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed) right_sample_block, onexpr.key_names_right, cond_column_name.second); } - if (kind == JoinKind::Cross) + if (kind == JoinKind::Cross || kind == JoinKind::Comma) { joinBlockImplCross(block, not_processed); return; @@ -970,6 +985,7 @@ void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed) else throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong JOIN combination: {} {}", strictness, kind); } + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: <<< [{}]", __FILE__, __LINE__, block.dumpNames()); } HashJoin::~HashJoin() @@ -1228,7 +1244,10 @@ IBlocksStreamPtr HashJoin::getNonJoinedBlocks(const Block & left_sample_block, { if (!JoinCommon::hasNonJoinedBlocks(*table_join)) return {}; + size_t left_columns_count = left_sample_block.columns(); + if (canRemoveColumnsFromLeftBlock()) + left_columns_count = table_join->getOutputColumns(JoinTableSide::Left).size(); bool flag_per_row = needUsedFlagsForPerRightTableRow(table_join); if (!flag_per_row) @@ -1237,9 +1256,11 @@ IBlocksStreamPtr HashJoin::getNonJoinedBlocks(const Block & left_sample_block, size_t expected_columns_count = left_columns_count + required_right_keys.columns() + sample_block_with_columns_to_add.columns(); if (expected_columns_count != result_sample_block.columns()) { - throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected number of columns in result sample block: {} instead of {} ({} + {} + {})", + throw Exception(ErrorCodes::LOGICAL_ERROR, + "Unexpected number of columns in result sample block: {} expected {} ([{}] = [{}] + [{}] + [{}])", result_sample_block.columns(), expected_columns_count, - left_columns_count, required_right_keys.columns(), sample_block_with_columns_to_add.columns()); + result_sample_block.dumpNames(), left_sample_block.dumpNames(), + required_right_keys.dumpNames(), sample_block_with_columns_to_add.dumpNames()); } } @@ -1329,8 +1350,9 @@ void HashJoin::validateAdditionalFilterExpression(ExpressionActionsPtr additiona if (expression_sample_block.columns() != 1) { throw Exception(ErrorCodes::LOGICAL_ERROR, - "Unexpected expression in JOIN ON section. Expected single column, got '{}'", - expression_sample_block.dumpStructure()); + "Unexpected expression in JOIN ON section. Expected single column, got '{}', expression:\n{}", + expression_sample_block.dumpStructure(), + additional_filter_expression->dumpActions()); } auto type = removeNullable(expression_sample_block.getByPosition(0).type); diff --git a/src/Interpreters/HashJoin/HashJoin.h b/src/Interpreters/HashJoin/HashJoin.h index 4c1ebbcdc66..d5abdc2ddb8 100644 --- a/src/Interpreters/HashJoin/HashJoin.h +++ b/src/Interpreters/HashJoin/HashJoin.h @@ -464,6 +464,9 @@ private: bool empty() const; + bool isUsedByAnotherAlgorithm() const; + bool canRemoveColumnsFromLeftBlock() const; + void validateAdditionalFilterExpression(std::shared_ptr additional_filter_expression); bool needUsedFlagsForPerRightTableRow(std::shared_ptr table_join_) const; diff --git a/src/Interpreters/HashJoin/HashJoinMethodsImpl.h b/src/Interpreters/HashJoin/HashJoinMethodsImpl.h index 45a766e2df6..b8d1dcba15c 100644 --- a/src/Interpreters/HashJoin/HashJoinMethodsImpl.h +++ b/src/Interpreters/HashJoin/HashJoinMethodsImpl.h @@ -56,7 +56,6 @@ Block HashJoinMethods::joinBlockImpl( const auto & key_names = !is_join_get ? onexprs[i].key_names_left : onexprs[i].key_names_right; join_on_keys.emplace_back(block, key_names, onexprs[i].condColumnNames().first, join.key_sizes[i]); } - size_t existing_columns = block.columns(); /** If you use FULL or RIGHT JOIN, then the columns from the "left" table must be materialized. * Because if they are constants, then in the "not joined" rows, they may have different values @@ -99,6 +98,22 @@ Block HashJoinMethods::joinBlockImpl( added_columns.buildJoinGetOutput(); else added_columns.buildOutput(); + + const auto & table_join = join.table_join; + std::set block_columns_to_erase; + if (join.canRemoveColumnsFromLeftBlock()) + { + std::unordered_set left_output_columns; + for (const auto & out_column : table_join->getOutputColumns(JoinTableSide::Left)) + left_output_columns.insert(out_column.name); + for (size_t i = 0; i < block.columns(); ++i) + { + if (!left_output_columns.contains(block.getByPosition(i).name)) + block_columns_to_erase.insert(i); + } + } + size_t existing_columns = block.columns(); + for (size_t i = 0; i < added_columns.size(); ++i) block.insert(added_columns.moveColumn(i)); @@ -160,6 +175,8 @@ Block HashJoinMethods::joinBlockImpl( block.safeGetByPosition(pos).column = block.safeGetByPosition(pos).column->replicate(*offsets_to_replicate); } } + + block.erase(block_columns_to_erase); return remaining_block; } diff --git a/src/Interpreters/InterpreterExplainQuery.cpp b/src/Interpreters/InterpreterExplainQuery.cpp index 2841d042cdc..5a84a1b140b 100644 --- a/src/Interpreters/InterpreterExplainQuery.cpp +++ b/src/Interpreters/InterpreterExplainQuery.cpp @@ -202,6 +202,7 @@ struct QueryPlanSettings /// Apply query plan optimizations. bool optimize = true; + bool logical_steps = false; bool json = false; constexpr static char name[] = "PLAN"; @@ -213,6 +214,7 @@ struct QueryPlanSettings {"actions", query_plan_options.actions}, {"indexes", query_plan_options.indexes}, {"optimize", optimize}, + {"logical", logical_steps}, {"json", json}, {"sorting", query_plan_options.sorting}, }; @@ -470,7 +472,11 @@ QueryPipeline InterpreterExplainQuery::executeImpl() } if (settings.optimize) - plan.optimize(QueryPlanOptimizationSettings::fromContext(context)); + { + auto optimization_settings = QueryPlanOptimizationSettings::fromContext(context); + optimization_settings.keep_logical_steps = settings.logical_steps; + plan.optimize(optimization_settings); + } if (settings.json) { diff --git a/src/Interpreters/JoinInfo.cpp b/src/Interpreters/JoinInfo.cpp new file mode 100644 index 00000000000..f34c29ec458 --- /dev/null +++ b/src/Interpreters/JoinInfo.cpp @@ -0,0 +1,28 @@ +#include + +namespace DB +{ + +// namespace Settings +// { +// #define DECLARE_JOIN_SETTINGS_EXTERN(type, name) \ +// extern const Settings##type name; + +// APPLY_FOR_JOIN_SETTINGS(DECLARE_JOIN_SETTINGS_EXTERN) +// #undef DECLARE_JOIN_SETTINGS_EXTERN +// } + +// JoinSettings JoinSettings::create(const Settings & query_settings) +// { +// JoinSettings join_settings; + +// #define COPY_JOIN_SETTINGS_FROM_QUERY(type, name) \ +// join_settings.name = query_settings[Setting::name]; + +// APPLY_FOR_JOIN_SETTINGS(COPY_JOIN_SETTINGS_FROM_QUERY) +// #undef COPY_JOIN_SETTINGS_FROM_QUERY + +// return join_settings; +// } + +} diff --git a/src/Interpreters/JoinInfo.h b/src/Interpreters/JoinInfo.h new file mode 100644 index 00000000000..62b673d31ff --- /dev/null +++ b/src/Interpreters/JoinInfo.h @@ -0,0 +1,194 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace DB +{ + +enum class PredicateOperator : UInt8 +{ + Equals, + NullSafeEquals, + Less, + LessOrEquals, + Greater, + GreaterOrEquals, +}; + +inline std::optional getJoinPredicateOperator(const String & func_name) +{ + if (func_name == "equals") + return PredicateOperator::Equals; + if (func_name == "isNotDistinctFrom") + return PredicateOperator::NullSafeEquals; + if (func_name == "less") + return PredicateOperator::Less; + if (func_name == "greater") + return PredicateOperator::Greater; + if (func_name == "lessOrEquals") + return PredicateOperator::LessOrEquals; + if (func_name == "greaterOrEquals") + return PredicateOperator::GreaterOrEquals; + return {}; +} + +inline PredicateOperator reversePredicateOperator(PredicateOperator op) +{ + switch (op) + { + case PredicateOperator::Equals: return PredicateOperator::Equals; + case PredicateOperator::NullSafeEquals: return PredicateOperator::NullSafeEquals; + case PredicateOperator::Less: return PredicateOperator::Greater; + case PredicateOperator::Greater: return PredicateOperator::Less; + case PredicateOperator::LessOrEquals: return PredicateOperator::GreaterOrEquals; + case PredicateOperator::GreaterOrEquals: return PredicateOperator::LessOrEquals; + } +} + +struct JoinExpressionActions +{ + JoinExpressionActions(const ColumnsWithTypeAndName & left_columns, const ColumnsWithTypeAndName & right_columns) + : left_pre_join_actions(left_columns) + , right_pre_join_actions(right_columns) + { + ColumnsWithTypeAndName concat_columns; + concat_columns.insert(concat_columns.end(), left_columns.begin(), left_columns.end()); + concat_columns.insert(concat_columns.end(), right_columns.begin(), right_columns.end()); + post_join_actions = ActionsDAG(concat_columns); + } + + ActionsDAG left_pre_join_actions; + ActionsDAG right_pre_join_actions; + ActionsDAG post_join_actions; +}; + +struct JoinActionRef +{ + const ActionsDAG::Node * node; + String column_name; + + explicit JoinActionRef(const ActionsDAG::Node * node_) + : node(node_) , column_name(node_ ? node_->result_name : "") + {} + + operator bool() const { return node != nullptr; } /// NOLINT +}; + + +/// JoinPredicate represents a single join qualifier +/// that that apply to the combination of two tables. +struct JoinPredicate +{ + JoinActionRef left_node; + JoinActionRef right_node; + PredicateOperator op; +}; + +/// JoinCondition determines if rows from two tables can be joined +struct JoinCondition +{ + /// Join predicates that must be satisfied to join rows + std::vector predicates; + + /// Pre-Join filters applied to the left and right tables independently + std::vector left_filter_conditions; + std::vector right_filter_conditions; + + /// Residual conditions depend on data from both tables and must be evaluated after the join has been performed. + /// Unlike the join predicates, these conditions can be arbitrary expressions. + std::vector residual_conditions; +}; + +struct JoinExpression +{ + /// A single join condition that must be satisfied to join rows + JoinCondition condition; + + /// Disjunctive join conditions represented by alternative conditions connected by the OR operator. + /// If any of the conditions is true, corresponding rows from the left and right tables can be joined. + std::vector disjunctive_conditions; + + /// Indicates if the join expression is defined with the USING clause + bool is_using = false; + + /// Set if JOIN ON expression was folded to a single constant on analysis stage + std::optional constant_value = {}; +}; + +struct JoinInfo +{ + /// An expression in ON/USING clause of a JOIN statement + JoinExpression expression; + + /// The type of join (e.g., INNER, LEFT, RIGHT, FULL) + JoinKind kind; + + /// The strictness of the join (e.g., ALL, ANY, SEMI, ANTI) + JoinStrictness strictness; + + /// The locality of the join (e.g., LOCAL, GLOBAL) + JoinLocality locality; +}; + +// #define APPLY_FOR_JOIN_SETTINGS(M) \ +// M(JoinAlgorithm, algorithm) \ +// M(size_t, max_block_size) \ +// \ +// M(bool, join_use_nulls) \ +// M(bool, any_join_distinct_right_table_keys) \ +// \ +// M(size_t, max_rows_in_join) \ +// M(size_t, max_bytes_in_join) \ +// \ +// M(OverflowMode, join_overflow_mode) \ +// M(bool, join_any_take_last_row) \ +// \ +// /* CROSS JOIN settings */ \ +// M(UInt64, cross_join_min_rows_to_compress) \ +// M(UInt64, cross_join_min_bytes_to_compress) \ +// \ +// /* Partial merge join settings */ \ +// M(UInt64, partial_merge_join_left_table_buffer_bytes) \ +// M(UInt64, partial_merge_join_rows_in_right_blocks) \ +// M(UInt64, join_on_disk_max_files_to_merge) \ +// \ +// /* Grace hash join settings */ \ +// M(UInt64, grace_hash_join_initial_buckets) \ +// M(UInt64, grace_hash_join_max_buckets) \ +// \ +// /* Full sorting merge join settings */ \ +// M(UInt64, max_rows_in_set_to_optimize_join) \ +// \ +// /* Hash/Parallel hash join settings */ \ +// M(bool, collect_hash_table_stats_during_joins) \ +// M(UInt64, max_size_to_preallocate_for_joins) \ +// \ +// M(bool, query_plan_convert_outer_join_to_inner_join) \ +// M(bool, multiple_joins_try_to_keep_original_names) \ +// \ +// M(bool, parallel_replicas_prefer_local_join) \ +// M(bool, allow_experimental_join_condition) \ +// \ +// M(UInt64, cross_to_inner_join_rewrite) \ + +// struct JoinSettings +// { +// #define DECLARE_JOIN_SETTING_FILEDS(type, name) \ +// type name; + +// APPLY_FOR_JOIN_SETTINGS(APPLY_FOR_JOIN_SETTINGS) +// #undef DECLARE_JOIN_SETTING_FILEDS + +// static JoinSettings create(const Settings & query_settings); +// }; + + +} diff --git a/src/Interpreters/JoinUtils.cpp b/src/Interpreters/JoinUtils.cpp index d48ae16d3cd..d06339e1b06 100644 --- a/src/Interpreters/JoinUtils.cpp +++ b/src/Interpreters/JoinUtils.cpp @@ -18,6 +18,7 @@ #include #include +#include namespace DB { @@ -675,6 +676,12 @@ NotJoinedBlocks::NotJoinedBlocks(std::unique_ptr filler_, , result_sample_block(materializeBlock(result_sample_block_)) { const auto & left_to_right_key_remap = table_join.leftToRightKeyRemap(); + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: [{}]", __FILE__, __LINE__, fmt::join( + left_to_right_key_remap | std::views::transform([](const auto & e) + { + return fmt::format("{} -> {}", e.first, e.second); + }), ", ")); + for (size_t left_pos = 0; left_pos < left_columns_count; ++left_pos) { /// We need right 'x' for 'RIGHT JOIN ... USING(x)' @@ -715,9 +722,9 @@ NotJoinedBlocks::NotJoinedBlocks(std::unique_ptr filler_, throw Exception( ErrorCodes::LOGICAL_ERROR, "Error in columns mapping in JOIN: assertion failed {} + {} + {} != {}; " - "Result block [{}], Saved block [{}]", + "left_columns_count = {}, result_sample_block.columns = [{}], saved_block_sample.columns = [{}]", column_indices_left.size(), column_indices_right.size(), same_result_keys.size(), result_sample_block.columns(), - result_sample_block.dumpNames(), saved_block_sample.dumpNames()); + left_columns_count, result_sample_block.dumpNames(), saved_block_sample.dumpNames()); } } diff --git a/src/Interpreters/TableJoin.cpp b/src/Interpreters/TableJoin.cpp index 2532dddba3c..ad6afa09767 100644 --- a/src/Interpreters/TableJoin.cpp +++ b/src/Interpreters/TableJoin.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include @@ -41,6 +42,7 @@ namespace DB namespace Setting { extern const SettingsBool allow_experimental_join_right_table_sorting; + extern const SettingsBool allow_experimental_analyzer; extern const SettingsUInt64 cross_join_min_bytes_to_compress; extern const SettingsUInt64 cross_join_min_rows_to_compress; extern const SettingsUInt64 default_max_bytes_in_join; @@ -143,9 +145,58 @@ TableJoin::TableJoin(const Settings & settings, VolumePtr tmp_volume_, Temporary , max_memory_usage(settings[Setting::max_memory_usage]) , tmp_volume(tmp_volume_) , tmp_data(tmp_data_) + , enable_analyzer(settings[Setting::allow_experimental_analyzer]) { } + +TableJoin::TableJoin(SizeLimits limits, bool use_nulls, JoinKind kind, JoinStrictness strictness, const Names & key_names_right) + : size_limits(limits) + , default_max_bytes(0) + , join_use_nulls(use_nulls) + , join_algorithm({JoinAlgorithm::DEFAULT}) +{ + clauses.emplace_back().key_names_right = key_names_right; + table_join.kind = kind; + table_join.strictness = strictness; +} + + +JoinKind TableJoin::kind() const +{ + if (join_info) + return join_info->kind; + return table_join.kind; +} + +void TableJoin::setKind(JoinKind kind) +{ + if (join_info) + join_info->kind = kind; + table_join.kind = kind; +} + +JoinStrictness TableJoin::strictness() const +{ + if (join_info) + return join_info->strictness; + return table_join.strictness; +} + +bool TableJoin::hasUsing() const +{ + if (join_info) + return join_info->expression.is_using; + return table_join.using_expression_list != nullptr; +} + +bool TableJoin::hasOn() const +{ + if (join_info) + return !join_info->expression.is_using; + return table_join.on_expression != nullptr; +} + void TableJoin::resetKeys() { clauses.clear(); @@ -161,6 +212,8 @@ void TableJoin::resetCollected() clauses.clear(); columns_from_joined_table.clear(); columns_added_by_join.clear(); + columns_from_left_table.clear(); + result_columns_from_left_table.clear(); original_names.clear(); renames.clear(); left_type_map.clear(); @@ -203,6 +256,19 @@ size_t TableJoin::rightKeyInclusion(const String & name) const return count; } +void TableJoin::setInputColumns(NamesAndTypesList left_output_columns, NamesAndTypesList right_output_columns) +{ + columns_from_left_table = std::move(left_output_columns); + columns_from_joined_table = std::move(right_output_columns); +} + +const NamesAndTypesList & TableJoin::getOutputColumns(JoinTableSide side) +{ + if (side == JoinTableSide::Left) + return result_columns_from_left_table; + return columns_added_by_join; +} + void TableJoin::deduplicateAndQualifyColumnNames(const NameSet & left_table_columns, const String & right_table_prefix) { NameSet joined_columns; @@ -351,9 +417,41 @@ bool TableJoin::rightBecomeNullable(const DataTypePtr & column_type) const return forceNullableRight() && JoinCommon::canBecomeNullable(column_type); } +void TableJoin::setUsedColumns(const Names & column_names) +{ + std::unordered_map left_columns_idx; + for (auto it = columns_from_left_table.begin(); it != columns_from_left_table.end(); ++it) + left_columns_idx[it->name] = it; + + std::unordered_map right_columns_idx; + for (auto it = columns_from_joined_table.begin(); it != columns_from_joined_table.end(); ++it) + right_columns_idx[it->name] = it; + + for (const auto & column_name : column_names) + { + if (auto lit = left_columns_idx.find(column_name); lit != left_columns_idx.end()) + setUsedColumn(*lit->second, JoinTableSide::Left); + else if (auto rit = right_columns_idx.find(column_name); rit != right_columns_idx.end()) + setUsedColumn(*rit->second, JoinTableSide::Right); + else + throw Exception(ErrorCodes::NOT_FOUND_COLUMN_IN_BLOCK, + "Column {} not found in JOIN, left columns: [{}], right columns: [{}]", column_name, + fmt::join(columns_from_left_table | std::views::transform([](const auto & col) { return col.name; }), ", "), + fmt::join(columns_from_joined_table | std::views::transform([](const auto & col) { return col.name; }), ", ")); + } +} + +void TableJoin::setUsedColumn(const NameAndTypePair & joined_column, JoinTableSide side) +{ + if (side == JoinTableSide::Left) + result_columns_from_left_table.push_back(joined_column); + else + columns_added_by_join.push_back(joined_column); +} + void TableJoin::addJoinedColumn(const NameAndTypePair & joined_column) { - columns_added_by_join.emplace_back(joined_column); + setUsedColumn(joined_column, JoinTableSide::Right); } NamesAndTypesList TableJoin::correctedColumnsAddedByJoin() const @@ -974,9 +1072,9 @@ bool TableJoin::allowParallelHashJoin() const return false; if (!right_storage_name.empty()) return false; - if (table_join.kind != JoinKind::Left && table_join.kind != JoinKind::Inner) + if (kind() != JoinKind::Left && kind() != JoinKind::Inner) return false; - if (table_join.strictness == JoinStrictness::Asof) + if (strictness() == JoinStrictness::Asof) return false; if (isSpecialStorage() || !oneDisjunct()) return false; @@ -995,5 +1093,10 @@ size_t TableJoin::getMaxMemoryUsage() const return max_memory_usage; } +void TableJoin::assertEnableEnalyzer() const +{ + if (!enable_analyzer) + throw DB::Exception(ErrorCodes::NOT_IMPLEMENTED, "TableJoin: analyzer is disabled"); +} } diff --git a/src/Interpreters/TableJoin.h b/src/Interpreters/TableJoin.h index e1bae55a4ed..b01b28d91e2 100644 --- a/src/Interpreters/TableJoin.h +++ b/src/Interpreters/TableJoin.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -164,9 +165,13 @@ private: std::shared_ptr mixed_join_expression = nullptr; ASTTableJoin table_join; + std::optional join_info; ASOFJoinInequality asof_inequality = ASOFJoinInequality::GreaterOrEquals; + NamesAndTypesList columns_from_left_table; + NamesAndTypesList result_columns_from_left_table; + /// All columns which can be read from joined table. Duplicating names are qualified. NamesAndTypesList columns_from_joined_table; /// Columns will be added to block by JOIN. @@ -201,6 +206,7 @@ private: std::string right_storage_name; bool is_join_with_constant = false; + bool enable_analyzer = false; Names requiredJoinedNames() const; @@ -243,22 +249,13 @@ public: /// for StorageJoin TableJoin(SizeLimits limits, bool use_nulls, JoinKind kind, JoinStrictness strictness, - const Names & key_names_right) - : size_limits(limits) - , default_max_bytes(0) - , join_use_nulls(use_nulls) - , join_algorithm({JoinAlgorithm::DEFAULT}) - { - clauses.emplace_back().key_names_right = key_names_right; - table_join.kind = kind; - table_join.strictness = strictness; - } + const Names & key_names_right); TableJoin(const TableJoin & rhs) = default; - JoinKind kind() const { return table_join.kind; } - void setKind(JoinKind kind) { table_join.kind = kind; } - JoinStrictness strictness() const { return table_join.strictness; } + JoinKind kind() const; + void setKind(JoinKind kind); + JoinStrictness strictness() const; bool sameStrictnessAndKind(JoinStrictness, JoinKind) const; const SizeLimits & sizeLimits() const { return size_limits; } size_t getMaxMemoryUsage() const; @@ -266,6 +263,8 @@ public: VolumePtr getGlobalTemporaryVolume() { return tmp_volume; } TemporaryDataOnDiskScopePtr getTempDataOnDisk() { return tmp_data; } + bool enableEnalyzer() const { return enable_analyzer; } + void assertEnableEnalyzer() const; ActionsDAG createJoinedBlockActions(ContextPtr context) const; @@ -316,6 +315,8 @@ public: ASTTableJoin & getTableJoin() { return table_join; } const ASTTableJoin & getTableJoin() const { return table_join; } + void setJoinInfo(const JoinInfo & join_info_) { join_info = join_info_; } + JoinOnClause & getOnlyClause() { assertHasOneOnExpr(); return clauses[0]; } const JoinOnClause & getOnlyClause() const { assertHasOneOnExpr(); return clauses[0]; } @@ -349,8 +350,8 @@ public: */ void addJoinCondition(const ASTPtr & ast, bool is_left); - bool hasUsing() const { return table_join.using_expression_list != nullptr; } - bool hasOn() const { return table_join.on_expression != nullptr; } + bool hasUsing() const; + bool hasOn() const; String getOriginalName(const String & column_name) const; NamesWithAliases getNamesWithAliases(const NameSet & required_columns) const; @@ -372,6 +373,9 @@ public: bool leftBecomeNullable(const DataTypePtr & column_type) const; bool rightBecomeNullable(const DataTypePtr & column_type) const; void addJoinedColumn(const NameAndTypePair & joined_column); + void setUsedColumn(const NameAndTypePair & joined_column, JoinTableSide side); + void setUsedColumns(const Names & column_names); + void setColumnsAddedByJoin(const NamesAndTypesList & columns_added_by_join_value) { columns_added_by_join = columns_added_by_join_value; @@ -397,11 +401,16 @@ public: ASTPtr leftKeysList() const; ASTPtr rightKeysList() const; /// For ON syntax only - void setColumnsFromJoinedTable(NamesAndTypesList columns_from_joined_table_value, const NameSet & left_table_columns, const String & right_table_prefix) + void setColumnsFromJoinedTable(NamesAndTypesList columns_from_joined_table_value, const NameSet & left_table_columns, const String & right_table_prefix, const NamesAndTypesList & columns_from_left_table_) { columns_from_joined_table = std::move(columns_from_joined_table_value); deduplicateAndQualifyColumnNames(left_table_columns, right_table_prefix); + result_columns_from_left_table = columns_from_left_table_; + columns_from_left_table = columns_from_left_table_; } + + void setInputColumns(NamesAndTypesList left_output_columns, NamesAndTypesList right_output_columns); + const NamesAndTypesList & getOutputColumns(JoinTableSide side); const NamesAndTypesList & columnsFromJoinedTable() const { return columns_from_joined_table; } const NamesAndTypesList & columnsAddedByJoin() const { return columns_added_by_join; } diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index ea08fd92339..28e11166762 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -1353,12 +1353,15 @@ TreeRewriterResultPtr TreeRewriter::analyzeSelect( if (tables_with_columns.size() > 1) { + auto columns_from_left_table = tables_with_columns[0].columns; const auto & right_table = tables_with_columns[1]; auto columns_from_joined_table = right_table.columns; /// query can use materialized or aliased columns from right joined table, /// we want to request it for right table columns_from_joined_table.insert(columns_from_joined_table.end(), right_table.hidden_columns.begin(), right_table.hidden_columns.end()); - result.analyzed_join->setColumnsFromJoinedTable(std::move(columns_from_joined_table), source_columns_set, right_table.table.getQualifiedNamePrefix()); + columns_from_left_table.insert(columns_from_left_table.end(), tables_with_columns[0].hidden_columns.begin(), tables_with_columns[0].hidden_columns.end()); + result.analyzed_join->setColumnsFromJoinedTable( + std::move(columns_from_joined_table), source_columns_set, right_table.table.getQualifiedNamePrefix(), columns_from_left_table); } translateQualifiedNames(query, *select_query, source_columns_set, tables_with_columns); diff --git a/src/Parsers/CreateQueryUUIDs.cpp b/src/Parsers/CreateQueryUUIDs.cpp index c788cc7a025..14cf5761a11 100644 --- a/src/Parsers/CreateQueryUUIDs.cpp +++ b/src/Parsers/CreateQueryUUIDs.cpp @@ -31,7 +31,7 @@ CreateQueryUUIDs::CreateQueryUUIDs(const ASTCreateQuery & query, bool generate_r /// If we generate random UUIDs for already existing tables then those UUIDs will not be correct making those inner target table inaccessible. /// Thus it's not safe for example to replace /// "ATTACH MATERIALIZED VIEW mv AS SELECT a FROM b" with - /// "ATTACH MATERIALIZED VIEW mv TO INNER UUID "XXXX" AS SELECT a FROM b" + /// "ATTACH MATERIALIZED VIEW mv TO INNER UUID "248372b7-02c4-4c88-a5e1-282a83cc572a" AS SELECT a FROM b" /// This replacement is safe only for CREATE queries when inner target tables don't exist yet. if (!query.attach) { diff --git a/src/Planner/PlannerJoinTree.cpp b/src/Planner/PlannerJoinTree.cpp index 5c153f6db39..b0a6c96c52c 100644 --- a/src/Planner/PlannerJoinTree.cpp +++ b/src/Planner/PlannerJoinTree.cpp @@ -47,6 +47,7 @@ #include #include #include +#include #include #include #include @@ -108,9 +109,7 @@ namespace Setting namespace ErrorCodes { - extern const int INVALID_JOIN_ON_EXPRESSION; extern const int LOGICAL_ERROR; - extern const int NOT_IMPLEMENTED; extern const int ACCESS_DENIED; extern const int PARAMETER_OUT_OF_BOUND; extern const int TOO_MANY_COLUMNS; @@ -1241,396 +1240,19 @@ void joinCastPlanColumnsToNullable(QueryPlan & plan_to_add_cast, PlannerContextP plan_to_add_cast.addStep(std::move(cast_join_columns_step)); } -JoinTreeQueryPlan buildQueryPlanForJoinNode(const QueryTreeNodePtr & join_table_expression, +JoinTreeQueryPlan joinPlansWithStep( + QueryPlanStepPtr join_step, JoinTreeQueryPlan left_join_tree_query_plan, JoinTreeQueryPlan right_join_tree_query_plan, const ColumnIdentifierSet & outer_scope_columns, PlannerContextPtr & planner_context) { - auto & join_node = join_table_expression->as(); - if (left_join_tree_query_plan.from_stage != QueryProcessingStage::FetchColumns) - throw Exception(ErrorCodes::UNSUPPORTED_METHOD, - "JOIN {} left table expression expected to process query to fetch columns stage. Actual {}", - join_node.formatASTForErrorMessage(), - QueryProcessingStage::toString(left_join_tree_query_plan.from_stage)); + std::vector plans; + plans.emplace_back(std::make_unique(std::move(left_join_tree_query_plan.query_plan))); + plans.emplace_back(std::make_unique(std::move(right_join_tree_query_plan.query_plan))); - auto left_plan = std::move(left_join_tree_query_plan.query_plan); - auto left_plan_output_columns = left_plan.getCurrentHeader().getColumnsWithTypeAndName(); - if (right_join_tree_query_plan.from_stage != QueryProcessingStage::FetchColumns) - throw Exception(ErrorCodes::UNSUPPORTED_METHOD, - "JOIN {} right table expression expected to process query to fetch columns stage. Actual {}", - join_node.formatASTForErrorMessage(), - QueryProcessingStage::toString(right_join_tree_query_plan.from_stage)); - - auto right_plan = std::move(right_join_tree_query_plan.query_plan); - auto right_plan_output_columns = right_plan.getCurrentHeader().getColumnsWithTypeAndName(); - - JoinClausesAndActions join_clauses_and_actions; - JoinKind join_kind = join_node.getKind(); - JoinStrictness join_strictness = join_node.getStrictness(); - - std::optional join_constant; - - if (join_strictness == JoinStrictness::All || join_strictness == JoinStrictness::Semi || join_strictness == JoinStrictness::Anti) - join_constant = tryExtractConstantFromJoinNode(join_table_expression); - - if (!join_constant && join_node.isOnJoinExpression()) - { - join_clauses_and_actions = buildJoinClausesAndActions(left_plan_output_columns, - right_plan_output_columns, - join_table_expression, - planner_context); - - join_clauses_and_actions.left_join_expressions_actions.appendInputsForUnusedColumns(left_plan.getCurrentHeader()); - auto left_join_expressions_actions_step = std::make_unique(left_plan.getCurrentHeader(), std::move(join_clauses_and_actions.left_join_expressions_actions)); - left_join_expressions_actions_step->setStepDescription("JOIN actions"); - appendSetsFromActionsDAG(left_join_expressions_actions_step->getExpression(), left_join_tree_query_plan.useful_sets); - left_plan.addStep(std::move(left_join_expressions_actions_step)); - - join_clauses_and_actions.right_join_expressions_actions.appendInputsForUnusedColumns(right_plan.getCurrentHeader()); - auto right_join_expressions_actions_step = std::make_unique(right_plan.getCurrentHeader(), std::move(join_clauses_and_actions.right_join_expressions_actions)); - right_join_expressions_actions_step->setStepDescription("JOIN actions"); - appendSetsFromActionsDAG(right_join_expressions_actions_step->getExpression(), right_join_tree_query_plan.useful_sets); - right_plan.addStep(std::move(right_join_expressions_actions_step)); - } - - std::unordered_map left_plan_column_name_to_cast_type; - std::unordered_map right_plan_column_name_to_cast_type; - - if (join_node.isUsingJoinExpression()) - { - auto & join_node_using_columns_list = join_node.getJoinExpression()->as(); - for (auto & join_node_using_node : join_node_using_columns_list.getNodes()) - { - auto & join_node_using_column_node = join_node_using_node->as(); - auto & inner_columns_list = join_node_using_column_node.getExpressionOrThrow()->as(); - - auto & left_inner_column_node = inner_columns_list.getNodes().at(0); - auto & left_inner_column = left_inner_column_node->as(); - - auto & right_inner_column_node = inner_columns_list.getNodes().at(1); - auto & right_inner_column = right_inner_column_node->as(); - - const auto & join_node_using_column_node_type = join_node_using_column_node.getColumnType(); - if (!left_inner_column.getColumnType()->equals(*join_node_using_column_node_type)) - { - const auto & left_inner_column_identifier = planner_context->getColumnNodeIdentifierOrThrow(left_inner_column_node); - left_plan_column_name_to_cast_type.emplace(left_inner_column_identifier, join_node_using_column_node_type); - } - - if (!right_inner_column.getColumnType()->equals(*join_node_using_column_node_type)) - { - const auto & right_inner_column_identifier = planner_context->getColumnNodeIdentifierOrThrow(right_inner_column_node); - right_plan_column_name_to_cast_type.emplace(right_inner_column_identifier, join_node_using_column_node_type); - } - } - } - - auto join_cast_plan_output_nodes = [&](QueryPlan & plan_to_add_cast, std::unordered_map & plan_column_name_to_cast_type) - { - ActionsDAG cast_actions_dag(plan_to_add_cast.getCurrentHeader().getColumnsWithTypeAndName()); - - for (auto & output_node : cast_actions_dag.getOutputs()) - { - auto it = plan_column_name_to_cast_type.find(output_node->result_name); - if (it == plan_column_name_to_cast_type.end()) - continue; - - const auto & cast_type = it->second; - output_node = &cast_actions_dag.addCast(*output_node, cast_type, output_node->result_name); - } - - cast_actions_dag.appendInputsForUnusedColumns(plan_to_add_cast.getCurrentHeader()); - auto cast_join_columns_step - = std::make_unique(plan_to_add_cast.getCurrentHeader(), std::move(cast_actions_dag)); - cast_join_columns_step->setStepDescription("Cast JOIN USING columns"); - plan_to_add_cast.addStep(std::move(cast_join_columns_step)); - }; - - if (!left_plan_column_name_to_cast_type.empty()) - join_cast_plan_output_nodes(left_plan, left_plan_column_name_to_cast_type); - - if (!right_plan_column_name_to_cast_type.empty()) - join_cast_plan_output_nodes(right_plan, right_plan_column_name_to_cast_type); - - const auto & query_context = planner_context->getQueryContext(); - const auto & settings = query_context->getSettingsRef(); - - if (settings[Setting::join_use_nulls]) - { - auto to_nullable_function = FunctionFactory::instance().get("toNullable", query_context); - if (isFull(join_kind)) - { - joinCastPlanColumnsToNullable(left_plan, planner_context, to_nullable_function); - joinCastPlanColumnsToNullable(right_plan, planner_context, to_nullable_function); - } - else if (isLeft(join_kind)) - { - joinCastPlanColumnsToNullable(right_plan, planner_context, to_nullable_function); - } - else if (isRight(join_kind)) - { - joinCastPlanColumnsToNullable(left_plan, planner_context, to_nullable_function); - } - } - - auto table_join = std::make_shared(settings, query_context->getGlobalTemporaryVolume(), query_context->getTempDataOnDisk()); - table_join->getTableJoin() = join_node.toASTTableJoin()->as(); - - if (join_constant) - { - /** If there is JOIN with always true constant, we transform it to cross. - * If there is JOIN with always false constant, we do not process JOIN keys. - * It is expected by join algorithm to handle such case. - * - * Example: SELECT * FROM test_table AS t1 INNER JOIN test_table AS t2 ON 1; - */ - if (*join_constant) - join_kind = JoinKind::Cross; - } - table_join->getTableJoin().kind = join_kind; - - if (join_kind == JoinKind::Comma) - { - join_kind = JoinKind::Cross; - table_join->getTableJoin().kind = JoinKind::Cross; - } - - table_join->setIsJoinWithConstant(join_constant != std::nullopt); - - if (join_node.isOnJoinExpression()) - { - const auto & join_clauses = join_clauses_and_actions.join_clauses; - bool is_asof = table_join->strictness() == JoinStrictness::Asof; - - if (join_clauses.size() > 1) - { - if (is_asof) - throw Exception(ErrorCodes::NOT_IMPLEMENTED, - "ASOF join {} doesn't support multiple ORs for keys in JOIN ON section", - join_node.formatASTForErrorMessage()); - } - - auto & table_join_clauses = table_join->getClauses(); - - for (const auto & join_clause : join_clauses) - { - table_join_clauses.emplace_back(); - auto & table_join_clause = table_join_clauses.back(); - - const auto & join_clause_left_key_nodes = join_clause.getLeftKeyNodes(); - const auto & join_clause_right_key_nodes = join_clause.getRightKeyNodes(); - - size_t join_clause_key_nodes_size = join_clause_left_key_nodes.size(); - chassert(join_clause_key_nodes_size == join_clause_right_key_nodes.size()); - - for (size_t i = 0; i < join_clause_key_nodes_size; ++i) - { - table_join_clause.addKey(join_clause_left_key_nodes[i]->result_name, - join_clause_right_key_nodes[i]->result_name, - join_clause.isNullsafeCompareKey(i)); - } - - const auto & join_clause_get_left_filter_condition_nodes = join_clause.getLeftFilterConditionNodes(); - if (!join_clause_get_left_filter_condition_nodes.empty()) - { - if (join_clause_get_left_filter_condition_nodes.size() != 1) - throw Exception(ErrorCodes::LOGICAL_ERROR, - "JOIN {} left filter conditions size must be 1. Actual {}", - join_node.formatASTForErrorMessage(), - join_clause_get_left_filter_condition_nodes.size()); - - const auto & join_clause_left_filter_condition_name = join_clause_get_left_filter_condition_nodes[0]->result_name; - table_join_clause.analyzer_left_filter_condition_column_name = join_clause_left_filter_condition_name; - } - - const auto & join_clause_get_right_filter_condition_nodes = join_clause.getRightFilterConditionNodes(); - if (!join_clause_get_right_filter_condition_nodes.empty()) - { - if (join_clause_get_right_filter_condition_nodes.size() != 1) - throw Exception(ErrorCodes::LOGICAL_ERROR, - "JOIN {} right filter conditions size must be 1. Actual {}", - join_node.formatASTForErrorMessage(), - join_clause_get_right_filter_condition_nodes.size()); - - const auto & join_clause_right_filter_condition_name = join_clause_get_right_filter_condition_nodes[0]->result_name; - table_join_clause.analyzer_right_filter_condition_column_name = join_clause_right_filter_condition_name; - } - - if (is_asof) - { - if (!join_clause.hasASOF()) - throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, - "JOIN {} no inequality in ASOF JOIN ON section", - join_node.formatASTForErrorMessage()); - } - - if (join_clause.hasASOF()) - { - const auto & asof_conditions = join_clause.getASOFConditions(); - assert(asof_conditions.size() == 1); - - const auto & asof_condition = asof_conditions[0]; - table_join->setAsofInequality(asof_condition.asof_inequality); - - /// Execution layer of JOIN algorithms expects that ASOF keys are last JOIN keys - std::swap(table_join_clause.key_names_left.at(asof_condition.key_index), table_join_clause.key_names_left.back()); - std::swap(table_join_clause.key_names_right.at(asof_condition.key_index), table_join_clause.key_names_right.back()); - } - } - - if (join_clauses_and_actions.mixed_join_expressions_actions) - { - ExpressionActionsPtr & mixed_join_expression = table_join->getMixedJoinExpression(); - mixed_join_expression = std::make_shared( - std::move(*join_clauses_and_actions.mixed_join_expressions_actions), - ExpressionActionsSettings::fromContext(planner_context->getQueryContext())); - - appendSetsFromActionsDAG(mixed_join_expression->getActionsDAG(), left_join_tree_query_plan.useful_sets); - } - } - else if (join_node.isUsingJoinExpression()) - { - auto & table_join_clauses = table_join->getClauses(); - table_join_clauses.emplace_back(); - auto & table_join_clause = table_join_clauses.back(); - - auto & using_list = join_node.getJoinExpression()->as(); - - for (auto & join_using_node : using_list.getNodes()) - { - auto & join_using_column_node = join_using_node->as(); - auto & using_join_columns_list = join_using_column_node.getExpressionOrThrow()->as(); - auto & using_join_left_join_column_node = using_join_columns_list.getNodes().at(0); - auto & using_join_right_join_column_node = using_join_columns_list.getNodes().at(1); - - const auto & left_column_identifier = planner_context->getColumnNodeIdentifierOrThrow(using_join_left_join_column_node); - const auto & right_column_identifier = planner_context->getColumnNodeIdentifierOrThrow(using_join_right_join_column_node); - - table_join_clause.key_names_left.push_back(left_column_identifier); - table_join_clause.key_names_right.push_back(right_column_identifier); - } - } - - const Block & left_header = left_plan.getCurrentHeader(); - auto left_table_names = left_header.getNames(); - NameSet left_table_names_set(left_table_names.begin(), left_table_names.end()); - - auto columns_from_joined_table = right_plan.getCurrentHeader().getNamesAndTypesList(); - table_join->setColumnsFromJoinedTable(columns_from_joined_table, left_table_names_set, ""); - - for (auto & column_from_joined_table : columns_from_joined_table) - { - /// Add columns from joined table only if they are presented in outer scope, otherwise they can be dropped - if (planner_context->getGlobalPlannerContext()->hasColumnIdentifier(column_from_joined_table.name) && - outer_scope_columns.contains(column_from_joined_table.name)) - table_join->addJoinedColumn(column_from_joined_table); - } - - const Block & right_header = right_plan.getCurrentHeader(); - auto join_algorithm = chooseJoinAlgorithm(table_join, join_node.getRightTableExpression(), left_header, right_header, planner_context); - - auto result_plan = QueryPlan(); - - bool is_filled_join = join_algorithm->isFilled(); - if (is_filled_join) - { - auto filled_join_step - = std::make_unique(left_plan.getCurrentHeader(), join_algorithm, settings[Setting::max_block_size]); - - filled_join_step->setStepDescription("Filled JOIN"); - left_plan.addStep(std::move(filled_join_step)); - - result_plan = std::move(left_plan); - } - else - { - auto add_sorting = [&] (QueryPlan & plan, const Names & key_names, JoinTableSide join_table_side) - { - SortDescription sort_description; - sort_description.reserve(key_names.size()); - for (const auto & key_name : key_names) - sort_description.emplace_back(key_name); - - SortingStep::Settings sort_settings(*query_context); - - auto sorting_step = std::make_unique( - plan.getCurrentHeader(), std::move(sort_description), 0 /*limit*/, sort_settings, true /*is_sorting_for_merge_join*/); - sorting_step->setStepDescription(fmt::format("Sort {} before JOIN", join_table_side)); - plan.addStep(std::move(sorting_step)); - }; - - auto crosswise_connection = CreateSetAndFilterOnTheFlyStep::createCrossConnection(); - auto add_create_set = [&settings, crosswise_connection](QueryPlan & plan, const Names & key_names, JoinTableSide join_table_side) - { - auto creating_set_step = std::make_unique( - plan.getCurrentHeader(), key_names, settings[Setting::max_rows_in_set_to_optimize_join], crosswise_connection, join_table_side); - creating_set_step->setStepDescription(fmt::format("Create set and filter {} joined stream", join_table_side)); - - auto * step_raw_ptr = creating_set_step.get(); - plan.addStep(std::move(creating_set_step)); - return step_raw_ptr; - }; - - if (join_algorithm->pipelineType() == JoinPipelineType::YShaped && join_kind != JoinKind::Paste) - { - const auto & join_clause = table_join->getOnlyClause(); - - bool join_type_allows_filtering = (join_strictness == JoinStrictness::All || join_strictness == JoinStrictness::Any) - && (isInner(join_kind) || isLeft(join_kind) || isRight(join_kind)); - - - auto has_non_const = [](const Block & block, const auto & keys) - { - for (const auto & key : keys) - { - const auto & column = block.getByName(key).column; - if (column && !isColumnConst(*column)) - return true; - } - return false; - }; - - /// This optimization relies on the sorting that should buffer data from both streams before emitting any rows. - /// Sorting on a stream with const keys can start returning rows immediately and pipeline may stuck. - /// Note: it's also doesn't work with the read-in-order optimization. - /// No checks here because read in order is not applied if we have `CreateSetAndFilterOnTheFlyStep` in the pipeline between the reading and sorting steps. - bool has_non_const_keys = has_non_const(left_plan.getCurrentHeader(), join_clause.key_names_left) - && has_non_const(right_plan.getCurrentHeader(), join_clause.key_names_right); - - if (settings[Setting::max_rows_in_set_to_optimize_join] > 0 && join_type_allows_filtering && has_non_const_keys) - { - auto * left_set = add_create_set(left_plan, join_clause.key_names_left, JoinTableSide::Left); - auto * right_set = add_create_set(right_plan, join_clause.key_names_right, JoinTableSide::Right); - - if (isInnerOrLeft(join_kind)) - right_set->setFiltering(left_set->getSet()); - - if (isInnerOrRight(join_kind)) - left_set->setFiltering(right_set->getSet()); - } - - add_sorting(left_plan, join_clause.key_names_left, JoinTableSide::Left); - add_sorting(right_plan, join_clause.key_names_right, JoinTableSide::Right); - } - - auto join_pipeline_type = join_algorithm->pipelineType(); - auto join_step = std::make_unique( - left_plan.getCurrentHeader(), - right_plan.getCurrentHeader(), - std::move(join_algorithm), - settings[Setting::max_block_size], - settings[Setting::max_threads], - false /*optimize_read_in_order*/); - - join_step->setStepDescription(fmt::format("JOIN {}", join_pipeline_type)); - - std::vector plans; - plans.emplace_back(std::make_unique(std::move(left_plan))); - plans.emplace_back(std::make_unique(std::move(right_plan))); - - result_plan.unitePlans(std::move(join_step), {std::move(plans)}); - } + QueryPlan result_plan; + result_plan.unitePlans(std::move(join_step), {std::move(plans)}); ActionsDAG drop_unused_columns_after_join_actions_dag(result_plan.getCurrentHeader().getColumnsWithTypeAndName()); ActionsDAG::NodeRawConstPtrs drop_unused_columns_after_join_actions_dag_updated_outputs; @@ -1667,33 +1289,103 @@ JoinTreeQueryPlan buildQueryPlanForJoinNode(const QueryTreeNodePtr & join_table_ if (drop_unused_columns_after_join_actions_dag_updated_outputs.empty() && first_skipped_column_node_index) drop_unused_columns_after_join_actions_dag_updated_outputs.push_back(drop_unused_columns_after_join_actions_dag_outputs[*first_skipped_column_node_index]); - drop_unused_columns_after_join_actions_dag_outputs = std::move(drop_unused_columns_after_join_actions_dag_updated_outputs); + if (drop_unused_columns_after_join_actions_dag_outputs.size() != drop_unused_columns_after_join_actions_dag_updated_outputs.size()) + { + drop_unused_columns_after_join_actions_dag_outputs = std::move(drop_unused_columns_after_join_actions_dag_updated_outputs); - auto drop_unused_columns_after_join_transform_step = std::make_unique(result_plan.getCurrentHeader(), std::move(drop_unused_columns_after_join_actions_dag)); - drop_unused_columns_after_join_transform_step->setStepDescription("DROP unused columns after JOIN"); - result_plan.addStep(std::move(drop_unused_columns_after_join_transform_step)); + auto drop_unused_columns_after_join_transform_step = std::make_unique(result_plan.getCurrentHeader(), std::move(drop_unused_columns_after_join_actions_dag)); + drop_unused_columns_after_join_transform_step->setStepDescription("DROP unused columns after JOIN"); + result_plan.addStep(std::move(drop_unused_columns_after_join_transform_step)); + } + /// Collect all required row_policies and actions sets from left and right join tree query plans + + auto result_used_row_policies = std::move(left_join_tree_query_plan.used_row_policies); for (const auto & right_join_tree_query_plan_row_policy : right_join_tree_query_plan.used_row_policies) - left_join_tree_query_plan.used_row_policies.insert(right_join_tree_query_plan_row_policy); + result_used_row_policies.insert(right_join_tree_query_plan_row_policy); - /// Collect all required actions sets in `left_join_tree_query_plan.useful_sets` - if (!is_filled_join) - for (const auto & useful_set : right_join_tree_query_plan.useful_sets) - left_join_tree_query_plan.useful_sets.insert(useful_set); + auto result_useful_sets = std::move(left_join_tree_query_plan.useful_sets); + for (const auto & useful_set : right_join_tree_query_plan.useful_sets) + result_useful_sets.insert(useful_set); - auto mapping = std::move(left_join_tree_query_plan.query_node_to_plan_step_mapping); - auto & r_mapping = right_join_tree_query_plan.query_node_to_plan_step_mapping; - mapping.insert(r_mapping.begin(), r_mapping.end()); + auto result_mapping = std::move(left_join_tree_query_plan.query_node_to_plan_step_mapping); + const auto & r_mapping = right_join_tree_query_plan.query_node_to_plan_step_mapping; + result_mapping.insert(r_mapping.begin(), r_mapping.end()); return JoinTreeQueryPlan{ .query_plan = std::move(result_plan), .from_stage = QueryProcessingStage::FetchColumns, - .used_row_policies = std::move(left_join_tree_query_plan.used_row_policies), - .useful_sets = std::move(left_join_tree_query_plan.useful_sets), - .query_node_to_plan_step_mapping = std::move(mapping), + .used_row_policies = std::move(result_used_row_policies), + .useful_sets = std::move(result_useful_sets), + .query_node_to_plan_step_mapping = std::move(result_mapping), }; } +JoinTreeQueryPlan buildQueryPlanForJoinNode(const QueryTreeNodePtr & join_table_expression, + JoinTreeQueryPlan left_join_tree_query_plan, + JoinTreeQueryPlan right_join_tree_query_plan, + const ColumnIdentifierSet & outer_scope_columns, + PlannerContextPtr & planner_context) +{ + auto & join_node = join_table_expression->as(); + if (left_join_tree_query_plan.from_stage != QueryProcessingStage::FetchColumns) + throw Exception(ErrorCodes::UNSUPPORTED_METHOD, + "JOIN {} left table expression expected to process query to fetch columns stage. Actual {}", + join_node.formatASTForErrorMessage(), + QueryProcessingStage::toString(left_join_tree_query_plan.from_stage)); + + auto & left_plan = left_join_tree_query_plan.query_plan; + auto & right_plan = right_join_tree_query_plan.query_plan; + + const auto & query_context = planner_context->getQueryContext(); + const auto & settings = query_context->getSettingsRef(); + + if (settings[Setting::join_use_nulls]) + { + JoinKind join_kind = join_node.getKind(); + + auto to_nullable_function = FunctionFactory::instance().get("toNullable", query_context); + if (isFull(join_kind)) + { + joinCastPlanColumnsToNullable(left_plan, planner_context, to_nullable_function); + joinCastPlanColumnsToNullable(right_plan, planner_context, to_nullable_function); + } + else if (isLeft(join_kind)) + { + joinCastPlanColumnsToNullable(right_plan, planner_context, to_nullable_function); + } + else if (isRight(join_kind)) + { + joinCastPlanColumnsToNullable(left_plan, planner_context, to_nullable_function); + } + } + + auto join_step_logical = buildJoinStepLogical( + left_plan.getCurrentHeader(), + right_plan.getCurrentHeader(), + outer_scope_columns, + join_node, + planner_context); + + std::visit([&join_step_logical](auto prepared_storage) + { + if constexpr (!std::is_same_v>) + { + if (!prepared_storage) + return; + join_step_logical->setPreparedJoinStorage(std::move(prepared_storage)); + } + }, tryGetStorageInTableJoin(join_node.getRightTableExpression())); + + + return joinPlansWithStep( + std::move(join_step_logical), + std::move(left_join_tree_query_plan), + std::move(right_join_tree_query_plan), + outer_scope_columns, + planner_context); +} + JoinTreeQueryPlan buildQueryPlanForArrayJoinNode(const QueryTreeNodePtr & array_join_table_expression, JoinTreeQueryPlan join_tree_query_plan, const ColumnIdentifierSet & outer_scope_columns, diff --git a/src/Planner/PlannerJoins.cpp b/src/Planner/PlannerJoins.cpp index c9a10f61502..ea44da64e27 100644 --- a/src/Planner/PlannerJoins.cpp +++ b/src/Planner/PlannerJoins.cpp @@ -39,9 +39,11 @@ #include #include #include +#include #include #include +#include namespace DB { @@ -140,7 +142,7 @@ TableExpressionSet extractTableExpressionsSet(const QueryTreeNodePtr & node) return res; } -std::set extractJoinTableSidesFromExpression(//const ActionsDAG::Node * expression_root_node, +std::set extractJoinTableSidesFromExpression( const IQueryTreeNode * expression_root_node, const TableExpressionSet & left_table_expressions, const TableExpressionSet & right_table_expressions, @@ -383,6 +385,24 @@ void buildJoinClause( } } +QueryTreeNodePtr getJoinExpressionFromNode(const JoinNode & join_node) +{ + /** It is possible to have constant value in JOIN ON section, that we need to ignore during DAG construction. + * If we do not ignore it, this function will be replaced by underlying constant. + * For example ASOF JOIN does not support JOIN with constants, and we should process it like ordinary JOIN. + * + * Example: SELECT * FROM (SELECT 1 AS id, 1 AS value) AS t1 ASOF LEFT JOIN (SELECT 1 AS id, 1 AS value) AS t2 + * ON (t1.id = t2.id) AND 1 != 1 AND (t1.value >= t1.value); + */ + const auto & join_expression = join_node.getJoinExpression(); + if (!join_expression) + return nullptr; + const auto * constant_join_expression = join_expression->as(); + if (constant_join_expression && constant_join_expression->hasSourceExpression()) + return constant_join_expression->getSourceExpression(); + return join_expression; +} + JoinClausesAndActions buildJoinClausesAndActions( const ColumnsWithTypeAndName & left_table_expression_columns, const ColumnsWithTypeAndName & right_table_expression_columns, @@ -391,29 +411,15 @@ JoinClausesAndActions buildJoinClausesAndActions( { ActionsDAG left_join_actions(left_table_expression_columns); ActionsDAG right_join_actions(right_table_expression_columns); + ColumnsWithTypeAndName mixed_table_expression_columns; for (const auto & left_column : left_table_expression_columns) - { mixed_table_expression_columns.push_back(left_column); - } for (const auto & right_column : right_table_expression_columns) - { mixed_table_expression_columns.push_back(right_column); - } ActionsDAG mixed_join_actions(mixed_table_expression_columns); - /** It is possible to have constant value in JOIN ON section, that we need to ignore during DAG construction. - * If we do not ignore it, this function will be replaced by underlying constant. - * For example ASOF JOIN does not support JOIN with constants, and we should process it like ordinary JOIN. - * - * Example: SELECT * FROM (SELECT 1 AS id, 1 AS value) AS t1 ASOF LEFT JOIN (SELECT 1 AS id, 1 AS value) AS t2 - * ON (t1.id = t2.id) AND 1 != 1 AND (t1.value >= t1.value); - */ - auto join_expression = join_node.getJoinExpression(); - auto * constant_join_expression = join_expression->as(); - - if (constant_join_expression && constant_join_expression->hasSourceExpression()) - join_expression = constant_join_expression->getSourceExpression(); + auto join_expression = getJoinExpressionFromNode(join_node); auto * function_node = join_expression->as(); if (!function_node) @@ -658,6 +664,224 @@ JoinClausesAndActions buildJoinClausesAndActions( return result; } +struct JoinInfoBuildContext +{ + explicit JoinInfoBuildContext( + const JoinNode & join_node_, + const ColumnsWithTypeAndName & left_table_columns_, + const ColumnsWithTypeAndName & right_table_columns_, + const PlannerContextPtr & planner_context_) + : join_node(join_node_) + , planner_context(planner_context_) + , left_table_columns(left_table_columns_) + , right_table_columns(right_table_columns_) + , left_table_expression_set(extractTableExpressionsSet(join_node.getLeftTableExpression())) + , right_table_expression_set(extractTableExpressionsSet(join_node.getRightTableExpression())) + , result_join_expression_actions(left_table_columns, right_table_columns) + { + result_join_info.kind = join_node.getKind(); + result_join_info.strictness = join_node.getStrictness(); + result_join_info.locality = join_node.getLocality(); + } + + enum class JoinSource : uint8_t { None, Left, Right, Both }; + + JoinSource getExpressionSource(const QueryTreeNodePtr & node) + { + auto res = extractJoinTableSidesFromExpression(node.get(), left_table_expression_set, right_table_expression_set, join_node); + if (res.empty()) + return JoinSource::None; + if (res.size() == 1) + { + if (*res.begin() == JoinTableSide::Left) + return JoinSource::Left; + return JoinSource::Right; + } + return JoinSource::Both; + } + + JoinActionRef addExpression(const QueryTreeNodePtr & node, JoinSource src) + { + const ActionsDAG::Node * dag_node_ptr = nullptr; + if (src == JoinSource::None) + dag_node_ptr = appendExpression(result_join_expression_actions.post_join_actions, node, planner_context, join_node); + else if (src == JoinSource::Left) + dag_node_ptr = appendExpression(result_join_expression_actions.left_pre_join_actions, node, planner_context, join_node); + else if (src == JoinSource::Right) + dag_node_ptr = appendExpression(result_join_expression_actions.right_pre_join_actions, node, planner_context, join_node); + else + dag_node_ptr = appendExpression(result_join_expression_actions.post_join_actions, node, planner_context, join_node); + return JoinActionRef(dag_node_ptr); + } + + const JoinNode & join_node; + const PlannerContextPtr & planner_context; + + ColumnsWithTypeAndName left_table_columns; + ColumnsWithTypeAndName right_table_columns; + + TableExpressionSet left_table_expression_set; + TableExpressionSet right_table_expression_set; + + JoinExpressionActions result_join_expression_actions; + JoinInfo result_join_info; +}; + +bool tryGetJoinPredicate(const FunctionNode * function_node, JoinInfoBuildContext & builder_context, JoinCondition & join_condition) +{ + if (!function_node || function_node->getArguments().getNodes().size() != 2) + return false; + + auto predicate_operator = getJoinPredicateOperator(function_node->getFunctionName()); + if (!predicate_operator.has_value()) + return false; + + auto left_node = function_node->getArguments().getNodes().at(0); + auto left_expr_source = builder_context.getExpressionSource(left_node); + + auto right_node = function_node->getArguments().getNodes().at(1); + auto right_expr_source = builder_context.getExpressionSource(right_node); + + if (left_expr_source == JoinInfoBuildContext::JoinSource::Left && right_expr_source == JoinInfoBuildContext::JoinSource::Right) + { + join_condition.predicates.emplace_back(JoinPredicate{ + builder_context.addExpression(left_node, JoinInfoBuildContext::JoinSource::Left), + builder_context.addExpression(right_node, JoinInfoBuildContext::JoinSource::Right), + predicate_operator.value()}); + return true; + } + + if (left_expr_source == JoinInfoBuildContext::JoinSource::Right && right_expr_source == JoinInfoBuildContext::JoinSource::Left) + { + join_condition.predicates.push_back(JoinPredicate{ + builder_context.addExpression(right_node, JoinInfoBuildContext::JoinSource::Left), + builder_context.addExpression(left_node, JoinInfoBuildContext::JoinSource::Right), + reversePredicateOperator(predicate_operator.value())}); + return true; + } + + return false; +} + +void buildJoinOnCondition(const QueryTreeNodePtr & node, JoinInfoBuildContext & builder_context, JoinCondition & join_condition) +{ + auto & using_list = node->as(); + for (auto & using_node : using_list.getNodes()) + { + auto & using_column_node = using_node->as(); + auto & inner_columns_list = using_column_node.getExpressionOrThrow()->as(); + chassert(inner_columns_list.getNodes().size() == 2); + + join_condition.predicates.emplace_back(JoinPredicate{ + builder_context.addExpression(inner_columns_list.getNodes().at(0), JoinInfoBuildContext::JoinSource::Left), + builder_context.addExpression(inner_columns_list.getNodes().at(1), JoinInfoBuildContext::JoinSource::Right), + PredicateOperator::Equals}); + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: {} ({}) == {} ({})", __FILE__, __LINE__, + join_condition.predicates.back().left_node.column_name, + builder_context.planner_context->getColumnNodeIdentifierOrThrow(inner_columns_list.getNodes().at(0)), + join_condition.predicates.back().right_node.column_name, + builder_context.planner_context->getColumnNodeIdentifierOrThrow(inner_columns_list.getNodes().at(1))); + } +} + +void buildJoinCondition(const QueryTreeNodePtr & node, JoinInfoBuildContext & builder_context, JoinCondition & join_condition) +{ + std::string function_name; + const auto * function_node = node->as(); + if (function_node) + function_name = function_node->getFunction()->getName(); + + if (function_name == "and") + { + for (const auto & child : function_node->getArguments()) + buildJoinCondition(child, builder_context, join_condition); + return; + } + + bool is_predicate = tryGetJoinPredicate(function_node, builder_context, join_condition); + if (is_predicate) + return; + + auto expr_source = builder_context.getExpressionSource(node); + if (expr_source == JoinInfoBuildContext::JoinSource::Left) + join_condition.left_filter_conditions.push_back(builder_context.addExpression(node, expr_source)); + else if (expr_source == JoinInfoBuildContext::JoinSource::Right) + join_condition.right_filter_conditions.push_back(builder_context.addExpression(node, expr_source)); + else + join_condition.residual_conditions.push_back(builder_context.addExpression(node, expr_source)); +} + +void buildDisjunctiveJoinConditions(const QueryTreeNodePtr & node, JoinInfoBuildContext & builder_context, std::vector & join_conditions) +{ + auto * function_node = node->as(); + if (!function_node) + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, + "JOIN {} join expression expected function", + node->formatASTForErrorMessage()); + + const auto & function_name = function_node->getFunction()->getName(); + + if (function_name == "or") + { + for (const auto & child : function_node->getArguments()) + buildDisjunctiveJoinConditions(child, builder_context, join_conditions); + return; + } + buildJoinCondition(node, builder_context, join_conditions.emplace_back()); +} + +} + +std::unique_ptr buildJoinStepLogical( + const Block & left_header, + const Block & right_header, + const NameSet & outer_scope_columns, + const JoinNode & join_node, + const PlannerContextPtr & planner_context) +{ + const auto & left_columns = left_header.getColumnsWithTypeAndName(); + const auto & right_columns = right_header.getColumnsWithTypeAndName(); + JoinInfoBuildContext build_context(join_node, left_columns, right_columns, planner_context); + + auto join_expression_constant = tryExtractConstantFromConditionNode(join_node.getJoinExpression()); + build_context.result_join_info.expression.constant_value = join_expression_constant; + + auto join_expression_node = getJoinExpressionFromNode(join_node); + + /// CROSS JOIN: doesn't have expression + if (join_expression_node == nullptr) + { + if (!isCrossOrComma(join_node.getKind())) + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, "Missing join expression in {}", join_node.formatASTForErrorMessage()); + } + /// USING + else if (join_node.isUsingJoinExpression()) + { + buildJoinOnCondition(join_expression_node, build_context, build_context.result_join_info.expression.condition); + } + /// JOIN ON some non-constant expression + else if (!join_expression_constant.has_value()) + { + if (join_expression_node->getNodeType() != QueryTreeNodeType::FUNCTION) + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, + "JOIN {} join expression expected function", + join_node.formatASTForErrorMessage()); + + buildDisjunctiveJoinConditions(join_expression_node, build_context, build_context.result_join_info.expression.disjunctive_conditions); + if (!build_context.result_join_info.expression.disjunctive_conditions.empty()) + { + build_context.result_join_info.expression.condition = build_context.result_join_info.expression.disjunctive_conditions.back(); + build_context.result_join_info.expression.disjunctive_conditions.pop_back(); + } + } + + return std::make_unique( + left_header, + right_header, + std::move(build_context.result_join_info), + std::move(build_context.result_join_expression_actions), + Names(outer_scope_columns.begin(), outer_scope_columns.end()), + planner_context->getQueryContext()); } JoinClausesAndActions buildJoinClausesAndActions( @@ -684,10 +908,7 @@ std::optional tryExtractConstantFromJoinNode(const QueryTreeNodePtr & join return tryExtractConstantFromConditionNode(join_node_typed.getJoinExpression()); } -namespace -{ - -void trySetStorageInTableJoin(const QueryTreeNodePtr & table_expression, std::shared_ptr & table_join) +PreparedJoinStorage tryGetStorageInTableJoin(const QueryTreeNodePtr & table_expression) { StoragePtr storage; @@ -698,19 +919,31 @@ void trySetStorageInTableJoin(const QueryTreeNodePtr & table_expression, std::sh auto storage_join = std::dynamic_pointer_cast(storage); if (storage_join) + return storage_join; + + auto storage_dictionary = std::dynamic_pointer_cast(storage); + if (storage_dictionary && storage_dictionary->getDictionary()->getSpecialKeyType() != DictionarySpecialKeyType::Range) + return std::dynamic_pointer_cast(storage_dictionary->getDictionary()); + + if (auto storage_key_value = std::dynamic_pointer_cast(storage)) + return storage_key_value; + + return {}; +} + +namespace +{ + +void trySetStorageInTableJoin(const QueryTreeNodePtr & table_expression, std::shared_ptr & table_join) +{ + auto storage = tryGetStorageInTableJoin(table_expression); + if (std::holds_alternative(storage)) + return; + std::visit([&table_join](const auto & storage_) { - table_join->setStorageJoin(storage_join); - return; - } - - if (!table_join->isEnabledAlgorithm(JoinAlgorithm::DIRECT) && !table_join->isEnabledAlgorithm(JoinAlgorithm::DEFAULT)) - return; - - if (auto storage_dictionary = std::dynamic_pointer_cast(storage); - storage_dictionary && storage_dictionary->getDictionary()->getSpecialKeyType() != DictionarySpecialKeyType::Range) - table_join->setStorageJoin(std::dynamic_pointer_cast(storage_dictionary->getDictionary())); - else if (auto storage_key_value = std::dynamic_pointer_cast(storage); storage_key_value) - table_join->setStorageJoin(storage_key_value); + if constexpr (!std::is_same_v>) + table_join->setStorageJoin(storage_); + }, storage); } std::shared_ptr tryDirectJoin(const std::shared_ptr & table_join, diff --git a/src/Planner/PlannerJoins.h b/src/Planner/PlannerJoins.h index d8665ab7739..24096c8c8f3 100644 --- a/src/Planner/PlannerJoins.h +++ b/src/Planner/PlannerJoins.h @@ -6,8 +6,11 @@ #include #include #include +#include +#include #include +#include namespace DB { @@ -224,4 +227,16 @@ std::shared_ptr chooseJoinAlgorithm(std::shared_ptr & table_jo const Block & right_table_expression_header, const PlannerContextPtr & planner_context); +using PreparedJoinStorage = std::variant, std::shared_ptr>; +PreparedJoinStorage tryGetStorageInTableJoin(const QueryTreeNodePtr & table_expression); + +class JoinStepLogical; + +std::unique_ptr buildJoinStepLogical( + const Block & left_header, + const Block & right_header, + const NameSet & outer_scope_columns, + const JoinNode & join_node, + const PlannerContextPtr & planner_context); + } diff --git a/src/Processors/QueryPlan/JoinStepLogical.cpp b/src/Processors/QueryPlan/JoinStepLogical.cpp new file mode 100644 index 00000000000..05fdc7aacf2 --- /dev/null +++ b/src/Processors/QueryPlan/JoinStepLogical.cpp @@ -0,0 +1,579 @@ +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace Setting +{ + extern const SettingsJoinAlgorithm join_algorithm; + extern const SettingsBool join_any_take_last_row; +} + +namespace ErrorCodes +{ + extern const int NOT_IMPLEMENTED; + extern const int LOGICAL_ERROR; + extern const int INVALID_JOIN_ON_EXPRESSION; +} + +std::string_view toString(PredicateOperator op) +{ + switch (op) + { + case PredicateOperator::Equals: return "="; + case PredicateOperator::NullSafeEquals: return "<=>"; + case PredicateOperator::Less: return "<"; + case PredicateOperator::LessOrEquals: return "<="; + case PredicateOperator::Greater: return ">"; + case PredicateOperator::GreaterOrEquals: return ">="; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Illegal value for PredicateOperator: {}", static_cast(op)); +} + + +std::string toFunctionName(PredicateOperator op) +{ + switch (op) + { + case PredicateOperator::Equals: return "equals"; + case PredicateOperator::NullSafeEquals: return "isNotDistinctFrom"; + case PredicateOperator::Less: return "less"; + case PredicateOperator::LessOrEquals: return "lessOrEquals"; + case PredicateOperator::Greater: return "greater"; + case PredicateOperator::GreaterOrEquals: return "greaterOrEquals"; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Illegal value for PredicateOperator: {}", static_cast(op)); +} + +std::optional operatorToAsofInequality(PredicateOperator op) +{ + switch (op) + { + case PredicateOperator::Less: return ASOFJoinInequality::Less; + case PredicateOperator::LessOrEquals: return ASOFJoinInequality::LessOrEquals; + case PredicateOperator::Greater: return ASOFJoinInequality::Greater; + case PredicateOperator::GreaterOrEquals: return ASOFJoinInequality::GreaterOrEquals; + default: return {}; + } +} + +void formatJoinCondition(const JoinCondition & join_condition, WriteBuffer & buf) +{ + auto quote_string = std::views::transform([](const auto & s) { return fmt::format("({})", s.column_name); }); + auto format_predicate = std::views::transform([](const auto & p) { return fmt::format("{} {} {}", p.left_node.column_name, toString(p.op), p.right_node.column_name); }); + buf << "["; + buf << fmt::format("Keys: ({})", fmt::join(join_condition.predicates | format_predicate, " AND ")); + if (!join_condition.left_filter_conditions.empty()) + buf << " " << fmt::format("Left: ({})", fmt::join(join_condition.left_filter_conditions | quote_string, " AND ")); + if (!join_condition.right_filter_conditions.empty()) + buf << " " << fmt::format("Right: ({})", fmt::join(join_condition.right_filter_conditions | quote_string, " AND ")); + if (!join_condition.residual_conditions.empty()) + buf << " " << fmt::format("Residual: ({})", fmt::join(join_condition.residual_conditions | quote_string, " AND ")); + buf << "]"; +} + +std::vector> describeJoinActions(const JoinInfo & join_info) +{ + std::vector> description; + + description.emplace_back("Type", toString(join_info.kind)); + description.emplace_back("Strictness", toString(join_info.strictness)); + description.emplace_back("Locality", toString(join_info.locality)); + + { + WriteBufferFromOwnString join_expression_str; + join_expression_str << (join_info.expression.is_using ? "USING" : "ON") << " " ; + formatJoinCondition(join_info.expression.condition, join_expression_str); + for (const auto & condition : join_info.expression.disjunctive_conditions) + { + join_expression_str << " | "; + formatJoinCondition(condition, join_expression_str); + } + description.emplace_back("Expression", join_expression_str.str()); + } + + return description; +} + + +JoinStepLogical::JoinStepLogical( + const Block & left_header_, + const Block & right_header_, + JoinInfo join_info_, + JoinExpressionActions join_expression_actions_, + Names required_output_columns_, + ContextPtr context_) + : expression_actions(std::move(join_expression_actions_)) + , join_info(std::move(join_info_)) + , required_output_columns(std::move(required_output_columns_)) + , query_context(std::move(context_)) +{ + updateInputHeaders({left_header_, right_header_}); +} + +QueryPipelineBuilderPtr JoinStepLogical::updatePipeline(QueryPipelineBuilders /* pipelines */, const BuildQueryPipelineSettings & /* settings */) +{ + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Cannot execute JoinStepLogical, it should be converted physical step first"); +} + + +void JoinStepLogical::describePipeline(FormatSettings & settings) const +{ + IQueryPlanStep::describePipeline(processors, settings); +} + +void JoinStepLogical::describeActions(FormatSettings & settings) const +{ + String prefix(settings.offset, settings.indent_char); + String prefix2(settings.offset + settings.indent, settings.indent_char); + + for (const auto & [name, value] : describeJoinActions(join_info)) + settings.out << prefix << name << ": " << value << '\n'; + settings.out << prefix << "Post Expression:\n"; + ExpressionActions(expression_actions.post_join_actions.clone()).describeActions(settings.out, prefix2); + settings.out << prefix << "Left Expression:\n"; + // settings.out << expression_actions.left_pre_join_actions.dumpDAG(); + ExpressionActions(expression_actions.left_pre_join_actions.clone()).describeActions(settings.out, prefix2); + settings.out << prefix << "Right Expression:\n"; + ExpressionActions(expression_actions.right_pre_join_actions.clone()).describeActions(settings.out, prefix2); +} + +void JoinStepLogical::describeActions(JSONBuilder::JSONMap & map) const +{ + for (const auto & [name, value] : describeJoinActions(join_info)) + map.add(name, value); + + map.add("Left Actions", ExpressionActions(expression_actions.left_pre_join_actions.clone()).toTree()); + map.add("Right Actions", ExpressionActions(expression_actions.right_pre_join_actions.clone()).toTree()); + map.add("Post Actions", ExpressionActions(expression_actions.post_join_actions.clone()).toTree()); +} + +static Block stackHeadersFromStreams(const Headers & input_headers, const Names & required_output_columns) +{ + NameSet required_output_columns_set(required_output_columns.begin(), required_output_columns.end()); + + Block result_header; + for (const auto & header : input_headers) + { + for (const auto & column : header) + { + if (required_output_columns_set.contains(column.name)) + { + result_header.insert(column); + } + else if (required_output_columns_set.empty()) + { + /// If no required columns specified, use one first column. + result_header.insert(column); + return result_header; + } + } + } + return result_header; +} + +void JoinStepLogical::updateOutputHeader() +{ + output_header = stackHeadersFromStreams(input_headers, required_output_columns); +} + + +JoinActionRef concatConditions(const std::vector & conditions, ActionsDAG & actions_dag, const ContextPtr & query_context) +{ + if (conditions.empty()) + return JoinActionRef(nullptr); + + if (conditions.size() == 1) + { + actions_dag.addOrReplaceInOutputs(*conditions.front().node); + return conditions.front(); + } + + auto and_function = FunctionFactory::instance().get("and", query_context); + ActionsDAG::NodeRawConstPtrs nodes; + nodes.reserve(conditions.size()); + for (const auto & condition : conditions) + { + if (!condition.node) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Condition node is nullptr"); + nodes.push_back(condition.node); + } + + const auto & result_node = actions_dag.addFunction(and_function, nodes, {}); + actions_dag.addOrReplaceInOutputs(result_node); + return JoinActionRef(&result_node); +} + +JoinActionRef concatMergeConditions(std::vector & conditions, ActionsDAG & actions_dag, const ContextPtr & query_context) +{ + auto condition = concatConditions(conditions, actions_dag, query_context); + conditions.clear(); + if (condition) + conditions = {condition}; + return condition; +} + + + +/// Can be used when action.node is outside of actions_dag. +const ActionsDAG::Node & addInputIfAbsent(ActionsDAG & actions_dag, const JoinActionRef & action) +{ + for (const auto * node : actions_dag.getInputs()) + { + if (node->result_name == action.column_name) + { + if (!node->result_type->equals(*action.node->result_type)) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Column '{}' expected to have type {} but got {}, in actions DAG: {}", + action.column_name, action.node->result_type->getName(), node->result_type->getName(), actions_dag.dumpDAG()); + return *node; + } + } + return actions_dag.addInput(action.column_name, action.node->result_type); +} + + +JoinActionRef predicateToCondition(const JoinPredicate & predicate, ActionsDAG & actions_dag, const ContextPtr & query_context) +{ + const auto & left_node = addInputIfAbsent(actions_dag, predicate.left_node); + const auto & right_node = addInputIfAbsent(actions_dag, predicate.right_node); + + auto operator_function = FunctionFactory::instance().get(toFunctionName(predicate.op), query_context); + const auto & result_node = actions_dag.addFunction(operator_function, {&left_node, &right_node}, {}); + return JoinActionRef(&result_node); +} + +bool canPushDownFromOn(const JoinInfo & join_info, std::optional side = {}) +{ + if (!join_info.expression.disjunctive_conditions.empty()) + return false; + + if (join_info.strictness != JoinStrictness::All + && join_info.strictness != JoinStrictness::Any + && join_info.strictness != JoinStrictness::RightAny + && join_info.strictness != JoinStrictness::Semi) + return false; + + return join_info.kind == JoinKind::Inner + || join_info.kind == JoinKind::Cross + || join_info.kind == JoinKind::Comma + || join_info.kind == JoinKind::Paste + || (side == JoinTableSide::Left && join_info.kind == JoinKind::Right) + || (side == JoinTableSide::Right && join_info.kind == JoinKind::Left); +} + +void addRequiredInputToOutput(ActionsDAG & dag, const NameSet & required_output_columns) +{ + NameSet existing_output_columns; + for (const auto & node : dag.getOutputs()) + existing_output_columns.insert(node->result_name); + + for (const auto * node : dag.getInputs()) + { + if (!required_output_columns.contains(node->result_name) + || existing_output_columns.contains(node->result_name)) + continue; + dag.addOrReplaceInOutputs(*node); + } +} + + +struct JoinPlanningContext +{ + bool is_join_with_special_storage = false; + bool is_asof = false; +}; + +void predicateOperandsToCommonType(JoinPredicate & predicate, JoinExpressionActions & expression_actions, JoinPlanningContext join_context) +{ + auto & left_node = predicate.left_node; + auto & right_node = predicate.right_node; + const auto & left_type = left_node.node->result_type; + const auto & right_type = right_node.node->result_type; + + if (left_type->equals(*right_type)) + return; + + DataTypePtr common_type; + try + { + common_type = getLeastSupertype(DataTypes{left_type, right_type}); + } + catch (Exception & ex) + { + ex.addMessage("JOIN cannot infer common type in ON section for keys. Left key '{}' type {}. Right key '{}' type {}", + left_node.column_name, left_type->getName(), + right_node.column_name, right_type->getName()); + throw; + } + + if (!left_type->equals(*common_type)) + { + left_node = JoinActionRef(&expression_actions.left_pre_join_actions.addCast(*left_node.node, common_type, {})); + expression_actions.left_pre_join_actions.addOrReplaceInOutputs(*left_node.node); + } + + if (!join_context.is_join_with_special_storage && !right_type->equals(*common_type)) + { + right_node = JoinActionRef(&expression_actions.right_pre_join_actions.addCast(*right_node.node, common_type, {})); + expression_actions.right_pre_join_actions.addOrReplaceInOutputs(*right_node.node); + } +} + +void addJoinConditionToTableJoin(JoinCondition & join_condition, TableJoin::JoinOnClause & table_join_clause, JoinExpressionActions & expression_actions, const ContextPtr & query_context, JoinPlanningContext join_context) +{ + std::vector new_predicates; + for (size_t i = 0; i < join_condition.predicates.size(); ++i) + { + auto & predicate = join_condition.predicates[i]; + predicateOperandsToCommonType(predicate, expression_actions, join_context); + if (PredicateOperator::Equals == predicate.op || PredicateOperator::NullSafeEquals == predicate.op) + { + table_join_clause.addKey(predicate.left_node.column_name, predicate.right_node.column_name, PredicateOperator::NullSafeEquals == predicate.op); + new_predicates.push_back(predicate); + } + else if (!join_context.is_asof) + { + auto predicate_action = predicateToCondition(predicate, expression_actions.post_join_actions, query_context); + join_condition.residual_conditions.push_back(predicate_action); + } + } + + if (new_predicates.empty()) + { + WriteBufferFromOwnString buf; + formatJoinCondition(join_condition, buf); + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, "No equality condition found in JOIN ON expression {}", buf.str()); + } + + join_condition.predicates = std::move(new_predicates); +} + + +void addRequiredOutputs(ActionsDAG & actions_dag, const Names & required_output_columns) +{ + NameSet required_output_columns_set(required_output_columns.begin(), required_output_columns.end()); + for (const auto * node : actions_dag.getInputs()) + { + if (required_output_columns_set.contains(node->result_name)) + actions_dag.addOrReplaceInOutputs(*node); + } +} + +JoinActionRef buildSingleActionForJoinExpression(const JoinCondition & join_condition, JoinExpressionActions & expression_actions, const ContextPtr & query_context) +{ + std::vector all_conditions; + auto left_filter_conditions_action = concatConditions(join_condition.left_filter_conditions, expression_actions.left_pre_join_actions, query_context); + if (left_filter_conditions_action) + { + left_filter_conditions_action.node = &addInputIfAbsent(expression_actions.post_join_actions, left_filter_conditions_action); + all_conditions.push_back(left_filter_conditions_action); + } + + auto right_filter_conditions_action = concatConditions(join_condition.right_filter_conditions, expression_actions.right_pre_join_actions, query_context); + if (right_filter_conditions_action) + { + right_filter_conditions_action.node = &addInputIfAbsent(expression_actions.post_join_actions, right_filter_conditions_action); + all_conditions.push_back(right_filter_conditions_action); + } + + for (const auto & predicate : join_condition.predicates) + { + auto predicate_action = predicateToCondition(predicate, expression_actions.post_join_actions, query_context); + all_conditions.push_back(predicate_action); + } + + return concatConditions(all_conditions, expression_actions.post_join_actions, query_context); +} + +JoinActionRef buildSingleActionForJoinExpression(const JoinExpression & join_expression, JoinExpressionActions & expression_actions, const ContextPtr & query_context) +{ + std::vector all_conditions; + + if (auto condition = buildSingleActionForJoinExpression(join_expression.condition, expression_actions, query_context)) + all_conditions.push_back(condition); + + for (const auto & join_condition : join_expression.disjunctive_conditions) + if (auto condition = buildSingleActionForJoinExpression(join_condition, expression_actions, query_context)) + all_conditions.push_back(condition); + + return concatConditions(all_conditions, expression_actions.post_join_actions, query_context); +} + +JoinPtr JoinStepLogical::chooseJoinAlgorithm(JoinActionRef & left_filter, JoinActionRef & right_filter, JoinActionRef & post_filter, bool is_explain_logical) +{ + for (const auto & [name, value] : describeJoinActions(join_info)) + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: {}: {}", __FILE__, __LINE__, name, value); + + const auto & settings = query_context->getSettingsRef(); + + auto table_join = std::make_shared(settings, query_context->getGlobalTemporaryVolume(), query_context->getTempDataOnDisk()); + table_join->setJoinInfo(join_info); + + auto & join_expression = join_info.expression; + + JoinPlanningContext join_context; + join_context.is_join_with_special_storage = std::visit([&](auto && storage_) + { + if (storage_ && join_expression.disjunctive_conditions.empty()) + { + table_join->setStorageJoin(storage_); + return true; + } + return false; + }, prepared_join_storage); + join_context.is_asof = join_info.strictness == JoinStrictness::Asof; + + auto & table_join_clauses = table_join->getClauses(); + + if (!isCrossOrComma(join_info.kind)) + { + addJoinConditionToTableJoin( + join_expression.condition, table_join_clauses.emplace_back(), + expression_actions, query_context, join_context); + } + + if (auto left_pre_filter_condition = concatMergeConditions(join_expression.condition.left_filter_conditions, expression_actions.left_pre_join_actions, query_context)) + { + if (canPushDownFromOn(join_info, JoinTableSide::Left)) + left_filter = left_pre_filter_condition; + else + table_join_clauses.back().analyzer_left_filter_condition_column_name = left_pre_filter_condition.column_name; + } + + if (auto right_pre_filter_condition = concatMergeConditions(join_expression.condition.right_filter_conditions, expression_actions.right_pre_join_actions, query_context)) + { + if (canPushDownFromOn(join_info, JoinTableSide::Right)) + right_filter = right_pre_filter_condition; + else + table_join_clauses.back().analyzer_right_filter_condition_column_name = right_pre_filter_condition.column_name; + } + + if (join_info.strictness == JoinStrictness::Asof) + { + if (!join_info.expression.disjunctive_conditions.empty()) + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, "ASOF join does not support multiple disjuncts in JOIN ON expression"); + + /// Find strictly only one inequality in predicate list for ASOF join + chassert(table_join_clauses.size() == 1); + auto & join_predicates = join_info.expression.condition.predicates; + bool asof_predicate_found = false; + for (auto & predicate : join_predicates) + { + predicateOperandsToCommonType(predicate, expression_actions, join_context); + auto asof_inequality_op = operatorToAsofInequality(predicate.op); + if (!asof_inequality_op) + continue; + + if (asof_predicate_found) + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, "ASOF join does not support multiple inequality predicates in JOIN ON expression"); + table_join->setAsofInequality(*asof_inequality_op); + table_join_clauses.front().addKey(predicate.left_node.column_name, predicate.right_node.column_name, /* null_safe_comparison = */ false); + } + if (!asof_predicate_found) + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, "ASOF join requires one inequality predicate in JOIN ON expression"); + } + + for (auto & join_condition : join_info.expression.disjunctive_conditions) + { + auto & table_join_clause = table_join_clauses.emplace_back(); + addJoinConditionToTableJoin(join_condition, table_join_clause, expression_actions, query_context, join_context); + if (auto left_pre_filter_condition = concatMergeConditions(join_condition.left_filter_conditions, expression_actions.left_pre_join_actions, query_context)) + table_join_clause.analyzer_left_filter_condition_column_name = left_pre_filter_condition.column_name; + if (auto right_pre_filter_condition = concatMergeConditions(join_condition.right_filter_conditions, expression_actions.right_pre_join_actions, query_context)) + table_join_clause.analyzer_right_filter_condition_column_name = right_pre_filter_condition.column_name; + } + + JoinActionRef residual_filter_condition(nullptr); + if (join_info.expression.disjunctive_conditions.empty()) + { + residual_filter_condition = concatMergeConditions( + join_info.expression.condition.residual_conditions, expression_actions.post_join_actions, query_context); + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: [{}]", __FILE__, __LINE__, fmt::join(expression_actions.post_join_actions.getOutputs() | std::views::transform([](const auto * node) { return node->result_name; }), ", ")); + } + else + { + bool need_residual_filter = !join_info.expression.condition.residual_conditions.empty(); + for (const auto & join_condition : join_info.expression.disjunctive_conditions) + { + need_residual_filter = need_residual_filter || !join_condition.residual_conditions.empty(); + if (need_residual_filter) + break; + } + + if (need_residual_filter) + residual_filter_condition = buildSingleActionForJoinExpression(join_info.expression, expression_actions, query_context); + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: [{}]", __FILE__, __LINE__, fmt::join(expression_actions.post_join_actions.getOutputs() | std::views::transform([](const auto * node) { return node->result_name; }), ", ")); + } + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: [{}]", __FILE__, __LINE__, fmt::join(expression_actions.post_join_actions.getOutputs() | std::views::transform([](const auto * node) { return node->result_name; }), ", ")); + LOG_DEBUG(&Poco::Logger::get("XXXX"), "{}:{}: residual_filter_condition {} ", __FILE__, __LINE__, residual_filter_condition.column_name); + + + if (residual_filter_condition && canPushDownFromOn(join_info)) + { + post_filter = residual_filter_condition; + } + else if (residual_filter_condition) + { + ActionsDAG dag; + if (is_explain_logical) + { + /// Keep post_join_actions for explain + dag = expression_actions.post_join_actions.clone(); + } + else + { + /// Move post_join_actions to join, replace with no-op dag + dag = std::move(expression_actions.post_join_actions); + expression_actions.post_join_actions = ActionsDAG(dag.getRequiredColumns()); + } + auto & outputs = dag.getOutputs(); + for (const auto * node : outputs) + { + if (node->result_name == residual_filter_condition.column_name) + { + outputs = {node}; + break; + } + } + ExpressionActionsPtr & mixed_join_expression = table_join->getMixedJoinExpression(); + mixed_join_expression = std::make_shared(std::move(dag), ExpressionActionsSettings::fromContext(query_context)); + } + + NameSet required_output_columns_set(required_output_columns.begin(), required_output_columns.end()); + addRequiredInputToOutput(expression_actions.left_pre_join_actions, required_output_columns_set); + addRequiredInputToOutput(expression_actions.right_pre_join_actions, required_output_columns_set); + addRequiredInputToOutput(expression_actions.post_join_actions, required_output_columns_set); + + table_join->setInputColumns( + expression_actions.left_pre_join_actions.getNamesAndTypesList(), + expression_actions.right_pre_join_actions.getNamesAndTypesList()); + table_join->setUsedColumns(expression_actions.post_join_actions.getRequiredColumnsNames()); + + Block right_sample_block(expression_actions.right_pre_join_actions.getResultColumns()); + JoinPtr join_ptr; + if (join_info.kind == JoinKind::Paste) + join_ptr = std::make_shared(table_join, right_sample_block); + else + join_ptr = std::make_shared(table_join, right_sample_block, settings[Setting::join_any_take_last_row]); + + return join_ptr; +} + +} diff --git a/src/Processors/QueryPlan/JoinStepLogical.h b/src/Processors/QueryPlan/JoinStepLogical.h new file mode 100644 index 00000000000..dd2222a7b5e --- /dev/null +++ b/src/Processors/QueryPlan/JoinStepLogical.h @@ -0,0 +1,60 @@ +#pragma once + +#include +#include +#include +#include + +namespace DB +{ + +class StorageJoin; +class IKeyValueEntity; + +/** JoinStepLogical is a logical step for JOIN operation. + * Doesn't contain any specific join algorithm or other execution details. + * It's place holder for join operation with it's description that can be serialized. + * Transformed to actual join step during plan optimization. + */ +class JoinStepLogical final : public IQueryPlanStep +{ +public: + JoinStepLogical( + const Block & left_header_, + const Block & right_header_, + JoinInfo join_info_, + JoinExpressionActions join_expression_actions_, + Names required_output_columns_, + ContextPtr context_); + + String getName() const override { return "JoinLogical"; } + + QueryPipelineBuilderPtr updatePipeline(QueryPipelineBuilders pipelines, const BuildQueryPipelineSettings &) override; + + void describePipeline(FormatSettings & settings) const override; + + void describeActions(JSONBuilder::JSONMap & map) const override; + void describeActions(FormatSettings & settings) const override; + + template + void setPreparedJoinStorage(T storage) { prepared_join_storage.emplace(std::move(storage)); } + + JoinPtr chooseJoinAlgorithm(JoinActionRef & left_filter, JoinActionRef & right_filter, JoinActionRef & post_filter, bool is_explain_logical); + + JoinExpressionActions & getExpressionActions() { return expression_actions; } + + ContextPtr getContext() const { return query_context; } + +protected: + void updateOutputHeader() override; + + JoinExpressionActions expression_actions; + JoinInfo join_info; + + Names required_output_columns; + ContextPtr query_context; + + std::variant, std::shared_ptr> prepared_join_storage; +}; + +} diff --git a/src/Processors/QueryPlan/Optimizations/Optimizations.h b/src/Processors/QueryPlan/Optimizations/Optimizations.h index 751d5182dc3..48e22a714eb 100644 --- a/src/Processors/QueryPlan/Optimizations/Optimizations.h +++ b/src/Processors/QueryPlan/Optimizations/Optimizations.h @@ -113,6 +113,7 @@ void optimizePrimaryKeyConditionAndLimit(const Stack & stack); void optimizePrewhere(Stack & stack, QueryPlan::Nodes & nodes); void optimizeReadInOrder(QueryPlan::Node & node, QueryPlan::Nodes & nodes); void optimizeAggregationInOrder(QueryPlan::Node & node, QueryPlan::Nodes &); +void optimizeJoin(QueryPlan::Node & node, QueryPlan::Nodes &, bool keep_logical); void optimizeDistinctInOrder(QueryPlan::Node & node, QueryPlan::Nodes &); /// A separate tree traverse to apply sorting properties after *InOrder optimizations. diff --git a/src/Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h b/src/Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h index 6232fc7f54f..c4371ab6b86 100644 --- a/src/Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h +++ b/src/Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h @@ -80,6 +80,8 @@ struct QueryPlanOptimizationSettings bool build_sets = true; + bool keep_logical_steps = false; + static QueryPlanOptimizationSettings fromSettings(const Settings & from); static QueryPlanOptimizationSettings fromContext(ContextPtr from); }; diff --git a/src/Processors/QueryPlan/Optimizations/optimizeJoin.cpp b/src/Processors/QueryPlan/Optimizations/optimizeJoin.cpp new file mode 100644 index 00000000000..120525ef6be --- /dev/null +++ b/src/Processors/QueryPlan/Optimizations/optimizeJoin.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + + +namespace DB::Setting +{ + extern const SettingsMaxThreads max_threads; + extern const SettingsUInt64 max_block_size; +} + +namespace DB::QueryPlanOptimizations +{ + +static std::optional estimateReadRowsCount(QueryPlan::Node & node) +{ + IQueryPlanStep * step = node.step.get(); + if (const auto * reading = typeid_cast(step)) + { + if (auto analyzed_result = reading->getAnalyzedResult()) + return analyzed_result->selected_rows; + if (auto analyzed_result = reading->selectRangesToRead()) + return analyzed_result->selected_rows; + return {}; + } + + if (const auto * reading = typeid_cast(step)) + return reading->getStorage()->totalRows(Settings{}); + + if (node.children.size() != 1) + return {}; + + if (typeid_cast(step) || typeid_cast(step)) + return estimateReadRowsCount(*node.children.front()); + + return {}; +} + +QueryPlan::Node * makeExpressionNodeOnTopOf(QueryPlan::Node * node, ActionsDAG actions_dag, const String & filter_column_name, QueryPlan::Nodes & nodes) +{ + const auto & header = node->step->getOutputHeader(); + + QueryPlanStepPtr step; + + if (filter_column_name.empty()) + step = std::make_unique(header, std::move(actions_dag)); + else + step = std::make_unique(header, std::move(actions_dag), filter_column_name, false); + + return &nodes.emplace_back(QueryPlan::Node{std::move(step), {node}}); +} + +void optimizeJoin(QueryPlan::Node & node, QueryPlan::Nodes & nodes, bool keep_logical) +{ + auto * join_step = typeid_cast(node.step.get()); + if (!join_step || node.children.size() != 2) + return; + + JoinActionRef left_filter(nullptr); + JoinActionRef right_filter(nullptr); + JoinActionRef post_filter(nullptr); + auto join_ptr = join_step->chooseJoinAlgorithm(left_filter, right_filter, post_filter, keep_logical); + if (keep_logical) + return; + + auto & join_expression_actions = join_step->getExpressionActions(); + + auto * new_left_node = makeExpressionNodeOnTopOf(node.children[0], std::move(join_expression_actions.left_pre_join_actions), left_filter.column_name, nodes); + auto * new_right_node = makeExpressionNodeOnTopOf(node.children[1], std::move(join_expression_actions.right_pre_join_actions), right_filter.column_name, nodes); + + const auto & settings = join_step->getContext()->getSettingsRef(); + + auto new_join_step = std::make_unique( + new_left_node->step->getOutputHeader(), + new_right_node->step->getOutputHeader(), + join_ptr, + settings[Setting::max_block_size], + settings[Setting::max_threads], + false); + + auto & new_join_node = nodes.emplace_back(); + new_join_node.step = std::move(new_join_step); + new_join_node.children = {new_left_node, new_right_node}; + + { + WriteBufferFromOwnString buffer; + IQueryPlanStep::FormatSettings settings_out{.out = buffer, .write_header = true}; + new_join_node.step->describeActions(settings_out); + } + + if (!post_filter) + node.step = std::make_unique(new_join_node.step->getOutputHeader(), std::move(join_expression_actions.post_join_actions)); + else + node.step = std::make_unique(new_join_node.step->getOutputHeader(), std::move(join_expression_actions.post_join_actions), post_filter.column_name, false); + node.children = {&new_join_node}; +} + +} diff --git a/src/Processors/QueryPlan/Optimizations/optimizeTree.cpp b/src/Processors/QueryPlan/Optimizations/optimizeTree.cpp index 03418c752d4..c20d9673d01 100644 --- a/src/Processors/QueryPlan/Optimizations/optimizeTree.cpp +++ b/src/Processors/QueryPlan/Optimizations/optimizeTree.cpp @@ -168,6 +168,8 @@ void optimizeTreeSecondPass(const QueryPlanOptimizationSettings & optimization_s if (optimization_settings.aggregation_in_order) optimizeAggregationInOrder(*frame.node, nodes); + + optimizeJoin(*frame.node, nodes, optimization_settings.keep_logical_steps); } /// Traverse all children first. diff --git a/src/Processors/QueryPlan/QueryPlan.cpp b/src/Processors/QueryPlan/QueryPlan.cpp index 98fd209c12a..3cdc9f8be3c 100644 --- a/src/Processors/QueryPlan/QueryPlan.cpp +++ b/src/Processors/QueryPlan/QueryPlan.cpp @@ -18,7 +18,7 @@ #include #include - +#include namespace DB { @@ -26,6 +26,7 @@ namespace DB namespace ErrorCodes { extern const int LOGICAL_ERROR; + extern const int NOT_FOUND_COLUMN_IN_BLOCK; } QueryPlan::QueryPlan() = default; @@ -460,16 +461,25 @@ void QueryPlan::explainPipeline(WriteBuffer & buffer, const ExplainPipelineOptio void QueryPlan::optimize(const QueryPlanOptimizationSettings & optimization_settings) { - /// optimization need to be applied before "mergeExpressions" optimization - /// it removes redundant sorting steps, but keep underlying expressions, - /// so "mergeExpressions" optimization handles them afterwards - if (optimization_settings.remove_redundant_sorting) - QueryPlanOptimizations::tryRemoveRedundantSorting(root); + try + { + /// optimization need to be applied before "mergeExpressions" optimization + /// it removes redundant sorting steps, but keep underlying expressions, + /// so "mergeExpressions" optimization handles them afterwards + if (optimization_settings.remove_redundant_sorting) + QueryPlanOptimizations::tryRemoveRedundantSorting(root); - QueryPlanOptimizations::optimizeTreeFirstPass(optimization_settings, *root, nodes); - QueryPlanOptimizations::optimizeTreeSecondPass(optimization_settings, *root, nodes); - if (optimization_settings.build_sets) - QueryPlanOptimizations::addStepsToBuildSets(*this, *root, nodes); + QueryPlanOptimizations::optimizeTreeFirstPass(optimization_settings, *root, nodes); + QueryPlanOptimizations::optimizeTreeSecondPass(optimization_settings, *root, nodes); + if (optimization_settings.build_sets) + QueryPlanOptimizations::addStepsToBuildSets(*this, *root, nodes); + } + catch (Exception & e) + { + if (e.code() == ErrorCodes::NOT_FOUND_COLUMN_IN_BLOCK) + e.addMessage("while optimizing query plan:\n{}", dumpQueryPlan(*this)); + throw; + } } void QueryPlan::explainEstimate(MutableColumns & columns) const diff --git a/src/Processors/QueryPlan/ReadFromMemoryStorageStep.h b/src/Processors/QueryPlan/ReadFromMemoryStorageStep.h index 238c1a3aad0..fc0325017ae 100644 --- a/src/Processors/QueryPlan/ReadFromMemoryStorageStep.h +++ b/src/Processors/QueryPlan/ReadFromMemoryStorageStep.h @@ -33,6 +33,8 @@ public: String getName() const override { return name; } + const StoragePtr & getStorage() const { return storage; } + void initializePipeline(QueryPipelineBuilder & pipeline, const BuildQueryPipelineSettings &) override; private: diff --git a/src/Processors/QueryPlan/ReadFromPreparedSource.cpp b/src/Processors/QueryPlan/ReadFromPreparedSource.cpp index 7f254b9bc51..82b8e9b746b 100644 --- a/src/Processors/QueryPlan/ReadFromPreparedSource.cpp +++ b/src/Processors/QueryPlan/ReadFromPreparedSource.cpp @@ -1,6 +1,7 @@ #include #include #include +#include namespace DB { @@ -21,14 +22,15 @@ void ReadFromPreparedSource::initializePipeline(QueryPipelineBuilder & pipeline, ReadFromStorageStep::ReadFromStorageStep( Pipe pipe_, - String storage_name, + StoragePtr storage_, ContextPtr context_, const SelectQueryInfo & query_info_) : ReadFromPreparedSource(std::move(pipe_)) + , storage(std::move(storage_)) , context(std::move(context_)) , query_info(query_info_) { - setStepDescription(storage_name); + setStepDescription(storage->getName()); for (const auto & processor : pipe.getProcessors()) processor->setStorageLimits(query_info.storage_limits); diff --git a/src/Processors/QueryPlan/ReadFromPreparedSource.h b/src/Processors/QueryPlan/ReadFromPreparedSource.h index b40a656cee3..32369afb2a3 100644 --- a/src/Processors/QueryPlan/ReadFromPreparedSource.h +++ b/src/Processors/QueryPlan/ReadFromPreparedSource.h @@ -21,14 +21,18 @@ protected: Pipe pipe; }; -class ReadFromStorageStep : public ReadFromPreparedSource +class ReadFromStorageStep final : public ReadFromPreparedSource { public: - ReadFromStorageStep(Pipe pipe_, String storage_name, ContextPtr context_, const SelectQueryInfo & query_info_); + ReadFromStorageStep(Pipe pipe_, StoragePtr storage_, ContextPtr context_, const SelectQueryInfo & query_info_); String getName() const override { return "ReadFromStorage"; } + const StoragePtr & getStorage() const { return storage; } + private: + StoragePtr storage; + ContextPtr context; SelectQueryInfo query_info; }; diff --git a/src/Storages/IStorage.cpp b/src/Storages/IStorage.cpp index 23f1811d330..982be5702ff 100644 --- a/src/Storages/IStorage.cpp +++ b/src/Storages/IStorage.cpp @@ -165,7 +165,7 @@ void IStorage::read( if (parallelize_output && parallelizeOutputAfterReading(context) && output_ports > 0 && output_ports < num_streams) pipe.resize(num_streams); - readFromPipe(query_plan, std::move(pipe), column_names, storage_snapshot, query_info, context, getName()); + readFromPipe(query_plan, std::move(pipe), column_names, storage_snapshot, query_info, context, shared_from_this()); } void IStorage::readFromPipe( @@ -175,7 +175,7 @@ void IStorage::readFromPipe( const StorageSnapshotPtr & storage_snapshot, SelectQueryInfo & query_info, ContextPtr context, - std::string storage_name) + std::shared_ptr storage_) { if (pipe.empty()) { @@ -184,7 +184,7 @@ void IStorage::readFromPipe( } else { - auto read_step = std::make_unique(std::move(pipe), storage_name, context, query_info); + auto read_step = std::make_unique(std::move(pipe), storage_, context, query_info); query_plan.addStep(std::move(read_step)); } } diff --git a/src/Storages/IStorage.h b/src/Storages/IStorage.h index 0dc48634282..60ea5c66a42 100644 --- a/src/Storages/IStorage.h +++ b/src/Storages/IStorage.h @@ -757,7 +757,7 @@ public: const StorageSnapshotPtr & storage_snapshot, SelectQueryInfo & query_info, ContextPtr context, - std::string storage_name); + std::shared_ptr storage_); private: /// Lock required for alter queries (lockForAlter). diff --git a/src/Storages/NATS/StorageNATS.cpp b/src/Storages/NATS/StorageNATS.cpp index 5a51f078e7b..7dc130ff9d5 100644 --- a/src/Storages/NATS/StorageNATS.cpp +++ b/src/Storages/NATS/StorageNATS.cpp @@ -398,7 +398,7 @@ void StorageNATS::read( } else { - auto read_step = std::make_unique(std::move(pipe), getName(), local_context, query_info); + auto read_step = std::make_unique(std::move(pipe), shared_from_this(), local_context, query_info); query_plan.addStep(std::move(read_step)); query_plan.addInterpreterContext(modified_context); } diff --git a/src/Storages/RabbitMQ/StorageRabbitMQ.cpp b/src/Storages/RabbitMQ/StorageRabbitMQ.cpp index 3e922b541f7..2fa1a39f76e 100644 --- a/src/Storages/RabbitMQ/StorageRabbitMQ.cpp +++ b/src/Storages/RabbitMQ/StorageRabbitMQ.cpp @@ -839,7 +839,7 @@ void StorageRabbitMQ::read( } else { - auto read_step = std::make_unique(std::move(pipe), getName(), local_context, query_info); + auto read_step = std::make_unique(std::move(pipe), shared_from_this(), local_context, query_info); query_plan.addStep(std::move(read_step)); query_plan.addInterpreterContext(modified_context); } diff --git a/src/Storages/StorageExecutable.cpp b/src/Storages/StorageExecutable.cpp index 013acb04f3e..b7cca09b81b 100644 --- a/src/Storages/StorageExecutable.cpp +++ b/src/Storages/StorageExecutable.cpp @@ -200,7 +200,7 @@ void StorageExecutable::read( } auto pipe = coordinator->createPipe(script_path, settings->script_arguments, std::move(inputs), std::move(sample_block), context, configuration); - IStorage::readFromPipe(query_plan, std::move(pipe), column_names, storage_snapshot, query_info, context, getName()); + IStorage::readFromPipe(query_plan, std::move(pipe), column_names, storage_snapshot, query_info, context, shared_from_this()); query_plan.addResources(std::move(resources)); }