diff --git a/dbms/src/Interpreters/IJoin.cpp b/dbms/src/Interpreters/IJoin.cpp index ed5c9f1935e..46497a8ed30 100644 --- a/dbms/src/Interpreters/IJoin.cpp +++ b/dbms/src/Interpreters/IJoin.cpp @@ -1,11 +1,37 @@ #include #include -#include +#include #include +#include namespace DB { +namespace ErrorCodes +{ + extern const int TYPE_MISMATCH; +} + + +namespace JoinCommon +{ + +void convertColumnToNullable(ColumnWithTypeAndName & column) +{ + if (column.type->isNullable() || !column.type->canBeInsideNullable()) + return; + + column.type = makeNullable(column.type); + if (column.column) + column.column = makeNullable(column.column); +} + +void convertColumnsToNullable(Block & block, size_t starting_pos) +{ + for (size_t i = starting_pos; i < block.columns(); ++i) + convertColumnToNullable(block.getByPosition(i)); +} + ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block, Block & sample_block_with_keys, Block & sample_block_with_columns_to_add) { @@ -43,6 +69,23 @@ ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & ri return key_columns; } +void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right) +{ + size_t keys_size = key_names_left.size(); + + for (size_t i = 0; i < keys_size; ++i) + { + DataTypePtr left_type = removeNullable(recursiveRemoveLowCardinality(block_left.getByName(key_names_left[i]).type)); + DataTypePtr right_type = removeNullable(recursiveRemoveLowCardinality(block_right.getByName(key_names_right[i]).type)); + + if (!left_type->equals(*right_type)) + throw Exception("Type mismatch of columns to JOIN by: " + + key_names_left[i] + " " + left_type->getName() + " at left, " + + key_names_right[i] + " " + right_type->getName() + " at right", + ErrorCodes::TYPE_MISMATCH); + } +} + void createMissedColumns(Block & block) { for (size_t i = 0; i < block.columns(); ++i) @@ -54,3 +97,4 @@ void createMissedColumns(Block & block) } } +} diff --git a/dbms/src/Interpreters/IJoin.h b/dbms/src/Interpreters/IJoin.h index d7c6d28d551..42eada1c43e 100644 --- a/dbms/src/Interpreters/IJoin.h +++ b/dbms/src/Interpreters/IJoin.h @@ -8,6 +8,7 @@ namespace DB { +struct ColumnWithTypeAndName; class Block; class IColumn; using ColumnRawPtrs = std::vector; @@ -34,10 +35,22 @@ public: using JoinPtr = std::shared_ptr; -/// Common join functions +namespace JoinCommon +{ + +void convertColumnToNullable(ColumnWithTypeAndName & column); +void convertColumnsToNullable(Block & block, size_t starting_pos = 0); + +/// Split key and other columns by keys name list ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block, Block & sample_block_with_keys, Block & sample_block_with_columns_to_add); + +/// Throw an exception if blocks have different types of key columns. Compare up to Nullability. +void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right); + void createMissedColumns(Block & block); } + +} diff --git a/dbms/src/Interpreters/Join.cpp b/dbms/src/Interpreters/Join.cpp index a0607837e12..6b3d9351740 100644 --- a/dbms/src/Interpreters/Join.cpp +++ b/dbms/src/Interpreters/Join.cpp @@ -36,21 +36,11 @@ namespace ErrorCodes } -static void convertColumnToNullable(ColumnWithTypeAndName & column) -{ - if (column.type->isNullable() || !column.type->canBeInsideNullable()) - return; - - column.type = makeNullable(column.type); - if (column.column) - column.column = makeNullable(column.column); -} - /// Converts column to nullable if needed. No backward convertion. static ColumnWithTypeAndName correctNullability(ColumnWithTypeAndName && column, bool nullable) { if (nullable) - convertColumnToNullable(column); + JoinCommon::convertColumnToNullable(column); return std::move(column); } @@ -58,7 +48,7 @@ static ColumnWithTypeAndName correctNullability(ColumnWithTypeAndName && column, { if (nullable) { - convertColumnToNullable(column); + JoinCommon::convertColumnToNullable(column); if (column.type->isNullable() && negative_null_map.size()) { MutableColumnPtr mutable_column = (*std::move(column.column)).mutate(); @@ -264,7 +254,7 @@ void Join::setSampleBlock(const Block & block) if (!empty()) return; - ColumnRawPtrs key_columns = extractKeysForJoin(key_names_right, block, right_table_keys, sample_block_with_columns_to_add); + ColumnRawPtrs key_columns = JoinCommon::extractKeysForJoin(key_names_right, block, right_table_keys, sample_block_with_columns_to_add); if (strictness == ASTTableJoin::Strictness::Asof) { @@ -303,15 +293,11 @@ void Join::setSampleBlock(const Block & block) blocklist_sample = Block(block.getColumnsWithTypeAndName()); prepareBlockListStructure(blocklist_sample); - createMissedColumns(sample_block_with_columns_to_add); + JoinCommon::createMissedColumns(sample_block_with_columns_to_add); /// In case of LEFT and FULL joins, if use_nulls, convert joined columns to Nullable. if (use_nulls && isLeftOrFull(kind)) - { - size_t num_columns_to_add = sample_block_with_columns_to_add.columns(); - for (size_t i = 0; i < num_columns_to_add; ++i) - convertColumnToNullable(sample_block_with_columns_to_add.getByPosition(i)); - } + JoinCommon::convertColumnsToNullable(sample_block_with_columns_to_add); } namespace @@ -500,12 +486,7 @@ bool Join::addJoinedBlock(const Block & block) /// In case of LEFT and FULL joins, if use_nulls, convert joined columns to Nullable. if (use_nulls && isLeftOrFull(kind)) - { - for (size_t i = isFull(kind) ? keys_size : 0; i < size; ++i) - { - convertColumnToNullable(stored_block->getByPosition(i)); - } - } + JoinCommon::convertColumnsToNullable(*stored_block, (isFull(kind) ? keys_size : 0)); if (kind != ASTTableJoin::Kind::Cross) { @@ -769,12 +750,11 @@ void Join::joinBlockImpl( constexpr bool right_or_full = static_in_v; if constexpr (right_or_full) { - for (size_t i = 0; i < existing_columns; ++i) - { + for (size_t i = 0; i < block.columns(); ++i) block.getByPosition(i).column = block.getByPosition(i).column->convertToFullColumnIfConst(); - if (use_nulls) - convertColumnToNullable(block.getByPosition(i)); - } + + if (use_nulls) + JoinCommon::convertColumnsToNullable(block); } /** For LEFT/INNER JOIN, the saved blocks do not contain keys. @@ -925,27 +905,6 @@ void Join::joinBlockImplCross(Block & block) const block = block.cloneWithColumns(std::move(dst_columns)); } - -void Join::checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right) const -{ - size_t keys_size = key_names_left.size(); - - for (size_t i = 0; i < keys_size; ++i) - { - /// Compare up to Nullability. - - DataTypePtr left_type = removeNullable(recursiveRemoveLowCardinality(block_left.getByName(key_names_left[i]).type)); - DataTypePtr right_type = removeNullable(recursiveRemoveLowCardinality(block_right.getByName(key_names_right[i]).type)); - - if (!left_type->equals(*right_type)) - throw Exception("Type mismatch of columns to JOIN by: " - + key_names_left[i] + " " + left_type->getName() + " at left, " - + key_names_right[i] + " " + right_type->getName() + " at right", - ErrorCodes::TYPE_MISMATCH); - } -} - - static void checkTypeOfKey(const Block & block_left, const Block & block_right) { auto & [c1, left_type_origin, left_name] = block_left.safeGetByPosition(0); @@ -1002,11 +961,10 @@ void Join::joinGet(Block & block, const String & column_name) const void Join::joinBlock(Block & block) { - const Names & key_names_left = join_options.keyNamesLeft(); - std::shared_lock lock(rwlock); - checkTypesOfKeys(block, key_names_left, right_table_keys); + const Names & key_names_left = join_options.keyNamesLeft(); + JoinCommon::checkTypesOfKeys(block, key_names_left, right_table_keys, key_names_right); if (joinDispatch(kind, strictness, maps, [&](auto kind_, auto strictness_, auto & map) { @@ -1206,8 +1164,7 @@ private: /// Convert left columns to Nullable if allowed if (parent.use_nulls) - for (size_t i = 0; i < result_sample_block.columns(); ++i) - convertColumnToNullable(result_sample_block.getByPosition(i)); + JoinCommon::convertColumnsToNullable(result_sample_block); /// Add columns from the right-side table to the block. for (size_t i = 0; i < right_sample_block.columns(); ++i) diff --git a/dbms/src/Interpreters/Join.h b/dbms/src/Interpreters/Join.h index 1ed446034d4..fe84dac485f 100644 --- a/dbms/src/Interpreters/Join.h +++ b/dbms/src/Interpreters/Join.h @@ -340,9 +340,6 @@ private: */ void prepareBlockListStructure(Block & stored_block); - /// Throw an exception if blocks have different types of key columns. - void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right) const; - template void joinBlockImpl( Block & block, diff --git a/dbms/src/Interpreters/MergeJoin.cpp b/dbms/src/Interpreters/MergeJoin.cpp index 16abc0ee94c..2d4f582f50e 100644 --- a/dbms/src/Interpreters/MergeJoin.cpp +++ b/dbms/src/Interpreters/MergeJoin.cpp @@ -16,8 +16,8 @@ MergeJoin::MergeJoin(const AnalyzedJoin & table_join_, const Block & right_sampl : table_join(table_join_) , required_right_keys(table_join.requiredRightKeys()) { - extractKeysForJoin(table_join.keyNamesRight(), right_sample_block, right_table_keys, sample_block_with_columns_to_add); - createMissedColumns(sample_block_with_columns_to_add); + JoinCommon::extractKeysForJoin(table_join.keyNamesRight(), right_sample_block, right_table_keys, sample_block_with_columns_to_add); + JoinCommon::createMissedColumns(sample_block_with_columns_to_add); } /// TODO: sort @@ -34,10 +34,12 @@ bool MergeJoin::addJoinedBlock(const Block & block) void MergeJoin::joinBlock(Block & block) { - addRightColumns(block); - std::shared_lock lock(rwlock); + JoinCommon::checkTypesOfKeys(block, table_join.keyNamesLeft(), right_table_keys, table_join.keyNamesRight()); + + addRightColumns(block); + for (auto it = right_blocks.begin(); it != right_blocks.end(); ++it) mergeJoin(block, *it); }