Merge pull request #60896 from loudongfeng/master_smj_nullorder

make nulls direction configuable for FullSortingMergeJoin
This commit is contained in:
vdimir 2024-03-21 10:40:43 +01:00 committed by GitHub
commit 9b51780458
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 18 deletions

View File

@ -21,9 +21,11 @@ namespace ErrorCodes
class FullSortingMergeJoin : public IJoin class FullSortingMergeJoin : public IJoin
{ {
public: public:
explicit FullSortingMergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_) explicit FullSortingMergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_,
int null_direction_ = 1)
: table_join(table_join_) : table_join(table_join_)
, right_sample_block(right_sample_block_) , right_sample_block(right_sample_block_)
, null_direction(null_direction_)
{ {
LOG_TRACE(getLogger("FullSortingMergeJoin"), "Will use full sorting merge join"); LOG_TRACE(getLogger("FullSortingMergeJoin"), "Will use full sorting merge join");
} }
@ -31,6 +33,8 @@ public:
std::string getName() const override { return "FullSortingMergeJoin"; } std::string getName() const override { return "FullSortingMergeJoin"; }
const TableJoin & getTableJoin() const override { return *table_join; } const TableJoin & getTableJoin() const override { return *table_join; }
int getNullDirection() const { return null_direction; }
bool addBlockToJoin(const Block & /* block */, bool /* check_limits */) override bool addBlockToJoin(const Block & /* block */, bool /* check_limits */) override
{ {
throw Exception(ErrorCodes::LOGICAL_ERROR, "FullSortingMergeJoin::addBlockToJoin should not be called"); throw Exception(ErrorCodes::LOGICAL_ERROR, "FullSortingMergeJoin::addBlockToJoin should not be called");
@ -119,6 +123,7 @@ private:
std::shared_ptr<TableJoin> table_join; std::shared_ptr<TableJoin> table_join;
Block right_sample_block; Block right_sample_block;
Block totals; Block totals;
int null_direction;
}; };
} }

View File

