diff --git a/src/Common/RadixSort.h b/src/Common/RadixSort.h index 0ad2dcafe14..6612fa89085 100644 --- a/src/Common/RadixSort.h +++ b/src/Common/RadixSort.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -208,6 +209,10 @@ private: static constexpr size_t KEY_BITS = sizeof(Key) * 8; static constexpr size_t NUM_PASSES = (KEY_BITS + (Traits::PART_SIZE_BITS - 1)) / Traits::PART_SIZE_BITS; + + static KeyBits keyToBits(Key x) { return ext::bit_cast(x); } + static Key bitsToKey(KeyBits x) { return ext::bit_cast(x); } + static ALWAYS_INLINE KeyBits getPart(size_t N, KeyBits x) { if (Traits::Transform::transform_is_simple) @@ -216,10 +221,12 @@ private: return (x >> (N * Traits::PART_SIZE_BITS)) & PART_BITMASK; } - static KeyBits keyToBits(Key x) { return ext::bit_cast(x); } - static Key bitsToKey(KeyBits x) { return ext::bit_cast(x); } + static ALWAYS_INLINE KeyBits extractPart(size_t N, Element & elem) + { + return getPart(N, keyToBits(Traits::extractKey(elem))); + } - static void insertionSortInternal(Element *arr, size_t size) + static void insertionSortInternal(Element * arr, size_t size) { Element * end = arr + size; for (Element * i = arr + 1; i < end; ++i) @@ -236,92 +243,6 @@ private: } } - /* Main MSD radix sort subroutine - * Puts elements to buckets based on PASS-th digit, then recursively calls insertion sort or itself on the buckets - */ - template - static inline void radixSortMSDInternal(Element * arr, size_t size, size_t limit) - { - Element * last_list[HISTOGRAM_SIZE + 1]; - Element ** last = last_list + 1; - size_t count[HISTOGRAM_SIZE] = {0}; - - for (Element * i = arr; i < arr + size; ++i) - ++count[getPart(PASS, *i)]; - - last_list[0] = last_list[1] = arr; - - size_t buckets_for_recursion = HISTOGRAM_SIZE; - Element * finish = arr + size; - for (size_t i = 1; i < HISTOGRAM_SIZE; ++i) - { - last[i] = last[i - 1] + count[i - 1]; - if (last[i] >= arr + limit) - { - buckets_for_recursion = i; - finish = last[i]; - } - } - - /* At this point, we have the following variables: - * count[i] is the size of i-th bucket - * last[i] is a pointer to the beginning of i-th bucket, last[-1] == last[0] - * buckets_for_recursion is the number of buckets that should be sorted, the last of them only partially - * finish is a pointer to the end of the first buckets_for_recursion buckets - */ - - // Scatter array elements to buckets until the first buckets_for_recursion buckets are full - for (size_t i = 0; i < buckets_for_recursion; ++i) - { - Element * end = last[i - 1] + count[i]; - if (end == finish) - { - last[i] = end; - break; - } - while (last[i] != end) - { - Element swapper = *last[i]; - KeyBits tag = getPart(PASS, swapper); - if (tag != i) - { - do - { - std::swap(swapper, *last[tag]++); - } while ((tag = getPart(PASS, swapper)) != i); - *last[i] = swapper; - } - ++last[i]; - } - } - - if constexpr (PASS > 0) - { - // Recursively sort buckets, except the last one - for (size_t i = 0; i < buckets_for_recursion - 1; ++i) - { - Element * start = last[i - 1]; - size_t subsize = last[i] - last[i - 1]; - radixSortMSDInternalHelper(start, subsize, subsize); - } - - // Sort last necessary bucket with limit - Element * start = last[buckets_for_recursion - 2]; - size_t subsize = last[buckets_for_recursion - 1] - last[buckets_for_recursion - 2]; - size_t sublimit = limit - (last[buckets_for_recursion - 1] - arr); - radixSortMSDInternalHelper(start, subsize, sublimit); - } - } - - // A helper to choose sorting algorithm based on array length - template - static inline void radixSortMSDInternalHelper(Element * arr, size_t size, size_t limit) - { - if (size <= INSERTION_SORT_THRESHOLD) - insertionSortInternal(arr, size); - else - radixSortMSDInternal(arr, size, limit); - } template static NO_INLINE void radixSortLSDInternal(Element * arr, size_t size, bool reverse, Result * destination) @@ -346,7 +267,7 @@ private: Traits::extractKey(arr[i]) = bitsToKey(Traits::Transform::forward(keyToBits(Traits::extractKey(arr[i])))); for (size_t pass = 0; pass < NUM_PASSES; ++pass) - ++histograms[pass * HISTOGRAM_SIZE + getPart(pass, keyToBits(Traits::extractKey(arr[i])))]; + ++histograms[pass * HISTOGRAM_SIZE + extractPart(pass, arr[i])]; } { @@ -372,7 +293,7 @@ private: for (size_t i = 0; i < size; ++i) { - size_t pos = getPart(pass, keyToBits(Traits::extractKey(reader[i]))); + size_t pos = extractPart(pass, reader[i]); /// Place the element on the next free position. auto & dest = writer[++histograms[pass * HISTOGRAM_SIZE + pos]]; @@ -394,7 +315,7 @@ private: { for (size_t i = 0; i < size; ++i) { - size_t pos = getPart(pass, keyToBits(Traits::extractKey(reader[i]))); + size_t pos = extractPart(pass, reader[i]); writer[size - 1 - (++histograms[pass * HISTOGRAM_SIZE + pos])] = Traits::extractResult(reader[i]); } } @@ -402,7 +323,7 @@ private: { for (size_t i = 0; i < size; ++i) { - size_t pos = getPart(pass, keyToBits(Traits::extractKey(reader[i]))); + size_t pos = extractPart(pass, reader[i]); writer[++histograms[pass * HISTOGRAM_SIZE + pos]] = Traits::extractResult(reader[i]); } } @@ -413,7 +334,7 @@ private: if (NUM_PASSES % 2) memcpy(arr, swap_buffer, size * sizeof(Element)); - /// This is suboptimal, we can embed it to the last pass. + /// TODO This is suboptimal, we can embed it to the last pass. if (reverse) std::reverse(arr, arr + size); } @@ -421,6 +342,169 @@ private: allocator.deallocate(swap_buffer, size * sizeof(Element)); } + + /* Main MSD radix sort subroutine. + * Puts elements to buckets based on PASS-th digit, then recursively calls insertion sort or itself on the buckets. + * + * TODO: Provide support for 'reverse' and 'DIRECT_WRITE_TO_DESTINATION'. + * + * Invariant: higher significant parts of the elements than PASS are constant within arr or is is the first PASS. + * PASS is counted from least significant (0), so the first pass is NUM_PASSES - 1. + */ + template + static inline void radixSortMSDInternal(Element * arr, size_t size, size_t limit) + { +// std::cerr << PASS << ", " << size << ", " << limit << "\n"; + + /// The beginning of every i-1-th bucket. 0th element will be equal to 1st. + /// Last element will point to array end. + Element * prev_buckets[HISTOGRAM_SIZE + 1]; + /// The beginning of every i-th bucket (the same array shifted by one). + Element ** buckets = &prev_buckets[1]; + + prev_buckets[0] = arr; + prev_buckets[1] = arr; + + /// The end of the range of buckets that we need with limit. + Element * finish = arr + size; + + /// Count histogram of current element parts. + + /// We use loop unrolling to minimize data dependencies and increase instruction level parallelism. + /// Unroll 8 times looks better on experiments; + /// also it corresponds with the results from https://github.com/powturbo/TurboHist + + static constexpr size_t UNROLL_COUNT = 8; + CountType count[HISTOGRAM_SIZE * UNROLL_COUNT]{}; + size_t unrolled_size = size / UNROLL_COUNT * UNROLL_COUNT; + + for (Element * elem = arr; elem < arr + unrolled_size; elem += UNROLL_COUNT) + for (size_t i = 0; i < UNROLL_COUNT; ++i) + ++count[i * HISTOGRAM_SIZE + extractPart(PASS, elem[i])]; + + for (Element * elem = arr + unrolled_size; elem < arr + size; ++elem) + ++count[extractPart(PASS, *elem)]; + + for (size_t i = 0; i < HISTOGRAM_SIZE; ++i) + for (size_t j = 1; j < UNROLL_COUNT; ++j) + count[i] += count[j * HISTOGRAM_SIZE + i]; + + /// Fill pointers to buckets according to the histogram. + + /// How many buckets we will recurse into. + ssize_t buckets_for_recursion = HISTOGRAM_SIZE; + bool finish_early = false; + + for (size_t i = 1; i < HISTOGRAM_SIZE; ++i) + { + /// Positions are just a cumulative sum of counts. + buckets[i] = buckets[i - 1] + count[i - 1]; + + /// If this bucket starts after limit, we don't need it. + if (!finish_early && buckets[i] >= arr + limit) + { + buckets_for_recursion = i; + finish = buckets[i]; + finish_early = true; + /// We cannot break here, because we need correct pointers to all buckets, see the next loop. + } + } + + /* At this point, we have the following variables: + * count[i] is the size of i-th bucket + * buckets[i] is a pointer to the beginning of i-th bucket, buckets[-1] == buckets[0] + * buckets_for_recursion is the number of buckets that should be sorted, the last of them only partially + * finish is a pointer to the end of the first buckets_for_recursion buckets + */ + + /// Scatter array elements to buckets until the first buckets_for_recursion buckets are full + /// After the above loop, buckets are shifted towards the end and now pointing to the beginning of i+1th bucket. + + for (ssize_t i = 0; /* guarded by 'finish' */; ++i) + { + assert(i < buckets_for_recursion); + + /// We look at i-1th index, because bucket pointers are shifted right on every loop iteration, + /// and all buckets before i was completely shifted to the beginning of the next bucket. + /// So, the beginning of i-th bucket is at buckets[i - 1]. + + Element * bucket_end = buckets[i - 1] + count[i]; + + /// Fill this bucket. + while (buckets[i] != bucket_end) + { + Element swapper = *buckets[i]; + KeyBits tag = extractPart(PASS, swapper); + + if (tag != KeyBits(i)) + { + /// Invariant: tag > i, because the elements with less tags are already at the right places. + assert(tag > KeyBits(i)); + + /// While the tag (digit) of the element is not that we need, + /// swap the element with the next element in the bucket for that tag. + + /// Interesting observation: + /// - we will definitely find the needed element, + /// because the tag's bucket will contain at least one "wrong" element, + /// because the "right" element is appeared in our bucket. + + /// After this loop we shift buckets[i] and buckets[tag] pointers to the right for all found tags. + /// And all positions that were traversed are filled with the proper values. + + do + { + std::swap(swapper, *buckets[tag]); + ++buckets[tag]; + tag = extractPart(PASS, swapper); + } while (tag != KeyBits(i)); + *buckets[i] = swapper; + } + + /// Now we have the right element at this place. + ++buckets[i]; + } + + if (bucket_end == finish) + break; + } + + /// Recursion for the relevant buckets. + + if constexpr (PASS > 0) + { + /// Recursively sort buckets, except the last one + for (ssize_t i = 0; i < buckets_for_recursion - 1; ++i) + { + Element * start = buckets[i - 1]; + ssize_t subsize = count[i]; + + radixSortMSDInternalHelper(start, subsize, subsize); + } + + /// Sort the last necessary bucket with limit + { + ssize_t i = buckets_for_recursion - 1; + + Element * start = buckets[i - 1]; + ssize_t subsize = count[i]; + ssize_t sublimit = limit - (start - arr); + + radixSortMSDInternalHelper(start, subsize, sublimit); + } + } + } + + // A helper to choose sorting algorithm based on array length + template + static inline void radixSortMSDInternalHelper(Element * arr, size_t size, size_t limit) + { + if (size <= INSERTION_SORT_THRESHOLD) + insertionSortInternal(arr, size); + else + radixSortMSDInternal(arr, size, limit); + } + public: /** Least significant digit radix sort (stable). * This function will sort inplace (modify 'arr') @@ -442,7 +526,14 @@ public: } /* Most significant digit radix sort - * Usually slower than LSD and is not stable, but allows partial sorting + * Is not stable, but allows partial sorting. + * And it's more cache-friendly and usually faster than LSD variant. + * + * NOTE: It's beneficial over std::partial_sort only if limit is above ~2% of size for 8 bit radix. + * NOTE: When lowering down limit to 1%, the radix of 4..6 or 10..12 bit started to become beneficial. + * For less than 1% limit, it's not recommended to use. + * NOTE: For huge arrays without limit, the radix 11 suddenly becomes better... but not for smaller arrays. + * Maybe it because histogram will fit in half of L1d cache (2048 * 4 = 16384). * * Based on https://github.com/voutcn/kxsort, license: * The MIT License @@ -480,13 +571,13 @@ public: /// Use RadixSort with custom traits for complex types instead. template -void radixSortLSD(T *arr, size_t size) +void radixSortLSD(T * arr, size_t size) { RadixSort>::executeLSD(arr, size); } template -void radixSortMSD(T *arr, size_t size, size_t limit) +void radixSortMSD(T * arr, size_t size, size_t limit) { RadixSort>::executeMSD(arr, size, limit); } diff --git a/src/Common/tests/CMakeLists.txt b/src/Common/tests/CMakeLists.txt index 72c47d1ef49..b68e71c0b43 100644 --- a/src/Common/tests/CMakeLists.txt +++ b/src/Common/tests/CMakeLists.txt @@ -34,6 +34,7 @@ target_link_libraries (compact_array PRIVATE clickhouse_common_io) add_executable (radix_sort radix_sort.cpp) target_link_libraries (radix_sort PRIVATE clickhouse_common_io) +target_include_directories(radix_sort SYSTEM PRIVATE ${PDQSORT_INCLUDE_DIR}) if (USE_OPENCL) add_executable (bitonic_sort bitonic_sort.cpp) diff --git a/src/Common/tests/radix_sort.cpp b/src/Common/tests/radix_sort.cpp index a7313d05cec..9288e32e734 100644 --- a/src/Common/tests/radix_sort.cpp +++ b/src/Common/tests/radix_sort.cpp @@ -1,14 +1,23 @@ -#if !defined(__APPLE__) && !defined(__FreeBSD__) -#include -#endif +#include #include #include + +//#if defined(NDEBUG) +//#undef NDEBUG #include +//#endif + #include +#include #include #include +#include -using Key = double; +/// Example: +/// for i in {6,8} {11..26}; do echo $i; for j in {1..10}; do ./radix_sort $i 65536 1000; done; echo; done + + +using Key = UInt64; static void NO_INLINE sort1(Key * data, size_t size) { @@ -24,29 +33,150 @@ static void NO_INLINE sort3(Key * data, size_t size) { std::sort(data, data + size, [](Key a, Key b) { - return RadixSortFloatTransform::forward(ext::bit_cast(a)) - < RadixSortFloatTransform::forward(ext::bit_cast(b)); + return RadixSortFloatTransform::forward(ext::bit_cast(a)) + < RadixSortFloatTransform::forward(ext::bit_cast(b)); }); } +static void NO_INLINE sort4(Key * data, size_t size) +{ + radixSortMSD(data, size, size); +} + +static void NO_INLINE sort5(Key * data, size_t size) +{ + pdqsort(data, data + size); +} + + +static void NO_INLINE sort6(Key * data, size_t size, size_t limit) +{ + std::partial_sort(data, data + limit, data + size); +} + +static void NO_INLINE sort7(Key * data, size_t size, size_t limit) +{ + std::partial_sort(data, data + limit, data + size, [](Key a, Key b) + { + return RadixSortFloatTransform::forward(ext::bit_cast(a)) + < RadixSortFloatTransform::forward(ext::bit_cast(b)); + }); +} + +static void NO_INLINE sort8(Key * data, size_t size, size_t limit) +{ + radixSortMSD(data, size, limit); +} + + +template +struct RadixSortTraitsWithCustomBits : RadixSortNumTraits +{ + static constexpr size_t PART_SIZE_BITS = N; +}; + +static void NO_INLINE sort11(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort12(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort13(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort14(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort15(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort16(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort17(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort18(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort19(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort20(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort21(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort22(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort23(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort24(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort25(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + +static void NO_INLINE sort26(Key * data, size_t size, size_t limit) +{ + RadixSort>::executeMSD(data, size, limit); +} + int main(int argc, char ** argv) { - pcg64 rng; + pcg64 rng(randomSeed()); - if (argc < 3) + if (argc < 3 || argc > 4) { - std::cerr << "Usage: program n method\n"; + std::cerr << "Usage: program method n [limit]\n"; return 1; } - size_t n = DB::parse(argv[1]); - size_t method = DB::parse(argv[2]); + size_t method = DB::parse(argv[1]); + size_t n = DB::parse(argv[2]); + size_t limit = n; + + if (argc == 4) + limit = DB::parse(argv[3]); + + std::cerr << std::fixed << std::setprecision(3); std::vector data(n); -// srand(time(nullptr)); - { Stopwatch watch; @@ -54,12 +184,12 @@ int main(int argc, char ** argv) elem = rng(); watch.stop(); - double elapsed = watch.elapsedSeconds(); + /* double elapsed = watch.elapsedSeconds(); std::cerr << "Filled in " << elapsed << " (" << n / elapsed << " elem/sec., " << n * sizeof(Key) / elapsed / 1048576 << " MB/sec.)" - << std::endl; + << std::endl;*/ } if (n <= 100) @@ -70,13 +200,34 @@ int main(int argc, char ** argv) std::cerr << std::endl; } - { Stopwatch watch; - if (method == 1) sort1(data.data(), n); - if (method == 2) sort2(data.data(), n); - if (method == 3) sort3(data.data(), n); + if (method == 1) sort1(data.data(), n); + if (method == 2) sort2(data.data(), n); + if (method == 3) sort3(data.data(), n); + if (method == 4) sort4(data.data(), n); + if (method == 5) sort5(data.data(), n); + if (method == 6) sort6(data.data(), n, limit); + if (method == 7) sort7(data.data(), n, limit); + if (method == 8) sort8(data.data(), n, limit); + + if (method == 11) sort11(data.data(), n, limit); + if (method == 12) sort12(data.data(), n, limit); + if (method == 13) sort13(data.data(), n, limit); + if (method == 14) sort14(data.data(), n, limit); + if (method == 15) sort15(data.data(), n, limit); + if (method == 16) sort16(data.data(), n, limit); + if (method == 17) sort17(data.data(), n, limit); + if (method == 18) sort18(data.data(), n, limit); + if (method == 19) sort19(data.data(), n, limit); + if (method == 20) sort20(data.data(), n, limit); + if (method == 21) sort21(data.data(), n, limit); + if (method == 22) sort22(data.data(), n, limit); + if (method == 23) sort23(data.data(), n, limit); + if (method == 24) sort24(data.data(), n, limit); + if (method == 25) sort25(data.data(), n, limit); + if (method == 26) sort26(data.data(), n, limit); watch.stop(); double elapsed = watch.elapsedSeconds(); @@ -87,33 +238,39 @@ int main(int argc, char ** argv) << std::endl; } + bool ok = true; + { Stopwatch watch; size_t i = 1; - while (i < n) + while (i < limit) { if (!(data[i - 1] <= data[i])) + { + ok = false; break; + } ++i; } watch.stop(); double elapsed = watch.elapsedSeconds(); - std::cerr - << "Checked in " << elapsed - << " (" << n / elapsed << " elem/sec., " - << n * sizeof(Key) / elapsed / 1048576 << " MB/sec.)" - << std::endl - << "Result: " << (i == n ? "Ok." : "Fail!") << std::endl; + if (!ok) + std::cerr + << "Checked in " << elapsed + << " (" << limit / elapsed << " elem/sec., " + << limit * sizeof(Key) / elapsed / 1048576 << " MB/sec.)" + << std::endl + << "Result: " << (ok ? "Ok." : "Fail!") << std::endl; } - if (n <= 1000) + if (!ok && limit <= 100000) { std::cerr << std::endl; std::cerr << data[0] << ' '; - for (size_t i = 1; i < n; ++i) + for (size_t i = 1; i < limit; ++i) { if (!(data[i - 1] <= data[i])) std::cerr << "*** ";