From a38ca3bca285be41b86fb2612719e4856f835cf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Mar=C3=ADn?= Date: Mon, 22 Jan 2024 14:55:52 +0100 Subject: [PATCH] Optimize SingleValueDataNumeric getIndexNotNullIf --- src/AggregateFunctions/SingleValueData.cpp | 158 +++++++++++++++++- src/AggregateFunctions/SingleValueData.h | 36 ++-- src/Common/findExtreme.cpp | 44 ++++- .../02406_minmax_behaviour.reference | 4 + .../0_stateless/02406_minmax_behaviour.sql | 3 + 5 files changed, 221 insertions(+), 24 deletions(-) diff --git a/src/AggregateFunctions/SingleValueData.cpp b/src/AggregateFunctions/SingleValueData.cpp index 1f8af7d9f63..3eb9571b370 100644 --- a/src/AggregateFunctions/SingleValueData.cpp +++ b/src/AggregateFunctions/SingleValueData.cpp @@ -34,7 +34,7 @@ mergeIfAndNullFlags(const UInt8 * __restrict null_map, const UInt8 * __restrict } -std::optional SingleValueDataBase::getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) +std::optional SingleValueDataBase::getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const { if (row_begin >= row_end) return std::nullopt; @@ -59,7 +59,7 @@ std::optional SingleValueDataBase::getSmallestIndex(const IColumn & colu } } -std::optional SingleValueDataBase::getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) +std::optional SingleValueDataBase::getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const { if (row_begin >= row_end) return std::nullopt; @@ -85,7 +85,7 @@ std::optional SingleValueDataBase::getGreatestIndex(const IColumn & colu } std::optional SingleValueDataBase::getSmallestIndexNotNullIf( - const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const { size_t index = row_begin; while ((index < row_end) && ((if_map && if_map[index] == 0) || (null_map && null_map[index] != 0))) @@ -100,7 +100,7 @@ std::optional SingleValueDataBase::getSmallestIndexNotNullIf( } std::optional SingleValueDataBase::getGreatestIndexNotNullIf( - const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const { size_t index = row_begin; while ((index < row_end) && ((if_map && if_map[index] == 0) || (null_map && null_map[index] != 0))) @@ -409,7 +409,7 @@ void SingleValueDataFixed::setGreatestNotNullIf( } template -std::optional SingleValueDataFixed::getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) +std::optional SingleValueDataFixed::getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const { if (row_begin >= row_end) return std::nullopt; @@ -430,7 +430,7 @@ std::optional SingleValueDataFixed::getSmallestIndex(const IColumn & } template -std::optional SingleValueDataFixed::getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) +std::optional SingleValueDataFixed::getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const { if (row_begin >= row_end) return std::nullopt; @@ -450,6 +450,134 @@ std::optional SingleValueDataFixed::getGreatestIndex(const IColumn & } } +template +std::optional SingleValueDataFixed::getSmallestIndexNotNullIf( + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const +{ + if (row_begin >= row_end) + return std::nullopt; + + const auto & vec = assert_cast(column); + + if constexpr (has_find_extreme_implementation) + { + std::optional opt; + if (!if_map) + { + opt = findExtremeMinNotNull(vec.getData().data(), null_map, row_begin, row_end); + if (!opt.has_value()) + return opt; + for (size_t i = row_begin; i < row_end; i++) + { + if (!null_map[i] && vec[i] == *opt) + return {i}; + } + } + else if (!null_map) + { + opt = findExtremeMinIf(vec.getData().data(), if_map, row_begin, row_end); + if (!opt.has_value()) + return opt; + for (size_t i = row_begin; i < row_end; i++) + { + if (if_map[i] && vec[i] == *opt) + return {i}; + } + } + else + { + auto final_flags = mergeIfAndNullFlags(null_map, if_map, row_begin, row_end); + opt = findExtremeMinIf(vec.getData().data(), final_flags.get(), row_begin, row_end); + if (!opt.has_value()) + return std::nullopt; + for (size_t i = row_begin; i < row_end; i++) + { + if (final_flags[i] && vec[i] == *opt) + return {i}; + } + } + UNREACHABLE(); + } + else + { + size_t index = row_begin; + while ((index < row_end) && ((if_map && if_map[index] == 0) || (null_map && null_map[index] != 0))) + index++; + if (index >= row_end) + return std::nullopt; + + for (size_t i = index + 1; i < row_end; i++) + if ((!if_map || if_map[i] != 0) && (!null_map || null_map[i] == 0) && (vec[i] < vec[index])) + index = i; + return {index}; + } +} + +template +std::optional SingleValueDataFixed::getGreatestIndexNotNullIf( + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const +{ + if (row_begin >= row_end) + return std::nullopt; + + const auto & vec = assert_cast(column); + + if constexpr (has_find_extreme_implementation) + { + std::optional opt; + if (!if_map) + { + opt = findExtremeMaxNotNull(vec.getData().data(), null_map, row_begin, row_end); + if (!opt.has_value()) + return opt; + for (size_t i = row_begin; i < row_end; i++) + { + if (!null_map[i] && vec[i] == *opt) + return {i}; + } + return opt; + } + else if (!null_map) + { + opt = findExtremeMaxIf(vec.getData().data(), if_map, row_begin, row_end); + if (!opt.has_value()) + return opt; + for (size_t i = row_begin; i < row_end; i++) + { + if (if_map[i] && vec[i] == *opt) + return {i}; + } + return opt; + } + else + { + auto final_flags = mergeIfAndNullFlags(null_map, if_map, row_begin, row_end); + opt = findExtremeMaxIf(vec.getData().data(), final_flags.get(), row_begin, row_end); + if (!opt.has_value()) + return std::nullopt; + for (size_t i = row_begin; i < row_end; i++) + { + if (final_flags[i] && vec[i] == *opt) + return {i}; + } + } + UNREACHABLE(); + } + else + { + size_t index = row_begin; + while ((index < row_end) && ((if_map && if_map[index] == 0) || (null_map && null_map[index] != 0))) + index++; + if (index >= row_end) + return std::nullopt; + + for (size_t i = index + 1; i < row_end; i++) + if ((!if_map || if_map[i] != 0) && (!null_map || null_map[i] == 0) && (vec[i] < vec[index])) + index = i; + return {index}; + } +} + #if USE_EMBEDDED_COMPILER @@ -864,17 +992,31 @@ void SingleValueDataNumeric::setGreatestNotNullIf( } template -std::optional SingleValueDataNumeric::getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) +std::optional SingleValueDataNumeric::getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const { return memory.get().getSmallestIndex(column, row_begin, row_end); } template -std::optional SingleValueDataNumeric::getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) +std::optional SingleValueDataNumeric::getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const { return memory.get().getGreatestIndex(column, row_begin, row_end); } +template +std::optional SingleValueDataNumeric::getSmallestIndexNotNullIf( + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const +{ + return memory.get().getSmallestIndexNotNullIf(column, null_map, if_map, row_begin, row_end); +} + +template +std::optional SingleValueDataNumeric::getGreatestIndexNotNullIf( + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const +{ + return memory.get().getGreatestIndexNotNullIf(column, null_map, if_map, row_begin, row_end); +} + #define DISPATCH(TYPE) template struct SingleValueDataNumeric; FOR_SINGLE_VALUE_NUMERIC_TYPES(DISPATCH) diff --git a/src/AggregateFunctions/SingleValueData.h b/src/AggregateFunctions/SingleValueData.h index 5bfa02b41ee..923c986923c 100644 --- a/src/AggregateFunctions/SingleValueData.h +++ b/src/AggregateFunctions/SingleValueData.h @@ -59,12 +59,12 @@ struct SingleValueDataBase /// Given a column returns the index of the smallest or greatest value in it /// Doesn't return anything if the column is empty /// There are used to implement argMin / argMax - virtual std::optional getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end); - virtual std::optional getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end); - static std::optional getSmallestIndexNotNullIf( - const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end); - static std::optional getGreatestIndexNotNullIf( - const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end); + virtual std::optional getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const; + virtual std::optional getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const; + virtual std::optional getSmallestIndexNotNullIf( + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const; + virtual std::optional getGreatestIndexNotNullIf( + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const; }; @@ -136,8 +136,12 @@ struct SingleValueDataFixed size_t row_end, Arena *); - std::optional getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end); - std::optional getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end); + std::optional getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const; + std::optional getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const; + std::optional getSmallestIndexNotNullIf( + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const; + std::optional getGreatestIndexNotNullIf( + const IColumn & column, const UInt8 * __restrict null_map, const UInt8 * __restrict if_map, size_t row_begin, size_t row_end) const; static bool allocatesMemoryInArena() { return false; } @@ -241,8 +245,20 @@ public: size_t row_end, Arena * arena) override; - std::optional getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) override; - std::optional getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) override; + std::optional getSmallestIndex(const IColumn & column, size_t row_begin, size_t row_end) const override; + std::optional getGreatestIndex(const IColumn & column, size_t row_begin, size_t row_end) const override; + std::optional getSmallestIndexNotNullIf( + const IColumn & column, + const UInt8 * __restrict null_map, + const UInt8 * __restrict if_map, + size_t row_begin, + size_t row_end) const override; + std::optional getGreatestIndexNotNullIf( + const IColumn & column, + const UInt8 * __restrict null_map, + const UInt8 * __restrict if_map, + size_t row_begin, + size_t row_end) const override; static bool allocatesMemoryInArena() { return false; } }; diff --git a/src/Common/findExtreme.cpp b/src/Common/findExtreme.cpp index 70d9c8293ab..2d131cfd524 100644 --- a/src/Common/findExtreme.cpp +++ b/src/Common/findExtreme.cpp @@ -2,6 +2,9 @@ #include #include +#include +#include + namespace DB { @@ -67,15 +70,44 @@ MULTITARGET_FUNCTION_AVX2_SSE42( for (size_t unroll_it = 0; unroll_it < unroll_block; unroll_it++) ret = ComparatorClass::cmp(ret, partial_min[unroll_it]); } - } - for (; i < count; i++) + for (; i < count; i++) + { + if (add_all_elements || !condition_map[i] == add_if_cond_zero) + ret = ComparatorClass::cmp(ret, ptr[i]); + } + return ret; + } + else { - if (add_all_elements || !condition_map[i] == add_if_cond_zero) - ret = ComparatorClass::cmp(ret, ptr[i]); + /// Only native integers + for (; i < count; i++) + { + constexpr bool is_min = std::same_as>; + if constexpr (add_all_elements) + { + ret = ComparatorClass::cmp(ret, ptr[i]); + } + else if constexpr (is_min) + { + bool keep_number = add_if_cond_zero ? !condition_map[i] : !!condition_map[i]; + /// If keep_number = ptr[i] * 1 + 0 * max = ptr[i] + /// If not keep_number = ptr[i] * 0 + 1 * max = max + T final = ptr[i] * T{keep_number} + T{!keep_number} * std::numeric_limits::max(); + ret = ComparatorClass::cmp(ret, final); + } + else + { + static_assert(std::same_as>); + bool keep_number = add_if_cond_zero ? !condition_map[i] : !!condition_map[i]; + /// If keep_number = ptr[i] * 1 + 0 * lowest = ptr[i] + /// If not keep_number = ptr[i] * 0 + 1 * lowest = lowest + T final = ptr[i] * T{keep_number} + T{!keep_number} * std::numeric_limits::lowest(); + ret = ComparatorClass::cmp(ret, final); + } + } + return ret; } - - return ret; } )) diff --git a/tests/queries/0_stateless/02406_minmax_behaviour.reference b/tests/queries/0_stateless/02406_minmax_behaviour.reference index d52ba640a0e..91e1de441d9 100644 --- a/tests/queries/0_stateless/02406_minmax_behaviour.reference +++ b/tests/queries/0_stateless/02406_minmax_behaviour.reference @@ -56,6 +56,10 @@ SELECT min(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, 22 SELECT max(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20)); 26 +SELECT max(number) from (Select if(number % 2 == 1, NULL, -number::Int8) as number FROM numbers(128)); +0 +SELECT min(number) from (Select if(number % 2 == 1, NULL, -number::Int8) as number FROM numbers(128)); +-126 SELECT argMax(number, now()) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=100; 10 SELECT argMax(number, now()) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=20000; diff --git a/tests/queries/0_stateless/02406_minmax_behaviour.sql b/tests/queries/0_stateless/02406_minmax_behaviour.sql index a3afe7d40b0..4b5bdbe4fd1 100644 --- a/tests/queries/0_stateless/02406_minmax_behaviour.sql +++ b/tests/queries/0_stateless/02406_minmax_behaviour.sql @@ -48,6 +48,9 @@ SELECT maxIf(number::Nullable(String), number < 10) as number from numbers(10, 1 SELECT min(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20)); SELECT max(n::Nullable(String)) from (Select if(number < 15 and number % 2 == 1, number * 2, NULL) as n from numbers(10, 20)); +SELECT max(number) from (Select if(number % 2 == 1, NULL, -number::Int8) as number FROM numbers(128)); +SELECT min(number) from (Select if(number % 2 == 1, NULL, -number::Int8) as number FROM numbers(128)); + SELECT argMax(number, now()) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=100; SELECT argMax(number, now()) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=20000; SELECT argMax(number, 1) FROM (Select number as number from numbers(10, 10000)) settings max_threads=1, max_block_size=100;