@ -16,6 +16,7 @@
#include <Core/SortCursor.h> #include <Core/SortCursor.h>
#include <Core/SortDescription.h> #include <Core/SortDescription.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <Interpreters/FullSortingMergeJoin.h>
#include <Interpreters/TableJoin.h> #include <Interpreters/TableJoin.h>
#include <Parsers/ASTTablesInSelectQuery.h> #include <Parsers/ASTTablesInSelectQuery.h>
#include <Processors/Transforms/MergeJoinTransform.h> #include <Processors/Transforms/MergeJoinTransform.h>
@ -43,7 +44,7 @@ FullMergeJoinCursorPtr createCursor(const Block & block, const Names & columns)
} }
template <bool has_left_nulls, bool has_right_nulls> template <bool has_left_nulls, bool has_right_nulls>
int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint = 1) int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint)
{ {
if constexpr (has_left_nulls && has_right_nulls) if constexpr (has_left_nulls && has_right_nulls)
{ {
@ -88,35 +89,36 @@ int nullableCompareAt(const IColumn & left_column, const IColumn & right_column,
} }
int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, size_t lpos, int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, size_t lpos,
const SortCursorImpl & rhs, size_t rpos) const SortCursorImpl & rhs, size_t rpos,
int null_direction_hint)
{ {
for (size_t i = 0; i < lhs.sort_columns_size; ++i) for (size_t i = 0; i < lhs.sort_columns_size; ++i)
{ {
/// TODO(@vdimir): use nullableCompareAt only if there's nullable columns /// TODO(@vdimir): use nullableCompareAt only if there's nullable columns
int cmp = nullableCompareAt<true, true>(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos); int cmp = nullableCompareAt<true, true>(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos, null_direction_hint);
if (cmp != 0) if (cmp != 0)
return cmp; return cmp;
} }
return 0; return 0;
} }
int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs) int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs, int null_direction_hint)
{ {
return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow()); return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow(), null_direction_hint);
} }
bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs) bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint)
{ {
/// The last row of left cursor is less than the current row of the right cursor. /// The last row of left cursor is less than the current row of the right cursor.
int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow()); int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow(), null_direction_hint);
return cmp < 0; return cmp < 0;
} }
int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs) int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint)
{ {
if (totallyLess(lhs, rhs)) if (totallyLess(lhs, rhs, null_direction_hint))
return -1; return -1;
if (totallyLess(rhs, lhs)) if (totallyLess(rhs, lhs, null_direction_hint))
return 1; return 1;
return 0; return 0;
} }
@ -302,6 +304,13 @@ MergeJoinAlgorithm::MergeJoinAlgorithm(
size_t right_idx = input_headers[1].getPositionByName(right_key); size_t right_idx = input_headers[1].getPositionByName(right_key);
left_to_right_key_remap[left_idx] = right_idx; left_to_right_key_remap[left_idx] = right_idx;
} }
const auto *smjPtr = typeid_cast<const FullSortingMergeJoin *>(table_join.get());
if (smjPtr)
{
null_direction_hint = smjPtr->getNullDirection();
}
} }
void MergeJoinAlgorithm::logElapsed(double seconds) void MergeJoinAlgorithm::logElapsed(double seconds)
@ -366,7 +375,8 @@ struct AllJoinImpl
size_t max_block_size, size_t max_block_size,
PaddedPODArray<UInt64> & left_map, PaddedPODArray<UInt64> & left_map,
PaddedPODArray<UInt64> & right_map, PaddedPODArray<UInt64> & right_map,
std::unique_ptr<AllJoinState> & state) std::unique_ptr<AllJoinState> & state,
int null_direction_hint)
{ {
right_map.clear(); right_map.clear();
right_map.reserve(max_block_size); right_map.reserve(max_block_size);
@ -382,7 +392,7 @@ struct AllJoinImpl
lpos = left_cursor->getRow(); lpos = left_cursor->getRow();
rpos = right_cursor->getRow(); rpos = right_cursor->getRow();
cmp = compareCursors(left_cursor.cursor, right_cursor.cursor); cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint);
if (cmp == 0) if (cmp == 0)
{ {
size_t lnum = nextDistinct(left_cursor.cursor); size_t lnum = nextDistinct(left_cursor.cursor);
@ -517,7 +527,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::allJoin(JoinKind kind)
{ {
PaddedPODArray<UInt64> idx_map[2]; PaddedPODArray<UInt64> idx_map[2];
dispatchKind<AllJoinImpl>(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state); dispatchKind<AllJoinImpl>(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state, null_direction_hint);
assert(idx_map[0].size() == idx_map[1].size()); assert(idx_map[0].size() == idx_map[1].size());
Chunk result; Chunk result;
@ -576,7 +586,8 @@ struct AnyJoinImpl
FullMergeJoinCursor & right_cursor, FullMergeJoinCursor & right_cursor,
PaddedPODArray<UInt64> & left_map, PaddedPODArray<UInt64> & left_map,
PaddedPODArray<UInt64> & right_map, PaddedPODArray<UInt64> & right_map,
AnyJoinState & state) AnyJoinState & state,
int null_direction_hint)
{ {
assert(enabled); assert(enabled);
@ -599,7 +610,7 @@ struct AnyJoinImpl
lpos = left_cursor->getRow(); lpos = left_cursor->getRow();
rpos = right_cursor->getRow(); rpos = right_cursor->getRow();
cmp = compareCursors(left_cursor.cursor, right_cursor.cursor); cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint);
if (cmp == 0) if (cmp == 0)
{ {
if constexpr (isLeftOrFull(kind)) if constexpr (isLeftOrFull(kind))
@ -723,7 +734,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::anyJoin(JoinKind kind)
PaddedPODArray<UInt64> idx_map[2]; PaddedPODArray<UInt64> idx_map[2];
size_t prev_pos[] = {current_left.getRow(), current_right.getRow()}; size_t prev_pos[] = {current_left.getRow(), current_right.getRow()};
dispatchKind<AnyJoinImpl>(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state); dispatchKind<AnyJoinImpl>(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state, null_direction_hint);
assert(idx_map[0].empty() || idx_map[1].empty() || idx_map[0].size() == idx_map[1].size()); assert(idx_map[0].empty() || idx_map[1].empty() || idx_map[0].size() == idx_map[1].size());
size_t num_result_rows = std::max(idx_map[0].size(), idx_map[1].size()); size_t num_result_rows = std::max(idx_map[0].size(), idx_map[1].size());
@ -816,7 +827,7 @@ IMergingAlgorithm::Status MergeJoinAlgorithm::merge()
} }
/// check if blocks are not intersecting at all /// check if blocks are not intersecting at all
if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor); cmp != 0) if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor, null_direction_hint); cmp != 0)
{ {
if (cmp < 0) if (cmp < 0)
{ {

View File

@ -258,6 +258,7 @@ private:
JoinPtr table_join; JoinPtr table_join;
size_t max_block_size; size_t max_block_size;
int null_direction_hint = 1;
struct Statistic struct Statistic
{ {