Azat Khuzhin 2021-01-23 09:05:07 +03:00
parent 86ead0f0a9
commit df53438a66
4 changed files with 48 additions and 10 deletions

View File

@ -1,5 +1,6 @@
#pragma once
#include <experimental/type_traits>
#include <type_traits>
#include <IO/WriteHelpers.h>
@ -15,14 +16,37 @@
namespace DB
{
/// Uses addOverflow method (if available) to avoid UB for sumWithOverflow()
///
/// Since NO_SANITIZE_UNDEFINED works only for the function itself, without
/// callers, and in case of non-POD type (i.e. Decimal) you have overwritten
/// operator+=(), which will have UB.
template <typename T>
struct AggregateFunctionSumAddOverflowImpl
{
static void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(T & lhs, const T & rhs)
{
lhs += rhs;
}
};
template <typename DecimalNativeType>
struct AggregateFunctionSumAddOverflowImpl<Decimal<DecimalNativeType>>
{
static void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(Decimal<DecimalNativeType> & lhs, const Decimal<DecimalNativeType> & rhs)
{
lhs.addOverflow(rhs);
}
};
template <typename T>
struct AggregateFunctionSumData
{
using Impl = AggregateFunctionSumAddOverflowImpl<T>;
T sum{};
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(T value)
{
sum += value;
Impl::add(sum, value);
}
/// Vectorized version
@ -45,22 +69,22 @@ struct AggregateFunctionSumData
while (ptr < unrolled_end)
{
for (size_t i = 0; i < unroll_count; ++i)
partial_sums[i] += ptr[i];
Impl::add(partial_sums[i], ptr[i]);
ptr += unroll_count;
}
for (size_t i = 0; i < unroll_count; ++i)
sum += partial_sums[i];
Impl::add(sum, partial_sums[i]);
}
/// clang cannot vectorize the loop if accumulator is class member instead of local variable.
T local_sum{};
while (ptr < end)
{
local_sum += *ptr;
Impl::add(local_sum, *ptr);
++ptr;
}
sum += local_sum;
Impl::add(sum, local_sum);
}
template <typename Value>
@ -78,30 +102,34 @@ struct AggregateFunctionSumData
while (ptr < unrolled_end)
{
for (size_t i = 0; i < unroll_count; ++i)
{
if (!null_map[i])
partial_sums[i] += ptr[i];
{
Impl::add(partial_sums[i], ptr[i]);
}
}
ptr += unroll_count;
null_map += unroll_count;
}
for (size_t i = 0; i < unroll_count; ++i)
sum += partial_sums[i];
Impl::add(sum, partial_sums[i]);
}
T local_sum{};
while (ptr < end)
{
if (!*null_map)
local_sum += *ptr;
Impl::add(local_sum, *ptr);
++ptr;
++null_map;
}
sum += local_sum;
Impl::add(sum, local_sum);
}
void NO_SANITIZE_UNDEFINED merge(const AggregateFunctionSumData & rhs)
{
sum += rhs.sum;
Impl::add(sum, rhs.sum);
}
void write(WriteBuffer & buf) const
@ -118,6 +146,7 @@ struct AggregateFunctionSumData
{
return sum;
}
};
template <typename T>

View File

@ -4,6 +4,7 @@
#include <string>
#include <vector>
#include <common/extended_types.h>
#include <common/defines.h>
namespace DB
@ -166,6 +167,9 @@ struct Decimal
const Decimal<T> & operator /= (const T & x) { value /= x; return *this; }
const Decimal<T> & operator %= (const T & x) { value %= x; return *this; }
/// This is to avoid UB for sumWithOverflow()
void NO_SANITIZE_UNDEFINED addOverflow(const T & x) { value += x; }
T value;
};

View File

@ -0,0 +1,3 @@
-- { echo }
SELECT sumWithOverflow(a - 65537) FROM (SELECT cast(number AS Decimal32(4)) a FROM numbers(10));
203668.4592

View File

@ -0,0 +1,2 @@
-- { echo }
SELECT sumWithOverflow(a - 65537) FROM (SELECT cast(number AS Decimal32(4)) a FROM numbers(10));