Minor style changes for ConcurrentHashJoin

This commit is contained in:
vdimir 2022-05-06 15:17:46 +00:00
parent bd5fab97d9
commit d712985575
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862
5 changed files with 59 additions and 44 deletions

View File

@ -16,25 +16,28 @@
#include <Parsers/IAST_fwd.h>
#include <Parsers/parseQuery.h>
#include <Common/Exception.h>
#include <Common/WeakHash.h>
#include <Common/typeid_cast.h>
#include <base/scope_guard.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int SET_SIZE_LIMIT_EXCEEDED;
extern const int BAD_ARGUMENTS;
}
namespace JoinStuff
{
ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr<TableJoin> table_join_, size_t slots_, const Block & right_sample_block, bool any_take_last_row_)
: context(context_)
, table_join(table_join_)
, slots(slots_)
{
if (!slots_ || slots_ >= 256)
if (slots < 1 || 256 < slots)
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Invalid argument slot : {}", slots_);
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Number of slots should be [1, 255], got {}", slots);
}
for (size_t i = 0; i < slots; ++i)
@ -43,36 +46,45 @@ ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr<Tabl
inner_hash_join->data = std::make_unique<HashJoin>(table_join_, right_sample_block, any_take_last_row_);
hash_joins.emplace_back(std::move(inner_hash_join));
}
}
bool ConcurrentHashJoin::addJoinedBlock(const Block & block, bool check_limits)
bool ConcurrentHashJoin::addJoinedBlock(const Block & right_block, bool check_limits)
{
Blocks dispatched_blocks = dispatchBlock(table_join->getOnlyClause().key_names_right, block);
Blocks dispatched_blocks = dispatchBlock(table_join->getOnlyClause().key_names_right, right_block);
std::list<size_t> pending_blocks;
for (size_t i = 0; i < dispatched_blocks.size(); ++i)
pending_blocks.emplace_back(i);
while (!pending_blocks.empty())
std::atomic<size_t> blocks_left = 0;
for (const auto & block : dispatched_blocks)
{
for (auto iter = pending_blocks.begin(); iter != pending_blocks.end();)
if (block)
{
++blocks_left;
}
}
while (blocks_left > 0)
{
/// insert blocks into corresponding HashJoin instances
for (size_t i = 0; i < dispatched_blocks.size(); ++i)
{
auto & i = *iter;
auto & hash_join = hash_joins[i];
auto & dispatched_block = dispatched_blocks[i];
if (hash_join->mutex.try_lock())
{
if (!hash_join->data->addJoinedBlock(dispatched_block, check_limits))
{
hash_join->mutex.unlock();
return false;
}
/// if current hash_join is already processed by another thread, skip it and try later
std::unique_lock<std::mutex> lock(hash_join->mutex, std::try_to_lock);
if (!lock.owns_lock())
continue;
hash_join->mutex.unlock();
iter = pending_blocks.erase(iter);
if (!dispatched_block)
continue;
bool limit_exceeded = !hash_join->data->addJoinedBlock(dispatched_block, check_limits);
dispatched_block = {};
blocks_left--;
if (limit_exceeded)
return false;
}
else
iter++;
}
}
@ -161,30 +173,32 @@ std::shared_ptr<NotJoinedBlocks> ConcurrentHashJoin::getNonJoinedBlocks(
throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid join type. join kind: {}, strictness: {}", table_join->kind(), table_join->strictness());
}
static IColumn::Selector hashToSelector(const WeakHash32 & hash, size_t num_shards)
{
const auto & data = hash.getData();
size_t num_rows = data.size();
IColumn::Selector selector(num_rows);
for (size_t i = 0; i < num_rows; ++i)
selector[i] = data[i] % num_shards;
return selector;
}
Blocks ConcurrentHashJoin::dispatchBlock(const Strings & key_columns_names, const Block & from_block)
{
Blocks result;
size_t num_shards = hash_joins.size();
size_t num_rows = from_block.rows();
size_t num_cols = from_block.columns();
ColumnRawPtrs key_cols;
WeakHash32 hash(num_rows);
for (const auto & key_name : key_columns_names)
{
key_cols.push_back(from_block.getByName(key_name).column.get());
}
IColumn::Selector selector(num_rows);
for (size_t i = 0; i < num_rows; ++i)
{
SipHash hash;
for (const auto & key_col : key_cols)
{
key_col->updateHashWithValue(i, hash);
}
selector[i] = hash.get64() % num_shards;
const auto & key_col = from_block.getByName(key_name).column;
key_col->updateWeakHash32(hash);
}
auto selector = hashToSelector(hash, num_shards);
Blocks result;
for (size_t i = 0; i < num_shards; ++i)
{
result.emplace_back(from_block.cloneEmpty());
@ -203,4 +217,3 @@ Blocks ConcurrentHashJoin::dispatchBlock(const Strings & key_columns_names, cons
}
}
}

View File

@ -15,8 +15,7 @@
namespace DB
{
namespace JoinStuff
{
/**
* Can run addJoinedBlock() parallelly to speedup the join process. On test, it almose linear speedup by
* the degree of parallelism.
@ -33,6 +32,7 @@ namespace JoinStuff
*/
class ConcurrentHashJoin : public IJoin
{
public:
explicit ConcurrentHashJoin(ContextPtr context_, std::shared_ptr<TableJoin> table_join_, size_t slots_, const Block & right_sample_block, bool any_take_last_row_ = false);
~ConcurrentHashJoin() override = default;
@ -49,6 +49,7 @@ public:
bool supportParallelJoin() const override { return true; }
std::shared_ptr<NotJoinedBlocks>
getNonJoinedBlocks(const Block & left_sample_block, const Block & result_sample_block, UInt64 max_block_size) const override;
private:
struct InternalHashJoin
{
@ -71,5 +72,5 @@ private:
Blocks dispatchBlock(const Strings & key_columns_names, const Block & from_block);
};
}
}

View File

@ -939,7 +939,7 @@ static std::shared_ptr<IJoin> chooseJoinAlgorithm(std::shared_ptr<TableJoin> ana
{
if (analyzed_join->allowParallelHashJoin())
{
return std::make_shared<JoinStuff::ConcurrentHashJoin>(context, analyzed_join, context->getSettings().max_threads, sample_block);
return std::make_shared<ConcurrentHashJoin>(context, analyzed_join, context->getSettings().max_threads, sample_block);
}
return std::make_shared<HashJoin>(analyzed_join, sample_block);
}

View File

@ -347,7 +347,7 @@ std::unique_ptr<QueryPipelineBuilder> QueryPipelineBuilder::joinPipelines(
/// ╞> FillingJoin ─> Resize ╣ ╞> Joining ─> (totals)
/// (totals) ─────────┘ ╙─────┘
auto num_streams = left->getNumStreams();
size_t num_streams = left->getNumStreams();
if (join->supportParallelJoin() && !right->hasTotals())
{

View File

@ -1,4 +1,5 @@
set join_algorithm='parallel_hash';
SET join_algorithm='parallel_hash';
SELECT
EventDate,
hits,