From a1a4bd3514f72b4545fbb734749dcb9d7f513007 Mon Sep 17 00:00:00 2001 From: nemonlou Date: Wed, 6 Mar 2024 10:42:30 +0800 Subject: [PATCH 1/5] make nulls direction configuable for FullSortingMergeJoin --- src/Core/Settings.h | 1 + src/Interpreters/InterpreterSelectQuery.cpp | 7 +- src/Processors/QueryPlan/JoinStep.cpp | 8 +- src/Processors/QueryPlan/JoinStep.h | 4 +- .../Transforms/MergeJoinTransform.cpp | 82 ++++++++++++------- .../Transforms/MergeJoinTransform.h | 16 +++- src/QueryPipeline/QueryPipelineBuilder.cpp | 3 +- src/QueryPipeline/QueryPipelineBuilder.h | 1 + 8 files changed, 84 insertions(+), 38 deletions(-) diff --git a/src/Core/Settings.h b/src/Core/Settings.h index a3c5638d97f..8d48b3f5e68 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -891,6 +891,7 @@ class IColumn; M(Int64, ignore_cold_parts_seconds, 0, "Only available in ClickHouse Cloud. Exclude new data parts from SELECT queries until they're either pre-warmed (see cache_populated_by_fetch) or this many seconds old. Only for Replicated-/SharedMergeTree.", 0) \ M(Int64, prefer_warmed_unmerged_parts_seconds, 0, "Only available in ClickHouse Cloud. If a merged part is less than this many seconds old and is not pre-warmed (see cache_populated_by_fetch), but all its source parts are available and pre-warmed, SELECT queries will read from those parts instead. Only for ReplicatedMergeTree. Note that this only checks whether CacheWarmer processed the part; if the part was fetched into cache by something else, it'll still be considered cold until CacheWarmer gets to it; if it was warmed, then evicted from cache, it'll still be considered warm.", 0) \ M(Bool, iceberg_engine_ignore_schema_evolution, false, "Ignore schema evolution in Iceberg table engine and read all data using latest schema saved on table creation. Note that it can lead to incorrect result", 0) \ + M(Bool, nulls_biggest_in_smj, true, "Treat nulls as biggest in sort. Used in sort merge join for compare null keys.", 0) \ // End of COMMON_SETTINGS // Please add settings related to formats into the FORMAT_FACTORY_SETTINGS, move obsolete settings to OBSOLETE_SETTINGS and obsolete format settings to OBSOLETE_FORMAT_SETTINGS. diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 7c87dadfce6..6f0a9fa9bfb 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -1693,9 +1693,10 @@ void InterpreterSelectQuery::executeImpl(QueryPlan & query_plan, std::optional

