diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index ad9d00661a0..5c6f7f0fdd5 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -144,7 +144,7 @@ public: } - void insert(const T value) { data.push_back(value); } + void insertValue(const T value) { data.push_back(value); } Container & getData() { return data; } const Container & getData() const { return data; } const T & getElement(size_t n) const { return data[n]; } diff --git a/dbms/src/Common/HashTable/Hash.h b/dbms/src/Common/HashTable/Hash.h index 90ee89953c0..befb660a968 100644 --- a/dbms/src/Common/HashTable/Hash.h +++ b/dbms/src/Common/HashTable/Hash.h @@ -84,6 +84,23 @@ struct DefaultHash>> } }; +template +struct DefaultHash && sizeof(T) <= 8>> +{ + size_t operator() (T key) const + { + return DefaultHash64(key); + } +}; + +template +struct DefaultHash && sizeof(T) == 16>> +{ + size_t operator() (T key) const + { + return DefaultHash64(key >> 64) ^ DefaultHash64(key); + } +}; template struct HashCRC32; diff --git a/dbms/src/Core/TypeListNumber.h b/dbms/src/Core/TypeListNumber.h index d9e6f82a7a6..84b716fa5b8 100644 --- a/dbms/src/Core/TypeListNumber.h +++ b/dbms/src/Core/TypeListNumber.h @@ -5,6 +5,9 @@ namespace DB { -using TypeListNumbers = TypeList; +using TypeListNativeNumbers = TypeList; +using TypeListDecimalNumbers = TypeList; +using TypeListNumbers = TypeList; } diff --git a/dbms/src/DataTypes/DataTypeLowCardinality.cpp b/dbms/src/DataTypes/DataTypeLowCardinality.cpp index 362db4efa33..417c988e5b9 100644 --- a/dbms/src/DataTypes/DataTypeLowCardinality.cpp +++ b/dbms/src/DataTypes/DataTypeLowCardinality.cpp @@ -894,7 +894,7 @@ MutableColumnUniquePtr DataTypeLowCardinality::createColumnUniqueImpl(const IDat if (isColumnedAsNumber(type)) { MutableColumnUniquePtr column; - TypeListNumbers::forEach(CreateColumnVector(column, *type, creator)); + TypeListNativeNumbers::forEach(CreateColumnVector(column, *type, creator)); if (!column) throw Exception("Unexpected numeric type: " + type->getName(), ErrorCodes::LOGICAL_ERROR); diff --git a/dbms/src/Functions/GatherUtils/Algorithms.h b/dbms/src/Functions/GatherUtils/Algorithms.h index c4b21ced4ae..fd77d52ece6 100644 --- a/dbms/src/Functions/GatherUtils/Algorithms.h +++ b/dbms/src/Functions/GatherUtils/Algorithms.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include "Sources.h" #include "Sinks.h" @@ -79,8 +80,16 @@ inline ALWAYS_INLINE void writeSlice(const NumericArraySlice & slice, Generic { for (size_t i = 0; i < slice.size; ++i) { - Field field = T(slice.data[i]); - sink.elements.insert(field); + if constexpr (IsDecimalNumber) + { + DecimalField field(T(slice.data[i]), 0); /// TODO: Decimal scale + sink.elements.insert(field); + } + else + { + Field field = T(slice.data[i]); + sink.elements.insert(field); + } } sink.current_offset += slice.size; } @@ -424,7 +433,13 @@ bool sliceHasImpl(const FirstSliceType & first, const SecondSliceType & second, template bool sliceEqualElements(const NumericArraySlice & first, const NumericArraySlice & second, size_t first_ind, size_t second_ind) { - return accurate::equalsOp(first.data[first_ind], second.data[second_ind]); + /// TODO: Decimal scale + if constexpr (IsDecimalNumber && IsDecimalNumber) + return accurate::equalsOp(typename T::NativeType(first.data[first_ind]), typename U::NativeType(second.data[second_ind])); + else if constexpr (IsDecimalNumber || IsDecimalNumber) + return false; + else + return accurate::equalsOp(first.data[first_ind], second.data[second_ind]); } template diff --git a/dbms/src/Functions/GatherUtils/Sinks.h b/dbms/src/Functions/GatherUtils/Sinks.h index c6925fab865..5fd943ae78b 100644 --- a/dbms/src/Functions/GatherUtils/Sinks.h +++ b/dbms/src/Functions/GatherUtils/Sinks.h @@ -3,6 +3,7 @@ #include "IArraySink.h" #include +#include #include #include #include @@ -33,17 +34,18 @@ struct NullableValueSource; template struct NumericArraySink : public ArraySinkImpl> { + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; using CompatibleArraySource = NumericArraySource; using CompatibleValueSource = NumericValueSource; - typename ColumnVector::Container & elements; + typename ColVecType::Container & elements; typename ColumnArray::Offsets & offsets; size_t row_num = 0; ColumnArray::Offset current_offset = 0; NumericArraySink(ColumnArray & arr, size_t column_size) - : elements(typeid_cast &>(arr.getData()).getData()), offsets(arr.getOffsets()) + : elements(typeid_cast(arr.getData()).getData()), offsets(arr.getOffsets()) { offsets.resize(column_size); } diff --git a/dbms/src/Functions/GatherUtils/Sources.h b/dbms/src/Functions/GatherUtils/Sources.h index d43dc69b2b0..c21a6fc523c 100644 --- a/dbms/src/Functions/GatherUtils/Sources.h +++ b/dbms/src/Functions/GatherUtils/Sources.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -30,17 +31,18 @@ namespace GatherUtils template struct NumericArraySource : public ArraySourceImpl> { + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; using Slice = NumericArraySlice; using Column = ColumnArray; - const typename ColumnVector::Container & elements; + const typename ColVecType::Container & elements; const typename ColumnArray::Offsets & offsets; size_t row_num = 0; ColumnArray::Offset prev_offset = 0; explicit NumericArraySource(const ColumnArray & arr) - : elements(typeid_cast &>(arr.getData()).getData()), offsets(arr.getOffsets()) + : elements(typeid_cast(arr.getData()).getData()), offsets(arr.getOffsets()) { } @@ -650,7 +652,7 @@ template struct NumericValueSource : ValueSourceImpl> { using Slice = NumericValueSlice; - using Column = ColumnVector; + using Column = std::conditional_t, ColumnDecimal, ColumnVector>; const T * begin; size_t total_rows; diff --git a/dbms/src/Functions/GatherUtils/createArraySink.cpp b/dbms/src/Functions/GatherUtils/createArraySink.cpp index 0f052856dbe..e6d80cdab9f 100644 --- a/dbms/src/Functions/GatherUtils/createArraySink.cpp +++ b/dbms/src/Functions/GatherUtils/createArraySink.cpp @@ -14,7 +14,9 @@ struct ArraySinkCreator { static std::unique_ptr create(ColumnArray & col, NullMap * null_map, size_t column_size) { - if (typeid_cast *>(&col.getData())) + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; + + if (typeid_cast(&col.getData())) { if (null_map) return std::make_unique>>(col, *null_map, column_size); diff --git a/dbms/src/Functions/GatherUtils/createArraySource.cpp b/dbms/src/Functions/GatherUtils/createArraySource.cpp index 2b0df7c7b7f..b7690a3f53c 100644 --- a/dbms/src/Functions/GatherUtils/createArraySource.cpp +++ b/dbms/src/Functions/GatherUtils/createArraySource.cpp @@ -14,7 +14,9 @@ struct ArraySourceCreator { static std::unique_ptr create(const ColumnArray & col, const NullMap * null_map, bool is_const, size_t total_rows) { - if (typeid_cast *>(&col.getData())) + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; + + if (typeid_cast(&col.getData())) { if (null_map) { diff --git a/dbms/src/Functions/GatherUtils/createValueSource.cpp b/dbms/src/Functions/GatherUtils/createValueSource.cpp index faf7d96c4c9..c74c41999aa 100644 --- a/dbms/src/Functions/GatherUtils/createValueSource.cpp +++ b/dbms/src/Functions/GatherUtils/createValueSource.cpp @@ -14,7 +14,9 @@ struct ValueSourceCreator { static std::unique_ptr create(const IColumn & col, const NullMap * null_map, bool is_const, size_t total_rows) { - if (auto column_vector = typeid_cast *>(&col)) + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; + + if (auto column_vector = typeid_cast(&col)) { if (null_map) { diff --git a/dbms/src/Functions/GeoUtils.h b/dbms/src/Functions/GeoUtils.h index 2191290d858..b13faa0f014 100644 --- a/dbms/src/Functions/GeoUtils.h +++ b/dbms/src/Functions/GeoUtils.h @@ -590,7 +590,7 @@ struct CallPointInPolygon template static ColumnPtr call(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl) { - using Impl = typename ApplyTypeListForClass<::DB::GeoUtils::CallPointInPolygon, TypeListNumbers>::Type; + using Impl = typename ApplyTypeListForClass<::DB::GeoUtils::CallPointInPolygon, TypeListNativeNumbers>::Type; if (auto column = typeid_cast *>(&x)) return Impl::template call(*column, y, impl); return CallPointInPolygon::call(x, y, impl); @@ -616,7 +616,7 @@ struct CallPointInPolygon<> template ColumnPtr pointInPolygon(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl) { - using Impl = typename ApplyTypeListForClass<::DB::GeoUtils::CallPointInPolygon, TypeListNumbers>::Type; + using Impl = typename ApplyTypeListForClass<::DB::GeoUtils::CallPointInPolygon, TypeListNativeNumbers>::Type; return Impl::call(x, y, impl); } diff --git a/dbms/src/Functions/array/arrayIntersect.cpp b/dbms/src/Functions/array/arrayIntersect.cpp index 8881abb1552..7485cec7f8f 100644 --- a/dbms/src/Functions/array/arrayIntersect.cpp +++ b/dbms/src/Functions/array/arrayIntersect.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -12,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -88,6 +90,19 @@ private: template void operator()(); }; + + struct DecimalExecutor + { + const UnpackedArrays & arrays; + const DataTypePtr & data_type; + ColumnPtr & result; + + DecimalExecutor(const UnpackedArrays & arrays_, const DataTypePtr & data_type_, ColumnPtr & result_) + : arrays(arrays_), data_type(data_type_), result(result_) {} + + template + void operator()(); + }; }; @@ -328,7 +343,8 @@ void FunctionArrayIntersect::executeImpl(Block & block, const ColumnNumbers & ar ColumnPtr result_column; auto not_nullable_nested_return_type = removeNullable(nested_return_type); - TypeListNumbers::forEach(NumberExecutor(arrays, not_nullable_nested_return_type, result_column)); + TypeListNativeNumbers::forEach(NumberExecutor(arrays, not_nullable_nested_return_type, result_column)); + TypeListDecimalNumbers::forEach(DecimalExecutor(arrays, not_nullable_nested_return_type, result_column)); using DateMap = ClearableHashMap, HashTableGrower, @@ -374,6 +390,17 @@ void FunctionArrayIntersect::NumberExecutor::operator()() result = execute, true>(arrays, ColumnVector::create()); } +template +void FunctionArrayIntersect::DecimalExecutor::operator()() +{ + using Map = ClearableHashMap, HashTableGrower, + HashTableAllocatorWithStackMemory<(1ULL << INITIAL_SIZE_DEGREE) * sizeof(T)>>; + + if (!result) + if (auto * decimal = typeid_cast *>(data_type.get())) + result = execute, true>(arrays, ColumnDecimal::create(0, decimal->getScale())); +} + template ColumnPtr FunctionArrayIntersect::execute(const UnpackedArrays & arrays, MutableColumnPtr result_data_ptr) { diff --git a/dbms/src/Functions/if.cpp b/dbms/src/Functions/if.cpp index f0534a13d66..aa7f924d1f9 100644 --- a/dbms/src/Functions/if.cpp +++ b/dbms/src/Functions/if.cpp @@ -175,9 +175,7 @@ public: private: template - static constexpr bool allow_arrays = - !IsDecimalNumber && !IsDecimalNumber && - !std::is_same_v && !std::is_same_v; + static constexpr bool allow_arrays = !std::is_same_v && !std::is_same_v; template static UInt32 decimalScale(Block & block [[maybe_unused]], const ColumnNumbers & arguments [[maybe_unused]]) diff --git a/dbms/tests/queries/0_stateless/00700_decimal_gathers.reference b/dbms/tests/queries/0_stateless/00700_decimal_gathers.reference new file mode 100644 index 00000000000..bbfd7388e12 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00700_decimal_gathers.reference @@ -0,0 +1,13 @@ +[2.000] +[2.0000000000] +[2.000000000000000000] +[1.000] +[1.0000000000] +[1.000000000000000000] +- +[2.000] +[1] +[2.000000000000000000] +[1.000] +[2] +[1.000000000000000000] diff --git a/dbms/tests/queries/0_stateless/00700_decimal_gathers.sql b/dbms/tests/queries/0_stateless/00700_decimal_gathers.sql new file mode 100644 index 00000000000..98519577b62 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00700_decimal_gathers.sql @@ -0,0 +1,17 @@ +select if(1, [cast(materialize(2.0),'Decimal(9,3)')], [cast(materialize(1.0),'Decimal(9,3)')]); +select if(1, [cast(materialize(2.0),'Decimal(18,10)')], [cast(materialize(1.0),'Decimal(18,10)')]); +select if(1, [cast(materialize(2.0),'Decimal(38,18)')], [cast(materialize(1.0),'Decimal(38,18)')]); + +select if(0, [cast(materialize(2.0),'Decimal(9,3)')], [cast(materialize(1.0),'Decimal(9,3)')]); +select if(0, [cast(materialize(2.0),'Decimal(18,10)')], [cast(materialize(1.0),'Decimal(18,10)')]); +select if(0, [cast(materialize(2.0),'Decimal(38,18)')], [cast(materialize(1.0),'Decimal(38,18)')]); + +select '-'; + +select if(1, [cast(materialize(2.0),'Decimal(9,3)')], [cast(materialize(1.0),'Decimal(9,0)')]); +select if(0, [cast(materialize(2.0),'Decimal(18,10)')], [cast(materialize(1.0),'Decimal(18,0)')]); +select if(1, [cast(materialize(2.0),'Decimal(38,18)')], [cast(materialize(1.0),'Decimal(38,8)')]); + +select if(0, [cast(materialize(2.0),'Decimal(9,0)')], [cast(materialize(1.0),'Decimal(9,3)')]); +select if(1, [cast(materialize(2.0),'Decimal(18,0)')], [cast(materialize(1.0),'Decimal(18,10)')]); +select if(0, [cast(materialize(2.0),'Decimal(38,0)')], [cast(materialize(1.0),'Decimal(38,18)')]);