Fix bug in Decimal scale (#14603)

This commit is contained in:
Artem Zuikov 2020-09-09 16:18:58 +03:00 committed by GitHub
parent 3973a17530
commit 48f29ae11f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 37 deletions

View File

@ -129,7 +129,7 @@ private:
Shift shift;
if (decimal0 && decimal1)
{
auto result_type = decimalResultType(*decimal0, *decimal1, false, false);
auto result_type = decimalResultType<false, false>(*decimal0, *decimal1);
shift.a = static_cast<CompareInt>(result_type.scaleFactorFor(*decimal0, false).value);
shift.b = static_cast<CompareInt>(result_type.scaleFactorFor(*decimal1, false).value);
}

View File

@ -156,38 +156,31 @@ protected:
};
template <typename T, typename U, template <typename> typename DecimalType>
typename std::enable_if_t<(sizeof(T) >= sizeof(U)), DecimalType<T>>
inline decimalResultType(const DecimalType<T> & tx, const DecimalType<U> & ty, bool is_multiply, bool is_divide)
template <bool is_multiply, bool is_division, typename T, typename U, template <typename> typename DecimalType>
inline auto decimalResultType(const DecimalType<T> & tx, const DecimalType<U> & ty)
{
UInt32 scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
if (is_multiply)
UInt32 scale{};
if constexpr (is_multiply)
scale = tx.getScale() + ty.getScale();
else if (is_divide)
else if constexpr (is_division)
scale = tx.getScale();
return DecimalType<T>(DecimalUtils::maxPrecision<T>(), scale);
else
scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
if constexpr (sizeof(T) < sizeof(U))
return DecimalType<U>(DecimalUtils::maxPrecision<U>(), scale);
else
return DecimalType<T>(DecimalUtils::maxPrecision<T>(), scale);
}
template <typename T, typename U, template <typename> typename DecimalType>
typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DecimalType<U>>
inline decimalResultType(const DecimalType<T> & tx, const DecimalType<U> & ty, bool is_multiply, bool is_divide)
{
UInt32 scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
if (is_multiply)
scale = tx.getScale() * ty.getScale();
else if (is_divide)
scale = tx.getScale();
return DecimalType<U>(DecimalUtils::maxPrecision<U>(), scale);
}
template <typename T, typename U, template <typename> typename DecimalType>
inline const DecimalType<T> decimalResultType(const DecimalType<T> & tx, const DataTypeNumber<U> &, bool, bool)
template <bool, bool, typename T, typename U, template <typename> typename DecimalType>
inline const DecimalType<T> decimalResultType(const DecimalType<T> & tx, const DataTypeNumber<U> &)
{
return DecimalType<T>(DecimalUtils::maxPrecision<T>(), tx.getScale());
}
template <typename T, typename U, template <typename> typename DecimalType>
inline const DecimalType<U> decimalResultType(const DataTypeNumber<T> &, const DecimalType<U> & ty, bool, bool)
template <bool, bool, typename T, typename U, template <typename> typename DecimalType>
inline const DecimalType<U> decimalResultType(const DataTypeNumber<T> &, const DecimalType<U> & ty)
{
return DecimalType<U>(DecimalUtils::maxPrecision<U>(), ty.getScale());
}

View File

