diff --git a/src/Interpreters/HashJoin.cpp b/src/Interpreters/HashJoin.cpp index 1da55537408..e35b0e8a225 100644 --- a/src/Interpreters/HashJoin.cpp +++ b/src/Interpreters/HashJoin.cpp @@ -765,7 +765,7 @@ bool HashJoin::addJoinedBlock(const Block & source_block, bool check_limits) auto join_mask_col = JoinCommon::getColumnAsMask(source_block, onexprs[onexpr_idx].condColumnNames().second); /// Save blocks that do not hold conditions in ON section ColumnUInt8::MutablePtr not_joined_map = nullptr; - if (!multiple_disjuncts && isRightOrFull(kind) && !join_mask_col.isConstant()) + if (!multiple_disjuncts && isRightOrFull(kind) && join_mask_col.hasData()) { const auto & join_mask = join_mask_col.getData(); /// Save rows that do not hold conditions @@ -845,7 +845,6 @@ struct JoinOnKeyColumns Sizes key_sizes; - explicit JoinOnKeyColumns(const Block & block, const Names & key_names_, const String & cond_column_name, const Sizes & key_sizes_) : key_names(key_names_) , materialized_keys_holder(JoinCommon::materializeColumns(block, key_names)) /// Rare case, when keys are constant or low cardinality. To avoid code bloat, simply materialize them. diff --git a/src/Interpreters/JoinUtils.cpp b/src/Interpreters/JoinUtils.cpp index 9a0781cd2f3..b8d8dd5df74 100644 --- a/src/Interpreters/JoinUtils.cpp +++ b/src/Interpreters/JoinUtils.cpp @@ -532,24 +532,24 @@ bool typesEqualUpToNullability(DataTypePtr left_type, DataTypePtr right_type) JoinMask getColumnAsMask(const Block & block, const String & column_name) { if (column_name.empty()) - return JoinMask(true); + return JoinMask(true, block.rows()); const auto & src_col = block.getByName(column_name); DataTypePtr col_type = recursiveRemoveLowCardinality(src_col.type); if (isNothing(col_type)) - return JoinMask(false); + return JoinMask(false, block.rows()); if (const auto * const_cond = checkAndGetColumn(*src_col.column)) { - return JoinMask(const_cond->getBool(0)); + return JoinMask(const_cond->getBool(0), block.rows()); } ColumnPtr join_condition_col = recursiveRemoveLowCardinality(src_col.column->convertToFullColumnIfConst()); if (const auto * nullable_col = typeid_cast(join_condition_col.get())) { if (isNothing(assert_cast(*col_type).getNestedType())) - return JoinMask(false); + return JoinMask(false, block.rows()); /// Return nested column with NULL set to false const auto & nest_col = assert_cast(nullable_col->getNestedColumn()); diff --git a/src/Interpreters/JoinUtils.h b/src/Interpreters/JoinUtils.h index b5bdf801b0a..f4f5f5bdc8d 100644 --- a/src/Interpreters/JoinUtils.h +++ b/src/Interpreters/JoinUtils.h @@ -21,24 +21,26 @@ using UInt8ColumnDataPtr = const ColumnUInt8::Container *; namespace JoinCommon { -/// Store boolean column handling constant value without materializing -/// Behaves similar to std::variant, but provides more convenient specialized interface class JoinMask { public: - explicit JoinMask(bool value) + explicit JoinMask() : column(nullptr) - , const_value(value) + {} + + explicit JoinMask(bool value, size_t size) + : column(ColumnUInt8::create(size, value)) {} explicit JoinMask(ColumnPtr col) : column(col) - , const_value(false) {} - bool isConstant() { return !column; } + bool hasData() + { + return column != nullptr; + } - /// Return data if mask is not constant UInt8ColumnDataPtr getData() { if (column) @@ -48,15 +50,11 @@ public: inline bool isRowFiltered(size_t row) const { - if (column) - return !assert_cast(*column).getData()[row]; - return !const_value; + return !assert_cast(*column).getData()[row]; } private: ColumnPtr column; - /// Used if column is null - bool const_value; }; diff --git a/src/Interpreters/MergeJoin.cpp b/src/Interpreters/MergeJoin.cpp index 2d54accc76a..1b3f35614f9 100644 --- a/src/Interpreters/MergeJoin.cpp +++ b/src/Interpreters/MergeJoin.cpp @@ -55,7 +55,7 @@ ColumnWithTypeAndName condtitionColumnToJoinable(const Block & block, const Stri if (!src_column_name.empty()) { auto join_mask = JoinCommon::getColumnAsMask(block, src_column_name); - if (!join_mask.isConstant()) + if (join_mask.hasData()) { for (size_t i = 0; i < res_size; ++i) null_map->getData()[i] = join_mask.isRowFiltered(i);