extract more common join functions

This commit is contained in:
chertus 2019-09-11 21:03:21 +03:00
parent a836f0cfd6
commit fc7ce2753d
5 changed files with 78 additions and 65 deletions

View File

@ -1,11 +1,37 @@
#include <Interpreters/IJoin.h>
#include <Columns/ColumnNullable.h>
#include <DataStreams/materializeBlock.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataStreams/materializeBlock.h>
namespace DB
{
namespace ErrorCodes
{
extern const int TYPE_MISMATCH;
}
namespace JoinCommon
{
void convertColumnToNullable(ColumnWithTypeAndName & column)
{
if (column.type->isNullable() || !column.type->canBeInsideNullable())
return;
column.type = makeNullable(column.type);
if (column.column)
column.column = makeNullable(column.column);
}
void convertColumnsToNullable(Block & block, size_t starting_pos)
{
for (size_t i = starting_pos; i < block.columns(); ++i)
convertColumnToNullable(block.getByPosition(i));
}
ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block,
Block & sample_block_with_keys, Block & sample_block_with_columns_to_add)
{
@ -43,6 +69,23 @@ ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & ri
return key_columns;
}
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right)
{
size_t keys_size = key_names_left.size();
for (size_t i = 0; i < keys_size; ++i)
{
DataTypePtr left_type = removeNullable(recursiveRemoveLowCardinality(block_left.getByName(key_names_left[i]).type));
DataTypePtr right_type = removeNullable(recursiveRemoveLowCardinality(block_right.getByName(key_names_right[i]).type));
if (!left_type->equals(*right_type))
throw Exception("Type mismatch of columns to JOIN by: "
+ key_names_left[i] + " " + left_type->getName() + " at left, "
+ key_names_right[i] + " " + right_type->getName() + " at right",
ErrorCodes::TYPE_MISMATCH);
}
}
void createMissedColumns(Block & block)
{
for (size_t i = 0; i < block.columns(); ++i)
@ -54,3 +97,4 @@ void createMissedColumns(Block & block)
}
}
}

View File

@ -8,6 +8,7 @@
namespace DB
{
struct ColumnWithTypeAndName;
class Block;
class IColumn;
using ColumnRawPtrs = std::vector<const IColumn *>;
@ -34,10 +35,22 @@ public:
using JoinPtr = std::shared_ptr<IJoin>;
/// Common join functions
namespace JoinCommon
{
void convertColumnToNullable(ColumnWithTypeAndName & column);
void convertColumnsToNullable(Block & block, size_t starting_pos = 0);
/// Split key and other columns by keys name list
ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block,
Block & sample_block_with_keys, Block & sample_block_with_columns_to_add);
/// Throw an exception if blocks have different types of key columns. Compare up to Nullability.
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right);
void createMissedColumns(Block & block);
}
}

View File

