From 849ac1fe9950367e4c2b96defe1b3cfb02a169de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Mar=C3=ADn?= Date: Wed, 17 Jan 2024 15:51:21 +0100 Subject: [PATCH] Implement findExtremeMinIndex / findExtremeMaxIndex --- src/AggregateFunctions/SingleValueData.cpp | 20 +------- src/AggregateFunctions/SingleValueData.h | 3 +- src/Columns/ColumnVector.cpp | 21 ++++++++ src/Common/findExtreme.cpp | 48 +++++++++++++++++-- src/Common/findExtreme.h | 22 +++++++-- .../agg_functions_argmin_argmax.xml | 24 ++++++++++ tests/performance/order_with_limit.xml | 3 ++ 7 files changed, 112 insertions(+), 29 deletions(-) create mode 100644 tests/performance/agg_functions_argmin_argmax.xml diff --git a/src/AggregateFunctions/SingleValueData.cpp b/src/AggregateFunctions/SingleValueData.cpp index 4c5ed86c206..dae5e80a38c 100644 --- a/src/AggregateFunctions/SingleValueData.cpp +++ b/src/AggregateFunctions/SingleValueData.cpp @@ -421,15 +421,7 @@ std::optional SingleValueDataFixed::getSmallestIndex(const IColumn & const auto & vec = assert_cast(column); if constexpr (has_find_extreme_implementation) { - std::optional opt = findExtremeMin(vec.getData().data(), row_begin, row_end); - if (!opt || (has() && value <= opt)) - return std::nullopt; - - /// TODO: Implement findExtremeMinIndex to do the lookup properly (with SIMD and batching) - for (size_t i = row_begin; i < row_end; i++) - if (vec.getData()[i] == *opt) - return i; - return row_end; + return findExtremeMinIndex(vec.getData().data(), row_begin, row_end); } else { @@ -450,15 +442,7 @@ std::optional SingleValueDataFixed::getGreatestIndex(const IColumn & const auto & vec = assert_cast(column); if constexpr (has_find_extreme_implementation) { - std::optional opt = findExtremeMax(vec.getData().data(), row_begin, row_end); - if (!opt || (has() && value >= opt)) - return std::nullopt; - - /// TODO: Implement findExtremeMaxIndex to do the lookup properly (with SIMD and batching) - for (size_t i = row_begin; i < row_end; i++) - if (vec.getData()[i] == *opt) - return i; - return row_end; + return findExtremeMaxIndex(vec.getData().data(), row_begin, row_end); } else { diff --git a/src/AggregateFunctions/SingleValueData.h b/src/AggregateFunctions/SingleValueData.h index eb770595a1a..e216586bf09 100644 --- a/src/AggregateFunctions/SingleValueData.h +++ b/src/AggregateFunctions/SingleValueData.h @@ -57,8 +57,7 @@ struct SingleValueDataBase virtual void setGreatestNotNullIf(const IColumn &, const UInt8 * __restrict, const UInt8 * __restrict, size_t, size_t, Arena *); /// Given a column returns the index of the smallest or greatest value in it - /// Doesn't return anything if the column is empty. In some cases (SingleValueDataFixed) it will also return - /// empty if the stored value is already the smallest/greatest + /// Doesn't return anything if the column is empty /// There are used to implement argMin / argMax virtual std::optional getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end); virtual std::optional getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end); diff --git a/src/Columns/ColumnVector.cpp b/src/Columns/ColumnVector.cpp index b1cf449dfde..b22d12b995f 100644 --- a/src/Columns/ColumnVector.cpp +++ b/src/Columns/ColumnVector.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -247,6 +248,26 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction iota(res.data(), data_size, IColumn::Permutation::value_type(0)); + if constexpr (has_find_extreme_implementation && !std::is_floating_point_v) + { + /// Disabled for:floating point + /// * floating point: We don't deal with nan_direction_hint + /// * stability::Stable: We might return any value, not the first + if ((limit == 1) && (stability == IColumn::PermutationSortStability::Unstable)) + { + std::optional index; + if (direction == IColumn::PermutationSortDirection::Ascending) + index = findExtremeMinIndex(data.data(), 0, data.size()); + else + index = findExtremeMaxIndex(data.data(), 0, data.size()); + if (index) + { + res.data()[0] = *index; + return; + } + } + } + if constexpr (is_arithmetic_v && !is_big_int_v) { if (!limit) diff --git a/src/Common/findExtreme.cpp b/src/Common/findExtreme.cpp index c02450e2adc..f7fc1c3cfea 100644 --- a/src/Common/findExtreme.cpp +++ b/src/Common/findExtreme.cpp @@ -133,14 +133,54 @@ std::optional findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __rest return findExtreme, false, false>(ptr, condition_map, start, end); } +template +std::optional findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end) +{ + /// This is implemented based on findNumericExtreme and not the other way around (or independently) because getting + /// the MIN or MAX value of an array is possible with SIMD, but getting the index isn't. + /// So what we do is use SIMD to find the lowest value and then iterate again over the array to find its position + std::optional opt = findExtremeMin(ptr, start, end); + if (!opt) + return std::nullopt; + + /// Some minimal heuristics for the case the input is sorted + if (*opt == ptr[start]) + return {start}; + for (size_t i = end - 1; i != start + 1; i--) + if (ptr[i] == *opt) + return {i}; + return std::nullopt; +} + +template +std::optional findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end) +{ + std::optional opt = findExtremeMax(ptr, start, end); + if (!opt) + return std::nullopt; + + /// Some minimal heuristics for the case the input is sorted + if (*opt == ptr[start]) + return {start}; + for (size_t i = end - 1; i != start + 1; i--) + if (ptr[i] == *opt) + return {i}; + return std::nullopt; +} #define INSTANTIATION(T) \ template std::optional findExtremeMin(const T * __restrict ptr, size_t start, size_t end); \ - template std::optional findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ - template std::optional findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ + template std::optional findExtremeMinNotNull( \ + const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ + template std::optional findExtremeMinIf( \ + const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ template std::optional findExtremeMax(const T * __restrict ptr, size_t start, size_t end); \ - template std::optional findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ - template std::optional findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); + template std::optional findExtremeMaxNotNull( \ + const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ + template std::optional findExtremeMaxIf( \ + const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ + template std::optional findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); \ + template std::optional findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end); FOR_BASIC_NUMERIC_TYPES(INSTANTIATION) #undef INSTANTIATION diff --git a/src/Common/findExtreme.h b/src/Common/findExtreme.h index c43d2d43350..68e7360d6e2 100644 --- a/src/Common/findExtreme.h +++ b/src/Common/findExtreme.h @@ -31,15 +31,27 @@ std::optional findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * _ template std::optional findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); +template +std::optional findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); + +template +std::optional findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end); + #define EXTERN_INSTANTIATION(T) \ extern template std::optional findExtremeMin(const T * __restrict ptr, size_t start, size_t end); \ - extern template std::optional findExtremeMinNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ - extern template std::optional findExtremeMinIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ + extern template std::optional findExtremeMinNotNull( \ + const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ + extern template std::optional findExtremeMinIf( \ + const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ extern template std::optional findExtremeMax(const T * __restrict ptr, size_t start, size_t end); \ - extern template std::optional findExtremeMaxNotNull(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ - extern template std::optional findExtremeMaxIf(const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); + extern template std::optional findExtremeMaxNotNull( \ + const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ + extern template std::optional findExtremeMaxIf( \ + const T * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end); \ + extern template std::optional findExtremeMinIndex(const T * __restrict ptr, size_t start, size_t end); \ + extern template std::optional findExtremeMaxIndex(const T * __restrict ptr, size_t start, size_t end); - FOR_BASIC_NUMERIC_TYPES(EXTERN_INSTANTIATION) +FOR_BASIC_NUMERIC_TYPES(EXTERN_INSTANTIATION) #undef EXTERN_INSTANTIATION } diff --git a/tests/performance/agg_functions_argmin_argmax.xml b/tests/performance/agg_functions_argmin_argmax.xml new file mode 100644 index 00000000000..e8eed2a82de --- /dev/null +++ b/tests/performance/agg_functions_argmin_argmax.xml @@ -0,0 +1,24 @@ + + + + group_scale + + 1000000 + + + + +select argMin(Title, EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null +select argMinIf(Title, EventTime, Title != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null +select argMinIf(Title::Nullable(String), EventTime::Nullable(DateTime), Title::Nullable(String) != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null + +select argMin(RegionID, EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null +select argMin((Title, RegionID), EventTime) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null +select argMinIf(Title, EventTime, Title != '') from hits_100m_single group by intHash32(UserID) % {group_scale} FORMAT Null + +select argMax(WatchID, Age) from hits_100m_single FORMAT Null +select argMax(WatchID, Age::Nullable(UInt8)) from hits_100m_single FORMAT Null +select argMax(WatchID, (EventDate, EventTime)) from hits_100m_single where Title != '' group by intHash32(UserID) % {group_scale} FORMAT Null +select argMax(MobilePhone, MobilePhoneModel) from hits_100m_single + + diff --git a/tests/performance/order_with_limit.xml b/tests/performance/order_with_limit.xml index 1e1cb52267c..d1ad2afade8 100644 --- a/tests/performance/order_with_limit.xml +++ b/tests/performance/order_with_limit.xml @@ -1,4 +1,5 @@ + SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 1 FORMAT Null SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 10 FORMAT Null SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 100 FORMAT Null SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 1500 FORMAT Null @@ -7,6 +8,7 @@ SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 10000 FORMAT Null SELECT number AS n FROM numbers_mt(200000000) ORDER BY n DESC LIMIT 65535 FORMAT Null + SELECT intHash64(number) AS n FROM numbers_mt(500000000) ORDER BY n LIMIT 1 FORMAT Null SELECT intHash64(number) AS n FROM numbers_mt(500000000) ORDER BY n LIMIT 10 FORMAT Null SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 100 FORMAT Null SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 1500 FORMAT Null @@ -15,6 +17,7 @@ SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n LIMIT 10000 FORMAT Null SELECT intHash64(number) AS n FROM numbers_mt(100000000) ORDER BY n LIMIT 65535 FORMAT Null + SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 1 FORMAT Null SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 10 FORMAT Null SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 100 FORMAT Null SELECT intHash64(number) AS n FROM numbers_mt(200000000) ORDER BY n, n + 1, n + 2 LIMIT 1500 FORMAT Null