mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-20 00:30:49 +00:00
Merge pull request #59148 from bigo-sg/improve_if_with_floating
Continue optimizing branch miss of if function when result type is float*/decimal*/int*
This commit is contained in:
commit
f67bff12b7
@ -46,12 +46,32 @@ using namespace GatherUtils;
|
||||
/** Selection function by condition: if(cond, then, else).
|
||||
* cond - UInt8
|
||||
* then, else - numeric types for which there is a general type, or dates, datetimes, or strings, or arrays of these types.
|
||||
* For better performance, try to use branch free code for numeric types(i.e. cond ? a : b --> !!cond * a + !cond * b), except floating point types because of Inf or NaN.
|
||||
* For better performance, try to use branch free code for numeric types(i.e. cond ? a : b --> !!cond * a + !cond * b)
|
||||
*/
|
||||
|
||||
template <typename ResultType>
|
||||
concept is_native_int_or_decimal_v
|
||||
= std::is_integral_v<ResultType> || (is_decimal<ResultType> && sizeof(ResultType) <= 8);
|
||||
|
||||
// This macro performs a branch-free conditional assignment for floating point types.
|
||||
// It uses bitwise operations to avoid branching, which can be beneficial for performance.
|
||||
#define BRANCHFREE_IF_FLOAT(TYPE, vc, va, vb, vr) \
|
||||
using UIntType = typename NumberTraits::Construct<false, false, sizeof(TYPE)>::Type; \
|
||||
using IntType = typename NumberTraits::Construct<true, false, sizeof(TYPE)>::Type; \
|
||||
auto mask = static_cast<UIntType>(static_cast<IntType>(vc) - 1); \
|
||||
auto new_a = static_cast<ResultType>(va); \
|
||||
auto new_b = static_cast<ResultType>(vb); \
|
||||
UIntType uint_a; \
|
||||
std::memcpy(&uint_a, &new_a, sizeof(UIntType)); \
|
||||
UIntType uint_b; \
|
||||
std::memcpy(&uint_b, &new_b, sizeof(UIntType)); \
|
||||
UIntType tmp = (~mask & uint_a) | (mask & uint_b); \
|
||||
(vr) = *(reinterpret_cast<ResultType *>(&tmp));
|
||||
|
||||
template <typename ArrayCond, typename ArrayA, typename ArrayB, typename ArrayResult, typename ResultType>
|
||||
inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const ArrayB & b, ArrayResult & res)
|
||||
{
|
||||
|
||||
size_t size = cond.size();
|
||||
bool a_is_short = a.size() < size;
|
||||
bool b_is_short = b.size() < size;
|
||||
@ -61,47 +81,68 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
|
||||
size_t a_index = 0, b_index = 0;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
if constexpr (std::is_integral_v<ResultType>)
|
||||
{
|
||||
if constexpr (is_native_int_or_decimal_v<ResultType>)
|
||||
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
|
||||
a_index += !!cond[i];
|
||||
b_index += !cond[i];
|
||||
else if constexpr (std::is_floating_point_v<ResultType>)
|
||||
{
|
||||
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[b_index], res[i])
|
||||
}
|
||||
else
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[b_index++]);
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b[b_index]);
|
||||
|
||||
a_index += !!cond[i];
|
||||
b_index += !cond[i];
|
||||
}
|
||||
}
|
||||
else if (a_is_short)
|
||||
{
|
||||
size_t a_index = 0;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if constexpr (std::is_integral_v<ResultType>)
|
||||
{
|
||||
{
|
||||
if constexpr (is_native_int_or_decimal_v<ResultType>)
|
||||
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]);
|
||||
a_index += !!cond[i];
|
||||
else if constexpr (std::is_floating_point_v<ResultType>)
|
||||
{
|
||||
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[i], res[i])
|
||||
}
|
||||
else
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[i]);
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b[i]);
|
||||
|
||||
a_index += !!cond[i];
|
||||
}
|
||||
}
|
||||
else if (b_is_short)
|
||||
{
|
||||
size_t b_index = 0;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if constexpr (std::is_integral_v<ResultType>)
|
||||
{
|
||||
{
|
||||
if constexpr (is_native_int_or_decimal_v<ResultType>)
|
||||
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
|
||||
b_index += !cond[i];
|
||||
else if constexpr (std::is_floating_point_v<ResultType>)
|
||||
{
|
||||
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[b_index], res[i])
|
||||
}
|
||||
else
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[b_index++]);
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[b_index]);
|
||||
|
||||
b_index += !cond[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if constexpr (std::is_integral_v<ResultType>)
|
||||
{
|
||||
if constexpr (is_native_int_or_decimal_v<ResultType>)
|
||||
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
|
||||
else if constexpr (std::is_floating_point_v<ResultType>)
|
||||
{
|
||||
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[i], res[i])
|
||||
}
|
||||
else
|
||||
{
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -114,21 +155,32 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar
|
||||
{
|
||||
size_t a_index = 0;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if constexpr (std::is_integral_v<ResultType>)
|
||||
{
|
||||
{
|
||||
if constexpr (is_native_int_or_decimal_v<ResultType>)
|
||||
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b);
|
||||
a_index += !!cond[i];
|
||||
else if constexpr (std::is_floating_point_v<ResultType>)
|
||||
{
|
||||
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b, res[i])
|
||||
}
|
||||
else
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b);
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b);
|
||||
|
||||
a_index += !!cond[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if constexpr (std::is_integral_v<ResultType>)
|
||||
{
|
||||
if constexpr (is_native_int_or_decimal_v<ResultType>)
|
||||
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
|
||||
else if constexpr (std::is_floating_point_v<ResultType>)
|
||||
{
|
||||
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b, res[i])
|
||||
}
|
||||
else
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -141,21 +193,68 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar
|
||||
{
|
||||
size_t b_index = 0;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if constexpr (std::is_integral_v<ResultType>)
|
||||
{
|
||||
{
|
||||
if constexpr (is_native_int_or_decimal_v<ResultType>)
|
||||
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
|
||||
b_index += !cond[i];
|
||||
else if constexpr (std::is_floating_point_v<ResultType>)
|
||||
{
|
||||
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[b_index], res[i])
|
||||
}
|
||||
else
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[b_index++]);
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[b_index]);
|
||||
|
||||
b_index += !cond[i];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
if constexpr (std::is_integral_v<ResultType>)
|
||||
{
|
||||
if constexpr (is_native_int_or_decimal_v<ResultType>)
|
||||
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
|
||||
else if constexpr (std::is_floating_point_v<ResultType>)
|
||||
{
|
||||
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[i], res[i])
|
||||
}
|
||||
else
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ArrayCond, typename A, typename B, typename ArrayResult, typename ResultType>
|
||||
inline void fillConstantConstant(const ArrayCond & cond, A a, B b, ArrayResult & res)
|
||||
{
|
||||
size_t size = cond.size();
|
||||
|
||||
/// Int8(alias type of uint8_t) has special aliasing properties that prevents compiler from auto-vectorizing for below codes, refer to https://gist.github.com/alexei-zaripov/dcc14c78819c5f1354afe8b70932007c
|
||||
///
|
||||
/// for (size_t i = 0; i < size; ++i)
|
||||
/// res[i] = cond[i] ? static_cast<Int8>(a) : static_cast<Int8>(b);
|
||||
///
|
||||
/// Therefore, we manually optimize it by avoiding branch miss when ResultType is Int8. Other types like (U)Int128|256 or Decimal128/256 also benefit from this optimization.
|
||||
if constexpr (std::is_same_v<ResultType, Int8> || is_over_big_int<ResultType>)
|
||||
{
|
||||
alignas(64) const ResultType ab[2] = {static_cast<ResultType>(a), static_cast<ResultType>(b)};
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
res[i] = ab[!cond[i]];
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<ResultType, Decimal32> || std::is_same_v<ResultType, Decimal64>)
|
||||
{
|
||||
ResultType new_a = static_cast<ResultType>(a);
|
||||
ResultType new_b = static_cast<ResultType>(b);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
/// Reuse new_a and new_b to achieve auto-vectorization
|
||||
res[i] = cond[i] ? new_a : new_b;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
|
||||
}
|
||||
}
|
||||
|
||||
@ -201,8 +300,7 @@ struct NumIfImpl
|
||||
auto col_res = ColVecResult::create(size);
|
||||
ArrayResult & res = col_res->getData();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
|
||||
fillConstantConstant<ArrayCond, A, B, ArrayResult, ResultType>(cond, a, b, res);
|
||||
return col_res;
|
||||
}
|
||||
};
|
||||
@ -251,8 +349,7 @@ struct NumIfImpl<Decimal<A>, Decimal<B>, Decimal<R>>
|
||||
auto col_res = ColVecResult::create(size, scale);
|
||||
ArrayResult & res = col_res->getData();
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
|
||||
fillConstantConstant<ArrayCond, A, B, ArrayResult, ResultType>(cond, a, b, res);
|
||||
return col_res;
|
||||
}
|
||||
};
|
||||
@ -1144,17 +1241,12 @@ public:
|
||||
|
||||
if (cond_const_col)
|
||||
{
|
||||
if (arg_then.type->equals(*arg_else.type))
|
||||
{
|
||||
return cond_const_col->getValue<UInt8>()
|
||||
? arg_then.column
|
||||
: arg_else.column;
|
||||
}
|
||||
UInt8 value = cond_const_col->getValue<UInt8>();
|
||||
const ColumnWithTypeAndName & arg = value ? arg_then : arg_else;
|
||||
if (arg.type->equals(*result_type))
|
||||
return arg.column;
|
||||
else
|
||||
{
|
||||
materialized_cond_col = cond_const_col->convertToFullColumn();
|
||||
cond_col = typeid_cast<const ColumnUInt8 *>(&*materialized_cond_col);
|
||||
}
|
||||
return castColumn(arg, result_type);
|
||||
}
|
||||
|
||||
if (!cond_col)
|
||||
@ -1191,6 +1283,8 @@ public:
|
||||
TypeIndex left_id = left_type->getTypeId();
|
||||
TypeIndex right_id = right_type->getTypeId();
|
||||
|
||||
/// TODO optimize for map type
|
||||
/// TODO optimize for nullable type
|
||||
if (!(callOnBasicTypes<true, true, true, false>(left_id, right_id, call)
|
||||
|| (res = executeTyped<UUID, UUID>(cond_col, arguments, result_type, input_rows_count))
|
||||
|| (res = executeString(cond_col, arguments, result_type))
|
||||
|
@ -1,12 +1,24 @@
|
||||
<test>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() > 42949673, zero + 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 3865470566, zero + 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 2147483647, zero + 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, 2)) ]]></query>
|
||||
|
||||
<!-- Tests when branches are both not constant -->
|
||||
<query>with rand32() % 2 as x select if(x, materialize(1.234), materialize(2.456)) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1.234::Decimal64(3) as a, 2.456::Decimal64(3) as b select if(x, materialize(a), materialize(b)) from numbers(100000000) format Null</query>
|
||||
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() > 42949673, zero + 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 3865470566, zero + 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 2147483647, zero + 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, zero + 2)) ]]></query>
|
||||
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, zero + 2)) ]]></query>
|
||||
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, 2)) ]]></query>
|
||||
|
||||
<!-- Tests when branches are both constant -->
|
||||
<query>with rand32() % 2 as x, 1::Int8 as a, -1::Int8 as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1::Int64 as a, -1::Int64 as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1::Int32 as a, -1::Int32 as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1::Decimal32(3) as a, -1::Decimal32(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1::Decimal64(3) as a, -1::Decimal64(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1::Decimal128(3) as a, -1::Decimal128(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1::Decimal256(3) as a, -1::Decimal256(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1::Int128 as a, -1::Int128 as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
<query>with rand32() % 2 as x, 1::Int256 as a, -1::Int256 as b select if(x, a, b) from numbers(100000000) format Null</query>
|
||||
</test>
|
||||
|
Loading…
Reference in New Issue
Block a user