further rearrangement

This commit is contained in:
myrrc 2020-12-22 01:37:07 +03:00
parent 881ce408bf
commit 2db721b647

View File

@ -307,8 +307,9 @@ private:
ColumnVector<ResultType>>::Container;
public:
template <OpCase op_case, bool is_decimal_a, bool is_decimal_b, bool scale_is_left_scale, class A, class B>
static void NO_INLINE process(const A & a, const B & b, ResultContainerType & c, NativeResultType scale)
template <OpCase op_case, bool is_decimal_a, bool is_decimal_b, class A, class B>
static void NO_INLINE process(const A & a, const B & b, ResultContainerType & c,
NativeResultType scale_a, NativeResultType scale_b)
{
if constexpr(op_case == OpCase::LeftConstant) static_assert(!IsDecimalNumber<A>);
if constexpr(op_case == OpCase::RightConstant) static_assert(!IsDecimalNumber<B>);
@ -317,13 +318,22 @@ public:
if constexpr (is_plus_minus_compare)
{
if (scale != 1)
if (scale_a != 1)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<scale_is_left_scale>(
c[i] = applyScaled<true>(
unwrap<op_case, OpCase::LeftConstant>(a, i),
unwrap<op_case, OpCase::RightConstant>(b, i),
scale);
scale_a);
return;
}
else if (scale_b != 1)
{
for (size_t i = 0; i < size; ++i)
c[i] = applyScaled<false>(
unwrap<op_case, OpCase::LeftConstant>(a, i),
unwrap<op_case, OpCase::RightConstant>(b, i),
scale_b);
return;
}
}
@ -333,7 +343,7 @@ public:
c[i] = applyScaledDiv<is_decimal_a>(
unwrap<op_case, OpCase::LeftConstant>(a, i),
unwrap<op_case, OpCase::RightConstant>(b, i),
scale);
scale_a);
return;
}
@ -343,17 +353,21 @@ public:
unwrap<op_case, OpCase::RightConstant>(b, i));
}
template <bool is_decimal_a, bool is_decimal_b, bool scale_is_left_scale, class A, class B>
static ResultType process(A a, B b, NativeResultType scale)
template <bool is_decimal_a, bool is_decimal_b, class A, class B>
static ResultType process(A a, B b, NativeResultType scale_a, NativeResultType scale_b)
{
static_assert(!IsDecimalNumber<A>);
static_assert(!IsDecimalNumber<B>);
if constexpr (is_division && is_decimal_b)
return applyScaledDiv<is_decimal_a>(a, b, scale);
return applyScaledDiv<is_decimal_a>(a, b, scale_a);
else if constexpr (is_plus_minus_compare)
if (scale != 1)
return applyScaled<scale_is_left_scale>(a, b, scale);
{
if (scale_a != 1)
return applyScaled<true>(a, b, scale_a);
if (scale_b != 1)
return applyScaled<false>(a, b, scale_b);
}
return apply(a, b);
}
@ -714,12 +728,21 @@ class FunctionBinaryArithmetic : public IFunction
return col_const->template getValue<T>();
}
template <OpCase op_case, bool left_decimal, bool right_decimal, class OpImpl, class OpImplCheck>
void helperInvokeEither(const auto& left, const auto& right, auto& vec_res, auto scale_a, auto scale_b) const
{
if (check_decimal_overflow)
OpImplCheck::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b);
else
OpImpl::template process<op_case, left_decimal, right_decimal>(left, right, vec_res, scale_a, scale_b);
}
template <class LeftDataType, class RightDataType, class ResultDataType>
ColumnPtr executeNumericWithDecimal(
const auto& left, const auto& right,
const ColumnConst * const col_left_const, const ColumnConst * const col_right_const,
const auto * const col_left, const auto * const col_right,
size_t col_left_size)
size_t col_left_size) const
{
using T0 = typename LeftDataType::FieldType;
using T1 = typename RightDataType::FieldType;
@ -747,24 +770,25 @@ class FunctionBinaryArithmetic : public IFunction
return decimalResultType<is_multiply, is_division>(left, right);
}();
ResultType left_scale;
ResultType right_scale;
const ResultType scale_a = [&] {
if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
return right.getScaleMultiplier();
else if constexpr (result_is_decimal)
return type.scaleFactorFor(left, is_multiply);
else if constexpr(left_is_decimal)
return left.getScale();
else
return 1; //won't be used, just to silence the warning
}();
if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
left_scale = right.getScaleMultiplier();
else if constexpr (result_is_decimal)
left_scale = type.scaleFactorFor(left, is_multiply);
else if constexpr(left_is_decimal) // BUG precision loss
left_scale = left.getScale();
else
left_scale = 1; //won't be used, just to silence the warning
if constexpr (result_is_decimal)
right_scale = type.scaleFactorFor(right, is_multiply || is_division);
else if constexpr(right_is_decimal)
right_scale = right.getScale();
else
right_scale = 1; //same
const ResultType scale_b = [&] {
if constexpr (result_is_decimal)
return type.scaleFactorFor(right, is_multiply || is_division);
else if constexpr(right_is_decimal)
return right.getScale();
else
return 1; //same
}();
/// non-vector result
if (col_left_const && col_right_const)
@ -772,11 +796,9 @@ class FunctionBinaryArithmetic : public IFunction
const NativeResultType const_a = helperGetOrConvert<T0, ResultDataType>(col_left_const, left);
const NativeResultType const_b = helperGetOrConvert<T1, ResultDataType>(col_right_const, right);
auto res = check_decimal_overflow ?
OpImplCheck::template constantConstant<left_is_decimal, right_is_decimal>(
const_a, const_b, left_scale, right_scale) :
OpImpl::template constantConstant<left_is_decimal, right_is_decimal>(
const_a, const_b, left_scale, right_scale);
const ResultType res = check_decimal_overflow
? OpImplCheck::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b)
: OpImpl::template process<left_is_decimal, right_is_decimal>(const_a, const_b, scale_a, scale_b);
if constexpr (result_is_decimal)
return ResultDataType(type.getPrecision(), type.getScale()).createColumnConst(
@ -795,34 +817,22 @@ class FunctionBinaryArithmetic : public IFunction
if (col_left && col_right)
{
if (check_decimal_overflow)
OpImplCheck::template vectorVector<left_is_decimal, right_is_decimal>(
col_left->getData(), col_right->getData(), vec_res, left_scale, right_scale);
else
OpImpl::template vectorVector<left_is_decimal, right_is_decimal>(
col_left->getData(), col_right->getData(), vec_res, left_scale, right_scale);
helperInvokeEither<OpCase::Vector, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b);
}
else if (col_left_const && col_right)
{
const NativeResultType const_a = helperGetOrConvert<T0, ResultDataType>(col_left_const, left);
if (check_decimal_overflow)
OpImplCheck::template constantVector<left_is_decimal, right_is_decimal>(
const_a, col_right->getData(), vec_res, left_scale, right_scale);
else
OpImpl::template constantVector<left_is_decimal, right_is_decimal>(
const_a, col_right->getData(), vec_res, left_scale, right_scale);
helperInvokeEither<OpCase::LeftConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
const_a, col_right->getData(), vec_res, scale_a, scale_b);
}
else if (col_left && col_right_const)
{
const NativeResultType const_b = helperGetOrConvert<T1, ResultDataType>(col_right_const, right);
if (check_decimal_overflow)
OpImplCheck::template vectorConstant<left_is_decimal, right_is_decimal>(
col_left->getData(), const_b, vec_res, left_scale, right_scale);
else
OpImpl::template vectorConstant<left_is_decimal, right_is_decimal>(
col_left->getData(), const_b, vec_res, left_scale, right_scale);
helperInvokeEither<OpCase::RightConstant, left_is_decimal, right_is_decimal, OpImpl, OpImplCheck>(
col_left->getData(), const_b, vec_res, scale_a, scale_b);
}
else
return nullptr;
@ -980,11 +990,11 @@ public:
out_chars.resize(col_left->getN());
OpImpl::process<OpCase::Vector>(
OpImpl::template process<OpCase::Vector>(
col_left->getChars().data(),
col_right->getChars().data(),
out_chars.data(),
out_chars.size());
out_chars.size(), {});
return ColumnConst::create(std::move(col_res), col_left_raw->size());
}
@ -1013,15 +1023,15 @@ public:
if (!is_left_column_const && !is_right_column_const)
{
OpImpl::process<OpCase::Vector>(
OpImpl::template process<OpCase::Vector>(
col_left->getChars().data(),
col_right->getChars().data(),
out_chars.data(),
out_chars.size());
out_chars.size(), {});
}
else if (is_left_column_const)
{
OpImpl::process<OpCase::LeftConstant>(
OpImpl::template process<OpCase::LeftConstant>(
col_left->getChars().data(),
col_right->getChars().data(),
out_chars.data(),
@ -1030,7 +1040,7 @@ public:
}
else
{
OpImpl::process<OpCase::RightConstant>(
OpImpl::template process<OpCase::RightConstant>(
col_left->getChars().data(),
col_right->getChars().data(),
out_chars.data(),
@ -1053,6 +1063,8 @@ public:
if constexpr (std::is_same_v<ResultDataType, InvalidType>)
return nullptr;
static_assert(!std::is_same_v<ResultDataType, InvalidType>);
using T0 = typename LeftDataType::FieldType;
using T1 = typename RightDataType::FieldType;
using ResultType = typename ResultDataType::FieldType;
@ -1096,7 +1108,7 @@ public:
if (col_left && col_right)
{
OpImpl::process<OpCase::Vector>(
OpImpl::template process<OpCase::Vector>(
col_left->getData().data(),
col_right->getData().data(),
vec_res.data(),
@ -1104,7 +1116,7 @@ public:
}
else if (col_left_const && col_right)
{
OpImpl::process<OpCase::LeftConstant>(
OpImpl::template process<OpCase::LeftConstant>(
col_left_const->template getValue<T0>(),
col_right->getData().data(),
vec_res.data(),
@ -1112,7 +1124,7 @@ public:
}
else if (col_left && col_right_const)
{
OpImpl::process<OpCase::RightConstant>(
OpImpl::template process<OpCase::RightConstant>(
col_left->getData().data(),
col_right_const->template getValue<T1>(),
vec_res.data(),