From 344bc05156c6336a98bfbacc613ad8675939587d Mon Sep 17 00:00:00 2001 From: lgbo Date: Mon, 10 Jul 2023 16:44:01 +0800 Subject: [PATCH] wip: reserve hash table size (#50875) --- src/Interpreters/GraceHashJoin.cpp | 34 +++++++++++++++++++----------- src/Interpreters/GraceHashJoin.h | 2 +- src/Interpreters/HashJoin.cpp | 14 +++++++++--- src/Interpreters/HashJoin.h | 30 ++++++++++++++++++++++++-- 4 files changed, 62 insertions(+), 18 deletions(-) diff --git a/src/Interpreters/GraceHashJoin.cpp b/src/Interpreters/GraceHashJoin.cpp index 5b9521f3774..edf604bc0b4 100644 --- a/src/Interpreters/GraceHashJoin.cpp +++ b/src/Interpreters/GraceHashJoin.cpp @@ -578,18 +578,12 @@ IBlocksStreamPtr GraceHashJoin::getDelayedBlocks() size_t bucket_idx = current_bucket->idx; - if (hash_join) + size_t prev_keys_num = 0; + if (hash_join && buckets.size() > 1) { - auto right_blocks = hash_join->releaseJoinedBlocks(/* restructure */ false); - for (auto & block : right_blocks) - { - Blocks blocks = JoinCommon::scatterBlockByHash(right_key_names, block, buckets.size()); - flushBlocksToBuckets(blocks, buckets, bucket_idx); - } + prev_keys_num = hash_join->getTotalRowCount(); } - hash_join = makeInMemoryJoin(); - for (bucket_idx = bucket_idx + 1; bucket_idx < buckets.size(); ++bucket_idx) { current_bucket = buckets[bucket_idx].get(); @@ -602,6 +596,7 @@ IBlocksStreamPtr GraceHashJoin::getDelayedBlocks() continue; } + hash_join = makeInMemoryJoin(prev_keys_num); auto right_reader = current_bucket->startJoining(); size_t num_rows = 0; /// count rows that were written and rehashed while (Block block = right_reader.read()) @@ -622,9 +617,10 @@ IBlocksStreamPtr GraceHashJoin::getDelayedBlocks() return nullptr; } -GraceHashJoin::InMemoryJoinPtr GraceHashJoin::makeInMemoryJoin() +GraceHashJoin::InMemoryJoinPtr GraceHashJoin::makeInMemoryJoin(size_t reserve_num) { - return std::make_unique(table_join, right_sample_block, any_take_last_row); + auto ret = std::make_unique(table_join, right_sample_block, any_take_last_row, reserve_num); + return std::move(ret); } Block GraceHashJoin::prepareRightBlock(const Block & block) @@ -652,6 +648,20 @@ void GraceHashJoin::addBlockToJoinImpl(Block block) if (!hash_join) hash_join = makeInMemoryJoin(); + // buckets size has been changed in other threads. Need to scatter current_block again. + // rehash could only happen under hash_join_mutex's scope. + auto current_buckets = getCurrentBuckets(); + if (buckets_snapshot.size() != current_buckets.size()) + { + LOG_TRACE(log, "mismatch buckets size. previous:{}, current:{}", buckets_snapshot.size(), getCurrentBuckets().size()); + Blocks blocks = JoinCommon::scatterBlockByHash(right_key_names, current_block, current_buckets.size()); + flushBlocksToBuckets(blocks, current_buckets, bucket_index); + current_block = std::move(blocks[bucket_index]); + if (!current_block.rows()) + return; + } + + auto prev_keys_num = hash_join->getTotalRowCount(); hash_join->addBlockToJoin(current_block, /* check_limits = */ false); if (!hasMemoryOverflow(hash_join)) @@ -680,7 +690,7 @@ void GraceHashJoin::addBlockToJoinImpl(Block block) current_block = concatenateBlocks(current_blocks); } - hash_join = makeInMemoryJoin(); + hash_join = makeInMemoryJoin(prev_keys_num); if (current_block.rows() > 0) hash_join->addBlockToJoin(current_block, /* check_limits = */ false); diff --git a/src/Interpreters/GraceHashJoin.h b/src/Interpreters/GraceHashJoin.h index 3736e9c5e1e..bce04ee6b04 100644 --- a/src/Interpreters/GraceHashJoin.h +++ b/src/Interpreters/GraceHashJoin.h @@ -91,7 +91,7 @@ public: private: void initBuckets(); /// Create empty join for in-memory processing. - InMemoryJoinPtr makeInMemoryJoin(); + InMemoryJoinPtr makeInMemoryJoin(size_t reserve_num = 0); /// Add right table block to the @join. Calls @rehash on overflow. void addBlockToJoinImpl(Block block); diff --git a/src/Interpreters/HashJoin.cpp b/src/Interpreters/HashJoin.cpp index 3a464de0f12..be08b7cbe1e 100644 --- a/src/Interpreters/HashJoin.cpp +++ b/src/Interpreters/HashJoin.cpp @@ -217,7 +217,7 @@ static void correctNullabilityInplace(ColumnWithTypeAndName & column, bool nulla JoinCommon::removeColumnNullability(column); } -HashJoin::HashJoin(std::shared_ptr table_join_, const Block & right_sample_block_, bool any_take_last_row_) +HashJoin::HashJoin(std::shared_ptr table_join_, const Block & right_sample_block_, bool any_take_last_row_, size_t reserve_num) : table_join(table_join_) , kind(table_join->kind()) , strictness(table_join->strictness()) @@ -302,7 +302,7 @@ HashJoin::HashJoin(std::shared_ptr table_join_, const Block & right_s } for (auto & maps : data->maps) - dataMapInit(maps); + dataMapInit(maps, reserve_num); } HashJoin::Type HashJoin::chooseMethod(JoinKind kind, const ColumnRawPtrs & key_columns, Sizes & key_sizes) @@ -454,13 +454,21 @@ struct KeyGetterForType using Type = typename KeyGetterForTypeImpl::Type; }; -void HashJoin::dataMapInit(MapsVariant & map) +void HashJoin::dataMapInit(MapsVariant & map, size_t reserve_num) { if (kind == JoinKind::Cross) return; joinDispatchInit(kind, strictness, map); joinDispatch(kind, strictness, map, [&](auto, auto, auto & map_) { map_.create(data->type); }); + + if (reserve_num) + { + joinDispatch(kind, strictness, map, [&](auto, auto, auto & map_) { map_.reserve(data->type, reserve_num); }); + } + + if (!data) + throw Exception(ErrorCodes::LOGICAL_ERROR, "HashJoin::dataMapInit called with empty data"); } bool HashJoin::empty() const diff --git a/src/Interpreters/HashJoin.h b/src/Interpreters/HashJoin.h index f30bbc3a46c..56dea98c1f1 100644 --- a/src/Interpreters/HashJoin.h +++ b/src/Interpreters/HashJoin.h @@ -146,7 +146,8 @@ public: class HashJoin : public IJoin { public: - HashJoin(std::shared_ptr table_join_, const Block & right_sample_block, bool any_take_last_row_ = false); + HashJoin( + std::shared_ptr table_join_, const Block & right_sample_block, bool any_take_last_row_ = false, size_t reserve_num = 0); ~HashJoin() override; @@ -217,6 +218,15 @@ public: M(keys256) \ M(hashed) + /// Only for maps using hash table. + #define APPLY_FOR_HASH_JOIN_VARIANTS(M) \ + M(key32) \ + M(key64) \ + M(key_string) \ + M(key_fixed_string) \ + M(keys128) \ + M(keys256) \ + M(hashed) /// Used for reading from StorageJoin and applying joinGet function #define APPLY_FOR_JOIN_VARIANTS_LIMITED(M) \ @@ -266,6 +276,22 @@ public: } } + void reserve(Type which, size_t num) + { + switch (which) + { + case Type::EMPTY: break; + case Type::CROSS: break; + case Type::key8: break; + case Type::key16: break; + + #define M(NAME) \ + case Type::NAME: NAME->reserve(num); break; + APPLY_FOR_HASH_JOIN_VARIANTS(M) + #undef M + } + } + size_t getTotalRowCount(Type which) const { switch (which) @@ -409,7 +435,7 @@ private: /// If set HashJoin instance is not available for modification (addBlockToJoin) TableLockHolder storage_join_lock = nullptr; - void dataMapInit(MapsVariant &); + void dataMapInit(MapsVariant &, size_t); void initRightBlockStructure(Block & saved_block_sample);