@ -36,21 +36,11 @@ namespace ErrorCodes
}
static void convertColumnToNullable(ColumnWithTypeAndName & column)
{
if (column.type->isNullable() || !column.type->canBeInsideNullable())
return;
column.type = makeNullable(column.type);
if (column.column)
column.column = makeNullable(column.column);
}
/// Converts column to nullable if needed. No backward convertion.
static ColumnWithTypeAndName correctNullability(ColumnWithTypeAndName && column, bool nullable)
{
if (nullable)
convertColumnToNullable(column);
JoinCommon::convertColumnToNullable(column);
return std::move(column);
}
@ -58,7 +48,7 @@ static ColumnWithTypeAndName correctNullability(ColumnWithTypeAndName && column,
{
if (nullable)
{
convertColumnToNullable(column);
JoinCommon::convertColumnToNullable(column);
if (column.type->isNullable() && negative_null_map.size())
{
MutableColumnPtr mutable_column = (*std::move(column.column)).mutate();
@ -264,7 +254,7 @@ void Join::setSampleBlock(const Block & block)
if (!empty())
return;
ColumnRawPtrs key_columns = extractKeysForJoin(key_names_right, block, right_table_keys, sample_block_with_columns_to_add);
ColumnRawPtrs key_columns = JoinCommon::extractKeysForJoin(key_names_right, block, right_table_keys, sample_block_with_columns_to_add);
if (strictness == ASTTableJoin::Strictness::Asof)
{
@ -303,15 +293,11 @@ void Join::setSampleBlock(const Block & block)
blocklist_sample = Block(block.getColumnsWithTypeAndName());
prepareBlockListStructure(blocklist_sample);
createMissedColumns(sample_block_with_columns_to_add);
JoinCommon::createMissedColumns(sample_block_with_columns_to_add);
/// In case of LEFT and FULL joins, if use_nulls, convert joined columns to Nullable.
if (use_nulls && isLeftOrFull(kind))
{
size_t num_columns_to_add = sample_block_with_columns_to_add.columns();
for (size_t i = 0; i < num_columns_to_add; ++i)
convertColumnToNullable(sample_block_with_columns_to_add.getByPosition(i));
}
JoinCommon::convertColumnsToNullable(sample_block_with_columns_to_add);
}
namespace
@ -500,12 +486,7 @@ bool Join::addJoinedBlock(const Block & block)
/// In case of LEFT and FULL joins, if use_nulls, convert joined columns to Nullable.
if (use_nulls && isLeftOrFull(kind))
{
for (size_t i = isFull(kind) ? keys_size : 0; i < size; ++i)
{
convertColumnToNullable(stored_block->getByPosition(i));
}
}
JoinCommon::convertColumnsToNullable(*stored_block, (isFull(kind) ? keys_size : 0));
if (kind != ASTTableJoin::Kind::Cross)
{
@ -769,12 +750,11 @@ void Join::joinBlockImpl(
constexpr bool right_or_full = static_in_v<KIND, ASTTableJoin::Kind::Right, ASTTableJoin::Kind::Full>;
if constexpr (right_or_full)
{
for (size_t i = 0; i < existing_columns; ++i)
{
for (size_t i = 0; i < block.columns(); ++i)
block.getByPosition(i).column = block.getByPosition(i).column->convertToFullColumnIfConst();
if (use_nulls)
convertColumnToNullable(block.getByPosition(i));
}
if (use_nulls)
JoinCommon::convertColumnsToNullable(block);
}
/** For LEFT/INNER JOIN, the saved blocks do not contain keys.
@ -925,27 +905,6 @@ void Join::joinBlockImplCross(Block & block) const
block = block.cloneWithColumns(std::move(dst_columns));
}
void Join::checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right) const
{
size_t keys_size = key_names_left.size();
for (size_t i = 0; i < keys_size; ++i)
{
/// Compare up to Nullability.
DataTypePtr left_type = removeNullable(recursiveRemoveLowCardinality(block_left.getByName(key_names_left[i]).type));
DataTypePtr right_type = removeNullable(recursiveRemoveLowCardinality(block_right.getByName(key_names_right[i]).type));
if (!left_type->equals(*right_type))
throw Exception("Type mismatch of columns to JOIN by: "
+ key_names_left[i] + " " + left_type->getName() + " at left, "
+ key_names_right[i] + " " + right_type->getName() + " at right",
ErrorCodes::TYPE_MISMATCH);
}
}
static void checkTypeOfKey(const Block & block_left, const Block & block_right)
{
auto & [c1, left_type_origin, left_name] = block_left.safeGetByPosition(0);
@ -1002,11 +961,10 @@ void Join::joinGet(Block & block, const String & column_name) const
void Join::joinBlock(Block & block)
{
const Names & key_names_left = join_options.keyNamesLeft();
std::shared_lock lock(rwlock);
checkTypesOfKeys(block, key_names_left, right_table_keys);
const Names & key_names_left = join_options.keyNamesLeft();
JoinCommon::checkTypesOfKeys(block, key_names_left, right_table_keys, key_names_right);
if (joinDispatch(kind, strictness, maps, [&](auto kind_, auto strictness_, auto & map)
{
@ -1206,8 +1164,7 @@ private:
/// Convert left columns to Nullable if allowed
if (parent.use_nulls)
for (size_t i = 0; i < result_sample_block.columns(); ++i)
convertColumnToNullable(result_sample_block.getByPosition(i));
JoinCommon::convertColumnsToNullable(result_sample_block);
/// Add columns from the right-side table to the block.
for (size_t i = 0; i < right_sample_block.columns(); ++i)

View File

@ -340,9 +340,6 @@ private:
*/
void prepareBlockListStructure(Block & stored_block);
/// Throw an exception if blocks have different types of key columns.
void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right) const;
template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename Maps>
void joinBlockImpl(
Block & block,

View File

@ -16,8 +16,8 @@ MergeJoin::MergeJoin(const AnalyzedJoin & table_join_, const Block & right_sampl
: table_join(table_join_)
, required_right_keys(table_join.requiredRightKeys())
{
extractKeysForJoin(table_join.keyNamesRight(), right_sample_block, right_table_keys, sample_block_with_columns_to_add);
createMissedColumns(sample_block_with_columns_to_add);
JoinCommon::extractKeysForJoin(table_join.keyNamesRight(), right_sample_block, right_table_keys, sample_block_with_columns_to_add);
JoinCommon::createMissedColumns(sample_block_with_columns_to_add);
}
/// TODO: sort
@ -34,10 +34,12 @@ bool MergeJoin::addJoinedBlock(const Block & block)
void MergeJoin::joinBlock(Block & block)
{
addRightColumns(block);
std::shared_lock lock(rwlock);
JoinCommon::checkTypesOfKeys(block, table_join.keyNamesLeft(), right_table_keys, table_join.keyNamesRight());
addRightColumns(block);
for (auto it = right_blocks.begin(); it != right_blocks.end(); ++it)
mergeJoin(block, *it);
}