@ -561,6 +561,9 @@ public:
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true>
class FunctionBinaryArithmetic : public IFunction
{
static constexpr const bool is_multiply = IsOperation<Op>::multiply;
static constexpr const bool is_division = IsOperation<Op>::division;
const Context & context;
bool check_decimal_overflow = true;
@ -858,7 +861,7 @@ public:
return false;
else if constexpr (std::is_same_v<LeftDataType, RightDataType>)
{
if (left.getN() == right.getN())
if (left.getN() == right.getN())
{
type_res = std::make_shared<LeftDataType>(left.getN());
return true;
@ -872,10 +875,7 @@ public:
{
if constexpr (IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType>)
{
constexpr bool is_multiply = IsOperation<Op>::multiply;
constexpr bool is_division = IsOperation<Op>::division;
ResultDataType result_type = decimalResultType(left, right, is_multiply, is_division);
ResultDataType result_type = decimalResultType<is_multiply, is_division>(left, right);
type_res = std::make_shared<ResultDataType>(result_type.getPrecision(), result_type.getScale());
}
else if constexpr (IsDataTypeDecimal<LeftDataType>)
@ -899,7 +899,7 @@ public:
type_res = std::make_shared<ResultDataType>();
return true;
}
}
}
return false;
});
if (!valid)
@ -995,8 +995,6 @@ public:
if constexpr (!std::is_same_v<ResultDataType, InvalidType>)
{
constexpr bool result_is_decimal = IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>;
constexpr bool is_multiply = IsOperation<Op>::multiply;
constexpr bool is_division = IsOperation<Op>::division;
using T0 = typename LeftDataType::FieldType;
using T1 = typename RightDataType::FieldType;
@ -1019,7 +1017,7 @@ public:
/// the only case with a non-vector result
if constexpr (result_is_decimal)
{
ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
@ -1044,7 +1042,7 @@ public:
typename ColVecResult::MutablePtr col_res = nullptr;
if constexpr (result_is_decimal)
{
ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
col_res = ColVecResult::create(0, type.getScale());
}
else
@ -1059,7 +1057,7 @@ public:
{
if constexpr (result_is_decimal)
{
ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
@ -1079,12 +1077,13 @@ public:
{
if constexpr (result_is_decimal)
{
ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
scale_a = right.getScaleMultiplier();
if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
{
OpImpl::vectorVector(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b,

View File

@ -13,7 +13,7 @@
12
13 fail: join predicates
14
0.00000000
0.000000
15 fail: correlated subquery
16
17 fail: correlated subquery

View File

@ -0,0 +1,18 @@
1.000 Decimal(9, 3)
1.000 Decimal(9, 3)
1.0000 Decimal(18, 4)
1.0000 Decimal(18, 4)
1.00000 Decimal(38, 5)
1.00000 Decimal(38, 5)
1.000 Decimal(18, 3)
1.000 Decimal(18, 3)
1.0000 Decimal(18, 4)
1.0000 Decimal(18, 4)
1.00000 Decimal(38, 5)
1.00000 Decimal(38, 5)
1.000 Decimal(38, 3)
1.000 Decimal(38, 3)
1.0000 Decimal(38, 4)
1.0000 Decimal(38, 4)
1.00000 Decimal(38, 5)
1.00000 Decimal(38, 5)

View File

@ -0,0 +1,20 @@
SELECT toDecimal32(1, 2) * toDecimal32(1, 1) x, toTypeName(x);
SELECT toDecimal32(1, 1) * toDecimal32(1, 2) x, toTypeName(x);
SELECT toDecimal32(1, 3) * toDecimal64(1, 1) x, toTypeName(x);
SELECT toDecimal32(1, 1) * toDecimal64(1, 3) x, toTypeName(x);
SELECT toDecimal32(1, 2) * toDecimal128(1, 3) x, toTypeName(x);
SELECT toDecimal32(1, 3) * toDecimal128(1, 2) x, toTypeName(x);
SELECT toDecimal64(1, 2) * toDecimal32(1, 1) x, toTypeName(x);
SELECT toDecimal64(1, 1) * toDecimal32(1, 2) x, toTypeName(x);
SELECT toDecimal64(1, 3) * toDecimal64(1, 1) x, toTypeName(x);
SELECT toDecimal64(1, 1) * toDecimal64(1, 3) x, toTypeName(x);
SELECT toDecimal64(1, 2) * toDecimal128(1, 3) x, toTypeName(x);
SELECT toDecimal64(1, 3) * toDecimal128(1, 2) x, toTypeName(x);
SELECT toDecimal128(1, 2) * toDecimal32(1, 1) x, toTypeName(x);
SELECT toDecimal128(1, 1) * toDecimal32(1, 2) x, toTypeName(x);
SELECT toDecimal128(1, 3) * toDecimal64(1, 1) x, toTypeName(x);
SELECT toDecimal128(1, 1) * toDecimal64(1, 3) x, toTypeName(x);
SELECT toDecimal128(1, 2) * toDecimal128(1, 3) x, toTypeName(x);
SELECT toDecimal128(1, 3) * toDecimal128(1, 2) x, toTypeName(x);