mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Fix bug in Decimal scale (#14603)
This commit is contained in:
parent
3973a17530
commit
48f29ae11f
@ -129,7 +129,7 @@ private:
|
||||
Shift shift;
|
||||
if (decimal0 && decimal1)
|
||||
{
|
||||
auto result_type = decimalResultType(*decimal0, *decimal1, false, false);
|
||||
auto result_type = decimalResultType<false, false>(*decimal0, *decimal1);
|
||||
shift.a = static_cast<CompareInt>(result_type.scaleFactorFor(*decimal0, false).value);
|
||||
shift.b = static_cast<CompareInt>(result_type.scaleFactorFor(*decimal1, false).value);
|
||||
}
|
||||
|
@ -156,38 +156,31 @@ protected:
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename U, template <typename> typename DecimalType>
|
||||
typename std::enable_if_t<(sizeof(T) >= sizeof(U)), DecimalType<T>>
|
||||
inline decimalResultType(const DecimalType<T> & tx, const DecimalType<U> & ty, bool is_multiply, bool is_divide)
|
||||
template <bool is_multiply, bool is_division, typename T, typename U, template <typename> typename DecimalType>
|
||||
inline auto decimalResultType(const DecimalType<T> & tx, const DecimalType<U> & ty)
|
||||
{
|
||||
UInt32 scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
|
||||
if (is_multiply)
|
||||
UInt32 scale{};
|
||||
if constexpr (is_multiply)
|
||||
scale = tx.getScale() + ty.getScale();
|
||||
else if (is_divide)
|
||||
else if constexpr (is_division)
|
||||
scale = tx.getScale();
|
||||
else
|
||||
scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
|
||||
|
||||
if constexpr (sizeof(T) < sizeof(U))
|
||||
return DecimalType<U>(DecimalUtils::maxPrecision<U>(), scale);
|
||||
else
|
||||
return DecimalType<T>(DecimalUtils::maxPrecision<T>(), scale);
|
||||
}
|
||||
|
||||
template <typename T, typename U, template <typename> typename DecimalType>
|
||||
typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DecimalType<U>>
|
||||
inline decimalResultType(const DecimalType<T> & tx, const DecimalType<U> & ty, bool is_multiply, bool is_divide)
|
||||
{
|
||||
UInt32 scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
|
||||
if (is_multiply)
|
||||
scale = tx.getScale() * ty.getScale();
|
||||
else if (is_divide)
|
||||
scale = tx.getScale();
|
||||
return DecimalType<U>(DecimalUtils::maxPrecision<U>(), scale);
|
||||
}
|
||||
|
||||
template <typename T, typename U, template <typename> typename DecimalType>
|
||||
inline const DecimalType<T> decimalResultType(const DecimalType<T> & tx, const DataTypeNumber<U> &, bool, bool)
|
||||
template <bool, bool, typename T, typename U, template <typename> typename DecimalType>
|
||||
inline const DecimalType<T> decimalResultType(const DecimalType<T> & tx, const DataTypeNumber<U> &)
|
||||
{
|
||||
return DecimalType<T>(DecimalUtils::maxPrecision<T>(), tx.getScale());
|
||||
}
|
||||
|
||||
template <typename T, typename U, template <typename> typename DecimalType>
|
||||
inline const DecimalType<U> decimalResultType(const DataTypeNumber<T> &, const DecimalType<U> & ty, bool, bool)
|
||||
template <bool, bool, typename T, typename U, template <typename> typename DecimalType>
|
||||
inline const DecimalType<U> decimalResultType(const DataTypeNumber<T> &, const DecimalType<U> & ty)
|
||||
{
|
||||
return DecimalType<U>(DecimalUtils::maxPrecision<U>(), ty.getScale());
|
||||
}
|
||||
|
@ -561,6 +561,9 @@ public:
|
||||
template <template <typename, typename> class Op, typename Name, bool valid_on_default_arguments = true>
|
||||
class FunctionBinaryArithmetic : public IFunction
|
||||
{
|
||||
static constexpr const bool is_multiply = IsOperation<Op>::multiply;
|
||||
static constexpr const bool is_division = IsOperation<Op>::division;
|
||||
|
||||
const Context & context;
|
||||
bool check_decimal_overflow = true;
|
||||
|
||||
@ -872,10 +875,7 @@ public:
|
||||
{
|
||||
if constexpr (IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType>)
|
||||
{
|
||||
constexpr bool is_multiply = IsOperation<Op>::multiply;
|
||||
constexpr bool is_division = IsOperation<Op>::division;
|
||||
|
||||
ResultDataType result_type = decimalResultType(left, right, is_multiply, is_division);
|
||||
ResultDataType result_type = decimalResultType<is_multiply, is_division>(left, right);
|
||||
type_res = std::make_shared<ResultDataType>(result_type.getPrecision(), result_type.getScale());
|
||||
}
|
||||
else if constexpr (IsDataTypeDecimal<LeftDataType>)
|
||||
@ -995,8 +995,6 @@ public:
|
||||
if constexpr (!std::is_same_v<ResultDataType, InvalidType>)
|
||||
{
|
||||
constexpr bool result_is_decimal = IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>;
|
||||
constexpr bool is_multiply = IsOperation<Op>::multiply;
|
||||
constexpr bool is_division = IsOperation<Op>::division;
|
||||
|
||||
using T0 = typename LeftDataType::FieldType;
|
||||
using T1 = typename RightDataType::FieldType;
|
||||
@ -1019,7 +1017,7 @@ public:
|
||||
/// the only case with a non-vector result
|
||||
if constexpr (result_is_decimal)
|
||||
{
|
||||
ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
|
||||
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
|
||||
typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
|
||||
typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
|
||||
if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
|
||||
@ -1044,7 +1042,7 @@ public:
|
||||
typename ColVecResult::MutablePtr col_res = nullptr;
|
||||
if constexpr (result_is_decimal)
|
||||
{
|
||||
ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
|
||||
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
|
||||
col_res = ColVecResult::create(0, type.getScale());
|
||||
}
|
||||
else
|
||||
@ -1059,7 +1057,7 @@ public:
|
||||
{
|
||||
if constexpr (result_is_decimal)
|
||||
{
|
||||
ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
|
||||
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
|
||||
|
||||
typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
|
||||
typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
|
||||
@ -1079,12 +1077,13 @@ public:
|
||||
{
|
||||
if constexpr (result_is_decimal)
|
||||
{
|
||||
ResultDataType type = decimalResultType(left, right, is_multiply, is_division);
|
||||
ResultDataType type = decimalResultType<is_multiply, is_division>(left, right);
|
||||
|
||||
typename ResultDataType::FieldType scale_a = type.scaleFactorFor(left, is_multiply);
|
||||
typename ResultDataType::FieldType scale_b = type.scaleFactorFor(right, is_multiply || is_division);
|
||||
if constexpr (IsDataTypeDecimal<RightDataType> && is_division)
|
||||
scale_a = right.getScaleMultiplier();
|
||||
|
||||
if (auto col_right = checkAndGetColumn<ColVecT1>(col_right_raw))
|
||||
{
|
||||
OpImpl::vectorVector(col_left->getData(), col_right->getData(), vec_res, scale_a, scale_b,
|
||||
|
@ -13,7 +13,7 @@
|
||||
12
|
||||
13 fail: join predicates
|
||||
14
|
||||
0.00000000
|
||||
0.000000
|
||||
15 fail: correlated subquery
|
||||
16
|
||||
17 fail: correlated subquery
|
||||
|
18
tests/queries/0_stateless/01474_decimal_scale_bug.reference
Normal file
18
tests/queries/0_stateless/01474_decimal_scale_bug.reference
Normal file
@ -0,0 +1,18 @@
|
||||
1.000 Decimal(9, 3)
|
||||
1.000 Decimal(9, 3)
|
||||
1.0000 Decimal(18, 4)
|
||||
1.0000 Decimal(18, 4)
|
||||
1.00000 Decimal(38, 5)
|
||||
1.00000 Decimal(38, 5)
|
||||
1.000 Decimal(18, 3)
|
||||
1.000 Decimal(18, 3)
|
||||
1.0000 Decimal(18, 4)
|
||||
1.0000 Decimal(18, 4)
|
||||
1.00000 Decimal(38, 5)
|
||||
1.00000 Decimal(38, 5)
|
||||
1.000 Decimal(38, 3)
|
||||
1.000 Decimal(38, 3)
|
||||
1.0000 Decimal(38, 4)
|
||||
1.0000 Decimal(38, 4)
|
||||
1.00000 Decimal(38, 5)
|
||||
1.00000 Decimal(38, 5)
|
20
tests/queries/0_stateless/01474_decimal_scale_bug.sql
Normal file
20
tests/queries/0_stateless/01474_decimal_scale_bug.sql
Normal file
@ -0,0 +1,20 @@
|
||||
SELECT toDecimal32(1, 2) * toDecimal32(1, 1) x, toTypeName(x);
|
||||
SELECT toDecimal32(1, 1) * toDecimal32(1, 2) x, toTypeName(x);
|
||||
SELECT toDecimal32(1, 3) * toDecimal64(1, 1) x, toTypeName(x);
|
||||
SELECT toDecimal32(1, 1) * toDecimal64(1, 3) x, toTypeName(x);
|
||||
SELECT toDecimal32(1, 2) * toDecimal128(1, 3) x, toTypeName(x);
|
||||
SELECT toDecimal32(1, 3) * toDecimal128(1, 2) x, toTypeName(x);
|
||||
|
||||
SELECT toDecimal64(1, 2) * toDecimal32(1, 1) x, toTypeName(x);
|
||||
SELECT toDecimal64(1, 1) * toDecimal32(1, 2) x, toTypeName(x);
|
||||
SELECT toDecimal64(1, 3) * toDecimal64(1, 1) x, toTypeName(x);
|
||||
SELECT toDecimal64(1, 1) * toDecimal64(1, 3) x, toTypeName(x);
|
||||
SELECT toDecimal64(1, 2) * toDecimal128(1, 3) x, toTypeName(x);
|
||||
SELECT toDecimal64(1, 3) * toDecimal128(1, 2) x, toTypeName(x);
|
||||
|
||||
SELECT toDecimal128(1, 2) * toDecimal32(1, 1) x, toTypeName(x);
|
||||
SELECT toDecimal128(1, 1) * toDecimal32(1, 2) x, toTypeName(x);
|
||||
SELECT toDecimal128(1, 3) * toDecimal64(1, 1) x, toTypeName(x);
|
||||
SELECT toDecimal128(1, 1) * toDecimal64(1, 3) x, toTypeName(x);
|
||||
SELECT toDecimal128(1, 2) * toDecimal128(1, 3) x, toTypeName(x);
|
||||
SELECT toDecimal128(1, 3) * toDecimal128(1, 2) x, toTypeName(x);
|
Loading…
Reference in New Issue
Block a user