refactor: Extract block sharding logic from ConcurrentHashJoin

This commit is contained in:
Sergey Skvortsov 2022-06-11 15:47:46 +03:00
parent a211b7f360
commit 56861d1f05
3 changed files with 73 additions and 48 deletions

View File

@ -171,47 +171,10 @@ std::shared_ptr<NotJoinedBlocks> ConcurrentHashJoin::getNonJoinedBlocks(
throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid join type. join kind: {}, strictness: {}", table_join->kind(), table_join->strictness()); throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid join type. join kind: {}, strictness: {}", table_join->kind(), table_join->strictness());
} }
static IColumn::Selector hashToSelector(const WeakHash32 & hash, size_t num_shards)
{
const auto & data = hash.getData();
size_t num_rows = data.size();
IColumn::Selector selector(num_rows);
for (size_t i = 0; i < num_rows; ++i)
selector[i] = data[i] % num_shards;
return selector;
}
Blocks ConcurrentHashJoin::dispatchBlock(const Strings & key_columns_names, const Block & from_block) Blocks ConcurrentHashJoin::dispatchBlock(const Strings & key_columns_names, const Block & from_block)
{ {
size_t num_shards = hash_joins.size(); size_t num_shards = hash_joins.size();
size_t num_rows = from_block.rows(); return JoinCommon::scatterBlockByHash(key_columns_names, from_block, num_shards);
size_t num_cols = from_block.columns();
WeakHash32 hash(num_rows);
for (const auto & key_name : key_columns_names)
{
const auto & key_col = from_block.getByName(key_name).column;
key_col->updateWeakHash32(hash);
}
auto selector = hashToSelector(hash, num_shards);
Blocks result;
for (size_t i = 0; i < num_shards; ++i)
{
result.emplace_back(from_block.cloneEmpty());
}
for (size_t i = 0; i < num_cols; ++i)
{
auto dispatched_columns = from_block.getByPosition(i).column->scatter(num_shards, selector);
assert(result.size() == dispatched_columns.size());
for (size_t block_index = 0; block_index < num_shards; ++block_index)
{
result[block_index].getByPosition(i).column = std::move(dispatched_columns[block_index]);
}
}
return result;
} }
} }

View File

@ -4,8 +4,9 @@
namespace DB namespace DB
{ {
GraceHashJoin::GraceHashJoin(std::shared_ptr<TableJoin> table_join, const Block & right_sample_block, bool any_take_last_row) GraceHashJoin::GraceHashJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block, bool any_take_last_row)
: first_bucket{std::make_shared<HashJoin>(table_join, right_sample_block, any_take_last_row)} : table_join{std::move(table_join_)}
, first_bucket{std::make_shared<HashJoin>(table_join, right_sample_block, any_take_last_row)}
{ {
} }

View File

@ -14,6 +14,10 @@
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <Common/WeakHash.h>
#include <base/FnTraits.h>
namespace DB namespace DB
{ {
@ -573,17 +577,74 @@ void splitAdditionalColumns(const Names & key_names, const Block & sample_block,
} }
} }
template <std::integral T> template <Fn<size_t(size_t)> Sharder>
static bool isPowerOf2(T number) { static IColumn::Selector hashToSelector(const WeakHash32 & hash, Sharder sharder)
// TODO(sskvor) {
return false; const auto & hashes = hash.getData();
size_t num_rows = hashes.size();
IColumn::Selector selector(num_rows);
for (size_t i = 0; i < num_rows; ++i)
selector[i] = sharder(hashes[i]);
return selector;
} }
Blocks scatterBlockByHash(const Strings& key_columns_names, const Block& block, size_t num_shards) { template <Fn<size_t(size_t)> Sharder>
if (likely(isPowerOf2(num_shards))) { static Blocks scatterBlockByHashImpl(const Strings & key_columns_names, const Block & block, size_t num_shards, Sharder sharder) {
return scatterBlockByHashLog2(key_columns_names, block, num_shards_log2); size_t num_rows = block.rows();
size_t num_cols = block.columns();
WeakHash32 hash(num_rows);
for (const auto & key_name : key_columns_names)
{
const auto & key_col = block.getByName(key_name).column;
key_col->updateWeakHash32(hash);
} }
return scatterBlockByHashImpl(key_columns_names, ) auto selector = hashToSelector(hash, sharder);
Blocks result;
result.reserve(num_shards);
for (size_t i = 0; i < num_shards; ++i)
{
result.emplace_back(block.cloneEmpty());
}
for (size_t i = 0; i < num_cols; ++i)
{
auto dispatched_columns = block.getByPosition(i).column->scatter(num_shards, selector);
assert(result.size() == dispatched_columns.size());
for (size_t block_index = 0; block_index < num_shards; ++block_index)
{
result[block_index].getByPosition(i).column = std::move(dispatched_columns[block_index]);
}
}
return result;
}
template <std::integral T>
static bool isPowerOf2(T number) {
return number == T{1} << bitScanReverse(number);
}
static Blocks scatterBlockByHashPow2(const Strings & key_columns_names, const Block & block, size_t num_shards) {
UInt32 log2 = bitScanReverse(num_shards);
UInt32 mask = maskLowBits<size_t>(log2);
return scatterBlockByHashImpl(key_columns_names, block, num_shards, [mask](size_t hash) {
return hash & mask;
});
}
static Blocks scatterBlockByHashGeneric(const Strings & key_columns_names, const Block & block, size_t num_shards) {
return scatterBlockByHashImpl(key_columns_names, block, num_shards, [num_shards](size_t hash) {
return hash % num_shards;
});
}
Blocks scatterBlockByHash(const Strings & key_columns_names, const Block & block, size_t num_shards) {
if (likely(isPowerOf2(num_shards)))
return scatterBlockByHashPow2(key_columns_names, block, num_shards);
return scatterBlockByHashGeneric(key_columns_names, block, num_shards);
} }
} }