#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int LOGICAL_ERROR; extern const int SET_SIZE_LIMIT_EXCEEDED; extern const int BAD_ARGUMENTS; } static UInt32 toPowerOfTwo(UInt32 x) { x = x - 1; for (UInt32 i = 1; i < sizeof(UInt32) * 8; i <<= 1) x = x | x >> i; return x + 1; } ConcurrentHashJoin::ConcurrentHashJoin(ContextPtr context_, std::shared_ptr table_join_, size_t slots_, const Block & right_sample_block, bool any_take_last_row_) : context(context_) , table_join(table_join_) , slots(slots_) { slots = std::min(std::max(1, slots), 255); slots = toPowerOfTwo(slots); for (size_t i = 0; i < slots; ++i) { auto inner_hash_join = std::make_shared(); inner_hash_join->data = std::make_unique(table_join_, right_sample_block, any_take_last_row_); hash_joins.emplace_back(std::move(inner_hash_join)); } } bool ConcurrentHashJoin::addJoinedBlock(const Block & right_block, bool check_limits) { Blocks dispatched_blocks = dispatchBlock(table_join->getOnlyClause().key_names_right, right_block); size_t blocks_left = 0; for (const auto & block : dispatched_blocks) { if (block) { ++blocks_left; } } while (blocks_left > 0) { /// insert blocks into corresponding HashJoin instances for (size_t i = 0; i < dispatched_blocks.size(); ++i) { auto & hash_join = hash_joins[i]; auto & dispatched_block = dispatched_blocks[i]; if (dispatched_block) { /// if current hash_join is already processed by another thread, skip it and try later std::unique_lock lock(hash_join->mutex, std::try_to_lock); if (!lock.owns_lock()) continue; bool limit_exceeded = !hash_join->data->addJoinedBlock(dispatched_block, check_limits); dispatched_block = {}; blocks_left--; if (limit_exceeded) return false; } } } if (check_limits) return table_join->sizeLimits().check(getTotalRowCount(), getTotalByteCount(), "JOIN", ErrorCodes::SET_SIZE_LIMIT_EXCEEDED); return true; } void ConcurrentHashJoin::joinBlock(Block & block, std::shared_ptr & /*not_processed*/) { Blocks dispatched_blocks = dispatchBlock(table_join->getOnlyClause().key_names_left, block); for (size_t i = 0; i < dispatched_blocks.size(); ++i) { std::shared_ptr none_extra_block; auto & hash_join = hash_joins[i]; auto & dispatched_block = dispatched_blocks[i]; hash_join->data->joinBlock(dispatched_block, none_extra_block); if (none_extra_block && !none_extra_block->empty()) throw Exception(ErrorCodes::LOGICAL_ERROR, "not_processed should be empty"); } block = concatenateBlocks(dispatched_blocks); } void ConcurrentHashJoin::checkTypesOfKeys(const Block & block) const { hash_joins[0]->data->checkTypesOfKeys(block); } void ConcurrentHashJoin::setTotals(const Block & block) { if (block) { std::lock_guard lock(totals_mutex); totals = block; } } const Block & ConcurrentHashJoin::getTotals() const { return totals; } size_t ConcurrentHashJoin::getTotalRowCount() const { size_t res = 0; for (const auto & hash_join : hash_joins) { std::lock_guard lock(hash_join->mutex); res += hash_join->data->getTotalRowCount(); } return res; } size_t ConcurrentHashJoin::getTotalByteCount() const { size_t res = 0; for (const auto & hash_join : hash_joins) { std::lock_guard lock(hash_join->mutex); res += hash_join->data->getTotalByteCount(); } return res; } bool ConcurrentHashJoin::alwaysReturnsEmptySet() const { for (const auto & hash_join : hash_joins) { std::lock_guard lock(hash_join->mutex); if (!hash_join->data->alwaysReturnsEmptySet()) return false; } return true; } std::shared_ptr ConcurrentHashJoin::getNonJoinedBlocks( const Block & /*left_sample_block*/, const Block & /*result_sample_block*/, UInt64 /*max_block_size*/) const { if (table_join->strictness() == ASTTableJoin::Strictness::Asof || table_join->strictness() == ASTTableJoin::Strictness::Semi || !isRightOrFull(table_join->kind())) { return {}; } throw Exception(ErrorCodes::LOGICAL_ERROR, "Invalid join type. join kind: {}, strictness: {}", table_join->kind(), table_join->strictness()); } static IColumn::Selector hashToSelector(const WeakHash32 & hash, size_t num_shards) { assert(num_shards > 0 && (num_shards & (num_shards - 1)) == 0); const auto & data = hash.getData(); size_t num_rows = data.size(); IColumn::Selector selector(num_rows); for (size_t i = 0; i < num_rows; ++i) selector[i] = data[i] & (num_shards - 1); return selector; } Blocks ConcurrentHashJoin::dispatchBlock(const Strings & key_columns_names, const Block & from_block) { size_t num_shards = hash_joins.size(); size_t num_rows = from_block.rows(); size_t num_cols = from_block.columns(); WeakHash32 hash(num_rows); for (const auto & key_name : key_columns_names) { const auto & key_col = from_block.getByName(key_name).column->convertToFullColumnIfConst(); const auto & key_col_no_lc = recursiveRemoveLowCardinality(recursiveRemoveSparse(key_col)); key_col_no_lc->updateWeakHash32(hash); } auto selector = hashToSelector(hash, num_shards); Blocks result; for (size_t i = 0; i < num_shards; ++i) { result.emplace_back(from_block.cloneEmpty()); } for (size_t i = 0; i < num_cols; ++i) { auto dispatched_columns = from_block.getByPosition(i).column->scatter(num_shards, selector); assert(result.size() == dispatched_columns.size()); for (size_t block_index = 0; block_index < num_shards; ++block_index) { result[block_index].getByPosition(i).column = std::move(dispatched_columns[block_index]); } } return result; } }