diff --git a/src/Interpreters/ConcurrentHashJoin.cpp b/src/Interpreters/ConcurrentHashJoin.cpp index f4e49d3230d..7dfe6a5549e 100644 --- a/src/Interpreters/ConcurrentHashJoin.cpp +++ b/src/Interpreters/ConcurrentHashJoin.cpp @@ -16,25 +16,28 @@ #include #include #include +#include #include +#include + namespace DB { + namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int SET_SIZE_LIMIT_EXCEEDED; extern const int BAD_ARGUMENTS; } -namespace JoinStuff -{ + ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr table_join_, size_t slots_, const Block & right_sample_block, bool any_take_last_row_) : context(context_) , table_join(table_join_) , slots(slots_) { - if (!slots_ || slots_ >= 256) + if (slots < 1 || 256 < slots) { - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalid argument slot : {}", slots_); + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Number of slots should be [1, 255], got {}", slots); } for (size_t i = 0; i < slots; ++i) @@ -43,36 +46,45 @@ ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptrdata = std::make_unique(table_join_, right_sample_block, any_take_last_row_); hash_joins.emplace_back(std::move(inner_hash_join)); } - } -bool ConcurrentHashJoin::addJoinedBlock(const Block & block, bool check_limits) +bool ConcurrentHashJoin::addJoinedBlock(const Block & right_block, bool check_limits) { - Blocks dispatched_blocks = dispatchBlock(table_join->getOnlyClause().key_names_right, block); + Blocks dispatched_blocks = dispatchBlock(table_join->getOnlyClause().key_names_right, right_block); - std::list pending_blocks; - for (size_t i = 0; i < dispatched_blocks.size(); ++i) - pending_blocks.emplace_back(i); - while (!pending_blocks.empty()) + std::atomic blocks_left = 0; + for (const auto & block : dispatched_blocks) { - for (auto iter = pending_blocks.begin(); iter != pending_blocks.end();) + if (block) + { + ++blocks_left; + } + } + + while (blocks_left > 0) + { + /// insert blocks into corresponding HashJoin instances + for (size_t i = 0; i < dispatched_blocks.size(); ++i) { - auto & i = *iter; auto & hash_join = hash_joins[i]; auto & dispatched_block = dispatched_blocks[i]; - if (hash_join->mutex.try_lock()) { - if (!hash_join->data->addJoinedBlock(dispatched_block, check_limits)) - { - hash_join->mutex.unlock(); - return false; - } + /// if current hash_join is already processed by another thread, skip it and try later + std::unique_lock lock(hash_join->mutex, std::try_to_lock); + if (!lock.owns_lock()) + continue; - hash_join->mutex.unlock(); - iter = pending_blocks.erase(iter); + if (!dispatched_block) + continue; + + bool limit_exceeded = !hash_join->data->addJoinedBlock(dispatched_block, check_limits); + + dispatched_block = {}; + blocks_left--; + + if (limit_exceeded) + return false; } - else - iter++; } } @@ -161,30 +173,32 @@ std::shared_ptr ConcurrentHashJoin::getNonJoinedBlocks( throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid join type. join kind: {}, strictness: {}", table_join->kind(), table_join->strictness()); } +static IColumn::Selector hashToSelector(const WeakHash32 & hash, size_t num_shards) +{ + const auto & data = hash.getData(); + size_t num_rows = data.size(); + + IColumn::Selector selector(num_rows); + for (size_t i = 0; i < num_rows; ++i) + selector[i] = data[i] % num_shards; + return selector; +} + Blocks ConcurrentHashJoin::dispatchBlock(const Strings & key_columns_names, const Block & from_block) { - Blocks result; - size_t num_shards = hash_joins.size(); size_t num_rows = from_block.rows(); size_t num_cols = from_block.columns(); - ColumnRawPtrs key_cols; + WeakHash32 hash(num_rows); for (const auto & key_name : key_columns_names) { - key_cols.push_back(from_block.getByName(key_name).column.get()); - } - IColumn::Selector selector(num_rows); - for (size_t i = 0; i < num_rows; ++i) - { - SipHash hash; - for (const auto & key_col : key_cols) - { - key_col->updateHashWithValue(i, hash); - } - selector[i] = hash.get64() % num_shards; + const auto & key_col = from_block.getByName(key_name).column; + key_col->updateWeakHash32(hash); } + auto selector = hashToSelector(hash, num_shards); + Blocks result; for (size_t i = 0; i < num_shards; ++i) { result.emplace_back(from_block.cloneEmpty()); @@ -203,4 +217,3 @@ Blocks ConcurrentHashJoin::dispatchBlock(const Strings & key_columns_names, cons } } -} diff --git a/src/Interpreters/ConcurrentHashJoin.h b/src/Interpreters/ConcurrentHashJoin.h index 47fa2b2112f..fb226c39a0c 100644 --- a/src/Interpreters/ConcurrentHashJoin.h +++ b/src/Interpreters/ConcurrentHashJoin.h @@ -15,8 +15,7 @@ namespace DB { -namespace JoinStuff -{ + /** * Can run addJoinedBlock() parallelly to speedup the join process. On test, it almose linear speedup by * the degree of parallelism. @@ -33,6 +32,7 @@ namespace JoinStuff */ class ConcurrentHashJoin : public IJoin { + public: explicit ConcurrentHashJoin(ContextPtr context_, std::shared_ptr table_join_, size_t slots_, const Block & right_sample_block, bool any_take_last_row_ = false); ~ConcurrentHashJoin() override = default; @@ -49,6 +49,7 @@ public: bool supportParallelJoin() const override { return true; } std::shared_ptr getNonJoinedBlocks(const Block & left_sample_block, const Block & result_sample_block, UInt64 max_block_size) const override; + private: struct InternalHashJoin { @@ -71,5 +72,5 @@ private: Blocks dispatchBlock(const Strings & key_columns_names, const Block & from_block); }; -} + } diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index 0a156ba0b3e..ee69e39782d 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -939,7 +939,7 @@ static std::shared_ptr chooseJoinAlgorithm(std::shared_ptr ana { if (analyzed_join->allowParallelHashJoin()) { - return std::make_shared(context, analyzed_join, context->getSettings().max_threads, sample_block); + return std::make_shared(context, analyzed_join, context->getSettings().max_threads, sample_block); } return std::make_shared(analyzed_join, sample_block); } diff --git a/src/QueryPipeline/QueryPipelineBuilder.cpp b/src/QueryPipeline/QueryPipelineBuilder.cpp index 5e074861110..012a825a9d5 100644 --- a/src/QueryPipeline/QueryPipelineBuilder.cpp +++ b/src/QueryPipeline/QueryPipelineBuilder.cpp @@ -347,7 +347,7 @@ std::unique_ptr QueryPipelineBuilder::joinPipelines( /// ╞> FillingJoin ─> Resize ╣ ╞> Joining ─> (totals) /// (totals) ─────────┘ ╙─────┘ - auto num_streams = left->getNumStreams(); + size_t num_streams = left->getNumStreams(); if (join->supportParallelJoin() && !right->hasTotals()) { diff --git a/tests/queries/1_stateful/00172_parallel_join.sql b/tests/queries/1_stateful/00172_parallel_join.sql index fce41d7a761..36b12a43b88 100644 --- a/tests/queries/1_stateful/00172_parallel_join.sql +++ b/tests/queries/1_stateful/00172_parallel_join.sql @@ -1,4 +1,5 @@ -set join_algorithm='parallel_hash'; +SET join_algorithm='parallel_hash'; + SELECT EventDate, hits,