From 77a5865a22d290033aec1894b2c79e688f713238 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sat, 13 Jan 2024 22:30:30 +0100 Subject: [PATCH 01/35] Adding FP16 --- base/base/BFloat16.h | 9 ++++++++ base/base/DecomposedFloat.h | 10 +++++++++ base/base/TypeLists.h | 2 +- base/base/TypeName.h | 1 + base/base/extended_types.h | 20 ++++++++++++++--- base/base/wide_integer.h | 2 +- base/base/wide_integer_impl.h | 8 ++++++- .../AggregateFunctionGroupArray.cpp | 4 ++-- .../AggregateFunctionGroupArrayMoving.cpp | 2 +- .../AggregateFunctionIntervalLengthSum.cpp | 4 ++-- .../AggregateFunctionSparkbar.cpp | 6 ++--- src/AggregateFunctions/AggregateFunctionSum.h | 6 ++--- .../AggregateFunctionUniqCombined.h | 2 +- src/AggregateFunctions/QuantileTDigest.h | 2 +- src/AggregateFunctions/ReservoirSampler.h | 2 +- .../ReservoirSamplerDeterministic.h | 2 +- src/Columns/ColumnArray.cpp | 4 ++++ src/Columns/ColumnNullable.cpp | 2 ++ src/Columns/ColumnVector.cpp | 13 ++++++----- src/Columns/ColumnVector.h | 1 + src/Columns/ColumnsCommon.cpp | 1 + src/Columns/ColumnsNumber.h | 1 + src/Columns/MaskOperations.cpp | 2 ++ src/Columns/tests/gtest_column_vector.cpp | 1 + src/Columns/tests/gtest_low_cardinality.cpp | 1 + src/Common/FieldVisitorConvertToNumber.h | 4 ++-- src/Common/HashTable/Hash.h | 1 + src/Common/HashTable/HashTable.h | 2 +- src/Common/NaNUtils.h | 6 ++--- src/Common/findExtreme.h | 2 +- src/Common/transformEndianness.h | 2 +- src/Core/AccurateComparison.h | 18 +++++++-------- src/Core/DecimalFunctions.h | 2 +- src/Core/Field.h | 1 + src/Core/SortCursor.h | 1 + src/Core/TypeId.h | 2 ++ src/Core/Types_fwd.h | 7 +----- src/Core/callOnTypeIndex.h | 3 +++ src/DataTypes/DataTypeNumberBase.cpp | 1 + src/DataTypes/DataTypeNumberBase.h | 1 + src/DataTypes/DataTypesDecimal.h | 5 +++-- src/DataTypes/DataTypesNumber.cpp | 1 + src/DataTypes/DataTypesNumber.h | 1 + src/DataTypes/IDataType.h | 5 ++++- src/DataTypes/NumberTraits.h | 22 +++++++++---------- .../Serializations/SerializationNumber.cpp | 1 + src/DataTypes/Utils.cpp | 7 ++++++ src/DataTypes/getLeastSupertype.cpp | 6 ++++- src/DataTypes/getMostSubtype.cpp | 6 ++++- src/Formats/ProtobufSerializer.cpp | 2 +- src/Functions/DivisionUtils.h | 16 +++++++------- src/Functions/FunctionMathUnary.h | 4 ++-- src/Functions/FunctionsConversion.h | 12 +++++----- src/Functions/FunctionsRound.h | 2 +- src/Functions/factorial.cpp | 2 +- src/Functions/minus.cpp | 4 ++-- src/Functions/moduloOrZero.cpp | 2 +- src/Functions/multiply.cpp | 4 ++-- src/Functions/plus.cpp | 4 ++-- src/Functions/sign.cpp | 2 +- src/IO/ReadHelpers.h | 4 +++- src/IO/WriteHelpers.h | 22 +++++++++++-------- src/Interpreters/RowRefs.cpp | 2 +- 63 files changed, 192 insertions(+), 105 deletions(-) create mode 100644 base/base/BFloat16.h diff --git a/base/base/BFloat16.h b/base/base/BFloat16.h new file mode 100644 index 00000000000..17c3ebe9ef3 --- /dev/null +++ b/base/base/BFloat16.h @@ -0,0 +1,9 @@ +#pragma once + +using BFloat16 = __bf16; + +namespace std +{ + inline constexpr bool isfinite(BFloat16) { return true; } + inline constexpr bool signbit(BFloat16) { return false; } +} diff --git a/base/base/DecomposedFloat.h b/base/base/DecomposedFloat.h index f152637b94e..fda7ee8d3f4 100644 --- a/base/base/DecomposedFloat.h +++ b/base/base/DecomposedFloat.h @@ -10,6 +10,15 @@ template struct FloatTraits; +template <> +struct FloatTraits<__bf16> +{ + using UInt = uint16_t; + static constexpr size_t bits = 16; + static constexpr size_t exponent_bits = 8; + static constexpr size_t mantissa_bits = bits - exponent_bits - 1; +}; + template <> struct FloatTraits { @@ -217,3 +226,4 @@ struct DecomposedFloat using DecomposedFloat64 = DecomposedFloat; using DecomposedFloat32 = DecomposedFloat; +using DecomposedFloat16 = DecomposedFloat<__bf16>; diff --git a/base/base/TypeLists.h b/base/base/TypeLists.h index 6c1283d054c..ce3111b1da3 100644 --- a/base/base/TypeLists.h +++ b/base/base/TypeLists.h @@ -9,7 +9,7 @@ namespace DB { using TypeListNativeInt = TypeList; -using TypeListFloat = TypeList; +using TypeListFloat = TypeList; using TypeListNativeNumber = TypeListConcat; using TypeListWideInt = TypeList; using TypeListInt = TypeListConcat; diff --git a/base/base/TypeName.h b/base/base/TypeName.h index 9005b5a2bf4..1f4b475d653 100644 --- a/base/base/TypeName.h +++ b/base/base/TypeName.h @@ -32,6 +32,7 @@ TN_MAP(Int32) TN_MAP(Int64) TN_MAP(Int128) TN_MAP(Int256) +TN_MAP(BFloat16) TN_MAP(Float32) TN_MAP(Float64) TN_MAP(String) diff --git a/base/base/extended_types.h b/base/base/extended_types.h index b58df45a97e..39665784141 100644 --- a/base/base/extended_types.h +++ b/base/base/extended_types.h @@ -4,6 +4,8 @@ #include #include +#include + using Int128 = wide::integer<128, signed>; using UInt128 = wide::integer<128, unsigned>; @@ -24,6 +26,7 @@ struct is_signed // NOLINT(readability-identifier-naming) template <> struct is_signed { static constexpr bool value = true; }; template <> struct is_signed { static constexpr bool value = true; }; +template <> struct is_signed { static constexpr bool value = true; }; template inline constexpr bool is_signed_v = is_signed::value; @@ -47,8 +50,6 @@ template concept is_integer = || std::is_same_v || std::is_same_v; -template concept is_floating_point = std::is_floating_point_v; - template struct is_arithmetic // NOLINT(readability-identifier-naming) { @@ -59,11 +60,24 @@ template <> struct is_arithmetic { static constexpr bool value = true; } template <> struct is_arithmetic { static constexpr bool value = true; }; template <> struct is_arithmetic { static constexpr bool value = true; }; template <> struct is_arithmetic { static constexpr bool value = true; }; - +template <> struct is_arithmetic { static constexpr bool value = true; }; template inline constexpr bool is_arithmetic_v = is_arithmetic::value; + +template +struct is_floating_point // NOLINT(readability-identifier-naming) +{ + static constexpr bool value = std::is_floating_point_v; +}; + +template <> struct is_floating_point { static constexpr bool value = true; }; + +template +inline constexpr bool is_floating_point_v = is_floating_point::value; + + template struct make_unsigned // NOLINT(readability-identifier-naming) { diff --git a/base/base/wide_integer.h b/base/base/wide_integer.h index ffd30460c03..877ef5bd137 100644 --- a/base/base/wide_integer.h +++ b/base/base/wide_integer.h @@ -117,6 +117,7 @@ public: constexpr operator long double() const noexcept; constexpr operator double() const noexcept; constexpr operator float() const noexcept; + constexpr operator __bf16() const noexcept; struct _impl; @@ -262,4 +263,3 @@ struct hash>; // NOLINTEND(*) #include "wide_integer_impl.h" - diff --git a/base/base/wide_integer_impl.h b/base/base/wide_integer_impl.h index c1fd7b69b7f..7b95164e44d 100644 --- a/base/base/wide_integer_impl.h +++ b/base/base/wide_integer_impl.h @@ -154,7 +154,7 @@ struct common_type, Arithmetic> static_assert(wide::ArithmeticConcept()); using type = std::conditional_t< - std::is_floating_point_v, + is_floating_point_v, Arithmetic, std::conditional_t< sizeof(Arithmetic) * 8 < Bits, @@ -1291,6 +1291,12 @@ constexpr integer::operator float() const noexcept return static_cast(static_cast(*this)); } +template +constexpr integer::operator __bf16() const noexcept +{ + return static_cast<__bf16>(static_cast(*this)); +} + // Unary operators template constexpr integer operator~(const integer & lhs) noexcept diff --git a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp index 6c6397e35d5..bcefa6b93dc 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp @@ -74,7 +74,7 @@ template struct GroupArraySamplerData { /// For easy serialization. - static_assert(std::has_unique_object_representations_v || std::is_floating_point_v); + static_assert(std::has_unique_object_representations_v || is_floating_point_v); // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena using Allocator = MixedAlignedArenaAllocator; @@ -116,7 +116,7 @@ template struct GroupArrayNumericData { /// For easy serialization. - static_assert(std::has_unique_object_representations_v || std::is_floating_point_v); + static_assert(std::has_unique_object_representations_v || is_floating_point_v); // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena using Allocator = MixedAlignedArenaAllocator; diff --git a/src/AggregateFunctions/AggregateFunctionGroupArrayMoving.cpp b/src/AggregateFunctions/AggregateFunctionGroupArrayMoving.cpp index 026b8d1956f..ee6a82686c5 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArrayMoving.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupArrayMoving.cpp @@ -38,7 +38,7 @@ template struct MovingData { /// For easy serialization. - static_assert(std::has_unique_object_representations_v || std::is_floating_point_v); + static_assert(std::has_unique_object_representations_v || is_floating_point_v); using Accumulator = T; diff --git a/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.cpp b/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.cpp index eacd0596757..06156643aa0 100644 --- a/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.cpp +++ b/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.cpp @@ -187,7 +187,7 @@ public: static DataTypePtr createResultType() { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) return std::make_shared(); return std::make_shared(); } @@ -227,7 +227,7 @@ public: void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) assert_cast(to).getData().push_back(getIntervalLengthSum(this->data(place))); else assert_cast(to).getData().push_back(getIntervalLengthSum(this->data(place))); diff --git a/src/AggregateFunctions/AggregateFunctionSparkbar.cpp b/src/AggregateFunctions/AggregateFunctionSparkbar.cpp index b6e538520a8..f4214f3a133 100644 --- a/src/AggregateFunctions/AggregateFunctionSparkbar.cpp +++ b/src/AggregateFunctions/AggregateFunctionSparkbar.cpp @@ -50,7 +50,7 @@ struct AggregateFunctionSparkbarData auto [it, inserted] = points.insert({x, y}); if (!inserted) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { it->getMapped() += y; return it->getMapped(); @@ -197,7 +197,7 @@ private: Y res; bool has_overfllow = false; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) res = histogram[index] + point.getMapped(); else has_overfllow = common::addOverflow(histogram[index], point.getMapped(), res); @@ -246,7 +246,7 @@ private: } constexpr auto levels_num = static_cast(BAR_LEVELS - 1); - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { y = y / (y_max / levels_num) + 1; } diff --git a/src/AggregateFunctions/AggregateFunctionSum.h b/src/AggregateFunctions/AggregateFunctionSum.h index 5781ab69c6b..81df3244b38 100644 --- a/src/AggregateFunctions/AggregateFunctionSum.h +++ b/src/AggregateFunctions/AggregateFunctionSum.h @@ -69,7 +69,7 @@ struct AggregateFunctionSumData size_t count = end - start; const auto * end_ptr = ptr + count; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { /// Compiler cannot unroll this loop, do it manually. /// (at least for floats, most likely due to the lack of -fassociative-math) @@ -164,7 +164,7 @@ struct AggregateFunctionSumData return; } - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { /// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned /// integer of the same size and use a mask instead (0 to discard, 0xFF..FF to keep) @@ -277,7 +277,7 @@ struct AggregateFunctionSumData template struct AggregateFunctionSumKahanData { - static_assert(std::is_floating_point_v, + static_assert(is_floating_point_v, "It doesn't make sense to use Kahan Summation algorithm for non floating point types"); T sum{}; diff --git a/src/AggregateFunctions/AggregateFunctionUniqCombined.h b/src/AggregateFunctions/AggregateFunctionUniqCombined.h index 10774442610..19e2665f9af 100644 --- a/src/AggregateFunctions/AggregateFunctionUniqCombined.h +++ b/src/AggregateFunctions/AggregateFunctionUniqCombined.h @@ -114,7 +114,7 @@ public: /// Initially UInt128 was introduced only for UUID, and then the other big-integer types were added. hash = static_cast(sipHash64(value)); } - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { hash = static_cast(intHash64(bit_cast(value))); } diff --git a/src/AggregateFunctions/QuantileTDigest.h b/src/AggregateFunctions/QuantileTDigest.h index 979c3f2af15..1407b73e669 100644 --- a/src/AggregateFunctions/QuantileTDigest.h +++ b/src/AggregateFunctions/QuantileTDigest.h @@ -380,7 +380,7 @@ public: ResultType getImpl(Float64 level) { if (centroids.empty()) - return std::is_floating_point_v ? std::numeric_limits::quiet_NaN() : 0; + return is_floating_point_v ? std::numeric_limits::quiet_NaN() : 0; compress(); diff --git a/src/AggregateFunctions/ReservoirSampler.h b/src/AggregateFunctions/ReservoirSampler.h index 37fc05a2e4c..242540102b8 100644 --- a/src/AggregateFunctions/ReservoirSampler.h +++ b/src/AggregateFunctions/ReservoirSampler.h @@ -278,6 +278,6 @@ private: if (OnEmpty == ReservoirSamplerOnEmpty::THROW) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSampler"); else - return NanLikeValueConstructor>::getValue(); + return NanLikeValueConstructor>::getValue(); } }; diff --git a/src/AggregateFunctions/ReservoirSamplerDeterministic.h b/src/AggregateFunctions/ReservoirSamplerDeterministic.h index daed0b98ca3..75af6638183 100644 --- a/src/AggregateFunctions/ReservoirSamplerDeterministic.h +++ b/src/AggregateFunctions/ReservoirSamplerDeterministic.h @@ -273,7 +273,7 @@ private: if (OnEmpty == ReservoirSamplerDeterministicOnEmpty::THROW) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSamplerDeterministic"); else - return NanLikeValueConstructor>::getValue(); + return NanLikeValueConstructor>::getValue(); } }; diff --git a/src/Columns/ColumnArray.cpp b/src/Columns/ColumnArray.cpp index 1cb8188bce6..4aaaf01e5ea 100644 --- a/src/Columns/ColumnArray.cpp +++ b/src/Columns/ColumnArray.cpp @@ -574,6 +574,8 @@ ColumnPtr ColumnArray::filter(const Filter & filt, ssize_t result_size_hint) con return filterNumber(filt, result_size_hint); if (typeid_cast(data.get())) return filterNumber(filt, result_size_hint); + if (typeid_cast(data.get())) + return filterNumber(filt, result_size_hint); if (typeid_cast(data.get())) return filterNumber(filt, result_size_hint); if (typeid_cast(data.get())) @@ -993,6 +995,8 @@ ColumnPtr ColumnArray::replicate(const Offsets & replicate_offsets) const return replicateNumber(replicate_offsets); if (typeid_cast(data.get())) return replicateNumber(replicate_offsets); + if (typeid_cast(data.get())) + return replicateNumber(replicate_offsets); if (typeid_cast(data.get())) return replicateNumber(replicate_offsets); if (typeid_cast(data.get())) diff --git a/src/Columns/ColumnNullable.cpp b/src/Columns/ColumnNullable.cpp index 4ee6bb3d586..3513ac06dcd 100644 --- a/src/Columns/ColumnNullable.cpp +++ b/src/Columns/ColumnNullable.cpp @@ -171,6 +171,8 @@ StringRef ColumnNullable::serializeValueIntoArena(size_t n, Arena & arena, char return static_cast(nested_column.get())->serializeValueIntoArena(n, arena, begin, &arr[n]); case TypeIndex::Int256: return static_cast(nested_column.get())->serializeValueIntoArena(n, arena, begin, &arr[n]); + case TypeIndex::BFloat16: + return static_cast(nested_column.get())->serializeValueIntoArena(n, arena, begin, &arr[n]); case TypeIndex::Float32: return static_cast(nested_column.get())->serializeValueIntoArena(n, arena, begin, &arr[n]); case TypeIndex::Float64: diff --git a/src/Columns/ColumnVector.cpp b/src/Columns/ColumnVector.cpp index b1cf449dfde..bad84e7147c 100644 --- a/src/Columns/ColumnVector.cpp +++ b/src/Columns/ColumnVector.cpp @@ -141,7 +141,7 @@ struct ColumnVector::less_stable if (unlikely(parent.data[lhs] == parent.data[rhs])) return lhs < rhs; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { if (unlikely(std::isnan(parent.data[lhs]) && std::isnan(parent.data[rhs]))) { @@ -173,7 +173,7 @@ struct ColumnVector::greater_stable if (unlikely(parent.data[lhs] == parent.data[rhs])) return lhs < rhs; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { if (unlikely(std::isnan(parent.data[lhs]) && std::isnan(parent.data[rhs]))) { @@ -259,7 +259,7 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; /// TODO: LSD RadixSort is currently not stable if direction is descending, or value is floating point - bool use_radix_sort = (sort_is_stable && ascending && !std::is_floating_point_v) || !sort_is_stable; + bool use_radix_sort = (sort_is_stable && ascending && !is_floating_point_v) || !sort_is_stable; /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. if (data_size >= 256 && data_size <= std::numeric_limits::max() && use_radix_sort) @@ -286,7 +286,7 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction /// Radix sort treats all NaNs to be greater than all numbers. /// If the user needs the opposite, we must move them accordingly. - if (std::is_floating_point_v && nan_direction_hint < 0) + if (is_floating_point_v && nan_direction_hint < 0) { size_t nans_to_move = 0; @@ -333,7 +333,7 @@ void ColumnVector::updatePermutation(IColumn::PermutationSortDirection direct if constexpr (is_arithmetic_v && !is_big_int_v) { /// TODO: LSD RadixSort is currently not stable if direction is descending, or value is floating point - bool use_radix_sort = (sort_is_stable && ascending && !std::is_floating_point_v) || !sort_is_stable; + bool use_radix_sort = (sort_is_stable && ascending && !is_floating_point_v) || !sort_is_stable; size_t size = end - begin; /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. @@ -356,7 +356,7 @@ void ColumnVector::updatePermutation(IColumn::PermutationSortDirection direct /// Radix sort treats all NaNs to be greater than all numbers. /// If the user needs the opposite, we must move them accordingly. - if (std::is_floating_point_v && nan_direction_hint < 0) + if (is_floating_point_v && nan_direction_hint < 0) { size_t nans_to_move = 0; @@ -970,6 +970,7 @@ template class ColumnVector; template class ColumnVector; template class ColumnVector; template class ColumnVector; +template class ColumnVector; template class ColumnVector; template class ColumnVector; template class ColumnVector; diff --git a/src/Columns/ColumnVector.h b/src/Columns/ColumnVector.h index fab2d5f06aa..c976fac3bab 100644 --- a/src/Columns/ColumnVector.h +++ b/src/Columns/ColumnVector.h @@ -570,6 +570,7 @@ extern template class ColumnVector; extern template class ColumnVector; extern template class ColumnVector; extern template class ColumnVector; +extern template class ColumnVector; extern template class ColumnVector; extern template class ColumnVector; extern template class ColumnVector; diff --git a/src/Columns/ColumnsCommon.cpp b/src/Columns/ColumnsCommon.cpp index 4ac84e10750..444f5fae87a 100644 --- a/src/Columns/ColumnsCommon.cpp +++ b/src/Columns/ColumnsCommon.cpp @@ -328,6 +328,7 @@ INSTANTIATE(Int32) INSTANTIATE(Int64) INSTANTIATE(Int128) INSTANTIATE(Int256) +INSTANTIATE(BFloat16) INSTANTIATE(Float32) INSTANTIATE(Float64) INSTANTIATE(Decimal32) diff --git a/src/Columns/ColumnsNumber.h b/src/Columns/ColumnsNumber.h index ae7eddb0b22..2dce2269079 100644 --- a/src/Columns/ColumnsNumber.h +++ b/src/Columns/ColumnsNumber.h @@ -23,6 +23,7 @@ using ColumnInt64 = ColumnVector; using ColumnInt128 = ColumnVector; using ColumnInt256 = ColumnVector; +using ColumnBFloat16 = ColumnVector; using ColumnFloat32 = ColumnVector; using ColumnFloat64 = ColumnVector; diff --git a/src/Columns/MaskOperations.cpp b/src/Columns/MaskOperations.cpp index b84268356a7..ca4ca263811 100644 --- a/src/Columns/MaskOperations.cpp +++ b/src/Columns/MaskOperations.cpp @@ -63,6 +63,7 @@ INSTANTIATE(Int32) INSTANTIATE(Int64) INSTANTIATE(Int128) INSTANTIATE(Int256) +INSTANTIATE(BFloat16) INSTANTIATE(Float32) INSTANTIATE(Float64) INSTANTIATE(Decimal32) @@ -225,6 +226,7 @@ MaskInfo extractMaskImpl( || extractMaskNumeric(mask, column, null_value, null_bytemap, nulls, mask_info) || extractMaskNumeric(mask, column, null_value, null_bytemap, nulls, mask_info) || extractMaskNumeric(mask, column, null_value, null_bytemap, nulls, mask_info) + || extractMaskNumeric(mask, column, null_value, null_bytemap, nulls, mask_info) || extractMaskNumeric(mask, column, null_value, null_bytemap, nulls, mask_info) || extractMaskNumeric(mask, column, null_value, null_bytemap, nulls, mask_info))) throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Cannot convert column {} to mask.", column->getName()); diff --git a/src/Columns/tests/gtest_column_vector.cpp b/src/Columns/tests/gtest_column_vector.cpp index b71d4a095ab..3a084a89079 100644 --- a/src/Columns/tests/gtest_column_vector.cpp +++ b/src/Columns/tests/gtest_column_vector.cpp @@ -93,6 +93,7 @@ TEST(ColumnVector, Filter) testFilter(); testFilter(); testFilter(); + testFilter(); testFilter(); testFilter(); testFilter(); diff --git a/src/Columns/tests/gtest_low_cardinality.cpp b/src/Columns/tests/gtest_low_cardinality.cpp index 5e01279b7df..965c0d219b9 100644 --- a/src/Columns/tests/gtest_low_cardinality.cpp +++ b/src/Columns/tests/gtest_low_cardinality.cpp @@ -45,6 +45,7 @@ TEST(ColumnLowCardinality, Insert) testLowCardinalityNumberInsert(std::make_shared()); testLowCardinalityNumberInsert(std::make_shared()); + testLowCardinalityNumberInsert(std::make_shared()); testLowCardinalityNumberInsert(std::make_shared()); testLowCardinalityNumberInsert(std::make_shared()); } diff --git a/src/Common/FieldVisitorConvertToNumber.h b/src/Common/FieldVisitorConvertToNumber.h index bf8c8c8638e..38144650b97 100644 --- a/src/Common/FieldVisitorConvertToNumber.h +++ b/src/Common/FieldVisitorConvertToNumber.h @@ -58,7 +58,7 @@ public: T operator() (const Float64 & x) const { - if constexpr (!std::is_floating_point_v) + if constexpr (!is_floating_point_v) { if (!isFinite(x)) { @@ -88,7 +88,7 @@ public: template T operator() (const DecimalField & x) const { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) return x.getValue().template convertTo() / x.getScaleMultiplier().template convertTo(); else return (x.getValue() / x.getScaleMultiplier()). template convertTo(); diff --git a/src/Common/HashTable/Hash.h b/src/Common/HashTable/Hash.h index fb6afcde133..b4bc6af1cef 100644 --- a/src/Common/HashTable/Hash.h +++ b/src/Common/HashTable/Hash.h @@ -322,6 +322,7 @@ DEFINE_HASH(Int32) DEFINE_HASH(Int64) DEFINE_HASH(Int128) DEFINE_HASH(Int256) +DEFINE_HASH(BFloat16) DEFINE_HASH(Float32) DEFINE_HASH(Float64) DEFINE_HASH(DB::UUID) diff --git a/src/Common/HashTable/HashTable.h b/src/Common/HashTable/HashTable.h index f23c4ca15dd..e4d5d3868c8 100644 --- a/src/Common/HashTable/HashTable.h +++ b/src/Common/HashTable/HashTable.h @@ -91,7 +91,7 @@ inline bool bitEquals(T && a, T && b) { using RealT = std::decay_t; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) return 0 == memcmp(&a, &b, sizeof(RealT)); /// Note that memcmp with constant size is compiler builtin. else return a == b; diff --git a/src/Common/NaNUtils.h b/src/Common/NaNUtils.h index 1c5a619e919..6363e3e61a2 100644 --- a/src/Common/NaNUtils.h +++ b/src/Common/NaNUtils.h @@ -9,7 +9,7 @@ template inline bool isNaN(T x) { /// To be sure, that this function is zero-cost for non-floating point types. - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) return std::isnan(x); else return false; @@ -19,7 +19,7 @@ inline bool isNaN(T x) template inline bool isFinite(T x) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) return std::isfinite(x); else return true; @@ -29,7 +29,7 @@ inline bool isFinite(T x) template T NaNOrZero() { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) return std::numeric_limits::quiet_NaN(); else return {}; diff --git a/src/Common/findExtreme.h b/src/Common/findExtreme.h index b38c24697c0..611af023d33 100644 --- a/src/Common/findExtreme.h +++ b/src/Common/findExtreme.h @@ -11,7 +11,7 @@ namespace DB { template -concept is_any_native_number = (is_any_of); +concept is_any_native_number = (is_any_of); template std::optional findExtremeMin(const T * __restrict ptr, size_t start, size_t end); diff --git a/src/Common/transformEndianness.h b/src/Common/transformEndianness.h index 1657305acda..2a0c45efe38 100644 --- a/src/Common/transformEndianness.h +++ b/src/Common/transformEndianness.h @@ -38,7 +38,7 @@ inline void transformEndianness(T & x) } template -requires std::is_floating_point_v +requires is_floating_point_v inline void transformEndianness(T & value) { if constexpr (ToEndian != FromEndian) diff --git a/src/Core/AccurateComparison.h b/src/Core/AccurateComparison.h index a201c136e3a..82d06876fe3 100644 --- a/src/Core/AccurateComparison.h +++ b/src/Core/AccurateComparison.h @@ -25,7 +25,7 @@ bool lessOp(A a, B b) return a < b; /// float vs float - if constexpr (std::is_floating_point_v && std::is_floating_point_v) + if constexpr (is_floating_point_v && is_floating_point_v) return a < b; /// anything vs NaN @@ -49,7 +49,7 @@ bool lessOp(A a, B b) } /// int vs float - if constexpr (is_integer && std::is_floating_point_v) + if constexpr (is_integer && is_floating_point_v) { if constexpr (sizeof(A) <= 4) return static_cast(a) < static_cast(b); @@ -57,7 +57,7 @@ bool lessOp(A a, B b) return DecomposedFloat(b).greater(a); } - if constexpr (std::is_floating_point_v && is_integer) + if constexpr (is_floating_point_v && is_integer) { if constexpr (sizeof(B) <= 4) return static_cast(a) < static_cast(b); @@ -65,8 +65,8 @@ bool lessOp(A a, B b) return DecomposedFloat(a).less(b); } - static_assert(is_integer || std::is_floating_point_v); - static_assert(is_integer || std::is_floating_point_v); + static_assert(is_integer || is_floating_point_v); + static_assert(is_integer || is_floating_point_v); UNREACHABLE(); } @@ -101,7 +101,7 @@ bool equalsOp(A a, B b) return a == b; /// float vs float - if constexpr (std::is_floating_point_v && std::is_floating_point_v) + if constexpr (is_floating_point_v && is_floating_point_v) return a == b; /// anything vs NaN @@ -125,7 +125,7 @@ bool equalsOp(A a, B b) } /// int vs float - if constexpr (is_integer && std::is_floating_point_v) + if constexpr (is_integer && is_floating_point_v) { if constexpr (sizeof(A) <= 4) return static_cast(a) == static_cast(b); @@ -133,7 +133,7 @@ bool equalsOp(A a, B b) return DecomposedFloat(b).equals(a); } - if constexpr (std::is_floating_point_v && is_integer) + if constexpr (is_floating_point_v && is_integer) { if constexpr (sizeof(B) <= 4) return static_cast(a) == static_cast(b); @@ -163,7 +163,7 @@ inline bool NO_SANITIZE_UNDEFINED convertNumeric(From value, To & result) return true; } - if constexpr (std::is_floating_point_v && std::is_floating_point_v) + if constexpr (is_floating_point_v && is_floating_point_v) { /// Note that NaNs doesn't compare equal to anything, but they are still in range of any Float type. if (isNaN(value)) diff --git a/src/Core/DecimalFunctions.h b/src/Core/DecimalFunctions.h index 8dad00c3a1e..c5bc4ad70f6 100644 --- a/src/Core/DecimalFunctions.h +++ b/src/Core/DecimalFunctions.h @@ -310,7 +310,7 @@ ReturnType convertToImpl(const DecimalType & decimal, UInt32 scale, To & result) using DecimalNativeType = typename DecimalType::NativeType; static constexpr bool throw_exception = std::is_void_v; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { result = static_cast(decimal.value) / static_cast(scaleMultiplier(scale)); } diff --git a/src/Core/Field.h b/src/Core/Field.h index 6afa98ed9c0..be70eb1ea07 100644 --- a/src/Core/Field.h +++ b/src/Core/Field.h @@ -251,6 +251,7 @@ template <> struct NearestFieldTypeImpl> { using Type = template <> struct NearestFieldTypeImpl> { using Type = DecimalField; }; template <> struct NearestFieldTypeImpl> { using Type = DecimalField; }; template <> struct NearestFieldTypeImpl> { using Type = DecimalField; }; +template <> struct NearestFieldTypeImpl { using Type = Float64; }; template <> struct NearestFieldTypeImpl { using Type = Float64; }; template <> struct NearestFieldTypeImpl { using Type = Float64; }; template <> struct NearestFieldTypeImpl { using Type = String; }; diff --git a/src/Core/SortCursor.h b/src/Core/SortCursor.h index 3c412fa1f17..a9dc90a8fa1 100644 --- a/src/Core/SortCursor.h +++ b/src/Core/SortCursor.h @@ -687,6 +687,7 @@ private: SortingQueueImpl>, strategy>, SortingQueueImpl>, strategy>, + SortingQueueImpl>, strategy>, SortingQueueImpl>, strategy>, SortingQueueImpl>, strategy>, diff --git a/src/Core/TypeId.h b/src/Core/TypeId.h index 9c634d2321c..73fa7da37e2 100644 --- a/src/Core/TypeId.h +++ b/src/Core/TypeId.h @@ -21,6 +21,7 @@ enum class TypeIndex Int64, Int128, Int256, + BFloat16, Float32, Float64, Date, @@ -91,6 +92,7 @@ TYPEID_MAP(Int32) TYPEID_MAP(Int64) TYPEID_MAP(Int128) TYPEID_MAP(Int256) +TYPEID_MAP(BFloat16) TYPEID_MAP(Float32) TYPEID_MAP(Float64) TYPEID_MAP(UUID) diff --git a/src/Core/Types_fwd.h b/src/Core/Types_fwd.h index a59e4b6eab8..2dffc910f9b 100644 --- a/src/Core/Types_fwd.h +++ b/src/Core/Types_fwd.h @@ -21,6 +21,7 @@ using Int128 = wide::integer<128, signed>; using UInt128 = wide::integer<128, unsigned>; using Int256 = wide::integer<256, signed>; using UInt256 = wide::integer<256, unsigned>; +using BFloat16 = __bf16; namespace DB { @@ -28,16 +29,10 @@ namespace DB using UUID = StrongTypedef; struct IPv4; - struct IPv6; struct Null; -using UInt128 = ::UInt128; -using UInt256 = ::UInt256; -using Int128 = ::Int128; -using Int256 = ::Int256; - enum class TypeIndex; /// Not a data type in database, defined just for convenience. diff --git a/src/Core/callOnTypeIndex.h b/src/Core/callOnTypeIndex.h index f5f67df563b..68aba2263c7 100644 --- a/src/Core/callOnTypeIndex.h +++ b/src/Core/callOnTypeIndex.h @@ -62,6 +62,7 @@ static bool callOnBasicType(TypeIndex number, F && f) { switch (number) { + case TypeIndex::BFloat16: return f(TypePair()); case TypeIndex::Float32: return f(TypePair()); case TypeIndex::Float64: return f(TypePair()); default: @@ -132,6 +133,7 @@ static inline bool callOnBasicTypes(TypeIndex type_num1, TypeIndex type_num2, F { switch (type_num1) { + case TypeIndex::BFloat16: return callOnBasicType(type_num2, std::forward(f)); case TypeIndex::Float32: return callOnBasicType(type_num2, std::forward(f)); case TypeIndex::Float64: return callOnBasicType(type_num2, std::forward(f)); default: @@ -189,6 +191,7 @@ static bool callOnIndexAndDataType(TypeIndex number, F && f, ExtraArgs && ... ar case TypeIndex::Int128: return f(TypePair, T>(), std::forward(args)...); case TypeIndex::Int256: return f(TypePair, T>(), std::forward(args)...); + case TypeIndex::BFloat16: return f(TypePair, T>(), std::forward(args)...); case TypeIndex::Float32: return f(TypePair, T>(), std::forward(args)...); case TypeIndex::Float64: return f(TypePair, T>(), std::forward(args)...); diff --git a/src/DataTypes/DataTypeNumberBase.cpp b/src/DataTypes/DataTypeNumberBase.cpp index be448fe1491..636d557f4d0 100644 --- a/src/DataTypes/DataTypeNumberBase.cpp +++ b/src/DataTypes/DataTypeNumberBase.cpp @@ -42,6 +42,7 @@ template class DataTypeNumberBase; template class DataTypeNumberBase; template class DataTypeNumberBase; template class DataTypeNumberBase; +template class DataTypeNumberBase; template class DataTypeNumberBase; template class DataTypeNumberBase; diff --git a/src/DataTypes/DataTypeNumberBase.h b/src/DataTypes/DataTypeNumberBase.h index 3a5b11c5124..11b9427a14d 100644 --- a/src/DataTypes/DataTypeNumberBase.h +++ b/src/DataTypes/DataTypeNumberBase.h @@ -68,6 +68,7 @@ extern template class DataTypeNumberBase; extern template class DataTypeNumberBase; extern template class DataTypeNumberBase; extern template class DataTypeNumberBase; +extern template class DataTypeNumberBase; extern template class DataTypeNumberBase; extern template class DataTypeNumberBase; diff --git a/src/DataTypes/DataTypesDecimal.h b/src/DataTypes/DataTypesDecimal.h index e2b433cbe2f..12d061b11e5 100644 --- a/src/DataTypes/DataTypesDecimal.h +++ b/src/DataTypes/DataTypesDecimal.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -209,9 +210,9 @@ inline ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & static constexpr bool throw_exception = std::is_same_v; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { - if (!std::isfinite(value)) + if (!isFinite(value)) { if constexpr (throw_exception) throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "{} convert overflow. Cannot convert infinity or NaN to decimal", ToDataType::family_name); diff --git a/src/DataTypes/DataTypesNumber.cpp b/src/DataTypes/DataTypesNumber.cpp index 1c0c418411b..81c64df9711 100644 --- a/src/DataTypes/DataTypesNumber.cpp +++ b/src/DataTypes/DataTypesNumber.cpp @@ -54,6 +54,7 @@ void registerDataTypeNumbers(DataTypeFactory & factory) factory.registerDataType("Int32", createNumericDataType); factory.registerDataType("Int64", createNumericDataType); + factory.registerDataType("BFloat16", createNumericDataType); factory.registerDataType("Float32", createNumericDataType); factory.registerDataType("Float64", createNumericDataType); diff --git a/src/DataTypes/DataTypesNumber.h b/src/DataTypes/DataTypesNumber.h index 0c1f88a7925..1fe95f58e99 100644 --- a/src/DataTypes/DataTypesNumber.h +++ b/src/DataTypes/DataTypesNumber.h @@ -63,6 +63,7 @@ using DataTypeInt8 = DataTypeNumber; using DataTypeInt16 = DataTypeNumber; using DataTypeInt32 = DataTypeNumber; using DataTypeInt64 = DataTypeNumber; +using DataTypeBFloat16 = DataTypeNumber; using DataTypeFloat32 = DataTypeNumber; using DataTypeFloat64 = DataTypeNumber; diff --git a/src/DataTypes/IDataType.h b/src/DataTypes/IDataType.h index eabf066bc3d..ac71a61683a 100644 --- a/src/DataTypes/IDataType.h +++ b/src/DataTypes/IDataType.h @@ -372,9 +372,10 @@ struct WhichDataType constexpr bool isDecimal256() const { return idx == TypeIndex::Decimal256; } constexpr bool isDecimal() const { return isDecimal32() || isDecimal64() || isDecimal128() || isDecimal256(); } + constexpr bool isBFloat16() const { return idx == TypeIndex::BFloat16; } constexpr bool isFloat32() const { return idx == TypeIndex::Float32; } constexpr bool isFloat64() const { return idx == TypeIndex::Float64; } - constexpr bool isFloat() const { return isFloat32() || isFloat64(); } + constexpr bool isFloat() const { return isBFloat16() || isFloat32() || isFloat64(); } constexpr bool isNativeNumber() const { return isNativeInteger() || isFloat(); } constexpr bool isNumber() const { return isInteger() || isFloat() || isDecimal(); } @@ -558,6 +559,7 @@ template inline constexpr bool IsDataTypeEnum> = tr M(Int16) \ M(Int32) \ M(Int64) \ + M(BFloat16) \ M(Float32) \ M(Float64) @@ -574,6 +576,7 @@ template inline constexpr bool IsDataTypeEnum> = tr M(Int64) \ M(Int128) \ M(Int256) \ + M(BFloat16) \ M(Float32) \ M(Float64) } diff --git a/src/DataTypes/NumberTraits.h b/src/DataTypes/NumberTraits.h index cf283d3358c..35a6238c71a 100644 --- a/src/DataTypes/NumberTraits.h +++ b/src/DataTypes/NumberTraits.h @@ -74,7 +74,7 @@ template struct ResultOfAdditionMultiplication { using Type = typename Construct< is_signed_v || is_signed_v, - std::is_floating_point_v || std::is_floating_point_v, + is_floating_point_v || is_floating_point_v, nextSize(max(sizeof(A), sizeof(B)))>::Type; }; @@ -82,7 +82,7 @@ template struct ResultOfSubtraction { using Type = typename Construct< true, - std::is_floating_point_v || std::is_floating_point_v, + is_floating_point_v || is_floating_point_v, nextSize(max(sizeof(A), sizeof(B)))>::Type; }; @@ -113,7 +113,7 @@ template struct ResultOfModulo /// Example: toInt32(-199) % toUInt8(200) will return -199 that does not fit in Int8, only in Int16. static constexpr size_t size_of_result = result_is_signed ? nextSize(sizeof(B)) : sizeof(B); using Type0 = typename Construct::Type; - using Type = std::conditional_t || std::is_floating_point_v, Float64, Type0>; + using Type = std::conditional_t || is_floating_point_v, Float64, Type0>; }; template struct ResultOfPositiveModulo @@ -121,21 +121,21 @@ template struct ResultOfPositiveModulo /// function positive_modulo always return non-negative number. static constexpr size_t size_of_result = sizeof(B); using Type0 = typename Construct::Type; - using Type = std::conditional_t || std::is_floating_point_v, Float64, Type0>; + using Type = std::conditional_t || is_floating_point_v, Float64, Type0>; }; template struct ResultOfModuloLegacy { using Type0 = typename Construct || is_signed_v, false, sizeof(B)>::Type; - using Type = std::conditional_t || std::is_floating_point_v, Float64, Type0>; + using Type = std::conditional_t || is_floating_point_v, Float64, Type0>; }; template struct ResultOfNegate { using Type = typename Construct< true, - std::is_floating_point_v, + is_floating_point_v, is_signed_v ? sizeof(A) : nextSize(sizeof(A))>::Type; }; @@ -143,7 +143,7 @@ template struct ResultOfAbs { using Type = typename Construct< false, - std::is_floating_point_v, + is_floating_point_v, sizeof(A)>::Type; }; @@ -154,7 +154,7 @@ template struct ResultOfBit using Type = typename Construct< is_signed_v || is_signed_v, false, - std::is_floating_point_v || std::is_floating_point_v ? 8 : max(sizeof(A), sizeof(B))>::Type; + is_floating_point_v || is_floating_point_v ? 8 : max(sizeof(A), sizeof(B))>::Type; }; template struct ResultOfBitNot @@ -180,7 +180,7 @@ template struct ResultOfBitNot template struct ResultOfIf { - static constexpr bool has_float = std::is_floating_point_v || std::is_floating_point_v; + static constexpr bool has_float = is_floating_point_v || is_floating_point_v; static constexpr bool has_integer = is_integer || is_integer; static constexpr bool has_signed = is_signed_v || is_signed_v; static constexpr bool has_unsigned = !is_signed_v || !is_signed_v; @@ -189,7 +189,7 @@ struct ResultOfIf static constexpr size_t max_size_of_unsigned_integer = max(is_signed_v ? 0 : sizeof(A), is_signed_v ? 0 : sizeof(B)); static constexpr size_t max_size_of_signed_integer = max(is_signed_v ? sizeof(A) : 0, is_signed_v ? sizeof(B) : 0); static constexpr size_t max_size_of_integer = max(is_integer ? sizeof(A) : 0, is_integer ? sizeof(B) : 0); - static constexpr size_t max_size_of_float = max(std::is_floating_point_v ? sizeof(A) : 0, std::is_floating_point_v ? sizeof(B) : 0); + static constexpr size_t max_size_of_float = max(is_floating_point_v ? sizeof(A) : 0, is_floating_point_v ? sizeof(B) : 0); using ConstructedType = typename Construct= max_size_of_float) @@ -211,7 +211,7 @@ template struct ToInteger using Type = typename Construct< is_signed_v, false, - std::is_floating_point_v ? 8 : sizeof(A)>::Type; + is_floating_point_v ? 8 : sizeof(A)>::Type; }; diff --git a/src/DataTypes/Serializations/SerializationNumber.cpp b/src/DataTypes/Serializations/SerializationNumber.cpp index b6c7e4618b8..805253fccee 100644 --- a/src/DataTypes/Serializations/SerializationNumber.cpp +++ b/src/DataTypes/Serializations/SerializationNumber.cpp @@ -176,6 +176,7 @@ template class SerializationNumber; template class SerializationNumber; template class SerializationNumber; template class SerializationNumber; +template class SerializationNumber; template class SerializationNumber; template class SerializationNumber; diff --git a/src/DataTypes/Utils.cpp b/src/DataTypes/Utils.cpp index e58331a8bcb..d1e314e77dc 100644 --- a/src/DataTypes/Utils.cpp +++ b/src/DataTypes/Utils.cpp @@ -54,6 +54,13 @@ bool canBeSafelyCasted(const DataTypePtr & from_type, const DataTypePtr & to_typ return false; } + case TypeIndex::BFloat16: + { + if (to_which_type.isFloat32() || to_which_type.isFloat64() || to_which_type.isString()) + return true; + + return false; + } case TypeIndex::Float32: { if (to_which_type.isFloat64() || to_which_type.isString()) diff --git a/src/DataTypes/getLeastSupertype.cpp b/src/DataTypes/getLeastSupertype.cpp index e5bdb4b267f..0ed075563e2 100644 --- a/src/DataTypes/getLeastSupertype.cpp +++ b/src/DataTypes/getLeastSupertype.cpp @@ -108,6 +108,8 @@ DataTypePtr getNumericType(const TypeIndexSet & types) maximize(max_bits_of_signed_integer, 128); else if (type == TypeIndex::Int256) maximize(max_bits_of_signed_integer, 256); + else if (type == TypeIndex::BFloat16) + maximize(max_mantissa_bits_of_floating, 8); else if (type == TypeIndex::Float32) maximize(max_mantissa_bits_of_floating, 24); else if (type == TypeIndex::Float64) @@ -144,7 +146,9 @@ DataTypePtr getNumericType(const TypeIndexSet & types) if (max_mantissa_bits_of_floating) { size_t min_mantissa_bits = std::max(min_bit_width_of_integer, max_mantissa_bits_of_floating); - if (min_mantissa_bits <= 24) + if (min_mantissa_bits <= 8) + return std::make_shared(); + else if (min_mantissa_bits <= 24) return std::make_shared(); else if (min_mantissa_bits <= 53) return std::make_shared(); diff --git a/src/DataTypes/getMostSubtype.cpp b/src/DataTypes/getMostSubtype.cpp index 33b5735456e..d0ea716f2ff 100644 --- a/src/DataTypes/getMostSubtype.cpp +++ b/src/DataTypes/getMostSubtype.cpp @@ -297,6 +297,8 @@ DataTypePtr getMostSubtype(const DataTypes & types, bool throw_if_result_is_noth minimize(min_bits_of_signed_integer, 128); else if (typeid_cast(type.get())) minimize(min_bits_of_signed_integer, 256); + else if (typeid_cast(type.get())) + minimize(min_mantissa_bits_of_floating, 8); else if (typeid_cast(type.get())) minimize(min_mantissa_bits_of_floating, 24); else if (typeid_cast(type.get())) @@ -313,7 +315,9 @@ DataTypePtr getMostSubtype(const DataTypes & types, bool throw_if_result_is_noth /// If the result must be floating. if (!min_bits_of_signed_integer && !min_bits_of_unsigned_integer) { - if (min_mantissa_bits_of_floating <= 24) + if (min_mantissa_bits_of_floating <= 8) + return std::make_shared(); + else if (min_mantissa_bits_of_floating <= 24) return std::make_shared(); else if (min_mantissa_bits_of_floating <= 53) return std::make_shared(); diff --git a/src/Formats/ProtobufSerializer.cpp b/src/Formats/ProtobufSerializer.cpp index dd37c25719c..872991709af 100644 --- a/src/Formats/ProtobufSerializer.cpp +++ b/src/Formats/ProtobufSerializer.cpp @@ -540,7 +540,7 @@ namespace case FieldTypeId::TYPE_ENUM: { - if (std::is_floating_point_v) + if (is_floating_point_v) incompatibleColumnType(TypeName); write_function = [this](NumberType value) diff --git a/src/Functions/DivisionUtils.h b/src/Functions/DivisionUtils.h index ff07309e248..2508bd2b62b 100644 --- a/src/Functions/DivisionUtils.h +++ b/src/Functions/DivisionUtils.h @@ -47,9 +47,9 @@ inline auto checkedDivision(A a, B b) { throwIfDivisionLeadsToFPE(a, b); - if constexpr (is_big_int_v && std::is_floating_point_v) + if constexpr (is_big_int_v && is_floating_point_v) return static_cast(a) / b; - else if constexpr (is_big_int_v && std::is_floating_point_v) + else if constexpr (is_big_int_v && is_floating_point_v) return a / static_cast(b); else if constexpr (is_big_int_v && is_big_int_v) return static_cast(a / b); @@ -86,17 +86,17 @@ struct DivideIntegralImpl { /// Comparisons are not strict to avoid rounding issues when operand is implicitly casted to float. - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) if (isNaN(a) || a >= std::numeric_limits::max() || a <= std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) if (isNaN(b) || b >= std::numeric_limits::max() || b <= std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); auto res = checkedDivision(CastA(a), CastB(b)); - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) if (isNaN(res) || res >= static_cast(std::numeric_limits::max()) || res <= std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division, because it will produce infinite or too large number"); @@ -122,18 +122,18 @@ struct ModuloImpl template static inline Result apply(A a, B b) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { /// This computation is similar to `fmod` but the latter is not inlined and has 40 times worse performance. return static_cast(a) - trunc(static_cast(a) / static_cast(b)) * static_cast(b); } else { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) if (isNaN(a) || a > std::numeric_limits::max() || a < std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) if (isNaN(b) || b > std::numeric_limits::max() || b < std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); diff --git a/src/Functions/FunctionMathUnary.h b/src/Functions/FunctionMathUnary.h index 9f400932356..8395855a564 100644 --- a/src/Functions/FunctionMathUnary.h +++ b/src/Functions/FunctionMathUnary.h @@ -66,7 +66,7 @@ private: /// Process all data as a whole and use FastOps implementation /// If the argument is integer, convert to Float64 beforehand - if constexpr (!std::is_floating_point_v) + if constexpr (!is_floating_point_v) { PODArray tmp_vec(size); for (size_t i = 0; i < size; ++i) @@ -150,7 +150,7 @@ private: { using Types = std::decay_t; using Type = typename Types::RightType; - using ReturnType = std::conditional_t, Float64, Type>; + using ReturnType = std::conditional_t, Float64, Type>; using ColVecType = ColumnVectorOrDecimal; const auto col_vec = checkAndGetColumn(col.column.get()); diff --git a/src/Functions/FunctionsConversion.h b/src/Functions/FunctionsConversion.h index eed75788fcd..fe4b14f5053 100644 --- a/src/Functions/FunctionsConversion.h +++ b/src/Functions/FunctionsConversion.h @@ -291,7 +291,7 @@ struct ConvertImpl else { /// If From Data is Nan or Inf and we convert to integer type, throw exception - if constexpr (std::is_floating_point_v && !std::is_floating_point_v) + if constexpr (is_floating_point_v && !is_floating_point_v) { if (!isFinite(vec_from[i])) { @@ -1314,7 +1314,7 @@ inline void convertFromTime(DataTypeDateTime::FieldType & x, t template void parseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTImpl *, bool precise_float_parsing) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { if (precise_float_parsing) readFloatTextPrecise(x, rb); @@ -1378,7 +1378,7 @@ inline void parseImpl(DataTypeIPv6::FieldType & x, ReadBuffer & rb template bool tryParseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTImpl *, bool precise_float_parsing) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { if (precise_float_parsing) return tryReadFloatTextPrecise(x, rb); @@ -2350,9 +2350,9 @@ private: using RightT = typename RightDataType::FieldType; static constexpr bool bad_left = - is_decimal || std::is_floating_point_v || is_big_int_v || is_signed_v; + is_decimal || is_floating_point_v || is_big_int_v || is_signed_v; static constexpr bool bad_right = - is_decimal || std::is_floating_point_v || is_big_int_v || is_signed_v; + is_decimal || is_floating_point_v || is_big_int_v || is_signed_v; /// Disallow int vs UUID conversion (but support int vs UInt128 conversion) if constexpr ((bad_left && std::is_same_v) || @@ -2678,7 +2678,7 @@ struct ToNumberMonotonicity /// Float cases. /// When converting to Float, the conversion is always monotonic. - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) return { .is_monotonic = true, .is_always_monotonic = true }; const auto * low_cardinality = typeid_cast(&type); diff --git a/src/Functions/FunctionsRound.h b/src/Functions/FunctionsRound.h index 3d1028c6d35..d775d616eb2 100644 --- a/src/Functions/FunctionsRound.h +++ b/src/Functions/FunctionsRound.h @@ -461,7 +461,7 @@ template - using FunctionRoundingImpl = std::conditional_t, + using FunctionRoundingImpl = std::conditional_t, FloatRoundingImpl, IntegerRoundingImpl>; diff --git a/src/Functions/factorial.cpp b/src/Functions/factorial.cpp index b814e8198e6..be545e398cd 100644 --- a/src/Functions/factorial.cpp +++ b/src/Functions/factorial.cpp @@ -21,7 +21,7 @@ struct FactorialImpl static inline NO_SANITIZE_UNDEFINED ResultType apply(A a) { - if constexpr (std::is_floating_point_v || is_over_big_int) + if constexpr (is_floating_point_v || is_over_big_int) throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type of argument of function factorial, should not be floating point or big int"); diff --git a/src/Functions/minus.cpp b/src/Functions/minus.cpp index 04877a42b18..109e5894f5e 100644 --- a/src/Functions/minus.cpp +++ b/src/Functions/minus.cpp @@ -17,8 +17,8 @@ struct MinusImpl { if constexpr (is_big_int_v || is_big_int_v) { - using CastA = std::conditional_t, B, A>; - using CastB = std::conditional_t, A, B>; + using CastA = std::conditional_t, B, A>; + using CastB = std::conditional_t, A, B>; return static_cast(static_cast(a)) - static_cast(static_cast(b)); } diff --git a/src/Functions/moduloOrZero.cpp b/src/Functions/moduloOrZero.cpp index 3551ae74c5f..bfd786940ce 100644 --- a/src/Functions/moduloOrZero.cpp +++ b/src/Functions/moduloOrZero.cpp @@ -17,7 +17,7 @@ struct ModuloOrZeroImpl template static inline Result apply(A a, B b) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { /// This computation is similar to `fmod` but the latter is not inlined and has 40 times worse performance. return ResultType(a) - trunc(ResultType(a) / ResultType(b)) * ResultType(b); diff --git a/src/Functions/multiply.cpp b/src/Functions/multiply.cpp index 4dc8cd10f31..ef51fe6061e 100644 --- a/src/Functions/multiply.cpp +++ b/src/Functions/multiply.cpp @@ -18,8 +18,8 @@ struct MultiplyImpl { if constexpr (is_big_int_v || is_big_int_v) { - using CastA = std::conditional_t, B, A>; - using CastB = std::conditional_t, A, B>; + using CastA = std::conditional_t, B, A>; + using CastB = std::conditional_t, A, B>; return static_cast(static_cast(a)) * static_cast(static_cast(b)); } diff --git a/src/Functions/plus.cpp b/src/Functions/plus.cpp index cd9cf6cec5c..ea79fb4702a 100644 --- a/src/Functions/plus.cpp +++ b/src/Functions/plus.cpp @@ -19,8 +19,8 @@ struct PlusImpl /// Next everywhere, static_cast - so that there is no wrong result in expressions of the form Int64 c = UInt32(a) * Int32(-1). if constexpr (is_big_int_v || is_big_int_v) { - using CastA = std::conditional_t, B, A>; - using CastB = std::conditional_t, A, B>; + using CastA = std::conditional_t, B, A>; + using CastB = std::conditional_t, A, B>; return static_cast(static_cast(a)) + static_cast(static_cast(b)); } diff --git a/src/Functions/sign.cpp b/src/Functions/sign.cpp index 6c849760eed..59a307e43bb 100644 --- a/src/Functions/sign.cpp +++ b/src/Functions/sign.cpp @@ -13,7 +13,7 @@ struct SignImpl static inline NO_SANITIZE_UNDEFINED ResultType apply(A a) { - if constexpr (is_decimal || std::is_floating_point_v) + if constexpr (is_decimal || is_floating_point_v) return a < A(0) ? -1 : a == A(0) ? 0 : 1; else if constexpr (is_signed_v) return a < 0 ? -1 : a == 0 ? 0 : 1; diff --git a/src/IO/ReadHelpers.h b/src/IO/ReadHelpers.h index 85584d63ee8..6068f49f5bf 100644 --- a/src/IO/ReadHelpers.h +++ b/src/IO/ReadHelpers.h @@ -1316,7 +1316,9 @@ inline bool tryReadText(UUID & x, ReadBuffer & buf) { return tryReadUUIDText(x, inline bool tryReadText(IPv4 & x, ReadBuffer & buf) { return tryReadIPv4Text(x, buf); } inline bool tryReadText(IPv6 & x, ReadBuffer & buf) { return tryReadIPv6Text(x, buf); } -inline void readText(is_floating_point auto & x, ReadBuffer & buf) { readFloatText(x, buf); } +template +requires is_floating_point_v +inline void readText(T & x, ReadBuffer & buf) { readFloatText(x, buf); } inline void readText(String & x, ReadBuffer & buf) { readEscapedString(x, buf); } diff --git a/src/IO/WriteHelpers.h b/src/IO/WriteHelpers.h index b4f8b476b11..c6a86b05f4d 100644 --- a/src/IO/WriteHelpers.h +++ b/src/IO/WriteHelpers.h @@ -153,6 +153,7 @@ inline void writeBoolText(bool x, WriteBuffer & buf) template +requires is_floating_point_v inline size_t writeFloatTextFastPath(T x, char * buffer) { Int64 result = 0; @@ -169,10 +170,13 @@ inline size_t writeFloatTextFastPath(T x, char * buffer) } else { - if (DecomposedFloat32(x).isIntegerInRepresentableRange()) - result = itoa(Int32(x), buffer) - buffer; + /// This will support 16-bit floats as well. + float f32 = x; + + if (DecomposedFloat32(f32).isIntegerInRepresentableRange()) + result = itoa(Int32(f32), buffer) - buffer; else - result = jkj::dragonbox::to_chars_n(x, buffer) - buffer; + result = jkj::dragonbox::to_chars_n(f32, buffer) - buffer; } if (result <= 0) @@ -181,10 +185,9 @@ inline size_t writeFloatTextFastPath(T x, char * buffer) } template +requires is_floating_point_v inline void writeFloatText(T x, WriteBuffer & buf) { - static_assert(std::is_same_v || std::is_same_v, "Argument for writeFloatText must be float or double"); - using Converter = DoubleConverter; if (likely(buf.available() >= Converter::MAX_REPRESENTATION_LENGTH)) { @@ -530,7 +533,7 @@ void writeJSONNumber(T x, WriteBuffer & ostr, const FormatSettings & settings) bool is_finite = isFinite(x); const bool need_quote = (is_integer && (sizeof(T) >= 8) && settings.json.quote_64bit_integers) - || (settings.json.quote_denormals && !is_finite) || (is_floating_point && (sizeof(T) >= 8) && settings.json.quote_64bit_floats); + || (settings.json.quote_denormals && !is_finite) || (is_floating_point_v && (sizeof(T) >= 8) && settings.json.quote_64bit_floats); if (need_quote) writeChar('"', ostr); @@ -541,7 +544,7 @@ void writeJSONNumber(T x, WriteBuffer & ostr, const FormatSettings & settings) writeCString("null", ostr); else { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point_v) { if (std::signbit(x)) { @@ -800,7 +803,6 @@ inline void writeXMLStringForTextElement(std::string_view s, WriteBuffer & buf) } /// @brief Serialize `uuid` into an array of characters in big-endian byte order. -/// @param uuid UUID to serialize. /// @return Array of characters in big-endian byte order. std::array formatUUID(const UUID & uuid); @@ -1065,7 +1067,9 @@ inline void writeText(is_integer auto x, WriteBuffer & buf) writeIntText(x, buf); } -inline void writeText(is_floating_point auto x, WriteBuffer & buf) { writeFloatText(x, buf); } +template +requires is_floating_point_v +inline void writeText(T x, WriteBuffer & buf) { writeFloatText(x, buf); } inline void writeText(is_enum auto x, WriteBuffer & buf) { writeText(magic_enum::enum_name(x), buf); } diff --git a/src/Interpreters/RowRefs.cpp b/src/Interpreters/RowRefs.cpp index 4335cde47f9..61caacd8346 100644 --- a/src/Interpreters/RowRefs.cpp +++ b/src/Interpreters/RowRefs.cpp @@ -181,7 +181,7 @@ private: if (!sorted.load(std::memory_order_relaxed)) { - if constexpr (std::is_arithmetic_v && !std::is_floating_point_v) + if constexpr (std::is_arithmetic_v && !is_floating_point_v) { if (likely(entries.size() > 256)) { From 2750f8ca1d9b7336fbfec6fc0e73a8fbf17eadee Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 2 Jun 2024 02:27:48 +0200 Subject: [PATCH 02/35] Whitespace --- src/Storages/StorageSet.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Storages/StorageSet.cpp b/src/Storages/StorageSet.cpp index 205a90423bf..a8c8e81e23d 100644 --- a/src/Storages/StorageSet.cpp +++ b/src/Storages/StorageSet.cpp @@ -130,7 +130,6 @@ StorageSetOrJoinBase::StorageSetOrJoinBase( storage_metadata.setComment(comment); setInMemoryMetadata(storage_metadata); - if (relative_path_.empty()) throw Exception(ErrorCodes::INCORRECT_FILE_NAME, "Join and Set storages require data path"); From 6e08f415c49afeac27ce08f97cde365dbf5940a2 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 2 Jun 2024 04:26:14 +0200 Subject: [PATCH 03/35] Preparation --- base/base/DecomposedFloat.h | 9 ++++++++ base/base/EnumReflection.h | 2 +- base/base/extended_types.h | 14 ++++-------- base/base/wide_integer_impl.h | 2 +- .../AggregateFunctionGroupArray.cpp | 4 ++-- .../AggregateFunctionGroupArrayMoving.cpp | 2 +- .../AggregateFunctionIntervalLengthSum.cpp | 4 ++-- .../AggregateFunctionSparkbar.cpp | 6 ++--- src/AggregateFunctions/AggregateFunctionSum.h | 6 ++--- src/AggregateFunctions/QuantileTDigest.h | 2 +- src/AggregateFunctions/ReservoirSampler.h | 2 +- .../ReservoirSamplerDeterministic.h | 2 +- src/Columns/ColumnVector.cpp | 16 +++++++------- src/Common/FieldVisitorConvertToNumber.h | 4 ++-- src/Common/HashTable/HashTable.h | 2 +- src/Common/NaNUtils.h | 14 ++++++------ src/Common/findExtreme.cpp | 2 +- src/Common/transformEndianness.h | 2 +- src/Core/AccurateComparison.h | 18 +++++++-------- src/Core/DecimalFunctions.h | 2 +- src/DataTypes/DataTypesDecimal.cpp | 5 +++-- src/DataTypes/NumberTraits.h | 22 +++++++++---------- src/Formats/ProtobufSerializer.cpp | 2 +- src/Functions/DivisionUtils.h | 16 +++++++------- src/Functions/FunctionMathUnary.h | 4 ++-- src/Functions/FunctionsConversion.cpp | 12 +++++----- src/Functions/FunctionsJSON.h | 4 ++-- src/Functions/FunctionsRound.h | 2 +- src/Functions/FunctionsVisitParam.h | 2 +- src/Functions/abs.cpp | 2 +- src/Functions/array/arrayAggregation.cpp | 2 +- src/Functions/factorial.cpp | 2 +- src/Functions/if.cpp | 16 +++++++------- src/Functions/minus.cpp | 4 ++-- src/Functions/moduloOrZero.cpp | 2 +- src/Functions/multiply.cpp | 4 ++-- src/Functions/plus.cpp | 4 ++-- src/Functions/sign.cpp | 2 +- src/IO/ReadHelpers.h | 2 +- src/IO/WriteHelpers.h | 10 ++++----- src/Interpreters/RowRefs.cpp | 2 +- 41 files changed, 120 insertions(+), 116 deletions(-) diff --git a/base/base/DecomposedFloat.h b/base/base/DecomposedFloat.h index 0997c39db16..b5bc3f08357 100644 --- a/base/base/DecomposedFloat.h +++ b/base/base/DecomposedFloat.h @@ -96,6 +96,15 @@ struct DecomposedFloat && ((mantissa() & ((1ULL << (Traits::mantissa_bits - normalizedExponent())) - 1)) == 0)); } + bool isFinite() const + { + return exponent() != ((1ull << Traits::exponent_bits) - 1); + } + + bool isNaN() const + { + return !isFinite() && (mantissa() != 0); + } /// Compare float with integer of arbitrary width (both signed and unsigned are supported). Assuming two's complement arithmetic. /// This function is generic, big integers (128, 256 bit) are supported as well. diff --git a/base/base/EnumReflection.h b/base/base/EnumReflection.h index 4a9de4d17a3..963c7e3f1b9 100644 --- a/base/base/EnumReflection.h +++ b/base/base/EnumReflection.h @@ -4,7 +4,7 @@ #include -template concept is_enum = std::is_enum_v; +template concept is_enum = std::is_enum_v; namespace detail { diff --git a/base/base/extended_types.h b/base/base/extended_types.h index de654152649..7ddf7de7e22 100644 --- a/base/base/extended_types.h +++ b/base/base/extended_types.h @@ -43,7 +43,7 @@ template <> struct is_unsigned { static constexpr bool value = true; }; template inline constexpr bool is_unsigned_v = is_unsigned::value; -template concept is_integer = +template concept is_integer = std::is_integral_v || std::is_same_v || std::is_same_v @@ -65,16 +65,10 @@ template <> struct is_arithmetic { static constexpr bool value = true; template inline constexpr bool is_arithmetic_v = is_arithmetic::value; -template -struct is_floating_point // NOLINT(readability-identifier-naming) -{ - static constexpr bool value = std::is_floating_point_v; -}; +template concept is_floating_point = + std::is_floating_point_v + || std::is_same_v; -template <> struct is_floating_point { static constexpr bool value = true; }; - -template -inline constexpr bool is_floating_point_v = is_floating_point::value; #define FOR_EACH_ARITHMETIC_TYPE(M) \ M(DataTypeDate) \ diff --git a/base/base/wide_integer_impl.h b/base/base/wide_integer_impl.h index c950fd27fa3..d0bbd7df9d4 100644 --- a/base/base/wide_integer_impl.h +++ b/base/base/wide_integer_impl.h @@ -154,7 +154,7 @@ struct common_type, Arithmetic> static_assert(wide::ArithmeticConcept()); using type = std::conditional_t< - is_floating_point_v, + std::is_floating_point_v || std::is_same_v, Arithmetic, std::conditional_t< sizeof(Arithmetic) * 8 < Bits, diff --git a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp index 0b478fe3c04..3a0bbb001c3 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArray.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupArray.cpp @@ -73,7 +73,7 @@ template struct GroupArraySamplerData { /// For easy serialization. - static_assert(std::has_unique_object_representations_v || is_floating_point_v); + static_assert(std::has_unique_object_representations_v || is_floating_point); // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena using Allocator = MixedAlignedArenaAllocator; @@ -115,7 +115,7 @@ template struct GroupArrayNumericData { /// For easy serialization. - static_assert(std::has_unique_object_representations_v || is_floating_point_v); + static_assert(std::has_unique_object_representations_v || is_floating_point); // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena using Allocator = MixedAlignedArenaAllocator; diff --git a/src/AggregateFunctions/AggregateFunctionGroupArrayMoving.cpp b/src/AggregateFunctions/AggregateFunctionGroupArrayMoving.cpp index ee6a82686c5..a9a09d7abd5 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupArrayMoving.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupArrayMoving.cpp @@ -38,7 +38,7 @@ template struct MovingData { /// For easy serialization. - static_assert(std::has_unique_object_representations_v || is_floating_point_v); + static_assert(std::has_unique_object_representations_v || is_floating_point); using Accumulator = T; diff --git a/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.cpp b/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.cpp index 06156643aa0..e5404add820 100644 --- a/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.cpp +++ b/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.cpp @@ -187,7 +187,7 @@ public: static DataTypePtr createResultType() { - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) return std::make_shared(); return std::make_shared(); } @@ -227,7 +227,7 @@ public: void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) assert_cast(to).getData().push_back(getIntervalLengthSum(this->data(place))); else assert_cast(to).getData().push_back(getIntervalLengthSum(this->data(place))); diff --git a/src/AggregateFunctions/AggregateFunctionSparkbar.cpp b/src/AggregateFunctions/AggregateFunctionSparkbar.cpp index 5b6fc3b315c..33412d50b21 100644 --- a/src/AggregateFunctions/AggregateFunctionSparkbar.cpp +++ b/src/AggregateFunctions/AggregateFunctionSparkbar.cpp @@ -50,7 +50,7 @@ struct AggregateFunctionSparkbarData auto [it, inserted] = points.insert({x, y}); if (!inserted) { - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { it->getMapped() += y; return it->getMapped(); @@ -197,7 +197,7 @@ private: Y res; bool has_overfllow = false; - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) res = histogram[index] + point.getMapped(); else has_overfllow = common::addOverflow(histogram[index], point.getMapped(), res); @@ -246,7 +246,7 @@ private: } constexpr auto levels_num = static_cast(BAR_LEVELS - 1); - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { y = y / (y_max / levels_num) + 1; } diff --git a/src/AggregateFunctions/AggregateFunctionSum.h b/src/AggregateFunctions/AggregateFunctionSum.h index c663c632280..d0d600be70b 100644 --- a/src/AggregateFunctions/AggregateFunctionSum.h +++ b/src/AggregateFunctions/AggregateFunctionSum.h @@ -69,7 +69,7 @@ struct AggregateFunctionSumData size_t count = end - start; const auto * end_ptr = ptr + count; - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { /// Compiler cannot unroll this loop, do it manually. /// (at least for floats, most likely due to the lack of -fassociative-math) @@ -193,7 +193,7 @@ struct AggregateFunctionSumData Impl::add(sum, local_sum); return; } - else if constexpr (is_floating_point_v) + else if constexpr (is_floating_point) { /// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned /// integer of the same size and use a mask instead (0 to discard, 0xFF..FF to keep) @@ -306,7 +306,7 @@ struct AggregateFunctionSumData template struct AggregateFunctionSumKahanData { - static_assert(is_floating_point_v, + static_assert(is_floating_point, "It doesn't make sense to use Kahan Summation algorithm for non floating point types"); T sum{}; diff --git a/src/AggregateFunctions/QuantileTDigest.h b/src/AggregateFunctions/QuantileTDigest.h index 408e500e941..a693c57e6d8 100644 --- a/src/AggregateFunctions/QuantileTDigest.h +++ b/src/AggregateFunctions/QuantileTDigest.h @@ -379,7 +379,7 @@ public: ResultType getImpl(Float64 level) { if (centroids.empty()) - return is_floating_point_v ? std::numeric_limits::quiet_NaN() : 0; + return is_floating_point ? std::numeric_limits::quiet_NaN() : 0; compress(); diff --git a/src/AggregateFunctions/ReservoirSampler.h b/src/AggregateFunctions/ReservoirSampler.h index 182a49af2ca..c21e76614c1 100644 --- a/src/AggregateFunctions/ReservoirSampler.h +++ b/src/AggregateFunctions/ReservoirSampler.h @@ -278,6 +278,6 @@ private: if (OnEmpty == ReservoirSamplerOnEmpty::THROW) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSampler"); else - return NanLikeValueConstructor>::getValue(); + return NanLikeValueConstructor>::getValue(); } }; diff --git a/src/AggregateFunctions/ReservoirSamplerDeterministic.h b/src/AggregateFunctions/ReservoirSamplerDeterministic.h index c9afcb21549..7fe5d23f4e4 100644 --- a/src/AggregateFunctions/ReservoirSamplerDeterministic.h +++ b/src/AggregateFunctions/ReservoirSamplerDeterministic.h @@ -272,7 +272,7 @@ private: if (OnEmpty == ReservoirSamplerDeterministicOnEmpty::THROW) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSamplerDeterministic"); else - return NanLikeValueConstructor>::getValue(); + return NanLikeValueConstructor>::getValue(); } }; diff --git a/src/Columns/ColumnVector.cpp b/src/Columns/ColumnVector.cpp index 19849b8a1c6..2b137231faa 100644 --- a/src/Columns/ColumnVector.cpp +++ b/src/Columns/ColumnVector.cpp @@ -118,7 +118,7 @@ struct ColumnVector::less_stable if (unlikely(parent.data[lhs] == parent.data[rhs])) return lhs < rhs; - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { if (unlikely(std::isnan(parent.data[lhs]) && std::isnan(parent.data[rhs]))) { @@ -150,7 +150,7 @@ struct ColumnVector::greater_stable if (unlikely(parent.data[lhs] == parent.data[rhs])) return lhs < rhs; - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { if (unlikely(std::isnan(parent.data[lhs]) && std::isnan(parent.data[rhs]))) { @@ -224,9 +224,9 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction iota(res.data(), data_size, IColumn::Permutation::value_type(0)); - if constexpr (has_find_extreme_implementation && !std::is_floating_point_v) + if constexpr (has_find_extreme_implementation && !is_floating_point) { - /// Disabled for:floating point + /// Disabled for floating point: /// * floating point: We don't deal with nan_direction_hint /// * stability::Stable: We might return any value, not the first if ((limit == 1) && (stability == IColumn::PermutationSortStability::Unstable)) @@ -256,7 +256,7 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; /// TODO: LSD RadixSort is currently not stable if direction is descending, or value is floating point - bool use_radix_sort = (sort_is_stable && ascending && !is_floating_point_v) || !sort_is_stable; + bool use_radix_sort = (sort_is_stable && ascending && !is_floating_point) || !sort_is_stable; /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. if (data_size >= 256 && data_size <= std::numeric_limits::max() && use_radix_sort) @@ -283,7 +283,7 @@ void ColumnVector::getPermutation(IColumn::PermutationSortDirection direction /// Radix sort treats all NaNs to be greater than all numbers. /// If the user needs the opposite, we must move them accordingly. - if (is_floating_point_v && nan_direction_hint < 0) + if (is_floating_point && nan_direction_hint < 0) { size_t nans_to_move = 0; @@ -330,7 +330,7 @@ void ColumnVector::updatePermutation(IColumn::PermutationSortDirection direct if constexpr (is_arithmetic_v && !is_big_int_v) { /// TODO: LSD RadixSort is currently not stable if direction is descending, or value is floating point - bool use_radix_sort = (sort_is_stable && ascending && !is_floating_point_v) || !sort_is_stable; + bool use_radix_sort = (sort_is_stable && ascending && !is_floating_point) || !sort_is_stable; size_t size = end - begin; /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. @@ -353,7 +353,7 @@ void ColumnVector::updatePermutation(IColumn::PermutationSortDirection direct /// Radix sort treats all NaNs to be greater than all numbers. /// If the user needs the opposite, we must move them accordingly. - if (is_floating_point_v && nan_direction_hint < 0) + if (is_floating_point && nan_direction_hint < 0) { size_t nans_to_move = 0; diff --git a/src/Common/FieldVisitorConvertToNumber.h b/src/Common/FieldVisitorConvertToNumber.h index 646caadce35..ebd084df54d 100644 --- a/src/Common/FieldVisitorConvertToNumber.h +++ b/src/Common/FieldVisitorConvertToNumber.h @@ -58,7 +58,7 @@ public: T operator() (const Float64 & x) const { - if constexpr (!is_floating_point_v) + if constexpr (!is_floating_point) { if (!isFinite(x)) { @@ -88,7 +88,7 @@ public: template T operator() (const DecimalField & x) const { - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) return x.getValue().template convertTo() / x.getScaleMultiplier().template convertTo(); else return (x.getValue() / x.getScaleMultiplier()).template convertTo(); diff --git a/src/Common/HashTable/HashTable.h b/src/Common/HashTable/HashTable.h index fd8832a56a3..8237c81461f 100644 --- a/src/Common/HashTable/HashTable.h +++ b/src/Common/HashTable/HashTable.h @@ -91,7 +91,7 @@ inline bool bitEquals(T && a, T && b) { using RealT = std::decay_t; - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) /// Note that memcmp with constant size is compiler builtin. return 0 == memcmp(&a, &b, sizeof(RealT)); /// NOLINT else diff --git a/src/Common/NaNUtils.h b/src/Common/NaNUtils.h index 0e885541599..3e4af902104 100644 --- a/src/Common/NaNUtils.h +++ b/src/Common/NaNUtils.h @@ -3,24 +3,24 @@ #include #include #include +#include template inline bool isNaN(T x) { /// To be sure, that this function is zero-cost for non-floating point types. - if constexpr (is_floating_point_v) - return std::isnan(x); + if constexpr (is_floating_point) + return DecomposedFloat(x).isNaN(); else return false; } - template inline bool isFinite(T x) { - if constexpr (is_floating_point_v) - return std::isfinite(x); + if constexpr (is_floating_point) + return DecomposedFloat(x).isFinite(); else return true; } @@ -28,7 +28,7 @@ inline bool isFinite(T x) template bool canConvertTo(Float64 x) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) return true; if (!isFinite(x)) return false; @@ -41,7 +41,7 @@ bool canConvertTo(Float64 x) template T NaNOrZero() { - if constexpr (is_floating_point_v) + if constexpr (std::is_floating_point_v) return std::numeric_limits::quiet_NaN(); else return {}; diff --git a/src/Common/findExtreme.cpp b/src/Common/findExtreme.cpp index ce3bbb86d7c..a29750b848a 100644 --- a/src/Common/findExtreme.cpp +++ b/src/Common/findExtreme.cpp @@ -47,7 +47,7 @@ MULTITARGET_FUNCTION_AVX2_SSE42( /// Unroll the loop manually for floating point, since the compiler doesn't do it without fastmath /// as it might change the return value - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) { constexpr size_t unroll_block = 512 / sizeof(T); /// Chosen via benchmarks with AVX2 so YMMV size_t unrolled_end = i + (((count - i) / unroll_block) * unroll_block); diff --git a/src/Common/transformEndianness.h b/src/Common/transformEndianness.h index 2a0c45efe38..e6e04ec75af 100644 --- a/src/Common/transformEndianness.h +++ b/src/Common/transformEndianness.h @@ -38,7 +38,7 @@ inline void transformEndianness(T & x) } template -requires is_floating_point_v +requires is_floating_point inline void transformEndianness(T & value) { if constexpr (ToEndian != FromEndian) diff --git a/src/Core/AccurateComparison.h b/src/Core/AccurateComparison.h index c1e93b8055a..87ff14e40e7 100644 --- a/src/Core/AccurateComparison.h +++ b/src/Core/AccurateComparison.h @@ -25,7 +25,7 @@ bool lessOp(A a, B b) return a < b; /// float vs float - if constexpr (is_floating_point_v && is_floating_point_v) + if constexpr (is_floating_point && is_floating_point) return a < b; /// anything vs NaN @@ -49,7 +49,7 @@ bool lessOp(A a, B b) } /// int vs float - if constexpr (is_integer && is_floating_point_v) + if constexpr (is_integer && is_floating_point) { if constexpr (sizeof(A) <= 4) return static_cast(a) < static_cast(b); @@ -57,7 +57,7 @@ bool lessOp(A a, B b) return DecomposedFloat(b).greater(a); } - if constexpr (is_floating_point_v && is_integer) + if constexpr (is_floating_point && is_integer) { if constexpr (sizeof(B) <= 4) return static_cast(a) < static_cast(b); @@ -65,8 +65,8 @@ bool lessOp(A a, B b) return DecomposedFloat(a).less(b); } - static_assert(is_integer || is_floating_point_v); - static_assert(is_integer || is_floating_point_v); + static_assert(is_integer || is_floating_point); + static_assert(is_integer || is_floating_point); UNREACHABLE(); } @@ -101,7 +101,7 @@ bool equalsOp(A a, B b) return a == b; /// float vs float - if constexpr (is_floating_point_v && is_floating_point_v) + if constexpr (is_floating_point && is_floating_point) return a == b; /// anything vs NaN @@ -125,7 +125,7 @@ bool equalsOp(A a, B b) } /// int vs float - if constexpr (is_integer && is_floating_point_v) + if constexpr (is_integer && is_floating_point) { if constexpr (sizeof(A) <= 4) return static_cast(a) == static_cast(b); @@ -133,7 +133,7 @@ bool equalsOp(A a, B b) return DecomposedFloat(b).equals(a); } - if constexpr (is_floating_point_v && is_integer) + if constexpr (is_floating_point && is_integer) { if constexpr (sizeof(B) <= 4) return static_cast(a) == static_cast(b); @@ -163,7 +163,7 @@ inline bool NO_SANITIZE_UNDEFINED convertNumeric(From value, To & result) return true; } - if constexpr (is_floating_point_v && is_floating_point_v) + if constexpr (is_floating_point && is_floating_point) { /// Note that NaNs doesn't compare equal to anything, but they are still in range of any Float type. if (isNaN(value)) diff --git a/src/Core/DecimalFunctions.h b/src/Core/DecimalFunctions.h index c5bc4ad70f6..435cef61145 100644 --- a/src/Core/DecimalFunctions.h +++ b/src/Core/DecimalFunctions.h @@ -310,7 +310,7 @@ ReturnType convertToImpl(const DecimalType & decimal, UInt32 scale, To & result) using DecimalNativeType = typename DecimalType::NativeType; static constexpr bool throw_exception = std::is_void_v; - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { result = static_cast(decimal.value) / static_cast(scaleMultiplier(scale)); } diff --git a/src/DataTypes/DataTypesDecimal.cpp b/src/DataTypes/DataTypesDecimal.cpp index 77a7a3e7237..d87eff97675 100644 --- a/src/DataTypes/DataTypesDecimal.cpp +++ b/src/DataTypes/DataTypesDecimal.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -269,9 +270,9 @@ ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & value, static constexpr bool throw_exception = std::is_same_v; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) { - if (!std::isfinite(value)) + if (!isFinite(value)) { if constexpr (throw_exception) throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "{} convert overflow. Cannot convert infinity or NaN to decimal", ToDataType::family_name); diff --git a/src/DataTypes/NumberTraits.h b/src/DataTypes/NumberTraits.h index ad1e9eaa67b..ee0d9812097 100644 --- a/src/DataTypes/NumberTraits.h +++ b/src/DataTypes/NumberTraits.h @@ -74,7 +74,7 @@ template struct ResultOfAdditionMultiplication { using Type = typename Construct< is_signed_v || is_signed_v, - is_floating_point_v || is_floating_point_v, + is_floating_point || is_floating_point, nextSize(max(sizeof(A), sizeof(B)))>::Type; }; @@ -82,7 +82,7 @@ template struct ResultOfSubtraction { using Type = typename Construct< true, - is_floating_point_v || is_floating_point_v, + is_floating_point || is_floating_point, nextSize(max(sizeof(A), sizeof(B)))>::Type; }; @@ -113,7 +113,7 @@ template struct ResultOfModulo /// Example: toInt32(-199) % toUInt8(200) will return -199 that does not fit in Int8, only in Int16. static constexpr size_t size_of_result = result_is_signed ? nextSize(sizeof(B)) : sizeof(B); using Type0 = typename Construct::Type; - using Type = std::conditional_t || is_floating_point_v, Float64, Type0>; + using Type = std::conditional_t || is_floating_point, Float64, Type0>; }; template struct ResultOfPositiveModulo @@ -121,21 +121,21 @@ template struct ResultOfPositiveModulo /// function positive_modulo always return non-negative number. static constexpr size_t size_of_result = sizeof(B); using Type0 = typename Construct::Type; - using Type = std::conditional_t || is_floating_point_v, Float64, Type0>; + using Type = std::conditional_t || is_floating_point, Float64, Type0>; }; template struct ResultOfModuloLegacy { using Type0 = typename Construct || is_signed_v, false, sizeof(B)>::Type; - using Type = std::conditional_t || is_floating_point_v, Float64, Type0>; + using Type = std::conditional_t || is_floating_point, Float64, Type0>; }; template struct ResultOfNegate { using Type = typename Construct< true, - is_floating_point_v, + is_floating_point, is_signed_v ? sizeof(A) : nextSize(sizeof(A))>::Type; }; @@ -143,7 +143,7 @@ template struct ResultOfAbs { using Type = typename Construct< false, - is_floating_point_v, + is_floating_point, sizeof(A)>::Type; }; @@ -154,7 +154,7 @@ template struct ResultOfBit using Type = typename Construct< is_signed_v || is_signed_v, false, - is_floating_point_v || is_floating_point_v ? 8 : max(sizeof(A), sizeof(B))>::Type; + is_floating_point || is_floating_point ? 8 : max(sizeof(A), sizeof(B))>::Type; }; template struct ResultOfBitNot @@ -180,7 +180,7 @@ template struct ResultOfBitNot template struct ResultOfIf { - static constexpr bool has_float = is_floating_point_v || is_floating_point_v; + static constexpr bool has_float = is_floating_point || is_floating_point; static constexpr bool has_integer = is_integer || is_integer; static constexpr bool has_signed = is_signed_v || is_signed_v; static constexpr bool has_unsigned = !is_signed_v || !is_signed_v; @@ -189,7 +189,7 @@ struct ResultOfIf static constexpr size_t max_size_of_unsigned_integer = max(is_signed_v ? 0 : sizeof(A), is_signed_v ? 0 : sizeof(B)); static constexpr size_t max_size_of_signed_integer = max(is_signed_v ? sizeof(A) : 0, is_signed_v ? sizeof(B) : 0); static constexpr size_t max_size_of_integer = max(is_integer ? sizeof(A) : 0, is_integer ? sizeof(B) : 0); - static constexpr size_t max_size_of_float = max(is_floating_point_v ? sizeof(A) : 0, is_floating_point_v ? sizeof(B) : 0); + static constexpr size_t max_size_of_float = max(is_floating_point ? sizeof(A) : 0, is_floating_point ? sizeof(B) : 0); using ConstructedType = typename Construct= max_size_of_float) @@ -211,7 +211,7 @@ template struct ToInteger using Type = typename Construct< is_signed_v, false, - is_floating_point_v ? 8 : sizeof(A)>::Type; + is_floating_point ? 8 : sizeof(A)>::Type; }; diff --git a/src/Formats/ProtobufSerializer.cpp b/src/Formats/ProtobufSerializer.cpp index 7f03bdeb45d..86b11f45b72 100644 --- a/src/Formats/ProtobufSerializer.cpp +++ b/src/Formats/ProtobufSerializer.cpp @@ -541,7 +541,7 @@ namespace case FieldTypeId::TYPE_ENUM: { - if (is_floating_point_v) + if (is_floating_point) incompatibleColumnType(TypeName); write_function = [this](NumberType value) diff --git a/src/Functions/DivisionUtils.h b/src/Functions/DivisionUtils.h index 1a241c7171a..e8f5da342f8 100644 --- a/src/Functions/DivisionUtils.h +++ b/src/Functions/DivisionUtils.h @@ -47,9 +47,9 @@ inline auto checkedDivision(A a, B b) { throwIfDivisionLeadsToFPE(a, b); - if constexpr (is_big_int_v && is_floating_point_v) + if constexpr (is_big_int_v && is_floating_point) return static_cast(a) / b; - else if constexpr (is_big_int_v && is_floating_point_v) + else if constexpr (is_big_int_v && is_floating_point) return a / static_cast(b); else if constexpr (is_big_int_v && is_big_int_v) return static_cast(a / b); @@ -86,17 +86,17 @@ struct DivideIntegralImpl { /// Comparisons are not strict to avoid rounding issues when operand is implicitly casted to float. - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) if (isNaN(a) || a >= std::numeric_limits::max() || a <= std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) if (isNaN(b) || b >= std::numeric_limits::max() || b <= std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); auto res = checkedDivision(CastA(a), CastB(b)); - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) if (isNaN(res) || res >= static_cast(std::numeric_limits::max()) || res <= std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division, because it will produce infinite or too large number"); @@ -122,18 +122,18 @@ struct ModuloImpl template static Result apply(A a, B b) { - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { /// This computation is similar to `fmod` but the latter is not inlined and has 40 times worse performance. return static_cast(a) - trunc(static_cast(a) / static_cast(b)) * static_cast(b); } else { - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) if (isNaN(a) || a > std::numeric_limits::max() || a < std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) if (isNaN(b) || b > std::numeric_limits::max() || b < std::numeric_limits::lowest()) throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); diff --git a/src/Functions/FunctionMathUnary.h b/src/Functions/FunctionMathUnary.h index 8395855a564..2cbd9b2e03c 100644 --- a/src/Functions/FunctionMathUnary.h +++ b/src/Functions/FunctionMathUnary.h @@ -66,7 +66,7 @@ private: /// Process all data as a whole and use FastOps implementation /// If the argument is integer, convert to Float64 beforehand - if constexpr (!is_floating_point_v) + if constexpr (!is_floating_point) { PODArray tmp_vec(size); for (size_t i = 0; i < size; ++i) @@ -150,7 +150,7 @@ private: { using Types = std::decay_t; using Type = typename Types::RightType; - using ReturnType = std::conditional_t, Float64, Type>; + using ReturnType = std::conditional_t, Float64, Type>; using ColVecType = ColumnVectorOrDecimal; const auto col_vec = checkAndGetColumn(col.column.get()); diff --git a/src/Functions/FunctionsConversion.cpp b/src/Functions/FunctionsConversion.cpp index 44d0b750af9..8512ea5726f 100644 --- a/src/Functions/FunctionsConversion.cpp +++ b/src/Functions/FunctionsConversion.cpp @@ -638,7 +638,7 @@ inline void convertFromTime(DataTypeDateTime::FieldType & x, t template void parseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTImpl *, bool precise_float_parsing) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) { if (precise_float_parsing) readFloatTextPrecise(x, rb); @@ -702,7 +702,7 @@ inline void parseImpl(DataTypeIPv6::FieldType & x, ReadBuffer & rb template bool tryParseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTImpl *, bool precise_float_parsing) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) { if (precise_float_parsing) return tryReadFloatTextPrecise(x, rb); @@ -1767,7 +1767,7 @@ struct ConvertImpl else { /// If From Data is Nan or Inf and we convert to integer type, throw exception - if constexpr (std::is_floating_point_v && !std::is_floating_point_v) + if constexpr (is_floating_point && !is_floating_point) { if (!isFinite(vec_from[i])) { @@ -2253,9 +2253,9 @@ private: using RightT = typename RightDataType::FieldType; static constexpr bool bad_left = - is_decimal || std::is_floating_point_v || is_big_int_v || is_signed_v; + is_decimal || is_floating_point || is_big_int_v || is_signed_v; static constexpr bool bad_right = - is_decimal || std::is_floating_point_v || is_big_int_v || is_signed_v; + is_decimal || is_floating_point || is_big_int_v || is_signed_v; /// Disallow int vs UUID conversion (but support int vs UInt128 conversion) if constexpr ((bad_left && std::is_same_v) || @@ -2578,7 +2578,7 @@ struct ToNumberMonotonicity /// Float cases. /// When converting to Float, the conversion is always monotonic. - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) return { .is_monotonic = true, .is_always_monotonic = true }; const auto * low_cardinality = typeid_cast(&type); diff --git a/src/Functions/FunctionsJSON.h b/src/Functions/FunctionsJSON.h index 8a2ad457d34..65c1a6fb2d2 100644 --- a/src/Functions/FunctionsJSON.h +++ b/src/Functions/FunctionsJSON.h @@ -741,7 +741,7 @@ public: switch (element.type()) { case ElementType::DOUBLE: - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) { /// We permit inaccurate conversion of double to float. /// Example: double 0.1 from JSON is not representable in float. @@ -769,7 +769,7 @@ public: case ElementType::STRING: { auto rb = ReadBufferFromMemory{element.getString()}; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) { if (!tryReadFloatText(value, rb) || !rb.eof()) return false; diff --git a/src/Functions/FunctionsRound.h b/src/Functions/FunctionsRound.h index ab62deed45d..46fbe70458d 100644 --- a/src/Functions/FunctionsRound.h +++ b/src/Functions/FunctionsRound.h @@ -453,7 +453,7 @@ template - using FunctionRoundingImpl = std::conditional_t, + using FunctionRoundingImpl = std::conditional_t, FloatRoundingImpl, IntegerRoundingImpl>; diff --git a/src/Functions/FunctionsVisitParam.h b/src/Functions/FunctionsVisitParam.h index 5e13fbbad5c..fd59ea3a9c1 100644 --- a/src/Functions/FunctionsVisitParam.h +++ b/src/Functions/FunctionsVisitParam.h @@ -57,7 +57,7 @@ struct ExtractNumericType ResultType x = 0; if (!in.eof()) { - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) tryReadFloatText(x, in); else tryReadIntText(x, in); diff --git a/src/Functions/abs.cpp b/src/Functions/abs.cpp index 9ac2363f765..3a618686b30 100644 --- a/src/Functions/abs.cpp +++ b/src/Functions/abs.cpp @@ -22,7 +22,7 @@ struct AbsImpl return a < 0 ? static_cast(~a) + 1 : static_cast(a); else if constexpr (is_integer && is_unsigned_v) return static_cast(a); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) return static_cast(std::abs(a)); } diff --git a/src/Functions/array/arrayAggregation.cpp b/src/Functions/array/arrayAggregation.cpp index 03aa5fb9086..9c17e1095c5 100644 --- a/src/Functions/array/arrayAggregation.cpp +++ b/src/Functions/array/arrayAggregation.cpp @@ -85,7 +85,7 @@ struct ArrayAggregateResultImpl std::conditional_t, Decimal128, std::conditional_t, Decimal256, std::conditional_t, Decimal128, - std::conditional_t, Float64, + std::conditional_t, Float64, std::conditional_t, Int64, UInt64>>>>>>>>>>>; }; diff --git a/src/Functions/factorial.cpp b/src/Functions/factorial.cpp index 3b46d9e867f..32bdc84b954 100644 --- a/src/Functions/factorial.cpp +++ b/src/Functions/factorial.cpp @@ -21,7 +21,7 @@ struct FactorialImpl static NO_SANITIZE_UNDEFINED ResultType apply(A a) { - if constexpr (is_floating_point_v || is_over_big_int) + if constexpr (is_floating_point || is_over_big_int) throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type of argument of function factorial, should not be floating point or big int"); diff --git a/src/Functions/if.cpp b/src/Functions/if.cpp index 7a6d37d810d..dded3d46652 100644 --- a/src/Functions/if.cpp +++ b/src/Functions/if.cpp @@ -87,7 +87,7 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[a_index]) + (!cond[i]) * static_cast(b[b_index]); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[b_index], res[i]) } @@ -105,7 +105,7 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[a_index]) + (!cond[i]) * static_cast(b[i]); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[i], res[i]) } @@ -122,7 +122,7 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[i]) + (!cond[i]) * static_cast(b[b_index]); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[b_index], res[i]) } @@ -138,7 +138,7 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[i]) + (!cond[i]) * static_cast(b[i]); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[i], res[i]) } @@ -162,7 +162,7 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[a_index]) + (!cond[i]) * static_cast(b); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b, res[i]) } @@ -178,7 +178,7 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a[i]) + (!cond[i]) * static_cast(b); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b, res[i]) } @@ -200,7 +200,7 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a) + (!cond[i]) * static_cast(b[b_index]); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[b_index], res[i]) } @@ -216,7 +216,7 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar { if constexpr (is_native_int_or_decimal_v) res[i] = !!cond[i] * static_cast(a) + (!cond[i]) * static_cast(b[i]); - else if constexpr (std::is_floating_point_v) + else if constexpr (is_floating_point) { BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[i], res[i]) } diff --git a/src/Functions/minus.cpp b/src/Functions/minus.cpp index 4d86442ad7e..cf318db805b 100644 --- a/src/Functions/minus.cpp +++ b/src/Functions/minus.cpp @@ -17,8 +17,8 @@ struct MinusImpl { if constexpr (is_big_int_v || is_big_int_v) { - using CastA = std::conditional_t, B, A>; - using CastB = std::conditional_t, A, B>; + using CastA = std::conditional_t, B, A>; + using CastB = std::conditional_t, A, B>; return static_cast(static_cast(a)) - static_cast(static_cast(b)); } diff --git a/src/Functions/moduloOrZero.cpp b/src/Functions/moduloOrZero.cpp index d233e4e4ce2..5a4d1539345 100644 --- a/src/Functions/moduloOrZero.cpp +++ b/src/Functions/moduloOrZero.cpp @@ -17,7 +17,7 @@ struct ModuloOrZeroImpl template static Result apply(A a, B b) { - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { /// This computation is similar to `fmod` but the latter is not inlined and has 40 times worse performance. return ResultType(a) - trunc(ResultType(a) / ResultType(b)) * ResultType(b); diff --git a/src/Functions/multiply.cpp b/src/Functions/multiply.cpp index 559143a43b4..740ab81d0d9 100644 --- a/src/Functions/multiply.cpp +++ b/src/Functions/multiply.cpp @@ -18,8 +18,8 @@ struct MultiplyImpl { if constexpr (is_big_int_v || is_big_int_v) { - using CastA = std::conditional_t, B, A>; - using CastB = std::conditional_t, A, B>; + using CastA = std::conditional_t, B, A>; + using CastB = std::conditional_t, A, B>; return static_cast(static_cast(a)) * static_cast(static_cast(b)); } diff --git a/src/Functions/plus.cpp b/src/Functions/plus.cpp index 00136e50c5b..26921713f78 100644 --- a/src/Functions/plus.cpp +++ b/src/Functions/plus.cpp @@ -19,8 +19,8 @@ struct PlusImpl /// Next everywhere, static_cast - so that there is no wrong result in expressions of the form Int64 c = UInt32(a) * Int32(-1). if constexpr (is_big_int_v || is_big_int_v) { - using CastA = std::conditional_t, B, A>; - using CastB = std::conditional_t, A, B>; + using CastA = std::conditional_t, B, A>; + using CastB = std::conditional_t, A, B>; return static_cast(static_cast(a)) + static_cast(static_cast(b)); } diff --git a/src/Functions/sign.cpp b/src/Functions/sign.cpp index 16f0efd2201..a6396a58c0c 100644 --- a/src/Functions/sign.cpp +++ b/src/Functions/sign.cpp @@ -13,7 +13,7 @@ struct SignImpl static NO_SANITIZE_UNDEFINED ResultType apply(A a) { - if constexpr (is_decimal || is_floating_point_v) + if constexpr (is_decimal || is_floating_point) return a < A(0) ? -1 : a == A(0) ? 0 : 1; else if constexpr (is_signed_v) return a < 0 ? -1 : a == 0 ? 0 : 1; diff --git a/src/IO/ReadHelpers.h b/src/IO/ReadHelpers.h index 6dda5a9b089..f1fcbb07af5 100644 --- a/src/IO/ReadHelpers.h +++ b/src/IO/ReadHelpers.h @@ -1382,7 +1382,7 @@ inline bool tryReadText(IPv4 & x, ReadBuffer & buf) { return tryReadIPv4Text(x, inline bool tryReadText(IPv6 & x, ReadBuffer & buf) { return tryReadIPv6Text(x, buf); } template -requires is_floating_point_v +requires is_floating_point inline void readText(T & x, ReadBuffer & buf) { readFloatText(x, buf); } inline void readText(String & x, ReadBuffer & buf) { readEscapedString(x, buf); } diff --git a/src/IO/WriteHelpers.h b/src/IO/WriteHelpers.h index cdeabfcf352..a4eefeaffe2 100644 --- a/src/IO/WriteHelpers.h +++ b/src/IO/WriteHelpers.h @@ -150,7 +150,7 @@ inline void writeBoolText(bool x, WriteBuffer & buf) template -requires is_floating_point_v +requires is_floating_point inline size_t writeFloatTextFastPath(T x, char * buffer) { Int64 result = 0; @@ -182,7 +182,7 @@ inline size_t writeFloatTextFastPath(T x, char * buffer) } template -requires is_floating_point_v +requires is_floating_point inline void writeFloatText(T x, WriteBuffer & buf) { using Converter = DoubleConverter; @@ -530,7 +530,7 @@ void writeJSONNumber(T x, WriteBuffer & ostr, const FormatSettings & settings) bool is_finite = isFinite(x); const bool need_quote = (is_integer && (sizeof(T) >= 8) && settings.json.quote_64bit_integers) - || (settings.json.quote_denormals && !is_finite) || (is_floating_point_v && (sizeof(T) >= 8) && settings.json.quote_64bit_floats); + || (settings.json.quote_denormals && !is_finite) || (is_floating_point && (sizeof(T) >= 8) && settings.json.quote_64bit_floats); if (need_quote) writeChar('"', ostr); @@ -541,7 +541,7 @@ void writeJSONNumber(T x, WriteBuffer & ostr, const FormatSettings & settings) writeCString("null", ostr); else { - if constexpr (is_floating_point_v) + if constexpr (is_floating_point) { if (std::signbit(x)) { @@ -1065,7 +1065,7 @@ inline void writeText(is_integer auto x, WriteBuffer & buf) } template -requires is_floating_point_v +requires is_floating_point inline void writeText(T x, WriteBuffer & buf) { writeFloatText(x, buf); } inline void writeText(is_enum auto x, WriteBuffer & buf) { writeText(magic_enum::enum_name(x), buf); } diff --git a/src/Interpreters/RowRefs.cpp b/src/Interpreters/RowRefs.cpp index 9785ba46dab..c5ffbb96d6f 100644 --- a/src/Interpreters/RowRefs.cpp +++ b/src/Interpreters/RowRefs.cpp @@ -183,7 +183,7 @@ private: if (sorted.load(std::memory_order_relaxed)) return; - if constexpr (std::is_arithmetic_v && !std::is_floating_point_v) + if constexpr (std::is_arithmetic_v && !std::is_floating_point) { if (likely(entries.size() > 256)) { From bf2a8f6a7f6eb8073b60468058f8259cf4a4f341 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 2 Jun 2024 20:43:02 +0200 Subject: [PATCH 04/35] Preparation --- base/base/BFloat16.h | 17 +- src/AggregateFunctions/AggregateFunctionSum.h | 13 +- .../AggregateFunctionUniq.h | 2 +- src/Core/DecimalFunctions.h | 10 +- src/Core/iostream_debug_helpers.cpp | 149 ------------------ src/Core/iostream_debug_helpers.h | 49 ------ src/DataTypes/DataTypesDecimal.cpp | 15 +- src/Dictionaries/RangeHashedDictionary.h | 3 +- src/Functions/FunctionsRound.h | 3 +- src/Functions/array/mapPopulateSeries.cpp | 32 ++-- src/Functions/exp.cpp | 9 +- src/Functions/log.cpp | 9 +- src/Functions/minus.cpp | 4 +- src/Functions/sigmoid.cpp | 10 +- src/Functions/tanh.cpp | 9 +- src/IO/WriteHelpers.h | 14 +- src/Interpreters/RowRefs.cpp | 2 +- src/Parsers/iostream_debug_helpers.cpp | 35 ---- src/Parsers/iostream_debug_helpers.h | 17 -- 19 files changed, 110 insertions(+), 292 deletions(-) delete mode 100644 src/Core/iostream_debug_helpers.cpp delete mode 100644 src/Core/iostream_debug_helpers.h delete mode 100644 src/Parsers/iostream_debug_helpers.cpp delete mode 100644 src/Parsers/iostream_debug_helpers.h diff --git a/base/base/BFloat16.h b/base/base/BFloat16.h index 17c3ebe9ef3..99eab5c67cb 100644 --- a/base/base/BFloat16.h +++ b/base/base/BFloat16.h @@ -1,9 +1,22 @@ #pragma once +#include + + using BFloat16 = __bf16; namespace std { - inline constexpr bool isfinite(BFloat16) { return true; } - inline constexpr bool signbit(BFloat16) { return false; } + inline constexpr bool isfinite(BFloat16 x) { return (bit_cast(x) & 0b0111111110000000) != 0b0111111110000000; } + inline constexpr bool signbit(BFloat16 x) { return bit_cast(x) & 0b1000000000000000; } +} + +inline Float32 BFloat16ToFloat32(BFloat16 x) +{ + return bit_cast(static_cast(bit_cast(x)) << 16); +} + +inline BFloat16 Float32ToBFloat16(Float32 x) +{ + return bit_cast(std::bit_cast(x) >> 16); } diff --git a/src/AggregateFunctions/AggregateFunctionSum.h b/src/AggregateFunctions/AggregateFunctionSum.h index d0d600be70b..f6c51241a5c 100644 --- a/src/AggregateFunctions/AggregateFunctionSum.h +++ b/src/AggregateFunctions/AggregateFunctionSum.h @@ -193,12 +193,11 @@ struct AggregateFunctionSumData Impl::add(sum, local_sum); return; } - else if constexpr (is_floating_point) + else if constexpr (is_floating_point && (sizeof(Value) == 4 || sizeof(Value) == 8)) { - /// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned + /// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned /// integer of the same size and use a mask instead (0 to discard, 0xFF..FF to keep) - static_assert(sizeof(Value) == 4 || sizeof(Value) == 8); - using equivalent_integer = typename std::conditional_t; + using EquivalentInteger = typename std::conditional_t; constexpr size_t unroll_count = 128 / sizeof(T); T partial_sums[unroll_count]{}; @@ -209,11 +208,11 @@ struct AggregateFunctionSumData { for (size_t i = 0; i < unroll_count; ++i) { - equivalent_integer value; - std::memcpy(&value, &ptr[i], sizeof(Value)); + EquivalentInteger value; + memcpy(&value, &ptr[i], sizeof(Value)); value &= (!condition_map[i] != add_if_zero) - 1; Value d; - std::memcpy(&d, &value, sizeof(Value)); + memcpy(&d, &value, sizeof(Value)); Impl::add(partial_sums[i], d); } ptr += unroll_count; diff --git a/src/AggregateFunctions/AggregateFunctionUniq.h b/src/AggregateFunctions/AggregateFunctionUniq.h index cef23f766c7..cd2d3c1eb18 100644 --- a/src/AggregateFunctions/AggregateFunctionUniq.h +++ b/src/AggregateFunctions/AggregateFunctionUniq.h @@ -257,7 +257,7 @@ template struct AggregateFunctionUniqTraits { static UInt64 hash(T x) { - if constexpr (std::is_same_v || std::is_same_v) + if constexpr (is_floating_point) { return bit_cast(x); } diff --git a/src/Core/DecimalFunctions.h b/src/Core/DecimalFunctions.h index 435cef61145..abd660a8a7f 100644 --- a/src/Core/DecimalFunctions.h +++ b/src/Core/DecimalFunctions.h @@ -17,6 +17,7 @@ class DataTypeNumber; namespace ErrorCodes { + extern const int NOT_IMPLEMENTED; extern const int DECIMAL_OVERFLOW; extern const int ARGUMENT_OUT_OF_BOUND; } @@ -310,7 +311,14 @@ ReturnType convertToImpl(const DecimalType & decimal, UInt32 scale, To & result) using DecimalNativeType = typename DecimalType::NativeType; static constexpr bool throw_exception = std::is_void_v; - if constexpr (is_floating_point) + if constexpr (std::is_same_v) + { + if constexpr (throw_exception) + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conversion from Decimal to BFloat16 is not implemented"); + else + return ReturnType(false); + } + else if constexpr (is_floating_point) { result = static_cast(decimal.value) / static_cast(scaleMultiplier(scale)); } diff --git a/src/Core/iostream_debug_helpers.cpp b/src/Core/iostream_debug_helpers.cpp deleted file mode 100644 index 38e61ac4fca..00000000000 --- a/src/Core/iostream_debug_helpers.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "iostream_debug_helpers.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - - -namespace DB -{ - -template <> -std::ostream & operator<< (std::ostream & stream, const Field & what) -{ - stream << applyVisitor(FieldVisitorDump(), what); - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const NameAndTypePair & what) -{ - stream << "NameAndTypePair(name = " << what.name << ", type = " << what.type << ")"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const IDataType & what) -{ - stream << "IDataType(name = " << what.getName() << ", default = " << what.getDefault() << ")"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const IStorage & what) -{ - auto table_id = what.getStorageID(); - stream << "IStorage(name = " << what.getName() << ", tableName = " << table_id.table_name << ") {" - << what.getInMemoryMetadataPtr()->getColumns().getAllPhysical().toString() << "}"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const TableLockHolder &) -{ - stream << "TableStructureReadLock()"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const IFunctionOverloadResolver & what) -{ - stream << "IFunction(name = " << what.getName() << ", variadic = " << what.isVariadic() << ", args = " << what.getNumberOfArguments() - << ")"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const Block & what) -{ - stream << "Block(" - << "num_columns = " << what.columns() << "){" << what.dumpStructure() << "}"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const ColumnWithTypeAndName & what) -{ - stream << "ColumnWithTypeAndName(name = " << what.name << ", type = " << *what.type << ", column = "; - return dumpValue(stream, what.column) << ")"; -} - -std::ostream & operator<<(std::ostream & stream, const IColumn & what) -{ - stream << "IColumn(" << what.dumpStructure() << ")"; - stream << "{"; - for (size_t i = 0; i < what.size(); ++i) - { - if (i) - stream << ", "; - stream << applyVisitor(FieldVisitorDump(), what[i]); - } - stream << "}"; - - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const Packet & what) -{ - stream << "Packet(" - << "type = " << what.type; - // types description: Core/Protocol.h - if (what.exception) - stream << "exception = " << what.exception.get(); - // TODO: profile_info - stream << ") {" << what.block << "}"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what) -{ - stream << "ExpressionActions(" << what.dumpActions() << ")"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const TreeRewriterResult & what) -{ - stream << "SyntaxAnalyzerResult{"; - stream << "storage=" << what.storage << "; "; - if (!what.source_columns.empty()) - { - stream << "source_columns="; - dumpValue(stream, what.source_columns); - stream << "; "; - } - if (!what.aliases.empty()) - { - stream << "aliases="; - dumpValue(stream, what.aliases); - stream << "; "; - } - if (!what.array_join_result_to_source.empty()) - { - stream << "array_join_result_to_source="; - dumpValue(stream, what.array_join_result_to_source); - stream << "; "; - } - if (!what.array_join_alias_to_name.empty()) - { - stream << "array_join_alias_to_name="; - dumpValue(stream, what.array_join_alias_to_name); - stream << "; "; - } - if (!what.array_join_name_to_alias.empty()) - { - stream << "array_join_name_to_alias="; - dumpValue(stream, what.array_join_name_to_alias); - stream << "; "; - } - stream << "rewrite_subqueries=" << what.rewrite_subqueries << "; "; - stream << "}"; - - return stream; -} - -} diff --git a/src/Core/iostream_debug_helpers.h b/src/Core/iostream_debug_helpers.h deleted file mode 100644 index e40bf74583e..00000000000 --- a/src/Core/iostream_debug_helpers.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once -#include - -namespace DB -{ - -// Use template to disable implicit casting for certain overloaded types such as Field, which leads -// to overload resolution ambiguity. -class Field; -template -requires std::is_same_v -std::ostream & operator<<(std::ostream & stream, const T & what); - -struct NameAndTypePair; -std::ostream & operator<<(std::ostream & stream, const NameAndTypePair & what); - -class IDataType; -std::ostream & operator<<(std::ostream & stream, const IDataType & what); - -class IStorage; -std::ostream & operator<<(std::ostream & stream, const IStorage & what); - -class IFunctionOverloadResolver; -std::ostream & operator<<(std::ostream & stream, const IFunctionOverloadResolver & what); - -class IFunctionBase; -std::ostream & operator<<(std::ostream & stream, const IFunctionBase & what); - -class Block; -std::ostream & operator<<(std::ostream & stream, const Block & what); - -struct ColumnWithTypeAndName; -std::ostream & operator<<(std::ostream & stream, const ColumnWithTypeAndName & what); - -class IColumn; -std::ostream & operator<<(std::ostream & stream, const IColumn & what); - -struct Packet; -std::ostream & operator<<(std::ostream & stream, const Packet & what); - -class ExpressionActions; -std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what); - -struct TreeRewriterResult; -std::ostream & operator<<(std::ostream & stream, const TreeRewriterResult & what); -} - -/// some operator<< should be declared before operator<<(... std::shared_ptr<>) -#include diff --git a/src/DataTypes/DataTypesDecimal.cpp b/src/DataTypes/DataTypesDecimal.cpp index d87eff97675..e0304e46b05 100644 --- a/src/DataTypes/DataTypesDecimal.cpp +++ b/src/DataTypes/DataTypesDecimal.cpp @@ -20,6 +20,7 @@ namespace ErrorCodes extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int DECIMAL_OVERFLOW; + extern const int NOT_IMPLEMENTED; } @@ -262,15 +263,19 @@ FOR_EACH_ARITHMETIC_TYPE(INVOKE); template requires (is_arithmetic_v && IsDataTypeDecimal) -ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & value, UInt32 scale, typename ToDataType::FieldType & result) +ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & /*value*/, UInt32 /*scale*/, typename ToDataType::FieldType & /*result*/) { - using FromFieldType = typename FromDataType::FieldType; +/* using FromFieldType = typename FromDataType::FieldType; using ToFieldType = typename ToDataType::FieldType; using ToNativeType = typename ToFieldType::NativeType; static constexpr bool throw_exception = std::is_same_v; - if constexpr (is_floating_point) + if constexpr (std::is_same_v) + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conversion from BFloat16 to Decimal is not implemented"); + } + else if constexpr (is_floating_point) { if (!isFinite(value)) { @@ -302,7 +307,9 @@ ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & value, return ReturnType(convertDecimalsImpl, ToDataType, ReturnType>(static_cast(value), 0, scale, result)); else return ReturnType(convertDecimalsImpl, ToDataType, ReturnType>(static_cast(value), 0, scale, result)); - } + }*/ + + return ReturnType(); } #define DISPATCH(FROM_DATA_TYPE, TO_DATA_TYPE) \ diff --git a/src/Dictionaries/RangeHashedDictionary.h b/src/Dictionaries/RangeHashedDictionary.h index bf004dbe32b..4950e7c8ee6 100644 --- a/src/Dictionaries/RangeHashedDictionary.h +++ b/src/Dictionaries/RangeHashedDictionary.h @@ -298,7 +298,8 @@ namespace impl using Types = std::decay_t; using DataType = typename Types::LeftType; - if constexpr (IsDataTypeDecimalOrNumber || IsDataTypeDateOrDateTime || IsDataTypeEnum) + if constexpr ((IsDataTypeDecimalOrNumber || IsDataTypeDateOrDateTime || IsDataTypeEnum) + && !std::is_same_v) { using ColumnType = typename DataType::ColumnType; func(TypePair()); diff --git a/src/Functions/FunctionsRound.h b/src/Functions/FunctionsRound.h index 46fbe70458d..7eea0d74975 100644 --- a/src/Functions/FunctionsRound.h +++ b/src/Functions/FunctionsRound.h @@ -579,7 +579,8 @@ public: using Types = std::decay_t; using DataType = typename Types::LeftType; - if constexpr (IsDataTypeNumber || IsDataTypeDecimal) + if constexpr ((IsDataTypeNumber || IsDataTypeDecimal) + && !std::is_same_v) { using FieldType = typename DataType::FieldType; res = Dispatcher::apply(column.column.get(), scale_arg); diff --git a/src/Functions/array/mapPopulateSeries.cpp b/src/Functions/array/mapPopulateSeries.cpp index 0db71ab2cf8..759696147c3 100644 --- a/src/Functions/array/mapPopulateSeries.cpp +++ b/src/Functions/array/mapPopulateSeries.cpp @@ -453,23 +453,29 @@ private: using ValueType = typename Types::RightType; static constexpr bool key_and_value_are_numbers = IsDataTypeNumber && IsDataTypeNumber; - static constexpr bool key_is_float = std::is_same_v || std::is_same_v; - if constexpr (key_and_value_are_numbers && !key_is_float) + if constexpr (key_and_value_are_numbers) { - using KeyFieldType = typename KeyType::FieldType; - using ValueFieldType = typename ValueType::FieldType; + if constexpr (is_floating_point) + { + return false; + } + else + { + using KeyFieldType = typename KeyType::FieldType; + using ValueFieldType = typename ValueType::FieldType; - executeImplTyped( - input.key_column, - input.value_column, - input.offsets_column, - input.max_key_column, - std::move(result_columns.result_key_column), - std::move(result_columns.result_value_column), - std::move(result_columns.result_offset_column)); + executeImplTyped( + input.key_column, + input.value_column, + input.offsets_column, + input.max_key_column, + std::move(result_columns.result_key_column), + std::move(result_columns.result_value_column), + std::move(result_columns.result_offset_column)); - return true; + return true; + } } return false; diff --git a/src/Functions/exp.cpp b/src/Functions/exp.cpp index d352cda7460..9b8207afe30 100644 --- a/src/Functions/exp.cpp +++ b/src/Functions/exp.cpp @@ -21,7 +21,14 @@ namespace template static void execute(const T * src, size_t size, T * dst) { - NFastOps::Exp(src, size, dst); + if constexpr (std::is_same_v) + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name); + } + else + { + NFastOps::Exp(src, size, dst); + } } }; } diff --git a/src/Functions/log.cpp b/src/Functions/log.cpp index 9096b8c6f22..d5e10c90c83 100644 --- a/src/Functions/log.cpp +++ b/src/Functions/log.cpp @@ -20,7 +20,14 @@ struct LogName { static constexpr auto name = "log"; }; template static void execute(const T * src, size_t size, T * dst) { - NFastOps::Log(src, size, dst); + if constexpr (std::is_same_v) + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name); + } + else + { + NFastOps::Log(src, size, dst); + } } }; diff --git a/src/Functions/minus.cpp b/src/Functions/minus.cpp index cf318db805b..a372e8d5d78 100644 --- a/src/Functions/minus.cpp +++ b/src/Functions/minus.cpp @@ -17,8 +17,8 @@ struct MinusImpl { if constexpr (is_big_int_v || is_big_int_v) { - using CastA = std::conditional_t, B, A>; - using CastB = std::conditional_t, A, B>; + using CastA = std::conditional_t, B, A>; + using CastB = std::conditional_t, A, B>; return static_cast(static_cast(a)) - static_cast(static_cast(b)); } diff --git a/src/Functions/sigmoid.cpp b/src/Functions/sigmoid.cpp index d121bdc7389..1179329845d 100644 --- a/src/Functions/sigmoid.cpp +++ b/src/Functions/sigmoid.cpp @@ -21,7 +21,14 @@ namespace template static void execute(const T * src, size_t size, T * dst) { - NFastOps::Sigmoid<>(src, size, dst); + if constexpr (std::is_same_v) + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name); + } + else + { + NFastOps::Sigmoid<>(src, size, dst); + } } }; } @@ -47,4 +54,3 @@ REGISTER_FUNCTION(Sigmoid) } } - diff --git a/src/Functions/tanh.cpp b/src/Functions/tanh.cpp index bdefa5263d7..49788b31970 100644 --- a/src/Functions/tanh.cpp +++ b/src/Functions/tanh.cpp @@ -19,7 +19,14 @@ struct TanhName { static constexpr auto name = "tanh"; }; template static void execute(const T * src, size_t size, T * dst) { - NFastOps::Tanh<>(src, size, dst); + if constexpr (std::is_same_v) + { + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name); + } + else + { + NFastOps::Tanh<>(src, size, dst); + } } }; diff --git a/src/IO/WriteHelpers.h b/src/IO/WriteHelpers.h index a4eefeaffe2..d2e2868b245 100644 --- a/src/IO/WriteHelpers.h +++ b/src/IO/WriteHelpers.h @@ -155,7 +155,7 @@ inline size_t writeFloatTextFastPath(T x, char * buffer) { Int64 result = 0; - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { /// The library Ryu has low performance on integers. /// This workaround improves performance 6..10 times. @@ -165,10 +165,16 @@ inline size_t writeFloatTextFastPath(T x, char * buffer) else result = jkj::dragonbox::to_chars_n(x, buffer) - buffer; } - else + else if constexpr (std::is_same_v) { - /// This will support 16-bit floats as well. - float f32 = x; + if (DecomposedFloat32(x).isIntegerInRepresentableRange()) + result = itoa(Int32(x), buffer) - buffer; + else + result = jkj::dragonbox::to_chars_n(x, buffer) - buffer; + } + else if constexpr (std::is_same_v) + { + Float32 f32 = BFloat16ToFloat32(x); if (DecomposedFloat32(f32).isIntegerInRepresentableRange()) result = itoa(Int32(f32), buffer) - buffer; diff --git a/src/Interpreters/RowRefs.cpp b/src/Interpreters/RowRefs.cpp index c5ffbb96d6f..a0fad8840e6 100644 --- a/src/Interpreters/RowRefs.cpp +++ b/src/Interpreters/RowRefs.cpp @@ -183,7 +183,7 @@ private: if (sorted.load(std::memory_order_relaxed)) return; - if constexpr (std::is_arithmetic_v && !std::is_floating_point) + if constexpr (std::is_arithmetic_v && !is_floating_point) { if (likely(entries.size() > 256)) { diff --git a/src/Parsers/iostream_debug_helpers.cpp b/src/Parsers/iostream_debug_helpers.cpp deleted file mode 100644 index b74d337b22d..00000000000 --- a/src/Parsers/iostream_debug_helpers.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "iostream_debug_helpers.h" -#include -#include -#include -#include -#include -#include - -namespace DB -{ - -std::ostream & operator<<(std::ostream & stream, const Token & what) -{ - stream << "Token (type="<< static_cast(what.type) <<"){"<< std::string{what.begin, what.end} << "}"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const Expected & what) -{ - stream << "Expected {variants="; - dumpValue(stream, what.variants) - << "; max_parsed_pos=" << what.max_parsed_pos << "}"; - return stream; -} - -std::ostream & operator<<(std::ostream & stream, const IAST & what) -{ - WriteBufferFromOStream buf(stream, 4096); - buf << "IAST{"; - what.dumpTree(buf); - buf << "}"; - return stream; -} - -} diff --git a/src/Parsers/iostream_debug_helpers.h b/src/Parsers/iostream_debug_helpers.h deleted file mode 100644 index 39f52ebcbc2..00000000000 --- a/src/Parsers/iostream_debug_helpers.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once -#include - -namespace DB -{ -struct Token; -std::ostream & operator<<(std::ostream & stream, const Token & what); - -struct Expected; -std::ostream & operator<<(std::ostream & stream, const Expected & what); - -class IAST; -std::ostream & operator<<(std::ostream & stream, const IAST & what); - -} - -#include From c3f42b7bc770e5e8104527011f6bc51d5b8469ff Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 01:25:25 +0100 Subject: [PATCH 05/35] Something --- cmake/linux/default_libs.cmake | 3 +- src/AggregateFunctions/ReservoirSampler.h | 2 +- src/Columns/ColumnUnique.cpp | 1 + src/Columns/ColumnUnique.h | 1 + src/Columns/IColumn.cpp | 1 + src/Common/FieldVisitorConvertToNumber.cpp | 2 +- src/Common/FieldVisitorConvertToNumber.h | 1 + src/DataTypes/DataTypesBinaryEncoding.cpp | 5 + src/DataTypes/DataTypesNumber.cpp | 1 + src/DataTypes/DataTypesNumber.h | 1 + src/Formats/JSONExtractTree.cpp | 4 +- src/Functions/FunctionBinaryArithmetic.h | 1 + src/Functions/FunctionsConversion.cpp | 1 + src/Functions/FunctionsRound.h | 2 +- src/IO/readFloatText.cpp | 9 ++ src/IO/readFloatText.h | 111 ++++++++++++++++-- .../Impl/Parquet/ParquetDataValuesReader.cpp | 2 + .../Impl/Parquet/ParquetLeafColReader.cpp | 1 + 18 files changed, 132 insertions(+), 17 deletions(-) diff --git a/cmake/linux/default_libs.cmake b/cmake/linux/default_libs.cmake index 51620bc9f33..79875e1ed6b 100644 --- a/cmake/linux/default_libs.cmake +++ b/cmake/linux/default_libs.cmake @@ -3,8 +3,7 @@ set (DEFAULT_LIBS "-nodefaultlibs") -# We need builtins from Clang's RT even without libcxx - for ubsan+int128. -# See https://bugs.llvm.org/show_bug.cgi?id=16404 +# We need builtins from Clang execute_process (COMMAND ${CMAKE_CXX_COMPILER} --target=${CMAKE_CXX_COMPILER_TARGET} --print-libgcc-file-name --rtlib=compiler-rt OUTPUT_VARIABLE BUILTINS_LIBRARY diff --git a/src/AggregateFunctions/ReservoirSampler.h b/src/AggregateFunctions/ReservoirSampler.h index 2668e0dc890..870cb429fb7 100644 --- a/src/AggregateFunctions/ReservoirSampler.h +++ b/src/AggregateFunctions/ReservoirSampler.h @@ -276,6 +276,6 @@ private: { if (OnEmpty == ReservoirSamplerOnEmpty::THROW) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSampler"); - return NanLikeValueConstructor>::getValue(); + return NanLikeValueConstructor>::getValue(); } }; diff --git a/src/Columns/ColumnUnique.cpp b/src/Columns/ColumnUnique.cpp index 54f45204c00..773edbfd590 100644 --- a/src/Columns/ColumnUnique.cpp +++ b/src/Columns/ColumnUnique.cpp @@ -16,6 +16,7 @@ template class ColumnUnique; template class ColumnUnique; template class ColumnUnique; template class ColumnUnique; +template class ColumnUnique; template class ColumnUnique; template class ColumnUnique; template class ColumnUnique; diff --git a/src/Columns/ColumnUnique.h b/src/Columns/ColumnUnique.h index ffa7c311e9e..ce7bbf0766f 100644 --- a/src/Columns/ColumnUnique.h +++ b/src/Columns/ColumnUnique.h @@ -760,6 +760,7 @@ extern template class ColumnUnique; extern template class ColumnUnique; extern template class ColumnUnique; extern template class ColumnUnique; +extern template class ColumnUnique; extern template class ColumnUnique; extern template class ColumnUnique; extern template class ColumnUnique; diff --git a/src/Columns/IColumn.cpp b/src/Columns/IColumn.cpp index c9a0514af4e..4a3886dddb6 100644 --- a/src/Columns/IColumn.cpp +++ b/src/Columns/IColumn.cpp @@ -443,6 +443,7 @@ template class IColumnHelper, ColumnFixedSizeHelper>; template class IColumnHelper, ColumnFixedSizeHelper>; template class IColumnHelper, ColumnFixedSizeHelper>; template class IColumnHelper, ColumnFixedSizeHelper>; +template class IColumnHelper, ColumnFixedSizeHelper>; template class IColumnHelper, ColumnFixedSizeHelper>; template class IColumnHelper, ColumnFixedSizeHelper>; template class IColumnHelper, ColumnFixedSizeHelper>; diff --git a/src/Common/FieldVisitorConvertToNumber.cpp b/src/Common/FieldVisitorConvertToNumber.cpp index 75b3fbfe02a..a5963e3d028 100644 --- a/src/Common/FieldVisitorConvertToNumber.cpp +++ b/src/Common/FieldVisitorConvertToNumber.cpp @@ -1,5 +1,4 @@ #include -#include "base/Decimal.h" namespace DB { @@ -17,6 +16,7 @@ template class FieldVisitorConvertToNumber; template class FieldVisitorConvertToNumber; template class FieldVisitorConvertToNumber; template class FieldVisitorConvertToNumber; +//template class FieldVisitorConvertToNumber; template class FieldVisitorConvertToNumber; template class FieldVisitorConvertToNumber; diff --git a/src/Common/FieldVisitorConvertToNumber.h b/src/Common/FieldVisitorConvertToNumber.h index 638b8805b6a..38d5dc473c4 100644 --- a/src/Common/FieldVisitorConvertToNumber.h +++ b/src/Common/FieldVisitorConvertToNumber.h @@ -129,6 +129,7 @@ extern template class FieldVisitorConvertToNumber; extern template class FieldVisitorConvertToNumber; extern template class FieldVisitorConvertToNumber; extern template class FieldVisitorConvertToNumber; +//extern template class FieldVisitorConvertToNumber; extern template class FieldVisitorConvertToNumber; extern template class FieldVisitorConvertToNumber; diff --git a/src/DataTypes/DataTypesBinaryEncoding.cpp b/src/DataTypes/DataTypesBinaryEncoding.cpp index dc0f2f3f5aa..c3190b462c3 100644 --- a/src/DataTypes/DataTypesBinaryEncoding.cpp +++ b/src/DataTypes/DataTypesBinaryEncoding.cpp @@ -96,6 +96,7 @@ enum class BinaryTypeIndex : uint8_t SimpleAggregateFunction = 0x2E, Nested = 0x2F, JSON = 0x30, + BFloat16 = 0x31, }; /// In future we can introduce more arguments in the JSON data type definition. @@ -151,6 +152,8 @@ BinaryTypeIndex getBinaryTypeIndex(const DataTypePtr & type) return BinaryTypeIndex::Int128; case TypeIndex::Int256: return BinaryTypeIndex::Int256; + case TypeIndex::BFloat16: + return BinaryTypeIndex::BFloat16; case TypeIndex::Float32: return BinaryTypeIndex::Float32; case TypeIndex::Float64: @@ -565,6 +568,8 @@ DataTypePtr decodeDataType(ReadBuffer & buf) return std::make_shared(); case BinaryTypeIndex::Int256: return std::make_shared(); + case BinaryTypeIndex::BFloat16: + return std::make_shared(); case BinaryTypeIndex::Float32: return std::make_shared(); case BinaryTypeIndex::Float64: diff --git a/src/DataTypes/DataTypesNumber.cpp b/src/DataTypes/DataTypesNumber.cpp index 5972cebbca1..4c8918521fe 100644 --- a/src/DataTypes/DataTypesNumber.cpp +++ b/src/DataTypes/DataTypesNumber.cpp @@ -112,6 +112,7 @@ template class DataTypeNumber; template class DataTypeNumber; template class DataTypeNumber; template class DataTypeNumber; +template class DataTypeNumber; template class DataTypeNumber; template class DataTypeNumber; diff --git a/src/DataTypes/DataTypesNumber.h b/src/DataTypes/DataTypesNumber.h index 29899847c4b..a9e77e01b13 100644 --- a/src/DataTypes/DataTypesNumber.h +++ b/src/DataTypes/DataTypesNumber.h @@ -63,6 +63,7 @@ extern template class DataTypeNumber; extern template class DataTypeNumber; extern template class DataTypeNumber; extern template class DataTypeNumber; +extern template class DataTypeNumber; extern template class DataTypeNumber; extern template class DataTypeNumber; diff --git a/src/Formats/JSONExtractTree.cpp b/src/Formats/JSONExtractTree.cpp index ae6051823b7..62905a2e630 100644 --- a/src/Formats/JSONExtractTree.cpp +++ b/src/Formats/JSONExtractTree.cpp @@ -131,7 +131,7 @@ bool tryGetNumericValueFromJSONElement( switch (element.type()) { case ElementType::DOUBLE: - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) { /// We permit inaccurate conversion of double to float. /// Example: double 0.1 from JSON is not representable in float. @@ -175,7 +175,7 @@ bool tryGetNumericValueFromJSONElement( return false; auto rb = ReadBufferFromMemory{element.getString()}; - if constexpr (std::is_floating_point_v) + if constexpr (is_floating_point) { if (!tryReadFloatText(value, rb) || !rb.eof()) { diff --git a/src/Functions/FunctionBinaryArithmetic.h b/src/Functions/FunctionBinaryArithmetic.h index df239b820af..854b40df441 100644 --- a/src/Functions/FunctionBinaryArithmetic.h +++ b/src/Functions/FunctionBinaryArithmetic.h @@ -110,6 +110,7 @@ template constexpr bool IsIntegralOrExtendedOrDecimal = IsDataTypeDecimal; template constexpr bool IsFloatingPoint = false; +template <> inline constexpr bool IsFloatingPoint = true; template <> inline constexpr bool IsFloatingPoint = true; template <> inline constexpr bool IsFloatingPoint = true; diff --git a/src/Functions/FunctionsConversion.cpp b/src/Functions/FunctionsConversion.cpp index 70ec390b576..1c662dd1d9a 100644 --- a/src/Functions/FunctionsConversion.cpp +++ b/src/Functions/FunctionsConversion.cpp @@ -2930,6 +2930,7 @@ template <> struct FunctionTo { using Type = FunctionToInt32; }; template <> struct FunctionTo { using Type = FunctionToInt64; }; template <> struct FunctionTo { using Type = FunctionToInt128; }; template <> struct FunctionTo { using Type = FunctionToInt256; }; +//template <> struct FunctionTo { using Type = FunctionToBFloat16; }; template <> struct FunctionTo { using Type = FunctionToFloat32; }; template <> struct FunctionTo { using Type = FunctionToFloat64; }; diff --git a/src/Functions/FunctionsRound.h b/src/Functions/FunctionsRound.h index 809905c692e..255eca5b406 100644 --- a/src/Functions/FunctionsRound.h +++ b/src/Functions/FunctionsRound.h @@ -694,7 +694,7 @@ public: if (arguments.size() > 1) { const ColumnWithTypeAndName & scale_column = arguments[1]; - res = Dispatcher::template apply(value_arg.column.get(), scale_column.column.get()); + res = Dispatcher::template apply(value_arg.column.get(), scale_column.column.get()); return true; } res = Dispatcher::template apply(value_arg.column.get()); diff --git a/src/IO/readFloatText.cpp b/src/IO/readFloatText.cpp index 17ccc1b25b7..fb3c86fd7b6 100644 --- a/src/IO/readFloatText.cpp +++ b/src/IO/readFloatText.cpp @@ -47,26 +47,35 @@ void assertNaN(ReadBuffer & buf) } +template void readFloatTextPrecise(BFloat16 &, ReadBuffer &); template void readFloatTextPrecise(Float32 &, ReadBuffer &); template void readFloatTextPrecise(Float64 &, ReadBuffer &); +template bool tryReadFloatTextPrecise(BFloat16 &, ReadBuffer &); template bool tryReadFloatTextPrecise(Float32 &, ReadBuffer &); template bool tryReadFloatTextPrecise(Float64 &, ReadBuffer &); +template void readFloatTextFast(BFloat16 &, ReadBuffer &); template void readFloatTextFast(Float32 &, ReadBuffer &); template void readFloatTextFast(Float64 &, ReadBuffer &); +template bool tryReadFloatTextFast(BFloat16 &, ReadBuffer &); template bool tryReadFloatTextFast(Float32 &, ReadBuffer &); template bool tryReadFloatTextFast(Float64 &, ReadBuffer &); +template void readFloatTextSimple(BFloat16 &, ReadBuffer &); template void readFloatTextSimple(Float32 &, ReadBuffer &); template void readFloatTextSimple(Float64 &, ReadBuffer &); +template bool tryReadFloatTextSimple(BFloat16 &, ReadBuffer &); template bool tryReadFloatTextSimple(Float32 &, ReadBuffer &); template bool tryReadFloatTextSimple(Float64 &, ReadBuffer &); +template void readFloatText(BFloat16 &, ReadBuffer &); template void readFloatText(Float32 &, ReadBuffer &); template void readFloatText(Float64 &, ReadBuffer &); +template bool tryReadFloatText(BFloat16 &, ReadBuffer &); template bool tryReadFloatText(Float32 &, ReadBuffer &); template bool tryReadFloatText(Float64 &, ReadBuffer &); +template bool tryReadFloatTextNoExponent(BFloat16 &, ReadBuffer &); template bool tryReadFloatTextNoExponent(Float32 &, ReadBuffer &); template bool tryReadFloatTextNoExponent(Float64 &, ReadBuffer &); diff --git a/src/IO/readFloatText.h b/src/IO/readFloatText.h index c2fec9d4b0b..a7fd6058dd9 100644 --- a/src/IO/readFloatText.h +++ b/src/IO/readFloatText.h @@ -222,7 +222,6 @@ ReturnType readFloatTextPreciseImpl(T & x, ReadBuffer & buf) break; } - char tmp_buf[MAX_LENGTH]; int num_copied_chars = 0; @@ -597,22 +596,85 @@ ReturnType readFloatTextSimpleImpl(T & x, ReadBuffer & buf) return ReturnType(true); } -template void readFloatTextPrecise(T & x, ReadBuffer & in) { readFloatTextPreciseImpl(x, in); } -template bool tryReadFloatTextPrecise(T & x, ReadBuffer & in) { return readFloatTextPreciseImpl(x, in); } +template void readFloatTextPrecise(T & x, ReadBuffer & in) +{ + if constexpr (std::is_same_v) + { + Float32 tmp; + readFloatTextPreciseImpl(tmp, in); + x = BFloat16(tmp); + } + else + readFloatTextPreciseImpl(x, in); +} + +template bool tryReadFloatTextPrecise(T & x, ReadBuffer & in) +{ + if constexpr (std::is_same_v) + { + Float32 tmp; + bool res = readFloatTextPreciseImpl(tmp, in); + if (res) + x = BFloat16(tmp); + return res; + } + else + return readFloatTextPreciseImpl(x, in); +} template void readFloatTextFast(T & x, ReadBuffer & in) { bool has_fractional; - readFloatTextFastImpl(x, in, has_fractional); + if constexpr (std::is_same_v) + { + Float32 tmp; + readFloatTextFastImpl(tmp, in, has_fractional); + x = BFloat16(tmp); + } + else + readFloatTextFastImpl(x, in, has_fractional); } + template bool tryReadFloatTextFast(T & x, ReadBuffer & in) { bool has_fractional; - return readFloatTextFastImpl(x, in, has_fractional); + if constexpr (std::is_same_v) + { + Float32 tmp; + bool res = readFloatTextFastImpl(tmp, in, has_fractional); + if (res) + x = BFloat16(tmp); + return res; + } + else + return readFloatTextFastImpl(x, in, has_fractional); } -template void readFloatTextSimple(T & x, ReadBuffer & in) { readFloatTextSimpleImpl(x, in); } -template bool tryReadFloatTextSimple(T & x, ReadBuffer & in) { return readFloatTextSimpleImpl(x, in); } +template void readFloatTextSimple(T & x, ReadBuffer & in) +{ + if constexpr (std::is_same_v) + { + Float32 tmp; + readFloatTextSimpleImpl(tmp, in); + x = BFloat16(tmp); + } + else + readFloatTextSimpleImpl(x, in); +} + +template bool tryReadFloatTextSimple(T & x, ReadBuffer & in) +{ + if constexpr (std::is_same_v) + { + Float32 tmp; + bool res = readFloatTextSimpleImpl(tmp, in); + if (res) + x = BFloat16(tmp); + return res; + } + else + return readFloatTextSimpleImpl(x, in); +} /// Implementation that is selected as default. @@ -624,18 +686,47 @@ template bool tryReadFloatText(T & x, ReadBuffer & in) { return try template bool tryReadFloatTextNoExponent(T & x, ReadBuffer & in) { bool has_fractional; - return readFloatTextFastImpl(x, in, has_fractional); + if constexpr (std::is_same_v) + { + Float32 tmp; + bool res = readFloatTextFastImpl(tmp, in, has_fractional); + if (res) + x = BFloat16(tmp); + return res; + + } + else + return readFloatTextFastImpl(x, in, has_fractional); } /// With a @has_fractional flag /// Used for input_format_try_infer_integers template bool tryReadFloatTextExt(T & x, ReadBuffer & in, bool & has_fractional) { - return readFloatTextFastImpl(x, in, has_fractional); + if constexpr (std::is_same_v) + { + Float32 tmp; + bool res = readFloatTextFastImpl(tmp, in, has_fractional); + if (res) + x = BFloat16(tmp); + return res; + } + else + return readFloatTextFastImpl(x, in, has_fractional); } + template bool tryReadFloatTextExtNoExponent(T & x, ReadBuffer & in, bool & has_fractional) { - return readFloatTextFastImpl(x, in, has_fractional); + if constexpr (std::is_same_v) + { + Float32 tmp; + bool res = readFloatTextFastImpl(tmp, in, has_fractional); + if (res) + x = BFloat16(tmp); + return res; + } + else + return readFloatTextFastImpl(x, in, has_fractional); } } diff --git a/src/Processors/Formats/Impl/Parquet/ParquetDataValuesReader.cpp b/src/Processors/Formats/Impl/Parquet/ParquetDataValuesReader.cpp index b471989076b..4b79be98810 100644 --- a/src/Processors/Formats/Impl/Parquet/ParquetDataValuesReader.cpp +++ b/src/Processors/Formats/Impl/Parquet/ParquetDataValuesReader.cpp @@ -580,6 +580,7 @@ template class ParquetPlainValuesReader; template class ParquetPlainValuesReader; template class ParquetPlainValuesReader; template class ParquetPlainValuesReader; +template class ParquetPlainValuesReader; template class ParquetPlainValuesReader; template class ParquetPlainValuesReader; template class ParquetPlainValuesReader>; @@ -602,6 +603,7 @@ template class ParquetRleDictReader; template class ParquetRleDictReader; template class ParquetRleDictReader; template class ParquetRleDictReader; +template class ParquetRleDictReader; template class ParquetRleDictReader; template class ParquetRleDictReader; template class ParquetRleDictReader>; diff --git a/src/Processors/Formats/Impl/Parquet/ParquetLeafColReader.cpp b/src/Processors/Formats/Impl/Parquet/ParquetLeafColReader.cpp index c3c7db510ed..328dd37107e 100644 --- a/src/Processors/Formats/Impl/Parquet/ParquetLeafColReader.cpp +++ b/src/Processors/Formats/Impl/Parquet/ParquetLeafColReader.cpp @@ -644,6 +644,7 @@ template class ParquetLeafColReader; template class ParquetLeafColReader; template class ParquetLeafColReader; template class ParquetLeafColReader; +template class ParquetLeafColReader; template class ParquetLeafColReader; template class ParquetLeafColReader; template class ParquetLeafColReader; From 1da6e1fffa8e5cc40d71fee52d6f2742a59d8f21 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 02:25:29 +0100 Subject: [PATCH 06/35] Conversions --- src/Functions/FunctionsConversion.cpp | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/src/Functions/FunctionsConversion.cpp b/src/Functions/FunctionsConversion.cpp index 1c662dd1d9a..f37dff35862 100644 --- a/src/Functions/FunctionsConversion.cpp +++ b/src/Functions/FunctionsConversion.cpp @@ -7,10 +7,8 @@ #include #include #include -#include #include #include -#include #include #include #include @@ -73,8 +71,10 @@ #include #include + namespace DB { + namespace Setting { extern const SettingsBool cast_ipv4_ipv6_default_on_conversion_error; @@ -1862,6 +1862,11 @@ struct ConvertImpl } } + if constexpr ((std::is_same_v || std::is_same_v) + && !(std::is_same_v || std::is_same_v)) + throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conversion from {} to {} is not supported", + TypeName, TypeName); + if constexpr (std::is_same_v || std::is_same_v) { @@ -2875,6 +2880,7 @@ struct NameToInt32 { static constexpr auto name = "toInt32"; }; struct NameToInt64 { static constexpr auto name = "toInt64"; }; struct NameToInt128 { static constexpr auto name = "toInt128"; }; struct NameToInt256 { static constexpr auto name = "toInt256"; }; +struct NameToBFloat16 { static constexpr auto name = "toBFloat16"; }; struct NameToFloat32 { static constexpr auto name = "toFloat32"; }; struct NameToFloat64 { static constexpr auto name = "toFloat64"; }; struct NameToUUID { static constexpr auto name = "toUUID"; }; @@ -2893,6 +2899,7 @@ using FunctionToInt32 = FunctionConvert>; using FunctionToInt128 = FunctionConvert>; using FunctionToInt256 = FunctionConvert>; +using FunctionToBFloat16 = FunctionConvert>; using FunctionToFloat32 = FunctionConvert>; using FunctionToFloat64 = FunctionConvert>; @@ -2930,7 +2937,7 @@ template <> struct FunctionTo { using Type = FunctionToInt32; }; template <> struct FunctionTo { using Type = FunctionToInt64; }; template <> struct FunctionTo { using Type = FunctionToInt128; }; template <> struct FunctionTo { using Type = FunctionToInt256; }; -//template <> struct FunctionTo { using Type = FunctionToBFloat16; }; +template <> struct FunctionTo { using Type = FunctionToBFloat16; }; template <> struct FunctionTo { using Type = FunctionToFloat32; }; template <> struct FunctionTo { using Type = FunctionToFloat64; }; @@ -2973,6 +2980,7 @@ struct NameToInt32OrZero { static constexpr auto name = "toInt32OrZero"; }; struct NameToInt64OrZero { static constexpr auto name = "toInt64OrZero"; }; struct NameToInt128OrZero { static constexpr auto name = "toInt128OrZero"; }; struct NameToInt256OrZero { static constexpr auto name = "toInt256OrZero"; }; +struct NameToBFloat16OrZero { static constexpr auto name = "toBFloat16OrZero"; }; struct NameToFloat32OrZero { static constexpr auto name = "toFloat32OrZero"; }; struct NameToFloat64OrZero { static constexpr auto name = "toFloat64OrZero"; }; struct NameToDateOrZero { static constexpr auto name = "toDateOrZero"; }; @@ -2999,6 +3007,7 @@ using FunctionToInt32OrZero = FunctionConvertFromString; using FunctionToInt128OrZero = FunctionConvertFromString; using FunctionToInt256OrZero = FunctionConvertFromString; +using FunctionToBFloat16OrZero = FunctionConvertFromString; using FunctionToFloat32OrZero = FunctionConvertFromString; using FunctionToFloat64OrZero = FunctionConvertFromString; using FunctionToDateOrZero = FunctionConvertFromString; @@ -3025,6 +3034,7 @@ struct NameToInt32OrNull { static constexpr auto name = "toInt32OrNull"; }; struct NameToInt64OrNull { static constexpr auto name = "toInt64OrNull"; }; struct NameToInt128OrNull { static constexpr auto name = "toInt128OrNull"; }; struct NameToInt256OrNull { static constexpr auto name = "toInt256OrNull"; }; +struct NameToBFloat16OrNull { static constexpr auto name = "toBFloat16OrNull"; }; struct NameToFloat32OrNull { static constexpr auto name = "toFloat32OrNull"; }; struct NameToFloat64OrNull { static constexpr auto name = "toFloat64OrNull"; }; struct NameToDateOrNull { static constexpr auto name = "toDateOrNull"; }; @@ -3051,6 +3061,7 @@ using FunctionToInt32OrNull = FunctionConvertFromString; using FunctionToInt128OrNull = FunctionConvertFromString; using FunctionToInt256OrNull = FunctionConvertFromString; +using FunctionToBFloat16OrNull = FunctionConvertFromString; using FunctionToFloat32OrNull = FunctionConvertFromString; using FunctionToFloat64OrNull = FunctionConvertFromString; using FunctionToDateOrNull = FunctionConvertFromString; @@ -5194,7 +5205,7 @@ private: if constexpr (is_any_of) { @@ -5447,6 +5458,7 @@ REGISTER_FUNCTION(Conversion) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); @@ -5485,6 +5497,7 @@ REGISTER_FUNCTION(Conversion) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); @@ -5513,6 +5526,7 @@ REGISTER_FUNCTION(Conversion) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); From e65bb147d553b3fcd5f361366547b2858a122247 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 02:27:53 +0100 Subject: [PATCH 07/35] Style --- src/Functions/exp.cpp | 6 ++++++ src/Functions/log.cpp | 5 +++++ src/Functions/sigmoid.cpp | 6 ++++++ src/Functions/tanh.cpp | 6 ++++++ 4 files changed, 23 insertions(+) diff --git a/src/Functions/exp.cpp b/src/Functions/exp.cpp index 07c9288e8ab..24f1d313831 100644 --- a/src/Functions/exp.cpp +++ b/src/Functions/exp.cpp @@ -3,6 +3,12 @@ namespace DB { + +namespace ErrorCodes +{ +extern const int NOT_IMPLEMENTED; +} + namespace { diff --git a/src/Functions/log.cpp b/src/Functions/log.cpp index beaa8128b2b..49fc509634b 100644 --- a/src/Functions/log.cpp +++ b/src/Functions/log.cpp @@ -4,6 +4,11 @@ namespace DB { +namespace ErrorCodes +{ +extern const int NOT_IMPLEMENTED; +} + namespace { diff --git a/src/Functions/sigmoid.cpp b/src/Functions/sigmoid.cpp index 1179329845d..bb9710a15fe 100644 --- a/src/Functions/sigmoid.cpp +++ b/src/Functions/sigmoid.cpp @@ -3,6 +3,12 @@ namespace DB { + +namespace ErrorCodes +{ +extern const int NOT_IMPLEMENTED; +} + namespace { diff --git a/src/Functions/tanh.cpp b/src/Functions/tanh.cpp index 293318f9bbb..d0e1440485b 100644 --- a/src/Functions/tanh.cpp +++ b/src/Functions/tanh.cpp @@ -3,6 +3,12 @@ namespace DB { + +namespace ErrorCodes +{ +extern const int NOT_IMPLEMENTED; +} + namespace { From b4acc885f35e4cccae818fca477efffbc9332ded Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 02:37:26 +0100 Subject: [PATCH 08/35] Documentation --- docs/en/sql-reference/data-types/float.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/en/sql-reference/data-types/float.md b/docs/en/sql-reference/data-types/float.md index 3c789076c1e..7185308bdce 100644 --- a/docs/en/sql-reference/data-types/float.md +++ b/docs/en/sql-reference/data-types/float.md @@ -1,10 +1,10 @@ --- slug: /en/sql-reference/data-types/float sidebar_position: 4 -sidebar_label: Float32, Float64 +sidebar_label: Float32, Float64, BFloat16 --- -# Float32, Float64 +# Float32, Float64, BFloat16 :::note If you need accurate calculations, in particular if you work with financial or business data requiring a high precision, you should consider using [Decimal](../data-types/decimal.md) instead. @@ -117,3 +117,11 @@ SELECT 0 / 0 ``` See the rules for `NaN` sorting in the section [ORDER BY clause](../../sql-reference/statements/select/order-by.md). + +## BFloat16 + +`BFloat16` is a 16-bit floating point data type with 8-bit exponent, sign, and 7-bit mantissa. + +It is useful for machine learning and AI applications. + +ClickHouse supports conversions between `Float32` and `BFloat16`. Most of other operations are not supported. From 6cb083621aece140d08d800620b0e5fe7bdc2da0 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 02:47:59 +0100 Subject: [PATCH 09/35] Documentation --- ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt | 1 + utils/check-style/aspell-ignore/en/aspell-dict.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt b/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt index e2966898be2..7cae8509b83 100644 --- a/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt +++ b/ci/jobs/scripts/check_style/aspell-ignore/en/aspell-dict.txt @@ -3131,3 +3131,4 @@ DistributedCachePoolBehaviourOnLimit SharedJoin ShareSet unacked +BFloat diff --git a/utils/check-style/aspell-ignore/en/aspell-dict.txt b/utils/check-style/aspell-ignore/en/aspell-dict.txt index a08143467cd..9765b45c085 100644 --- a/utils/check-style/aspell-ignore/en/aspell-dict.txt +++ b/utils/check-style/aspell-ignore/en/aspell-dict.txt @@ -3154,3 +3154,4 @@ znode znodes zookeeperSessionUptime zstd +BFloat From bec94da77e8333d64c71b4bf778fbf78d10a8519 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 13:19:08 +0100 Subject: [PATCH 10/35] Progressing --- src/DataTypes/DataTypesDecimal.cpp | 8 +++----- src/DataTypes/DataTypesDecimal.h | 2 -- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/DataTypes/DataTypesDecimal.cpp b/src/DataTypes/DataTypesDecimal.cpp index fddae052ada..63bd4bf2a59 100644 --- a/src/DataTypes/DataTypesDecimal.cpp +++ b/src/DataTypes/DataTypesDecimal.cpp @@ -262,9 +262,9 @@ FOR_EACH_ARITHMETIC_TYPE(INVOKE); template requires (is_arithmetic_v && IsDataTypeDecimal) -ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & /*value*/, UInt32 /*scale*/, typename ToDataType::FieldType & /*result*/) +ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & value, UInt32 scale, typename ToDataType::FieldType & result) { -/* using FromFieldType = typename FromDataType::FieldType; + using FromFieldType = typename FromDataType::FieldType; using ToFieldType = typename ToDataType::FieldType; using ToNativeType = typename ToFieldType::NativeType; @@ -306,9 +306,7 @@ ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & /*value return ReturnType(convertDecimalsImpl, ToDataType, ReturnType>(static_cast(value), 0, scale, result)); else return ReturnType(convertDecimalsImpl, ToDataType, ReturnType>(static_cast(value), 0, scale, result)); - }*/ - - return ReturnType(); + } } #define DISPATCH(FROM_DATA_TYPE, TO_DATA_TYPE) \ diff --git a/src/DataTypes/DataTypesDecimal.h b/src/DataTypes/DataTypesDecimal.h index e0d49408981..09a25617506 100644 --- a/src/DataTypes/DataTypesDecimal.h +++ b/src/DataTypes/DataTypesDecimal.h @@ -3,9 +3,7 @@ #include #include #include -#include #include -#include #include #include #include From f0dc1330eb9d830161531819432a611a363fdc6b Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 13:53:08 +0100 Subject: [PATCH 11/35] Rounding --- src/Functions/FunctionsRound.h | 42 ++++++++++++++++++++++++++-------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/src/Functions/FunctionsRound.h b/src/Functions/FunctionsRound.h index 255eca5b406..70ad4d17718 100644 --- a/src/Functions/FunctionsRound.h +++ b/src/Functions/FunctionsRound.h @@ -268,6 +268,19 @@ inline double roundWithMode(double x, RoundingMode mode) std::unreachable(); } +inline BFloat16 roundWithMode(BFloat16 x, RoundingMode mode) +{ + switch (mode) + { + case RoundingMode::Round: return BFloat16(nearbyintf(Float32(x))); + case RoundingMode::Floor: return BFloat16(floorf(Float32(x))); + case RoundingMode::Ceil: return BFloat16(ceilf(Float32(x))); + case RoundingMode::Trunc: return BFloat16(truncf(Float32(x))); + } + + std::unreachable(); +} + template class FloatRoundingComputationBase { @@ -289,6 +302,11 @@ public: } }; +template <> +class FloatRoundingComputationBase : public FloatRoundingComputationBase +{ +}; + /** Implementation of low-level round-off functions for floating-point values. */ @@ -688,20 +706,26 @@ public: using Types = std::decay_t; using DataType = typename Types::RightType; - if constexpr ((IsDataTypeNumber || IsDataTypeDecimal) - && !std::is_same_v) + if (arguments.size() > 1) { - if (arguments.size() > 1) + const ColumnWithTypeAndName & scale_column = arguments[1]; + + auto call_scale = [&](const auto & scaleTypes) -> bool { - const ColumnWithTypeAndName & scale_column = arguments[1]; - res = Dispatcher::template apply(value_arg.column.get(), scale_column.column.get()); + using ScaleTypes = std::decay_t; + using ScaleType = typename ScaleTypes::RightType; + + res = Dispatcher::template apply(value_arg.column.get(), scale_column.column.get()); return true; - } - res = Dispatcher::template apply(value_arg.column.get()); + }; + + TypeIndex right_index = scale_column.type->getTypeId(); + if (!callOnBasicType(right_index, call_scale)) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type"); return true; } - else - return false; + res = Dispatcher::template apply(value_arg.column.get()); + return true; }; #if !defined(__SSE4_1__) From db98fb4c79252d6305eabc06a749e2082bb1c489 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 14:39:45 +0100 Subject: [PATCH 12/35] Documentation --- src/Functions/FunctionsConversion.cpp | 64 +++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 3 deletions(-) diff --git a/src/Functions/FunctionsConversion.cpp b/src/Functions/FunctionsConversion.cpp index f37dff35862..37a4ba30d30 100644 --- a/src/Functions/FunctionsConversion.cpp +++ b/src/Functions/FunctionsConversion.cpp @@ -5458,7 +5458,17 @@ REGISTER_FUNCTION(Conversion) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); - factory.registerFunction(); + + factory.registerFunction(FunctionDocumentation{.description=R"( +Converts Float32 to BFloat16 with losing the precision. + +Example: +[example:typical] +)", + .examples{ + {"typical", "SELECT toBFloat16(12.3::Float32);", "12.3125"}}, + .categories{"Conversion"}}); + factory.registerFunction(); factory.registerFunction(); @@ -5497,7 +5507,31 @@ REGISTER_FUNCTION(Conversion) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); - factory.registerFunction(); + + factory.registerFunction(FunctionDocumentation{.description=R"( +Converts String to BFloat16. + +If the string does not represent a floating point value, the function returns zero. + +The function allows a silent loss of precision while converting from the string representation. In that case, it will return the truncated result. + +Example of successful conversion: +[example:typical] + +Examples of not successful conversion: +[example:invalid1] +[example:invalid2] + +Example of a loss of precision: +[example:precision] +)", + .examples{ + {"typical", "SELECT toBFloat16OrZero('12.3');", "12.3125"}}, + {"invalid1", "SELECT toBFloat16OrZero('abc');", "0"}}, + {"invalid2", "SELECT toBFloat16OrZero(' 1');", "0"}}, + {"precision", "SELECT toBFloat16OrZero('12.3456789');", "12.375"}}, + .categories{"Conversion"}}); + factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); @@ -5526,7 +5560,31 @@ REGISTER_FUNCTION(Conversion) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); - factory.registerFunction(); + + factory.registerFunction(FunctionDocumentation{.description=R"( +Converts String to Nullable(BFloat16). + +If the string does not represent a floating point value, the function returns NULL. + +The function allows a silent loss of precision while converting from the string representation. In that case, it will return the truncated result. + +Example of successful conversion: +[example:typical] + +Examples of not successful conversion: +[example:invalid1] +[example:invalid2] + +Example of a loss of precision: +[example:precision] +)", + .examples{ + {"typical", "SELECT toBFloat16OrNull('12.3');", "12.3125"}}, + {"invalid1", "SELECT toBFloat16OrNull('abc');", "NULL"}}, + {"invalid2", "SELECT toBFloat16OrNull(' 1');", "NULL"}}, + {"precision", "SELECT toBFloat16OrNull('12.3456789');", "12.375"}}, + .categories{"Conversion"}}); + factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); From 1c85a0401fbbddccbd3e310a965ce0eb67079a2b Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 15:14:17 +0100 Subject: [PATCH 13/35] Documentation --- src/Functions/FunctionsConversion.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Functions/FunctionsConversion.cpp b/src/Functions/FunctionsConversion.cpp index 37a4ba30d30..7f4ccc338cf 100644 --- a/src/Functions/FunctionsConversion.cpp +++ b/src/Functions/FunctionsConversion.cpp @@ -5526,9 +5526,9 @@ Example of a loss of precision: [example:precision] )", .examples{ - {"typical", "SELECT toBFloat16OrZero('12.3');", "12.3125"}}, - {"invalid1", "SELECT toBFloat16OrZero('abc');", "0"}}, - {"invalid2", "SELECT toBFloat16OrZero(' 1');", "0"}}, + {"typical", "SELECT toBFloat16OrZero('12.3');", "12.3125"}, + {"invalid1", "SELECT toBFloat16OrZero('abc');", "0"}, + {"invalid2", "SELECT toBFloat16OrZero(' 1');", "0"}, {"precision", "SELECT toBFloat16OrZero('12.3456789');", "12.375"}}, .categories{"Conversion"}}); @@ -5579,9 +5579,9 @@ Example of a loss of precision: [example:precision] )", .examples{ - {"typical", "SELECT toBFloat16OrNull('12.3');", "12.3125"}}, - {"invalid1", "SELECT toBFloat16OrNull('abc');", "NULL"}}, - {"invalid2", "SELECT toBFloat16OrNull(' 1');", "NULL"}}, + {"typical", "SELECT toBFloat16OrNull('12.3');", "12.3125"}, + {"invalid1", "SELECT toBFloat16OrNull('abc');", "NULL"}, + {"invalid2", "SELECT toBFloat16OrNull(' 1');", "NULL"}, {"precision", "SELECT toBFloat16OrNull('12.3456789');", "12.375"}}, .categories{"Conversion"}}); From bf8fc60bacbb95e12760b00960115e2a6230c280 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 16:20:44 +0100 Subject: [PATCH 14/35] Arithmetic --- src/Functions/FunctionBinaryArithmetic.h | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/Functions/FunctionBinaryArithmetic.h b/src/Functions/FunctionBinaryArithmetic.h index 854b40df441..43140427170 100644 --- a/src/Functions/FunctionBinaryArithmetic.h +++ b/src/Functions/FunctionBinaryArithmetic.h @@ -804,7 +804,7 @@ class FunctionBinaryArithmetic : public IFunction DataTypeFixedString, DataTypeString, DataTypeInterval>; - using Floats = TypeList; + using Floats = TypeList; using ValidTypes = std::conditional_t, @@ -2043,7 +2043,15 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A using DecimalResultType = typename BinaryOperationTraits::DecimalResultDataType; if constexpr (std::is_same_v) + { return nullptr; + } + else if constexpr ((std::is_same_v || std::is_same_v) + && (sizeof(typename LeftDataType::FieldType) > 8 || sizeof(typename RightDataType::FieldType) > 8)) + { + /// Big integers and BFloat16 are not supported together. + return nullptr; + } else // we can't avoid the else because otherwise the compiler may assume the ResultDataType may be Invalid // and that would produce the compile error. { @@ -2060,7 +2068,7 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A ColumnPtr left_col = nullptr; ColumnPtr right_col = nullptr; - /// When Decimal op Float32/64, convert both of them into Float64 + /// When Decimal op Float32/64/16, convert both of them into Float64 if constexpr (decimal_with_float) { const auto converted_type = std::make_shared(); @@ -2095,7 +2103,6 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A /// Here we check if we have `intDiv` or `intDivOrZero` and at least one of the arguments is decimal, because in this case originally we had result as decimal, so we need to convert result into integer after calculations else if constexpr (!decimal_with_float && (is_int_div || is_int_div_or_zero) && (IsDataTypeDecimal || IsDataTypeDecimal)) { - if constexpr (!std::is_same_v) { DataTypePtr type_res; From 62c94a784158274e28cf05136cf4023de47f4f01 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 16:40:24 +0100 Subject: [PATCH 15/35] Maybe better --- cmake/cpu_features.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/cpu_features.cmake b/cmake/cpu_features.cmake index 2bb6deb4847..dbc77d835be 100644 --- a/cmake/cpu_features.cmake +++ b/cmake/cpu_features.cmake @@ -85,7 +85,7 @@ elseif (ARCH_AARCH64) # [8] https://developer.arm.com/documentation/102651/a/What-are-dot-product-intructions- # [9] https://developer.arm.com/documentation/dui0801/g/A64-Data-Transfer-Instructions/LDAPR?lang=en # [10] https://github.com/aws/aws-graviton-getting-started/blob/main/README.md - set (COMPILER_FLAGS "${COMPILER_FLAGS} -march=armv8.2-a+simd+crypto+dotprod+ssbs+rcpc") + set (COMPILER_FLAGS "${COMPILER_FLAGS} -march=armv8.2-a+simd+crypto+dotprod+ssbs+rcpc+bf16") endif () # Best-effort check: The build generates and executes intermediate binaries, e.g. protoc and llvm-tablegen. If we build on ARM for ARM From 08e6e598f7c140d0be39a64d933521872716ed2c Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 17:41:37 +0100 Subject: [PATCH 16/35] Better code --- src/Common/findExtreme.h | 2 +- src/DataTypes/IDataType.h | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/Common/findExtreme.h b/src/Common/findExtreme.h index c2b31c51e87..68e7360d6e2 100644 --- a/src/Common/findExtreme.h +++ b/src/Common/findExtreme.h @@ -11,7 +11,7 @@ namespace DB { template -concept has_find_extreme_implementation = (is_any_of); +concept has_find_extreme_implementation = (is_any_of); template std::optional findExtremeMin(const T * __restrict ptr, size_t start, size_t end); diff --git a/src/DataTypes/IDataType.h b/src/DataTypes/IDataType.h index 4d64b927d83..1e41d6b2eba 100644 --- a/src/DataTypes/IDataType.h +++ b/src/DataTypes/IDataType.h @@ -606,7 +606,6 @@ template inline constexpr bool IsDataTypeEnum> = tr M(Int16) \ M(Int32) \ M(Int64) \ - M(BFloat16) \ M(Float32) \ M(Float64) From 7877d59ff6e7334cde310b2eec626bc6ba7442fe Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 22:13:22 +0100 Subject: [PATCH 17/35] Manual implementation --- base/base/BFloat16.h | 300 +++++++++++++++++- base/base/DecomposedFloat.h | 2 +- base/base/TypeLists.h | 5 +- src/AggregateFunctions/AggregateFunctionAvg.h | 2 +- .../AggregateFunctionDeltaSum.cpp | 6 +- .../AggregateFunctionDeltaSumTimestamp.cpp | 10 +- .../AggregateFunctionMaxIntersections.cpp | 6 +- .../AggregateFunctionSparkbar.cpp | 12 +- src/AggregateFunctions/AggregateFunctionSum.h | 11 +- src/Core/Types_fwd.h | 2 +- src/Functions/FunctionsRound.h | 2 +- src/Functions/PolygonUtils.h | 4 +- src/Functions/divide.cpp | 2 +- src/IO/WriteHelpers.h | 2 +- 14 files changed, 318 insertions(+), 48 deletions(-) diff --git a/base/base/BFloat16.h b/base/base/BFloat16.h index 99eab5c67cb..9c6196d6aab 100644 --- a/base/base/BFloat16.h +++ b/base/base/BFloat16.h @@ -1,22 +1,294 @@ #pragma once -#include +#include +#include -using BFloat16 = __bf16; +//using BFloat16 = __bf16; + +class BFloat16 +{ +private: + UInt16 x = 0; + +public: + constexpr BFloat16() = default; + constexpr BFloat16(const BFloat16 & other) = default; + constexpr BFloat16 & operator=(const BFloat16 & other) = default; + + explicit constexpr BFloat16(const Float32 & other) + { + x = static_cast(std::bit_cast(other) >> 16); + } + + template + explicit constexpr BFloat16(const T & other) + : BFloat16(Float32(other)) + { + } + + template + constexpr BFloat16 & operator=(const T & other) + { + *this = BFloat16(other); + return *this; + } + + explicit constexpr operator Float32() const + { + return std::bit_cast(static_cast(x) << 16); + } + + template + explicit constexpr operator T() const + { + return T(Float32(*this)); + } + + constexpr bool isFinite() const + { + return (x & 0b0111111110000000) != 0b0111111110000000; + } + + constexpr bool isNaN() const + { + return !isFinite() && (x & 0b0000000001111111) != 0b0000000000000000; + } + + constexpr bool signBit() const + { + return x & 0b1000000000000000; + } + + constexpr bool operator==(const BFloat16 & other) const + { + return x == other.x; + } + + constexpr bool operator!=(const BFloat16 & other) const + { + return x != other.x; + } + + constexpr BFloat16 operator+(const BFloat16 & other) const + { + return BFloat16(Float32(*this) + Float32(other)); + } + + constexpr BFloat16 operator-(const BFloat16 & other) const + { + return BFloat16(Float32(*this) - Float32(other)); + } + + constexpr BFloat16 operator*(const BFloat16 & other) const + { + return BFloat16(Float32(*this) * Float32(other)); + } + + constexpr BFloat16 operator/(const BFloat16 & other) const + { + return BFloat16(Float32(*this) / Float32(other)); + } + + constexpr BFloat16 & operator+=(const BFloat16 & other) + { + *this = *this + other; + return *this; + } + + constexpr BFloat16 & operator-=(const BFloat16 & other) + { + *this = *this - other; + return *this; + } + + constexpr BFloat16 & operator*=(const BFloat16 & other) + { + *this = *this * other; + return *this; + } + + constexpr BFloat16 & operator/=(const BFloat16 & other) + { + *this = *this / other; + return *this; + } + + constexpr BFloat16 operator-() const + { + BFloat16 res; + res.x = x ^ 0b1000000000000000; + return res; + } +}; + + +template +requires(!std::is_same_v) +constexpr bool operator==(const BFloat16 & a, const T & b) +{ + return Float32(a) == b; +} + +template +requires(!std::is_same_v) +constexpr bool operator==(const T & a, const BFloat16 & b) +{ + return a == Float32(b); +} + +template +requires(!std::is_same_v) +constexpr bool operator!=(const BFloat16 & a, const T & b) +{ + return Float32(a) != b; +} + +template +requires(!std::is_same_v) +constexpr bool operator!=(const T & a, const BFloat16 & b) +{ + return a != Float32(b); +} + +template +requires(!std::is_same_v) +constexpr bool operator<(const BFloat16 & a, const T & b) +{ + return Float32(a) < b; +} + +template +requires(!std::is_same_v) +constexpr bool operator<(const T & a, const BFloat16 & b) +{ + return a < Float32(b); +} + +constexpr inline bool operator<(BFloat16 a, BFloat16 b) +{ + return Float32(a) < Float32(b); +} + +template +requires(!std::is_same_v) +constexpr bool operator>(const BFloat16 & a, const T & b) +{ + return Float32(a) > b; +} + +template +requires(!std::is_same_v) +constexpr bool operator>(const T & a, const BFloat16 & b) +{ + return a > Float32(b); +} + +constexpr inline bool operator>(BFloat16 a, BFloat16 b) +{ + return Float32(a) > Float32(b); +} + + +template +requires(!std::is_same_v) +constexpr bool operator<=(const BFloat16 & a, const T & b) +{ + return Float32(a) <= b; +} + +template +requires(!std::is_same_v) +constexpr bool operator<=(const T & a, const BFloat16 & b) +{ + return a <= Float32(b); +} + +constexpr inline bool operator<=(BFloat16 a, BFloat16 b) +{ + return Float32(a) <= Float32(b); +} + +template +requires(!std::is_same_v) +constexpr bool operator>=(const BFloat16 & a, const T & b) +{ + return Float32(a) >= b; +} + +template +requires(!std::is_same_v) +constexpr bool operator>=(const T & a, const BFloat16 & b) +{ + return a >= Float32(b); +} + +constexpr inline bool operator>=(BFloat16 a, BFloat16 b) +{ + return Float32(a) >= Float32(b); +} + + +template +requires(!std::is_same_v) +constexpr inline auto operator+(T a, BFloat16 b) +{ + return a + Float32(b); +} + +template +requires(!std::is_same_v) +constexpr inline auto operator+(BFloat16 a, T b) +{ + return Float32(a) + b; +} + +template +requires(!std::is_same_v) +constexpr inline auto operator-(T a, BFloat16 b) +{ + return a - Float32(b); +} + +template +requires(!std::is_same_v) +constexpr inline auto operator-(BFloat16 a, T b) +{ + return Float32(a) - b; +} + +template +requires(!std::is_same_v) +constexpr inline auto operator*(T a, BFloat16 b) +{ + return a * Float32(b); +} + +template +requires(!std::is_same_v) +constexpr inline auto operator*(BFloat16 a, T b) +{ + return Float32(a) * b; +} + +template +requires(!std::is_same_v) +constexpr inline auto operator/(T a, BFloat16 b) +{ + return a / Float32(b); +} + +template +requires(!std::is_same_v) +constexpr inline auto operator/(BFloat16 a, T b) +{ + return Float32(a) / b; +} + namespace std { - inline constexpr bool isfinite(BFloat16 x) { return (bit_cast(x) & 0b0111111110000000) != 0b0111111110000000; } - inline constexpr bool signbit(BFloat16 x) { return bit_cast(x) & 0b1000000000000000; } -} - -inline Float32 BFloat16ToFloat32(BFloat16 x) -{ - return bit_cast(static_cast(bit_cast(x)) << 16); -} - -inline BFloat16 Float32ToBFloat16(Float32 x) -{ - return bit_cast(std::bit_cast(x) >> 16); + inline constexpr bool isfinite(BFloat16 x) { return x.isFinite(); } + inline constexpr bool isnan(BFloat16 x) { return x.isNaN(); } + inline constexpr bool signbit(BFloat16 x) { return x.signBit(); } } diff --git a/base/base/DecomposedFloat.h b/base/base/DecomposedFloat.h index 26a929b4997..3bd059cb21c 100644 --- a/base/base/DecomposedFloat.h +++ b/base/base/DecomposedFloat.h @@ -11,7 +11,7 @@ template struct FloatTraits; template <> -struct FloatTraits<__bf16> +struct FloatTraits { using UInt = uint16_t; static constexpr size_t bits = 16; diff --git a/base/base/TypeLists.h b/base/base/TypeLists.h index ce3111b1da3..375ea94b5ea 100644 --- a/base/base/TypeLists.h +++ b/base/base/TypeLists.h @@ -9,10 +9,11 @@ namespace DB { using TypeListNativeInt = TypeList; -using TypeListFloat = TypeList; -using TypeListNativeNumber = TypeListConcat; +using TypeListNativeFloat = TypeList; +using TypeListNativeNumber = TypeListConcat; using TypeListWideInt = TypeList; using TypeListInt = TypeListConcat; +using TypeListFloat = TypeListConcat>; using TypeListIntAndFloat = TypeListConcat; using TypeListDecimal = TypeList; using TypeListNumber = TypeListConcat; diff --git a/src/AggregateFunctions/AggregateFunctionAvg.h b/src/AggregateFunctions/AggregateFunctionAvg.h index 6e1e9289565..8d53a081ee0 100644 --- a/src/AggregateFunctions/AggregateFunctionAvg.h +++ b/src/AggregateFunctions/AggregateFunctionAvg.h @@ -231,7 +231,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final { - increment(place, static_cast(*columns[0]).getData()[row_num]); + increment(place, Numerator(static_cast(*columns[0]).getData()[row_num])); ++this->data(place).denominator; } diff --git a/src/AggregateFunctions/AggregateFunctionDeltaSum.cpp b/src/AggregateFunctions/AggregateFunctionDeltaSum.cpp index 42169c34c25..c61b9918a35 100644 --- a/src/AggregateFunctions/AggregateFunctionDeltaSum.cpp +++ b/src/AggregateFunctions/AggregateFunctionDeltaSum.cpp @@ -27,9 +27,9 @@ namespace template struct AggregationFunctionDeltaSumData { - T sum = 0; - T last = 0; - T first = 0; + T sum{}; + T last{}; + T first{}; bool seen = false; }; diff --git a/src/AggregateFunctions/AggregateFunctionDeltaSumTimestamp.cpp b/src/AggregateFunctions/AggregateFunctionDeltaSumTimestamp.cpp index 5819c533fd9..dc1adead87c 100644 --- a/src/AggregateFunctions/AggregateFunctionDeltaSumTimestamp.cpp +++ b/src/AggregateFunctions/AggregateFunctionDeltaSumTimestamp.cpp @@ -25,11 +25,11 @@ namespace template struct AggregationFunctionDeltaSumTimestampData { - ValueType sum = 0; - ValueType first = 0; - ValueType last = 0; - TimestampType first_ts = 0; - TimestampType last_ts = 0; + ValueType sum{}; + ValueType first{}; + ValueType last{}; + TimestampType first_ts{}; + TimestampType last_ts{}; bool seen = false; }; diff --git a/src/AggregateFunctions/AggregateFunctionMaxIntersections.cpp b/src/AggregateFunctions/AggregateFunctionMaxIntersections.cpp index ca91f960dab..f4edec7f528 100644 --- a/src/AggregateFunctions/AggregateFunctionMaxIntersections.cpp +++ b/src/AggregateFunctions/AggregateFunctionMaxIntersections.cpp @@ -155,9 +155,9 @@ public: void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - Int64 current_intersections = 0; - Int64 max_intersections = 0; - PointType position_of_max_intersections = 0; + Int64 current_intersections{}; + Int64 max_intersections{}; + PointType position_of_max_intersections{}; /// const_cast because we will sort the array auto & array = this->data(place).value; diff --git a/src/AggregateFunctions/AggregateFunctionSparkbar.cpp b/src/AggregateFunctions/AggregateFunctionSparkbar.cpp index 33412d50b21..de2a741e105 100644 --- a/src/AggregateFunctions/AggregateFunctionSparkbar.cpp +++ b/src/AggregateFunctions/AggregateFunctionSparkbar.cpp @@ -45,7 +45,7 @@ struct AggregateFunctionSparkbarData Y insert(const X & x, const Y & y) { if (isNaN(y) || y <= 0) - return 0; + return {}; auto [it, inserted] = points.insert({x, y}); if (!inserted) @@ -173,13 +173,13 @@ private: if (from_x >= to_x) { - size_t sz = updateFrame(values, 8); + size_t sz = updateFrame(values, Y{8}); values.push_back('\0'); offsets.push_back(offsets.empty() ? sz + 1 : offsets.back() + sz + 1); return; } - PaddedPODArray histogram(width, 0); + PaddedPODArray histogram(width, Y{0}); PaddedPODArray count_histogram(width, 0); /// The number of points in each bucket for (const auto & point : data.points) @@ -218,10 +218,10 @@ private: for (size_t i = 0; i < histogram.size(); ++i) { if (count_histogram[i] > 0) - histogram[i] /= count_histogram[i]; + histogram[i] = histogram[i] / count_histogram[i]; } - Y y_max = 0; + Y y_max{}; for (auto & y : histogram) { if (isNaN(y) || y <= 0) @@ -245,7 +245,7 @@ private: continue; } - constexpr auto levels_num = static_cast(BAR_LEVELS - 1); + constexpr auto levels_num = Y{BAR_LEVELS - 1}; if constexpr (is_floating_point) { y = y / (y_max / levels_num) + 1; diff --git a/src/AggregateFunctions/AggregateFunctionSum.h b/src/AggregateFunctions/AggregateFunctionSum.h index f6c51241a5c..7c7fb6338a2 100644 --- a/src/AggregateFunctions/AggregateFunctionSum.h +++ b/src/AggregateFunctions/AggregateFunctionSum.h @@ -83,7 +83,7 @@ struct AggregateFunctionSumData while (ptr < unrolled_end) { for (size_t i = 0; i < unroll_count; ++i) - Impl::add(partial_sums[i], ptr[i]); + Impl::add(partial_sums[i], T(ptr[i])); ptr += unroll_count; } @@ -95,7 +95,7 @@ struct AggregateFunctionSumData T local_sum{}; while (ptr < end_ptr) { - Impl::add(local_sum, *ptr); + Impl::add(local_sum, T(*ptr)); ++ptr; } Impl::add(sum, local_sum); @@ -227,7 +227,7 @@ struct AggregateFunctionSumData while (ptr < end_ptr) { if (!*condition_map == add_if_zero) - Impl::add(local_sum, *ptr); + Impl::add(local_sum, T(*ptr)); ++ptr; ++condition_map; } @@ -488,10 +488,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { const auto & column = assert_cast(*columns[0]); - if constexpr (is_big_int_v) - this->data(place).add(static_cast(column.getData()[row_num])); - else - this->data(place).add(column.getData()[row_num]); + this->data(place).add(static_cast(column.getData()[row_num])); } void addBatchSinglePlace( diff --git a/src/Core/Types_fwd.h b/src/Core/Types_fwd.h index 6d3383ae7ff..b94a29ce72c 100644 --- a/src/Core/Types_fwd.h +++ b/src/Core/Types_fwd.h @@ -21,7 +21,7 @@ using Int128 = wide::integer<128, signed>; using UInt128 = wide::integer<128, unsigned>; using Int256 = wide::integer<256, signed>; using UInt256 = wide::integer<256, unsigned>; -using BFloat16 = __bf16; +class BFloat16; namespace DB { diff --git a/src/Functions/FunctionsRound.h b/src/Functions/FunctionsRound.h index 70ad4d17718..6c9cc8a37b3 100644 --- a/src/Functions/FunctionsRound.h +++ b/src/Functions/FunctionsRound.h @@ -298,7 +298,7 @@ public: static VectorType prepare(size_t scale) { - return load1(scale); + return load1(ScalarType(scale)); } }; diff --git a/src/Functions/PolygonUtils.h b/src/Functions/PolygonUtils.h index bf8241774a6..601ffcb00b4 100644 --- a/src/Functions/PolygonUtils.h +++ b/src/Functions/PolygonUtils.h @@ -583,7 +583,7 @@ struct CallPointInPolygon template static ColumnPtr call(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl) { - using Impl = TypeListChangeRoot; + using Impl = TypeListChangeRoot; if (auto column = typeid_cast *>(&x)) return Impl::template call(*column, y, impl); return CallPointInPolygon::call(x, y, impl); @@ -609,7 +609,7 @@ struct CallPointInPolygon<> template NO_INLINE ColumnPtr pointInPolygon(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl) { - using Impl = TypeListChangeRoot; + using Impl = TypeListChangeRoot; return Impl::call(x, y, impl); } diff --git a/src/Functions/divide.cpp b/src/Functions/divide.cpp index 7c67245c382..3947ba2d142 100644 --- a/src/Functions/divide.cpp +++ b/src/Functions/divide.cpp @@ -18,7 +18,7 @@ struct DivideFloatingImpl template static NO_SANITIZE_UNDEFINED Result apply(A a [[maybe_unused]], B b [[maybe_unused]]) { - return static_cast(a) / b; + return static_cast(a) / static_cast(b); } #if USE_EMBEDDED_COMPILER diff --git a/src/IO/WriteHelpers.h b/src/IO/WriteHelpers.h index f01e09e3f73..0a32c4c5446 100644 --- a/src/IO/WriteHelpers.h +++ b/src/IO/WriteHelpers.h @@ -174,7 +174,7 @@ inline size_t writeFloatTextFastPath(T x, char * buffer) } else if constexpr (std::is_same_v) { - Float32 f32 = BFloat16ToFloat32(x); + Float32 f32 = Float32(x); if (DecomposedFloat32(f32).isIntegerInRepresentableRange()) result = itoa(Int32(f32), buffer) - buffer; From 16d05bbc6d9a1369b393f836d0ccd8ea64fe2057 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 22:41:40 +0100 Subject: [PATCH 18/35] Comparisons --- base/base/BFloat16.h | 22 +++++++++++++++++++++- src/Functions/FunctionsComparison.h | 7 +++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/base/base/BFloat16.h b/base/base/BFloat16.h index 9c6196d6aab..f7491b64eb3 100644 --- a/base/base/BFloat16.h +++ b/base/base/BFloat16.h @@ -4,8 +4,28 @@ #include -//using BFloat16 = __bf16; +/** BFloat16 is a 16-bit floating point type, which has the same number (8) of exponent bits as Float32. + * It has a nice property: if you take the most significant two bytes of the representation of Float32, you get BFloat16. + * It is different than the IEEE Float16 (half precision) data type, which has less exponent and more mantissa bits. + * + * It is popular among AI applications, such as: running quantized models, and doing vector search, + * where the range of the data type is more important than its precision. + * + * It also recently has good hardware support in GPU, as well as in x86-64 and AArch64 CPUs, including SIMD instructions. + * But it is rarely utilized by compilers. + * + * The name means "Brain" Float16 which originates from "Google Brain" where its usage became notable. + * It is also known under the name "bf16". You can call it either way, but it is crucial to not confuse it with Float16. + * Here is a manual implementation of this data type. Only required operations are implemented. + * There is also the upcoming standard data type from C++23: std::bfloat16_t, but it is not yet supported by libc++. + * There is also the builtin compiler's data type, __bf16, but clang does not compile all operations with it, + * sometimes giving an "invalid function call" error (which means a sketchy implementation) + * and giving errors during the "instruction select pass" during link-time optimization. + * + * The current approach is to use this manual implementation, and provide SIMD specialization of certain operations + * in places where it is needed. + */ class BFloat16 { private: diff --git a/src/Functions/FunctionsComparison.h b/src/Functions/FunctionsComparison.h index be0875581a5..bcb9e0641b8 100644 --- a/src/Functions/FunctionsComparison.h +++ b/src/Functions/FunctionsComparison.h @@ -721,6 +721,7 @@ private: || (res = executeNumRightType(col_left, col_right_untyped)) || (res = executeNumRightType(col_left, col_right_untyped)) || (res = executeNumRightType(col_left, col_right_untyped)) + || (res = executeNumRightType(col_left, col_right_untyped)) || (res = executeNumRightType(col_left, col_right_untyped)) || (res = executeNumRightType(col_left, col_right_untyped))) return res; @@ -741,6 +742,7 @@ private: || (res = executeNumConstRightType(col_left_const, col_right_untyped)) || (res = executeNumConstRightType(col_left_const, col_right_untyped)) || (res = executeNumConstRightType(col_left_const, col_right_untyped)) + || (res = executeNumConstRightType(col_left_const, col_right_untyped)) || (res = executeNumConstRightType(col_left_const, col_right_untyped)) || (res = executeNumConstRightType(col_left_const, col_right_untyped))) return res; @@ -1289,9 +1291,10 @@ public: || (res = executeNumLeftType(col_left_untyped, col_right_untyped)) || (res = executeNumLeftType(col_left_untyped, col_right_untyped)) || (res = executeNumLeftType(col_left_untyped, col_right_untyped)) + || (res = executeNumLeftType(col_left_untyped, col_right_untyped)) || (res = executeNumLeftType(col_left_untyped, col_right_untyped)) || (res = executeNumLeftType(col_left_untyped, col_right_untyped)))) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}", + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of the first argument of function {}", col_left_untyped->getName(), getName()); return res; @@ -1339,7 +1342,7 @@ public: getName(), left_type->getName(), right_type->getName()); - /// When Decimal comparing to Float32/64, we convert both of them into Float64. + /// When Decimal comparing to Float32/64/16, we convert both of them into Float64. /// Other systems like MySQL and Spark also do as this. if (left_is_float || right_is_float) { From 92e8fa23ba0073f2caa43d66bab5d99475d3c656 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 23:43:10 +0100 Subject: [PATCH 19/35] Remove obsolete setting from tests --- src/Databases/enableAllExperimentalSettings.cpp | 1 - tests/performance/avg_weighted.xml | 1 - tests/performance/reinterpret_as.xml | 1 - tests/queries/0_stateless/01035_avg.sql | 2 -- .../0_stateless/01182_materialized_view_different_structure.sql | 1 - tests/queries/0_stateless/01440_big_int_exotic_casts.sql | 2 -- .../0_stateless/01554_bloom_filter_index_big_integer_uuid.sql | 2 -- tests/queries/0_stateless/01622_byte_size.sql | 2 -- tests/queries/0_stateless/01721_dictionary_decimal_p_s.sql | 2 -- tests/queries/0_stateless/01804_dictionary_decimal256_type.sql | 2 -- .../0_stateless/01875_ssd_cache_dictionary_decimal256_type.sh | 2 -- 11 files changed, 18 deletions(-) diff --git a/src/Databases/enableAllExperimentalSettings.cpp b/src/Databases/enableAllExperimentalSettings.cpp index d51d2671992..bc2dae55f97 100644 --- a/src/Databases/enableAllExperimentalSettings.cpp +++ b/src/Databases/enableAllExperimentalSettings.cpp @@ -24,7 +24,6 @@ void enableAllExperimentalSettings(ContextMutablePtr context) context->setSetting("allow_experimental_dynamic_type", 1); context->setSetting("allow_experimental_json_type", 1); context->setSetting("allow_experimental_vector_similarity_index", 1); - context->setSetting("allow_experimental_bigint_types", 1); context->setSetting("allow_experimental_window_functions", 1); context->setSetting("allow_experimental_geo_types", 1); context->setSetting("allow_experimental_map_type", 1); diff --git a/tests/performance/avg_weighted.xml b/tests/performance/avg_weighted.xml index edf3c19fdfa..ec1b7aae5c2 100644 --- a/tests/performance/avg_weighted.xml +++ b/tests/performance/avg_weighted.xml @@ -1,6 +1,5 @@ - 1 1 8 diff --git a/tests/performance/reinterpret_as.xml b/tests/performance/reinterpret_as.xml index d05ef3bb038..2e0fa0571c3 100644 --- a/tests/performance/reinterpret_as.xml +++ b/tests/performance/reinterpret_as.xml @@ -1,6 +1,5 @@ - 1 15G diff --git a/tests/queries/0_stateless/01035_avg.sql b/tests/queries/0_stateless/01035_avg.sql index a3cb35a80ec..0f7baddaec5 100644 --- a/tests/queries/0_stateless/01035_avg.sql +++ b/tests/queries/0_stateless/01035_avg.sql @@ -1,5 +1,3 @@ -SET allow_experimental_bigint_types=1; - CREATE TABLE IF NOT EXISTS test_01035_avg ( i8 Int8 DEFAULT i64, i16 Int16 DEFAULT i64, diff --git a/tests/queries/0_stateless/01182_materialized_view_different_structure.sql b/tests/queries/0_stateless/01182_materialized_view_different_structure.sql index 485f9985974..7e41172bd0c 100644 --- a/tests/queries/0_stateless/01182_materialized_view_different_structure.sql +++ b/tests/queries/0_stateless/01182_materialized_view_different_structure.sql @@ -20,7 +20,6 @@ SELECT sum(value) FROM (SELECT number, sum(number) AS value FROM (SELECT *, toDe CREATE TABLE src (n UInt64, s FixedString(16)) ENGINE=Memory; CREATE TABLE dst (n UInt8, s String) ENGINE = Memory; CREATE MATERIALIZED VIEW mv TO dst (n String) AS SELECT * FROM src; -SET allow_experimental_bigint_types=1; CREATE TABLE dist (n Int128) ENGINE=Distributed(test_cluster_two_shards, currentDatabase(), mv); INSERT INTO src SELECT number, toString(number) FROM numbers(1000); diff --git a/tests/queries/0_stateless/01440_big_int_exotic_casts.sql b/tests/queries/0_stateless/01440_big_int_exotic_casts.sql index 42fde9da01b..f411af897e8 100644 --- a/tests/queries/0_stateless/01440_big_int_exotic_casts.sql +++ b/tests/queries/0_stateless/01440_big_int_exotic_casts.sql @@ -32,8 +32,6 @@ SELECT number y, toInt128(number) - y, toInt256(number) - y, toUInt256(number) - SELECT -number y, toInt128(number) + y, toInt256(number) + y, toUInt256(number) + y FROM numbers_mt(10) ORDER BY number; -SET allow_experimental_bigint_types = 1; - DROP TABLE IF EXISTS t; CREATE TABLE t (x UInt64, i256 Int256, u256 UInt256, d256 Decimal256(2)) ENGINE = Memory; diff --git a/tests/queries/0_stateless/01554_bloom_filter_index_big_integer_uuid.sql b/tests/queries/0_stateless/01554_bloom_filter_index_big_integer_uuid.sql index 3472f41092d..f82fe39f439 100644 --- a/tests/queries/0_stateless/01554_bloom_filter_index_big_integer_uuid.sql +++ b/tests/queries/0_stateless/01554_bloom_filter_index_big_integer_uuid.sql @@ -1,5 +1,3 @@ -SET allow_experimental_bigint_types = 1; - CREATE TABLE 01154_test (x Int128, INDEX ix_x x TYPE bloom_filter(0.01) GRANULARITY 1) ENGINE = MergeTree() ORDER BY x SETTINGS index_granularity=8192; INSERT INTO 01154_test VALUES (1), (2), (3); SELECT x FROM 01154_test WHERE x = 1; diff --git a/tests/queries/0_stateless/01622_byte_size.sql b/tests/queries/0_stateless/01622_byte_size.sql index 9f9de4e58e9..f73011f4151 100644 --- a/tests/queries/0_stateless/01622_byte_size.sql +++ b/tests/queries/0_stateless/01622_byte_size.sql @@ -4,8 +4,6 @@ select ''; select '# byteSize'; -set allow_experimental_bigint_types = 1; - -- numbers #0 -- select ''; select 'byteSize for numbers #0'; diff --git a/tests/queries/0_stateless/01721_dictionary_decimal_p_s.sql b/tests/queries/0_stateless/01721_dictionary_decimal_p_s.sql index 272bd2d7104..57483430cc0 100644 --- a/tests/queries/0_stateless/01721_dictionary_decimal_p_s.sql +++ b/tests/queries/0_stateless/01721_dictionary_decimal_p_s.sql @@ -1,6 +1,5 @@ -- Tags: no-parallel -set allow_experimental_bigint_types=1; drop database if exists db_01721; drop table if exists db_01721.table_decimal_dict; drop dictionary if exists db_01721.decimal_dict; @@ -77,4 +76,3 @@ SELECT dictGet('db_01721.decimal_dict', 'Decimal32_', toUInt64(5000)), drop table if exists table_decimal_dict; drop dictionary if exists cache_dict; drop database if exists db_01721; - diff --git a/tests/queries/0_stateless/01804_dictionary_decimal256_type.sql b/tests/queries/0_stateless/01804_dictionary_decimal256_type.sql index 08a8d0feb27..32b029442b9 100644 --- a/tests/queries/0_stateless/01804_dictionary_decimal256_type.sql +++ b/tests/queries/0_stateless/01804_dictionary_decimal256_type.sql @@ -1,7 +1,5 @@ -- Tags: no-parallel -SET allow_experimental_bigint_types = 1; - DROP TABLE IF EXISTS dictionary_decimal_source_table; CREATE TABLE dictionary_decimal_source_table ( diff --git a/tests/queries/0_stateless/01875_ssd_cache_dictionary_decimal256_type.sh b/tests/queries/0_stateless/01875_ssd_cache_dictionary_decimal256_type.sh index 1294ba53e82..2a24a931696 100755 --- a/tests/queries/0_stateless/01875_ssd_cache_dictionary_decimal256_type.sh +++ b/tests/queries/0_stateless/01875_ssd_cache_dictionary_decimal256_type.sh @@ -6,8 +6,6 @@ CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) . "$CURDIR"/../shell_config.sh $CLICKHOUSE_CLIENT --query=" - SET allow_experimental_bigint_types = 1; - DROP TABLE IF EXISTS dictionary_decimal_source_table; CREATE TABLE dictionary_decimal_source_table ( From 19ab7d484a6d7a2346103c5468bca611df03e3d9 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Sun, 10 Nov 2024 23:50:31 +0100 Subject: [PATCH 20/35] Add an experimental setting --- src/Core/Settings.cpp | 5 ++++- src/Core/SettingsChangesHistory.cpp | 1 + .../parseColumnsListForTableFunction.cpp | 14 ++++++++++++++ .../parseColumnsListForTableFunction.h | 1 + 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/Core/Settings.cpp b/src/Core/Settings.cpp index 01339226c2d..7c2042ee16d 100644 --- a/src/Core/Settings.cpp +++ b/src/Core/Settings.cpp @@ -5729,7 +5729,10 @@ Enable experimental functions for natural language processing. Enable experimental hash functions )", EXPERIMENTAL) \ DECLARE(Bool, allow_experimental_object_type, false, R"( -Allow Object and JSON data types +Allow the obsolete Object data type +)", EXPERIMENTAL) \ + DECLARE(Bool, allow_experimental_bfloat16_type, false, R"( +Allow BFloat16 data type (under development). )", EXPERIMENTAL) \ DECLARE(Bool, allow_experimental_time_series_table, false, R"( Allows creation of tables with the [TimeSeries](../../engines/table-engines/integrations/time-series.md) table engine. diff --git a/src/Core/SettingsChangesHistory.cpp b/src/Core/SettingsChangesHistory.cpp index 0ff9d0a6833..23aeeb47224 100644 --- a/src/Core/SettingsChangesHistory.cpp +++ b/src/Core/SettingsChangesHistory.cpp @@ -77,6 +77,7 @@ static std::initializer_list(&data_type)) diff --git a/src/Interpreters/parseColumnsListForTableFunction.h b/src/Interpreters/parseColumnsListForTableFunction.h index 6e00492c0ad..39b9f092d89 100644 --- a/src/Interpreters/parseColumnsListForTableFunction.h +++ b/src/Interpreters/parseColumnsListForTableFunction.h @@ -20,6 +20,7 @@ struct DataTypeValidationSettings bool allow_experimental_object_type = true; bool allow_suspicious_fixed_string_types = true; bool allow_experimental_variant_type = true; + bool allow_experimental_bfloat16_type = true; bool allow_suspicious_variant_types = true; bool validate_nested_types = true; bool allow_experimental_dynamic_type = true; From 1a2ee7929e746395a6f0426b6935887af287fd30 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Mon, 11 Nov 2024 00:16:09 +0100 Subject: [PATCH 21/35] More conversions --- src/Functions/FunctionsConversion.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/Functions/FunctionsConversion.cpp b/src/Functions/FunctionsConversion.cpp index 7f4ccc338cf..effaa6faa6d 100644 --- a/src/Functions/FunctionsConversion.cpp +++ b/src/Functions/FunctionsConversion.cpp @@ -1862,11 +1862,6 @@ struct ConvertImpl } } - if constexpr ((std::is_same_v || std::is_same_v) - && !(std::is_same_v || std::is_same_v)) - throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conversion from {} to {} is not supported", - TypeName, TypeName); - if constexpr (std::is_same_v || std::is_same_v) { From f042c921ee84ef583f1b76c9d4587b963bd06f45 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Mon, 11 Nov 2024 00:16:28 +0100 Subject: [PATCH 22/35] Distances --- base/base/BFloat16.h | 7 ++ src/Common/CPUID.h | 6 ++ src/Common/TargetSpecific.cpp | 3 + src/Common/TargetSpecific.h | 26 +++++- src/Functions/array/arrayDistance.cpp | 112 +++++++++++++++++++------- 5 files changed, 119 insertions(+), 35 deletions(-) diff --git a/base/base/BFloat16.h b/base/base/BFloat16.h index f7491b64eb3..2df84dbc0f2 100644 --- a/base/base/BFloat16.h +++ b/base/base/BFloat16.h @@ -80,6 +80,13 @@ public: return x & 0b1000000000000000; } + constexpr BFloat16 abs() const + { + BFloat16 res; + res.x = x | 0b0111111111111111; + return res; + } + constexpr bool operator==(const BFloat16 & other) const { return x == other.x; diff --git a/src/Common/CPUID.h b/src/Common/CPUID.h index b49f7706904..b5c26e64d1e 100644 --- a/src/Common/CPUID.h +++ b/src/Common/CPUID.h @@ -266,6 +266,11 @@ inline bool haveAVX512VBMI2() noexcept return haveAVX512F() && ((CPUInfo(0x7, 0).registers.ecx >> 6) & 1u); } +inline bool haveAVX512BF16() noexcept +{ + return haveAVX512F() && ((CPUInfo(0x7, 1).registers.eax >> 5) & 1u); +} + inline bool haveRDRAND() noexcept { return CPUInfo(0x0).registers.eax >= 0x7 && ((CPUInfo(0x1).registers.ecx >> 30) & 1u); @@ -326,6 +331,7 @@ inline bool haveAMXINT8() noexcept OP(AVX512VL) \ OP(AVX512VBMI) \ OP(AVX512VBMI2) \ + OP(AVX512BF16) \ OP(PREFETCHWT1) \ OP(SHA) \ OP(ADX) \ diff --git a/src/Common/TargetSpecific.cpp b/src/Common/TargetSpecific.cpp index 8540c9a9986..4400d9a60b3 100644 --- a/src/Common/TargetSpecific.cpp +++ b/src/Common/TargetSpecific.cpp @@ -23,6 +23,8 @@ UInt32 getSupportedArchs() result |= static_cast(TargetArch::AVX512VBMI); if (CPU::CPUFlagsCache::have_AVX512VBMI2) result |= static_cast(TargetArch::AVX512VBMI2); + if (CPU::CPUFlagsCache::have_AVX512BF16) + result |= static_cast(TargetArch::AVX512BF16); if (CPU::CPUFlagsCache::have_AMXBF16) result |= static_cast(TargetArch::AMXBF16); if (CPU::CPUFlagsCache::have_AMXTILE) @@ -50,6 +52,7 @@ String toString(TargetArch arch) case TargetArch::AVX512BW: return "avx512bw"; case TargetArch::AVX512VBMI: return "avx512vbmi"; case TargetArch::AVX512VBMI2: return "avx512vbmi2"; + case TargetArch::AVX512BF16: return "avx512bf16"; case TargetArch::AMXBF16: return "amxbf16"; case TargetArch::AMXTILE: return "amxtile"; case TargetArch::AMXINT8: return "amxint8"; diff --git a/src/Common/TargetSpecific.h b/src/Common/TargetSpecific.h index f9523f667b2..5584bd1f63a 100644 --- a/src/Common/TargetSpecific.h +++ b/src/Common/TargetSpecific.h @@ -83,9 +83,10 @@ enum class TargetArch : UInt32 AVX512BW = (1 << 4), AVX512VBMI = (1 << 5), AVX512VBMI2 = (1 << 6), - AMXBF16 = (1 << 7), - AMXTILE = (1 << 8), - AMXINT8 = (1 << 9), + AVX512BF16 = (1 << 7), + AMXBF16 = (1 << 8), + AMXTILE = (1 << 9), + AMXINT8 = (1 << 10), }; /// Runtime detection. @@ -102,6 +103,7 @@ String toString(TargetArch arch); /// NOLINTNEXTLINE #define USE_MULTITARGET_CODE 1 +#define AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2,avx512bf16"))) #define AVX512VBMI2_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2"))) #define AVX512VBMI_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi"))) #define AVX512BW_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw"))) @@ -111,6 +113,8 @@ String toString(TargetArch arch); #define SSE42_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt"))) #define DEFAULT_FUNCTION_SPECIFIC_ATTRIBUTE +# define BEGIN_AVX512BF16_SPECIFIC_CODE \ + _Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2,avx512bf16\"))),apply_to=function)") # define BEGIN_AVX512VBMI2_SPECIFIC_CODE \ _Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2\"))),apply_to=function)") # define BEGIN_AVX512VBMI_SPECIFIC_CODE \ @@ -197,6 +201,14 @@ namespace TargetSpecific::AVX512VBMI2 { \ } \ END_TARGET_SPECIFIC_CODE +#define DECLARE_AVX512BF16_SPECIFIC_CODE(...) \ +BEGIN_AVX512BF16_SPECIFIC_CODE \ +namespace TargetSpecific::AVX512BF16 { \ + DUMMY_FUNCTION_DEFINITION \ + using namespace DB::TargetSpecific::AVX512BF16; \ + __VA_ARGS__ \ +} \ +END_TARGET_SPECIFIC_CODE #else @@ -211,6 +223,7 @@ END_TARGET_SPECIFIC_CODE #define DECLARE_AVX512BW_SPECIFIC_CODE(...) #define DECLARE_AVX512VBMI_SPECIFIC_CODE(...) #define DECLARE_AVX512VBMI2_SPECIFIC_CODE(...) +#define DECLARE_AVX512BF16_SPECIFIC_CODE(...) #endif @@ -229,7 +242,8 @@ DECLARE_AVX2_SPECIFIC_CODE (__VA_ARGS__) \ DECLARE_AVX512F_SPECIFIC_CODE(__VA_ARGS__) \ DECLARE_AVX512BW_SPECIFIC_CODE (__VA_ARGS__) \ DECLARE_AVX512VBMI_SPECIFIC_CODE (__VA_ARGS__) \ -DECLARE_AVX512VBMI2_SPECIFIC_CODE (__VA_ARGS__) +DECLARE_AVX512VBMI2_SPECIFIC_CODE (__VA_ARGS__) \ +DECLARE_AVX512BF16_SPECIFIC_CODE (__VA_ARGS__) DECLARE_DEFAULT_CODE( constexpr auto BuildArch = TargetArch::Default; /// NOLINT @@ -263,6 +277,10 @@ DECLARE_AVX512VBMI2_SPECIFIC_CODE( constexpr auto BuildArch = TargetArch::AVX512VBMI2; /// NOLINT ) // DECLARE_AVX512VBMI2_SPECIFIC_CODE +DECLARE_AVX512BF16_SPECIFIC_CODE( + constexpr auto BuildArch = TargetArch::AVX512BF16; /// NOLINT +) // DECLARE_AVX512BF16_SPECIFIC_CODE + /** Runtime Dispatch helpers for class members. * * Example of usage: diff --git a/src/Functions/array/arrayDistance.cpp b/src/Functions/array/arrayDistance.cpp index a1f48747eb6..da49359c422 100644 --- a/src/Functions/array/arrayDistance.cpp +++ b/src/Functions/array/arrayDistance.cpp @@ -14,6 +14,31 @@ #include #endif + +namespace +{ + inline BFloat16 fabs(BFloat16 x) + { + return x.abs(); + } + + inline BFloat16 sqrt(BFloat16 x) + { + return BFloat16(::sqrtf(Float32(x))); + } + + template + inline BFloat16 pow(BFloat16 x, T p) + { + return BFloat16(::powf(Float32(x), Float32(p))); + } + + inline BFloat16 fmax(BFloat16 x, BFloat16 y) + { + return BFloat16(::fmaxf(Float32(x), Float32(y))); + } +} + namespace DB { namespace ErrorCodes @@ -34,7 +59,7 @@ struct L1Distance template struct State { - FloatType sum = 0; + FloatType sum{}; }; template @@ -65,7 +90,7 @@ struct L2Distance template struct State { - FloatType sum = 0; + FloatType sum{}; }; template @@ -82,7 +107,7 @@ struct L2Distance #if USE_MULTITARGET_CODE template - AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine( + AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine( const ResultType * __restrict data_x, const ResultType * __restrict data_y, size_t i_max, @@ -90,19 +115,29 @@ struct L2Distance size_t & i_y, State & state) { - static constexpr bool is_float32 = std::is_same_v; - __m512 sums; - if constexpr (is_float32) + if constexpr (sizeof(ResultType) <= 4) sums = _mm512_setzero_ps(); else sums = _mm512_setzero_pd(); - constexpr size_t n = is_float32 ? 16 : 8; + constexpr size_t n = sizeof(__m512) / sizeof(ResultType); for (; i_x + n < i_max; i_x += n, i_y += n) { - if constexpr (is_float32) + if constexpr (sizeof(ResultType) == 2) + { + __m512 x_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast(data_x + i_x))); + __m512 x_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast(data_x + i_x + n / 2))); + __m512 y_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast(data_y + i_y))); + __m512 y_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast(data_y + i_y + n / 2))); + + __m512 differences_1 = _mm512_sub_ps(x_1, y_1); + __m512 differences_2 = _mm512_sub_ps(x_2, y_2); + sums = _mm512_fmadd_ps(differences_1, differences_1, sums); + sums = _mm512_fmadd_ps(differences_2, differences_2, sums); + } + else if constexpr (sizeof(ResultType) == 4) { __m512 x = _mm512_loadu_ps(data_x + i_x); __m512 y = _mm512_loadu_ps(data_y + i_y); @@ -118,7 +153,7 @@ struct L2Distance } } - if constexpr (is_float32) + if constexpr (sizeof(ResultType) <= 4) state.sum = _mm512_reduce_add_ps(sums); else state.sum = _mm512_reduce_add_pd(sums); @@ -128,7 +163,7 @@ struct L2Distance template static ResultType finalize(const State & state, const ConstParams &) { - return sqrt(state.sum); + return sqrt(ResultType(state.sum)); } }; @@ -156,13 +191,13 @@ struct LpDistance template struct State { - FloatType sum = 0; + FloatType sum{}; }; template static void accumulate(State & state, ResultType x, ResultType y, const ConstParams & params) { - state.sum += static_cast(std::pow(fabs(x - y), params.power)); + state.sum += static_cast(pow(fabs(x - y), params.power)); } template @@ -174,7 +209,7 @@ struct LpDistance template static ResultType finalize(const State & state, const ConstParams & params) { - return static_cast(std::pow(state.sum, params.inverted_power)); + return static_cast(pow(state.sum, params.inverted_power)); } }; @@ -187,7 +222,7 @@ struct LinfDistance template struct State { - FloatType dist = 0; + FloatType dist{}; }; template @@ -218,9 +253,9 @@ struct CosineDistance template struct State { - FloatType dot_prod = 0; - FloatType x_squared = 0; - FloatType y_squared = 0; + FloatType dot_prod{}; + FloatType x_squared{}; + FloatType y_squared{}; }; template @@ -241,7 +276,7 @@ struct CosineDistance #if USE_MULTITARGET_CODE template - AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine( + AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine( const ResultType * __restrict data_x, const ResultType * __restrict data_y, size_t i_max, @@ -249,13 +284,11 @@ struct CosineDistance size_t & i_y, State & state) { - static constexpr bool is_float32 = std::is_same_v; - __m512 dot_products; __m512 x_squareds; __m512 y_squareds; - if constexpr (is_float32) + if constexpr (sizeof(ResultType) <= 4) { dot_products = _mm512_setzero_ps(); x_squareds = _mm512_setzero_ps(); @@ -268,11 +301,19 @@ struct CosineDistance y_squareds = _mm512_setzero_pd(); } - constexpr size_t n = is_float32 ? 16 : 8; + constexpr size_t n = sizeof(__m512) / sizeof(ResultType); for (; i_x + n < i_max; i_x += n, i_y += n) { - if constexpr (is_float32) + if constexpr (sizeof(ResultType) == 2) + { + __m512 x = _mm512_loadu_ps(data_x + i_x); + __m512 y = _mm512_loadu_ps(data_y + i_y); + dot_products = _mm512_dpbf16_ps(dot_products, x, y); + x_squareds = _mm512_dpbf16_ps(x_squareds, x, x); + y_squareds = _mm512_dpbf16_ps(y_squareds, y, y); + } + if constexpr (sizeof(ResultType) == 4) { __m512 x = _mm512_loadu_ps(data_x + i_x); __m512 y = _mm512_loadu_ps(data_y + i_y); @@ -290,7 +331,7 @@ struct CosineDistance } } - if constexpr (is_float32) + if constexpr (sizeof(ResultType) == 2 || sizeof(ResultType) == 4) { state.dot_prod = _mm512_reduce_add_ps(dot_products); state.x_squared = _mm512_reduce_add_ps(x_squareds); @@ -308,7 +349,7 @@ struct CosineDistance template static ResultType finalize(const State & state, const ConstParams &) { - return 1 - state.dot_prod / sqrt(state.x_squared * state.y_squared); + return ResultType(1) - state.dot_prod / sqrt(state.x_squared * state.y_squared); } }; @@ -353,11 +394,13 @@ public: return std::make_shared(); case TypeIndex::Float32: return std::make_shared(); + case TypeIndex::BFloat16: + return std::make_shared(); default: throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} has nested type {}. " - "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, BFloat16, Float32, Float64.", getName(), common_type->getName()); } @@ -367,6 +410,9 @@ public: { switch (result_type->getTypeId()) { + case TypeIndex::BFloat16: + return executeWithResultType(arguments, input_rows_count); + break; case TypeIndex::Float32: return executeWithResultType(arguments, input_rows_count); break; @@ -388,6 +434,7 @@ public: ACTION(Int16) \ ACTION(Int32) \ ACTION(Int64) \ + ACTION(BFloat16) \ ACTION(Float32) \ ACTION(Float64) @@ -412,7 +459,7 @@ private: throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} has nested type {}. " - "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, BFloat16, Float32, Float64.", getName(), type_x->getName()); } @@ -437,7 +484,7 @@ private: throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} has nested type {}. " - "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, BFloat16, Float32, Float64.", getName(), type_y->getName()); } @@ -548,13 +595,15 @@ private: /// SIMD optimization: process multiple elements in both input arrays at once. /// To avoid combinatorial explosion of SIMD kernels, focus on - /// - the two most common input/output types (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64 instead of 10 x - /// 10 input types x 2 output types, + /// - the three most common input/output types (BFloat16 x BFloat16) --> BFloat16, + /// (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64 + /// instead of 10 x 10 input types x 2 output types, /// - const/non-const inputs instead of non-const/non-const inputs /// - the two most common metrics L2 and cosine distance, /// - the most powerful SIMD instruction set (AVX-512F). #if USE_MULTITARGET_CODE - if constexpr (std::is_same_v && std::is_same_v) /// ResultType is Float32 or Float64 + /// ResultType is BFloat16, Float32 or Float64 + if constexpr (std::is_same_v && std::is_same_v) { if constexpr (std::is_same_v || std::is_same_v) @@ -638,4 +687,5 @@ FunctionPtr createFunctionArrayL2SquaredDistance(ContextPtr context_) { return F FunctionPtr createFunctionArrayLpDistance(ContextPtr context_) { return FunctionArrayDistance::create(context_); } FunctionPtr createFunctionArrayLinfDistance(ContextPtr context_) { return FunctionArrayDistance::create(context_); } FunctionPtr createFunctionArrayCosineDistance(ContextPtr context_) { return FunctionArrayDistance::create(context_); } + } From 6dee7e42766177e30712f7c1c341663b3fba2f91 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Mon, 11 Nov 2024 00:24:57 +0100 Subject: [PATCH 23/35] Fix style --- src/Databases/enableAllExperimentalSettings.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/Databases/enableAllExperimentalSettings.cpp b/src/Databases/enableAllExperimentalSettings.cpp index bc2dae55f97..1be54664bc9 100644 --- a/src/Databases/enableAllExperimentalSettings.cpp +++ b/src/Databases/enableAllExperimentalSettings.cpp @@ -27,6 +27,8 @@ void enableAllExperimentalSettings(ContextMutablePtr context) context->setSetting("allow_experimental_window_functions", 1); context->setSetting("allow_experimental_geo_types", 1); context->setSetting("allow_experimental_map_type", 1); + context->setSetting("allow_experimental_bigint_types", 1); + context->setSetting("allow_experimental_bfloat16_type", 1); context->setSetting("allow_deprecated_error_prone_window_functions", 1); context->setSetting("allow_suspicious_low_cardinality_types", 1); From 89b015cecfad9a6a8f44039efa556d209ea50239 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Mon, 11 Nov 2024 00:25:11 +0100 Subject: [PATCH 24/35] Do not compile BFloat16 --- src/DataTypes/IDataType.h | 3 ++- src/DataTypes/Native.cpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/DataTypes/IDataType.h b/src/DataTypes/IDataType.h index 1e41d6b2eba..8f06526ddbb 100644 --- a/src/DataTypes/IDataType.h +++ b/src/DataTypes/IDataType.h @@ -411,7 +411,8 @@ struct WhichDataType constexpr bool isBFloat16() const { return idx == TypeIndex::BFloat16; } constexpr bool isFloat32() const { return idx == TypeIndex::Float32; } constexpr bool isFloat64() const { return idx == TypeIndex::Float64; } - constexpr bool isFloat() const { return isBFloat16() || isFloat32() || isFloat64(); } + constexpr bool isNativeFloat() const { return isFloat32() || isFloat64(); } + constexpr bool isFloat() const { return isNativeFloat() || isBFloat16(); } constexpr bool isNativeNumber() const { return isNativeInteger() || isFloat(); } constexpr bool isNumber() const { return isInteger() || isFloat() || isDecimal(); } diff --git a/src/DataTypes/Native.cpp b/src/DataTypes/Native.cpp index 5dc490b0bd5..53354d7c6e0 100644 --- a/src/DataTypes/Native.cpp +++ b/src/DataTypes/Native.cpp @@ -37,7 +37,7 @@ bool canBeNativeType(const IDataType & type) return canBeNativeType(*data_type_nullable.getNestedType()); } - return data_type.isNativeInt() || data_type.isNativeUInt() || data_type.isFloat() || data_type.isDate() + return data_type.isNativeInt() || data_type.isNativeUInt() || data_type.isNativeFloat() || data_type.isDate() || data_type.isDate32() || data_type.isDateTime() || data_type.isEnum(); } From 968a559917577a63464bbaf87d3e724912cb7d5a Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Mon, 11 Nov 2024 00:59:37 +0100 Subject: [PATCH 25/35] Add a test --- .../queries/0_stateless/03269_bf16.reference | 45 ++++++++++ tests/queries/0_stateless/03269_bf16.sql | 88 +++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 tests/queries/0_stateless/03269_bf16.reference create mode 100644 tests/queries/0_stateless/03269_bf16.sql diff --git a/tests/queries/0_stateless/03269_bf16.reference b/tests/queries/0_stateless/03269_bf16.reference new file mode 100644 index 00000000000..daa26cb252f --- /dev/null +++ b/tests/queries/0_stateless/03269_bf16.reference @@ -0,0 +1,45 @@ +1 -1 1.09375 -1.09375 1 -1 1.09375 -1.09375 18446744000000000000 -0 inf -inf nan +1.09375 1.09375 1.09375 1 +1 1 0 1 1 +0 2.1875 1.1962891 1 Float32 Float32 Float32 Float64 +-0.006250000000000089 2.19375 1.203125 1.0057142857142858 Float64 Float64 Float64 Float64 +0 0 1 0 +1000 1000 1 0 +2000 2000 1 0 +3000 2992 0 8 +4000 4000 1 0 +5000 4992 0 8 +6000 5984 0 16 +7000 6976 0 24 +8000 8000 1 0 +9000 8960 0 40 +49995000 49855104 4999.5 4985.5104 0 0 9999 9984 10000 925 10000 925 +0 0 1 0 +1000 1000 1 0 +2000 2000 1 0 +3000 2992 0 8 +4000 4000 1 0 +5000 4992 0 8 +6000 5984 0 16 +7000 6976 0 24 +8000 8000 1 0 +9000 8960 0 40 +49995000 49855104 4999.5 4985.5104 0 0 9999 9984 10000 925 10000 925 +Row 1: +────── +a32: [0,0.5,1,1.5,2,2.5,3,3.5,4,4.5,5,5.5,6,6.5,7,7.5,8,8.5,9,9.5,10,10.5,11,11.5,12,12.5,13,13.5,14,14.5,15,15.5,16,16.5,17,17.5,18,18.5,19,19.5,20,20.5,21,21.5,22,22.5,23,23.5,24,24.5,25,25.5,26,26.5,27,27.5,28,28.5,29,29.5,30,30.5,31,31.5,32,32.5,33,33.5,34,34.5,35,35.5,36,36.5,37,37.5,38,38.5,39,39.5,40,40.5,41,41.5,42,42.5,43,43.5,44,44.5,45,45.5,46,46.5,47,47.5,48,48.5,49,49.5,50,50.5,51,51.5,52,52.5,53,53.5,54,54.5,55,55.5,56,56.5,57,57.5,58,58.5,59,59.5,60,60.5,61,61.5,62,62.5,63,63.5,64,64.5,65,65.5,66,66.5,67,67.5,68,68.5,69,69.5,70,70.5,71,71.5,72,72.5,73,73.5,74,74.5,75,75.5,76,76.5,77,77.5,78,78.5,79,79.5,80,80.5,81,81.5,82,82.5,83,83.5,84,84.5,85,85.5,86,86.5,87,87.5,88,88.5,89,89.5,90,90.5,91,91.5,92,92.5,93,93.5,94,94.5,95,95.5,96,96.5,97,97.5,98,98.5,99,99.5,100,100.5,101,101.5,102,102.5,103,103.5,104,104.5,105,105.5,106,106.5,107,107.5,108,108.5,109,109.5,110,110.5,111,111.5,112,112.5,113,113.5,114,114.5,115,115.5,116,116.5,117,117.5,118,118.5,119,119.5,120,120.5,121,121.5,122,122.5,123,123.5,124,124.5,125,125.5,126,126.5,127,127.5,128,128.5,129,129.5,130,130.5,131,131.5,132,132.5,133,133.5,134,134.5,135,135.5,136,136.5,137,137.5,138,138.5,139,139.5,140,140.5,141,141.5,142,142.5,143,143.5,144,144.5,145,145.5,146,146.5,147,147.5,148,148.5,149,149.5,150,150.5,151,151.5,152,152.5,153,153.5,154,154.5,155,155.5,156,156.5,157,157.5,158,158.5,159,159.5,160,160.5,161,161.5,162,162.5,163,163.5,164,164.5,165,165.5,166,166.5,167,167.5,168,168.5,169,169.5,170,170.5,171,171.5,172,172.5,173,173.5,174,174.5,175,175.5,176,176.5,177,177.5,178,178.5,179,179.5,180,180.5,181,181.5,182,182.5,183,183.5,184,184.5,185,185.5,186,186.5,187,187.5,188,188.5,189,189.5,190,190.5,191,191.5] +a16: [0,0.5,1,1.5,2,2.5,3,3.5,4,4.5,5,5.5,6,6.5,7,7.5,8,8.5,9,9.5,10,10.5,11,11.5,12,12.5,13,13.5,14,14.5,15,15.5,16,16.5,17,17.5,18,18.5,19,19.5,20,20.5,21,21.5,22,22.5,23,23.5,24,24.5,25,25.5,26,26.5,27,27.5,28,28.5,29,29.5,30,30.5,31,31.5,32,32.5,33,33.5,34,34.5,35,35.5,36,36.5,37,37.5,38,38.5,39,39.5,40,40.5,41,41.5,42,42.5,43,43.5,44,44.5,45,45.5,46,46.5,47,47.5,48,48.5,49,49.5,50,50.5,51,51.5,52,52.5,53,53.5,54,54.5,55,55.5,56,56.5,57,57.5,58,58.5,59,59.5,60,60.5,61,61.5,62,62.5,63,63.5,64,64.5,65,65.5,66,66.5,67,67.5,68,68.5,69,69.5,70,70.5,71,71.5,72,72.5,73,73.5,74,74.5,75,75.5,76,76.5,77,77.5,78,78.5,79,79.5,80,80.5,81,81.5,82,82.5,83,83.5,84,84.5,85,85.5,86,86.5,87,87.5,88,88.5,89,89.5,90,90.5,91,91.5,92,92.5,93,93.5,94,94.5,95,95.5,96,96.5,97,97.5,98,98.5,99,99.5,100,100.5,101,101.5,102,102.5,103,103.5,104,104.5,105,105.5,106,106.5,107,107.5,108,108.5,109,109.5,110,110.5,111,111.5,112,112.5,113,113.5,114,114.5,115,115.5,116,116.5,117,117.5,118,118.5,119,119.5,120,120.5,121,121.5,122,122.5,123,123.5,124,124.5,125,125.5,126,126.5,127,127.5,128,128,129,129,130,130,131,131,132,132,133,133,134,134,135,135,136,136,137,137,138,138,139,139,140,140,141,141,142,142,143,143,144,144,145,145,146,146,147,147,148,148,149,149,150,150,151,151,152,152,153,153,154,154,155,155,156,156,157,157,158,158,159,159,160,160,161,161,162,162,163,163,164,164,165,165,166,166,167,167,168,168,169,169,170,170,171,171,172,172,173,173,174,174,175,175,176,176,177,177,178,178,179,179,180,180,181,181,182,182,183,183,184,184,185,185,186,186,187,187,188,188,189,189,190,190,191,191] +a32_1: [1,1.5,2,2.5,3,3.5,4,4.5,5,5.5,6,6.5,7,7.5,8,8.5,9,9.5,10,10.5,11,11.5,12,12.5,13,13.5,14,14.5,15,15.5,16,16.5,17,17.5,18,18.5,19,19.5,20,20.5,21,21.5,22,22.5,23,23.5,24,24.5,25,25.5,26,26.5,27,27.5,28,28.5,29,29.5,30,30.5,31,31.5,32,32.5,33,33.5,34,34.5,35,35.5,36,36.5,37,37.5,38,38.5,39,39.5,40,40.5,41,41.5,42,42.5,43,43.5,44,44.5,45,45.5,46,46.5,47,47.5,48,48.5,49,49.5,50,50.5,51,51.5,52,52.5,53,53.5,54,54.5,55,55.5,56,56.5,57,57.5,58,58.5,59,59.5,60,60.5,61,61.5,62,62.5,63,63.5,64,64.5,65,65.5,66,66.5,67,67.5,68,68.5,69,69.5,70,70.5,71,71.5,72,72.5,73,73.5,74,74.5,75,75.5,76,76.5,77,77.5,78,78.5,79,79.5,80,80.5,81,81.5,82,82.5,83,83.5,84,84.5,85,85.5,86,86.5,87,87.5,88,88.5,89,89.5,90,90.5,91,91.5,92,92.5,93,93.5,94,94.5,95,95.5,96,96.5,97,97.5,98,98.5,99,99.5,100,100.5,101,101.5,102,102.5,103,103.5,104,104.5,105,105.5,106,106.5,107,107.5,108,108.5,109,109.5,110,110.5,111,111.5,112,112.5,113,113.5,114,114.5,115,115.5,116,116.5,117,117.5,118,118.5,119,119.5,120,120.5,121,121.5,122,122.5,123,123.5,124,124.5,125,125.5,126,126.5,127,127.5,128,128.5,129,129.5,130,130.5,131,131.5,132,132.5,133,133.5,134,134.5,135,135.5,136,136.5,137,137.5,138,138.5,139,139.5,140,140.5,141,141.5,142,142.5,143,143.5,144,144.5,145,145.5,146,146.5,147,147.5,148,148.5,149,149.5,150,150.5,151,151.5,152,152.5,153,153.5,154,154.5,155,155.5,156,156.5,157,157.5,158,158.5,159,159.5,160,160.5,161,161.5,162,162.5,163,163.5,164,164.5,165,165.5,166,166.5,167,167.5,168,168.5,169,169.5,170,170.5,171,171.5,172,172.5,173,173.5,174,174.5,175,175.5,176,176.5,177,177.5,178,178.5,179,179.5,180,180.5,181,181.5,182,182.5,183,183.5,184,184.5,185,185.5,186,186.5,187,187.5,188,188.5,189,189.5,190,190.5,191,191.5,192,192.5] +a16_1: [1,1.5,2,2.5,3,3.5,4,4.5,5,5.5,6,6.5,7,7.5,8,8.5,9,9.5,10,10.5,11,11.5,12,12.5,13,13.5,14,14.5,15,15.5,16,16.5,17,17.5,18,18.5,19,19.5,20,20.5,21,21.5,22,22.5,23,23.5,24,24.5,25,25.5,26,26.5,27,27.5,28,28.5,29,29.5,30,30.5,31,31.5,32,32.5,33,33.5,34,34.5,35,35.5,36,36.5,37,37.5,38,38.5,39,39.5,40,40.5,41,41.5,42,42.5,43,43.5,44,44.5,45,45.5,46,46.5,47,47.5,48,48.5,49,49.5,50,50.5,51,51.5,52,52.5,53,53.5,54,54.5,55,55.5,56,56.5,57,57.5,58,58.5,59,59.5,60,60.5,61,61.5,62,62.5,63,63.5,64,64.5,65,65.5,66,66.5,67,67.5,68,68.5,69,69.5,70,70.5,71,71.5,72,72.5,73,73.5,74,74.5,75,75.5,76,76.5,77,77.5,78,78.5,79,79.5,80,80.5,81,81.5,82,82.5,83,83.5,84,84.5,85,85.5,86,86.5,87,87.5,88,88.5,89,89.5,90,90.5,91,91.5,92,92.5,93,93.5,94,94.5,95,95.5,96,96.5,97,97.5,98,98.5,99,99.5,100,100.5,101,101.5,102,102.5,103,103.5,104,104.5,105,105.5,106,106.5,107,107.5,108,108.5,109,109.5,110,110.5,111,111.5,112,112.5,113,113.5,114,114.5,115,115.5,116,116.5,117,117.5,118,118.5,119,119.5,120,120.5,121,121.5,122,122.5,123,123.5,124,124.5,125,125.5,126,126.5,127,127.5,128,128.5,129,129,130,130,131,131,132,132,133,133,134,134,135,135,136,136,137,137,138,138,139,139,140,140,141,141,142,142,143,143,144,144,145,145,146,146,147,147,148,148,149,149,150,150,151,151,152,152,153,153,154,154,155,155,156,156,157,157,158,158,159,159,160,160,161,161,162,162,163,163,164,164,165,165,166,166,167,167,168,168,169,169,170,170,171,171,172,172,173,173,174,174,175,175,176,176,177,177,178,178,179,179,180,180,181,181,182,182,183,183,184,184,185,185,186,186,187,187,188,188,189,189,190,190,191,191,192,192] +dotProduct(a32, a32_1): 4736944 -- 4.74 million +dotProduct(a16, a16_1): 4726688 -- 4.73 million +cosineDistance(a32, a32_1): 0.000010093636084174129 +cosineDistance(a16, a16_1): 0.00001010226319664298 +L2Distance(a32, a32_1): 19.595917942265423 +L2Distance(a16, a16_1): 19.595917942265423 +L1Distance(a32, a32_1): 384 +L1Distance(a16, a16_1): 384 +LinfDistance(a32, a32_1): 1 +LinfDistance(a16, a16_1): 1 +LpDistance(a32, a32_1, 5): 3.2875036590344515 +LpDistance(a16, a16_1, 5): 3.2875036590344515 diff --git a/tests/queries/0_stateless/03269_bf16.sql b/tests/queries/0_stateless/03269_bf16.sql new file mode 100644 index 00000000000..375cca73b62 --- /dev/null +++ b/tests/queries/0_stateless/03269_bf16.sql @@ -0,0 +1,88 @@ +SET allow_experimental_bfloat16_type = 1; + +-- This is a smoke test, non exhaustive. + +-- Conversions + +SELECT + 1::BFloat16, + -1::BFloat16, + 1.1::BFloat16, + -1.1::BFloat16, + CAST(1 AS BFloat16), + CAST(-1 AS BFloat16), + CAST(1.1 AS BFloat16), + CAST(-1.1 AS BFloat16), + CAST(0xFFFFFFFFFFFFFFFF AS BFloat16), + CAST(-0.0 AS BFloat16), + CAST(inf AS BFloat16), + CAST(-inf AS BFloat16), + CAST(nan AS BFloat16); + +-- Conversions back + +SELECT + CAST(1.1::BFloat16 AS BFloat16), + CAST(1.1::BFloat16 AS Float32), + CAST(1.1::BFloat16 AS Float64), + CAST(1.1::BFloat16 AS Int8); + +-- Comparisons + +SELECT + 1.1::BFloat16 = 1.1::BFloat16, + 1.1::BFloat16 < 1.1, + 1.1::BFloat16 > 1.1, + 1.1::BFloat16 > 1, + 1.1::BFloat16 = 1.09375; + +-- Arithmetic + +SELECT + 1.1::BFloat16 - 1.1::BFloat16 AS a, + 1.1::BFloat16 + 1.1::BFloat16 AS b, + 1.1::BFloat16 * 1.1::BFloat16 AS c, + 1.1::BFloat16 / 1.1::BFloat16 AS d, + toTypeName(a), toTypeName(b), toTypeName(c), toTypeName(d); + +SELECT + 1.1::BFloat16 - 1.1 AS a, + 1.1 + 1.1::BFloat16 AS b, + 1.1::BFloat16 * 1.1 AS c, + 1.1 / 1.1::BFloat16 AS d, + toTypeName(a), toTypeName(b), toTypeName(c), toTypeName(d); + +-- Tables + +DROP TABLE IF EXISTS t; +CREATE TEMPORARY TABLE t (n UInt64, x BFloat16); +INSERT INTO t SELECT number, number FROM numbers(10000); +SELECT *, n = x, n - x FROM t WHERE n % 1000 = 0 ORDER BY n; + +-- Aggregate functions + +SELECT sum(n), sum(x), avg(n), avg(x), min(n), min(x), max(n), max(x), uniq(n), uniq(x), uniqExact(n), uniqExact(x) FROM t; + +-- MergeTree + +DROP TABLE t; +CREATE TABLE t (n UInt64, x BFloat16) ENGINE = MergeTree ORDER BY n; +INSERT INTO t SELECT number, number FROM numbers(10000); +SELECT *, n = x, n - x FROM t WHERE n % 1000 = 0 ORDER BY n; +SELECT sum(n), sum(x), avg(n), avg(x), min(n), min(x), max(n), max(x), uniq(n), uniq(x), uniqExact(n), uniqExact(x) FROM t; + +-- Distances + +WITH + arrayMap(x -> toFloat32(x) / 2, range(384)) AS a32, + arrayMap(x -> toBFloat16(x) / 2, range(384)) AS a16, + arrayMap(x -> x + 1, a32) AS a32_1, + arrayMap(x -> x + 1, a16) AS a16_1 +SELECT a32, a16, a32_1, a16_1, + dotProduct(a32, a32_1), dotProduct(a16, a16_1), + cosineDistance(a32, a32_1), cosineDistance(a16, a16_1), + L2Distance(a32, a32_1), L2Distance(a16, a16_1), + L1Distance(a32, a32_1), L1Distance(a16, a16_1), + LinfDistance(a32, a32_1), LinfDistance(a16, a16_1), + LpDistance(a32, a32_1, 5), LpDistance(a16, a16_1, 5) +FORMAT Vertical; From bfeefa2c8a5ce71dd0cc90d68d831d694aef3418 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Mon, 11 Nov 2024 01:02:10 +0100 Subject: [PATCH 26/35] Introspection --- src/Functions/FunctionsBinaryRepresentation.cpp | 1 + tests/queries/0_stateless/03269_bf16.reference | 1 + tests/queries/0_stateless/03269_bf16.sql | 7 +++++++ 3 files changed, 9 insertions(+) diff --git a/src/Functions/FunctionsBinaryRepresentation.cpp b/src/Functions/FunctionsBinaryRepresentation.cpp index c8e8f167e4c..50a3c0862f4 100644 --- a/src/Functions/FunctionsBinaryRepresentation.cpp +++ b/src/Functions/FunctionsBinaryRepresentation.cpp @@ -296,6 +296,7 @@ public: tryExecuteUIntOrInt(column, res_column) || tryExecuteString(column, res_column) || tryExecuteFixedString(column, res_column) || + tryExecuteFloat(column, res_column) || tryExecuteFloat(column, res_column) || tryExecuteFloat(column, res_column) || tryExecuteDecimal(column, res_column) || diff --git a/tests/queries/0_stateless/03269_bf16.reference b/tests/queries/0_stateless/03269_bf16.reference index daa26cb252f..31395d92e2b 100644 --- a/tests/queries/0_stateless/03269_bf16.reference +++ b/tests/queries/0_stateless/03269_bf16.reference @@ -43,3 +43,4 @@ LinfDistance(a32, a32_1): 1 LinfDistance(a16, a16_1): 1 LpDistance(a32, a32_1, 5): 3.2875036590344515 LpDistance(a16, a16_1, 5): 3.2875036590344515 +1.09375 8C3F 1000110000111111 2 16268 8C3F diff --git a/tests/queries/0_stateless/03269_bf16.sql b/tests/queries/0_stateless/03269_bf16.sql index 375cca73b62..de4e2f6da47 100644 --- a/tests/queries/0_stateless/03269_bf16.sql +++ b/tests/queries/0_stateless/03269_bf16.sql @@ -86,3 +86,10 @@ SELECT a32, a16, a32_1, a16_1, LinfDistance(a32, a32_1), LinfDistance(a16, a16_1), LpDistance(a32, a32_1, 5), LpDistance(a16, a16_1, 5) FORMAT Vertical; + +-- Introspection + +SELECT 1.1::BFloat16 AS x, + hex(x), bin(x), + byteSize(x), + reinterpretAsUInt16(x), hex(reinterpretAsString(x)); From 3e50cf94fe858e8440ffd69040334356326b97db Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Mon, 11 Nov 2024 01:04:55 +0100 Subject: [PATCH 27/35] Rounding --- tests/queries/0_stateless/03269_bf16.reference | 1 + tests/queries/0_stateless/03269_bf16.sql | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/tests/queries/0_stateless/03269_bf16.reference b/tests/queries/0_stateless/03269_bf16.reference index 31395d92e2b..896cc307623 100644 --- a/tests/queries/0_stateless/03269_bf16.reference +++ b/tests/queries/0_stateless/03269_bf16.reference @@ -44,3 +44,4 @@ LinfDistance(a16, a16_1): 1 LpDistance(a32, a32_1, 5): 3.2875036590344515 LpDistance(a16, a16_1, 5): 3.2875036590344515 1.09375 8C3F 1000110000111111 2 16268 8C3F +1.09375 1 1.09375 1.0859375 0 diff --git a/tests/queries/0_stateless/03269_bf16.sql b/tests/queries/0_stateless/03269_bf16.sql index de4e2f6da47..b332a6e3119 100644 --- a/tests/queries/0_stateless/03269_bf16.sql +++ b/tests/queries/0_stateless/03269_bf16.sql @@ -93,3 +93,8 @@ SELECT 1.1::BFloat16 AS x, hex(x), bin(x), byteSize(x), reinterpretAsUInt16(x), hex(reinterpretAsString(x)); + +-- Rounding (this could be not towards the nearest) + +SELECT 1.1::BFloat16 AS x, + round(x), round(x, 1), round(x, 2), round(x, -1); From 3a855f501cd5d16ff97e9dde8b6fcb2d3b7ae497 Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Mon, 11 Nov 2024 02:15:31 +0100 Subject: [PATCH 28/35] Cleanups --- base/base/DecomposedFloat.h | 2 +- base/base/wide_integer.h | 1 - base/base/wide_integer_impl.h | 8 +------- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/base/base/DecomposedFloat.h b/base/base/DecomposedFloat.h index 3bd059cb21c..fef91adefb0 100644 --- a/base/base/DecomposedFloat.h +++ b/base/base/DecomposedFloat.h @@ -230,4 +230,4 @@ struct DecomposedFloat using DecomposedFloat64 = DecomposedFloat; using DecomposedFloat32 = DecomposedFloat; -using DecomposedFloat16 = DecomposedFloat<__bf16>; +using DecomposedFloat16 = DecomposedFloat; diff --git a/base/base/wide_integer.h b/base/base/wide_integer.h index baf6e490ada..f3a4dc9e6d5 100644 --- a/base/base/wide_integer.h +++ b/base/base/wide_integer.h @@ -118,7 +118,6 @@ public: constexpr operator long double() const noexcept; constexpr operator double() const noexcept; constexpr operator float() const noexcept; - constexpr operator __bf16() const noexcept; struct _impl; diff --git a/base/base/wide_integer_impl.h b/base/base/wide_integer_impl.h index d0bbd7df9d4..3787971a20e 100644 --- a/base/base/wide_integer_impl.h +++ b/base/base/wide_integer_impl.h @@ -154,7 +154,7 @@ struct common_type, Arithmetic> static_assert(wide::ArithmeticConcept()); using type = std::conditional_t< - std::is_floating_point_v || std::is_same_v, + std::is_floating_point_v, Arithmetic, std::conditional_t< sizeof(Arithmetic) * 8 < Bits, @@ -1300,12 +1300,6 @@ constexpr integer::operator float() const noexcept return static_cast(static_cast(*this)); } -template -constexpr integer::operator __bf16() const noexcept -{ - return static_cast<__bf16>(static_cast(*this)); -} - // Unary operators template constexpr integer operator~(const integer & lhs) noexcept From 7310376413e34ad2c958c5ccde9cddd75e0a5aed Mon Sep 17 00:00:00 2001 From: Alexey Milovidov Date: Tue, 12 Nov 2024 01:23:01 +0100 Subject: [PATCH 29/35] Remove ridiculous code bloat --- .../AggregateFunctionDeltaSumTimestamp.cpp | 69 ++++++++++++++---- src/AggregateFunctions/Helpers.h | 70 +------------------ 2 files changed, 58 insertions(+), 81 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionDeltaSumTimestamp.cpp b/src/AggregateFunctions/AggregateFunctionDeltaSumTimestamp.cpp index dc1adead87c..0c5b752b539 100644 --- a/src/AggregateFunctions/AggregateFunctionDeltaSumTimestamp.cpp +++ b/src/AggregateFunctions/AggregateFunctionDeltaSumTimestamp.cpp @@ -22,6 +22,13 @@ namespace ErrorCodes namespace { +/** Due to a lack of proper code review, this code was contributed with a multiplication of template instantiations + * over all pairs of data types, and we deeply regret that. + * + * We cannot remove all combinations, because the binary representation of serialized data has to remain the same, + * but we can partially heal the wound by treating unsigned and signed data types in the same way. + */ + template struct AggregationFunctionDeltaSumTimestampData { @@ -37,23 +44,22 @@ template class AggregationFunctionDeltaSumTimestamp final : public IAggregateFunctionDataHelper< AggregationFunctionDeltaSumTimestampData, - AggregationFunctionDeltaSumTimestamp - > + AggregationFunctionDeltaSumTimestamp> { public: AggregationFunctionDeltaSumTimestamp(const DataTypes & arguments, const Array & params) : IAggregateFunctionDataHelper< AggregationFunctionDeltaSumTimestampData, - AggregationFunctionDeltaSumTimestamp - >{arguments, params, createResultType()} - {} + AggregationFunctionDeltaSumTimestamp>{arguments, params, createResultType()} + { + } AggregationFunctionDeltaSumTimestamp() : IAggregateFunctionDataHelper< AggregationFunctionDeltaSumTimestampData, - AggregationFunctionDeltaSumTimestamp - >{} - {} + AggregationFunctionDeltaSumTimestamp>{} + { + } bool allocatesMemoryInArena() const override { return false; } @@ -63,8 +69,8 @@ public: void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override { - auto value = assert_cast &>(*columns[0]).getData()[row_num]; - auto ts = assert_cast &>(*columns[1]).getData()[row_num]; + auto value = unalignedLoad(columns[0]->getRawData().data() + row_num * sizeof(ValueType)); + auto ts = unalignedLoad(columns[1]->getRawData().data() + row_num * sizeof(TimestampType)); auto & data = this->data(place); @@ -172,10 +178,49 @@ public: void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override { - assert_cast &>(to).getData().push_back(this->data(place).sum); + static_cast(to).template insertRawData( + reinterpret_cast(&this->data(place).sum)); } }; + + +template class AggregateFunctionTemplate, typename... TArgs> +static IAggregateFunction * createWithTwoTypesSecond(const IDataType & second_type, TArgs && ... args) +{ + WhichDataType which(second_type); + + if (which.idx == TypeIndex::UInt32) return new AggregateFunctionTemplate(args...); + if (which.idx == TypeIndex::UInt64) return new AggregateFunctionTemplate(args...); + if (which.idx == TypeIndex::Int32) return new AggregateFunctionTemplate(args...); + if (which.idx == TypeIndex::Int64) return new AggregateFunctionTemplate(args...); + if (which.idx == TypeIndex::Float32) return new AggregateFunctionTemplate(args...); + if (which.idx == TypeIndex::Float64) return new AggregateFunctionTemplate(args...); + if (which.idx == TypeIndex::Date) return new AggregateFunctionTemplate(args...); + if (which.idx == TypeIndex::DateTime) return new AggregateFunctionTemplate(args...); + + return nullptr; +} + +template