( query_plan.getCurrentDataStream(), joined_plan->getCurrentDataStream(), expressions.join, settings.max_block_size, max_streams, - analysis_result.optimize_read_in_order); + analysis_result.optimize_read_in_order, + null_direct_hint); join_step->setStepDescription(fmt::format("JOIN {}", expressions.join->pipelineType())); std::vector plans; diff --git a/src/Processors/QueryPlan/JoinStep.cpp b/src/Processors/QueryPlan/JoinStep.cpp index 1931b1eb3a1..0c46ce7893d 100644 --- a/src/Processors/QueryPlan/JoinStep.cpp +++ b/src/Processors/QueryPlan/JoinStep.cpp @@ -44,8 +44,10 @@ JoinStep::JoinStep( JoinPtr join_, size_t max_block_size_, size_t max_streams_, - bool keep_left_read_in_order_) - : join(std::move(join_)), max_block_size(max_block_size_), max_streams(max_streams_), keep_left_read_in_order(keep_left_read_in_order_) + bool keep_left_read_in_order_, + int null_direction_hint_) + : join(std::move(join_)), max_block_size(max_block_size_), max_streams(max_streams_), keep_left_read_in_order(keep_left_read_in_order_), + null_direction_hint(null_direction_hint_) { updateInputStreams(DataStreams{left_stream_, right_stream_}); } @@ -58,7 +60,7 @@ QueryPipelineBuilderPtr JoinStep::updatePipeline(QueryPipelineBuilders pipelines if (join->pipelineType() == JoinPipelineType::YShaped) { auto joined_pipeline = QueryPipelineBuilder::joinPipelinesYShaped( - std::move(pipelines[0]), std::move(pipelines[1]), join, output_stream->header, max_block_size, &processors); + std::move(pipelines[0]), std::move(pipelines[1]), join, output_stream->header, max_block_size, null_direction_hint, &processors); joined_pipeline->resize(max_streams); return joined_pipeline; } diff --git a/src/Processors/QueryPlan/JoinStep.h b/src/Processors/QueryPlan/JoinStep.h index a9059a083fe..08909ce48a9 100644 --- a/src/Processors/QueryPlan/JoinStep.h +++ b/src/Processors/QueryPlan/JoinStep.h @@ -19,7 +19,8 @@ public: JoinPtr join_, size_t max_block_size_, size_t max_streams_, - bool keep_left_read_in_order_); + bool keep_left_read_in_order_, + int null_direction_hint_ = 1); String getName() const override { return "Join"; } @@ -42,6 +43,7 @@ private: size_t max_block_size; size_t max_streams; bool keep_left_read_in_order; + int null_direction_hint; }; /// Special step for the case when Join is already filled. diff --git a/src/Processors/Transforms/MergeJoinTransform.cpp b/src/Processors/Transforms/MergeJoinTransform.cpp index 2d313d4ea5c..c8e3a806a9f 100644 --- a/src/Processors/Transforms/MergeJoinTransform.cpp +++ b/src/Processors/Transforms/MergeJoinTransform.cpp @@ -43,7 +43,7 @@ FullMergeJoinCursorPtr createCursor(const Block & block, const Names & columns) } template -int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint = 1) +int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint) { if constexpr (has_left_nulls && has_right_nulls) { @@ -88,35 +88,36 @@ int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, } int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, size_t lpos, - const SortCursorImpl & rhs, size_t rpos) + const SortCursorImpl & rhs, size_t rpos, + int null_direction_hint) { for (size_t i = 0; i < lhs.sort_columns_size; ++i) { /// TODO(@vdimir): use nullableCompareAt only if there's nullable columns - int cmp = nullableCompareAt(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos); + int cmp = nullableCompareAt(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos, null_direction_hint); if (cmp != 0) return cmp; } return 0; } -int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs) +int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs, int null_direction_hint) { - return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow()); + return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow(), null_direction_hint); } -bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs) +bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint) { /// The last row of left cursor is less than the current row of the right cursor. - int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow()); + int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow(), null_direction_hint); return cmp < 0; } -int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs) +int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint) { - if (totallyLess(lhs, rhs)) + if (totallyLess(lhs, rhs, null_direction_hint)) return -1; - if (totallyLess(rhs, lhs)) + if (totallyLess(rhs, lhs, null_direction_hint)) return 1; return 0; } @@ -270,9 +271,11 @@ bool FullMergeJoinCursor::fullyCompleted() const MergeJoinAlgorithm::MergeJoinAlgorithm( JoinPtr table_join_, const Blocks & input_headers, - size_t max_block_size_) + size_t max_block_size_, + int null_direction_hint_) : table_join(table_join_) , max_block_size(max_block_size_) + , null_direction_hint(null_direction_hint_) , log(getLogger("MergeJoinAlgorithm")) { if (input_headers.size() != 2) @@ -356,7 +359,7 @@ void MergeJoinAlgorithm::consume(Input & input, size_t source_num) cursors[source_num]->setChunk(std::move(input.chunk)); } -template +template struct AllJoinImpl { constexpr static bool enabled = isInner(kind) || isLeft(kind) || isRight(kind) || isFull(kind); @@ -382,7 +385,7 @@ struct AllJoinImpl lpos = left_cursor->getRow(); rpos = right_cursor->getRow(); - cmp = compareCursors(left_cursor.cursor, right_cursor.cursor); + cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, nullDirection(nullOrder)); if (cmp == 0) { size_t lnum = nextDistinct(left_cursor.cursor); @@ -432,19 +435,37 @@ struct AllJoinImpl } }; -template class Impl, typename ... Args> -void dispatchKind(JoinKind kind, Args && ... args) +template class Impl, typename ... Args> +void dispatchKind(JoinKind kind, int null_direction_hint, Args && ... args) { - if (Impl::enabled && kind == JoinKind::Inner) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Left) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Right) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Full) - return Impl::join(std::forward(args)...); + if (isSmall(null_direction_hint)) + { + if (Impl::enabled && kind == JoinKind::Inner) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Left) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Right) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Full) + return Impl::join(std::forward(args)...); + else + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); + + } else - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); + { + if (Impl::enabled && kind == JoinKind::Inner) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Left) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Right) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Full) + return Impl::join(std::forward(args)...); + else + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); + + } } std::optional MergeJoinAlgorithm::handleAllJoinState() @@ -517,7 +538,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::allJoin(JoinKind kind) { PaddedPODArray idx_map[2]; - dispatchKind(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state); + dispatchKind(kind, null_direction_hint, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state); assert(idx_map[0].size() == idx_map[1].size()); Chunk result; @@ -567,7 +588,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::allJoin(JoinKind kind) } -template +template struct AnyJoinImpl { constexpr static bool enabled = isInner(kind) || isLeft(kind) || isRight(kind); @@ -599,7 +620,7 @@ struct AnyJoinImpl lpos = left_cursor->getRow(); rpos = right_cursor->getRow(); - cmp = compareCursors(left_cursor.cursor, right_cursor.cursor); + cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, nullDirection(order)); if (cmp == 0) { if constexpr (isLeftOrFull(kind)) @@ -723,7 +744,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::anyJoin(JoinKind kind) PaddedPODArray idx_map[2]; size_t prev_pos[] = {current_left.getRow(), current_right.getRow()}; - dispatchKind(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state); + dispatchKind(kind, null_direction_hint, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state); assert(idx_map[0].empty() || idx_map[1].empty() || idx_map[0].size() == idx_map[1].size()); size_t num_result_rows = std::max(idx_map[0].size(), idx_map[1].size()); @@ -816,7 +837,7 @@ IMergingAlgorithm::Status MergeJoinAlgorithm::merge() } /// check if blocks are not intersecting at all - if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor); cmp != 0) + if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor, null_direction_hint); cmp != 0) { if (cmp < 0) { @@ -851,6 +872,7 @@ MergeJoinTransform::MergeJoinTransform( const Blocks & input_headers, const Block & output_header, size_t max_block_size, + int null_direction_hint_, UInt64 limit_hint_) : IMergingTransform( input_headers, @@ -859,7 +881,7 @@ MergeJoinTransform::MergeJoinTransform( limit_hint_, /* always_read_till_end_= */ false, /* empty_chunk_on_finish_= */ true, - table_join, input_headers, max_block_size) + table_join, input_headers, max_block_size, null_direction_hint_) , log(getLogger("MergeJoinTransform")) { LOG_TRACE(log, "Use MergeJoinTransform"); diff --git a/src/Processors/Transforms/MergeJoinTransform.h b/src/Processors/Transforms/MergeJoinTransform.h index 959550067f7..43485321122 100644 --- a/src/Processors/Transforms/MergeJoinTransform.h +++ b/src/Processors/Transforms/MergeJoinTransform.h @@ -220,6 +220,17 @@ private: bool recieved_all_blocks = false; }; +/// Join method. +enum class NullOrder +{ + SMALLEST, /// null is treated as smallest + BIGGEST /// null is treated as biggest +}; + +inline constexpr bool isSmall(int null_direction) { return null_direction == 1; } + +inline constexpr int nullDirection(NullOrder order) {return order == NullOrder::SMALLEST ? 1 : -1;} + /* * This class is used to join chunks from two sorted streams. * It is used in MergeJoinTransform. @@ -227,7 +238,8 @@ private: class MergeJoinAlgorithm final : public IMergingAlgorithm { public: - explicit MergeJoinAlgorithm(JoinPtr table_join, const Blocks & input_headers, size_t max_block_size_); + explicit MergeJoinAlgorithm(JoinPtr table_join, const Blocks & input_headers, size_t max_block_size_, + int null_direction_hint = 1); const char * getName() const override { return "MergeJoinAlgorithm"; } void initialize(Inputs inputs) override; @@ -258,6 +270,7 @@ private: JoinPtr table_join; size_t max_block_size; + int null_direction_hint; struct Statistic { @@ -282,6 +295,7 @@ public: const Blocks & input_headers, const Block & output_header, size_t max_block_size, + int null_direction_hint, UInt64 limit_hint = 0); String getName() const override { return "MergeJoinTransform"; } diff --git a/src/QueryPipeline/QueryPipelineBuilder.cpp b/src/QueryPipeline/QueryPipelineBuilder.cpp index 67a8fe5dcab..e338c3ce0fa 100644 --- a/src/QueryPipeline/QueryPipelineBuilder.cpp +++ b/src/QueryPipeline/QueryPipelineBuilder.cpp @@ -349,6 +349,7 @@ std::unique_ptr QueryPipelineBuilder::joinPipelinesYShaped JoinPtr join, const Block & out_header, size_t max_block_size, + int null_direction_hint, Processors * collected_processors) { left->checkInitializedAndNotCompleted(); @@ -376,7 +377,7 @@ std::unique_ptr QueryPipelineBuilder::joinPipelinesYShaped } else { - auto joining = std::make_shared(join, inputs, out_header, max_block_size); + auto joining = std::make_shared(join, inputs, out_header, max_block_size, null_direction_hint); return mergePipelines(std::move(left), std::move(right), std::move(joining), collected_processors); } } diff --git a/src/QueryPipeline/QueryPipelineBuilder.h b/src/QueryPipeline/QueryPipelineBuilder.h index f0b2ead687e..4753f957a25 100644 --- a/src/QueryPipeline/QueryPipelineBuilder.h +++ b/src/QueryPipeline/QueryPipelineBuilder.h @@ -137,6 +137,7 @@ public: JoinPtr table_join, const Block & out_header, size_t max_block_size, + int null_direction_hint, Processors * collected_processors = nullptr); /// Add other pipeline and execute it before current one. From 0b5fc743f2e0711556ab4628aecd13e5fcd1a9b8 Mon Sep 17 00:00:00 2001 From: nemonlou Date: Tue, 12 Mar 2024 09:55:02 +0800 Subject: [PATCH 2/5] make nulls direction configuable for FullSortingMergeJoin(fix review comments) --- src/Core/Settings.h | 1 - src/Interpreters/InterpreterSelectQuery.cpp | 7 +-- .../Transforms/MergeJoinTransform.cpp | 56 +++++++------------ .../Transforms/MergeJoinTransform.h | 11 ---- 4 files changed, 22 insertions(+), 53 deletions(-) diff --git a/src/Core/Settings.h b/src/Core/Settings.h index 8d48b3f5e68..a3c5638d97f 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -891,7 +891,6 @@ class IColumn; M(Int64, ignore_cold_parts_seconds, 0, "Only available in ClickHouse Cloud. Exclude new data parts from SELECT queries until they're either pre-warmed (see cache_populated_by_fetch) or this many seconds old. Only for Replicated-/SharedMergeTree.", 0) \ M(Int64, prefer_warmed_unmerged_parts_seconds, 0, "Only available in ClickHouse Cloud. If a merged part is less than this many seconds old and is not pre-warmed (see cache_populated_by_fetch), but all its source parts are available and pre-warmed, SELECT queries will read from those parts instead. Only for ReplicatedMergeTree. Note that this only checks whether CacheWarmer processed the part; if the part was fetched into cache by something else, it'll still be considered cold until CacheWarmer gets to it; if it was warmed, then evicted from cache, it'll still be considered warm.", 0) \ M(Bool, iceberg_engine_ignore_schema_evolution, false, "Ignore schema evolution in Iceberg table engine and read all data using latest schema saved on table creation. Note that it can lead to incorrect result", 0) \ - M(Bool, nulls_biggest_in_smj, true, "Treat nulls as biggest in sort. Used in sort merge join for compare null keys.", 0) \ // End of COMMON_SETTINGS // Please add settings related to formats into the FORMAT_FACTORY_SETTINGS, move obsolete settings to OBSOLETE_SETTINGS and obsolete format settings to OBSOLETE_FORMAT_SETTINGS. diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 6f0a9fa9bfb..7c87dadfce6 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -1693,10 +1693,9 @@ void InterpreterSelectQuery::executeImpl(QueryPlan & query_plan, std::optional

