some joins code unification

This commit is contained in:
chertus 2019-09-10 21:39:10 +03:00
parent 02691f50ef
commit 8afa48fa42
8 changed files with 103 additions and 72 deletions

View File

@ -142,6 +142,20 @@ Names AnalyzedJoin::requiredJoinedNames() const
return Names(required_columns_set.begin(), required_columns_set.end()); return Names(required_columns_set.begin(), required_columns_set.end());
} }
std::unordered_map<String, DataTypePtr> AnalyzedJoin::requiredRightKeys() const
{
NameSet right_keys;
for (const auto & name : key_names_right)
right_keys.insert(name);
std::unordered_map<String, DataTypePtr> required;
for (const auto & column : columns_added_by_join)
if (right_keys.count(column.name))
required.insert({column.name, column.type});
return required;
}
NamesWithAliases AnalyzedJoin::getRequiredColumns(const Block & sample, const Names & action_required_columns) const NamesWithAliases AnalyzedJoin::getRequiredColumns(const Block & sample, const Names & action_required_columns) const
{ {
NameSet required_columns(action_required_columns.begin(), action_required_columns.end()); NameSet required_columns(action_required_columns.begin(), action_required_columns.end());
@ -230,7 +244,7 @@ BlockInputStreamPtr AnalyzedJoin::createStreamWithNonJoinedDataIfFullOrRightJoin
{ {
if (isRightOrFull(table_join.kind)) if (isRightOrFull(table_join.kind))
if (auto hash_join = typeid_cast<Join *>(join.get())) if (auto hash_join = typeid_cast<Join *>(join.get()))
return hash_join->createStreamWithNonJoinedRows(source_header, *this, max_block_size); return hash_join->createStreamWithNonJoinedRows(source_header, max_block_size);
return {}; return {};
} }

View File

@ -92,6 +92,7 @@ public:
void deduplicateAndQualifyColumnNames(const NameSet & left_table_columns, const String & right_table_prefix); void deduplicateAndQualifyColumnNames(const NameSet & left_table_columns, const String & right_table_prefix);
size_t rightKeyInclusion(const String & name) const; size_t rightKeyInclusion(const String & name) const;
std::unordered_map<String, DataTypePtr> requiredRightKeys() const;
void addJoinedColumn(const NameAndTypePair & joined_column); void addJoinedColumn(const NameAndTypePair & joined_column);
void addJoinedColumnsAndCorrectNullability(Block & sample_block) const; void addJoinedColumnsAndCorrectNullability(Block & sample_block) const;

View File

@ -0,0 +1,56 @@
#include <Interpreters/IJoin.h>
#include <Columns/ColumnNullable.h>
#include <DataStreams/materializeBlock.h>
#include <DataTypes/DataTypeLowCardinality.h>
namespace DB
{
ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block,
Block & sample_block_with_keys, Block & sample_block_with_columns_to_add)
{
size_t keys_size = key_names_right.size();
ColumnRawPtrs key_columns(keys_size);
sample_block_with_columns_to_add = materializeBlock(right_sample_block);
for (size_t i = 0; i < keys_size; ++i)
{
const String & column_name = key_names_right[i];
/// there could be the same key names
if (sample_block_with_keys.has(column_name))
{
key_columns[i] = sample_block_with_keys.getByName(column_name).column.get();
continue;
}
auto & col = sample_block_with_columns_to_add.getByName(column_name);
col.column = recursiveRemoveLowCardinality(col.column);
col.type = recursiveRemoveLowCardinality(col.type);
/// Extract right keys with correct keys order.
sample_block_with_keys.insert(col);
sample_block_with_columns_to_add.erase(column_name);
key_columns[i] = sample_block_with_keys.getColumns().back().get();
/// We will join only keys, where all components are not NULL.
if (auto * nullable = checkAndGetColumn<ColumnNullable>(*key_columns[i]))
key_columns[i] = &nullable->getNestedColumn();
}
return key_columns;
}
void createMissedColumns(Block & block)
{
for (size_t i = 0; i < block.columns(); ++i)
{
auto & column = block.getByPosition(i);
if (!column.column)
column.column = column.type->createColumn();
}
}
}

View File

@ -1,11 +1,16 @@
#pragma once #pragma once
#include <memory> #include <memory>
#include <vector>
#include <Core/Names.h>
namespace DB namespace DB
{ {
class Block; class Block;
class IColumn;
using ColumnRawPtrs = std::vector<const IColumn *>;
class IJoin class IJoin
{ {
@ -29,4 +34,10 @@ public:
using JoinPtr = std::shared_ptr<IJoin>; using JoinPtr = std::shared_ptr<IJoin>;
/// Common join functions
ColumnRawPtrs extractKeysForJoin(const Names & key_names_right, const Block & right_sample_block,
Block & sample_block_with_keys, Block & sample_block_with_columns_to_add);
void createMissedColumns(Block & block);
} }

View File

@ -35,19 +35,6 @@ namespace ErrorCodes
extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_COLUMN;
} }
static std::unordered_map<String, DataTypePtr> requiredRightKeys(const Names & key_names, const NamesAndTypesList & columns_added_by_join)
{
NameSet right_keys;
for (const auto & name : key_names)
right_keys.insert(name);
std::unordered_map<String, DataTypePtr> required;
for (const auto & column : columns_added_by_join)
if (right_keys.count(column.name))
required.insert({column.name, column.type});
return required;
}
static void convertColumnToNullable(ColumnWithTypeAndName & column) static void convertColumnToNullable(ColumnWithTypeAndName & column)
{ {
@ -276,36 +263,7 @@ void Join::setSampleBlock(const Block & block)
if (!empty()) if (!empty())
return; return;
size_t keys_size = key_names_right.size(); ColumnRawPtrs key_columns = extractKeysForJoin(key_names_right, block, sample_block_with_keys, sample_block_with_columns_to_add);
ColumnRawPtrs key_columns(keys_size);
sample_block_with_columns_to_add = materializeBlock(block);
for (size_t i = 0; i < keys_size; ++i)
{
const String & column_name = key_names_right[i];
/// there could be the same key names
if (sample_block_with_keys.has(column_name))
{
key_columns[i] = sample_block_with_keys.getByName(column_name).column.get();
continue;
}
auto & col = sample_block_with_columns_to_add.getByName(column_name);
col.column = recursiveRemoveLowCardinality(col.column);
col.type = recursiveRemoveLowCardinality(col.type);
/// Extract right keys with correct keys order.
sample_block_with_keys.insert(col);
sample_block_with_columns_to_add.erase(column_name);
key_columns[i] = sample_block_with_keys.getColumns().back().get();
/// We will join only keys, where all components are not NULL.
if (auto * nullable = checkAndGetColumn<ColumnNullable>(*key_columns[i]))
key_columns[i] = &nullable->getNestedColumn();
}
if (strictness == ASTTableJoin::Strictness::Asof) if (strictness == ASTTableJoin::Strictness::Asof)
{ {
@ -344,19 +302,15 @@ void Join::setSampleBlock(const Block & block)
blocklist_sample = Block(block.getColumnsWithTypeAndName()); blocklist_sample = Block(block.getColumnsWithTypeAndName());
prepareBlockListStructure(blocklist_sample); prepareBlockListStructure(blocklist_sample);
size_t num_columns_to_add = sample_block_with_columns_to_add.columns(); createMissedColumns(sample_block_with_columns_to_add);
for (size_t i = 0; i < num_columns_to_add; ++i)
{
auto & column = sample_block_with_columns_to_add.getByPosition(i);
if (!column.column)
column.column = column.type->createColumn();
}
/// In case of LEFT and FULL joins, if use_nulls, convert joined columns to Nullable. /// In case of LEFT and FULL joins, if use_nulls, convert joined columns to Nullable.
if (use_nulls && isLeftOrFull(kind)) if (use_nulls && isLeftOrFull(kind))
{
size_t num_columns_to_add = sample_block_with_columns_to_add.columns();
for (size_t i = 0; i < num_columns_to_add; ++i) for (size_t i = 0; i < num_columns_to_add; ++i)
convertColumnToNullable(sample_block_with_columns_to_add.getByPosition(i)); convertColumnToNullable(sample_block_with_columns_to_add.getByPosition(i));
}
} }
namespace namespace
@ -784,7 +738,6 @@ template <ASTTableJoin::Kind KIND, ASTTableJoin::Strictness STRICTNESS, typename
void Join::joinBlockImpl( void Join::joinBlockImpl(
Block & block, Block & block,
const Names & key_names_left, const Names & key_names_left,
const NamesAndTypesList & columns_added_by_join,
const Block & block_with_columns_to_add, const Block & block_with_columns_to_add,
const Maps & maps_) const const Maps & maps_) const
{ {
@ -842,7 +795,7 @@ void Join::joinBlockImpl(
block.insert(added.moveColumn(i)); block.insert(added.moveColumn(i));
/// Filter & insert missing rows /// Filter & insert missing rows
auto right_keys = requiredRightKeys(key_names_right, columns_added_by_join); auto right_keys = join_options.requiredRightKeys();
constexpr bool is_all_join = STRICTNESS == ASTTableJoin::Strictness::All; constexpr bool is_all_join = STRICTNESS == ASTTableJoin::Strictness::All;
constexpr bool inner_or_right = static_in_v<KIND, ASTTableJoin::Kind::Inner, ASTTableJoin::Kind::Right>; constexpr bool inner_or_right = static_in_v<KIND, ASTTableJoin::Kind::Inner, ASTTableJoin::Kind::Right>;
@ -1025,7 +978,7 @@ template <typename Maps>
void Join::joinGetImpl(Block & block, const String & column_name, const Maps & maps_) const void Join::joinGetImpl(Block & block, const String & column_name, const Maps & maps_) const
{ {
joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>( joinBlockImpl<ASTTableJoin::Kind::Left, ASTTableJoin::Strictness::Any>(
block, {block.getByPosition(0).name}, {}, {sample_block_with_columns_to_add.getByName(column_name)}, maps_); block, {block.getByPosition(0).name}, {sample_block_with_columns_to_add.getByName(column_name)}, maps_);
} }
@ -1053,7 +1006,6 @@ void Join::joinGet(Block & block, const String & column_name) const
void Join::joinBlock(Block & block) void Join::joinBlock(Block & block)
{ {
const Names & key_names_left = join_options.keyNamesLeft(); const Names & key_names_left = join_options.keyNamesLeft();
const NamesAndTypesList & columns_added_by_join = join_options.columnsAddedByJoin();
std::shared_lock lock(rwlock); std::shared_lock lock(rwlock);
@ -1061,7 +1013,7 @@ void Join::joinBlock(Block & block)
if (joinDispatch(kind, strictness, maps, [&](auto kind_, auto strictness_, auto & map) if (joinDispatch(kind, strictness, maps, [&](auto kind_, auto strictness_, auto & map)
{ {
joinBlockImpl<kind_, strictness_>(block, key_names_left, columns_added_by_join, sample_block_with_columns_to_add, map); joinBlockImpl<kind_, strictness_>(block, key_names_left, sample_block_with_columns_to_add, map);
})) }))
{ {
/// Joined /// Joined
@ -1158,11 +1110,13 @@ struct AdderNonJoined<ASTTableJoin::Strictness::Asof, Mapped>
class NonJoinedBlockInputStream : public IBlockInputStream class NonJoinedBlockInputStream : public IBlockInputStream
{ {
public: public:
NonJoinedBlockInputStream(const Join & parent_, const Block & left_sample_block, const Names & key_names_left, NonJoinedBlockInputStream(const Join & parent_, const Block & left_sample_block, UInt64 max_block_size_)
const NamesAndTypesList & columns_added_by_join, UInt64 max_block_size_)
: parent(parent_) : parent(parent_)
, max_block_size(max_block_size_) , max_block_size(max_block_size_)
{ {
const Names & key_names_left = parent_.join_options.keyNamesLeft();
std::unordered_map<String, DataTypePtr> required_right_keys = parent_.join_options.requiredRightKeys();
/** left_sample_block contains keys and "left" columns. /** left_sample_block contains keys and "left" columns.
* result_sample_block - keys, "left" columns, and "right" columns. * result_sample_block - keys, "left" columns, and "right" columns.
*/ */
@ -1181,7 +1135,7 @@ public:
const Block & right_sample_block = parent.sample_block_with_columns_to_add; const Block & right_sample_block = parent.sample_block_with_columns_to_add;
std::unordered_map<size_t, size_t> left_to_right_key_map; std::unordered_map<size_t, size_t> left_to_right_key_map;
makeResultSampleBlock(left_sample_block, right_sample_block, columns_added_by_join, makeResultSampleBlock(left_sample_block, right_sample_block, required_right_keys,
key_positions_left, left_to_right_key_map); key_positions_left, left_to_right_key_map);
auto nullability_changes = getNullabilityChanges(parent.sample_block_with_keys, result_sample_block, auto nullability_changes = getNullabilityChanges(parent.sample_block_with_keys, result_sample_block,
@ -1250,7 +1204,7 @@ private:
void makeResultSampleBlock(const Block & left_sample_block, const Block & right_sample_block, void makeResultSampleBlock(const Block & left_sample_block, const Block & right_sample_block,
const NamesAndTypesList & columns_added_by_join, const std::unordered_map<String, DataTypePtr> & right_keys,
const std::vector<size_t> & key_positions_left, const std::vector<size_t> & key_positions_left,
std::unordered_map<size_t, size_t> & left_to_right_key_map) std::unordered_map<size_t, size_t> & left_to_right_key_map)
{ {
@ -1270,7 +1224,6 @@ private:
} }
const auto & key_names_right = parent.key_names_right; const auto & key_names_right = parent.key_names_right;
auto right_keys = requiredRightKeys(key_names_right, columns_added_by_join);
/// Add join key columns from right block if they has different name. /// Add join key columns from right block if they has different name.
for (size_t i = 0; i < key_names_right.size(); ++i) for (size_t i = 0; i < key_names_right.size(); ++i)
@ -1462,11 +1415,9 @@ private:
}; };
BlockInputStreamPtr Join::createStreamWithNonJoinedRows(const Block & left_sample_block, const AnalyzedJoin & join_params, BlockInputStreamPtr Join::createStreamWithNonJoinedRows(const Block & left_sample_block, UInt64 max_block_size) const
UInt64 max_block_size) const
{ {
return std::make_shared<NonJoinedBlockInputStream>(*this, left_sample_block, return std::make_shared<NonJoinedBlockInputStream>(*this, left_sample_block, max_block_size);
join_params.keyNamesLeft(), join_params.columnsAddedByJoin(), max_block_size);
} }

View File

@ -158,8 +158,7 @@ public:
* Use only after all calls to joinBlock was done. * Use only after all calls to joinBlock was done.
* left_sample_block is passed without account of 'use_nulls' setting (columns will be converted to Nullable inside). * left_sample_block is passed without account of 'use_nulls' setting (columns will be converted to Nullable inside).
*/ */
BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & left_sample_block, const AnalyzedJoin & join_params, BlockInputStreamPtr createStreamWithNonJoinedRows(const Block & left_sample_block, UInt64 max_block_size) const;
UInt64 max_block_size) const;
/// Number of keys in all built JOIN maps. /// Number of keys in all built JOIN maps.
size_t getTotalRowCount() const override; size_t getTotalRowCount() const override;
@ -346,7 +345,6 @@ private:
void joinBlockImpl( void joinBlockImpl(
Block & block, Block & block,
const Names & key_names_left, const Names & key_names_left,
const NamesAndTypesList & columns_added_by_join,
const Block & block_with_columns_to_add, const Block & block_with_columns_to_add,
const Maps & maps) const; const Maps & maps) const;

View File

@ -14,10 +14,9 @@ namespace ErrorCodes
MergeJoin::MergeJoin(const AnalyzedJoin & table_join_, const Block & right_sample_block) MergeJoin::MergeJoin(const AnalyzedJoin & table_join_, const Block & right_sample_block)
: table_join(table_join_) : table_join(table_join_)
, sample_block_with_columns_to_add(materializeBlock(right_sample_block))
{ {
for (auto & column : table_join.columnsAddedByJoin()) extractKeysForJoin(table_join.keyNamesRight(), right_sample_block, sample_block_with_keys, sample_block_with_columns_to_add);
sample_block_with_columns_to_add.getByName(column.name); createMissedColumns(sample_block_with_columns_to_add);
} }
/// TODO: sort /// TODO: sort

View File

@ -26,6 +26,7 @@ public:
private: private:
mutable std::shared_mutex rwlock; mutable std::shared_mutex rwlock;
const AnalyzedJoin & table_join; const AnalyzedJoin & table_join;
Block sample_block_with_keys;
Block sample_block_with_columns_to_add; Block sample_block_with_columns_to_add;
BlocksList right_blocks; BlocksList right_blocks;
size_t right_blocks_row_count = 0; size_t right_blocks_row_count = 0;