diff --git a/src/AggregateFunctions/AggregateFunctionSum.h b/src/AggregateFunctions/AggregateFunctionSum.h index 5781ab69c6b..58aaddf357a 100644 --- a/src/AggregateFunctions/AggregateFunctionSum.h +++ b/src/AggregateFunctions/AggregateFunctionSum.h @@ -146,9 +146,7 @@ struct AggregateFunctionSumData size_t count = end - start; const auto * end_ptr = ptr + count; - if constexpr ( - (is_integer && !is_big_int_v) - || (is_decimal && !std::is_same_v && !std::is_same_v)) + if constexpr ((is_integer || is_decimal) && !is_over_big_int) { /// For integers we can vectorize the operation if we replace the null check using a multiplication (by 0 for null, 1 for not null) /// https://quick-bench.com/q/MLTnfTvwC2qZFVeWHfOBR3U7a8I @@ -163,8 +161,39 @@ struct AggregateFunctionSumData Impl::add(sum, local_sum); return; } + else if constexpr (is_over_big_int) + { + /// Use a mask to discard or keep the value to reduce branch miss. + /// Notice that for (U)Int128 or Decimal128, MaskType is Int8 instead of Int64, otherwise extra branches will be introduced by compiler (for unknown reason) and performance will be worse. + using MaskType = std::conditional_t; + alignas(64) const MaskType masks[2] = {0, -1}; + T local_sum{}; + while (ptr < end_ptr) + { + Value v = *ptr; + if constexpr (!add_if_zero) + { + if constexpr (is_integer) + v &= masks[!!*condition_map]; + else + v.value &= masks[!!*condition_map]; + } + else + { + if constexpr (is_integer) + v &= masks[!*condition_map]; + else + v.value &= masks[!*condition_map]; + } - if constexpr (std::is_floating_point_v) + Impl::add(local_sum, v); + ++ptr; + ++condition_map; + } + Impl::add(sum, local_sum); + return; + } + else if constexpr (std::is_floating_point_v) { /// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned /// integer of the same size and use a mask instead (0 to discard, 0xFF..FF to keep) diff --git a/tests/performance/sum.xml b/tests/performance/sum.xml index 57b879a360d..36b898436bf 100644 --- a/tests/performance/sum.xml +++ b/tests/performance/sum.xml @@ -17,6 +17,13 @@ SELECT sumKahan(toNullable(toFloat32(number))) FROM numbers(100000000) SELECT sumKahan(toNullable(toFloat64(number))) FROM numbers(100000000) + select sumIf(number::Decimal128(3), rand32() % 2 = 0) from numbers(100000000) + select sumIf(number::Decimal256(3), rand32() % 2 = 0) from numbers(100000000) + select sumIf(number::Int128, rand32() % 2 = 0) from numbers(100000000) + select sumIf(number::UInt128, rand32() % 2 = 0) from numbers(100000000) + select sumIf(number::Int256, rand32() % 2 = 0) from numbers(100000000) + select sumIf(number::UInt256, rand32() % 2 = 0) from numbers(100000000) + CREATE TABLE nullfloat32 (x Nullable(Float32)) ENGINE = Memory INSERT INTO nullfloat32 diff --git a/tests/queries/0_stateless/02985_if_over_big_int_decimal.reference b/tests/queries/0_stateless/02985_if_over_big_int_decimal.reference new file mode 100644 index 00000000000..1dfad945ee2 --- /dev/null +++ b/tests/queries/0_stateless/02985_if_over_big_int_decimal.reference @@ -0,0 +1,12 @@ +49500 +49500 +49500 +49500 +49500 +49500 +450000 +450000 +450000 +450000 +450000 +450000 diff --git a/tests/queries/0_stateless/02985_if_over_big_int_decimal.sql b/tests/queries/0_stateless/02985_if_over_big_int_decimal.sql new file mode 100644 index 00000000000..0295a64a092 --- /dev/null +++ b/tests/queries/0_stateless/02985_if_over_big_int_decimal.sql @@ -0,0 +1,14 @@ +select sumIf(number::Int128, number % 10 == 0) from numbers(1000); +select sumIf(number::UInt128, number % 10 == 0) from numbers(1000); +select sumIf(number::Int256, number % 10 == 0) from numbers(1000); +select sumIf(number::UInt256, number % 10 == 0) from numbers(1000); +select sumIf(number::Decimal128(3), number % 10 == 0) from numbers(1000); +select sumIf(number::Decimal256(3), number % 10 == 0) from numbers(1000); + +-- Test when the condition is neither 0 nor 1 +select sumIf(number::Int128, number % 10) from numbers(1000); +select sumIf(number::UInt128, number % 10) from numbers(1000); +select sumIf(number::Int256, number % 10) from numbers(1000); +select sumIf(number::UInt256, number % 10) from numbers(1000); +select sumIf(number::Decimal128(3), number % 10) from numbers(1000); +select sumIf(number::Decimal256(3), number % 10) from numbers(1000);