Merge pull request #59148 from bigo-sg/improve_if_with_floating

Continue optimizing branch miss of if function when result type is float*/decimal*/int*
This commit is contained in:
Raúl Marín 2024-01-31 12:42:06 +01:00 committed by GitHub
commit f67bff12b7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 154 additions and 48 deletions

View File

@ -46,12 +46,32 @@ using namespace GatherUtils;
/** Selection function by condition: if(cond, then, else).
* cond - UInt8
* then, else - numeric types for which there is a general type, or dates, datetimes, or strings, or arrays of these types.
* For better performance, try to use branch free code for numeric types(i.e. cond ? a : b --> !!cond * a + !cond * b), except floating point types because of Inf or NaN.
* For better performance, try to use branch free code for numeric types(i.e. cond ? a : b --> !!cond * a + !cond * b)
*/
template <typename ResultType>
concept is_native_int_or_decimal_v
= std::is_integral_v<ResultType> || (is_decimal<ResultType> && sizeof(ResultType) <= 8);
// This macro performs a branch-free conditional assignment for floating point types.
// It uses bitwise operations to avoid branching, which can be beneficial for performance.
#define BRANCHFREE_IF_FLOAT(TYPE, vc, va, vb, vr) \
using UIntType = typename NumberTraits::Construct<false, false, sizeof(TYPE)>::Type; \
using IntType = typename NumberTraits::Construct<true, false, sizeof(TYPE)>::Type; \
auto mask = static_cast<UIntType>(static_cast<IntType>(vc) - 1); \
auto new_a = static_cast<ResultType>(va); \
auto new_b = static_cast<ResultType>(vb); \
UIntType uint_a; \
std::memcpy(&uint_a, &new_a, sizeof(UIntType)); \
UIntType uint_b; \
std::memcpy(&uint_b, &new_b, sizeof(UIntType)); \
UIntType tmp = (~mask & uint_a) | (mask & uint_b); \
(vr) = *(reinterpret_cast<ResultType *>(&tmp));
template <typename ArrayCond, typename ArrayA, typename ArrayB, typename ArrayResult, typename ResultType>
inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const ArrayB & b, ArrayResult & res)
{
size_t size = cond.size();
bool a_is_short = a.size() < size;
bool b_is_short = b.size() < size;
@ -61,47 +81,68 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
size_t a_index = 0, b_index = 0;
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
a_index += !!cond[i];
b_index += !cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[b_index++]);
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b[b_index]);
a_index += !!cond[i];
b_index += !cond[i];
}
}
else if (a_is_short)
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]);
a_index += !!cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[i], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[i]);
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b[i]);
a_index += !!cond[i];
}
}
else if (b_is_short)
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[b_index++]);
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[i], res[i])
}
else
{
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]);
}
}
}
}
@ -114,21 +155,32 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b);
a_index += !!cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b, res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b);
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b);
a_index += !!cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b, res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b);
}
}
}
@ -141,21 +193,68 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[b_index++]);
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[i], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]);
}
}
}
template <typename ArrayCond, typename A, typename B, typename ArrayResult, typename ResultType>
inline void fillConstantConstant(const ArrayCond & cond, A a, B b, ArrayResult & res)
{
size_t size = cond.size();
/// Int8(alias type of uint8_t) has special aliasing properties that prevents compiler from auto-vectorizing for below codes, refer to https://gist.github.com/alexei-zaripov/dcc14c78819c5f1354afe8b70932007c
///
/// for (size_t i = 0; i < size; ++i)
/// res[i] = cond[i] ? static_cast<Int8>(a) : static_cast<Int8>(b);
///
/// Therefore, we manually optimize it by avoiding branch miss when ResultType is Int8. Other types like (U)Int128|256 or Decimal128/256 also benefit from this optimization.
if constexpr (std::is_same_v<ResultType, Int8> || is_over_big_int<ResultType>)
{
alignas(64) const ResultType ab[2] = {static_cast<ResultType>(a), static_cast<ResultType>(b)};
for (size_t i = 0; i < size; ++i)
{
res[i] = ab[!cond[i]];
}
}
else if constexpr (std::is_same_v<ResultType, Decimal32> || std::is_same_v<ResultType, Decimal64>)
{
ResultType new_a = static_cast<ResultType>(a);
ResultType new_b = static_cast<ResultType>(b);
for (size_t i = 0; i < size; ++i)
{
/// Reuse new_a and new_b to achieve auto-vectorization
res[i] = cond[i] ? new_a : new_b;
}
}
else
{
for (size_t i = 0; i < size; ++i)
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
}
}
@ -201,8 +300,7 @@ struct NumIfImpl
auto col_res = ColVecResult::create(size);
ArrayResult & res = col_res->getData();
for (size_t i = 0; i < size; ++i)
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
fillConstantConstant<ArrayCond, A, B, ArrayResult, ResultType>(cond, a, b, res);
return col_res;
}
};
@ -251,8 +349,7 @@ struct NumIfImpl<Decimal<A>, Decimal<B>, Decimal<R>>
auto col_res = ColVecResult::create(size, scale);
ArrayResult & res = col_res->getData();
for (size_t i = 0; i < size; ++i)
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
fillConstantConstant<ArrayCond, A, B, ArrayResult, ResultType>(cond, a, b, res);
return col_res;
}
};
@ -1144,17 +1241,12 @@ public:
if (cond_const_col)
{
if (arg_then.type->equals(*arg_else.type))
{
return cond_const_col->getValue<UInt8>()
? arg_then.column
: arg_else.column;
}
UInt8 value = cond_const_col->getValue<UInt8>();
const ColumnWithTypeAndName & arg = value ? arg_then : arg_else;
if (arg.type->equals(*result_type))
return arg.column;
else
{
materialized_cond_col = cond_const_col->convertToFullColumn();
cond_col = typeid_cast<const ColumnUInt8 *>(&*materialized_cond_col);
}
return castColumn(arg, result_type);
}
if (!cond_col)
@ -1191,6 +1283,8 @@ public:
TypeIndex left_id = left_type->getTypeId();
TypeIndex right_id = right_type->getTypeId();
/// TODO optimize for map type
/// TODO optimize for nullable type
if (!(callOnBasicTypes<true, true, true, false>(left_id, right_id, call)
|| (res = executeTyped<UUID, UUID>(cond_col, arguments, result_type, input_rows_count))
|| (res = executeString(cond_col, arguments, result_type))

View File

@ -1,12 +1,24 @@
<test>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() > 42949673, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 3865470566, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 2147483647, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, 2)) ]]></query>
<!-- Tests when branches are both not constant -->
<query>with rand32() % 2 as x select if(x, materialize(1.234), materialize(2.456)) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1.234::Decimal64(3) as a, 2.456::Decimal64(3) as b select if(x, materialize(a), materialize(b)) from numbers(100000000) format Null</query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() > 42949673, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 3865470566, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 2147483647, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, 2)) ]]></query>
<!-- Tests when branches are both constant -->
<query>with rand32() % 2 as x, 1::Int8 as a, -1::Int8 as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Int64 as a, -1::Int64 as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Int32 as a, -1::Int32 as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Decimal32(3) as a, -1::Decimal32(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Decimal64(3) as a, -1::Decimal64(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Decimal128(3) as a, -1::Decimal128(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Decimal256(3) as a, -1::Decimal256(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Int128 as a, -1::Int128 as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Int256 as a, -1::Int256 as b select if(x, a, b) from numbers(100000000) format Null</query>
</test>