( query_plan.getCurrentDataStream(), joined_plan->getCurrentDataStream(), expressions.join, settings.max_block_size, max_streams, - analysis_result.optimize_read_in_order, - null_direct_hint); + analysis_result.optimize_read_in_order); join_step->setStepDescription(fmt::format("JOIN {}", expressions.join->pipelineType())); std::vector plans; diff --git a/src/Processors/Transforms/MergeJoinTransform.cpp b/src/Processors/Transforms/MergeJoinTransform.cpp index c8e3a806a9f..6288a850d76 100644 --- a/src/Processors/Transforms/MergeJoinTransform.cpp +++ b/src/Processors/Transforms/MergeJoinTransform.cpp @@ -359,7 +359,7 @@ void MergeJoinAlgorithm::consume(Input & input, size_t source_num) cursors[source_num]->setChunk(std::move(input.chunk)); } -template +template struct AllJoinImpl { constexpr static bool enabled = isInner(kind) || isLeft(kind) || isRight(kind) || isFull(kind); @@ -369,7 +369,8 @@ struct AllJoinImpl size_t max_block_size, PaddedPODArray & left_map, PaddedPODArray & right_map, - std::unique_ptr & state) + std::unique_ptr & state, + int null_direction_hint) { right_map.clear(); right_map.reserve(max_block_size); @@ -385,7 +386,7 @@ struct AllJoinImpl lpos = left_cursor->getRow(); rpos = right_cursor->getRow(); - cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, nullDirection(nullOrder)); + cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint); if (cmp == 0) { size_t lnum = nextDistinct(left_cursor.cursor); @@ -435,37 +436,19 @@ struct AllJoinImpl } }; -template class Impl, typename ... Args> -void dispatchKind(JoinKind kind, int null_direction_hint, Args && ... args) +template class Impl, typename ... Args> +void dispatchKind(JoinKind kind, Args && ... args) { - if (isSmall(null_direction_hint)) - { - if (Impl::enabled && kind == JoinKind::Inner) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Left) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Right) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Full) - return Impl::join(std::forward(args)...); + if (Impl::enabled && kind == JoinKind::Inner) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Left) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Right) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Full) + return Impl::join(std::forward(args)...); else throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); - - } - else - { - if (Impl::enabled && kind == JoinKind::Inner) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Left) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Right) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Full) - return Impl::join(std::forward(args)...); - else - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); - - } } std::optional MergeJoinAlgorithm::handleAllJoinState() @@ -538,7 +521,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::allJoin(JoinKind kind) { PaddedPODArray idx_map[2]; - dispatchKind(kind, null_direction_hint, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state); + dispatchKind(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state, null_direction_hint); assert(idx_map[0].size() == idx_map[1].size()); Chunk result; @@ -588,7 +571,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::allJoin(JoinKind kind) } -template +template struct AnyJoinImpl { constexpr static bool enabled = isInner(kind) || isLeft(kind) || isRight(kind); @@ -597,7 +580,8 @@ struct AnyJoinImpl FullMergeJoinCursor & right_cursor, PaddedPODArray & left_map, PaddedPODArray & right_map, - AnyJoinState & state) + AnyJoinState & state, + int null_direction_hint) { assert(enabled); @@ -620,7 +604,7 @@ struct AnyJoinImpl lpos = left_cursor->getRow(); rpos = right_cursor->getRow(); - cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, nullDirection(order)); + cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint); if (cmp == 0) { if constexpr (isLeftOrFull(kind)) @@ -744,7 +728,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::anyJoin(JoinKind kind) PaddedPODArray idx_map[2]; size_t prev_pos[] = {current_left.getRow(), current_right.getRow()}; - dispatchKind(kind, null_direction_hint, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state); + dispatchKind(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state, null_direction_hint); assert(idx_map[0].empty() || idx_map[1].empty() || idx_map[0].size() == idx_map[1].size()); size_t num_result_rows = std::max(idx_map[0].size(), idx_map[1].size()); diff --git a/src/Processors/Transforms/MergeJoinTransform.h b/src/Processors/Transforms/MergeJoinTransform.h index 43485321122..8af486ea34b 100644 --- a/src/Processors/Transforms/MergeJoinTransform.h +++ b/src/Processors/Transforms/MergeJoinTransform.h @@ -220,17 +220,6 @@ private: bool recieved_all_blocks = false; }; -/// Join method. -enum class NullOrder -{ - SMALLEST, /// null is treated as smallest - BIGGEST /// null is treated as biggest -}; - -inline constexpr bool isSmall(int null_direction) { return null_direction == 1; } - -inline constexpr int nullDirection(NullOrder order) {return order == NullOrder::SMALLEST ? 1 : -1;} - /* * This class is used to join chunks from two sorted streams. * It is used in MergeJoinTransform. From 5cf22bae6f40ab1beb258cf22a9b0627c601495d Mon Sep 17 00:00:00 2001 From: nemonlou Date: Tue, 12 Mar 2024 10:04:37 +0800 Subject: [PATCH 3/5] minor: fix style change --- .../Transforms/MergeJoinTransform.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Processors/Transforms/MergeJoinTransform.cpp b/src/Processors/Transforms/MergeJoinTransform.cpp index 6288a850d76..37a178810cb 100644 --- a/src/Processors/Transforms/MergeJoinTransform.cpp +++ b/src/Processors/Transforms/MergeJoinTransform.cpp @@ -440,15 +440,15 @@ template class Impl, typename ... Args> void dispatchKind(JoinKind kind, Args && ... args) { if (Impl::enabled && kind == JoinKind::Inner) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Left) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Right) - return Impl::join(std::forward(args)...); - else if (Impl::enabled && kind == JoinKind::Full) - return Impl::join(std::forward(args)...); - else - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Left) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Right) + return Impl::join(std::forward(args)...); + else if (Impl::enabled && kind == JoinKind::Full) + return Impl::join(std::forward(args)...); + else + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); } std::optional MergeJoinAlgorithm::handleAllJoinState() From 8ff21d7e47c25637b87f3ab674421cfc0a2c4487 Mon Sep 17 00:00:00 2001 From: nemonlou Date: Tue, 19 Mar 2024 10:59:32 +0800 Subject: [PATCH 4/5] fix review comments --- src/Interpreters/FullSortingMergeJoin.h | 7 ++++++- src/Processors/QueryPlan/JoinStep.cpp | 8 +++----- src/Processors/QueryPlan/JoinStep.h | 4 +--- .../Transforms/MergeJoinTransform.cpp | 17 +++++++++++------ src/Processors/Transforms/MergeJoinTransform.h | 6 ++---- src/QueryPipeline/QueryPipelineBuilder.cpp | 3 +-- src/QueryPipeline/QueryPipelineBuilder.h | 1 - 7 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/Interpreters/FullSortingMergeJoin.h b/src/Interpreters/FullSortingMergeJoin.h index 7688d44f7a9..7e07c2004b6 100644 --- a/src/Interpreters/FullSortingMergeJoin.h +++ b/src/Interpreters/FullSortingMergeJoin.h @@ -21,9 +21,11 @@ namespace ErrorCodes class FullSortingMergeJoin : public IJoin { public: - explicit FullSortingMergeJoin(std::shared_ptr table_join_, const Block & right_sample_block_) + explicit FullSortingMergeJoin(std::shared_ptr table_join_, const Block & right_sample_block_, + int null_direction_ = 1) : table_join(table_join_) , right_sample_block(right_sample_block_) + , null_direction(null_direction_) { LOG_TRACE(getLogger("FullSortingMergeJoin"), "Will use full sorting merge join"); } @@ -31,6 +33,8 @@ public: std::string getName() const override { return "FullSortingMergeJoin"; } const TableJoin & getTableJoin() const override { return *table_join; } + int getNullDirection() const { return null_direction; } + bool addBlockToJoin(const Block & /* block */, bool /* check_limits */) override { throw Exception(ErrorCodes::LOGICAL_ERROR, "FullSortingMergeJoin::addBlockToJoin should not be called"); @@ -119,6 +123,7 @@ private: std::shared_ptr table_join; Block right_sample_block; Block totals; + int null_direction; }; } diff --git a/src/Processors/QueryPlan/JoinStep.cpp b/src/Processors/QueryPlan/JoinStep.cpp index 0c46ce7893d..1931b1eb3a1 100644 --- a/src/Processors/QueryPlan/JoinStep.cpp +++ b/src/Processors/QueryPlan/JoinStep.cpp @@ -44,10 +44,8 @@ JoinStep::JoinStep( JoinPtr join_, size_t max_block_size_, size_t max_streams_, - bool keep_left_read_in_order_, - int null_direction_hint_) - : join(std::move(join_)), max_block_size(max_block_size_), max_streams(max_streams_), keep_left_read_in_order(keep_left_read_in_order_), - null_direction_hint(null_direction_hint_) + bool keep_left_read_in_order_) + : join(std::move(join_)), max_block_size(max_block_size_), max_streams(max_streams_), keep_left_read_in_order(keep_left_read_in_order_) { updateInputStreams(DataStreams{left_stream_, right_stream_}); } @@ -60,7 +58,7 @@ QueryPipelineBuilderPtr JoinStep::updatePipeline(QueryPipelineBuilders pipelines if (join->pipelineType() == JoinPipelineType::YShaped) { auto joined_pipeline = QueryPipelineBuilder::joinPipelinesYShaped( - std::move(pipelines[0]), std::move(pipelines[1]), join, output_stream->header, max_block_size, null_direction_hint, &processors); + std::move(pipelines[0]), std::move(pipelines[1]), join, output_stream->header, max_block_size, &processors); joined_pipeline->resize(max_streams); return joined_pipeline; } diff --git a/src/Processors/QueryPlan/JoinStep.h b/src/Processors/QueryPlan/JoinStep.h index 08909ce48a9..a9059a083fe 100644 --- a/src/Processors/QueryPlan/JoinStep.h +++ b/src/Processors/QueryPlan/JoinStep.h @@ -19,8 +19,7 @@ public: JoinPtr join_, size_t max_block_size_, size_t max_streams_, - bool keep_left_read_in_order_, - int null_direction_hint_ = 1); + bool keep_left_read_in_order_); String getName() const override { return "Join"; } @@ -43,7 +42,6 @@ private: size_t max_block_size; size_t max_streams; bool keep_left_read_in_order; - int null_direction_hint; }; /// Special step for the case when Join is already filled. diff --git a/src/Processors/Transforms/MergeJoinTransform.cpp b/src/Processors/Transforms/MergeJoinTransform.cpp index 37a178810cb..b63598483ef 100644 --- a/src/Processors/Transforms/MergeJoinTransform.cpp +++ b/src/Processors/Transforms/MergeJoinTransform.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -271,11 +272,9 @@ bool FullMergeJoinCursor::fullyCompleted() const MergeJoinAlgorithm::MergeJoinAlgorithm( JoinPtr table_join_, const Blocks & input_headers, - size_t max_block_size_, - int null_direction_hint_) + size_t max_block_size_) : table_join(table_join_) , max_block_size(max_block_size_) - , null_direction_hint(null_direction_hint_) , log(getLogger("MergeJoinAlgorithm")) { if (input_headers.size() != 2) @@ -305,6 +304,13 @@ MergeJoinAlgorithm::MergeJoinAlgorithm( size_t right_idx = input_headers[1].getPositionByName(right_key); left_to_right_key_remap[left_idx] = right_idx; } + + auto smjPtr = typeid_cast(table_join.get()); + if (smjPtr) + { + null_direction_hint = smjPtr->getNullDirection(); + } + } void MergeJoinAlgorithm::logElapsed(double seconds) @@ -448,7 +454,7 @@ void dispatchKind(JoinKind kind, Args && ... args) else if (Impl::enabled && kind == JoinKind::Full) return Impl::join(std::forward(args)...); else - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Unsupported join kind: \"{}\"", kind); } std::optional MergeJoinAlgorithm::handleAllJoinState() @@ -856,7 +862,6 @@ MergeJoinTransform::MergeJoinTransform( const Blocks & input_headers, const Block & output_header, size_t max_block_size, - int null_direction_hint_, UInt64 limit_hint_) : IMergingTransform( input_headers, @@ -865,7 +870,7 @@ MergeJoinTransform::MergeJoinTransform( limit_hint_, /* always_read_till_end_= */ false, /* empty_chunk_on_finish_= */ true, - table_join, input_headers, max_block_size, null_direction_hint_) + table_join, input_headers, max_block_size) , log(getLogger("MergeJoinTransform")) { LOG_TRACE(log, "Use MergeJoinTransform"); diff --git a/src/Processors/Transforms/MergeJoinTransform.h b/src/Processors/Transforms/MergeJoinTransform.h index 8af486ea34b..cf9331abd59 100644 --- a/src/Processors/Transforms/MergeJoinTransform.h +++ b/src/Processors/Transforms/MergeJoinTransform.h @@ -227,8 +227,7 @@ private: class MergeJoinAlgorithm final : public IMergingAlgorithm { public: - explicit MergeJoinAlgorithm(JoinPtr table_join, const Blocks & input_headers, size_t max_block_size_, - int null_direction_hint = 1); + explicit MergeJoinAlgorithm(JoinPtr table_join, const Blocks & input_headers, size_t max_block_size_); const char * getName() const override { return "MergeJoinAlgorithm"; } void initialize(Inputs inputs) override; @@ -259,7 +258,7 @@ private: JoinPtr table_join; size_t max_block_size; - int null_direction_hint; + int null_direction_hint = 1; struct Statistic { @@ -284,7 +283,6 @@ public: const Blocks & input_headers, const Block & output_header, size_t max_block_size, - int null_direction_hint, UInt64 limit_hint = 0); String getName() const override { return "MergeJoinTransform"; } diff --git a/src/QueryPipeline/QueryPipelineBuilder.cpp b/src/QueryPipeline/QueryPipelineBuilder.cpp index e338c3ce0fa..67a8fe5dcab 100644 --- a/src/QueryPipeline/QueryPipelineBuilder.cpp +++ b/src/QueryPipeline/QueryPipelineBuilder.cpp @@ -349,7 +349,6 @@ std::unique_ptr QueryPipelineBuilder::joinPipelinesYShaped JoinPtr join, const Block & out_header, size_t max_block_size, - int null_direction_hint, Processors * collected_processors) { left->checkInitializedAndNotCompleted(); @@ -377,7 +376,7 @@ std::unique_ptr QueryPipelineBuilder::joinPipelinesYShaped } else { - auto joining = std::make_shared(join, inputs, out_header, max_block_size, null_direction_hint); + auto joining = std::make_shared(join, inputs, out_header, max_block_size); return mergePipelines(std::move(left), std::move(right), std::move(joining), collected_processors); } } diff --git a/src/QueryPipeline/QueryPipelineBuilder.h b/src/QueryPipeline/QueryPipelineBuilder.h index 4753f957a25..f0b2ead687e 100644 --- a/src/QueryPipeline/QueryPipelineBuilder.h +++ b/src/QueryPipeline/QueryPipelineBuilder.h @@ -137,7 +137,6 @@ public: JoinPtr table_join, const Block & out_header, size_t max_block_size, - int null_direction_hint, Processors * collected_processors = nullptr); /// Add other pipeline and execute it before current one. From 96e90438e0a1f47fc706792492cf86fe093d0b26 Mon Sep 17 00:00:00 2001 From: nemonlou Date: Wed, 20 Mar 2024 14:25:16 +0800 Subject: [PATCH 5/5] fix clang-tiny --- src/Processors/Transforms/MergeJoinTransform.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Processors/Transforms/MergeJoinTransform.cpp b/src/Processors/Transforms/MergeJoinTransform.cpp index b63598483ef..62361bef5e2 100644 --- a/src/Processors/Transforms/MergeJoinTransform.cpp +++ b/src/Processors/Transforms/MergeJoinTransform.cpp @@ -305,7 +305,7 @@ MergeJoinAlgorithm::MergeJoinAlgorithm( left_to_right_key_remap[left_idx] = right_idx; } - auto smjPtr = typeid_cast(table_join.get()); + const auto *smjPtr = typeid_cast(table_join.get()); if (smjPtr) { null_direction_hint = smjPtr->getNullDirection();