diff --git a/src/Interpreters/ExpressionAnalyzer.cpp b/src/Interpreters/ExpressionAnalyzer.cpp index 7059240e408..9cbf6258199 100644 --- a/src/Interpreters/ExpressionAnalyzer.cpp +++ b/src/Interpreters/ExpressionAnalyzer.cpp @@ -1047,10 +1047,13 @@ bool SelectQueryExpressionAnalyzer::appendJoinLeftKeys(ExpressionActionsChain & return true; } -JoinPtr SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain, ActionsDAGPtr & converting_join_columns) +JoinPtr SelectQueryExpressionAnalyzer::appendJoin( + ExpressionActionsChain & chain, + ActionsDAGPtr & converting_join_columns) { const ColumnsWithTypeAndName & left_sample_columns = chain.getLastStep().getResultColumns(); - JoinPtr table_join = makeTableJoin(*syntax->ast_join, left_sample_columns, converting_join_columns); + + JoinPtr join = makeJoin(*syntax->ast_join, left_sample_columns, converting_join_columns); if (converting_join_columns) { @@ -1060,9 +1063,9 @@ JoinPtr SelectQueryExpressionAnalyzer::appendJoin(ExpressionActionsChain & chain ExpressionActionsChain::Step & step = chain.lastStep(columns_after_array_join); chain.steps.push_back(std::make_unique( - syntax->analyzed_join, table_join, step.getResultColumns())); + syntax->analyzed_join, join, step.getResultColumns())); chain.addStep(); - return table_join; + return join; } static ActionsDAGPtr createJoinedBlockActions(ContextPtr context, const TableJoin & analyzed_join) @@ -1199,7 +1202,7 @@ std::shared_ptr tryKeyValueJoin(std::shared_ptr a return std::make_shared(analyzed_join, right_sample_block, storage); } -JoinPtr SelectQueryExpressionAnalyzer::makeTableJoin( +JoinPtr SelectQueryExpressionAnalyzer::makeJoin( const ASTTablesInSelectQueryElement & join_element, const ColumnsWithTypeAndName & left_columns, ActionsDAGPtr & left_convert_actions) diff --git a/src/Interpreters/ExpressionAnalyzer.h b/src/Interpreters/ExpressionAnalyzer.h index 278415f6429..167c3dfd918 100644 --- a/src/Interpreters/ExpressionAnalyzer.h +++ b/src/Interpreters/ExpressionAnalyzer.h @@ -375,7 +375,7 @@ private: NameSet required_result_columns; SelectQueryOptions query_options; - JoinPtr makeTableJoin( + JoinPtr makeJoin( const ASTTablesInSelectQueryElement & join_element, const ColumnsWithTypeAndName & left_columns, ActionsDAGPtr & left_convert_actions); diff --git a/src/Interpreters/TableJoin.cpp b/src/Interpreters/TableJoin.cpp index 20842023bab..db75a86fff6 100644 --- a/src/Interpreters/TableJoin.cpp +++ b/src/Interpreters/TableJoin.cpp @@ -507,14 +507,18 @@ static void renameIfNeeded(String & name, const NameToNameMap & renames) } std::pair -TableJoin::createConvertingActions(const ColumnsWithTypeAndName & left_sample_columns, const ColumnsWithTypeAndName & right_sample_columns) +TableJoin::createConvertingActions( + const ColumnsWithTypeAndName & left_sample_columns, + const ColumnsWithTypeAndName & right_sample_columns) { inferJoinKeyCommonType(left_sample_columns, right_sample_columns, !isSpecialStorage()); NameToNameMap left_key_column_rename; NameToNameMap right_key_column_rename; - auto left_converting_actions = applyKeyConvertToTable(left_sample_columns, left_type_map, left_key_column_rename, forceNullableLeft()); - auto right_converting_actions = applyKeyConvertToTable(right_sample_columns, right_type_map, right_key_column_rename, forceNullableRight()); + auto left_converting_actions = applyKeyConvertToTable( + left_sample_columns, left_type_map, left_key_column_rename, forceNullableLeft()); + auto right_converting_actions = applyKeyConvertToTable( + right_sample_columns, right_type_map, right_key_column_rename, forceNullableRight()); { auto log_actions = [](const String & side, const ActionsDAGPtr & dag) @@ -536,7 +540,18 @@ TableJoin::createConvertingActions(const ColumnsWithTypeAndName & left_sample_co else { LOG_DEBUG(&Poco::Logger::get("TableJoin"), "{} JOIN converting actions: empty", side); + return; } + auto format_cols = [](const auto & cols) -> std::string + { + std::vector str_cols; + str_cols.reserve(cols.size()); + for (const auto & col : cols) + str_cols.push_back(fmt::format("'{}': {}", col.name, col.type->getName())); + return fmt::format("[{}]", fmt::join(str_cols, ", ")); + }; + LOG_DEBUG(&Poco::Logger::get("TableJoin"), "{} JOIN converting actions: {} -> {}", + side, format_cols(dag->getRequiredColumns()), format_cols(dag->getResultColumns())); }; log_actions("Left", left_converting_actions); log_actions("Right", right_converting_actions); @@ -646,10 +661,18 @@ static ActionsDAGPtr changeKeyTypes(const ColumnsWithTypeAndName & cols_src, if (!has_some_to_do) return nullptr; - return ActionsDAG::makeConvertingActions(cols_src, cols_dst, ActionsDAG::MatchColumnsMode::Name, true, add_new_cols, &key_column_rename); + return ActionsDAG::makeConvertingActions( + /* source= */ cols_src, + /* result= */ cols_dst, + /* mode= */ ActionsDAG::MatchColumnsMode::Name, + /* ignore_constant_values= */ true, + /* add_casted_columns= */ add_new_cols, + /* new_names= */ &key_column_rename); } -static ActionsDAGPtr changeTypesToNullable(const ColumnsWithTypeAndName & cols_src, const NameSet & exception_cols) +static ActionsDAGPtr changeTypesToNullable( + const ColumnsWithTypeAndName & cols_src, + const NameSet & exception_cols) { ColumnsWithTypeAndName cols_dst = cols_src; bool has_some_to_do = false; @@ -664,7 +687,14 @@ static ActionsDAGPtr changeTypesToNullable(const ColumnsWithTypeAndName & cols_s if (!has_some_to_do) return nullptr; - return ActionsDAG::makeConvertingActions(cols_src, cols_dst, ActionsDAG::MatchColumnsMode::Name, true, false, nullptr); + + return ActionsDAG::makeConvertingActions( + /* source= */ cols_src, + /* result= */ cols_dst, + /* mode= */ ActionsDAG::MatchColumnsMode::Name, + /* ignore_constant_values= */ true, + /* add_casted_columns= */ false, + /* new_names= */ nullptr); } ActionsDAGPtr TableJoin::applyKeyConvertToTable( @@ -679,7 +709,7 @@ ActionsDAGPtr TableJoin::applyKeyConvertToTable( /// Create DAG to make columns nullable if needed if (make_nullable) { - /// Do not need to make nullable temporary columns that would be used only as join keys, but now shown to user + /// Do not need to make nullable temporary columns that would be used only as join keys, but not shown to user NameSet cols_not_nullable; for (const auto & t : key_column_rename) cols_not_nullable.insert(t.second); diff --git a/src/Interpreters/TableJoin.h b/src/Interpreters/TableJoin.h index e44a0657da3..45a4106d040 100644 --- a/src/Interpreters/TableJoin.h +++ b/src/Interpreters/TableJoin.h @@ -156,7 +156,8 @@ private: /// Create converting actions and change key column names if required ActionsDAGPtr applyKeyConvertToTable( - const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, NameToNameMap & key_column_rename, + const ColumnsWithTypeAndName & cols_src, const NameToTypeMap & type_mapping, + NameToNameMap & key_column_rename, bool make_nullable) const; void addKey(const String & left_name, const String & right_name, const ASTPtr & left_ast, const ASTPtr & right_ast = nullptr); @@ -278,7 +279,9 @@ public: /// For `USING` join we will convert key columns inplace and affect into types in the result table /// For `JOIN ON` we will create new columns with converted keys to join by. std::pair - createConvertingActions(const ColumnsWithTypeAndName & left_sample_columns, const ColumnsWithTypeAndName & right_sample_columns); + createConvertingActions( + const ColumnsWithTypeAndName & left_sample_columns, + const ColumnsWithTypeAndName & right_sample_columns); void setAsofInequality(ASOF::Inequality inequality) { asof_inequality = inequality; } ASOF::Inequality getAsofInequality() { return asof_inequality; } diff --git a/src/Processors/Merges/IMergingTransform.h b/src/Processors/Merges/IMergingTransform.h index 061750d91e4..f7178d7b1ae 100644 --- a/src/Processors/Merges/IMergingTransform.h +++ b/src/Processors/Merges/IMergingTransform.h @@ -3,6 +3,7 @@ #include #include #include +#include namespace DB { diff --git a/src/Processors/Transforms/MergeJoinTransform.cpp b/src/Processors/Transforms/MergeJoinTransform.cpp index 37e81ab500b..09319438c1d 100644 --- a/src/Processors/Transforms/MergeJoinTransform.cpp +++ b/src/Processors/Transforms/MergeJoinTransform.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -7,12 +8,13 @@ #include #include #include -#include "Columns/ColumnsNumber.h" -#include "Columns/IColumn.h" -#include "Core/SortCursor.h" -#include "Parsers/ASTTablesInSelectQuery.h" -#include "base/defines.h" -#include "base/types.h" +#include +#include +#include +#include +#include +#include +#include namespace DB @@ -42,20 +44,89 @@ FullMergeJoinCursor createCursor(const Block & block, const Names & columns) } -int ALWAYS_INLINE compareCursors(const SortCursor & lhs, const SortCursor & rhs) +template +int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint = 1) +{ + if constexpr (has_left_nulls && has_right_nulls) + { + const auto * left_nullable = checkAndGetColumn(left_column); + const auto * right_nullable = checkAndGetColumn(right_column); + + if (left_nullable && right_nullable) + { + int res = left_column.compareAt(lhs_pos, rhs_pos, right_column, null_direction_hint); + if (res) + return res; + + /// NULL != NULL case + if (left_column.isNullAt(lhs_pos)) + return null_direction_hint; + + return 0; + } + } + + if constexpr (has_left_nulls) + { + if (const auto * left_nullable = checkAndGetColumn(left_column)) + { + if (left_column.isNullAt(lhs_pos)) + return null_direction_hint; + return left_nullable->getNestedColumn().compareAt(lhs_pos, rhs_pos, right_column, null_direction_hint); + } + } + + if constexpr (has_right_nulls) + { + if (const auto * right_nullable = checkAndGetColumn(right_column)) + { + if (right_column.isNullAt(rhs_pos)) + return -null_direction_hint; + return left_column.compareAt(lhs_pos, rhs_pos, right_nullable->getNestedColumn(), null_direction_hint); + } + } + + return left_column.compareAt(lhs_pos, rhs_pos, right_column, null_direction_hint); +} + +/// If on_pos == true, compare two columns at specified positions. +/// Otherwise, compare two columns at the current positions, `lpos` and `rpos` are ignored. +template +int ALWAYS_INLINE compareCursors(const Cursor & lhs, const Cursor & rhs, + [[ maybe_unused ]] size_t lpos = 0, + [[ maybe_unused ]] size_t rpos = 0) { for (size_t i = 0; i < lhs->sort_columns_size; ++i) { const auto & desc = lhs->desc[i]; int direction = desc.direction; int nulls_direction = desc.nulls_direction; - int res = direction * lhs->sort_columns[i]->compareAt(lhs->getRow(), rhs->getRow(), *(rhs.impl->sort_columns[i]), nulls_direction); - if (res != 0) - return res; + + int cmp = direction * nullableCompareAt( + *lhs->sort_columns[i], + *rhs->sort_columns[i], + on_pos ? lpos : lhs->getRow(), + on_pos ? rpos : rhs->getRow(), + nulls_direction); + if (cmp != 0) + return cmp; } return 0; } +bool ALWAYS_INLINE totallyLess(const FullMergeJoinCursor & lhs, const FullMergeJoinCursor & rhs) +{ + if (lhs->rows == 0 || rhs->rows == 0) + return false; + + if (!lhs->isValid() || !rhs->isValid()) + return false; + + /// The last row of this cursor is no larger than the first row of the another cursor. + int cmp = compareCursors(lhs, rhs, lhs->rows - 1, 0); + return cmp < 0; +} + void addIndexColumn(const Columns & columns, const IColumn & indices, Chunk & result) { for (const auto & col : columns) @@ -108,7 +179,7 @@ void MergeJoinAlgorithm::initialize(Inputs inputs) if (inputs.size() != 2) throw Exception("MergeJoinAlgorithm requires exactly two inputs", ErrorCodes::LOGICAL_ERROR); LOG_DEBUG(log, "MergeJoinAlgorithm initialize, number of inputs: {}", inputs.size()); - current_inputs.resize(inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { createSampleChunk(inputs[i].chunk, sample_chunks.emplace_back()); @@ -130,19 +201,15 @@ void MergeJoinAlgorithm::consume(Input & input, size_t source_num) { LOG_DEBUG(log, "Consume from {} chunk: {}", source_num, bool(input.chunk)); - left_stream_finished = left_stream_finished || (!input.chunk && source_num == 0); - right_stream_finished = right_stream_finished || (!input.chunk && source_num == 1); + if (!input.chunk) + cursors[source_num].completeAll(); prepareChunk(input.chunk); if (input.chunk.getNumRows() >= EMPTY_VALUE) throw Exception("Too many rows in input", ErrorCodes::TOO_MANY_ROWS); - current_inputs[source_num] = std::move(input); - if (current_inputs[source_num].chunk) - { - cursors[source_num].getImpl().reset(current_inputs[source_num].chunk.getColumns(), {}, current_inputs[source_num].permutation); - } + cursors[source_num].setInput(std::move(input)); } static size_t ALWAYS_INLINE rowsLeft(SortCursor cursor) @@ -215,92 +282,77 @@ static void leftOrFullAny(SortCursor left_cursor, SortCursor right_cursor, Padde } } +static Chunk createBlockWithDefaults(const Chunk & lhs, const Chunk & rhs) +{ + Chunk result; + size_t num_rows = std::max(lhs.getNumRows(), rhs.getNumRows()); + createSampleChunk(lhs, result, num_rows); + createSampleChunk(rhs, result, num_rows); + return result; +} + IMergingAlgorithm::Status MergeJoinAlgorithm::merge() { - if (current_inputs[0].skip_last_row || current_inputs[1].skip_last_row) - throw Exception("MergeJoinAlgorithm does not support skipLastRow", ErrorCodes::LOGICAL_ERROR); - - if (!current_inputs[0].chunk && !left_stream_finished) - { + if (!cursors[0].isValid() && !cursors[0].fullyCompleted()) return Status(0); - } - if (!current_inputs[1].chunk && !right_stream_finished) - { + if (!cursors[1].isValid() && !cursors[1].fullyCompleted()) return Status(1); - } JoinKind kind = table_join->getTableJoin().kind(); - if (left_stream_finished && right_stream_finished) + if (cursors[0].fullyCompleted() && cursors[1].fullyCompleted()) { return Status({}, true); } - if (isInner(kind) && (left_stream_finished || right_stream_finished)) + if (isInner(kind) && (cursors[0].fullyCompleted() || cursors[1].fullyCompleted())) { + LOG_DEBUG(log, "{}:{} ", __FILE__, __LINE__); return Status({}, true); } - auto create_block_with_defaults = [] (const Chunk & lhs, const Chunk & rhs) -> Chunk + if (cursors[0].fullyCompleted() && isRightOrFull(kind)) { - Chunk result; - size_t num_rows = std::max(lhs.getNumRows(), rhs.getNumRows()); - createSampleChunk(lhs, result, num_rows); - createSampleChunk(rhs, result, num_rows); - return result; - }; - - if (isLeftOrFull(kind) && right_stream_finished) - { - Chunk result = create_block_with_defaults(current_inputs[0].chunk, sample_chunks[1]); - current_inputs[0] = {}; - return Status(std::move(result), left_stream_finished && right_stream_finished); + Chunk result = createBlockWithDefaults(sample_chunks[0], cursors[1].moveCurrentChunk()); + return Status(std::move(result)); } - if (isRightOrFull(kind) && left_stream_finished) + if (isLeftOrFull(kind) && cursors[1].fullyCompleted()) { - Chunk result = create_block_with_defaults(sample_chunks[0], current_inputs[1].chunk); - current_inputs[1] = {}; - return Status(std::move(result), left_stream_finished && right_stream_finished); + Chunk result = createBlockWithDefaults(cursors[0].moveCurrentChunk(), sample_chunks[1]); + return Status(std::move(result)); } - SortCursor left_cursor = cursors[0].getCursor(); - SortCursor right_cursor = cursors[1].getCursor(); - - if (!left_cursor->isValid() || (right_cursor->isValid() && left_cursor.totallyLessOrEquals(right_cursor))) + if (!cursors[0]->isValid() || totallyLess(cursors[0], cursors[1])) { - current_inputs[0] = {}; - if (left_stream_finished) + if (cursors[0]->isValid() && isLeft(kind)) { - return Status({}, true); + Chunk result = createBlockWithDefaults(cursors[0].moveCurrentChunk(), sample_chunks[1]); + return Status(std::move(result), false); } + cursors[0].moveCurrentChunk(); + if (cursors[0].fullyCompleted()) + return Status({}, true); return Status(0); } - if (!right_cursor->isValid() || (left_cursor->isValid() && right_cursor.totallyLessOrEquals(left_cursor))) - { - current_inputs[1] = {}; - if (right_stream_finished) - { - return Status({}, true); - } - return Status(1); - } + // if (!cursors[1]->isValid() || totallyLess(cursors[1], cursors[0])) + // ... auto left_map = ColumnUInt64::create(); auto right_map = ColumnUInt64::create(); if (isInner(kind)) { - leftOrFullAny(left_cursor, right_cursor, left_map->getData(), right_map->getData()); + leftOrFullAny(cursors[0].getCursor(), cursors[1].getCursor(), left_map->getData(), right_map->getData()); } else if (isLeft(kind)) { - leftOrFullAny(left_cursor, right_cursor, left_map->getData(), right_map->getData()); + leftOrFullAny(cursors[0].getCursor(), cursors[1].getCursor(), left_map->getData(), right_map->getData()); } else if (isRight(kind)) { - leftOrFullAny(left_cursor, right_cursor, left_map->getData(), right_map->getData()); + leftOrFullAny(cursors[0].getCursor(), cursors[1].getCursor(), left_map->getData(), right_map->getData()); } else { @@ -308,20 +360,10 @@ IMergingAlgorithm::Status MergeJoinAlgorithm::merge() } Chunk result; - addIndexColumn(current_inputs[0].chunk.getColumns(), *left_map, result); - addIndexColumn(current_inputs[1].chunk.getColumns(), *right_map, result); + addIndexColumn(cursors[0].getCurrentChunk().getColumns(), *left_map, result); + addIndexColumn(cursors[1].getCurrentChunk().getColumns(), *right_map, result); - if (!left_cursor->isValid()) - { - current_inputs[0] = {}; - } - - if (!right_cursor->isValid()) - { - current_inputs[1] = {}; - } - - return Status(std::move(result), left_stream_finished && right_stream_finished); + return Status(std::move(result), cursors[0].fullyCompleted() && cursors[1].fullyCompleted()); } MergeJoinTransform::MergeJoinTransform( diff --git a/src/Processors/Transforms/MergeJoinTransform.h b/src/Processors/Transforms/MergeJoinTransform.h index fef2c3fd720..673724e188c 100644 --- a/src/Processors/Transforms/MergeJoinTransform.h +++ b/src/Processors/Transforms/MergeJoinTransform.h @@ -65,8 +65,50 @@ public: return impl; } + Chunk moveCurrentChunk() + { + Chunk res = std::move(current_input.chunk); + current_input = {}; + return res; + } + + const Chunk & getCurrentChunk() const + { + return current_input.chunk; + } + + void setInput(IMergingAlgorithm::Input && input) + { + current_input = std::move(input); + + if (!current_input.chunk) + completeAll(); + + if (current_input.skip_last_row) + throw Exception("MergeJoinAlgorithm does not support skipLastRow", ErrorCodes::LOGICAL_ERROR); + + if (current_input.chunk) + { + impl.reset(current_input.chunk.getColumns(), {}, current_input.permutation); + } + } + + bool isValid() const + { + return current_input.chunk && impl.isValid(); + } + + bool fullyCompleted() const { return fully_completed; } + + void completeAll() { fully_completed = true; } + + SortCursorImpl * operator-> () { return &impl; } + const SortCursorImpl * operator-> () const { return &impl; } + private: SortCursorImpl impl; + IMergingAlgorithm::Input current_input; + bool fully_completed = false; // bool has_left_nullable = false; // bool has_right_nullable = false; }; @@ -88,14 +130,9 @@ private: SortDescription left_desc; SortDescription right_desc; - std::vector current_inputs; std::vector cursors; - std::vector sample_chunks; - bool left_stream_finished = false; - bool right_stream_finished = false; - JoinPtr table_join; Poco::Logger * log; };