Merge pull request #27021 from vdimir/join-on-condition-constant

This commit is contained in:
Vladimir C 2021-11-08 15:52:24 +03:00 committed by GitHub
commit 6f369319c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 83 additions and 36 deletions

View File

@ -777,15 +777,15 @@ bool HashJoin::addJoinedBlock(const Block & source_block, bool check_limits)
auto join_mask_col = JoinCommon::getColumnAsMask(block, onexprs[onexpr_idx].condColumnNames().second);
/// Save blocks that do not hold conditions in ON section
ColumnUInt8::MutablePtr not_joined_map = nullptr;
if (!multiple_disjuncts && isRightOrFull(kind) && join_mask_col)
if (!multiple_disjuncts && isRightOrFull(kind) && !join_mask_col.isConstant())
{
const auto & join_mask = assert_cast<const ColumnUInt8 &>(*join_mask_col).getData();
const auto & join_mask = join_mask_col.getData();
/// Save rows that do not hold conditions
not_joined_map = ColumnUInt8::create(block.rows(), 0);
for (size_t i = 0, sz = join_mask.size(); i < sz; ++i)
for (size_t i = 0, sz = join_mask->size(); i < sz; ++i)
{
/// Condition hold, do not save row
if (join_mask[i])
if ((*join_mask)[i])
continue;
/// NULL key will be saved anyway because, do not save twice
@ -802,7 +802,8 @@ bool HashJoin::addJoinedBlock(const Block & source_block, bool check_limits)
{
size_t size = insertFromBlockImpl<strictness_>(
*this, data->type, map, rows, key_columns, key_sizes[onexpr_idx], stored_block, null_map,
join_mask_col ? &assert_cast<const ColumnUInt8 &>(*join_mask_col).getData() : nullptr,
/// If mask is false constant, rows are added to hashmap anyway. It's not a happy-flow, so this case is not optimized
join_mask_col.getData(),
data->pool);
if (multiple_disjuncts)
@ -846,7 +847,7 @@ struct JoinOnKeyColumns
ColumnPtr null_map_holder;
/// Only rows where mask == true can be joined
ColumnPtr join_mask_column;
JoinCommon::JoinMask join_mask_column;
Sizes key_sizes;
@ -859,17 +860,10 @@ struct JoinOnKeyColumns
, null_map_holder(extractNestedColumnsAndNullMap(key_columns, null_map))
, join_mask_column(JoinCommon::getColumnAsMask(block, cond_column_name))
, key_sizes(key_sizes_)
{}
bool isRowFiltered(size_t i) const
{
if (join_mask_column)
{
UInt8ColumnDataPtr mask = &assert_cast<const ColumnUInt8 &>(*(join_mask_column)).getData();
return !(*mask)[i];
}
return false;
}
bool isRowFiltered(size_t i) const { return join_mask_column.isRowFiltered(i); }
};
class AddedColumns
@ -985,6 +979,7 @@ public:
const IColumn & leftAsofKey() const { return *left_asof_key; }
std::vector<JoinOnKeyColumns> join_on_keys;
size_t rows_to_add;
std::unique_ptr<IColumn::Offsets> offsets_to_replicate;
bool need_filter = false;
@ -998,6 +993,7 @@ private:
std::optional<TypeIndex> asof_type;
ASOF::Inequality asof_inequality;
const IColumn * left_asof_key = nullptr;
bool is_join_get;
void addColumn(const ColumnWithTypeAndName & src_column, const std::string & qualified_name)
@ -1949,12 +1945,14 @@ private:
for (auto & it = *nulls_position; it != end && rows_added < max_block_size; ++it)
{
const Block * block = it->first;
const NullMap & nullmap = assert_cast<const ColumnUInt8 &>(*it->second).getData();
const auto * block = it->first;
ConstNullMapPtr nullmap = nullptr;
if (it->second)
nullmap = &assert_cast<const ColumnUInt8 &>(*it->second).getData();
for (size_t row = 0; row < nullmap.size(); ++row)
for (size_t row = 0; row < block->rows(); ++row)
{
if (nullmap[row])
if (nullmap && (*nullmap)[row])
{
for (size_t col = 0; col < columns_keys_and_right.size(); ++col)
columns_keys_and_right[col]->insertFrom(*block->getByPosition(col).column, row);

View File

@ -50,12 +50,12 @@ ColumnWithTypeAndName condtitionColumnToJoinable(const Block & block, const Stri
if (!src_column_name.empty())
{
auto mask_col = JoinCommon::getColumnAsMask(block, src_column_name);
assert(mask_col);
const auto & mask_data = assert_cast<const ColumnUInt8 &>(*mask_col).getData();
for (size_t i = 0; i < res_size; ++i)
null_map->getData()[i] = !mask_data[i];
auto join_mask = JoinCommon::getColumnAsMask(block, src_column_name);
if (!join_mask.isConstant())
{
for (size_t i = 0; i < res_size; ++i)
null_map->getData()[i] = join_mask.isRowFiltered(i);
}
}
ColumnPtr res_col = ColumnNullable::create(std::move(data_col), std::move(null_map));
@ -477,6 +477,7 @@ MergeJoin::MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right
, max_joined_block_rows(table_join->maxJoinedBlockRows())
, max_rows_in_right_block(table_join->maxRowsInRightBlock())
, max_files_to_merge(table_join->maxFilesToMerge())
, log(&Poco::Logger::get("MergeJoin"))
{
switch (table_join->strictness())
{
@ -549,6 +550,8 @@ MergeJoin::MergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right
makeSortAndMerge(key_names_left, left_sort_description, left_merge_description);
makeSortAndMerge(key_names_right, right_sort_description, right_merge_description);
LOG_DEBUG(log, "Joining keys: left [{}], right [{}]", fmt::join(key_names_left, ", "), fmt::join(key_names_right, ", "));
/// Temporary disable 'partial_merge_join_left_table_buffer_bytes' without 'partial_merge_join_optimizations'
if (table_join->enablePartialMergeJoinOptimizations())
if (size_t max_bytes = table_join->maxBytesInLeftBuffer())

View File

@ -118,6 +118,8 @@ private:
Names lowcard_right_keys;
Poco::Logger * log;
void changeLeftColumns(Block & block, MutableColumns && columns) const;
void addRightColumns(Block & block, MutableColumns && columns);

View File

@ -1,18 +1,18 @@
#include <Interpreters/join_common.h>
#include <Columns/ColumnConst.h>
#include <Columns/ColumnLowCardinality.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnConst.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/TableJoin.h>
#include <IO/WriteHelpers.h>
namespace DB
{
@ -492,23 +492,27 @@ bool typesEqualUpToNullability(DataTypePtr left_type, DataTypePtr right_type)
return left_type_strict->equals(*right_type_strict);
}
ColumnPtr getColumnAsMask(const Block & block, const String & column_name)
JoinMask getColumnAsMask(const Block & block, const String & column_name)
{
if (column_name.empty())
return nullptr;
return JoinMask(true);
const auto & src_col = block.getByName(column_name);
DataTypePtr col_type = recursiveRemoveLowCardinality(src_col.type);
if (isNothing(col_type))
return ColumnUInt8::create(block.rows(), 0);
return JoinMask(false);
const auto & join_condition_col = recursiveRemoveLowCardinality(src_col.column->convertToFullColumnIfConst());
if (const auto * const_cond = checkAndGetColumn<ColumnConst>(*src_col.column))
{
return JoinMask(const_cond->getBool(0));
}
ColumnPtr join_condition_col = recursiveRemoveLowCardinality(src_col.column->convertToFullColumnIfConst());
if (const auto * nullable_col = typeid_cast<const ColumnNullable *>(join_condition_col.get()))
{
if (isNothing(assert_cast<const DataTypeNullable &>(*col_type).getNestedType()))
return ColumnUInt8::create(block.rows(), 0);
return JoinMask(false);
/// Return nested column with NULL set to false
const auto & nest_col = assert_cast<const ColumnUInt8 &>(nullable_col->getNestedColumn());
@ -517,10 +521,10 @@ ColumnPtr getColumnAsMask(const Block & block, const String & column_name)
auto res = ColumnUInt8::create(nullable_col->size(), 0);
for (size_t i = 0, sz = nullable_col->size(); i < sz; ++i)
res->getData()[i] = !null_map.getData()[i] && nest_col.getData()[i];
return res;
return JoinMask(std::move(res));
}
else
return join_condition_col;
return JoinMask(std::move(join_condition_col));
}

View File

@ -19,6 +19,46 @@ using UInt8ColumnDataPtr = const ColumnUInt8::Container *;
namespace JoinCommon
{
/// Store boolean column handling constant value without materializing
/// Behaves similar to std::variant<bool, ColumnPtr>, but provides more convenient specialized interface
class JoinMask
{
public:
explicit JoinMask(bool value)
: column(nullptr)
, const_value(value)
{}
explicit JoinMask(ColumnPtr col)
: column(col)
, const_value(false)
{}
bool isConstant() { return !column; }
/// Return data if mask is not constant
UInt8ColumnDataPtr getData()
{
if (column)
return &assert_cast<const ColumnUInt8 &>(*column).getData();
return nullptr;
}
inline bool isRowFiltered(size_t row) const
{
if (column)
return !assert_cast<const ColumnUInt8 &>(*column).getData()[row];
return !const_value;
}
private:
ColumnPtr column;
/// Used if column is null
bool const_value;
};
bool canBecomeNullable(const DataTypePtr & type);
DataTypePtr convertTypeToNullable(const DataTypePtr & type);
void convertColumnToNullable(ColumnWithTypeAndName & column);
@ -58,7 +98,7 @@ void addDefaultValues(IColumn & column, const DataTypePtr & type, size_t count);
bool typesEqualUpToNullability(DataTypePtr left_type, DataTypePtr right_type);
/// Return mask array of type ColumnUInt8 for specified column. Source should have type UInt8 or Nullable(UInt8).
ColumnPtr getColumnAsMask(const Block & block, const String & column_name);
JoinMask getColumnAsMask(const Block & block, const String & column_name);
/// Split key and other columns by keys name list
void splitAdditionalColumns(const Names & key_names, const Block & sample_block, Block & block_keys, Block & block_others);