diff --git a/src/Common/ColumnsHashingImpl.h b/src/Common/ColumnsHashingImpl.h index 9af746a69ad..aa7ae6ea29d 100644 --- a/src/Common/ColumnsHashingImpl.h +++ b/src/Common/ColumnsHashingImpl.h @@ -124,6 +124,10 @@ class FindResultImpl : public FindResultImplBase, public FindResultImplOffsetBas Mapped * value; public: + FindResultImpl() + : FindResultImplBase(false), FindResultImplOffsetBase(0) + {} + FindResultImpl(Mapped * value_, bool found_, size_t off) : FindResultImplBase(found_), FindResultImplOffsetBase(off), value(value_) {} Mapped & getMapped() const { return *value; } diff --git a/src/Core/SortDescription.h b/src/Core/SortDescription.h index 41b4e5b6b32..e1653b9102b 100644 --- a/src/Core/SortDescription.h +++ b/src/Core/SortDescription.h @@ -42,15 +42,15 @@ struct SortColumnDescription bool with_fill; FillColumnDescription fill_description; - SortColumnDescription( - size_t column_number_, int direction_, int nulls_direction_, + explicit SortColumnDescription( + size_t column_number_, int direction_ = 1, int nulls_direction_ = 1, const std::shared_ptr & collator_ = nullptr, bool with_fill_ = false, const FillColumnDescription & fill_description_ = {}) : column_number(column_number_), direction(direction_), nulls_direction(nulls_direction_), collator(collator_) , with_fill(with_fill_), fill_description(fill_description_) {} - SortColumnDescription( - const std::string & column_name_, int direction_, int nulls_direction_, + explicit SortColumnDescription( + const std::string & column_name_, int direction_ = 1, int nulls_direction_ = 1, const std::shared_ptr & collator_ = nullptr, bool with_fill_ = false, const FillColumnDescription & fill_description_ = {}) : column_name(column_name_), column_number(0), direction(direction_), nulls_direction(nulls_direction_) diff --git a/src/Interpreters/CollectJoinOnKeysVisitor.cpp b/src/Interpreters/CollectJoinOnKeysVisitor.cpp index 3b3fdaa65cb..9715af01a0a 100644 --- a/src/Interpreters/CollectJoinOnKeysVisitor.cpp +++ b/src/Interpreters/CollectJoinOnKeysVisitor.cpp @@ -12,48 +12,77 @@ namespace ErrorCodes extern const int INVALID_JOIN_ON_EXPRESSION; extern const int AMBIGUOUS_COLUMN_NAME; extern const int SYNTAX_ERROR; - extern const int NOT_IMPLEMENTED; extern const int LOGICAL_ERROR; } -void CollectJoinOnKeysMatcher::Data::addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, - const std::pair & table_no) +namespace +{ + +bool isLeftIdentifier(JoinIdentifierPos pos) +{ + /// Unknown identifiers considered as left, we will try to process it on later stages + /// Usually such identifiers came from `ARRAY JOIN ... AS ...` + return pos == JoinIdentifierPos::Left || pos == JoinIdentifierPos::Unknown; +} + +bool isRightIdentifier(JoinIdentifierPos pos) +{ + return pos == JoinIdentifierPos::Right; +} + +} + +void CollectJoinOnKeysMatcher::Data::addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, JoinIdentifierPosPair table_pos) { ASTPtr left = left_ast->clone(); ASTPtr right = right_ast->clone(); - if (table_no.first == 1 || table_no.second == 2) + if (isLeftIdentifier(table_pos.first) && isRightIdentifier(table_pos.second)) analyzed_join.addOnKeys(left, right); - else if (table_no.first == 2 || table_no.second == 1) + else if (isRightIdentifier(table_pos.first) && isLeftIdentifier(table_pos.second)) analyzed_join.addOnKeys(right, left); else throw Exception("Cannot detect left and right JOIN keys. JOIN ON section is ambiguous.", - ErrorCodes::AMBIGUOUS_COLUMN_NAME); - has_some = true; + ErrorCodes::INVALID_JOIN_ON_EXPRESSION); } void CollectJoinOnKeysMatcher::Data::addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, - const std::pair & table_no, const ASOF::Inequality & inequality) + JoinIdentifierPosPair table_pos, const ASOF::Inequality & inequality) { - if (table_no.first == 1 || table_no.second == 2) + if (isLeftIdentifier(table_pos.first) && isRightIdentifier(table_pos.second)) { asof_left_key = left_ast->clone(); asof_right_key = right_ast->clone(); analyzed_join.setAsofInequality(inequality); } - else if (table_no.first == 2 || table_no.second == 1) + else if (isRightIdentifier(table_pos.first) && isLeftIdentifier(table_pos.second)) { asof_left_key = right_ast->clone(); asof_right_key = left_ast->clone(); analyzed_join.setAsofInequality(ASOF::reverseInequality(inequality)); } + else + { + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, + "Expressions {} and {} are from the same table but from different arguments of equal function in ASOF JOIN", + queryToString(left_ast), queryToString(right_ast)); + } } void CollectJoinOnKeysMatcher::Data::asofToJoinKeys() { if (!asof_left_key || !asof_right_key) throw Exception("No inequality in ASOF JOIN ON section.", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); - addJoinKeys(asof_left_key, asof_right_key, {1, 2}); + addJoinKeys(asof_left_key, asof_right_key, {JoinIdentifierPos::Left, JoinIdentifierPos::Right}); +} + +void CollectJoinOnKeysMatcher::visit(const ASTIdentifier & ident, const ASTPtr & ast, CollectJoinOnKeysMatcher::Data & data) +{ + if (auto expr_from_table = getTableForIdentifiers(ast, false, data); expr_from_table != JoinIdentifierPos::Unknown) + data.analyzed_join.addJoinCondition(ast, isLeftIdentifier(expr_from_table)); + else + throw Exception("Unexpected identifier '" + ident.name() + "' in JOIN ON section", + ErrorCodes::INVALID_JOIN_ON_EXPRESSION); } void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & ast, Data & data) @@ -61,9 +90,6 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as if (func.name == "and") return; /// go into children - if (func.name == "or") - throw Exception("JOIN ON does not support OR. Unexpected '" + queryToString(ast) + "'", ErrorCodes::NOT_IMPLEMENTED); - ASOF::Inequality inequality = ASOF::getInequality(func.name); if (func.name == "equals" || inequality != ASOF::Inequality::None) { @@ -71,32 +97,50 @@ void CollectJoinOnKeysMatcher::visit(const ASTFunction & func, const ASTPtr & as throw Exception("Function " + func.name + " takes two arguments, got '" + func.formatForErrorMessage() + "' instead", ErrorCodes::SYNTAX_ERROR); } - else - throw Exception("Expected equality or inequality, got '" + queryToString(ast) + "'", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); if (func.name == "equals") { ASTPtr left = func.arguments->children.at(0); ASTPtr right = func.arguments->children.at(1); - auto table_numbers = getTableNumbers(ast, left, right, data); - data.addJoinKeys(left, right, table_numbers); - } - else if (inequality != ASOF::Inequality::None) - { - if (!data.is_asof) - throw Exception("JOIN ON inequalities are not supported. Unexpected '" + queryToString(ast) + "'", - ErrorCodes::NOT_IMPLEMENTED); + auto table_numbers = getTableNumbers(left, right, data); + if (table_numbers.first == table_numbers.second) + { + if (table_numbers.first == JoinIdentifierPos::Unknown) + throw Exception("Ambiguous column in expression '" + queryToString(ast) + "' in JOIN ON section", + ErrorCodes::AMBIGUOUS_COLUMN_NAME); + data.analyzed_join.addJoinCondition(ast, isLeftIdentifier(table_numbers.first)); + return; + } + if (table_numbers.first != JoinIdentifierPos::NotApplicable && table_numbers.second != JoinIdentifierPos::NotApplicable) + { + data.addJoinKeys(left, right, table_numbers); + return; + } + } + + if (auto expr_from_table = getTableForIdentifiers(ast, false, data); expr_from_table != JoinIdentifierPos::Unknown) + { + data.analyzed_join.addJoinCondition(ast, isLeftIdentifier(expr_from_table)); + return; + } + + if (data.is_asof && inequality != ASOF::Inequality::None) + { if (data.asof_left_key || data.asof_right_key) throw Exception("ASOF JOIN expects exactly one inequality in ON section. Unexpected '" + queryToString(ast) + "'", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); ASTPtr left = func.arguments->children.at(0); ASTPtr right = func.arguments->children.at(1); - auto table_numbers = getTableNumbers(ast, left, right, data); + auto table_numbers = getTableNumbers(left, right, data); data.addAsofJoinKeys(left, right, table_numbers, inequality); + return; } + + throw Exception("Unsupported JOIN ON conditions. Unexpected '" + queryToString(ast) + "'", + ErrorCodes::INVALID_JOIN_ON_EXPRESSION); } void CollectJoinOnKeysMatcher::getIdentifiers(const ASTPtr & ast, std::vector & out) @@ -118,32 +162,10 @@ void CollectJoinOnKeysMatcher::getIdentifiers(const ASTPtr & ast, std::vector CollectJoinOnKeysMatcher::getTableNumbers(const ASTPtr & expr, const ASTPtr & left_ast, const ASTPtr & right_ast, - Data & data) +JoinIdentifierPosPair CollectJoinOnKeysMatcher::getTableNumbers(const ASTPtr & left_ast, const ASTPtr & right_ast, Data & data) { - std::vector left_identifiers; - std::vector right_identifiers; - - getIdentifiers(left_ast, left_identifiers); - getIdentifiers(right_ast, right_identifiers); - - if (left_identifiers.empty() || right_identifiers.empty()) - { - throw Exception("Not equi-join ON expression: " + queryToString(expr) + ". No columns in one of equality side.", - ErrorCodes::INVALID_JOIN_ON_EXPRESSION); - } - - size_t left_idents_table = getTableForIdentifiers(left_identifiers, data); - size_t right_idents_table = getTableForIdentifiers(right_identifiers, data); - - if (left_idents_table && left_idents_table == right_idents_table) - { - auto left_name = queryToString(*left_identifiers[0]); - auto right_name = queryToString(*right_identifiers[0]); - - throw Exception("In expression " + queryToString(expr) + " columns " + left_name + " and " + right_name - + " are from the same table but from different arguments of equal function", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); - } + auto left_idents_table = getTableForIdentifiers(left_ast, true, data); + auto right_idents_table = getTableForIdentifiers(right_ast, true, data); return std::make_pair(left_idents_table, right_idents_table); } @@ -173,11 +195,16 @@ const ASTIdentifier * CollectJoinOnKeysMatcher::unrollAliases(const ASTIdentifie return identifier; } -/// @returns 1 if identifiers belongs to left table, 2 for right table and 0 if unknown. Throws on table mix. +/// @returns Left or right table identifiers belongs to. /// Place detected identifier into identifiers[0] if any. -size_t CollectJoinOnKeysMatcher::getTableForIdentifiers(std::vector & identifiers, const Data & data) +JoinIdentifierPos CollectJoinOnKeysMatcher::getTableForIdentifiers(const ASTPtr & ast, bool throw_on_table_mix, const Data & data) { - size_t table_number = 0; + std::vector identifiers; + getIdentifiers(ast, identifiers); + if (identifiers.empty()) + return JoinIdentifierPos::NotApplicable; + + JoinIdentifierPos table_number = JoinIdentifierPos::Unknown; for (auto & ident : identifiers) { @@ -187,10 +214,20 @@ size_t CollectJoinOnKeysMatcher::getTableForIdentifiers(std::vectorname()); + } - if (!membership) + if (membership == JoinIdentifierPos::Unknown) { const String & name = identifier->name(); bool in_left_table = data.left_table.hasColumn(name); @@ -211,22 +248,24 @@ size_t CollectJoinOnKeysMatcher::getTableForIdentifiers(std::vectorgetAliasOrColumnName() + " and " + ident->getAliasOrColumnName() - + " are from different tables.", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); + if (throw_on_table_mix) + throw Exception("Invalid columns in JOIN ON section. Columns " + + identifiers[0]->getAliasOrColumnName() + " and " + ident->getAliasOrColumnName() + + " are from different tables.", ErrorCodes::INVALID_JOIN_ON_EXPRESSION); + return JoinIdentifierPos::Unknown; } } diff --git a/src/Interpreters/CollectJoinOnKeysVisitor.h b/src/Interpreters/CollectJoinOnKeysVisitor.h index 54e008a114e..0647f58f79b 100644 --- a/src/Interpreters/CollectJoinOnKeysVisitor.h +++ b/src/Interpreters/CollectJoinOnKeysVisitor.h @@ -18,6 +18,21 @@ namespace ASOF enum class Inequality; } +enum class JoinIdentifierPos +{ + /// Position can't be established, identifier not resolved + Unknown, + /// Left side of JOIN + Left, + /// Right side of JOIN + Right, + /// Expression not valid, e.g. doesn't contain identifiers + NotApplicable, +}; + +using JoinIdentifierPosPair = std::pair; + + class CollectJoinOnKeysMatcher { public: @@ -32,10 +47,9 @@ public: const bool is_asof{false}; ASTPtr asof_left_key{}; ASTPtr asof_right_key{}; - bool has_some{false}; - void addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no); - void addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, const std::pair & table_no, + void addJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, JoinIdentifierPosPair table_pos); + void addAsofJoinKeys(const ASTPtr & left_ast, const ASTPtr & right_ast, JoinIdentifierPosPair table_pos, const ASOF::Inequality & asof_inequality); void asofToJoinKeys(); }; @@ -43,7 +57,17 @@ public: static void visit(const ASTPtr & ast, Data & data) { if (auto * func = ast->as()) + { visit(*func, ast, data); + } + else if (auto * ident = ast->as()) + { + visit(*ident, ast, data); + } + else + { + /// visit children + } } static bool needChildVisit(const ASTPtr & node, const ASTPtr &) @@ -55,11 +79,12 @@ public: private: static void visit(const ASTFunction & func, const ASTPtr & ast, Data & data); + static void visit(const ASTIdentifier & ident, const ASTPtr & ast, Data & data); static void getIdentifiers(const ASTPtr & ast, std::vector & out); - static std::pair getTableNumbers(const ASTPtr & expr, const ASTPtr & left_ast, const ASTPtr & right_ast, Data & data); + static JoinIdentifierPosPair getTableNumbers(const ASTPtr & left_ast, const ASTPtr & right_ast, Data & data); static const ASTIdentifier * unrollAliases(const ASTIdentifier * identifier, const Aliases & aliases); - static size_t getTableForIdentifiers(std::vector & identifiers, const Data & data); + static JoinIdentifierPos getTableForIdentifiers(const ASTPtr & ast, bool throw_on_table_mix, const Data & data); }; /// Parse JOIN ON expression and collect ASTs for joined columns. diff --git a/src/Interpreters/ExpressionActions.cpp b/src/Interpreters/ExpressionActions.cpp index 905fcf0331c..6797947a101 100644 --- a/src/Interpreters/ExpressionActions.cpp +++ b/src/Interpreters/ExpressionActions.cpp @@ -812,6 +812,9 @@ void ExpressionActionsChain::JoinStep::finalize(const NameSet & required_output_ for (const auto & name : analyzed_join->keyNamesLeft()) required_names.emplace(name); + if (ASTPtr extra_condition_column = analyzed_join->joinConditionColumn(JoinTableSide::Left)) + required_names.emplace(extra_condition_column->getColumnName()); + for (const auto & column : required_columns) { if (required_names.count(column.name) != 0) diff --git a/src/Interpreters/HashJoin.cpp b/src/Interpreters/HashJoin.cpp index 56ad13511ac..dd17fc1004c 100644 --- a/src/Interpreters/HashJoin.cpp +++ b/src/Interpreters/HashJoin.cpp @@ -190,9 +190,12 @@ HashJoin::HashJoin(std::shared_ptr table_join_, const Block & right_s { LOG_DEBUG(log, "Right sample block: {}", right_sample_block.dumpStructure()); - table_join->splitAdditionalColumns(right_sample_block, right_table_keys, sample_block_with_columns_to_add); + JoinCommon::splitAdditionalColumns(key_names_right, right_sample_block, right_table_keys, sample_block_with_columns_to_add); + required_right_keys = table_join->getRequiredRightKeys(right_table_keys, required_right_keys_sources); + std::tie(condition_mask_column_name_left, condition_mask_column_name_right) = table_join->joinConditionColumnNames(); + JoinCommon::removeLowCardinalityInplace(right_table_keys); initRightBlockStructure(data->sample_block); @@ -500,7 +503,7 @@ namespace template size_t NO_INLINE insertFromBlockImplTypeCase( HashJoin & join, Map & map, size_t rows, const ColumnRawPtrs & key_columns, - const Sizes & key_sizes, Block * stored_block, ConstNullMapPtr null_map, Arena & pool) + const Sizes & key_sizes, Block * stored_block, ConstNullMapPtr null_map, UInt8ColumnDataPtr join_mask, Arena & pool) { [[maybe_unused]] constexpr bool mapped_one = std::is_same_v; constexpr bool is_asof_join = STRICTNESS == ASTTableJoin::Strictness::Asof; @@ -516,6 +519,10 @@ namespace if (has_null_map && (*null_map)[i]) continue; + /// Check condition for right table from ON section + if (join_mask && !(*join_mask)[i]) + continue; + if constexpr (is_asof_join) Inserter::insertAsof(join, map, key_getter, stored_block, i, pool, *asof_column); else if constexpr (mapped_one) @@ -530,19 +537,21 @@ namespace template size_t insertFromBlockImplType( HashJoin & join, Map & map, size_t rows, const ColumnRawPtrs & key_columns, - const Sizes & key_sizes, Block * stored_block, ConstNullMapPtr null_map, Arena & pool) + const Sizes & key_sizes, Block * stored_block, ConstNullMapPtr null_map, UInt8ColumnDataPtr join_mask, Arena & pool) { if (null_map) - return insertFromBlockImplTypeCase(join, map, rows, key_columns, key_sizes, stored_block, null_map, pool); + return insertFromBlockImplTypeCase( + join, map, rows, key_columns, key_sizes, stored_block, null_map, join_mask, pool); else - return insertFromBlockImplTypeCase(join, map, rows, key_columns, key_sizes, stored_block, null_map, pool); + return insertFromBlockImplTypeCase( + join, map, rows, key_columns, key_sizes, stored_block, null_map, join_mask, pool); } template size_t insertFromBlockImpl( HashJoin & join, HashJoin::Type type, Maps & maps, size_t rows, const ColumnRawPtrs & key_columns, - const Sizes & key_sizes, Block * stored_block, ConstNullMapPtr null_map, Arena & pool) + const Sizes & key_sizes, Block * stored_block, ConstNullMapPtr null_map, UInt8ColumnDataPtr join_mask, Arena & pool) { switch (type) { @@ -553,7 +562,7 @@ namespace #define M(TYPE) \ case HashJoin::Type::TYPE: \ return insertFromBlockImplType>::Type>(\ - join, *maps.TYPE, rows, key_columns, key_sizes, stored_block, null_map, pool); \ + join, *maps.TYPE, rows, key_columns, key_sizes, stored_block, null_map, join_mask, pool); \ break; APPLY_FOR_JOIN_VARIANTS(M) #undef M @@ -624,10 +633,34 @@ bool HashJoin::addJoinedBlock(const Block & source_block, bool check_limits) UInt8 save_nullmap = 0; if (isRightOrFull(kind) && null_map) { + /// Save rows with NULL keys for (size_t i = 0; !save_nullmap && i < null_map->size(); ++i) save_nullmap |= (*null_map)[i]; } + auto join_mask_col = JoinCommon::getColumnAsMask(block, condition_mask_column_name_right); + + /// Save blocks that do not hold conditions in ON section + ColumnUInt8::MutablePtr not_joined_map = nullptr; + if (isRightOrFull(kind) && join_mask_col) + { + const auto & join_mask = assert_cast(*join_mask_col).getData(); + /// Save rows that do not hold conditions + not_joined_map = ColumnUInt8::create(block.rows(), 0); + for (size_t i = 0, sz = join_mask.size(); i < sz; ++i) + { + /// Condition hold, do not save row + if (join_mask[i]) + continue; + + /// NULL key will be saved anyway because, do not save twice + if (save_nullmap && (*null_map)[i]) + continue; + + not_joined_map->getData()[i] = 1; + } + } + Block structured_block = structureRightBlock(block); size_t total_rows = 0; size_t total_bytes = 0; @@ -647,7 +680,10 @@ bool HashJoin::addJoinedBlock(const Block & source_block, bool check_limits) { joinDispatch(kind, strictness, data->maps, [&](auto kind_, auto strictness_, auto & map) { - size_t size = insertFromBlockImpl(*this, data->type, map, rows, key_columns, key_sizes, stored_block, null_map, data->pool); + size_t size = insertFromBlockImpl( + *this, data->type, map, rows, key_columns, key_sizes, stored_block, null_map, + join_mask_col ? &assert_cast(*join_mask_col).getData() : nullptr, + data->pool); /// Number of buckets + 1 value from zero storage used_flags.reinit(size + 1); }); @@ -656,6 +692,9 @@ bool HashJoin::addJoinedBlock(const Block & source_block, bool check_limits) if (save_nullmap) data->blocks_nullmaps.emplace_back(stored_block, null_map_holder); + if (not_joined_map) + data->blocks_nullmaps.emplace_back(stored_block, std::move(not_joined_map)); + if (!check_limits) return true; @@ -693,6 +732,7 @@ public: const HashJoin & join, const ColumnRawPtrs & key_columns_, const Sizes & key_sizes_, + const UInt8ColumnDataPtr & join_mask_column_, bool is_asof_join, bool is_join_get_) : key_columns(key_columns_) @@ -700,6 +740,7 @@ public: , rows_to_add(block.rows()) , asof_type(join.getAsofType()) , asof_inequality(join.getAsofInequality()) + , join_mask_column(join_mask_column_) , is_join_get(is_join_get_) { size_t num_columns_to_add = block_with_columns_to_add.columns(); @@ -784,6 +825,8 @@ public: ASOF::Inequality asofInequality() const { return asof_inequality; } const IColumn & leftAsofKey() const { return *left_asof_key; } + bool isRowFiltered(size_t i) { return join_mask_column && !(*join_mask_column)[i]; } + const ColumnRawPtrs & key_columns; const Sizes & key_sizes; size_t rows_to_add; @@ -799,6 +842,7 @@ private: std::optional asof_type; ASOF::Inequality asof_inequality; const IColumn * left_asof_key = nullptr; + UInt8ColumnDataPtr join_mask_column; bool is_join_get; void addColumn(const ColumnWithTypeAndName & src_column, const std::string & qualified_name) @@ -891,7 +935,9 @@ NO_INLINE IColumn::Filter joinRightColumns( } } - auto find_result = key_getter.findKey(map, i, pool); + bool row_acceptable = !added_columns.isRowFiltered(i); + using FindResult = typename KeyGetter::FindResult; + auto find_result = row_acceptable ? key_getter.findKey(map, i, pool) : FindResult(); if (find_result.isFound()) { @@ -1098,7 +1144,20 @@ void HashJoin::joinBlockImpl( * For ASOF, the last column is used as the ASOF column */ - AddedColumns added_columns(block_with_columns_to_add, block, savedBlockSample(), *this, left_key_columns, key_sizes, is_asof_join, is_join_get); + /// Only rows where mask == true can be joined + ColumnPtr join_mask_column = JoinCommon::getColumnAsMask(block, condition_mask_column_name_left); + + AddedColumns added_columns( + block_with_columns_to_add, + block, + savedBlockSample(), + *this, + left_key_columns, + key_sizes, + join_mask_column ? &assert_cast(*join_mask_column).getData() : nullptr, + is_asof_join, + is_join_get); + bool has_required_right_keys = (required_right_keys.columns() != 0); added_columns.need_filter = need_filter || has_required_right_keys; @@ -1324,7 +1383,8 @@ ColumnWithTypeAndName HashJoin::joinGet(const Block & block, const Block & block void HashJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed) { const Names & key_names_left = table_join->keyNamesLeft(); - JoinCommon::checkTypesOfKeys(block, key_names_left, right_table_keys, key_names_right); + JoinCommon::checkTypesOfKeys(block, key_names_left, condition_mask_column_name_left, + right_sample_block, key_names_right, condition_mask_column_name_right); if (overDictionary()) { diff --git a/src/Interpreters/HashJoin.h b/src/Interpreters/HashJoin.h index 86c53081059..65e3f5dbabe 100644 --- a/src/Interpreters/HashJoin.h +++ b/src/Interpreters/HashJoin.h @@ -377,6 +377,10 @@ private: /// Left table column names that are sources for required_right_keys columns std::vector required_right_keys_sources; + /// Additional conditions for rows to join from JOIN ON section + String condition_mask_column_name_left; + String condition_mask_column_name_right; + Poco::Logger * log; Block totals; diff --git a/src/Interpreters/MergeJoin.cpp b/src/Interpreters/MergeJoin.cpp index 8f9d94b6079..0a89a4568e3 100644 --- a/src/Interpreters/MergeJoin.cpp +++ b/src/Interpreters/MergeJoin.cpp @@ -1,19 +1,21 @@ #include +#include #include #include -#include +#include +#include +#include +#include +#include #include #include -#include #include -#include -#include -#include -#include -#include +#include #include -#include +#include +#include +#include namespace DB @@ -23,12 +25,50 @@ namespace ErrorCodes { extern const int NOT_IMPLEMENTED; extern const int PARAMETER_OUT_OF_BOUND; + extern const int ILLEGAL_COLUMN; extern const int LOGICAL_ERROR; } namespace { +String deriveTempName(const String & name) +{ + return "--" + name; +} + +/* + * Convert column with conditions for left or right table to join to joining key. + * Input column type is UInt8 output is Nullable(UInt8). + * 0 converted to NULL and such rows won't be joined, + * 1 converted to 0 (any constant non-NULL value to join) + */ +ColumnWithTypeAndName condtitionColumnToJoinable(const Block & block, const String & src_column_name) +{ + size_t res_size = block.rows(); + auto data_col = ColumnUInt8::create(res_size, 0); + auto null_map = ColumnUInt8::create(res_size, 0); + + if (!src_column_name.empty()) + { + auto mask_col = JoinCommon::getColumnAsMask(block, src_column_name); + assert(mask_col); + const auto & mask_data = assert_cast(*mask_col).getData(); + + for (size_t i = 0; i < res_size; ++i) + null_map->getData()[i] = !mask_data[i]; + } + + ColumnPtr res_col = ColumnNullable::create(std::move(data_col), std::move(null_map)); + DataTypePtr res_col_type = std::make_shared(std::make_shared()); + String res_name = deriveTempName(src_column_name); + + if (block.has(res_name)) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Conflicting column name '{}'", res_name); + + return {res_col, res_col_type, res_name}; +} + template int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos) { @@ -180,7 +220,7 @@ class MergeJoinCursor { public: MergeJoinCursor(const Block & block, const SortDescription & desc_) - : impl(SortCursorImpl(block, desc_)) + : impl(block, desc_) { /// SortCursorImpl can work with permutation, but MergeJoinCursor can't. if (impl.permutation) @@ -320,14 +360,17 @@ MutableColumns makeMutableColumns(const Block & block, size_t rows_to_reserve = void makeSortAndMerge(const Names & keys, SortDescription & sort, SortDescription & merge) { NameSet unique_keys; + for (const auto & sd: merge) + unique_keys.insert(sd.column_name); + for (const auto & key_name : keys) { - merge.emplace_back(SortColumnDescription(key_name, 1, 1)); + merge.emplace_back(key_name); - if (!unique_keys.count(key_name)) + if (!unique_keys.contains(key_name)) { unique_keys.insert(key_name); - sort.emplace_back(SortColumnDescription(key_name, 1, 1)); + sort.emplace_back(key_name); } } } @@ -464,15 +507,31 @@ MergeJoin::MergeJoin(std::shared_ptr table_join_, const Block & right ErrorCodes::PARAMETER_OUT_OF_BOUND); } - for (const auto & right_key : table_join->keyNamesRight()) + std::tie(mask_column_name_left, mask_column_name_right) = table_join->joinConditionColumnNames(); + + /// Add auxiliary joining keys to join only rows where conditions from JOIN ON sections holds + /// Input boolean column converted to nullable and only rows with non NULLS value will be joined + if (!mask_column_name_left.empty() || !mask_column_name_right.empty()) + { + JoinCommon::checkTypesOfMasks({}, "", right_sample_block, mask_column_name_right); + + key_names_left.push_back(deriveTempName(mask_column_name_left)); + key_names_right.push_back(deriveTempName(mask_column_name_right)); + } + + key_names_left.insert(key_names_left.end(), table_join->keyNamesLeft().begin(), table_join->keyNamesLeft().end()); + key_names_right.insert(key_names_right.end(), table_join->keyNamesRight().begin(), table_join->keyNamesRight().end()); + + addConditionJoinColumn(right_sample_block, JoinTableSide::Right); + JoinCommon::splitAdditionalColumns(key_names_right, right_sample_block, right_table_keys, right_columns_to_add); + + for (const auto & right_key : key_names_right) { if (right_sample_block.getByName(right_key).type->lowCardinality()) lowcard_right_keys.push_back(right_key); } - - table_join->splitAdditionalColumns(right_sample_block, right_table_keys, right_columns_to_add); JoinCommon::removeLowCardinalityInplace(right_table_keys); - JoinCommon::removeLowCardinalityInplace(right_sample_block, table_join->keyNamesRight()); + JoinCommon::removeLowCardinalityInplace(right_sample_block, key_names_right); const NameSet required_right_keys = table_join->requiredRightKeys(); for (const auto & column : right_table_keys) @@ -484,8 +543,8 @@ MergeJoin::MergeJoin(std::shared_ptr table_join_, const Block & right if (nullable_right_side) JoinCommon::convertColumnsToNullable(right_columns_to_add); - makeSortAndMerge(table_join->keyNamesLeft(), left_sort_description, left_merge_description); - makeSortAndMerge(table_join->keyNamesRight(), right_sort_description, right_merge_description); + makeSortAndMerge(key_names_left, left_sort_description, left_merge_description); + makeSortAndMerge(key_names_right, right_sort_description, right_merge_description); /// Temporary disable 'partial_merge_join_left_table_buffer_bytes' without 'partial_merge_join_optimizations' if (table_join->enablePartialMergeJoinOptimizations()) @@ -526,7 +585,8 @@ void MergeJoin::mergeInMemoryRightBlocks() pipeline.init(std::move(source)); /// TODO: there should be no split keys by blocks for RIGHT|FULL JOIN - pipeline.addTransform(std::make_shared(pipeline.getHeader(), right_sort_description, max_rows_in_right_block, 0, 0, 0, 0, nullptr, 0)); + pipeline.addTransform(std::make_shared( + pipeline.getHeader(), right_sort_description, max_rows_in_right_block, 0, 0, 0, 0, nullptr, 0)); auto sorted_input = PipelineExecutingBlockInputStream(std::move(pipeline)); @@ -602,6 +662,7 @@ bool MergeJoin::addJoinedBlock(const Block & src_block, bool) { Block block = modifyRightBlock(src_block); + addConditionJoinColumn(block, JoinTableSide::Right); sortBlock(block, right_sort_description); return saveRightBlock(std::move(block)); } @@ -611,16 +672,22 @@ void MergeJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed) Names lowcard_keys = lowcard_right_keys; if (block) { - JoinCommon::checkTypesOfKeys(block, table_join->keyNamesLeft(), right_table_keys, table_join->keyNamesRight()); + JoinCommon::checkTypesOfMasks(block, mask_column_name_left, right_sample_block, mask_column_name_right); + + /// Add auxiliary column, will be removed after joining + addConditionJoinColumn(block, JoinTableSide::Left); + + JoinCommon::checkTypesOfKeys(block, key_names_left, right_table_keys, key_names_right); + materializeBlockInplace(block); - for (const auto & column_name : table_join->keyNamesLeft()) + for (const auto & column_name : key_names_left) { if (block.getByName(column_name).type->lowCardinality()) lowcard_keys.push_back(column_name); } - JoinCommon::removeLowCardinalityInplace(block, table_join->keyNamesLeft(), false); + JoinCommon::removeLowCardinalityInplace(block, key_names_left, false); sortBlock(block, left_sort_description); @@ -655,6 +722,9 @@ void MergeJoin::joinBlock(Block & block, ExtraBlockPtr & not_processed) if (!not_processed && left_blocks_buffer) not_processed = std::make_shared(NotProcessed{{}, 0, 0, 0}); + if (needConditionJoinColumn()) + block.erase(deriveTempName(mask_column_name_left)); + for (const auto & column_name : lowcard_keys) { if (!block.has(column_name)) @@ -697,7 +767,7 @@ void MergeJoin::joinSortedBlock(Block & block, ExtraBlockPtr & not_processed) if (skip_not_intersected) { - int intersection = left_cursor.intersect(min_max_right_blocks[i], table_join->keyNamesRight()); + int intersection = left_cursor.intersect(min_max_right_blocks[i], key_names_right); if (intersection < 0) break; /// (left) ... (right) if (intersection > 0) @@ -730,7 +800,7 @@ void MergeJoin::joinSortedBlock(Block & block, ExtraBlockPtr & not_processed) if (skip_not_intersected) { - int intersection = left_cursor.intersect(min_max_right_blocks[i], table_join->keyNamesRight()); + int intersection = left_cursor.intersect(min_max_right_blocks[i], key_names_right); if (intersection < 0) break; /// (left) ... (right) if (intersection > 0) @@ -831,7 +901,7 @@ bool MergeJoin::leftJoin(MergeJoinCursor & left_cursor, const Block & left_block } bool MergeJoin::allInnerJoin(MergeJoinCursor & left_cursor, const Block & left_block, RightBlockInfo & right_block_info, - MutableColumns & left_columns, MutableColumns & right_columns, size_t & left_key_tail) + MutableColumns & left_columns, MutableColumns & right_columns, size_t & left_key_tail) { const Block & right_block = *right_block_info.block; MergeJoinCursor right_cursor(right_block, right_merge_description); @@ -970,11 +1040,15 @@ void MergeJoin::initRightTableWriter() class NonMergeJoinedBlockInputStream : private NotJoined, public IBlockInputStream { public: - NonMergeJoinedBlockInputStream(const MergeJoin & parent_, const Block & result_sample_block_, UInt64 max_block_size_) + NonMergeJoinedBlockInputStream(const MergeJoin & parent_, + const Block & result_sample_block_, + const Names & key_names_right_, + UInt64 max_block_size_) : NotJoined(*parent_.table_join, parent_.modifyRightBlock(parent_.right_sample_block), parent_.right_sample_block, - result_sample_block_) + result_sample_block_, + {}, key_names_right_) , parent(parent_) , max_block_size(max_block_size_) {} @@ -1062,10 +1136,26 @@ private: BlockInputStreamPtr MergeJoin::createStreamWithNonJoinedRows(const Block & result_sample_block, UInt64 max_block_size) const { if (table_join->strictness() == ASTTableJoin::Strictness::All && (is_right || is_full)) - return std::make_shared(*this, result_sample_block, max_block_size); + return std::make_shared(*this, result_sample_block, key_names_right, max_block_size); return {}; } +bool MergeJoin::needConditionJoinColumn() const +{ + return !mask_column_name_left.empty() || !mask_column_name_right.empty(); +} + +void MergeJoin::addConditionJoinColumn(Block & block, JoinTableSide block_side) const +{ + if (needConditionJoinColumn()) + { + if (block_side == JoinTableSide::Left) + block.insert(condtitionColumnToJoinable(block, mask_column_name_left)); + else + block.insert(condtitionColumnToJoinable(block, mask_column_name_right)); + } +} + MergeJoin::RightBlockInfo::RightBlockInfo(std::shared_ptr block_, size_t block_number_, size_t & skip_, RowBitmaps * bitmaps_) : block(block_) diff --git a/src/Interpreters/MergeJoin.h b/src/Interpreters/MergeJoin.h index 8c829569a41..11e5dc86dc2 100644 --- a/src/Interpreters/MergeJoin.h +++ b/src/Interpreters/MergeJoin.h @@ -16,7 +16,7 @@ class TableJoin; class MergeJoinCursor; struct MergeJoinEqualRange; class RowBitmaps; - +enum class JoinTableSide; class MergeJoin : public IJoin { @@ -79,6 +79,14 @@ private: Block right_columns_to_add; SortedBlocksWriter::Blocks right_blocks; + Names key_names_right; + Names key_names_left; + + /// Additional conditions for rows to join from JOIN ON section. + /// Only rows where conditions are met can be joined. + String mask_column_name_left; + String mask_column_name_right; + /// Each block stores first and last row from corresponding sorted block on disk Blocks min_max_right_blocks; std::shared_ptr left_blocks_buffer; @@ -151,6 +159,9 @@ private: void mergeFlushedRightBlocks(); void initRightTableWriter(); + + bool needConditionJoinColumn() const; + void addConditionJoinColumn(Block & block, JoinTableSide block_side) const; }; } diff --git a/src/Interpreters/TableJoin.cpp b/src/Interpreters/TableJoin.cpp index 122e2cd6479..20e8f6b18b4 100644 --- a/src/Interpreters/TableJoin.cpp +++ b/src/Interpreters/TableJoin.cpp @@ -1,17 +1,17 @@ #include -#include - -#include - -#include -#include -#include - #include +#include +#include +#include + #include -#include +#include +#include +#include + +#include namespace DB @@ -132,6 +132,8 @@ ASTPtr TableJoin::leftKeysList() const { ASTPtr keys_list = std::make_shared(); keys_list->children = key_asts_left; + if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Left)) + keys_list->children.push_back(extra_cond); return keys_list; } @@ -140,6 +142,8 @@ ASTPtr TableJoin::rightKeysList() const ASTPtr keys_list = std::make_shared(); if (hasOn()) keys_list->children = key_asts_right; + if (ASTPtr extra_cond = joinConditionColumn(JoinTableSide::Right)) + keys_list->children.push_back(extra_cond); return keys_list; } @@ -176,22 +180,6 @@ NamesWithAliases TableJoin::getRequiredColumns(const Block & sample, const Names return getNamesWithAliases(required_columns); } -void TableJoin::splitAdditionalColumns(const Block & sample_block, Block & block_keys, Block & block_others) const -{ - block_others = materializeBlock(sample_block); - - for (const String & column_name : key_names_right) - { - /// Extract right keys with correct keys order. There could be the same key names. - if (!block_keys.has(column_name)) - { - auto & col = block_others.getByName(column_name); - block_keys.insert(col); - block_others.erase(column_name); - } - } -} - Block TableJoin::getRequiredRightKeys(const Block & right_table_keys, std::vector & keys_sources) const { const Names & left_keys = keyNamesLeft(); @@ -474,4 +462,48 @@ String TableJoin::renamedRightColumnName(const String & name) const return name; } +void TableJoin::addJoinCondition(const ASTPtr & ast, bool is_left) +{ + LOG_TRACE(&Poco::Logger::get("TableJoin"), "Add join condition for {} table: {}", (is_left ? "left" : "right"), queryToString(ast)); + + if (is_left) + on_filter_condition_asts_left.push_back(ast); + else + on_filter_condition_asts_right.push_back(ast); +} + +/// Returns all conditions related to one table joined with 'and' function +static ASTPtr buildJoinConditionColumn(const ASTs & on_filter_condition_asts) +{ + if (on_filter_condition_asts.empty()) + return nullptr; + + if (on_filter_condition_asts.size() == 1) + return on_filter_condition_asts[0]; + + auto function = std::make_shared(); + function->name = "and"; + function->arguments = std::make_shared(); + function->children.push_back(function->arguments); + function->arguments->children = on_filter_condition_asts; + return function; +} + +ASTPtr TableJoin::joinConditionColumn(JoinTableSide side) const +{ + if (side == JoinTableSide::Left) + return buildJoinConditionColumn(on_filter_condition_asts_left); + return buildJoinConditionColumn(on_filter_condition_asts_right); +} + +std::pair TableJoin::joinConditionColumnNames() const +{ + std::pair res; + if (auto cond_ast = joinConditionColumn(JoinTableSide::Left)) + res.first = cond_ast->getColumnName(); + if (auto cond_ast = joinConditionColumn(JoinTableSide::Right)) + res.second = cond_ast->getColumnName(); + return res; +} + } diff --git a/src/Interpreters/TableJoin.h b/src/Interpreters/TableJoin.h index 08098e5378c..4c8c16028f5 100644 --- a/src/Interpreters/TableJoin.h +++ b/src/Interpreters/TableJoin.h @@ -33,6 +33,12 @@ struct Settings; class IVolume; using VolumePtr = std::shared_ptr; +enum class JoinTableSide +{ + Left, + Right +}; + class TableJoin { @@ -67,9 +73,12 @@ private: Names key_names_left; Names key_names_right; /// Duplicating names are qualified. + ASTs on_filter_condition_asts_left; + ASTs on_filter_condition_asts_right; ASTs key_asts_left; ASTs key_asts_right; + ASTTableJoin table_join; ASOF::Inequality asof_inequality = ASOF::Inequality::GreaterOrEquals; @@ -150,6 +159,23 @@ public: void addUsingKey(const ASTPtr & ast); void addOnKeys(ASTPtr & left_table_ast, ASTPtr & right_table_ast); + /* Conditions for left/right table from JOIN ON section. + * + * Conditions for left and right tables stored separately and united with 'and' function into one column. + * For example for query: + * SELECT ... JOIN ... ON t1.id == t2.id AND expr11(t1) AND expr21(t2) AND expr12(t1) AND expr22(t2) + * + * We will build two new ASTs: `expr11(t1) AND expr12(t1)`, `expr21(t2) AND expr22(t2)` + * Such columns will be added and calculated for left and right tables respectively. + * Only rows where conditions are met (where new columns have non-zero value) will be joined. + * + * NOTE: non-equi condition containing columns from different tables (like `... ON t1.id = t2.id AND t1.val > t2.val) + * doesn't supported yet, it can be added later. + */ + void addJoinCondition(const ASTPtr & ast, bool is_left); + ASTPtr joinConditionColumn(JoinTableSide side) const; + std::pair joinConditionColumnNames() const; + bool hasUsing() const { return table_join.using_expression_list != nullptr; } bool hasOn() const { return table_join.on_expression != nullptr; } @@ -201,8 +227,6 @@ public: /// StorageJoin overrides key names (cause of different names qualification) void setRightKeys(const Names & keys) { key_names_right = keys; } - /// Split key and other columns by keys name list - void splitAdditionalColumns(const Block & sample_block, Block & block_keys, Block & block_others) const; Block getRequiredRightKeys(const Block & right_table_keys, std::vector & keys_sources) const; String renamedRightColumnName(const String & name) const; diff --git a/src/Interpreters/TreeRewriter.cpp b/src/Interpreters/TreeRewriter.cpp index 44a33d0eecf..cc345004f6f 100644 --- a/src/Interpreters/TreeRewriter.cpp +++ b/src/Interpreters/TreeRewriter.cpp @@ -532,9 +532,12 @@ void collectJoinedColumns(TableJoin & analyzed_join, const ASTTableJoin & table_ CollectJoinOnKeysVisitor::Data data{analyzed_join, tables[0], tables[1], aliases, is_asof}; CollectJoinOnKeysVisitor(data).visit(table_join.on_expression); - if (!data.has_some) + if (analyzed_join.keyNamesLeft().empty()) + { throw Exception("Cannot get JOIN keys from JOIN ON section: " + queryToString(table_join.on_expression), ErrorCodes::INVALID_JOIN_ON_EXPRESSION); + } + if (is_asof) data.asofToJoinKeys(); } diff --git a/src/Interpreters/join_common.cpp b/src/Interpreters/join_common.cpp index 74f2c26a2ef..9d6abda42ed 100644 --- a/src/Interpreters/join_common.cpp +++ b/src/Interpreters/join_common.cpp @@ -1,21 +1,29 @@ #include -#include -#include -#include + #include -#include -#include -#include +#include + #include + +#include +#include +#include +#include + #include +#include +#include + +#include namespace DB { namespace ErrorCodes { - extern const int TYPE_MISMATCH; + extern const int INVALID_JOIN_ON_EXPRESSION; extern const int LOGICAL_ERROR; + extern const int TYPE_MISMATCH; } namespace @@ -220,6 +228,12 @@ ColumnRawPtrs materializeColumnsInplace(Block & block, const Names & names) return ptrs; } +ColumnPtr materializeColumn(const Block & block, const String & column_name) +{ + const auto & src_column = block.getByName(column_name).column; + return recursiveRemoveLowCardinality(src_column->convertToFullColumnIfConst()); +} + Columns materializeColumns(const Block & block, const Names & names) { Columns materialized; @@ -227,8 +241,7 @@ Columns materializeColumns(const Block & block, const Names & names) for (const auto & column_name : names) { - const auto & src_column = block.getByName(column_name).column; - materialized.emplace_back(recursiveRemoveLowCardinality(src_column->convertToFullColumnIfConst())); + materialized.emplace_back(materializeColumn(block, column_name)); } return materialized; @@ -294,7 +307,8 @@ ColumnRawPtrs extractKeysForJoin(const Block & block_keys, const Names & key_nam return key_columns; } -void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right) +void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, + const Block & block_right, const Names & key_names_right) { size_t keys_size = key_names_left.size(); @@ -305,12 +319,38 @@ void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, co if (!left_type->equals(*right_type)) throw Exception("Type mismatch of columns to JOIN by: " - + key_names_left[i] + " " + left_type->getName() + " at left, " - + key_names_right[i] + " " + right_type->getName() + " at right", - ErrorCodes::TYPE_MISMATCH); + + key_names_left[i] + " " + left_type->getName() + " at left, " + + key_names_right[i] + " " + right_type->getName() + " at right", + ErrorCodes::TYPE_MISMATCH); } } +void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const String & condition_name_left, + const Block & block_right, const Names & key_names_right, const String & condition_name_right) +{ + checkTypesOfKeys(block_left, key_names_left,block_right,key_names_right); + checkTypesOfMasks(block_left, condition_name_left, block_right, condition_name_right); +} + +void checkTypesOfMasks(const Block & block_left, const String & condition_name_left, + const Block & block_right, const String & condition_name_right) +{ + auto check_cond_column_type = [](const Block & block, const String & col_name) + { + if (col_name.empty()) + return; + + DataTypePtr dtype = removeNullable(recursiveRemoveLowCardinality(block.getByName(col_name).type)); + + if (!dtype->equals(DataTypeUInt8{})) + throw Exception(ErrorCodes::INVALID_JOIN_ON_EXPRESSION, + "Expected logical expression in JOIN ON section, got unexpected column '{}' of type '{}'", + col_name, dtype->getName()); + }; + check_cond_column_type(block_left, condition_name_left); + check_cond_column_type(block_right, condition_name_right); +} + void createMissedColumns(Block & block) { for (size_t i = 0; i < block.columns(); ++i) @@ -359,28 +399,80 @@ bool typesEqualUpToNullability(DataTypePtr left_type, DataTypePtr right_type) return left_type_strict->equals(*right_type_strict); } +ColumnPtr getColumnAsMask(const Block & block, const String & column_name) +{ + if (column_name.empty()) + return nullptr; + + const auto & src_col = block.getByName(column_name); + + DataTypePtr col_type = recursiveRemoveLowCardinality(src_col.type); + if (isNothing(col_type)) + return ColumnUInt8::create(block.rows(), 0); + + const auto & join_condition_col = recursiveRemoveLowCardinality(src_col.column->convertToFullColumnIfConst()); + + if (const auto * nullable_col = typeid_cast(join_condition_col.get())) + { + if (isNothing(assert_cast(*col_type).getNestedType())) + return ColumnUInt8::create(block.rows(), 0); + + /// Return nested column with NULL set to false + const auto & nest_col = assert_cast(nullable_col->getNestedColumn()); + const auto & null_map = nullable_col->getNullMapColumn(); + + auto res = ColumnUInt8::create(nullable_col->size(), 0); + for (size_t i = 0, sz = nullable_col->size(); i < sz; ++i) + res->getData()[i] = !null_map.getData()[i] && nest_col.getData()[i]; + return res; + } + else + return join_condition_col; +} + + +void splitAdditionalColumns(const Names & key_names, const Block & sample_block, Block & block_keys, Block & block_others) +{ + block_others = materializeBlock(sample_block); + + for (const String & column_name : key_names) + { + /// Extract right keys with correct keys order. There could be the same key names. + if (!block_keys.has(column_name)) + { + auto & col = block_others.getByName(column_name); + block_keys.insert(col); + block_others.erase(column_name); + } + } +} + } NotJoined::NotJoined(const TableJoin & table_join, const Block & saved_block_sample_, const Block & right_sample_block, - const Block & result_sample_block_) + const Block & result_sample_block_, const Names & key_names_left_, const Names & key_names_right_) : saved_block_sample(saved_block_sample_) , result_sample_block(materializeBlock(result_sample_block_)) + , key_names_left(key_names_left_.empty() ? table_join.keyNamesLeft() : key_names_left_) + , key_names_right(key_names_right_.empty() ? table_join.keyNamesRight() : key_names_right_) { std::vector tmp; Block right_table_keys; Block sample_block_with_columns_to_add; - table_join.splitAdditionalColumns(right_sample_block, right_table_keys, sample_block_with_columns_to_add); + + JoinCommon::splitAdditionalColumns(key_names_right, right_sample_block, right_table_keys, + sample_block_with_columns_to_add); Block required_right_keys = table_join.getRequiredRightKeys(right_table_keys, tmp); std::unordered_map left_to_right_key_remap; if (table_join.hasUsing()) { - for (size_t i = 0; i < table_join.keyNamesLeft().size(); ++i) + for (size_t i = 0; i < key_names_left.size(); ++i) { - const String & left_key_name = table_join.keyNamesLeft()[i]; - const String & right_key_name = table_join.keyNamesRight()[i]; + const String & left_key_name = key_names_left[i]; + const String & right_key_name = key_names_right[i]; size_t left_key_pos = result_sample_block.getPositionByName(left_key_name); size_t right_key_pos = saved_block_sample.getPositionByName(right_key_name); diff --git a/src/Interpreters/join_common.h b/src/Interpreters/join_common.h index 2da795d0d4c..8862116d1fa 100644 --- a/src/Interpreters/join_common.h +++ b/src/Interpreters/join_common.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -12,6 +13,7 @@ struct ColumnWithTypeAndName; class TableJoin; class IColumn; using ColumnRawPtrs = std::vector; +using UInt8ColumnDataPtr = const ColumnUInt8::Container *; namespace JoinCommon { @@ -22,6 +24,7 @@ void convertColumnsToNullable(Block & block, size_t starting_pos = 0); void removeColumnNullability(ColumnWithTypeAndName & column); void changeColumnRepresentation(const ColumnPtr & src_column, ColumnPtr & dst_column); ColumnPtr emptyNotNullableClone(const ColumnPtr & column); +ColumnPtr materializeColumn(const Block & block, const String & name); Columns materializeColumns(const Block & block, const Names & names); ColumnRawPtrs materializeColumnsInplace(Block & block, const Names & names); ColumnRawPtrs getRawPointers(const Columns & columns); @@ -31,8 +34,17 @@ void restoreLowCardinalityInplace(Block & block); ColumnRawPtrs extractKeysForJoin(const Block & block_keys, const Names & key_names_right); -/// Throw an exception if blocks have different types of key columns. Compare up to Nullability. -void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const Block & block_right, const Names & key_names_right); +/// Throw an exception if join condition column is not UIint8 +void checkTypesOfMasks(const Block & block_left, const String & condition_name_left, + const Block & block_right, const String & condition_name_right); + +/// Throw an exception if blocks have different types of key columns . Compare up to Nullability. +void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, + const Block & block_right, const Names & key_names_right); + +/// Check both keys and conditions +void checkTypesOfKeys(const Block & block_left, const Names & key_names_left, const String & condition_name_left, + const Block & block_right, const Names & key_names_right, const String & condition_name_right); void createMissedColumns(Block & block); void joinTotals(Block left_totals, Block right_totals, const TableJoin & table_join, Block & out_block); @@ -41,6 +53,12 @@ void addDefaultValues(IColumn & column, const DataTypePtr & type, size_t count); bool typesEqualUpToNullability(DataTypePtr left_type, DataTypePtr right_type); +/// Return mask array of type ColumnUInt8 for specified column. Source should have type UInt8 or Nullable(UInt8). +ColumnPtr getColumnAsMask(const Block & block, const String & column_name); + +/// Split key and other columns by keys name list +void splitAdditionalColumns(const Names & key_names, const Block & sample_block, Block & block_keys, Block & block_others); + void changeLowCardinalityInplace(ColumnWithTypeAndName & column); } @@ -50,7 +68,7 @@ class NotJoined { public: NotJoined(const TableJoin & table_join, const Block & saved_block_sample_, const Block & right_sample_block, - const Block & result_sample_block_); + const Block & result_sample_block_, const Names & key_names_left_ = {}, const Names & key_names_right_ = {}); void correctLowcardAndNullability(MutableColumns & columns_right); void addLeftColumns(Block & block, size_t rows_added) const; @@ -61,6 +79,9 @@ protected: Block saved_block_sample; Block result_sample_block; + Names key_names_left; + Names key_names_right; + ~NotJoined() = default; private: diff --git a/tests/queries/0_stateless/00878_join_unexpected_results.reference b/tests/queries/0_stateless/00878_join_unexpected_results.reference index a389cb47a96..1630e30d641 100644 --- a/tests/queries/0_stateless/00878_join_unexpected_results.reference +++ b/tests/queries/0_stateless/00878_join_unexpected_results.reference @@ -23,10 +23,15 @@ join_use_nulls = 1 - \N \N - +1 1 \N \N +2 2 \N \N - 1 1 1 1 2 2 \N \N - +1 1 1 1 +- +2 2 \N \N join_use_nulls = 0 1 1 2 2 @@ -49,7 +54,12 @@ join_use_nulls = 0 - - - +1 1 0 0 +2 2 0 0 - 1 1 1 1 2 2 0 0 - +1 1 1 1 +- +2 2 0 0 diff --git a/tests/queries/0_stateless/00878_join_unexpected_results.sql b/tests/queries/0_stateless/00878_join_unexpected_results.sql index 0aef5208b26..0ad7b1122e1 100644 --- a/tests/queries/0_stateless/00878_join_unexpected_results.sql +++ b/tests/queries/0_stateless/00878_join_unexpected_results.sql @@ -4,9 +4,8 @@ drop table if exists s; create table t(a Int64, b Int64) engine = Memory; create table s(a Int64, b Int64) engine = Memory; -insert into t values(1,1); -insert into t values(2,2); -insert into s values(1,1); +insert into t values (1,1), (2,2); +insert into s values (1,1); select 'join_use_nulls = 1'; set join_use_nulls = 1; @@ -30,11 +29,13 @@ select * from t left outer join s on (t.a=s.a and t.b=s.b) where s.a is null; select '-'; select s.* from t left outer join s on (t.a=s.a and t.b=s.b) where s.a is null; select '-'; -select t.*, s.* from t left join s on (s.a=t.a and t.b=s.b and t.a=toInt64(2)) order by t.a; -- {serverError 403 } +select t.*, s.* from t left join s on (s.a=t.a and t.b=s.b and t.a=toInt64(2)) order by t.a; select '-'; select t.*, s.* from t left join s on (s.a=t.a) order by t.a; select '-'; -select t.*, s.* from t left join s on (t.b=toInt64(2) and s.a=t.a) where s.b=2; -- {serverError 403 } +select t.*, s.* from t left join s on (t.b=toInt64(1) and s.a=t.a) where s.b=1; +select '-'; +select t.*, s.* from t left join s on (t.b=toInt64(2) and s.a=t.a) where t.b=2; select 'join_use_nulls = 0'; set join_use_nulls = 0; @@ -58,11 +59,13 @@ select '-'; select '-'; -- select s.* from t left outer join s on (t.a=s.a and t.b=s.b) where s.a is null; -- TODO select '-'; -select t.*, s.* from t left join s on (s.a=t.a and t.b=s.b and t.a=toInt64(2)) order by t.a; -- {serverError 403 } +select t.*, s.* from t left join s on (s.a=t.a and t.b=s.b and t.a=toInt64(2)) order by t.a; select '-'; select t.*, s.* from t left join s on (s.a=t.a) order by t.a; select '-'; -select t.*, s.* from t left join s on (t.b=toInt64(2) and s.a=t.a) where s.b=2; -- {serverError 403 } +select t.*, s.* from t left join s on (t.b=toInt64(1) and s.a=t.a) where s.b=1; +select '-'; +select t.*, s.* from t left join s on (t.b=toInt64(2) and s.a=t.a) where t.b=2; drop table t; drop table s; diff --git a/tests/queries/0_stateless/01095_tpch_like_smoke.reference b/tests/queries/0_stateless/01095_tpch_like_smoke.reference index e47b402bf9f..8cdcc2b015f 100644 --- a/tests/queries/0_stateless/01095_tpch_like_smoke.reference +++ b/tests/queries/0_stateless/01095_tpch_like_smoke.reference @@ -11,7 +11,7 @@ 10 11 12 -13 fail: join predicates +13 14 0.000000 15 fail: correlated subquery diff --git a/tests/queries/0_stateless/01095_tpch_like_smoke.sql b/tests/queries/0_stateless/01095_tpch_like_smoke.sql index ffd2e21dc39..5971178ade5 100644 --- a/tests/queries/0_stateless/01095_tpch_like_smoke.sql +++ b/tests/queries/0_stateless/01095_tpch_like_smoke.sql @@ -476,7 +476,7 @@ group by order by l_shipmode; -select 13, 'fail: join predicates'; -- TODO: Invalid expression for JOIN ON +select 13; select c_count, count(*) as custdist @@ -484,7 +484,7 @@ from ( select c_custkey, - count(o_orderkey) + count(o_orderkey) as c_count from customer left outer join orders on c_custkey = o_custkey @@ -496,7 +496,7 @@ group by c_count order by custdist desc, - c_count desc; -- { serverError 403 } + c_count desc; select 14; select diff --git a/tests/queries/0_stateless/01429_join_on_error_messages.sql b/tests/queries/0_stateless/01429_join_on_error_messages.sql index f9e2647f2e3..6e792e90d42 100644 --- a/tests/queries/0_stateless/01429_join_on_error_messages.sql +++ b/tests/queries/0_stateless/01429_join_on_error_messages.sql @@ -4,8 +4,8 @@ SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON (A.a = arrayJoin([1])); -- { SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON equals(a); -- { serverError 62 } SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON less(a); -- { serverError 62 } -SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b OR a = b; -- { serverError 48 } -SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b AND a > b; -- { serverError 48 } -SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b AND a < b; -- { serverError 48 } -SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b AND a >= b; -- { serverError 48 } -SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b AND a <= b; -- { serverError 48 } +SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b OR a = b; -- { serverError 403 } +SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b AND a > b; -- { serverError 403 } +SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b AND a < b; -- { serverError 403 } +SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b AND a >= b; -- { serverError 403 } +SELECT 1 FROM (select 1 a) A JOIN (select 1 b) B ON a = b AND a <= b; -- { serverError 403 } diff --git a/tests/queries/0_stateless/01881_join_on_conditions.reference b/tests/queries/0_stateless/01881_join_on_conditions.reference new file mode 100644 index 00000000000..e1fac0e7dc3 --- /dev/null +++ b/tests/queries/0_stateless/01881_join_on_conditions.reference @@ -0,0 +1,108 @@ +-- hash_join -- +-- +222 2 +222 222 +333 333 +-- +222 222 +333 333 +-- +222 +333 +-- +1 +1 +1 +1 +1 +1 +1 +1 +1 +-- +2 +2 +3 +2 +3 +2 +3 +2 +3 +2 +3 +2 +3 +2 +3 +2 +3 +-- +222 2 +333 3 +222 2 +333 3 +-- +0 2 AAA a +0 4 CCC CCC +1 111 111 0 +2 222 2 0 +2 222 222 2 AAA AAA +3 333 333 3 BBB BBB +-- +2 222 2 2 AAA a +2 222 222 2 AAA AAA +-- partial_merge -- +-- +222 2 +222 222 +333 333 +-- +222 222 +333 333 +-- +222 +333 +-- +1 +1 +1 +1 +1 +1 +1 +1 +1 +-- +2 +2 +3 +2 +3 +2 +3 +2 +3 +2 +3 +2 +3 +2 +3 +2 +3 +-- +222 2 +333 3 +222 2 +333 3 +-- +0 2 AAA a +0 4 CCC CCC +1 111 111 0 +2 222 2 0 +2 222 222 2 AAA AAA +3 333 333 3 BBB BBB +-- +2 222 2 2 AAA a +2 222 222 2 AAA AAA diff --git a/tests/queries/0_stateless/01881_join_on_conditions.sql b/tests/queries/0_stateless/01881_join_on_conditions.sql new file mode 100644 index 00000000000..a34c413845b --- /dev/null +++ b/tests/queries/0_stateless/01881_join_on_conditions.sql @@ -0,0 +1,141 @@ +DROP TABLE IF EXISTS t1; +DROP TABLE IF EXISTS t2; +DROP TABLE IF EXISTS t2_nullable; +DROP TABLE IF EXISTS t2_lc; + +CREATE TABLE t1 (`id` Int32, key String, key2 String) ENGINE = TinyLog; +CREATE TABLE t2 (`id` Int32, key String, key2 String) ENGINE = TinyLog; +CREATE TABLE t2_nullable (`id` Int32, key String, key2 Nullable(String)) ENGINE = TinyLog; +CREATE TABLE t2_lc (`id` Int32, key String, key2 LowCardinality(String)) ENGINE = TinyLog; + +INSERT INTO t1 VALUES (1, '111', '111'),(2, '222', '2'),(2, '222', '222'),(3, '333', '333'); +INSERT INTO t2 VALUES (2, 'AAA', 'AAA'),(2, 'AAA', 'a'),(3, 'BBB', 'BBB'),(4, 'CCC', 'CCC'); +INSERT INTO t2_nullable VALUES (2, 'AAA', 'AAA'),(2, 'AAA', 'a'),(3, 'BBB', NULL),(4, 'CCC', 'CCC'); +INSERT INTO t2_lc VALUES (2, 'AAA', 'AAA'),(2, 'AAA', 'a'),(3, 'BBB', 'BBB'),(4, 'CCC', 'CCC'); + +SELECT '-- hash_join --'; + +SELECT '--'; +SELECT t1.key, t1.key2 FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2; +SELECT '--'; +SELECT t1.key, t1.key2 FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2; + +SELECT '--'; +SELECT t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2; + +SELECT '--'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t2.id > 2; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t2.id == 3; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t2.key2 == 'BBB'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t1.key2 == '333'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2_nullable as t2 ON t1.id == t2.id AND (t2.key == t2.key2 OR isNull(t2.key2)) AND t1.key == t1.key2 AND t1.key2 == '333'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2_lc as t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t1.key2 == '333'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2_nullable as t2 ON t1.id == t2.id AND isNull(t2.key2); +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2_nullable as t2 ON t1.id == t2.id AND t1.key2 like '33%'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t1.id >= length(t1.key); + +-- DISTINCT is used to remove the difference between 'hash' and 'merge' join: 'merge' doesn't support `any_join_distinct_right_table_keys` + +SELECT '--'; +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2_nullable as t2 ON t1.id == t2.id AND t2.key2 != ''; +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toNullable(t2.key2 != ''); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toLowCardinality(t2.key2 != ''); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toLowCardinality(toNullable(t2.key2 != '')); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toNullable(toLowCardinality(t2.key2 != '')); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toNullable(t1.key2 != ''); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toLowCardinality(t1.key2 != ''); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toLowCardinality(toNullable(t1.key2 != '')); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toNullable(toLowCardinality(t1.key2 != '')); + +SELECT '--'; +SELECT DISTINCT t1.key, toUInt8(t1.id) as e FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND e; +-- `e + 1` is UInt16 +SELECT DISTINCT t1.key, toUInt8(t1.id) as e FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND e + 1; -- { serverError 403 } +SELECT DISTINCT t1.key, toUInt8(t1.id) as e FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toUInt8(e + 1); + +SELECT '--'; +SELECT t1.id, t1.key, t1.key2, t2.id, t2.key, t2.key2 FROM t1 FULL JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 ORDER BY t1.id, t2.id; + +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t1.id; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.id; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t1.id + 2; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.id + 2; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t1.key; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.key; -- { serverError 403 } +SELECT * FROM t1 JOIN t2 ON t2.key == t2.key2 AND (t1.id == t2.id OR isNull(t2.key2)); -- { serverError 403 } +SELECT * FROM t1 JOIN t2 ON t2.key == t2.key2 OR t1.id == t2.id; -- { serverError 403 } +SELECT * FROM t1 JOIN t2 ON (t2.key == t2.key2 AND (t1.key == t1.key2 AND t1.key != 'XXX' OR t1.id == t2.id)) AND t1.id == t2.id; -- { serverError 403 } +-- non-equi condition containing columns from different tables doesn't supported yet +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t1.id >= t2.id; -- { serverError 403 } +SELECT * FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t1.id >= length(t2.key); -- { serverError 403 } + +SELECT '--'; +-- length(t1.key2) == length(t2.key2) is expression for columns from both tables, it works because it part of joining key +SELECT t1.*, t2.* FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND length(t1.key2) == length(t2.key2) AND t1.key != '333'; + +SET join_algorithm = 'partial_merge'; + +SELECT '-- partial_merge --'; + +SELECT '--'; +SELECT t1.key, t1.key2 FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2; +SELECT '--'; +SELECT t1.key, t1.key2 FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2; + +SELECT '--'; +SELECT t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2; + +SELECT '--'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t2.id > 2; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t2.id == 3; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t2.key2 == 'BBB'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t1.key2 == '333'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2_nullable as t2 ON t1.id == t2.id AND (t2.key == t2.key2 OR isNull(t2.key2)) AND t1.key == t1.key2 AND t1.key2 == '333'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2_lc as t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t1.key2 == '333'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2_nullable as t2 ON t1.id == t2.id AND isNull(t2.key2); +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2_nullable as t2 ON t1.id == t2.id AND t1.key2 like '33%'; +SELECT '333' = t1.key FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t1.id >= length(t1.key); + +-- DISTINCT is used to remove the difference between 'hash' and 'merge' join: 'merge' doesn't support `any_join_distinct_right_table_keys` + +SELECT '--'; +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2_nullable as t2 ON t1.id == t2.id AND t2.key2 != ''; +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toNullable(t2.key2 != ''); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toLowCardinality(t2.key2 != ''); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toLowCardinality(toNullable(t2.key2 != '')); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toNullable(toLowCardinality(t2.key2 != '')); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toNullable(t1.key2 != ''); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toLowCardinality(t1.key2 != ''); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toLowCardinality(toNullable(t1.key2 != '')); +SELECT DISTINCT t1.id FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toNullable(toLowCardinality(t1.key2 != '')); + +SELECT '--'; +SELECT DISTINCT t1.key, toUInt8(t1.id) as e FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND e; +-- `e + 1` is UInt16 +SELECT DISTINCT t1.key, toUInt8(t1.id) as e FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND e + 1; -- { serverError 403 } +SELECT DISTINCT t1.key, toUInt8(t1.id) as e FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND toUInt8(e + 1); + +SELECT '--'; +SELECT t1.id, t1.key, t1.key2, t2.id, t2.key, t2.key2 FROM t1 FULL JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 ORDER BY t1.id, t2.id; + +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t1.id; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.id; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t1.id + 2; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.id + 2; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t1.key; -- { serverError 403 } +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t2.key; -- { serverError 403 } +SELECT * FROM t1 JOIN t2 ON t2.key == t2.key2 AND (t1.id == t2.id OR isNull(t2.key2)); -- { serverError 403 } +SELECT * FROM t1 JOIN t2 ON t2.key == t2.key2 OR t1.id == t2.id; -- { serverError 403 } +SELECT * FROM t1 JOIN t2 ON (t2.key == t2.key2 AND (t1.key == t1.key2 AND t1.key != 'XXX' OR t1.id == t2.id)) AND t1.id == t2.id; -- { serverError 403 } +-- non-equi condition containing columns from different tables doesn't supported yet +SELECT * FROM t1 INNER ALL JOIN t2 ON t1.id == t2.id AND t1.id >= t2.id; -- { serverError 403 } +SELECT * FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND t2.key == t2.key2 AND t1.key == t1.key2 AND t1.id >= length(t2.key); -- { serverError 403 } + +SELECT '--'; +-- length(t1.key2) == length(t2.key2) is expression for columns from both tables, it works because it part of joining key +SELECT t1.*, t2.* FROM t1 INNER ANY JOIN t2 ON t1.id == t2.id AND length(t1.key2) == length(t2.key2) AND t1.key != '333'; + +DROP TABLE IF EXISTS t1; +DROP TABLE IF EXISTS t2; +DROP TABLE IF EXISTS t2_nullable; +DROP TABLE IF EXISTS t2_lc;