From 9bc9e87cbaff2b2fa8c739df47a4f08b9ae80037 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Tue, 5 Apr 2022 14:32:50 +0200 Subject: [PATCH 1/5] DecimalColumn improve getPermutation performance using RadixSort --- src/Columns/ColumnDecimal.cpp | 54 ++++++++++++++++++++++++++++++ src/Columns/ColumnVector.cpp | 62 +++++++++++++---------------------- 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/src/Columns/ColumnDecimal.cpp b/src/Columns/ColumnDecimal.cpp index 0d82818a431..e5ebdc3666f 100644 --- a/src/Columns/ColumnDecimal.cpp +++ b/src/Columns/ColumnDecimal.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -138,6 +139,26 @@ void ColumnDecimal::updateHashFast(SipHash & hash) const hash.update(reinterpret_cast(data.data()), size() * sizeof(data[0])); } +namespace +{ + template + struct ValueWithIndex + { + T value; + UInt32 index; + }; + + template + struct RadixSortTraits : RadixSortNumTraits + { + using Element = ValueWithIndex; + using Result = size_t; + + static T & extractKey(Element & elem) { return elem.value; } + static size_t extractResult(Element & elem) { return elem.index; } + }; +} + template void ColumnDecimal::getPermutation(IColumn::PermutationSortDirection direction, IColumn::PermutationSortStability stability, size_t limit, int, IColumn::Permutation & res) const @@ -159,6 +180,39 @@ void ColumnDecimal::getPermutation(IColumn::PermutationSortDirection directio return data[lhs] > data[rhs]; }; + size_t data_size = data.size(); + res.resize(data_size); + + if (limit >= data_size) { + limit = 0; + } + + if (!limit) + { + /// A case for radix sort + /// LSD RadixSort is stable + if constexpr (is_arithmetic_v && !is_big_int_v) + { + bool reverse = direction == IColumn::PermutationSortDirection::Descending; + bool ascending = direction == IColumn::PermutationSortDirection::Ascending; + bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; + + /// TODO: LSD RadixSort is currently not stable if direction is descending + bool use_radix_sort = (sort_is_stable && ascending) || !sort_is_stable; + + /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. + if (data_size >= 256 && data_size <= std::numeric_limits::max() && use_radix_sort) + { + PaddedPODArray> pairs(data_size); + for (UInt32 i = 0; i < UInt32(data_size); ++i) + pairs[i] = {data[i].value, i}; + + RadixSort>::executeLSD(pairs.data(), data_size, reverse, res.data()); + return; + } + } + } + if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Unstable) this->getPermutationImpl(limit, res, comparator_ascending, DefaultSort(), DefaultPartialSort()); else if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Stable) diff --git a/src/Columns/ColumnVector.cpp b/src/Columns/ColumnVector.cpp index e46384e4d03..c1150957a8f 100644 --- a/src/Columns/ColumnVector.cpp +++ b/src/Columns/ColumnVector.cpp @@ -254,31 +254,16 @@ template void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction, IColumn::PermutationSortStability stability, size_t limit, int nan_direction_hint, IColumn::Permutation & res) const { - size_t s = data.size(); - res.resize(s); + size_t data_size = data.size(); + res.resize(data_size); - if (s == 0) + if (data_size == 0) return; - if (limit >= s) + if (limit >= data_size) limit = 0; - if (limit) - { - for (size_t i = 0; i < s; ++i) - res[i] = i; - - if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Unstable) - ::partial_sort(res.begin(), res.begin() + limit, res.end(), less(*this, nan_direction_hint)); - else if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Stable) - ::partial_sort(res.begin(), res.begin() + limit, res.end(), less_stable(*this, nan_direction_hint)); - else if (direction == IColumn::PermutationSortDirection::Descending && stability == IColumn::PermutationSortStability::Unstable) - ::partial_sort(res.begin(), res.begin() + limit, res.end(), greater(*this, nan_direction_hint)); - else if (direction == IColumn::PermutationSortDirection::Descending && stability == IColumn::PermutationSortStability::Stable) - ::partial_sort(res.begin(), res.begin() + limit, res.end(), greater_stable(*this, nan_direction_hint)); - } - else - { + if (!limit) { /// A case for radix sort /// LSD RadixSort is stable if constexpr (is_arithmetic_v && !is_big_int_v) @@ -291,13 +276,13 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction bool use_radix_sort = (sort_is_stable && ascending && !std::is_floating_point_v) || !sort_is_stable; /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. - if (s >= 256 && s <= std::numeric_limits::max() && use_radix_sort) + if (data_size >= 256 && data_size <= std::numeric_limits::max() && use_radix_sort) { - PaddedPODArray> pairs(s); - for (UInt32 i = 0; i < static_cast(s); ++i) + PaddedPODArray> pairs(data_size); + for (UInt32 i = 0; i < static_cast(data_size); ++i) pairs[i] = {data[i], i}; - RadixSort>::executeLSD(pairs.data(), s, reverse, res.data()); + RadixSort>::executeLSD(pairs.data(), data_size, reverse, res.data()); /// Radix sort treats all NaNs to be greater than all numbers. /// If the user needs the opposite, we must move them accordingly. @@ -305,9 +290,9 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction { size_t nans_to_move = 0; - for (size_t i = 0; i < s; ++i) + for (size_t i = 0; i < data_size; ++i) { - if (isNaN(data[res[reverse ? i : s - 1 - i]])) + if (isNaN(data[res[reverse ? i : data_size - 1 - i]])) ++nans_to_move; else break; @@ -315,26 +300,23 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction if (nans_to_move) { - std::rotate(std::begin(res), std::begin(res) + (reverse ? nans_to_move : s - nans_to_move), std::end(res)); + std::rotate(std::begin(res), std::begin(res) + (reverse ? nans_to_move : data_size - nans_to_move), std::end(res)); } } + return; } } - - /// Default sorting algorithm. - for (size_t i = 0; i < s; ++i) - res[i] = i; - - if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Unstable) - ::sort(res.begin(), res.end(), less(*this, nan_direction_hint)); - else if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Stable) - ::sort(res.begin(), res.end(), less_stable(*this, nan_direction_hint)); - else if (direction == IColumn::PermutationSortDirection::Descending && stability == IColumn::PermutationSortStability::Unstable) - ::sort(res.begin(), res.end(), greater(*this, nan_direction_hint)); - else if (direction == IColumn::PermutationSortDirection::Descending && stability == IColumn::PermutationSortStability::Stable) - ::sort(res.begin(), res.end(), greater_stable(*this, nan_direction_hint)); } + + if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Unstable) + this->getPermutationImpl(limit, res, less(*this, nan_direction_hint), DefaultSort(), DefaultPartialSort()); + else if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Stable) + this->getPermutationImpl(limit, res, less_stable(*this, nan_direction_hint), DefaultSort(), DefaultPartialSort()); + else if (direction == IColumn::PermutationSortDirection::Descending && stability == IColumn::PermutationSortStability::Unstable) + this->getPermutationImpl(limit, res, greater(*this, nan_direction_hint), DefaultSort(), DefaultPartialSort()); + else + this->getPermutationImpl(limit, res, greater_stable(*this, nan_direction_hint), DefaultSort(), DefaultPartialSort()); } template From 0a9835d0858e34a4b22231a8117c7e3a6c6326c1 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Tue, 5 Apr 2022 20:39:21 +0200 Subject: [PATCH 2/5] Added performance tests --- tests/performance/merge_tree_insert.xml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/performance/merge_tree_insert.xml b/tests/performance/merge_tree_insert.xml index 1e987d27d50..ec991c458c2 100644 --- a/tests/performance/merge_tree_insert.xml +++ b/tests/performance/merge_tree_insert.xml @@ -18,15 +18,29 @@ merge_tree_insert_6 + + + decimal_primary_key_table_name + + merge_tree_insert_7 + merge_tree_insert_8 + merge_tree_insert_9 + + CREATE TABLE merge_tree_insert_1 (value_1 UInt64, value_2 UInt64, value_3 UInt64) ENGINE = MergeTree ORDER BY (value_1) CREATE TABLE merge_tree_insert_2 (value_1 UInt64, value_2 UInt64, value_3 UInt64) ENGINE = MergeTree ORDER BY (value_1, value_2) CREATE TABLE merge_tree_insert_3 (value_1 UInt64, value_2 UInt64, value_3 UInt64) ENGINE = MergeTree ORDER BY (value_1, value_2, value_3) + CREATE TABLE merge_tree_insert_4 (value_1 String, value_2 String, value_3 String) ENGINE = MergeTree ORDER BY (value_1) CREATE TABLE merge_tree_insert_5 (value_1 String, value_2 String, value_3 String) ENGINE = MergeTree ORDER BY (value_1, value_2) CREATE TABLE merge_tree_insert_6 (value_1 String, value_2 String, value_3 String) ENGINE = MergeTree ORDER BY (value_1, value_2, value_3) + CREATE TABLE merge_tree_insert_4 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1) + CREATE TABLE merge_tree_insert_5 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1, value_2) + CREATE TABLE merge_tree_insert_6 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1, value_2, value_3) + INSERT INTO {integer_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 500000 INSERT INTO {integer_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 1000000 INSERT INTO {integer_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 1500000 @@ -35,7 +49,12 @@ INSERT INTO {string_primary_key_table_name} SELECT toString(rand64(0)), toString(rand64(1)), toString(rand64(2)) FROM system.numbers LIMIT 1000000 INSERT INTO {string_primary_key_table_name} SELECT toString(rand64(0)), toString(rand64(1)), toString(rand64(2)) FROM system.numbers LIMIT 1500000 + INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 500000 + INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 1000000 + INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 1500000 + DROP TABLE IF EXISTS {integer_primary_key_table_name} DROP TABLE IF EXISTS {string_primary_key_table_name} + DROP TABLE IF EXISTS {decimal_primary_key_table_name} From 1de95d8c369ef3abbbe7e3e2b22b6e82e9a71035 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Thu, 21 Sep 2023 14:20:11 +0300 Subject: [PATCH 3/5] Updated implementation --- base/base/sort.h | 19 +++ contrib/pdqsort/pdqsort.h | 200 +++++++++++++++++++++++- src/Columns/ColumnDecimal.cpp | 87 +++++++---- src/Columns/ColumnVector.cpp | 60 +++---- src/Columns/RadixSortHelper.h | 25 +++ tests/performance/merge_tree_insert.xml | 12 +- tests/performance/sort_patterns.xml | 22 +++ 7 files changed, 362 insertions(+), 63 deletions(-) create mode 100644 src/Columns/RadixSortHelper.h create mode 100644 tests/performance/sort_patterns.xml diff --git a/base/base/sort.h b/base/base/sort.h index 912545979dc..cda03d92451 100644 --- a/base/base/sort.h +++ b/base/base/sort.h @@ -131,3 +131,22 @@ void sort(RandomIt first, RandomIt last) using comparator = std::less; ::sort(first, last, comparator()); } + +template +bool trySort(RandomIt first, RandomIt last, Compare compare) +{ +#ifndef NDEBUG + ::shuffle(first, last); +#endif + + ComparatorWrapper compare_wrapper = compare; + return ::pdqsort_try_sort(first, last, compare_wrapper); +} + +template +bool trySort(RandomIt first, RandomIt last) +{ + using value_type = typename std::iterator_traits::value_type; + using comparator = std::less; + return ::pdqsort_try_sort(first, last, comparator()); +} diff --git a/contrib/pdqsort/pdqsort.h b/contrib/pdqsort/pdqsort.h index 01e82b710ee..cbfc82a4f41 100644 --- a/contrib/pdqsort/pdqsort.h +++ b/contrib/pdqsort/pdqsort.h @@ -54,8 +54,10 @@ namespace pdqsort_detail { block_size = 64, // Cacheline size, assumes power of two. - cacheline_size = 64 + cacheline_size = 64, + /// Try sort allowed iterations + try_sort_iterations = 3, }; #if __cplusplus >= 201103L @@ -501,6 +503,167 @@ namespace pdqsort_detail { leftmost = false; } } + + template + inline bool pdqsort_try_sort_loop(Iter begin, + Iter end, + Compare comp, + size_t bad_allowed, + size_t iterations_allowed, + bool force_sort = false, + bool leftmost = true) { + typedef typename std::iterator_traits::difference_type diff_t; + + // Use a while loop for tail recursion elimination. + while (true) { + if (!force_sort && iterations_allowed == 0) { + return false; + } + + diff_t size = end - begin; + + // Insertion sort is faster for small arrays. + if (size < insertion_sort_threshold) { + if (leftmost) insertion_sort(begin, end, comp); + else unguarded_insertion_sort(begin, end, comp); + + return true; + } + + // Choose pivot as median of 3 or pseudomedian of 9. + diff_t s2 = size / 2; + if (size > ninther_threshold) { + sort3(begin, begin + s2, end - 1, comp); + sort3(begin + 1, begin + (s2 - 1), end - 2, comp); + sort3(begin + 2, begin + (s2 + 1), end - 3, comp); + sort3(begin + (s2 - 1), begin + s2, begin + (s2 + 1), comp); + std::iter_swap(begin, begin + s2); + } else sort3(begin + s2, begin, end - 1, comp); + + // If *(begin - 1) is the end of the right partition of a previous partition operation + // there is no element in [begin, end) that is smaller than *(begin - 1). Then if our + // pivot compares equal to *(begin - 1) we change strategy, putting equal elements in + // the left partition, greater elements in the right partition. We do not have to + // recurse on the left partition, since it's sorted (all equal). + if (!leftmost && !comp(*(begin - 1), *begin)) { + begin = partition_left(begin, end, comp) + 1; + continue; + } + + // Partition and get results. + std::pair part_result = + Branchless ? partition_right_branchless(begin, end, comp) + : partition_right(begin, end, comp); + Iter pivot_pos = part_result.first; + bool already_partitioned = part_result.second; + + // Check for a highly unbalanced partition. + diff_t l_size = pivot_pos - begin; + diff_t r_size = end - (pivot_pos + 1); + bool highly_unbalanced = l_size < size / 8 || r_size < size / 8; + + // If we got a highly unbalanced partition we shuffle elements to break many patterns. + if (highly_unbalanced) { + if (!force_sort) { + return false; + } + + // If we had too many bad partitions, switch to heapsort to guarantee O(n log n). + if (--bad_allowed == 0) { + std::make_heap(begin, end, comp); + std::sort_heap(begin, end, comp); + return true; + } + + if (l_size >= insertion_sort_threshold) { + std::iter_swap(begin, begin + l_size / 4); + std::iter_swap(pivot_pos - 1, pivot_pos - l_size / 4); + + if (l_size > ninther_threshold) { + std::iter_swap(begin + 1, begin + (l_size / 4 + 1)); + std::iter_swap(begin + 2, begin + (l_size / 4 + 2)); + std::iter_swap(pivot_pos - 2, pivot_pos - (l_size / 4 + 1)); + std::iter_swap(pivot_pos - 3, pivot_pos - (l_size / 4 + 2)); + } + } + + if (r_size >= insertion_sort_threshold) { + std::iter_swap(pivot_pos + 1, pivot_pos + (1 + r_size / 4)); + std::iter_swap(end - 1, end - r_size / 4); + + if (r_size > ninther_threshold) { + std::iter_swap(pivot_pos + 2, pivot_pos + (2 + r_size / 4)); + std::iter_swap(pivot_pos + 3, pivot_pos + (3 + r_size / 4)); + std::iter_swap(end - 2, end - (1 + r_size / 4)); + std::iter_swap(end - 3, end - (2 + r_size / 4)); + } + } + } else { + // If we were decently balanced and we tried to sort an already partitioned + // sequence try to use insertion sort. + if (already_partitioned && partial_insertion_sort(begin, pivot_pos, comp) + && partial_insertion_sort(pivot_pos + 1, end, comp)) { + return true; + } + } + + // Sort the left partition first using recursion and do tail recursion elimination for + // the right-hand partition. + if (pdqsort_try_sort_loop(begin, + pivot_pos, + comp, + bad_allowed, + iterations_allowed - 1, + force_sort, + leftmost)) { + force_sort = true; + } else { + return false; + } + + --iterations_allowed; + begin = pivot_pos + 1; + leftmost = false; + } + + return false; + } + + template + inline bool pdqsort_try_sort_impl(Iter begin, Iter end, Compare comp, size_t bad_allowed) + { + typedef typename std::iterator_traits::difference_type diff_t; + + static constexpr size_t iterations_allowed = pdqsort_detail::try_sort_iterations; + static constexpr size_t num_to_try = 16; + + diff_t size = end - begin; + + if (size > num_to_try * 10) + { + size_t out_of_order_elements = 0; + + for (size_t i = 1; i < num_to_try; ++i) + { + diff_t offset = size / num_to_try; + + diff_t prev_position = offset * (i - 1); + diff_t curr_position = offset * i; + diff_t next_position = offset * (i + 1) - 1; + + bool prev_less_than_curr = comp(*(begin + prev_position), *(begin + curr_position)); + bool curr_less_than_next = comp(*(begin + curr_position), *(begin + next_position)); + if ((prev_less_than_curr && curr_less_than_next) || (!prev_less_than_curr && !curr_less_than_next)) + continue; + + ++out_of_order_elements; + if (out_of_order_elements > iterations_allowed) + return false; + } + } + + return pdqsort_try_sort_loop(begin, end, comp, bad_allowed, iterations_allowed); + } } @@ -538,6 +701,41 @@ inline void pdqsort_branchless(Iter begin, Iter end) { pdqsort_branchless(begin, end, std::less()); } +template +inline bool pdqsort_try_sort(Iter begin, Iter end, Compare comp) { + if (begin == end) return true; + +#if __cplusplus >= 201103L + return pdqsort_detail::pdqsort_try_sort_impl::type>::value && + std::is_arithmetic::value_type>::value>( + begin, end, comp, pdqsort_detail::log2(end - begin)); +#else + return pdqsort_detail::pdqsort_try_sort_impl( + begin, end, comp, pdqsort_detail::log2(end - begin)); +#endif +} + +template +inline bool pdqsort_try_sort(Iter begin, Iter end) { + typedef typename std::iterator_traits::value_type T; + return pdqsort_try_sort(begin, end, std::less()); +} + +template +inline bool pdqsort_try_sort_branchless(Iter begin, Iter end, Compare comp) { + if (begin == end) return true; + + return pdqsort_detail::pdqsort_try_sort_impl( + begin, end, comp, pdqsort_detail::log2(end - begin)); +} + +template +inline bool pdqsort_try_sort_branchless(Iter begin, Iter end) { + typedef typename std::iterator_traits::value_type T; + return pdqsort_try_sort_branchless(begin, end, std::less()); +} + #undef PDQSORT_PREFER_MOVE diff --git a/src/Columns/ColumnDecimal.cpp b/src/Columns/ColumnDecimal.cpp index e5ebdc3666f..111c0e3cb1c 100644 --- a/src/Columns/ColumnDecimal.cpp +++ b/src/Columns/ColumnDecimal.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include @@ -139,26 +140,6 @@ void ColumnDecimal::updateHashFast(SipHash & hash) const hash.update(reinterpret_cast(data.data()), size() * sizeof(data[0])); } -namespace -{ - template - struct ValueWithIndex - { - T value; - UInt32 index; - }; - - template - struct RadixSortTraits : RadixSortNumTraits - { - using Element = ValueWithIndex; - using Result = size_t; - - static T & extractKey(Element & elem) { return elem.value; } - static size_t extractResult(Element & elem) { return elem.index; } - }; -} - template void ColumnDecimal::getPermutation(IColumn::PermutationSortDirection direction, IColumn::PermutationSortStability stability, size_t limit, int, IColumn::Permutation & res) const @@ -183,16 +164,19 @@ void ColumnDecimal::getPermutation(IColumn::PermutationSortDirection directio size_t data_size = data.size(); res.resize(data_size); - if (limit >= data_size) { + if (limit >= data_size) limit = 0; - } - if (!limit) + for (size_t i = 0; i < data_size; ++i) + res[i] = i; + + if constexpr (is_arithmetic_v && !is_big_int_v) { - /// A case for radix sort - /// LSD RadixSort is stable - if constexpr (is_arithmetic_v && !is_big_int_v) + if (!limit) { + /// A case for radix sort + /// LSD RadixSort is stable + bool reverse = direction == IColumn::PermutationSortDirection::Descending; bool ascending = direction == IColumn::PermutationSortDirection::Ascending; bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; @@ -203,8 +187,25 @@ void ColumnDecimal::getPermutation(IColumn::PermutationSortDirection directio /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. if (data_size >= 256 && data_size <= std::numeric_limits::max() && use_radix_sort) { + for (size_t i = 0; i < data_size; ++i) + res[i] = i; + + bool try_sort = false; + + if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Unstable) + try_sort = trySort(res.begin(), res.end(), comparator_ascending); + else if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Stable) + try_sort = trySort(res.begin(), res.end(), comparator_ascending_stable); + else if (direction == IColumn::PermutationSortDirection::Descending && stability == IColumn::PermutationSortStability::Unstable) + try_sort = trySort(res.begin(), res.end(), comparator_descending); + else + try_sort = trySort(res.begin(), res.end(), comparator_descending_stable); + + if (try_sort) + return; + PaddedPODArray> pairs(data_size); - for (UInt32 i = 0; i < UInt32(data_size); ++i) + for (UInt32 i = 0; i < static_cast(data_size); ++i) pairs[i] = {data[i].value, i}; RadixSort>::executeLSD(pairs.data(), data_size, reverse, res.data()); @@ -245,7 +246,37 @@ void ColumnDecimal::updatePermutation(IColumn::PermutationSortDirection direc return data[lhs] < data[rhs]; }; auto equals_comparator = [this](size_t lhs, size_t rhs) { return data[lhs] == data[rhs]; }; - auto sort = [](auto begin, auto end, auto pred) { ::sort(begin, end, pred); }; + auto sort = [&](auto begin, auto end, auto pred) + { + bool reverse = direction == IColumn::PermutationSortDirection::Descending; + bool ascending = direction == IColumn::PermutationSortDirection::Ascending; + bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; + + /// TODO: LSD RadixSort is currently not stable if direction is descending + bool use_radix_sort = (sort_is_stable && ascending) || !sort_is_stable; + size_t size = end - begin; + + if (size >= 256 && size <= std::numeric_limits::max() && use_radix_sort) + { + bool try_sort = trySort(begin, end, pred); + if (try_sort) + return; + + PaddedPODArray> pairs(size); + size_t index = 0; + + for (auto * it = begin; it != end; ++it) + { + pairs[index] = {data[*it].value, static_cast(*it)}; + ++index; + } + + RadixSort>::executeLSD(pairs.data(), size, reverse, res.data()); + return; + } + + ::sort(begin, end, pred); + }; auto partial_sort = [](auto begin, auto mid, auto end, auto pred) { ::partial_sort(begin, mid, end, pred); }; if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Unstable) diff --git a/src/Columns/ColumnVector.cpp b/src/Columns/ColumnVector.cpp index c1150957a8f..37e62c76596 100644 --- a/src/Columns/ColumnVector.cpp +++ b/src/Columns/ColumnVector.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -192,26 +193,6 @@ struct ColumnVector::equals bool operator()(size_t lhs, size_t rhs) const { return CompareHelper::equals(parent.data[lhs], parent.data[rhs], nan_direction_hint); } }; -namespace -{ - template - struct ValueWithIndex - { - T value; - UInt32 index; - }; - - template - struct RadixSortTraits : RadixSortNumTraits - { - using Element = ValueWithIndex; - using Result = size_t; - - static T & extractKey(Element & elem) { return elem.value; } - static size_t extractResult(Element & elem) { return elem.index; } - }; -} - #if USE_EMBEDDED_COMPILER template @@ -263,11 +244,16 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction if (limit >= data_size) limit = 0; - if (!limit) { - /// A case for radix sort - /// LSD RadixSort is stable - if constexpr (is_arithmetic_v && !is_big_int_v) + for (size_t i = 0; i < data_size; ++i) + res[i] = i; + + if constexpr (is_arithmetic_v && !is_big_int_v) + { + if (!limit) { + /// A case for radix sort + /// LSD RadixSort is stable + bool reverse = direction == IColumn::PermutationSortDirection::Descending; bool ascending = direction == IColumn::PermutationSortDirection::Ascending; bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; @@ -278,6 +264,20 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. if (data_size >= 256 && data_size <= std::numeric_limits::max() && use_radix_sort) { + bool try_sort = false; + + if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Unstable) + try_sort = trySort(res.begin(), res.end(), less(*this, nan_direction_hint)); + else if (direction == IColumn::PermutationSortDirection::Ascending && stability == IColumn::PermutationSortStability::Stable) + try_sort = trySort(res.begin(), res.end(), less_stable(*this, nan_direction_hint)); + else if (direction == IColumn::PermutationSortDirection::Descending && stability == IColumn::PermutationSortStability::Unstable) + try_sort = trySort(res.begin(), res.end(), greater(*this, nan_direction_hint)); + else + try_sort = trySort(res.begin(), res.end(), greater_stable(*this, nan_direction_hint)); + + if (try_sort) + return; + PaddedPODArray> pairs(data_size); for (UInt32 i = 0; i < static_cast(data_size); ++i) pairs[i] = {data[i], i}; @@ -323,12 +323,12 @@ template void ColumnVector::updatePermutation(IColumn::PermutationSortDirection direction, IColumn::PermutationSortStability stability, size_t limit, int nan_direction_hint, IColumn::Permutation & res, EqualRanges & equal_ranges) const { - bool reverse = direction == IColumn::PermutationSortDirection::Descending; - bool ascending = direction == IColumn::PermutationSortDirection::Ascending; - bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; - auto sort = [&](auto begin, auto end, auto pred) { + bool reverse = direction == IColumn::PermutationSortDirection::Descending; + bool ascending = direction == IColumn::PermutationSortDirection::Ascending; + bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; + /// A case for radix sort if constexpr (is_arithmetic_v && !is_big_int_v) { @@ -339,6 +339,10 @@ void ColumnVector::updatePermutation(IColumn::PermutationSortDirection direct /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. if (size >= 256 && size <= std::numeric_limits::max() && use_radix_sort) { + bool try_sort = trySort(begin, end, pred); + if (try_sort) + return; + PaddedPODArray> pairs(size); size_t index = 0; diff --git a/src/Columns/RadixSortHelper.h b/src/Columns/RadixSortHelper.h new file mode 100644 index 00000000000..e7d8ea6e535 --- /dev/null +++ b/src/Columns/RadixSortHelper.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace DB +{ + +template +struct ValueWithIndex +{ + T value; + UInt32 index; +}; + +template +struct RadixSortTraits : RadixSortNumTraits +{ + using Element = ValueWithIndex; + using Result = size_t; + + static T & extractKey(Element & elem) { return elem.value; } + static size_t extractResult(Element & elem) { return elem.index; } +}; + +} diff --git a/tests/performance/merge_tree_insert.xml b/tests/performance/merge_tree_insert.xml index ec991c458c2..3e1d2541480 100644 --- a/tests/performance/merge_tree_insert.xml +++ b/tests/performance/merge_tree_insert.xml @@ -37,9 +37,9 @@ CREATE TABLE merge_tree_insert_5 (value_1 String, value_2 String, value_3 String) ENGINE = MergeTree ORDER BY (value_1, value_2) CREATE TABLE merge_tree_insert_6 (value_1 String, value_2 String, value_3 String) ENGINE = MergeTree ORDER BY (value_1, value_2, value_3) - CREATE TABLE merge_tree_insert_4 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1) - CREATE TABLE merge_tree_insert_5 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1, value_2) - CREATE TABLE merge_tree_insert_6 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1, value_2, value_3) + CREATE TABLE merge_tree_insert_7 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1) + CREATE TABLE merge_tree_insert_8 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1, value_2) + CREATE TABLE merge_tree_insert_9 (value_1 Decimal64(8), value_2 Decimal64(8), value_3 Decimal64(8)) ENGINE = MergeTree ORDER BY (value_1, value_2, value_3) INSERT INTO {integer_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 500000 INSERT INTO {integer_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 1000000 @@ -49,9 +49,9 @@ INSERT INTO {string_primary_key_table_name} SELECT toString(rand64(0)), toString(rand64(1)), toString(rand64(2)) FROM system.numbers LIMIT 1000000 INSERT INTO {string_primary_key_table_name} SELECT toString(rand64(0)), toString(rand64(1)), toString(rand64(2)) FROM system.numbers LIMIT 1500000 - INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 500000 - INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 1000000 - INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0), rand64(1), rand64(2) FROM system.numbers LIMIT 1500000 + INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0) % 1000000, rand64(1) % 1500000, rand64(2) % 2000000 FROM system.numbers LIMIT 500000 + INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0) % 1000000, rand64(1) % 1500000, rand64(2) % 2000000 FROM system.numbers LIMIT 1000000 + INSERT INTO {decimal_primary_key_table_name} SELECT rand64(0) % 1000000, rand64(1) % 1500000, rand64(2) % 2000000 FROM system.numbers LIMIT 1500000 DROP TABLE IF EXISTS {integer_primary_key_table_name} DROP TABLE IF EXISTS {string_primary_key_table_name} diff --git a/tests/performance/sort_patterns.xml b/tests/performance/sort_patterns.xml new file mode 100644 index 00000000000..6ca4a34fc34 --- /dev/null +++ b/tests/performance/sort_patterns.xml @@ -0,0 +1,22 @@ + + + + integer_type + + UInt32 + UInt64 + + + + + CREATE TABLE sequential_{integer_type} (key {integer_type}, value {integer_type}) Engine = Memory + + INSERT INTO sequential_{integer_type} SELECT number, number FROM numbers(10000000) + + SELECT key, value FROM sequential_{integer_type} ORDER BY key; + SELECT key, value FROM sequential_{integer_type} ORDER BY key, value; + SELECT key, value FROM sequential_{integer_type} ORDER BY key DESC; + SELECT key, value FROM sequential_{integer_type} ORDER BY key DESC, value DESC; + + DROP TABLE IF EXISTS sequential_{integer_type} + From 40be8227ea3ed1507361378ce7dc500d3b817c51 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Fri, 22 Sep 2023 12:25:13 +0300 Subject: [PATCH 4/5] Fixed tests --- tests/performance/sort_patterns.xml | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/performance/sort_patterns.xml b/tests/performance/sort_patterns.xml index 6ca4a34fc34..fc49b20cc8c 100644 --- a/tests/performance/sort_patterns.xml +++ b/tests/performance/sort_patterns.xml @@ -7,16 +7,22 @@ UInt64 + + sort_expression + + key + key, value + key DESC + key DESC, value DESC + + CREATE TABLE sequential_{integer_type} (key {integer_type}, value {integer_type}) Engine = Memory - INSERT INTO sequential_{integer_type} SELECT number, number FROM numbers(10000000) + INSERT INTO sequential_{integer_type} SELECT number, number FROM numbers(500000000) - SELECT key, value FROM sequential_{integer_type} ORDER BY key; - SELECT key, value FROM sequential_{integer_type} ORDER BY key, value; - SELECT key, value FROM sequential_{integer_type} ORDER BY key DESC; - SELECT key, value FROM sequential_{integer_type} ORDER BY key DESC, value DESC; + SELECT key, value FROM sequential_{integer_type} ORDER BY {sort_expression} FORMAT Null DROP TABLE IF EXISTS sequential_{integer_type} From f7494a5e454110eb440c5e7eb7a7f65fdc9fa6c8 Mon Sep 17 00:00:00 2001 From: Maksim Kita Date: Mon, 25 Sep 2023 11:39:57 +0300 Subject: [PATCH 5/5] Added documentation --- base/base/sort.h | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/base/base/sort.h b/base/base/sort.h index cda03d92451..1a814587763 100644 --- a/base/base/sort.h +++ b/base/base/sort.h @@ -132,6 +132,13 @@ void sort(RandomIt first, RandomIt last) ::sort(first, last, comparator()); } +/** Try to fast sort elements for common sorting patterns: + * 1. If elements are already sorted. + * 2. If elements are already almost sorted. + * 3. If elements are already sorted in reverse order. + * + * Returns true if fast sort was performed or elements were already sorted, false otherwise. + */ template bool trySort(RandomIt first, RandomIt last, Compare compare) { @@ -148,5 +155,5 @@ bool trySort(RandomIt first, RandomIt last) { using value_type = typename std::iterator_traits::value_type; using comparator = std::less; - return ::pdqsort_try_sort(first, last, comparator()); + return ::trySort(first, last, comparator()); }