diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index 7240a9627ac..6756189408c 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -153,7 +153,7 @@ public: UInt32 getScale() const { return scale; } private: - UInt32 scale = DecimalField::wrongScale(); + UInt32 scale = DecimalField::wrongScale(); }; @@ -258,7 +258,7 @@ public: if constexpr (IsDecimalNumber) { UInt32 scale = data.getScale(); - if (scale == DecimalField::wrongScale()) + if (scale == DecimalField::wrongScale()) throw Exception("Extracting Decimal field with unknown scale. Scale is lost.", ErrorCodes::LOGICAL_ERROR); return DecimalField(data[n], scale); } diff --git a/dbms/src/Core/Field.cpp b/dbms/src/Core/Field.cpp index 279dd90c52b..eb6278adc95 100644 --- a/dbms/src/Core/Field.cpp +++ b/dbms/src/Core/Field.cpp @@ -273,21 +273,36 @@ namespace DB } - bool DecimalField::operator < (const DecimalField & r) const + template + static bool decEqual(T x, T y, UInt32 x_scale, UInt32 y_scale) { - using Comparator = DecimalComparison; - return Comparator::compare(Decimal128(dec), Decimal128(r.dec), scale, r.scale); + using Comparator = DecimalComparison; + return Comparator::compare(x, y, x_scale, y_scale); } - bool DecimalField::operator <= (const DecimalField & r) const + template + static bool decLess(T x, T y, UInt32 x_scale, UInt32 y_scale) { - using Comparator = DecimalComparison; - return Comparator::compare(Decimal128(dec), Decimal128(r.dec), scale, r.scale); + using Comparator = DecimalComparison; + return Comparator::compare(x, y, x_scale, y_scale); } - bool DecimalField::operator == (const DecimalField & r) const + template + static bool decLessOrEqual(T x, T y, UInt32 x_scale, UInt32 y_scale) { - using Comparator = DecimalComparison; - return Comparator::compare(Decimal128(dec), Decimal128(r.dec), scale, r.scale); + using Comparator = DecimalComparison; + return Comparator::compare(x, y, x_scale, y_scale); } + + template <> bool decimalEqual(Decimal32 x, Decimal32 y, UInt32 xs, UInt32 ys) { return decEqual(x, y, xs, ys); } + template <> bool decimalLess(Decimal32 x, Decimal32 y, UInt32 xs, UInt32 ys) { return decLess(x, y, xs, ys); } + template <> bool decimalLessOrEqual(Decimal32 x, Decimal32 y, UInt32 xs, UInt32 ys) { return decLessOrEqual(x, y, xs, ys); } + + template <> bool decimalEqual(Decimal64 x, Decimal64 y, UInt32 xs, UInt32 ys) { return decEqual(x, y, xs, ys); } + template <> bool decimalLess(Decimal64 x, Decimal64 y, UInt32 xs, UInt32 ys) { return decLess(x, y, xs, ys); } + template <> bool decimalLessOrEqual(Decimal64 x, Decimal64 y, UInt32 xs, UInt32 ys) { return decLessOrEqual(x, y, xs, ys); } + + template <> bool decimalEqual(Decimal128 x, Decimal128 y, UInt32 xs, UInt32 ys) { return decEqual(x, y, xs, ys); } + template <> bool decimalLess(Decimal128 x, Decimal128 y, UInt32 xs, UInt32 ys) { return decLess(x, y, xs, ys); } + template <> bool decimalLessOrEqual(Decimal128 x, Decimal128 y, UInt32 xs, UInt32 ys) { return decLessOrEqual(x, y, xs, ys); } } diff --git a/dbms/src/Core/Field.h b/dbms/src/Core/Field.h index a77b22a3da4..31062e7d66b 100644 --- a/dbms/src/Core/Field.h +++ b/dbms/src/Core/Field.h @@ -27,33 +27,52 @@ using Array = std::vector; using TupleBackend = std::vector; STRONG_TYPEDEF(TupleBackend, Tuple) /// Array and Tuple are different types with equal representation inside Field. +template bool decimalEqual(T x, T y, UInt32 x_scale, UInt32 y_scale); +template bool decimalLess(T x, T y, UInt32 x_scale, UInt32 y_scale); +template bool decimalLessOrEqual(T x, T y, UInt32 x_scale, UInt32 y_scale); +template class DecimalField { public: static constexpr UInt32 wrongScale() { return std::numeric_limits::max(); } - DecimalField(Int128 value, UInt32 scale_ = wrongScale()) + DecimalField(T value, UInt32 scale_ = wrongScale()) : dec(value), scale(scale_) {} - operator Decimal32() const { return dec; } - operator Decimal64() const { return dec; } - operator Decimal128() const { return dec; } + operator T() const { return dec; } UInt32 getScale() const { return scale; } - bool operator < (const DecimalField & r) const; - bool operator <= (const DecimalField & r) const; - bool operator == (const DecimalField & r) const; + template + bool operator < (const DecimalField & r) const + { + using MaxType = std::conditional_t<(sizeof(T) > sizeof(U)), T, U>; + return decimalLess(dec, r, scale, r.getScale()); + } - bool operator > (const DecimalField & r) const { return r < *this; } - bool operator >= (const DecimalField & r) const { return r <= * this; } - bool operator != (const DecimalField & r) const { return !(*this == r); } + template + bool operator <= (const DecimalField & r) const + { + using MaxType = std::conditional_t<(sizeof(T) > sizeof(U)), T, U>; + return decimalLessOrEqual(dec, r, scale, r.getScale()); + } + + template + bool operator == (const DecimalField & r) const + { + using MaxType = std::conditional_t<(sizeof(T) > sizeof(U)), T, U>; + return decimalEqual(dec, r, scale, r.getScale()); + } + + template bool operator > (const DecimalField & r) const { return r < *this; } + template bool operator >= (const DecimalField & r) const { return r <= * this; } + template bool operator != (const DecimalField & r) const { return !(*this == r); } private: - Int128 dec; + T dec; UInt32 scale; }; @@ -91,7 +110,9 @@ public: String = 16, Array = 17, Tuple = 18, - Decimal = 19, + Decimal32 = 19, + Decimal64 = 20, + Decimal128 = 21, }; static const int MIN_NON_POD = 16; @@ -109,7 +130,9 @@ public: case String: return "String"; case Array: return "Array"; case Tuple: return "Tuple"; - case Decimal: return "Decimal"; + case Decimal32: return "Decimal32"; + case Decimal64: return "Decimal64"; + case Decimal128: return "Decimal128"; } throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD); @@ -121,6 +144,7 @@ public: template struct TypeToEnum; template struct EnumToType; + static bool IsDecimal(Types::Which which) { return which >= Types::Decimal32 && which <= Types::Decimal128; } Field() : which(Types::Null) @@ -294,7 +318,9 @@ public: case Types::String: return get() < rhs.get(); case Types::Array: return get() < rhs.get(); case Types::Tuple: return get() < rhs.get(); - case Types::Decimal: return get() < rhs.get(); + case Types::Decimal32: return get>() < rhs.get>(); + case Types::Decimal64: return get>() < rhs.get>(); + case Types::Decimal128: return get>() < rhs.get>(); } throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD); @@ -323,7 +349,9 @@ public: case Types::String: return get() <= rhs.get(); case Types::Array: return get() <= rhs.get(); case Types::Tuple: return get() <= rhs.get(); - case Types::Decimal: return get() <= rhs.get(); + case Types::Decimal32: return get>() <= rhs.get>(); + case Types::Decimal64: return get>() <= rhs.get>(); + case Types::Decimal128: return get>() <= rhs.get>(); } throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD); @@ -350,7 +378,9 @@ public: case Types::Tuple: return get() == rhs.get(); case Types::UInt128: return get() == rhs.get(); case Types::Int128: return get() == rhs.get(); - case Types::Decimal: return get() == rhs.get(); + case Types::Decimal32: return get>() == rhs.get>(); + case Types::Decimal64: return get>() == rhs.get>(); + case Types::Decimal128: return get>() == rhs.get>(); } throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD); @@ -363,7 +393,8 @@ public: private: std::aligned_union_t, DecimalField, DecimalField > storage; Types::Which which; @@ -412,7 +443,9 @@ private: case Types::String: f(field.template get()); return; case Types::Array: f(field.template get()); return; case Types::Tuple: f(field.template get()); return; - case Types::Decimal: f(field.template get()); return; + case Types::Decimal32: f(field.template get>()); return; + case Types::Decimal64: f(field.template get>()); return; + case Types::Decimal128: f(field.template get>()); return; default: throw Exception("Bad type of Field", ErrorCodes::BAD_TYPE_OF_FIELD); @@ -496,7 +529,9 @@ template <> struct Field::TypeToEnum { static const Types::Which value template <> struct Field::TypeToEnum { static const Types::Which value = Types::String; }; template <> struct Field::TypeToEnum { static const Types::Which value = Types::Array; }; template <> struct Field::TypeToEnum { static const Types::Which value = Types::Tuple; }; -template <> struct Field::TypeToEnum{ static const Types::Which value = Types::Decimal; }; +template <> struct Field::TypeToEnum>{ static const Types::Which value = Types::Decimal32; }; +template <> struct Field::TypeToEnum>{ static const Types::Which value = Types::Decimal64; }; +template <> struct Field::TypeToEnum>{ static const Types::Which value = Types::Decimal128; }; template <> struct Field::EnumToType { using Type = Null; }; template <> struct Field::EnumToType { using Type = UInt64; }; @@ -507,7 +542,9 @@ template <> struct Field::EnumToType { using Type = Float template <> struct Field::EnumToType { using Type = String; }; template <> struct Field::EnumToType { using Type = Array; }; template <> struct Field::EnumToType { using Type = Tuple; }; -template <> struct Field::EnumToType { using Type = DecimalField; }; +template <> struct Field::EnumToType { using Type = DecimalField; }; +template <> struct Field::EnumToType { using Type = DecimalField; }; +template <> struct Field::EnumToType { using Type = DecimalField; }; template @@ -551,9 +588,9 @@ template <> struct NearestFieldType { using Type = Int64; }; template <> struct NearestFieldType { using Type = Int64; }; template <> struct NearestFieldType { using Type = Int64; }; template <> struct NearestFieldType { using Type = Int128; }; -template <> struct NearestFieldType { using Type = DecimalField; }; -template <> struct NearestFieldType { using Type = DecimalField; }; -template <> struct NearestFieldType { using Type = DecimalField; }; +template <> struct NearestFieldType { using Type = DecimalField; }; +template <> struct NearestFieldType { using Type = DecimalField; }; +template <> struct NearestFieldType { using Type = DecimalField; }; template <> struct NearestFieldType { using Type = Float64; }; template <> struct NearestFieldType { using Type = Float64; }; template <> struct NearestFieldType { using Type = String; }; diff --git a/dbms/src/DataTypes/DataTypesDecimal.cpp b/dbms/src/DataTypes/DataTypesDecimal.cpp index 609a87fb552..fb64e744642 100644 --- a/dbms/src/DataTypes/DataTypesDecimal.cpp +++ b/dbms/src/DataTypes/DataTypesDecimal.cpp @@ -102,7 +102,7 @@ T DataTypeDecimal::parseFromString(const String & str) const template void DataTypeDecimal::serializeBinary(const Field & field, WriteBuffer & ostr) const { - FieldType x = get(field); + FieldType x = get>(field); writeBinary(x, ostr); }