Manual implementation

This commit is contained in:
Alexey Milovidov 2024-11-10 22:13:22 +01:00
parent 08e6e598f7
commit 7877d59ff6
14 changed files with 318 additions and 48 deletions

View File

@ -1,22 +1,294 @@
#pragma once
#include <base/bit_cast.h>
#include <bit>
#include <base/types.h>
using BFloat16 = __bf16;
//using BFloat16 = __bf16;
class BFloat16
{
private:
UInt16 x = 0;
public:
constexpr BFloat16() = default;
constexpr BFloat16(const BFloat16 & other) = default;
constexpr BFloat16 & operator=(const BFloat16 & other) = default;
explicit constexpr BFloat16(const Float32 & other)
{
x = static_cast<UInt16>(std::bit_cast<UInt32>(other) >> 16);
}
template <typename T>
explicit constexpr BFloat16(const T & other)
: BFloat16(Float32(other))
{
}
template <typename T>
constexpr BFloat16 & operator=(const T & other)
{
*this = BFloat16(other);
return *this;
}
explicit constexpr operator Float32() const
{
return std::bit_cast<Float32>(static_cast<UInt32>(x) << 16);
}
template <typename T>
explicit constexpr operator T() const
{
return T(Float32(*this));
}
constexpr bool isFinite() const
{
return (x & 0b0111111110000000) != 0b0111111110000000;
}
constexpr bool isNaN() const
{
return !isFinite() && (x & 0b0000000001111111) != 0b0000000000000000;
}
constexpr bool signBit() const
{
return x & 0b1000000000000000;
}
constexpr bool operator==(const BFloat16 & other) const
{
return x == other.x;
}
constexpr bool operator!=(const BFloat16 & other) const
{
return x != other.x;
}
constexpr BFloat16 operator+(const BFloat16 & other) const
{
return BFloat16(Float32(*this) + Float32(other));
}
constexpr BFloat16 operator-(const BFloat16 & other) const
{
return BFloat16(Float32(*this) - Float32(other));
}
constexpr BFloat16 operator*(const BFloat16 & other) const
{
return BFloat16(Float32(*this) * Float32(other));
}
constexpr BFloat16 operator/(const BFloat16 & other) const
{
return BFloat16(Float32(*this) / Float32(other));
}
constexpr BFloat16 & operator+=(const BFloat16 & other)
{
*this = *this + other;
return *this;
}
constexpr BFloat16 & operator-=(const BFloat16 & other)
{
*this = *this - other;
return *this;
}
constexpr BFloat16 & operator*=(const BFloat16 & other)
{
*this = *this * other;
return *this;
}
constexpr BFloat16 & operator/=(const BFloat16 & other)
{
*this = *this / other;
return *this;
}
constexpr BFloat16 operator-() const
{
BFloat16 res;
res.x = x ^ 0b1000000000000000;
return res;
}
};
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator==(const BFloat16 & a, const T & b)
{
return Float32(a) == b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator==(const T & a, const BFloat16 & b)
{
return a == Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator!=(const BFloat16 & a, const T & b)
{
return Float32(a) != b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator!=(const T & a, const BFloat16 & b)
{
return a != Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator<(const BFloat16 & a, const T & b)
{
return Float32(a) < b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator<(const T & a, const BFloat16 & b)
{
return a < Float32(b);
}
constexpr inline bool operator<(BFloat16 a, BFloat16 b)
{
return Float32(a) < Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator>(const BFloat16 & a, const T & b)
{
return Float32(a) > b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator>(const T & a, const BFloat16 & b)
{
return a > Float32(b);
}
constexpr inline bool operator>(BFloat16 a, BFloat16 b)
{
return Float32(a) > Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator<=(const BFloat16 & a, const T & b)
{
return Float32(a) <= b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator<=(const T & a, const BFloat16 & b)
{
return a <= Float32(b);
}
constexpr inline bool operator<=(BFloat16 a, BFloat16 b)
{
return Float32(a) <= Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator>=(const BFloat16 & a, const T & b)
{
return Float32(a) >= b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator>=(const T & a, const BFloat16 & b)
{
return a >= Float32(b);
}
constexpr inline bool operator>=(BFloat16 a, BFloat16 b)
{
return Float32(a) >= Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator+(T a, BFloat16 b)
{
return a + Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator+(BFloat16 a, T b)
{
return Float32(a) + b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator-(T a, BFloat16 b)
{
return a - Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator-(BFloat16 a, T b)
{
return Float32(a) - b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator*(T a, BFloat16 b)
{
return a * Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator*(BFloat16 a, T b)
{
return Float32(a) * b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator/(T a, BFloat16 b)
{
return a / Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator/(BFloat16 a, T b)
{
return Float32(a) / b;
}
namespace std
{
inline constexpr bool isfinite(BFloat16 x) { return (bit_cast<UInt16>(x) & 0b0111111110000000) != 0b0111111110000000; }
inline constexpr bool signbit(BFloat16 x) { return bit_cast<UInt16>(x) & 0b1000000000000000; }
}
inline Float32 BFloat16ToFloat32(BFloat16 x)
{
return bit_cast<Float32>(static_cast<UInt32>(bit_cast<UInt16>(x)) << 16);
}
inline BFloat16 Float32ToBFloat16(Float32 x)
{
return bit_cast<BFloat16>(std::bit_cast<UInt32>(x) >> 16);
inline constexpr bool isfinite(BFloat16 x) { return x.isFinite(); }
inline constexpr bool isnan(BFloat16 x) { return x.isNaN(); }
inline constexpr bool signbit(BFloat16 x) { return x.signBit(); }
}

View File

@ -11,7 +11,7 @@
template <typename T> struct FloatTraits;
template <>
struct FloatTraits<__bf16>
struct FloatTraits<BFloat16>
{
using UInt = uint16_t;
static constexpr size_t bits = 16;

View File

@ -9,10 +9,11 @@ namespace DB
{
using TypeListNativeInt = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64>;
using TypeListFloat = TypeList<BFloat16, Float32, Float64>;
using TypeListNativeNumber = TypeListConcat<TypeListNativeInt, TypeListFloat>;
using TypeListNativeFloat = TypeList<Float32, Float64>;
using TypeListNativeNumber = TypeListConcat<TypeListNativeInt, TypeListNativeFloat>;
using TypeListWideInt = TypeList<UInt128, Int128, UInt256, Int256>;
using TypeListInt = TypeListConcat<TypeListNativeInt, TypeListWideInt>;
using TypeListFloat = TypeListConcat<TypeListNativeFloat, TypeList<BFloat16>>;
using TypeListIntAndFloat = TypeListConcat<TypeListInt, TypeListFloat>;
using TypeListDecimal = TypeList<Decimal32, Decimal64, Decimal128, Decimal256>;
using TypeListNumber = TypeListConcat<TypeListIntAndFloat, TypeListDecimal>;

View File

@ -231,7 +231,7 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
{
increment(place, static_cast<const ColVecType &>(*columns[0]).getData()[row_num]);
increment(place, Numerator(static_cast<const ColVecType &>(*columns[0]).getData()[row_num]));
++this->data(place).denominator;
}

View File

@ -27,9 +27,9 @@ namespace
template <typename T>
struct AggregationFunctionDeltaSumData
{
T sum = 0;
T last = 0;
T first = 0;
T sum{};
T last{};
T first{};
bool seen = false;
};

View File

@ -25,11 +25,11 @@ namespace
template <typename ValueType, typename TimestampType>
struct AggregationFunctionDeltaSumTimestampData
{
ValueType sum = 0;
ValueType first = 0;
ValueType last = 0;
TimestampType first_ts = 0;
TimestampType last_ts = 0;
ValueType sum{};
ValueType first{};
ValueType last{};
TimestampType first_ts{};
TimestampType last_ts{};
bool seen = false;
};

View File

@ -155,9 +155,9 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{
Int64 current_intersections = 0;
Int64 max_intersections = 0;
PointType position_of_max_intersections = 0;
Int64 current_intersections{};
Int64 max_intersections{};
PointType position_of_max_intersections{};
/// const_cast because we will sort the array
auto & array = this->data(place).value;

View File

@ -45,7 +45,7 @@ struct AggregateFunctionSparkbarData
Y insert(const X & x, const Y & y)
{
if (isNaN(y) || y <= 0)
return 0;
return {};
auto [it, inserted] = points.insert({x, y});
if (!inserted)
@ -173,13 +173,13 @@ private:
if (from_x >= to_x)
{
size_t sz = updateFrame(values, 8);
size_t sz = updateFrame(values, Y{8});
values.push_back('\0');
offsets.push_back(offsets.empty() ? sz + 1 : offsets.back() + sz + 1);
return;
}
PaddedPODArray<Y> histogram(width, 0);
PaddedPODArray<Y> histogram(width, Y{0});
PaddedPODArray<UInt64> count_histogram(width, 0); /// The number of points in each bucket
for (const auto & point : data.points)
@ -218,10 +218,10 @@ private:
for (size_t i = 0; i < histogram.size(); ++i)
{
if (count_histogram[i] > 0)
histogram[i] /= count_histogram[i];
histogram[i] = histogram[i] / count_histogram[i];
}
Y y_max = 0;
Y y_max{};
for (auto & y : histogram)
{
if (isNaN(y) || y <= 0)
@ -245,7 +245,7 @@ private:
continue;
}
constexpr auto levels_num = static_cast<Y>(BAR_LEVELS - 1);
constexpr auto levels_num = Y{BAR_LEVELS - 1};
if constexpr (is_floating_point<Y>)
{
y = y / (y_max / levels_num) + 1;

View File

@ -83,7 +83,7 @@ struct AggregateFunctionSumData
while (ptr < unrolled_end)
{
for (size_t i = 0; i < unroll_count; ++i)
Impl::add(partial_sums[i], ptr[i]);
Impl::add(partial_sums[i], T(ptr[i]));
ptr += unroll_count;
}
@ -95,7 +95,7 @@ struct AggregateFunctionSumData
T local_sum{};
while (ptr < end_ptr)
{
Impl::add(local_sum, *ptr);
Impl::add(local_sum, T(*ptr));
++ptr;
}
Impl::add(sum, local_sum);
@ -227,7 +227,7 @@ struct AggregateFunctionSumData
while (ptr < end_ptr)
{
if (!*condition_map == add_if_zero)
Impl::add(local_sum, *ptr);
Impl::add(local_sum, T(*ptr));
++ptr;
++condition_map;
}
@ -488,10 +488,7 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
const auto & column = assert_cast<const ColVecType &>(*columns[0]);
if constexpr (is_big_int_v<T>)
this->data(place).add(static_cast<TResult>(column.getData()[row_num]));
else
this->data(place).add(column.getData()[row_num]);
}
void addBatchSinglePlace(

View File

@ -21,7 +21,7 @@ using Int128 = wide::integer<128, signed>;
using UInt128 = wide::integer<128, unsigned>;
using Int256 = wide::integer<256, signed>;
using UInt256 = wide::integer<256, unsigned>;
using BFloat16 = __bf16;
class BFloat16;
namespace DB
{

View File

@ -298,7 +298,7 @@ public:
static VectorType prepare(size_t scale)
{
return load1(scale);
return load1(ScalarType(scale));
}
};

View File

@ -583,7 +583,7 @@ struct CallPointInPolygon<Type, Types ...>
template <typename PointInPolygonImpl>
static ColumnPtr call(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl)
{
using Impl = TypeListChangeRoot<CallPointInPolygon, TypeListIntAndFloat>;
using Impl = TypeListChangeRoot<CallPointInPolygon, TypeListNativeNumber>;
if (auto column = typeid_cast<const ColumnVector<Type> *>(&x))
return Impl::template call<Type>(*column, y, impl);
return CallPointInPolygon<Types ...>::call(x, y, impl);
@ -609,7 +609,7 @@ struct CallPointInPolygon<>
template <typename PointInPolygonImpl>
NO_INLINE ColumnPtr pointInPolygon(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl)
{
using Impl = TypeListChangeRoot<CallPointInPolygon, TypeListIntAndFloat>;
using Impl = TypeListChangeRoot<CallPointInPolygon, TypeListNativeNumber>;
return Impl::call(x, y, impl);
}

View File

@ -18,7 +18,7 @@ struct DivideFloatingImpl
template <typename Result = ResultType>
static NO_SANITIZE_UNDEFINED Result apply(A a [[maybe_unused]], B b [[maybe_unused]])
{
return static_cast<Result>(a) / b;
return static_cast<Result>(a) / static_cast<Result>(b);
}
#if USE_EMBEDDED_COMPILER

View File

@ -174,7 +174,7 @@ inline size_t writeFloatTextFastPath(T x, char * buffer)
}
else if constexpr (std::is_same_v<T, BFloat16>)
{
Float32 f32 = BFloat16ToFloat32(x);
Float32 f32 = Float32(x);
if (DecomposedFloat32(f32).isIntegerInRepresentableRange())
result = itoa(Int32(f32), buffer) - buffer;