diff --git a/src/Interpreters/ConcurrentHashJoin.cpp b/src/Interpreters/ConcurrentHashJoin.cpp index a8187b457cf..ab4ea242882 100644 --- a/src/Interpreters/ConcurrentHashJoin.cpp +++ b/src/Interpreters/ConcurrentHashJoin.cpp @@ -14,12 +14,14 @@ #include #include #include +#include #include #include #include #include +#include #include -#include "base/logger_useful.h" +#include namespace DB { namespace ErrorCodes @@ -35,7 +37,7 @@ ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr= 256) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalid argument slot : {}", slots_); } @@ -50,10 +52,10 @@ ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr(), std::make_shared()}; const auto & onexpr = table_join->getClauses()[0]; auto & left_dispatch_data = *dispatch_datas[0]; - std::tie(left_dispatch_data.hash_expression_actions, left_dispatch_data.hash_columns_names) = buildHashExpressionAction(left_sample_block, onexpr.key_names_left); + std::tie(left_dispatch_data.hash_expression_actions, left_dispatch_data.hash_column_name) = buildHashExpressionAction(left_sample_block, onexpr.key_names_left); auto & right_dispatch_data = *dispatch_datas[1]; - std::tie(right_dispatch_data.hash_expression_actions, right_dispatch_data.hash_columns_names) = buildHashExpressionAction(right_sample_block, onexpr.key_names_right); + std::tie(right_dispatch_data.hash_expression_actions, right_dispatch_data.hash_column_name) = buildHashExpressionAction(right_sample_block, onexpr.key_names_right); } bool ConcurrentHashJoin::addJoinedBlock(const Block & block, bool check_limits) @@ -80,7 +82,7 @@ bool ConcurrentHashJoin::addJoinedBlock(const Block & block, bool check_limits) hash_join->mutex.unlock(); iter = pending_blocks.erase(iter); } - else + else { iter++; } @@ -172,7 +174,7 @@ std::shared_ptr ConcurrentHashJoin::getNonJoinedBlocks( throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid join type. join kind: {}, strictness: {}", table_join->kind(), table_join->strictness()); } -std::pair, Strings> ConcurrentHashJoin::buildHashExpressionAction(const Block & block, const Strings & based_columns_names) +std::pair, String> ConcurrentHashJoin::buildHashExpressionAction(const Block & block, const Strings & based_columns_names) { Strings hash_columns_names; WriteBufferFromOwnString col_buf; @@ -183,17 +185,12 @@ std::pair, Strings> ConcurrentHashJoin::build col_buf << based_columns_names[i]; } WriteBufferFromOwnString write_buf; - for (size_t i = 0; i < slots; ++i) - { - if (i) - write_buf << ","; - write_buf << "cityHash64(" << col_buf.str() << ")%" << slots << "=" << i; - } + write_buf << "cityHash64(" << col_buf.str() << ") % " << slots; + auto settings = context->getSettings(); ParserExpressionList hash_expr_parser(true); ASTPtr func_ast = parseQuery(hash_expr_parser, write_buf.str(), "Parse Block hash expression", settings.max_query_size, settings.max_parser_depth); - for (auto & child : func_ast->children) - hash_columns_names.emplace_back(child->getColumnName()); + auto hash_column_name = func_ast->children[0]->getColumnName(); DebugASTLog visit_log; const auto & names_and_types = block.getNamesAndTypesList(); @@ -211,28 +208,56 @@ std::pair, Strings> ConcurrentHashJoin::build true, false, true, false); ActionsVisitor(visitor_data, visit_log.stream()).visit(func_ast); actions = visitor_data.getActions(); - return {std::make_shared(actions), hash_columns_names}; + return {std::make_shared(actions), hash_column_name}; } void ConcurrentHashJoin::dispatchBlock(BlockDispatchControlData & dispatch_data, Block & from_block, Blocks & dispatched_blocks) { - auto rows_before_filtration = from_block.rows(); auto header = from_block.cloneEmpty(); - dispatch_data.hash_expression_actions->execute(from_block, rows_before_filtration); - for (const auto & filter_column_name : dispatch_data.hash_columns_names) + auto num_shards = hash_joins.size(); + Block block_for_build_selector = from_block; + dispatch_data.hash_expression_actions->execute(block_for_build_selector); + auto selector_column = block_for_build_selector.getByName(dispatch_data.hash_column_name); + std::vector selector_slots; + for (UInt64 i = 0; i < num_shards; ++i) { - auto full_column = from_block.findByName(filter_column_name)->column->convertToFullColumnIfConst(); - auto filter_desc = std::make_unique(*full_column); - auto num_filtered_rows = filter_desc->countBytesInFilter(); - ColumnsWithTypeAndName filtered_block_columns; - for (size_t i = 0; i < header.columns(); ++i) - { - auto & from_column = from_block.getByPosition(i); - auto filtered_column = filter_desc->filter(*from_column.column, num_filtered_rows); - filtered_block_columns.emplace_back(filtered_column, from_column.type, from_column.name); - } - dispatched_blocks.emplace_back(std::move(filtered_block_columns)); + selector_slots.emplace_back(i); + dispatched_blocks.emplace_back(from_block.cloneEmpty()); } + if (selector_column.column->isNullable()) + { + const auto * nullable_col = typeid_cast(selector_column.column.get()); + const auto & nested_col = nullable_col->getNestedColumnPtr(); + size_t last_offset = 0; + MutableColumnPtr dst = nullable_col->cloneEmpty(); + for (size_t i = 0, sz = selector_column.column->size(); i < sz; ++i) + { + if (selector_column.column->isNullAt(i))[[unlikely]] + { + if (i > last_offset)[[likely]] + dst->insertRangeFrom(*nested_col, last_offset, i - last_offset); + dst->insertDefault(); + last_offset = i + 1; + } + } + if (last_offset < selector_column.column->size()) + { + dst->insertRangeFrom(*nested_col, last_offset, selector_column.column->size() - last_offset); + } + selector_column.column = std::move(dst); + } + auto selector = createBlockSelector(*selector_column.column, selector_slots); + + auto columns_in_block = header.columns(); + for (size_t i = 0; i < columns_in_block; ++i) + { + auto dispatched_columns = from_block.getByPosition(i).column->scatter(num_shards, selector); + for (size_t block_index = 0; block_index < num_shards; ++block_index) + { + dispatched_blocks[block_index].getByPosition(i).column = std::move(dispatched_columns[block_index]); + } + } + } } diff --git a/src/Interpreters/ConcurrentHashJoin.h b/src/Interpreters/ConcurrentHashJoin.h index 066fe7fefdb..e1f73d38a75 100644 --- a/src/Interpreters/ConcurrentHashJoin.h +++ b/src/Interpreters/ConcurrentHashJoin.h @@ -71,7 +71,7 @@ private: struct BlockDispatchControlData { std::shared_ptr hash_expression_actions; - Strings hash_columns_names; + String hash_column_name; BlockDispatchControlData() = default; }; @@ -79,9 +79,9 @@ private: Poco::Logger * logger = &Poco::Logger::get("ConcurrentHashJoin"); - std::pair, Strings> buildHashExpressionAction(const Block & block, const Strings & based_columns_names); + std::pair, String> buildHashExpressionAction(const Block & block, const Strings & based_columns_names); - static void dispatchBlock(BlockDispatchControlData & dispatch_data, Block & from_block, Blocks & dispatched_blocks); + void dispatchBlock(BlockDispatchControlData & dispatch_data, Block & from_block, Blocks & dispatched_blocks); }; }