Merge pull request #59731 from kitaisreal/asof-join-try-sort-with-radix-sort

ASOF JOIN use trySort with RadixSort
This commit is contained in:
Raúl Marín 2024-02-14 15:54:22 +01:00 committed by GitHub
commit 11519f949b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 118 additions and 32 deletions

View File

@ -15,6 +15,7 @@
#include <base/bit_cast.h> #include <base/bit_cast.h>
#include <base/extended_types.h> #include <base/extended_types.h>
#include <base/sort.h>
#include <Core/Defines.h> #include <Core/Defines.h>
@ -114,6 +115,11 @@ struct RadixSortFloatTraits
{ {
return x < y; return x < y;
} }
static bool greater(Key x, Key y)
{
return x > y;
}
}; };
@ -148,6 +154,11 @@ struct RadixSortUIntTraits
{ {
return x < y; return x < y;
} }
static bool greater(Key x, Key y)
{
return x > y;
}
}; };
@ -182,6 +193,11 @@ struct RadixSortIntTraits
{ {
return x < y; return x < y;
} }
static bool greater(Key x, Key y)
{
return x > y;
}
}; };
@ -214,6 +230,22 @@ private:
static KeyBits keyToBits(Key x) { return bit_cast<KeyBits>(x); } static KeyBits keyToBits(Key x) { return bit_cast<KeyBits>(x); }
static Key bitsToKey(KeyBits x) { return bit_cast<Key>(x); } static Key bitsToKey(KeyBits x) { return bit_cast<Key>(x); }
struct LessComparator
{
ALWAYS_INLINE bool operator()(Element & lhs, Element & rhs)
{
return Traits::less(Traits::extractKey(lhs), Traits::extractKey(rhs));
}
};
struct GreaterComparator
{
ALWAYS_INLINE bool operator()(Element & lhs, Element & rhs)
{
return Traits::greater(Traits::extractKey(lhs), Traits::extractKey(rhs));
}
};
static ALWAYS_INLINE KeyBits getPart(size_t N, KeyBits x) static ALWAYS_INLINE KeyBits getPart(size_t N, KeyBits x)
{ {
if (Traits::Transform::transform_is_simple) if (Traits::Transform::transform_is_simple)
@ -504,6 +536,24 @@ private:
radixSortMSDInternal<PASS>(arr, size, limit); radixSortMSDInternal<PASS>(arr, size, limit);
} }
template <bool DIRECT_WRITE_TO_DESTINATION, typename Comparator>
static void executeLSDWithTrySortInternal(Element * arr, size_t size, bool reverse, Comparator comparator, Result * destination)
{
bool try_sort = ::trySort(arr, arr + size, comparator);
if (try_sort)
{
if constexpr (DIRECT_WRITE_TO_DESTINATION)
{
for (size_t i = 0; i < size; ++i)
destination[i] = Traits::extractResult(arr[i]);
}
return;
}
radixSortLSDInternal<DIRECT_WRITE_TO_DESTINATION>(arr, size, reverse, destination);
}
public: public:
/** Least significant digit radix sort (stable). /** Least significant digit radix sort (stable).
* This function will sort inplace (modify 'arr') * This function will sort inplace (modify 'arr')
@ -529,6 +579,38 @@ public:
radixSortLSDInternal<true>(arr, size, reverse, destination); radixSortLSDInternal<true>(arr, size, reverse, destination);
} }
/** Tries to fast sort elements for common sorting patterns (unstable).
* If fast sort cannot be performed, execute least significant digit radix sort.
*/
static void executeLSDWithTrySort(Element * arr, size_t size)
{
return executeLSDWithTrySort(arr, size, false);
}
static void executeLSDWithTrySort(Element * arr, size_t size, bool reverse)
{
return executeLSDWithTrySort(arr, size, reverse, nullptr);
}
static void executeLSDWithTrySort(Element * arr, size_t size, bool reverse, Result * destination)
{
if (reverse)
{
if (destination)
return executeLSDWithTrySortInternal<true>(arr, size, reverse, GreaterComparator(), destination);
else
return executeLSDWithTrySortInternal<false>(arr, size, reverse, GreaterComparator(), destination);
}
else
{
if (destination)
return executeLSDWithTrySortInternal<true>(arr, size, reverse, LessComparator(), destination);
else
return executeLSDWithTrySortInternal<false>(arr, size, reverse, LessComparator(), destination);
}
}
/* Most significant digit radix sort /* Most significant digit radix sort
* Is not stable, but allows partial sorting. * Is not stable, but allows partial sorting.
* And it's more cache-friendly and usually faster than LSD variant. * And it's more cache-friendly and usually faster than LSD variant.

View File

@ -175,45 +175,42 @@ private:
// the array becomes immutable // the array becomes immutable
void sort() void sort()
{ {
if (!sorted.load(std::memory_order_acquire)) if (sorted.load(std::memory_order_acquire))
return;
std::lock_guard<std::mutex> l(lock);
if (sorted.load(std::memory_order_relaxed))
return;
if constexpr (std::is_arithmetic_v<TKey> && !std::is_floating_point_v<TKey>)
{ {
std::lock_guard<std::mutex> l(lock); if (likely(entries.size() > 256))
if (!sorted.load(std::memory_order_relaxed))
{ {
if constexpr (std::is_arithmetic_v<TKey> && !std::is_floating_point_v<TKey>) struct RadixSortTraits : RadixSortNumTraits<TKey>
{ {
if (likely(entries.size() > 256)) using Element = Entry;
{ using Result = Element;
struct RadixSortTraits : RadixSortNumTraits<TKey>
{
using Element = Entry;
using Result = Element;
static TKey & extractKey(Element & elem) { return elem.value; } static TKey & extractKey(Element & elem) { return elem.value; }
static Result extractResult(Element & elem) { return elem; } static Result extractResult(Element & elem) { return elem; }
}; };
if constexpr (is_descending)
RadixSort<RadixSortTraits>::executeLSD(entries.data(), entries.size(), true);
else
RadixSort<RadixSortTraits>::executeLSD(entries.data(), entries.size(), false);
sorted.store(true, std::memory_order_release);
return;
}
}
if constexpr (is_descending)
::sort(entries.begin(), entries.end(), GreaterEntryOperator());
else
::sort(entries.begin(), entries.end(), LessEntryOperator());
RadixSort<RadixSortTraits>::executeLSDWithTrySort(entries.data(), entries.size(), is_descending /*reverse*/);
sorted.store(true, std::memory_order_release); sorted.store(true, std::memory_order_release);
return;
} }
} }
if constexpr (is_descending)
::sort(entries.begin(), entries.end(), GreaterEntryOperator());
else
::sort(entries.begin(), entries.end(), LessEntryOperator());
sorted.store(true, std::memory_order_release);
} }
}; };
} }
AsofRowRefs createAsofRowRef(TypeIndex type, ASOFJoinInequality inequality) AsofRowRefs createAsofRowRef(TypeIndex type, ASOFJoinInequality inequality)

View File

@ -43,6 +43,13 @@
</query> </query>
<substitutions> <substitutions>
<substitution>
<name>num_unique_sessions</name>
<values>
<value>1000</value>
<value>1000000</value>
</values>
</substitution>
<substitution> <substitution>
<name>num_rows</name> <name>num_rows</name>
<values> <values>
@ -56,15 +63,15 @@
FROM FROM
( (
SELECT SELECT
number AS id, (number % {num_unique_sessions}) AS visitor_id,
number AS visitor_id number AS id
FROM system.numbers FROM system.numbers
LIMIT {num_rows} LIMIT {num_rows}
) AS sessions ) AS sessions
ASOF LEFT JOIN ASOF LEFT JOIN
( (
SELECT SELECT
number AS visitor_id, (number % {num_unique_sessions}) AS visitor_id,
number AS starting_session_id number AS starting_session_id
FROM system.numbers FROM system.numbers
LIMIT {num_rows} LIMIT {num_rows}