Merge pull request #60896 from loudongfeng/master_smj_nullorder

make nulls direction configuable for FullSortingMergeJoin
This commit is contained in:
vdimir 2024-03-21 10:40:43 +01:00 committed by GitHub
commit 9b51780458
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 35 additions and 18 deletions

View File

@ -21,9 +21,11 @@ namespace ErrorCodes
class FullSortingMergeJoin : public IJoin
{
public:
explicit FullSortingMergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_)
explicit FullSortingMergeJoin(std::shared_ptr<TableJoin> table_join_, const Block & right_sample_block_,
int null_direction_ = 1)
: table_join(table_join_)
, right_sample_block(right_sample_block_)
, null_direction(null_direction_)
{
LOG_TRACE(getLogger("FullSortingMergeJoin"), "Will use full sorting merge join");
}
@ -31,6 +33,8 @@ public:
std::string getName() const override { return "FullSortingMergeJoin"; }
const TableJoin & getTableJoin() const override { return *table_join; }
int getNullDirection() const { return null_direction; }
bool addBlockToJoin(const Block & /* block */, bool /* check_limits */) override
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "FullSortingMergeJoin::addBlockToJoin should not be called");
@ -119,6 +123,7 @@ private:
std::shared_ptr<TableJoin> table_join;
Block right_sample_block;
Block totals;
int null_direction;
};
}

View File

@ -16,6 +16,7 @@
#include <Core/SortCursor.h>
#include <Core/SortDescription.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/FullSortingMergeJoin.h>
#include <Interpreters/TableJoin.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Processors/Transforms/MergeJoinTransform.h>
@ -43,7 +44,7 @@ FullMergeJoinCursorPtr createCursor(const Block & block, const Names & columns)
}
template <bool has_left_nulls, bool has_right_nulls>
int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint = 1)
int nullableCompareAt(const IColumn & left_column, const IColumn & right_column, size_t lhs_pos, size_t rhs_pos, int null_direction_hint)
{
if constexpr (has_left_nulls && has_right_nulls)
{
@ -88,35 +89,36 @@ int nullableCompareAt(const IColumn & left_column, const IColumn & right_column,
}
int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, size_t lpos,
const SortCursorImpl & rhs, size_t rpos)
const SortCursorImpl & rhs, size_t rpos,
int null_direction_hint)
{
for (size_t i = 0; i < lhs.sort_columns_size; ++i)
{
/// TODO(@vdimir): use nullableCompareAt only if there's nullable columns
int cmp = nullableCompareAt<true, true>(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos);
int cmp = nullableCompareAt<true, true>(*lhs.sort_columns[i], *rhs.sort_columns[i], lpos, rpos, null_direction_hint);
if (cmp != 0)
return cmp;
}
return 0;
}
int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs)
int ALWAYS_INLINE compareCursors(const SortCursorImpl & lhs, const SortCursorImpl & rhs, int null_direction_hint)
{
return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow());
return compareCursors(lhs, lhs.getRow(), rhs, rhs.getRow(), null_direction_hint);
}
bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs)
bool ALWAYS_INLINE totallyLess(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint)
{
/// The last row of left cursor is less than the current row of the right cursor.
int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow());
int cmp = compareCursors(lhs, lhs.rows - 1, rhs, rhs.getRow(), null_direction_hint);
return cmp < 0;
}
int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs)
int ALWAYS_INLINE totallyCompare(SortCursorImpl & lhs, SortCursorImpl & rhs, int null_direction_hint)
{
if (totallyLess(lhs, rhs))
if (totallyLess(lhs, rhs, null_direction_hint))
return -1;
if (totallyLess(rhs, lhs))
if (totallyLess(rhs, lhs, null_direction_hint))
return 1;
return 0;
}
@ -302,6 +304,13 @@ MergeJoinAlgorithm::MergeJoinAlgorithm(
size_t right_idx = input_headers[1].getPositionByName(right_key);
left_to_right_key_remap[left_idx] = right_idx;
}
const auto *smjPtr = typeid_cast<const FullSortingMergeJoin *>(table_join.get());
if (smjPtr)
{
null_direction_hint = smjPtr->getNullDirection();
}
}
void MergeJoinAlgorithm::logElapsed(double seconds)
@ -366,7 +375,8 @@ struct AllJoinImpl
size_t max_block_size,
PaddedPODArray<UInt64> & left_map,
PaddedPODArray<UInt64> & right_map,
std::unique_ptr<AllJoinState> & state)
std::unique_ptr<AllJoinState> & state,
int null_direction_hint)
{
right_map.clear();
right_map.reserve(max_block_size);
@ -382,7 +392,7 @@ struct AllJoinImpl
lpos = left_cursor->getRow();
rpos = right_cursor->getRow();
cmp = compareCursors(left_cursor.cursor, right_cursor.cursor);
cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint);
if (cmp == 0)
{
size_t lnum = nextDistinct(left_cursor.cursor);
@ -517,7 +527,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::allJoin(JoinKind kind)
{
PaddedPODArray<UInt64> idx_map[2];
dispatchKind<AllJoinImpl>(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state);
dispatchKind<AllJoinImpl>(kind, *cursors[0], *cursors[1], max_block_size, idx_map[0], idx_map[1], all_join_state, null_direction_hint);
assert(idx_map[0].size() == idx_map[1].size());
Chunk result;
@ -576,7 +586,8 @@ struct AnyJoinImpl
FullMergeJoinCursor & right_cursor,
PaddedPODArray<UInt64> & left_map,
PaddedPODArray<UInt64> & right_map,
AnyJoinState & state)
AnyJoinState & state,
int null_direction_hint)
{
assert(enabled);
@ -599,7 +610,7 @@ struct AnyJoinImpl
lpos = left_cursor->getRow();
rpos = right_cursor->getRow();
cmp = compareCursors(left_cursor.cursor, right_cursor.cursor);
cmp = compareCursors(left_cursor.cursor, right_cursor.cursor, null_direction_hint);
if (cmp == 0)
{
if constexpr (isLeftOrFull(kind))
@ -723,7 +734,7 @@ MergeJoinAlgorithm::Status MergeJoinAlgorithm::anyJoin(JoinKind kind)
PaddedPODArray<UInt64> idx_map[2];
size_t prev_pos[] = {current_left.getRow(), current_right.getRow()};
dispatchKind<AnyJoinImpl>(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state);
dispatchKind<AnyJoinImpl>(kind, *cursors[0], *cursors[1], idx_map[0], idx_map[1], any_join_state, null_direction_hint);
assert(idx_map[0].empty() || idx_map[1].empty() || idx_map[0].size() == idx_map[1].size());
size_t num_result_rows = std::max(idx_map[0].size(), idx_map[1].size());
@ -816,7 +827,7 @@ IMergingAlgorithm::Status MergeJoinAlgorithm::merge()
}
/// check if blocks are not intersecting at all
if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor); cmp != 0)
if (int cmp = totallyCompare(cursors[0]->cursor, cursors[1]->cursor, null_direction_hint); cmp != 0)
{
if (cmp < 0)
{

View File

@ -258,6 +258,7 @@ private:
JoinPtr table_join;
size_t max_block_size;
int null_direction_hint = 1;
struct Statistic
{