function if branch free fix tests.

This commit is contained in:
zhanglistar 2023-12-19 16:40:13 +08:00
parent b252b7182c
commit 59b049ce08

View File

@ -57,10 +57,10 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
size_t a_index = 0, b_index = 0;
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_arithmetic_v<ResultType>)
if constexpr (std::is_integral_v<ResultType>)
{
res[i] = cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
a_index += cond[i];
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
a_index += !!cond[i];
b_index += !cond[i];
}
else
@ -71,10 +71,10 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>)
if constexpr (std::is_integral_v<ResultType>)
{
res[i] = cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]);
a_index += cond[i];
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]);
a_index += !!cond[i];
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[i]);
@ -83,9 +83,9 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>)
if constexpr (std::is_integral_v<ResultType>)
{
res[i] = cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
}
else
@ -94,8 +94,8 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>)
res[i] = cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
if constexpr (std::is_integral_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]);
}
@ -111,10 +111,10 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>)
if constexpr (std::is_integral_v<ResultType>)
{
res[i] = cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b);
a_index += cond[i];
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b);
a_index += !!cond[i];
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b);
@ -122,8 +122,8 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>)
res[i] = cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
if constexpr (std::is_integral_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b);
}
@ -139,9 +139,9 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>)
if constexpr (std::is_integral_v<ResultType>)
{
res[i] = cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
}
else
@ -150,8 +150,8 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>)
res[i] = cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
if constexpr (std::is_integral_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]);
}