optimize for decimal

This commit is contained in:
taiyang-li 2024-01-24 16:30:37 +08:00
parent 09e24ed6c5
commit aed8ffe3d8

View File

@ -57,7 +57,7 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
size_t a_index = 0, b_index = 0;
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_integral_v<ResultType>)
if constexpr (std::is_integral_v<ResultType> || ((is_decimal<ResultType> && sizeof(ResultType) <= 8)))
{
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
a_index += !!cond[i];
@ -71,33 +71,39 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (std::is_integral_v<ResultType> || ((is_decimal<ResultType> && sizeof(ResultType) <= 8)))
{
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]);
a_index += !!cond[i];
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[i]);
}
}
else if (b_is_short)
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (std::is_integral_v<ResultType> || ((is_decimal<ResultType> && sizeof(ResultType) <= 8)))
{
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[b_index++]);
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (std::is_integral_v<ResultType> || ((is_decimal<ResultType> && sizeof(ResultType) <= 8)))
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]);
}
}
}
@ -110,21 +116,25 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (std::is_integral_v<ResultType> || ((is_decimal<ResultType> && sizeof(ResultType) <= 8)))
{
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b);
a_index += !!cond[i];
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b);
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (std::is_integral_v<ResultType> || ((is_decimal<ResultType> && sizeof(ResultType) <= 8)))
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b);
}
}
}
@ -137,21 +147,25 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (std::is_integral_v<ResultType> || ((is_decimal<ResultType> && sizeof(ResultType) <= 8)))
{
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
}
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[b_index++]);
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (std::is_integral_v<ResultType> || ((is_decimal<ResultType> && sizeof(ResultType) <= 8)))
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]);
}
}
}
@ -197,6 +211,7 @@ struct NumIfImpl
auto col_res = ColVecResult::create(size);
ArrayResult & res = col_res->getData();
/// TODO 这里是否可避免分支跳转
for (size_t i = 0; i < size; ++i)
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
return col_res;
@ -1120,6 +1135,7 @@ public:
}
else
{
/// TODO 这里不物化行不行?
materialized_cond_col = cond_const_col->convertToFullColumn();
cond_col = typeid_cast<const ColumnUInt8 *>(&*materialized_cond_col);
}
@ -1159,6 +1175,8 @@ public:
TypeIndex left_id = left_type->getTypeId();
TypeIndex right_id = right_type->getTypeId();
/// TODO map类型是否有优化空间
/// TODO 对nullable类型是否有优化空间
if (!(callOnBasicTypes<true, true, true, false>(left_id, right_id, call)
|| (res = executeTyped<UUID, UUID>(cond_col, arguments, result_type, input_rows_count))
|| (res = executeString(cond_col, arguments, result_type))