function if branch free fix tests.

This commit is contained in:
zhanglistar 2023-12-19 16:40:13 +08:00
parent b252b7182c
commit 59b049ce08

View File

@ -57,10 +57,10 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
size_t a_index = 0, b_index = 0; size_t a_index = 0, b_index = 0;
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
{ {
if constexpr (std::is_arithmetic_v<ResultType>) if constexpr (std::is_integral_v<ResultType>)
{ {
res[i] = cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]); res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
a_index += cond[i]; a_index += !!cond[i];
b_index += !cond[i]; b_index += !cond[i];
} }
else else
@ -71,10 +71,10 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
{ {
size_t a_index = 0; size_t a_index = 0;
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>) if constexpr (std::is_integral_v<ResultType>)
{ {
res[i] = cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]); res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]);
a_index += cond[i]; a_index += !!cond[i];
} }
else else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[i]); res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[i]);
@ -83,9 +83,9 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
{ {
size_t b_index = 0; size_t b_index = 0;
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>) if constexpr (std::is_integral_v<ResultType>)
{ {
res[i] = cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]); res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i]; b_index += !cond[i];
} }
else else
@ -94,8 +94,8 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
else else
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>) if constexpr (std::is_integral_v<ResultType>)
res[i] = cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]); res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
else else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]); res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]);
} }
@ -111,10 +111,10 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar
{ {
size_t a_index = 0; size_t a_index = 0;
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>) if constexpr (std::is_integral_v<ResultType>)
{ {
res[i] = cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b); res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b);
a_index += cond[i]; a_index += !!cond[i];
} }
else else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b); res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b);
@ -122,8 +122,8 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar
else else
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>) if constexpr (std::is_integral_v<ResultType>)
res[i] = cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b); res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
else else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b); res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b);
} }
@ -139,9 +139,9 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar
{ {
size_t b_index = 0; size_t b_index = 0;
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>) if constexpr (std::is_integral_v<ResultType>)
{ {
res[i] = cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]); res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i]; b_index += !cond[i];
} }
else else
@ -150,8 +150,8 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar
else else
{ {
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
if constexpr (std::is_arithmetic_v<ResultType>) if constexpr (std::is_integral_v<ResultType>)
res[i] = cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]); res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
else else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]); res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]);
} }