update codes

This commit is contained in:
lgbo-ustc 2022-04-21 12:14:36 +08:00
parent f461d18de5
commit bfd1a0e33a
4 changed files with 43 additions and 62 deletions

View File

@ -19,6 +19,7 @@
#include <Parsers/IAST_fwd.h>
#include <Parsers/parseQuery.h>
#include <Common/Exception.h>
#include "base/logger_useful.h"
namespace DB
{
namespace ErrorCodes
@ -29,7 +30,7 @@ namespace ErrorCodes
}
namespace JoinStuff
{
ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr<TableJoin> table_join_, size_t slots_, const Block & right_sample_block, bool any_take_last_row_)
ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr<TableJoin> table_join_, size_t slots_, const Block & left_sample_block, const Block & right_sample_block, bool any_take_last_row_)
: context(context_)
, table_join(table_join_)
, slots(slots_)
@ -41,17 +42,23 @@ ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr<Tabl
for (size_t i = 0; i < slots; ++i)
{
auto inner_hash_join = std::make_shared<InnerHashJoin>();
auto inner_hash_join = std::make_shared<InternalHashJoin>();
inner_hash_join->data = std::make_unique<HashJoin>(table_join_, right_sample_block, any_take_last_row_);
hash_joins.emplace_back(std::move(inner_hash_join));
}
dispatch_datas.emplace_back(std::make_shared<BlockDispatchControlData>());
dispatch_datas.emplace_back(std::make_shared<BlockDispatchControlData>());
dispatch_datas = {std::make_shared<BlockDispatchControlData>(), std::make_shared<BlockDispatchControlData>()};
const auto & onexpr = table_join->getClauses()[0];
auto & left_dispatch_data = *dispatch_datas[0];
std::tie(left_dispatch_data.hash_expression_actions, left_dispatch_data.hash_columns_names) = buildHashExpressionAction(left_sample_block, onexpr.key_names_left);
auto & right_dispatch_data = *dispatch_datas[1];
std::tie(right_dispatch_data.hash_expression_actions, right_dispatch_data.hash_columns_names) = buildHashExpressionAction(right_sample_block, onexpr.key_names_right);
}
bool ConcurrentHashJoin::addJoinedBlock(const Block & block, bool check_limits)
{
auto & dispatch_data = getBlockDispatchControlData(block, RIGHT);
auto & dispatch_data = *dispatch_datas[1];
std::vector<Block> dispatched_blocks;
Block cloned_block = block;
dispatchBlock(dispatch_data, cloned_block, dispatched_blocks);
@ -77,7 +84,6 @@ bool ConcurrentHashJoin::addJoinedBlock(const Block & block, bool check_limits)
void ConcurrentHashJoin::joinBlock(Block & block, std::shared_ptr<ExtraBlock> & not_processed)
{
if (block.rows())
waitAllAddJoinedBlocksFinished();
else
@ -87,7 +93,7 @@ void ConcurrentHashJoin::joinBlock(Block & block, std::shared_ptr<ExtraBlock> &
return;
}
auto & dispatch_data = getBlockDispatchControlData(block, LEFT);
auto & dispatch_data = *dispatch_datas[0];
std::vector<Block> dispatched_blocks;
Block cloned_block = block;
dispatchBlock(dispatch_data, cloned_block, dispatched_blocks);
@ -191,8 +197,9 @@ std::shared_ptr<NotJoinedBlocks> ConcurrentHashJoin::getNonJoinedBlocks(
throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid join type. join kind: {}, strictness: {}", table_join->kind(), table_join->strictness());
}
std::shared_ptr<ExpressionActions> ConcurrentHashJoin::buildHashExpressionAction(const Block & block, const Strings & based_columns_names, Strings & hash_columns_names)
std::pair<std::shared_ptr<ExpressionActions>, Strings> ConcurrentHashJoin::buildHashExpressionAction(const Block & block, const Strings & based_columns_names)
{
Strings hash_columns_names;
WriteBufferFromOwnString col_buf;
for (size_t i = 0, sz = based_columns_names.size(); i < sz; ++i)
{
@ -229,37 +236,13 @@ std::shared_ptr<ExpressionActions> ConcurrentHashJoin::buildHashExpressionAction
true, false, true, false);
ActionsVisitor(visitor_data, visit_log.stream()).visit(func_ast);
actions = visitor_data.getActions();
return std::make_shared<ExpressionActions>(actions);
}
ConcurrentHashJoin::BlockDispatchControlData & ConcurrentHashJoin::getBlockDispatchControlData(const Block & block, TableIndex table_index)
{
auto & data = *dispatch_datas[table_index];
if (data.has_init)[[likely]]
return data;
std::lock_guard lock(data.mutex);
if (data.has_init)
return data;
if (table_join->getClauses().empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "empty join clauses");
const auto & onexpr = table_join->getClauses()[0];
if (table_index == LEFT)
{
data.hash_expression_actions = buildHashExpressionAction(block, onexpr.key_names_left, data.hash_columns_names);
}
else
{
data.hash_expression_actions = buildHashExpressionAction(block, onexpr.key_names_right, data.hash_columns_names);
}
data.header = block.cloneEmpty();
data.has_init = true;
return data;
return {std::make_shared<ExpressionActions>(actions), hash_columns_names};
}
void ConcurrentHashJoin::dispatchBlock(BlockDispatchControlData & dispatch_data, Block & from_block, std::vector<Block> & dispatched_blocks)
{
auto rows_before_filtration = from_block.rows();
auto header = from_block.cloneEmpty();
dispatch_data.hash_expression_actions->execute(from_block, rows_before_filtration);
for (const auto & filter_column_name : dispatch_data.hash_columns_names)
{
@ -267,7 +250,7 @@ void ConcurrentHashJoin::dispatchBlock(BlockDispatchControlData & dispatch_data,
auto filter_desc = std::make_unique<FilterDescription>(*full_column);
auto num_filtered_rows = filter_desc->countBytesInFilter();
ColumnsWithTypeAndName filtered_block_columns;
for (size_t i = 0; i < dispatch_data.header.columns(); ++i)
for (size_t i = 0; i < header.columns(); ++i)
{
auto & from_column = from_block.getByPosition(i);
auto filtered_column = filter_desc->filter(*from_column.column, num_filtered_rows);
@ -281,7 +264,7 @@ void ConcurrentHashJoin::waitAllAddJoinedBlocksFinished()
{
while (finished_add_joined_blocks_tasks < hash_joins.size())[[unlikely]]
{
std::shared_ptr<InnerHashJoin> hash_join;
std::shared_ptr<InternalHashJoin> hash_join;
{
std::unique_lock lock(finished_add_joined_blocks_tasks_mutex);
hash_join = getUnfinishedAddJoinedBlockTasks();
@ -306,7 +289,7 @@ void ConcurrentHashJoin::waitAllAddJoinedBlocksFinished()
}
}
std::shared_ptr<ConcurrentHashJoin::InnerHashJoin> ConcurrentHashJoin::getUnfinishedAddJoinedBlockTasks()
std::shared_ptr<ConcurrentHashJoin::InternalHashJoin> ConcurrentHashJoin::getUnfinishedAddJoinedBlockTasks()
{
for (auto & hash_join : hash_joins)
{

View File

@ -34,7 +34,7 @@ namespace JoinStuff
class ConcurrentHashJoin : public IJoin
{
public:
explicit ConcurrentHashJoin(ContextPtr context_, std::shared_ptr<TableJoin> table_join_, size_t slots_, const Block & right_sample_block, bool any_take_last_row_ = false);
explicit ConcurrentHashJoin(ContextPtr context_, std::shared_ptr<TableJoin> table_join_, size_t slots_, const Block & left_sample_block, const Block & right_sample_block, bool any_take_last_row_ = false);
~ConcurrentHashJoin() override = default;
const TableJoin & getTableJoin() const override { return *table_join; }
@ -50,7 +50,7 @@ public:
std::shared_ptr<NotJoinedBlocks>
getNonJoinedBlocks(const Block & left_sample_block, const Block & result_sample_block, UInt64 max_block_size) const override;
private:
struct InnerHashJoin
struct InternalHashJoin
{
std::mutex mutex;
std::unique_ptr<HashJoin> data;
@ -60,7 +60,7 @@ private:
ContextPtr context;
std::shared_ptr<TableJoin> table_join;
size_t slots;
std::vector<std::shared_ptr<InnerHashJoin>> hash_joins;
std::vector<std::shared_ptr<InternalHashJoin>> hash_joins;
std::atomic<size_t> check_total_rows;
std::atomic<size_t> check_total_bytes;
@ -71,19 +71,13 @@ private:
mutable std::mutex totals_mutex;
Block totals;
enum TableIndex
{
LEFT = 0,
RIGHT = 1
};
struct BlockDispatchControlData
{
std::mutex mutex;
std::atomic<bool> has_init = false;
//std::mutex mutex;
//std::atomic<bool> has_init = false;
std::shared_ptr<ExpressionActions> hash_expression_actions;
Strings hash_columns_names;
Block header;
//Block header;
BlockDispatchControlData() = default;
};
@ -91,13 +85,12 @@ private:
Poco::Logger * logger = &Poco::Logger::get("ConcurrentHashJoin");
std::shared_ptr<ExpressionActions> buildHashExpressionAction(const Block & block, const Strings & based_columns_names, Strings & hash_columns_names);
BlockDispatchControlData & getBlockDispatchControlData(const Block & block, TableIndex table_index);
std::pair<std::shared_ptr<ExpressionActions>, Strings> buildHashExpressionAction(const Block & block, const Strings & based_columns_names);
static void dispatchBlock(BlockDispatchControlData & dispatch_data, Block & from_block, std::vector<Block> & dispatched_blocks);
void waitAllAddJoinedBlocksFinished();
std::shared_ptr<InnerHashJoin> getUnfinishedAddJoinedBlockTasks();
std::shared_ptr<InternalHashJoin> getUnfinishedAddJoinedBlockTasks();
};
}

View File

@ -60,6 +60,9 @@
#include <Processors/QueryPlan/QueryPlan.h>
#include <Parsers/formatAST.h>
#include <Poco/Logger.h>
#include <base/logger_useful.h>
namespace DB
{
@ -927,24 +930,24 @@ static ActionsDAGPtr createJoinedBlockActions(ContextPtr context, const TableJoi
return ExpressionAnalyzer(expression_list, syntax_result, context).getActionsDAG(true, false);
}
static std::shared_ptr<IJoin> chooseJoinAlgorithm(std::shared_ptr<TableJoin> analyzed_join, const Block & sample_block, ContextPtr context)
static std::shared_ptr<IJoin> chooseJoinAlgorithm(std::shared_ptr<TableJoin> analyzed_join, const Block left_sample_block, const Block & right_sample_block, ContextPtr context)
{
/// HashJoin with Dictionary optimisation
if (analyzed_join->tryInitDictJoin(sample_block, context))
return std::make_shared<HashJoin>(analyzed_join, sample_block);
if (analyzed_join->tryInitDictJoin(right_sample_block, context))
return std::make_shared<HashJoin>(analyzed_join, right_sample_block);
bool allow_merge_join = analyzed_join->allowMergeJoin();
if (analyzed_join->forceHashJoin() || (analyzed_join->preferMergeJoin() && !allow_merge_join))
{
if (analyzed_join->allowParallelHashJoin())
{
return std::make_shared<JoinStuff::ConcurrentHashJoin>(context, analyzed_join, context->getSettings().max_threads, sample_block);
return std::make_shared<JoinStuff::ConcurrentHashJoin>(context, analyzed_join, context->getSettings().max_threads, left_sample_block, right_sample_block);
}
return std::make_shared<HashJoin>(analyzed_join, sample_block);
return std::make_shared<HashJoin>(analyzed_join, right_sample_block);
}
else if (analyzed_join->forceMergeJoin() || (analyzed_join->preferMergeJoin() && allow_merge_join))
return std::make_shared<MergeJoin>(analyzed_join, sample_block);
return std::make_shared<JoinSwitcher>(analyzed_join, sample_block);
return std::make_shared<MergeJoin>(analyzed_join, right_sample_block);
return std::make_shared<JoinSwitcher>(analyzed_join, right_sample_block);
}
static std::unique_ptr<QueryPlan> buildJoinedPlan(
@ -1032,7 +1035,8 @@ JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin(
joined_plan->addStep(std::move(converting_step));
}
JoinPtr join = chooseJoinAlgorithm(analyzed_join, joined_plan->getCurrentDataStream().header, getContext());
Block left_sample_block(left_columns);
JoinPtr join = chooseJoinAlgorithm(analyzed_join, left_sample_block, joined_plan->getCurrentDataStream().header, getContext());
/// Do not make subquery for join over dictionary.
if (analyzed_join->getDictionaryReader())

View File

@ -754,7 +754,8 @@ bool TableJoin::allowParallelHashJoin() const
return false;
if (table_join.kind != ASTTableJoin::Kind::Left && table_join.kind != ASTTableJoin::Kind::Inner)
return false;
if (isSpecialStorage() || !oneDisjunct())
return false;
return true;
}