diff --git a/src/Functions/FunctionBinaryArithmetic.h b/src/Functions/FunctionBinaryArithmetic.h index a9bff29fca6..ec5c9e8bd67 100644 --- a/src/Functions/FunctionBinaryArithmetic.h +++ b/src/Functions/FunctionBinaryArithmetic.h @@ -307,8 +307,9 @@ private: ColumnVector>::Container; public: - template - static void NO_INLINE process(const A & a, const B & b, ResultContainerType & c, NativeResultType scale) + template + static void NO_INLINE process(const A & a, const B & b, ResultContainerType & c, + NativeResultType scale_a, NativeResultType scale_b) { if constexpr(op_case == OpCase::LeftConstant) static_assert(!IsDecimalNumber); if constexpr(op_case == OpCase::RightConstant) static_assert(!IsDecimalNumber); @@ -317,13 +318,22 @@ public: if constexpr (is_plus_minus_compare) { - if (scale != 1) + if (scale_a != 1) { for (size_t i = 0; i < size; ++i) - c[i] = applyScaled( + c[i] = applyScaled( unwrap(a, i), unwrap(b, i), - scale); + scale_a); + return; + } + else if (scale_b != 1) + { + for (size_t i = 0; i < size; ++i) + c[i] = applyScaled( + unwrap(a, i), + unwrap(b, i), + scale_b); return; } } @@ -333,7 +343,7 @@ public: c[i] = applyScaledDiv( unwrap(a, i), unwrap(b, i), - scale); + scale_a); return; } @@ -343,17 +353,21 @@ public: unwrap(b, i)); } - template - static ResultType process(A a, B b, NativeResultType scale) + template + static ResultType process(A a, B b, NativeResultType scale_a, NativeResultType scale_b) { static_assert(!IsDecimalNumber); static_assert(!IsDecimalNumber); if constexpr (is_division && is_decimal_b) - return applyScaledDiv(a, b, scale); + return applyScaledDiv(a, b, scale_a); else if constexpr (is_plus_minus_compare) - if (scale != 1) - return applyScaled(a, b, scale); + { + if (scale_a != 1) + return applyScaled(a, b, scale_a); + if (scale_b != 1) + return applyScaled(a, b, scale_b); + } return apply(a, b); } @@ -714,12 +728,21 @@ class FunctionBinaryArithmetic : public IFunction return col_const->template getValue(); } + template + void helperInvokeEither(const auto& left, const auto& right, auto& vec_res, auto scale_a, auto scale_b) const + { + if (check_decimal_overflow) + OpImplCheck::template process(left, right, vec_res, scale_a, scale_b); + else + OpImpl::template process(left, right, vec_res, scale_a, scale_b); + } + template ColumnPtr executeNumericWithDecimal( const auto& left, const auto& right, const ColumnConst * const col_left_const, const ColumnConst * const col_right_const, const auto * const col_left, const auto * const col_right, - size_t col_left_size) + size_t col_left_size) const { using T0 = typename LeftDataType::FieldType; using T1 = typename RightDataType::FieldType; @@ -747,24 +770,25 @@ class FunctionBinaryArithmetic : public IFunction return decimalResultType(left, right); }(); - ResultType left_scale; - ResultType right_scale; + const ResultType scale_a = [&] { + if constexpr (IsDataTypeDecimal && is_division) + return right.getScaleMultiplier(); + else if constexpr (result_is_decimal) + return type.scaleFactorFor(left, is_multiply); + else if constexpr(left_is_decimal) + return left.getScale(); + else + return 1; //won't be used, just to silence the warning + }(); - if constexpr (IsDataTypeDecimal && is_division) - left_scale = right.getScaleMultiplier(); - else if constexpr (result_is_decimal) - left_scale = type.scaleFactorFor(left, is_multiply); - else if constexpr(left_is_decimal) // BUG precision loss - left_scale = left.getScale(); - else - left_scale = 1; //won't be used, just to silence the warning - - if constexpr (result_is_decimal) - right_scale = type.scaleFactorFor(right, is_multiply || is_division); - else if constexpr(right_is_decimal) - right_scale = right.getScale(); - else - right_scale = 1; //same + const ResultType scale_b = [&] { + if constexpr (result_is_decimal) + return type.scaleFactorFor(right, is_multiply || is_division); + else if constexpr(right_is_decimal) + return right.getScale(); + else + return 1; //same + }(); /// non-vector result if (col_left_const && col_right_const) @@ -772,11 +796,9 @@ class FunctionBinaryArithmetic : public IFunction const NativeResultType const_a = helperGetOrConvert(col_left_const, left); const NativeResultType const_b = helperGetOrConvert(col_right_const, right); - auto res = check_decimal_overflow ? - OpImplCheck::template constantConstant( - const_a, const_b, left_scale, right_scale) : - OpImpl::template constantConstant( - const_a, const_b, left_scale, right_scale); + const ResultType res = check_decimal_overflow + ? OpImplCheck::template process(const_a, const_b, scale_a, scale_b) + : OpImpl::template process(const_a, const_b, scale_a, scale_b); if constexpr (result_is_decimal) return ResultDataType(type.getPrecision(), type.getScale()).createColumnConst( @@ -795,34 +817,22 @@ class FunctionBinaryArithmetic : public IFunction if (col_left && col_right) { - if (check_decimal_overflow) - OpImplCheck::template vectorVector( - col_left->getData(), col_right->getData(), vec_res, left_scale, right_scale); - else - OpImpl::template vectorVector( - col_left->getData(), col_right->getData(), vec_res, left_scale, right_scale); + helperInvokeEither( + col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b); } else if (col_left_const && col_right) { const NativeResultType const_a = helperGetOrConvert(col_left_const, left); - if (check_decimal_overflow) - OpImplCheck::template constantVector( - const_a, col_right->getData(), vec_res, left_scale, right_scale); - else - OpImpl::template constantVector( - const_a, col_right->getData(), vec_res, left_scale, right_scale); + helperInvokeEither( + const_a, col_right->getData(), vec_res, scale_a, scale_b); } else if (col_left && col_right_const) { const NativeResultType const_b = helperGetOrConvert(col_right_const, right); - if (check_decimal_overflow) - OpImplCheck::template vectorConstant( - col_left->getData(), const_b, vec_res, left_scale, right_scale); - else - OpImpl::template vectorConstant( - col_left->getData(), const_b, vec_res, left_scale, right_scale); + helperInvokeEither( + col_left->getData(), const_b, vec_res, scale_a, scale_b); } else return nullptr; @@ -980,11 +990,11 @@ public: out_chars.resize(col_left->getN()); - OpImpl::process( + OpImpl::template process( col_left->getChars().data(), col_right->getChars().data(), out_chars.data(), - out_chars.size()); + out_chars.size(), {}); return ColumnConst::create(std::move(col_res), col_left_raw->size()); } @@ -1013,15 +1023,15 @@ public: if (!is_left_column_const && !is_right_column_const) { - OpImpl::process( + OpImpl::template process( col_left->getChars().data(), col_right->getChars().data(), out_chars.data(), - out_chars.size()); + out_chars.size(), {}); } else if (is_left_column_const) { - OpImpl::process( + OpImpl::template process( col_left->getChars().data(), col_right->getChars().data(), out_chars.data(), @@ -1030,7 +1040,7 @@ public: } else { - OpImpl::process( + OpImpl::template process( col_left->getChars().data(), col_right->getChars().data(), out_chars.data(), @@ -1053,6 +1063,8 @@ public: if constexpr (std::is_same_v) return nullptr; + static_assert(!std::is_same_v); + using T0 = typename LeftDataType::FieldType; using T1 = typename RightDataType::FieldType; using ResultType = typename ResultDataType::FieldType; @@ -1096,7 +1108,7 @@ public: if (col_left && col_right) { - OpImpl::process( + OpImpl::template process( col_left->getData().data(), col_right->getData().data(), vec_res.data(), @@ -1104,7 +1116,7 @@ public: } else if (col_left_const && col_right) { - OpImpl::process( + OpImpl::template process( col_left_const->template getValue(), col_right->getData().data(), vec_res.data(), @@ -1112,7 +1124,7 @@ public: } else if (col_left && col_right_const) { - OpImpl::process( + OpImpl::template process( col_left->getData().data(), col_right_const->template getValue(), vec_res.data(),