fix conflicts

This commit is contained in:
taiyang-li 2024-11-15 09:36:15 +08:00
commit bcc2a31bac
165 changed files with 1679 additions and 545 deletions

313
base/base/BFloat16.h Normal file
View File

@ -0,0 +1,313 @@
#pragma once
#include <bit>
#include <base/types.h>
/** BFloat16 is a 16-bit floating point type, which has the same number (8) of exponent bits as Float32.
* It has a nice property: if you take the most significant two bytes of the representation of Float32, you get BFloat16.
* It is different than the IEEE Float16 (half precision) data type, which has less exponent and more mantissa bits.
*
* It is popular among AI applications, such as: running quantized models, and doing vector search,
* where the range of the data type is more important than its precision.
*
* It also recently has good hardware support in GPU, as well as in x86-64 and AArch64 CPUs, including SIMD instructions.
* But it is rarely utilized by compilers.
*
* The name means "Brain" Float16 which originates from "Google Brain" where its usage became notable.
* It is also known under the name "bf16". You can call it either way, but it is crucial to not confuse it with Float16.
* Here is a manual implementation of this data type. Only required operations are implemented.
* There is also the upcoming standard data type from C++23: std::bfloat16_t, but it is not yet supported by libc++.
* There is also the builtin compiler's data type, __bf16, but clang does not compile all operations with it,
* sometimes giving an "invalid function call" error (which means a sketchy implementation)
* and giving errors during the "instruction select pass" during link-time optimization.
*
* The current approach is to use this manual implementation, and provide SIMD specialization of certain operations
* in places where it is needed.
*/
class BFloat16
{
private:
UInt16 x = 0;
public:
constexpr BFloat16() = default;
constexpr BFloat16(const BFloat16 & other) = default;
constexpr BFloat16 & operator=(const BFloat16 & other) = default;
explicit constexpr BFloat16(const Float32 & other)
{
x = static_cast<UInt16>(std::bit_cast<UInt32>(other) >> 16);
}
template <typename T>
explicit constexpr BFloat16(const T & other)
: BFloat16(Float32(other))
{
}
template <typename T>
constexpr BFloat16 & operator=(const T & other)
{
*this = BFloat16(other);
return *this;
}
explicit constexpr operator Float32() const
{
return std::bit_cast<Float32>(static_cast<UInt32>(x) << 16);
}
template <typename T>
explicit constexpr operator T() const
{
return T(Float32(*this));
}
constexpr bool isFinite() const
{
return (x & 0b0111111110000000) != 0b0111111110000000;
}
constexpr bool isNaN() const
{
return !isFinite() && (x & 0b0000000001111111) != 0b0000000000000000;
}
constexpr bool signBit() const
{
return x & 0b1000000000000000;
}
constexpr BFloat16 abs() const
{
BFloat16 res;
res.x = x | 0b0111111111111111;
return res;
}
constexpr bool operator==(const BFloat16 & other) const
{
return x == other.x;
}
constexpr bool operator!=(const BFloat16 & other) const
{
return x != other.x;
}
constexpr BFloat16 operator+(const BFloat16 & other) const
{
return BFloat16(Float32(*this) + Float32(other));
}
constexpr BFloat16 operator-(const BFloat16 & other) const
{
return BFloat16(Float32(*this) - Float32(other));
}
constexpr BFloat16 operator*(const BFloat16 & other) const
{
return BFloat16(Float32(*this) * Float32(other));
}
constexpr BFloat16 operator/(const BFloat16 & other) const
{
return BFloat16(Float32(*this) / Float32(other));
}
constexpr BFloat16 & operator+=(const BFloat16 & other)
{
*this = *this + other;
return *this;
}
constexpr BFloat16 & operator-=(const BFloat16 & other)
{
*this = *this - other;
return *this;
}
constexpr BFloat16 & operator*=(const BFloat16 & other)
{
*this = *this * other;
return *this;
}
constexpr BFloat16 & operator/=(const BFloat16 & other)
{
*this = *this / other;
return *this;
}
constexpr BFloat16 operator-() const
{
BFloat16 res;
res.x = x ^ 0b1000000000000000;
return res;
}
};
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator==(const BFloat16 & a, const T & b)
{
return Float32(a) == b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator==(const T & a, const BFloat16 & b)
{
return a == Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator!=(const BFloat16 & a, const T & b)
{
return Float32(a) != b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator!=(const T & a, const BFloat16 & b)
{
return a != Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator<(const BFloat16 & a, const T & b)
{
return Float32(a) < b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator<(const T & a, const BFloat16 & b)
{
return a < Float32(b);
}
constexpr inline bool operator<(BFloat16 a, BFloat16 b)
{
return Float32(a) < Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator>(const BFloat16 & a, const T & b)
{
return Float32(a) > b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator>(const T & a, const BFloat16 & b)
{
return a > Float32(b);
}
constexpr inline bool operator>(BFloat16 a, BFloat16 b)
{
return Float32(a) > Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator<=(const BFloat16 & a, const T & b)
{
return Float32(a) <= b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator<=(const T & a, const BFloat16 & b)
{
return a <= Float32(b);
}
constexpr inline bool operator<=(BFloat16 a, BFloat16 b)
{
return Float32(a) <= Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator>=(const BFloat16 & a, const T & b)
{
return Float32(a) >= b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr bool operator>=(const T & a, const BFloat16 & b)
{
return a >= Float32(b);
}
constexpr inline bool operator>=(BFloat16 a, BFloat16 b)
{
return Float32(a) >= Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator+(T a, BFloat16 b)
{
return a + Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator+(BFloat16 a, T b)
{
return Float32(a) + b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator-(T a, BFloat16 b)
{
return a - Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator-(BFloat16 a, T b)
{
return Float32(a) - b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator*(T a, BFloat16 b)
{
return a * Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator*(BFloat16 a, T b)
{
return Float32(a) * b;
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator/(T a, BFloat16 b)
{
return a / Float32(b);
}
template <typename T>
requires(!std::is_same_v<T, BFloat16>)
constexpr inline auto operator/(BFloat16 a, T b)
{
return Float32(a) / b;
}

View File

@ -10,6 +10,15 @@
template <typename T> struct FloatTraits; template <typename T> struct FloatTraits;
template <>
struct FloatTraits<BFloat16>
{
using UInt = uint16_t;
static constexpr size_t bits = 16;
static constexpr size_t exponent_bits = 8;
static constexpr size_t mantissa_bits = bits - exponent_bits - 1;
};
template <> template <>
struct FloatTraits<float> struct FloatTraits<float>
{ {
@ -87,6 +96,15 @@ struct DecomposedFloat
&& ((mantissa() & ((1ULL << (Traits::mantissa_bits - normalizedExponent())) - 1)) == 0)); && ((mantissa() & ((1ULL << (Traits::mantissa_bits - normalizedExponent())) - 1)) == 0));
} }
bool isFinite() const
{
return exponent() != ((1ull << Traits::exponent_bits) - 1);
}
bool isNaN() const
{
return !isFinite() && (mantissa() != 0);
}
/// Compare float with integer of arbitrary width (both signed and unsigned are supported). Assuming two's complement arithmetic. /// Compare float with integer of arbitrary width (both signed and unsigned are supported). Assuming two's complement arithmetic.
/// This function is generic, big integers (128, 256 bit) are supported as well. /// This function is generic, big integers (128, 256 bit) are supported as well.
@ -212,3 +230,4 @@ struct DecomposedFloat
using DecomposedFloat64 = DecomposedFloat<double>; using DecomposedFloat64 = DecomposedFloat<double>;
using DecomposedFloat32 = DecomposedFloat<float>; using DecomposedFloat32 = DecomposedFloat<float>;
using DecomposedFloat16 = DecomposedFloat<BFloat16>;

View File

@ -4,7 +4,7 @@
#include <fmt/format.h> #include <fmt/format.h>
template <class T> concept is_enum = std::is_enum_v<T>; template <typename T> concept is_enum = std::is_enum_v<T>;
namespace detail namespace detail
{ {

View File

@ -9,10 +9,11 @@ namespace DB
{ {
using TypeListNativeInt = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64>; using TypeListNativeInt = TypeList<UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64>;
using TypeListFloat = TypeList<Float32, Float64>; using TypeListNativeFloat = TypeList<Float32, Float64>;
using TypeListNativeNumber = TypeListConcat<TypeListNativeInt, TypeListFloat>; using TypeListNativeNumber = TypeListConcat<TypeListNativeInt, TypeListNativeFloat>;
using TypeListWideInt = TypeList<UInt128, Int128, UInt256, Int256>; using TypeListWideInt = TypeList<UInt128, Int128, UInt256, Int256>;
using TypeListInt = TypeListConcat<TypeListNativeInt, TypeListWideInt>; using TypeListInt = TypeListConcat<TypeListNativeInt, TypeListWideInt>;
using TypeListFloat = TypeListConcat<TypeListNativeFloat, TypeList<BFloat16>>;
using TypeListIntAndFloat = TypeListConcat<TypeListInt, TypeListFloat>; using TypeListIntAndFloat = TypeListConcat<TypeListInt, TypeListFloat>;
using TypeListDecimal = TypeList<Decimal32, Decimal64, Decimal128, Decimal256>; using TypeListDecimal = TypeList<Decimal32, Decimal64, Decimal128, Decimal256>;
using TypeListNumber = TypeListConcat<TypeListIntAndFloat, TypeListDecimal>; using TypeListNumber = TypeListConcat<TypeListIntAndFloat, TypeListDecimal>;

View File

@ -32,6 +32,7 @@ TN_MAP(Int32)
TN_MAP(Int64) TN_MAP(Int64)
TN_MAP(Int128) TN_MAP(Int128)
TN_MAP(Int256) TN_MAP(Int256)
TN_MAP(BFloat16)
TN_MAP(Float32) TN_MAP(Float32)
TN_MAP(Float64) TN_MAP(Float64)
TN_MAP(String) TN_MAP(String)

View File

@ -4,6 +4,8 @@
#include <base/types.h> #include <base/types.h>
#include <base/wide_integer.h> #include <base/wide_integer.h>
#include <base/BFloat16.h>
using Int128 = wide::integer<128, signed>; using Int128 = wide::integer<128, signed>;
using UInt128 = wide::integer<128, unsigned>; using UInt128 = wide::integer<128, unsigned>;
@ -24,6 +26,7 @@ struct is_signed // NOLINT(readability-identifier-naming)
template <> struct is_signed<Int128> { static constexpr bool value = true; }; template <> struct is_signed<Int128> { static constexpr bool value = true; };
template <> struct is_signed<Int256> { static constexpr bool value = true; }; template <> struct is_signed<Int256> { static constexpr bool value = true; };
template <> struct is_signed<BFloat16> { static constexpr bool value = true; };
template <typename T> template <typename T>
inline constexpr bool is_signed_v = is_signed<T>::value; inline constexpr bool is_signed_v = is_signed<T>::value;
@ -40,15 +43,13 @@ template <> struct is_unsigned<UInt256> { static constexpr bool value = true; };
template <typename T> template <typename T>
inline constexpr bool is_unsigned_v = is_unsigned<T>::value; inline constexpr bool is_unsigned_v = is_unsigned<T>::value;
template <class T> concept is_integer = template <typename T> concept is_integer =
std::is_integral_v<T> std::is_integral_v<T>
|| std::is_same_v<T, Int128> || std::is_same_v<T, Int128>
|| std::is_same_v<T, UInt128> || std::is_same_v<T, UInt128>
|| std::is_same_v<T, Int256> || std::is_same_v<T, Int256>
|| std::is_same_v<T, UInt256>; || std::is_same_v<T, UInt256>;
template <class T> concept is_floating_point = std::is_floating_point_v<T>;
template <typename T> template <typename T>
struct is_arithmetic // NOLINT(readability-identifier-naming) struct is_arithmetic // NOLINT(readability-identifier-naming)
{ {
@ -59,11 +60,16 @@ template <> struct is_arithmetic<Int128> { static constexpr bool value = true; }
template <> struct is_arithmetic<UInt128> { static constexpr bool value = true; }; template <> struct is_arithmetic<UInt128> { static constexpr bool value = true; };
template <> struct is_arithmetic<Int256> { static constexpr bool value = true; }; template <> struct is_arithmetic<Int256> { static constexpr bool value = true; };
template <> struct is_arithmetic<UInt256> { static constexpr bool value = true; }; template <> struct is_arithmetic<UInt256> { static constexpr bool value = true; };
template <> struct is_arithmetic<BFloat16> { static constexpr bool value = true; };
template <typename T> template <typename T>
inline constexpr bool is_arithmetic_v = is_arithmetic<T>::value; inline constexpr bool is_arithmetic_v = is_arithmetic<T>::value;
template <typename T> concept is_floating_point =
std::is_floating_point_v<T>
|| std::is_same_v<T, BFloat16>;
#define FOR_EACH_ARITHMETIC_TYPE(M) \ #define FOR_EACH_ARITHMETIC_TYPE(M) \
M(DataTypeDate) \ M(DataTypeDate) \
M(DataTypeDate32) \ M(DataTypeDate32) \
@ -80,6 +86,7 @@ inline constexpr bool is_arithmetic_v = is_arithmetic<T>::value;
M(DataTypeUInt128) \ M(DataTypeUInt128) \
M(DataTypeInt256) \ M(DataTypeInt256) \
M(DataTypeUInt256) \ M(DataTypeUInt256) \
M(DataTypeBFloat16) \
M(DataTypeFloat32) \ M(DataTypeFloat32) \
M(DataTypeFloat64) M(DataTypeFloat64)
@ -99,6 +106,7 @@ inline constexpr bool is_arithmetic_v = is_arithmetic<T>::value;
M(DataTypeUInt128, X) \ M(DataTypeUInt128, X) \
M(DataTypeInt256, X) \ M(DataTypeInt256, X) \
M(DataTypeUInt256, X) \ M(DataTypeUInt256, X) \
M(DataTypeBFloat16, X) \
M(DataTypeFloat32, X) \ M(DataTypeFloat32, X) \
M(DataTypeFloat64, X) M(DataTypeFloat64, X)

View File

@ -3131,3 +3131,4 @@ DistributedCachePoolBehaviourOnLimit
SharedJoin SharedJoin
ShareSet ShareSet
unacked unacked
BFloat

View File

@ -85,7 +85,7 @@ elseif (ARCH_AARCH64)
# [8] https://developer.arm.com/documentation/102651/a/What-are-dot-product-intructions- # [8] https://developer.arm.com/documentation/102651/a/What-are-dot-product-intructions-
# [9] https://developer.arm.com/documentation/dui0801/g/A64-Data-Transfer-Instructions/LDAPR?lang=en # [9] https://developer.arm.com/documentation/dui0801/g/A64-Data-Transfer-Instructions/LDAPR?lang=en
# [10] https://github.com/aws/aws-graviton-getting-started/blob/main/README.md # [10] https://github.com/aws/aws-graviton-getting-started/blob/main/README.md
set (COMPILER_FLAGS "${COMPILER_FLAGS} -march=armv8.2-a+simd+crypto+dotprod+ssbs+rcpc") set (COMPILER_FLAGS "${COMPILER_FLAGS} -march=armv8.2-a+simd+crypto+dotprod+ssbs+rcpc+bf16")
endif () endif ()
# Best-effort check: The build generates and executes intermediate binaries, e.g. protoc and llvm-tablegen. If we build on ARM for ARM # Best-effort check: The build generates and executes intermediate binaries, e.g. protoc and llvm-tablegen. If we build on ARM for ARM

View File

@ -3,8 +3,7 @@
set (DEFAULT_LIBS "-nodefaultlibs") set (DEFAULT_LIBS "-nodefaultlibs")
# We need builtins from Clang's RT even without libcxx - for ubsan+int128. # We need builtins from Clang
# See https://bugs.llvm.org/show_bug.cgi?id=16404
execute_process (COMMAND execute_process (COMMAND
${CMAKE_CXX_COMPILER} --target=${CMAKE_CXX_COMPILER_TARGET} --print-libgcc-file-name --rtlib=compiler-rt ${CMAKE_CXX_COMPILER} --target=${CMAKE_CXX_COMPILER_TARGET} --print-libgcc-file-name --rtlib=compiler-rt
OUTPUT_VARIABLE BUILTINS_LIBRARY OUTPUT_VARIABLE BUILTINS_LIBRARY

View File

@ -1,10 +1,10 @@
--- ---
slug: /en/sql-reference/data-types/float slug: /en/sql-reference/data-types/float
sidebar_position: 4 sidebar_position: 4
sidebar_label: Float32, Float64 sidebar_label: Float32, Float64, BFloat16
--- ---
# Float32, Float64 # Float32, Float64, BFloat16
:::note :::note
If you need accurate calculations, in particular if you work with financial or business data requiring a high precision, you should consider using [Decimal](../data-types/decimal.md) instead. If you need accurate calculations, in particular if you work with financial or business data requiring a high precision, you should consider using [Decimal](../data-types/decimal.md) instead.
@ -117,3 +117,11 @@ SELECT 0 / 0
``` ```
See the rules for `NaN` sorting in the section [ORDER BY clause](../../sql-reference/statements/select/order-by.md). See the rules for `NaN` sorting in the section [ORDER BY clause](../../sql-reference/statements/select/order-by.md).
## BFloat16
`BFloat16` is a 16-bit floating point data type with 8-bit exponent, sign, and 7-bit mantissa.
It is useful for machine learning and AI applications.
ClickHouse supports conversions between `Float32` and `BFloat16`. Most of other operations are not supported.

View File

@ -7,7 +7,6 @@
#include <random> #include <random>
#include <string_view> #include <string_view>
#include <pcg_random.hpp> #include <pcg_random.hpp>
#include <Poco/UUID.h>
#include <Poco/UUIDGenerator.h> #include <Poco/UUIDGenerator.h>
#include <Poco/Util/Application.h> #include <Poco/Util/Application.h>
#include <Common/Stopwatch.h> #include <Common/Stopwatch.h>
@ -152,8 +151,6 @@ public:
global_context->setClientName(std::string(DEFAULT_CLIENT_NAME)); global_context->setClientName(std::string(DEFAULT_CLIENT_NAME));
global_context->setQueryKindInitial(); global_context->setQueryKindInitial();
std::cerr << std::fixed << std::setprecision(3);
/// This is needed to receive blocks with columns of AggregateFunction data type /// This is needed to receive blocks with columns of AggregateFunction data type
/// (example: when using stage = 'with_mergeable_state') /// (example: when using stage = 'with_mergeable_state')
registerAggregateFunctions(); registerAggregateFunctions();
@ -226,6 +223,8 @@ private:
ContextMutablePtr global_context; ContextMutablePtr global_context;
QueryProcessingStage::Enum query_processing_stage; QueryProcessingStage::Enum query_processing_stage;
WriteBufferFromFileDescriptor log{STDERR_FILENO};
std::atomic<size_t> consecutive_errors{0}; std::atomic<size_t> consecutive_errors{0};
/// Don't execute new queries after timelimit or SIGINT or exception /// Don't execute new queries after timelimit or SIGINT or exception
@ -303,16 +302,16 @@ private:
} }
std::cerr << "Loaded " << queries.size() << " queries.\n"; log << "Loaded " << queries.size() << " queries.\n" << flush;
} }
void printNumberOfQueriesExecuted(size_t num) void printNumberOfQueriesExecuted(size_t num)
{ {
std::cerr << "\nQueries executed: " << num; log << "\nQueries executed: " << num;
if (queries.size() > 1) if (queries.size() > 1)
std::cerr << " (" << (num * 100.0 / queries.size()) << "%)"; log << " (" << (num * 100.0 / queries.size()) << "%)";
std::cerr << ".\n"; log << ".\n" << flush;
} }
/// Try push new query and check cancellation conditions /// Try push new query and check cancellation conditions
@ -339,9 +338,10 @@ private:
if (interrupt_listener.check()) if (interrupt_listener.check())
{ {
std::cout << "Stopping launch of queries. SIGINT received." << std::endl; std::cout << "Stopping launch of queries. SIGINT received.\n";
return false; return false;
} }
}
double seconds = delay_watch.elapsedSeconds(); double seconds = delay_watch.elapsedSeconds();
if (delay > 0 && seconds > delay) if (delay > 0 && seconds > delay)
@ -352,7 +352,6 @@ private:
: report(comparison_info_per_interval, seconds); : report(comparison_info_per_interval, seconds);
delay_watch.restart(); delay_watch.restart();
} }
}
return true; return true;
} }
@ -438,16 +437,16 @@ private:
catch (...) catch (...)
{ {
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
std::cerr << "An error occurred while processing the query " << "'" << query << "'" log << "An error occurred while processing the query " << "'" << query << "'"
<< ": " << getCurrentExceptionMessage(false) << std::endl; << ": " << getCurrentExceptionMessage(false) << '\n';
if (!(continue_on_errors || max_consecutive_errors > ++consecutive_errors)) if (!(continue_on_errors || max_consecutive_errors > ++consecutive_errors))
{ {
shutdown = true; shutdown = true;
throw; throw;
} }
std::cerr << getCurrentExceptionMessage(print_stacktrace, log << getCurrentExceptionMessage(print_stacktrace,
true /*check embedded stack trace*/) << std::endl; true /*check embedded stack trace*/) << '\n' << flush;
size_t info_index = round_robin ? 0 : connection_index; size_t info_index = round_robin ? 0 : connection_index;
++comparison_info_per_interval[info_index]->errors; ++comparison_info_per_interval[info_index]->errors;
@ -504,7 +503,7 @@ private:
{ {
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
std::cerr << "\n"; log << "\n";
for (size_t i = 0; i < infos.size(); ++i) for (size_t i = 0; i < infos.size(); ++i)
{ {
const auto & info = infos[i]; const auto & info = infos[i];
@ -524,31 +523,31 @@ private:
connection_description += conn->getDescription(); connection_description += conn->getDescription();
} }
} }
std::cerr log
<< connection_description << ", " << connection_description << ", "
<< "queries: " << info->queries << ", "; << "queries: " << info->queries.load() << ", ";
if (info->errors) if (info->errors)
{ {
std::cerr << "errors: " << info->errors << ", "; log << "errors: " << info->errors << ", ";
} }
std::cerr log
<< "QPS: " << (info->queries / seconds) << ", " << "QPS: " << fmt::format("{:.3f}", info->queries / seconds) << ", "
<< "RPS: " << (info->read_rows / seconds) << ", " << "RPS: " << fmt::format("{:.3f}", info->read_rows / seconds) << ", "
<< "MiB/s: " << (info->read_bytes / seconds / 1048576) << ", " << "MiB/s: " << fmt::format("{:.3f}", info->read_bytes / seconds / 1048576) << ", "
<< "result RPS: " << (info->result_rows / seconds) << ", " << "result RPS: " << fmt::format("{:.3f}", info->result_rows / seconds) << ", "
<< "result MiB/s: " << (info->result_bytes / seconds / 1048576) << "." << "result MiB/s: " << fmt::format("{:.3f}", info->result_bytes / seconds / 1048576) << "."
<< "\n"; << "\n";
} }
std::cerr << "\n"; log << "\n";
auto print_percentile = [&](double percent) auto print_percentile = [&](double percent)
{ {
std::cerr << percent << "%\t\t"; log << percent << "%\t\t";
for (const auto & info : infos) for (const auto & info : infos)
{ {
std::cerr << info->sampler.quantileNearest(percent / 100.0) << " sec.\t"; log << fmt::format("{:.3f}", info->sampler.quantileNearest(percent / 100.0)) << " sec.\t";
} }
std::cerr << "\n"; log << "\n";
}; };
for (int percent = 0; percent <= 90; percent += 10) for (int percent = 0; percent <= 90; percent += 10)
@ -559,13 +558,15 @@ private:
print_percentile(99.9); print_percentile(99.9);
print_percentile(99.99); print_percentile(99.99);
std::cerr << "\n" << t_test.compareAndReport(confidence).second << "\n"; log << "\n" << t_test.compareAndReport(confidence).second << "\n";
if (!cumulative) if (!cumulative)
{ {
for (auto & info : infos) for (auto & info : infos)
info->clear(); info->clear();
} }
log.next();
} }
public: public:
@ -741,7 +742,7 @@ int mainEntryClickHouseBenchmark(int argc, char ** argv)
} }
catch (...) catch (...)
{ {
std::cerr << getCurrentExceptionMessage(print_stacktrace, true) << std::endl; std::cerr << getCurrentExceptionMessage(print_stacktrace, true) << '\n';
return getCurrentExceptionCode(); return getCurrentExceptionCode();
} }
} }

View File

@ -231,7 +231,7 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const final
{ {
increment(place, static_cast<const ColVecType &>(*columns[0]).getData()[row_num]); increment(place, Numerator(static_cast<const ColVecType &>(*columns[0]).getData()[row_num]));
++this->data(place).denominator; ++this->data(place).denominator;
} }

View File

@ -27,9 +27,9 @@ namespace
template <typename T> template <typename T>
struct AggregationFunctionDeltaSumData struct AggregationFunctionDeltaSumData
{ {
T sum = 0; T sum{};
T last = 0; T last{};
T first = 0; T first{};
bool seen = false; bool seen = false;
}; };

View File

@ -22,21 +22,14 @@ namespace ErrorCodes
namespace namespace
{ {
/** Due to a lack of proper code review, this code was contributed with a multiplication of template instantiations
* over all pairs of data types, and we deeply regret that.
*
* We cannot remove all combinations, because the binary representation of serialized data has to remain the same,
* but we can partially heal the wound by treating unsigned and signed data types in the same way.
*/
template <typename ValueType, typename TimestampType> template <typename ValueType, typename TimestampType>
struct AggregationFunctionDeltaSumTimestampData struct AggregationFunctionDeltaSumTimestampData
{ {
ValueType sum = 0; ValueType sum{};
ValueType first = 0; ValueType first{};
ValueType last = 0; ValueType last{};
TimestampType first_ts = 0; TimestampType first_ts{};
TimestampType last_ts = 0; TimestampType last_ts{};
bool seen = false; bool seen = false;
}; };
@ -44,22 +37,23 @@ template <typename ValueType, typename TimestampType>
class AggregationFunctionDeltaSumTimestamp final class AggregationFunctionDeltaSumTimestamp final
: public IAggregateFunctionDataHelper< : public IAggregateFunctionDataHelper<
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>, AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>> AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
>
{ {
public: public:
AggregationFunctionDeltaSumTimestamp(const DataTypes & arguments, const Array & params) AggregationFunctionDeltaSumTimestamp(const DataTypes & arguments, const Array & params)
: IAggregateFunctionDataHelper< : IAggregateFunctionDataHelper<
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>, AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>>{arguments, params, createResultType()} AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
{ >{arguments, params, createResultType()}
} {}
AggregationFunctionDeltaSumTimestamp() AggregationFunctionDeltaSumTimestamp()
: IAggregateFunctionDataHelper< : IAggregateFunctionDataHelper<
AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>, AggregationFunctionDeltaSumTimestampData<ValueType, TimestampType>,
AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>>{} AggregationFunctionDeltaSumTimestamp<ValueType, TimestampType>
{ >{}
} {}
bool allocatesMemoryInArena() const override { return false; } bool allocatesMemoryInArena() const override { return false; }
@ -69,8 +63,8 @@ public:
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {
auto value = unalignedLoad<ValueType>(columns[0]->getRawData().data() + row_num * sizeof(ValueType)); auto value = assert_cast<const ColumnVector<ValueType> &>(*columns[0]).getData()[row_num];
auto ts = unalignedLoad<TimestampType>(columns[1]->getRawData().data() + row_num * sizeof(TimestampType)); auto ts = assert_cast<const ColumnVector<TimestampType> &>(*columns[1]).getData()[row_num];
auto & data = this->data(place); auto & data = this->data(place);
@ -178,48 +172,10 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{ {
static_cast<ColumnFixedSizeHelper &>(to).template insertRawData<sizeof(ValueType)>( assert_cast<ColumnVector<ValueType> &>(to).getData().push_back(this->data(place).sum);
reinterpret_cast<const char *>(&this->data(place).sum));
} }
}; };
template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
IAggregateFunction * createWithTwoTypesSecond(const IDataType & second_type, TArgs && ... args)
{
WhichDataType which(second_type);
if (which.idx == TypeIndex::UInt32) return new AggregateFunctionTemplate<FirstType, UInt32>(args...);
if (which.idx == TypeIndex::UInt64) return new AggregateFunctionTemplate<FirstType, UInt64>(args...);
if (which.idx == TypeIndex::Int32) return new AggregateFunctionTemplate<FirstType, UInt32>(args...);
if (which.idx == TypeIndex::Int64) return new AggregateFunctionTemplate<FirstType, UInt64>(args...);
if (which.idx == TypeIndex::Float32) return new AggregateFunctionTemplate<FirstType, Float32>(args...);
if (which.idx == TypeIndex::Float64) return new AggregateFunctionTemplate<FirstType, Float64>(args...);
if (which.idx == TypeIndex::Date) return new AggregateFunctionTemplate<FirstType, UInt16>(args...);
if (which.idx == TypeIndex::DateTime) return new AggregateFunctionTemplate<FirstType, UInt32>(args...);
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
IAggregateFunction * createWithTwoTypes(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
{
WhichDataType which(first_type);
if (which.idx == TypeIndex::UInt8) return createWithTwoTypesSecond<UInt8, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::UInt16) return createWithTwoTypesSecond<UInt16, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::UInt32) return createWithTwoTypesSecond<UInt32, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::UInt64) return createWithTwoTypesSecond<UInt64, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::Int8) return createWithTwoTypesSecond<UInt8, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::Int16) return createWithTwoTypesSecond<UInt16, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::Int32) return createWithTwoTypesSecond<UInt32, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::Int64) return createWithTwoTypesSecond<UInt64, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::Float32) return createWithTwoTypesSecond<Float32, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::Float64) return createWithTwoTypesSecond<Float64, AggregateFunctionTemplate>(second_type, args...);
return nullptr;
}
AggregateFunctionPtr createAggregateFunctionDeltaSumTimestamp( AggregateFunctionPtr createAggregateFunctionDeltaSumTimestamp(
const String & name, const String & name,
const DataTypes & arguments, const DataTypes & arguments,
@ -237,7 +193,7 @@ AggregateFunctionPtr createAggregateFunctionDeltaSumTimestamp(
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}, " throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}, "
"must be Int, Float, Date, DateTime", arguments[1]->getName(), name); "must be Int, Float, Date, DateTime", arguments[1]->getName(), name);
return AggregateFunctionPtr(createWithTwoTypes<AggregationFunctionDeltaSumTimestamp>( return AggregateFunctionPtr(createWithTwoNumericOrDateTypes<AggregationFunctionDeltaSumTimestamp>(
*arguments[0], *arguments[1], arguments, params)); *arguments[0], *arguments[1], arguments, params));
} }
} }

View File

@ -79,7 +79,7 @@ template <typename T>
struct GroupArraySamplerData struct GroupArraySamplerData
{ {
/// For easy serialization. /// For easy serialization.
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>); static_assert(std::has_unique_object_representations_v<T> || is_floating_point<T>);
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>; using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
@ -120,7 +120,7 @@ template <typename T>
struct GroupArrayNumericData<T, false> struct GroupArrayNumericData<T, false>
{ {
/// For easy serialization. /// For easy serialization.
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>); static_assert(std::has_unique_object_representations_v<T> || is_floating_point<T>);
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>; using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;

View File

@ -38,7 +38,7 @@ template <typename T>
struct MovingData struct MovingData
{ {
/// For easy serialization. /// For easy serialization.
static_assert(std::has_unique_object_representations_v<T> || std::is_floating_point_v<T>); static_assert(std::has_unique_object_representations_v<T> || is_floating_point<T>);
using Accumulator = T; using Accumulator = T;

View File

@ -187,7 +187,7 @@ public:
static DataTypePtr createResultType() static DataTypePtr createResultType()
{ {
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
return std::make_shared<DataTypeFloat64>(); return std::make_shared<DataTypeFloat64>();
return std::make_shared<DataTypeUInt64>(); return std::make_shared<DataTypeUInt64>();
} }
@ -227,7 +227,7 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{ {
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
assert_cast<ColumnFloat64 &>(to).getData().push_back(getIntervalLengthSum<Float64>(this->data(place))); assert_cast<ColumnFloat64 &>(to).getData().push_back(getIntervalLengthSum<Float64>(this->data(place)));
else else
assert_cast<ColumnUInt64 &>(to).getData().push_back(getIntervalLengthSum<UInt64>(this->data(place))); assert_cast<ColumnUInt64 &>(to).getData().push_back(getIntervalLengthSum<UInt64>(this->data(place)));

View File

@ -155,9 +155,9 @@ public:
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
{ {
Int64 current_intersections = 0; Int64 current_intersections{};
Int64 max_intersections = 0; Int64 max_intersections{};
PointType position_of_max_intersections = 0; PointType position_of_max_intersections{};
/// const_cast because we will sort the array /// const_cast because we will sort the array
auto & array = this->data(place).value; auto & array = this->data(place).value;

View File

@ -45,12 +45,12 @@ struct AggregateFunctionSparkbarData
Y insert(const X & x, const Y & y) Y insert(const X & x, const Y & y)
{ {
if (isNaN(y) || y <= 0) if (isNaN(y) || y <= 0)
return 0; return {};
auto [it, inserted] = points.insert({x, y}); auto [it, inserted] = points.insert({x, y});
if (!inserted) if (!inserted)
{ {
if constexpr (std::is_floating_point_v<Y>) if constexpr (is_floating_point<Y>)
{ {
it->getMapped() += y; it->getMapped() += y;
return it->getMapped(); return it->getMapped();
@ -173,13 +173,13 @@ private:
if (from_x >= to_x) if (from_x >= to_x)
{ {
size_t sz = updateFrame(values, 8); size_t sz = updateFrame(values, Y{8});
values.push_back('\0'); values.push_back('\0');
offsets.push_back(offsets.empty() ? sz + 1 : offsets.back() + sz + 1); offsets.push_back(offsets.empty() ? sz + 1 : offsets.back() + sz + 1);
return; return;
} }
PaddedPODArray<Y> histogram(width, 0); PaddedPODArray<Y> histogram(width, Y{0});
PaddedPODArray<UInt64> count_histogram(width, 0); /// The number of points in each bucket PaddedPODArray<UInt64> count_histogram(width, 0); /// The number of points in each bucket
for (const auto & point : data.points) for (const auto & point : data.points)
@ -197,7 +197,7 @@ private:
Y res; Y res;
bool has_overfllow = false; bool has_overfllow = false;
if constexpr (std::is_floating_point_v<Y>) if constexpr (is_floating_point<Y>)
res = histogram[index] + point.getMapped(); res = histogram[index] + point.getMapped();
else else
has_overfllow = common::addOverflow(histogram[index], point.getMapped(), res); has_overfllow = common::addOverflow(histogram[index], point.getMapped(), res);
@ -218,10 +218,10 @@ private:
for (size_t i = 0; i < histogram.size(); ++i) for (size_t i = 0; i < histogram.size(); ++i)
{ {
if (count_histogram[i] > 0) if (count_histogram[i] > 0)
histogram[i] /= count_histogram[i]; histogram[i] = histogram[i] / count_histogram[i];
} }
Y y_max = 0; Y y_max{};
for (auto & y : histogram) for (auto & y : histogram)
{ {
if (isNaN(y) || y <= 0) if (isNaN(y) || y <= 0)
@ -245,8 +245,8 @@ private:
continue; continue;
} }
constexpr auto levels_num = static_cast<Y>(BAR_LEVELS - 1); constexpr auto levels_num = Y{BAR_LEVELS - 1};
if constexpr (std::is_floating_point_v<Y>) if constexpr (is_floating_point<Y>)
{ {
y = y / (y_max / levels_num) + 1; y = y / (y_max / levels_num) + 1;
} }

View File

@ -69,7 +69,7 @@ struct AggregateFunctionSumData
size_t count = end - start; size_t count = end - start;
const auto * end_ptr = ptr + count; const auto * end_ptr = ptr + count;
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
{ {
/// Compiler cannot unroll this loop, do it manually. /// Compiler cannot unroll this loop, do it manually.
/// (at least for floats, most likely due to the lack of -fassociative-math) /// (at least for floats, most likely due to the lack of -fassociative-math)
@ -83,7 +83,7 @@ struct AggregateFunctionSumData
while (ptr < unrolled_end) while (ptr < unrolled_end)
{ {
for (size_t i = 0; i < unroll_count; ++i) for (size_t i = 0; i < unroll_count; ++i)
Impl::add(partial_sums[i], ptr[i]); Impl::add(partial_sums[i], T(ptr[i]));
ptr += unroll_count; ptr += unroll_count;
} }
@ -95,7 +95,7 @@ struct AggregateFunctionSumData
T local_sum{}; T local_sum{};
while (ptr < end_ptr) while (ptr < end_ptr)
{ {
Impl::add(local_sum, *ptr); Impl::add(local_sum, T(*ptr));
++ptr; ++ptr;
} }
Impl::add(sum, local_sum); Impl::add(sum, local_sum);
@ -193,12 +193,11 @@ struct AggregateFunctionSumData
Impl::add(sum, local_sum); Impl::add(sum, local_sum);
return; return;
} }
else if constexpr (std::is_floating_point_v<T>) else if constexpr (is_floating_point<T> && (sizeof(Value) == 4 || sizeof(Value) == 8))
{ {
/// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned /// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned
/// integer of the same size and use a mask instead (0 to discard, 0xFF..FF to keep) /// integer of the same size and use a mask instead (0 to discard, 0xFF..FF to keep)
static_assert(sizeof(Value) == 4 || sizeof(Value) == 8); using EquivalentInteger = typename std::conditional_t<sizeof(Value) == 4, UInt32, UInt64>;
using equivalent_integer = typename std::conditional_t<sizeof(Value) == 4, UInt32, UInt64>;
constexpr size_t unroll_count = 128 / sizeof(T); constexpr size_t unroll_count = 128 / sizeof(T);
T partial_sums[unroll_count]{}; T partial_sums[unroll_count]{};
@ -209,11 +208,11 @@ struct AggregateFunctionSumData
{ {
for (size_t i = 0; i < unroll_count; ++i) for (size_t i = 0; i < unroll_count; ++i)
{ {
equivalent_integer value; EquivalentInteger value;
std::memcpy(&value, &ptr[i], sizeof(Value)); memcpy(&value, &ptr[i], sizeof(Value));
value &= (!condition_map[i] != add_if_zero) - 1; value &= (!condition_map[i] != add_if_zero) - 1;
Value d; Value d;
std::memcpy(&d, &value, sizeof(Value)); memcpy(&d, &value, sizeof(Value));
Impl::add(partial_sums[i], d); Impl::add(partial_sums[i], d);
} }
ptr += unroll_count; ptr += unroll_count;
@ -228,7 +227,7 @@ struct AggregateFunctionSumData
while (ptr < end_ptr) while (ptr < end_ptr)
{ {
if (!*condition_map == add_if_zero) if (!*condition_map == add_if_zero)
Impl::add(local_sum, *ptr); Impl::add(local_sum, T(*ptr));
++ptr; ++ptr;
++condition_map; ++condition_map;
} }
@ -306,7 +305,7 @@ struct AggregateFunctionSumData
template <typename T> template <typename T>
struct AggregateFunctionSumKahanData struct AggregateFunctionSumKahanData
{ {
static_assert(std::is_floating_point_v<T>, static_assert(is_floating_point<T>,
"It doesn't make sense to use Kahan Summation algorithm for non floating point types"); "It doesn't make sense to use Kahan Summation algorithm for non floating point types");
T sum{}; T sum{};
@ -489,10 +488,7 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{ {
const auto & column = assert_cast<const ColVecType &>(*columns[0]); const auto & column = assert_cast<const ColVecType &>(*columns[0]);
if constexpr (is_big_int_v<T>)
this->data(place).add(static_cast<TResult>(column.getData()[row_num])); this->data(place).add(static_cast<TResult>(column.getData()[row_num]));
else
this->data(place).add(column.getData()[row_num]);
} }
void addBatchSinglePlace( void addBatchSinglePlace(

View File

@ -257,7 +257,7 @@ template <typename T> struct AggregateFunctionUniqTraits
{ {
static UInt64 hash(T x) static UInt64 hash(T x)
{ {
if constexpr (std::is_same_v<T, Float32> || std::is_same_v<T, Float64>) if constexpr (is_floating_point<T>)
{ {
return bit_cast<UInt64>(x); return bit_cast<UInt64>(x);
} }

View File

@ -111,7 +111,7 @@ public:
/// Initially UInt128 was introduced only for UUID, and then the other big-integer types were added. /// Initially UInt128 was introduced only for UUID, and then the other big-integer types were added.
hash = static_cast<HashValueType>(sipHash64(value)); hash = static_cast<HashValueType>(sipHash64(value));
} }
else if constexpr (std::is_floating_point_v<T>) else if constexpr (is_floating_point<T>)
{ {
hash = static_cast<HashValueType>(intHash64(bit_cast<UInt64>(value))); hash = static_cast<HashValueType>(intHash64(bit_cast<UInt64>(value)));
} }

View File

@ -184,8 +184,36 @@ static IAggregateFunction * createWithDecimalType(const IDataType & argument_typ
} }
/** For template with two arguments. /** For template with two arguments.
* This is an extremely dangerous for code bloat - do not use.
*/ */
template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithTwoNumericTypesSecond(const IDataType & second_type, TArgs && ... args)
{
WhichDataType which(second_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) return new AggregateFunctionTemplate<FirstType, TYPE>(args...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) return new AggregateFunctionTemplate<FirstType, Int8>(args...);
if (which.idx == TypeIndex::Enum16) return new AggregateFunctionTemplate<FirstType, Int16>(args...);
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithTwoNumericTypes(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
{
WhichDataType which(first_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return createWithTwoNumericTypesSecond<TYPE, AggregateFunctionTemplate>(second_type, args...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8)
return createWithTwoNumericTypesSecond<Int8, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::Enum16)
return createWithTwoNumericTypesSecond<Int16, AggregateFunctionTemplate>(second_type, args...);
return nullptr;
}
template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate, typename... TArgs> template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithTwoBasicNumericTypesSecond(const IDataType & second_type, TArgs && ... args) static IAggregateFunction * createWithTwoBasicNumericTypesSecond(const IDataType & second_type, TArgs && ... args)
{ {
@ -209,6 +237,46 @@ static IAggregateFunction * createWithTwoBasicNumericTypes(const IDataType & fir
return nullptr; return nullptr;
} }
template <typename FirstType, template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithTwoNumericOrDateTypesSecond(const IDataType & second_type, TArgs && ... args)
{
WhichDataType which(second_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) return new AggregateFunctionTemplate<FirstType, TYPE>(args...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8) return new AggregateFunctionTemplate<FirstType, Int8>(args...);
if (which.idx == TypeIndex::Enum16) return new AggregateFunctionTemplate<FirstType, Int16>(args...);
/// expects that DataTypeDate based on UInt16, DataTypeDateTime based on UInt32
if (which.idx == TypeIndex::Date) return new AggregateFunctionTemplate<FirstType, UInt16>(args...);
if (which.idx == TypeIndex::DateTime) return new AggregateFunctionTemplate<FirstType, UInt32>(args...);
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithTwoNumericOrDateTypes(const IDataType & first_type, const IDataType & second_type, TArgs && ... args)
{
WhichDataType which(first_type);
#define DISPATCH(TYPE) \
if (which.idx == TypeIndex::TYPE) \
return createWithTwoNumericOrDateTypesSecond<TYPE, AggregateFunctionTemplate>(second_type, args...);
FOR_NUMERIC_TYPES(DISPATCH)
#undef DISPATCH
if (which.idx == TypeIndex::Enum8)
return createWithTwoNumericOrDateTypesSecond<Int8, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::Enum16)
return createWithTwoNumericOrDateTypesSecond<Int16, AggregateFunctionTemplate>(second_type, args...);
/// expects that DataTypeDate based on UInt16, DataTypeDateTime based on UInt32
if (which.idx == TypeIndex::Date)
return createWithTwoNumericOrDateTypesSecond<UInt16, AggregateFunctionTemplate>(second_type, args...);
if (which.idx == TypeIndex::DateTime)
return createWithTwoNumericOrDateTypesSecond<UInt32, AggregateFunctionTemplate>(second_type, args...);
return nullptr;
}
template <template <typename> class AggregateFunctionTemplate, typename... TArgs> template <template <typename> class AggregateFunctionTemplate, typename... TArgs>
static IAggregateFunction * createWithStringType(const IDataType & argument_type, TArgs && ... args) static IAggregateFunction * createWithStringType(const IDataType & argument_type, TArgs && ... args)
{ {

View File

@ -391,7 +391,7 @@ public:
ResultType getImpl(Float64 level) ResultType getImpl(Float64 level)
{ {
if (centroids.empty()) if (centroids.empty())
return std::is_floating_point_v<ResultType> ? std::numeric_limits<ResultType>::quiet_NaN() : 0; return is_floating_point<ResultType> ? std::numeric_limits<ResultType>::quiet_NaN() : 0;
compress(); compress();

View File

@ -276,6 +276,6 @@ private:
{ {
if (OnEmpty == ReservoirSamplerOnEmpty::THROW) if (OnEmpty == ReservoirSamplerOnEmpty::THROW)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSampler"); throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSampler");
return NanLikeValueConstructor<ResultType, std::is_floating_point_v<ResultType>>::getValue(); return NanLikeValueConstructor<ResultType, is_floating_point<ResultType>>::getValue();
} }
}; };

View File

@ -271,7 +271,7 @@ private:
{ {
if (OnEmpty == ReservoirSamplerDeterministicOnEmpty::THROW) if (OnEmpty == ReservoirSamplerDeterministicOnEmpty::THROW)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSamplerDeterministic"); throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Quantile of empty ReservoirSamplerDeterministic");
return NanLikeValueConstructor<ResultType, std::is_floating_point_v<ResultType>>::getValue(); return NanLikeValueConstructor<ResultType, is_floating_point<ResultType>>::getValue();
} }
}; };

View File

@ -121,8 +121,7 @@ BackupCoordinationStageSync::BackupCoordinationStageSync(
try try
{ {
concurrency_check.emplace(is_restore, /* on_cluster = */ true, zookeeper_path, allow_concurrency, concurrency_counters_); createStartAndAliveNodesAndCheckConcurrency(concurrency_counters_);
createStartAndAliveNodes();
startWatchingThread(); startWatchingThread();
} }
catch (...) catch (...)
@ -221,7 +220,7 @@ void BackupCoordinationStageSync::createRootNodes()
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected path in ZooKeeper specified: {}", zookeeper_path); throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected path in ZooKeeper specified: {}", zookeeper_path);
} }
auto holder = with_retries.createRetriesControlHolder("BackupStageSync::createRootNodes", WithRetries::kInitialization); auto holder = with_retries.createRetriesControlHolder("BackupCoordinationStageSync::createRootNodes", WithRetries::kInitialization);
holder.retries_ctl.retryLoop( holder.retries_ctl.retryLoop(
[&, &zookeeper = holder.faulty_zookeeper]() [&, &zookeeper = holder.faulty_zookeeper]()
{ {
@ -232,18 +231,22 @@ void BackupCoordinationStageSync::createRootNodes()
} }
void BackupCoordinationStageSync::createStartAndAliveNodes() void BackupCoordinationStageSync::createStartAndAliveNodesAndCheckConcurrency(BackupConcurrencyCounters & concurrency_counters_)
{ {
auto holder = with_retries.createRetriesControlHolder("BackupStageSync::createStartAndAliveNodes", WithRetries::kInitialization); auto holder = with_retries.createRetriesControlHolder("BackupCoordinationStageSync::createStartAndAliveNodes", WithRetries::kInitialization);
holder.retries_ctl.retryLoop([&, &zookeeper = holder.faulty_zookeeper]() holder.retries_ctl.retryLoop([&, &zookeeper = holder.faulty_zookeeper]()
{ {
with_retries.renewZooKeeper(zookeeper); with_retries.renewZooKeeper(zookeeper);
createStartAndAliveNodes(zookeeper); createStartAndAliveNodesAndCheckConcurrency(zookeeper);
}); });
/// The local concurrency check should be done here after BackupCoordinationStageSync::checkConcurrency() checked that
/// there are no 'alive' nodes corresponding to other backups or restores.
local_concurrency_check.emplace(is_restore, /* on_cluster = */ true, zookeeper_path, allow_concurrency, concurrency_counters_);
} }
void BackupCoordinationStageSync::createStartAndAliveNodes(Coordination::ZooKeeperWithFaultInjection::Ptr zookeeper) void BackupCoordinationStageSync::createStartAndAliveNodesAndCheckConcurrency(Coordination::ZooKeeperWithFaultInjection::Ptr zookeeper)
{ {
/// The "num_hosts" node keeps the number of hosts which started (created the "started" node) /// The "num_hosts" node keeps the number of hosts which started (created the "started" node)
/// but not yet finished (not created the "finished" node). /// but not yet finished (not created the "finished" node).
@ -464,7 +467,7 @@ void BackupCoordinationStageSync::watchingThread()
try try
{ {
/// Recreate the 'alive' node if necessary and read a new state from ZooKeeper. /// Recreate the 'alive' node if necessary and read a new state from ZooKeeper.
auto holder = with_retries.createRetriesControlHolder("BackupStageSync::watchingThread"); auto holder = with_retries.createRetriesControlHolder("BackupCoordinationStageSync::watchingThread");
auto & zookeeper = holder.faulty_zookeeper; auto & zookeeper = holder.faulty_zookeeper;
with_retries.renewZooKeeper(zookeeper); with_retries.renewZooKeeper(zookeeper);
@ -496,6 +499,9 @@ void BackupCoordinationStageSync::watchingThread()
tryLogCurrentException(log, "Caught exception while watching"); tryLogCurrentException(log, "Caught exception while watching");
} }
if (should_stop())
return;
zk_nodes_changed->tryWait(sync_period_ms.count()); zk_nodes_changed->tryWait(sync_period_ms.count());
} }
} }
@ -769,7 +775,7 @@ void BackupCoordinationStageSync::setStage(const String & stage, const String &
stopWatchingThread(); stopWatchingThread();
} }
auto holder = with_retries.createRetriesControlHolder("BackupStageSync::setStage"); auto holder = with_retries.createRetriesControlHolder("BackupCoordinationStageSync::setStage");
holder.retries_ctl.retryLoop([&, &zookeeper = holder.faulty_zookeeper]() holder.retries_ctl.retryLoop([&, &zookeeper = holder.faulty_zookeeper]()
{ {
with_retries.renewZooKeeper(zookeeper); with_retries.renewZooKeeper(zookeeper);
@ -938,7 +944,7 @@ bool BackupCoordinationStageSync::finishImpl(bool throw_if_error, WithRetries::K
try try
{ {
auto holder = with_retries.createRetriesControlHolder("BackupStageSync::finish", retries_kind); auto holder = with_retries.createRetriesControlHolder("BackupCoordinationStageSync::finish", retries_kind);
holder.retries_ctl.retryLoop([&, &zookeeper = holder.faulty_zookeeper]() holder.retries_ctl.retryLoop([&, &zookeeper = holder.faulty_zookeeper]()
{ {
with_retries.renewZooKeeper(zookeeper); with_retries.renewZooKeeper(zookeeper);
@ -1309,7 +1315,7 @@ bool BackupCoordinationStageSync::setError(const Exception & exception, bool thr
} }
} }
auto holder = with_retries.createRetriesControlHolder("BackupStageSync::setError", WithRetries::kErrorHandling); auto holder = with_retries.createRetriesControlHolder("BackupCoordinationStageSync::setError", WithRetries::kErrorHandling);
holder.retries_ctl.retryLoop([&, &zookeeper = holder.faulty_zookeeper]() holder.retries_ctl.retryLoop([&, &zookeeper = holder.faulty_zookeeper]()
{ {
with_retries.renewZooKeeper(zookeeper); with_retries.renewZooKeeper(zookeeper);

View File

@ -72,8 +72,8 @@ private:
void createRootNodes(); void createRootNodes();
/// Atomically creates both 'start' and 'alive' nodes and also checks that there is no concurrent backup or restore if `allow_concurrency` is false. /// Atomically creates both 'start' and 'alive' nodes and also checks that there is no concurrent backup or restore if `allow_concurrency` is false.
void createStartAndAliveNodes(); void createStartAndAliveNodesAndCheckConcurrency(BackupConcurrencyCounters & concurrency_counters_);
void createStartAndAliveNodes(Coordination::ZooKeeperWithFaultInjection::Ptr zookeeper); void createStartAndAliveNodesAndCheckConcurrency(Coordination::ZooKeeperWithFaultInjection::Ptr zookeeper);
/// Deserialize the version of a node stored in the 'start' node. /// Deserialize the version of a node stored in the 'start' node.
int parseStartNode(const String & start_node_contents, const String & host) const; int parseStartNode(const String & start_node_contents, const String & host) const;
@ -171,7 +171,7 @@ private:
const String alive_node_path; const String alive_node_path;
const String alive_tracker_node_path; const String alive_tracker_node_path;
std::optional<BackupConcurrencyCheck> concurrency_check; std::optional<BackupConcurrencyCheck> local_concurrency_check;
std::shared_ptr<Poco::Event> zk_nodes_changed; std::shared_ptr<Poco::Event> zk_nodes_changed;

View File

@ -140,8 +140,6 @@ void highlight(const String & query, std::vector<replxx::Replxx::Color> & colors
/// We don't do highlighting for foreign dialects, such as PRQL and Kusto. /// We don't do highlighting for foreign dialects, such as PRQL and Kusto.
/// Only normal ClickHouse SQL queries are highlighted. /// Only normal ClickHouse SQL queries are highlighted.
/// Currently we highlight only the first query in the multi-query mode.
ParserQuery parser(end, false, context.getSettingsRef()[Setting::implicit_select]); ParserQuery parser(end, false, context.getSettingsRef()[Setting::implicit_select]);
ASTPtr ast; ASTPtr ast;
bool parse_res = false; bool parse_res = false;

View File

@ -662,6 +662,8 @@ ColumnPtr ColumnArray::filter(const Filter & filt, ssize_t result_size_hint) con
return filterNumber<Int128>(filt, result_size_hint); return filterNumber<Int128>(filt, result_size_hint);
if (typeid_cast<const ColumnInt256 *>(data.get())) if (typeid_cast<const ColumnInt256 *>(data.get()))
return filterNumber<Int256>(filt, result_size_hint); return filterNumber<Int256>(filt, result_size_hint);
if (typeid_cast<const ColumnBFloat16 *>(data.get()))
return filterNumber<BFloat16>(filt, result_size_hint);
if (typeid_cast<const ColumnFloat32 *>(data.get())) if (typeid_cast<const ColumnFloat32 *>(data.get()))
return filterNumber<Float32>(filt, result_size_hint); return filterNumber<Float32>(filt, result_size_hint);
if (typeid_cast<const ColumnFloat64 *>(data.get())) if (typeid_cast<const ColumnFloat64 *>(data.get()))
@ -1065,6 +1067,8 @@ ColumnPtr ColumnArray::replicate(const Offsets & replicate_offsets) const
return replicateNumber<Int128>(replicate_offsets); return replicateNumber<Int128>(replicate_offsets);
if (typeid_cast<const ColumnInt256 *>(data.get())) if (typeid_cast<const ColumnInt256 *>(data.get()))
return replicateNumber<Int256>(replicate_offsets); return replicateNumber<Int256>(replicate_offsets);
if (typeid_cast<const ColumnBFloat16 *>(data.get()))
return replicateNumber<BFloat16>(replicate_offsets);
if (typeid_cast<const ColumnFloat32 *>(data.get())) if (typeid_cast<const ColumnFloat32 *>(data.get()))
return replicateNumber<Float32>(replicate_offsets); return replicateNumber<Float32>(replicate_offsets);
if (typeid_cast<const ColumnFloat64 *>(data.get())) if (typeid_cast<const ColumnFloat64 *>(data.get()))

View File

@ -16,6 +16,7 @@ template class ColumnUnique<ColumnInt128>;
template class ColumnUnique<ColumnUInt128>; template class ColumnUnique<ColumnUInt128>;
template class ColumnUnique<ColumnInt256>; template class ColumnUnique<ColumnInt256>;
template class ColumnUnique<ColumnUInt256>; template class ColumnUnique<ColumnUInt256>;
template class ColumnUnique<ColumnBFloat16>;
template class ColumnUnique<ColumnFloat32>; template class ColumnUnique<ColumnFloat32>;
template class ColumnUnique<ColumnFloat64>; template class ColumnUnique<ColumnFloat64>;
template class ColumnUnique<ColumnString>; template class ColumnUnique<ColumnString>;

View File

@ -760,6 +760,7 @@ extern template class ColumnUnique<ColumnInt128>;
extern template class ColumnUnique<ColumnUInt128>; extern template class ColumnUnique<ColumnUInt128>;
extern template class ColumnUnique<ColumnInt256>; extern template class ColumnUnique<ColumnInt256>;
extern template class ColumnUnique<ColumnUInt256>; extern template class ColumnUnique<ColumnUInt256>;
extern template class ColumnUnique<ColumnBFloat16>;
extern template class ColumnUnique<ColumnFloat32>; extern template class ColumnUnique<ColumnFloat32>;
extern template class ColumnUnique<ColumnFloat64>; extern template class ColumnUnique<ColumnFloat64>;
extern template class ColumnUnique<ColumnString>; extern template class ColumnUnique<ColumnString>;

View File

@ -118,9 +118,9 @@ struct ColumnVector<T>::less_stable
if (unlikely(parent.data[lhs] == parent.data[rhs])) if (unlikely(parent.data[lhs] == parent.data[rhs]))
return lhs < rhs; return lhs < rhs;
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
{ {
if (unlikely(std::isnan(parent.data[lhs]) && std::isnan(parent.data[rhs]))) if (unlikely(isNaN(parent.data[lhs]) && isNaN(parent.data[rhs])))
{ {
return lhs < rhs; return lhs < rhs;
} }
@ -150,9 +150,9 @@ struct ColumnVector<T>::greater_stable
if (unlikely(parent.data[lhs] == parent.data[rhs])) if (unlikely(parent.data[lhs] == parent.data[rhs]))
return lhs < rhs; return lhs < rhs;
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
{ {
if (unlikely(std::isnan(parent.data[lhs]) && std::isnan(parent.data[rhs]))) if (unlikely(isNaN(parent.data[lhs]) && isNaN(parent.data[rhs])))
{ {
return lhs < rhs; return lhs < rhs;
} }
@ -224,9 +224,9 @@ void ColumnVector<T>::getPermutation(IColumn::PermutationSortDirection direction
iota(res.data(), data_size, IColumn::Permutation::value_type(0)); iota(res.data(), data_size, IColumn::Permutation::value_type(0));
if constexpr (has_find_extreme_implementation<T> && !std::is_floating_point_v<T>) if constexpr (has_find_extreme_implementation<T> && !is_floating_point<T>)
{ {
/// Disabled for:floating point /// Disabled for floating point:
/// * floating point: We don't deal with nan_direction_hint /// * floating point: We don't deal with nan_direction_hint
/// * stability::Stable: We might return any value, not the first /// * stability::Stable: We might return any value, not the first
if ((limit == 1) && (stability == IColumn::PermutationSortStability::Unstable)) if ((limit == 1) && (stability == IColumn::PermutationSortStability::Unstable))
@ -256,7 +256,7 @@ void ColumnVector<T>::getPermutation(IColumn::PermutationSortDirection direction
bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable; bool sort_is_stable = stability == IColumn::PermutationSortStability::Stable;
/// TODO: LSD RadixSort is currently not stable if direction is descending, or value is floating point /// TODO: LSD RadixSort is currently not stable if direction is descending, or value is floating point
bool use_radix_sort = (sort_is_stable && ascending && !std::is_floating_point_v<T>) || !sort_is_stable; bool use_radix_sort = (sort_is_stable && ascending && !is_floating_point<T>) || !sort_is_stable;
/// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters.
if (data_size >= 256 && data_size <= std::numeric_limits<UInt32>::max() && use_radix_sort) if (data_size >= 256 && data_size <= std::numeric_limits<UInt32>::max() && use_radix_sort)
@ -283,7 +283,7 @@ void ColumnVector<T>::getPermutation(IColumn::PermutationSortDirection direction
/// Radix sort treats all NaNs to be greater than all numbers. /// Radix sort treats all NaNs to be greater than all numbers.
/// If the user needs the opposite, we must move them accordingly. /// If the user needs the opposite, we must move them accordingly.
if (std::is_floating_point_v<T> && nan_direction_hint < 0) if (is_floating_point<T> && nan_direction_hint < 0)
{ {
size_t nans_to_move = 0; size_t nans_to_move = 0;
@ -330,7 +330,7 @@ void ColumnVector<T>::updatePermutation(IColumn::PermutationSortDirection direct
if constexpr (is_arithmetic_v<T> && !is_big_int_v<T>) if constexpr (is_arithmetic_v<T> && !is_big_int_v<T>)
{ {
/// TODO: LSD RadixSort is currently not stable if direction is descending, or value is floating point /// TODO: LSD RadixSort is currently not stable if direction is descending, or value is floating point
bool use_radix_sort = (sort_is_stable && ascending && !std::is_floating_point_v<T>) || !sort_is_stable; bool use_radix_sort = (sort_is_stable && ascending && !is_floating_point<T>) || !sort_is_stable;
size_t size = end - begin; size_t size = end - begin;
/// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters. /// Thresholds on size. Lower threshold is arbitrary. Upper threshold is chosen by the type for histogram counters.
@ -353,7 +353,7 @@ void ColumnVector<T>::updatePermutation(IColumn::PermutationSortDirection direct
/// Radix sort treats all NaNs to be greater than all numbers. /// Radix sort treats all NaNs to be greater than all numbers.
/// If the user needs the opposite, we must move them accordingly. /// If the user needs the opposite, we must move them accordingly.
if (std::is_floating_point_v<T> && nan_direction_hint < 0) if (is_floating_point<T> && nan_direction_hint < 0)
{ {
size_t nans_to_move = 0; size_t nans_to_move = 0;
@ -1005,6 +1005,7 @@ template class ColumnVector<Int32>;
template class ColumnVector<Int64>; template class ColumnVector<Int64>;
template class ColumnVector<Int128>; template class ColumnVector<Int128>;
template class ColumnVector<Int256>; template class ColumnVector<Int256>;
template class ColumnVector<BFloat16>;
template class ColumnVector<Float32>; template class ColumnVector<Float32>;
template class ColumnVector<Float64>; template class ColumnVector<Float64>;
template class ColumnVector<UUID>; template class ColumnVector<UUID>;

View File

@ -481,6 +481,7 @@ extern template class ColumnVector<Int32>;
extern template class ColumnVector<Int64>; extern template class ColumnVector<Int64>;
extern template class ColumnVector<Int128>; extern template class ColumnVector<Int128>;
extern template class ColumnVector<Int256>; extern template class ColumnVector<Int256>;
extern template class ColumnVector<BFloat16>;
extern template class ColumnVector<Float32>; extern template class ColumnVector<Float32>;
extern template class ColumnVector<Float64>; extern template class ColumnVector<Float64>;
extern template class ColumnVector<UUID>; extern template class ColumnVector<UUID>;

View File

@ -328,6 +328,7 @@ INSTANTIATE(Int32)
INSTANTIATE(Int64) INSTANTIATE(Int64)
INSTANTIATE(Int128) INSTANTIATE(Int128)
INSTANTIATE(Int256) INSTANTIATE(Int256)
INSTANTIATE(BFloat16)
INSTANTIATE(Float32) INSTANTIATE(Float32)
INSTANTIATE(Float64) INSTANTIATE(Float64)
INSTANTIATE(Decimal32) INSTANTIATE(Decimal32)

View File

@ -23,6 +23,7 @@ using ColumnInt64 = ColumnVector<Int64>;
using ColumnInt128 = ColumnVector<Int128>; using ColumnInt128 = ColumnVector<Int128>;
using ColumnInt256 = ColumnVector<Int256>; using ColumnInt256 = ColumnVector<Int256>;
using ColumnBFloat16 = ColumnVector<BFloat16>;
using ColumnFloat32 = ColumnVector<Float32>; using ColumnFloat32 = ColumnVector<Float32>;
using ColumnFloat64 = ColumnVector<Float64>; using ColumnFloat64 = ColumnVector<Float64>;

View File

@ -443,6 +443,7 @@ template class IColumnHelper<ColumnVector<Int32>, ColumnFixedSizeHelper>;
template class IColumnHelper<ColumnVector<Int64>, ColumnFixedSizeHelper>; template class IColumnHelper<ColumnVector<Int64>, ColumnFixedSizeHelper>;
template class IColumnHelper<ColumnVector<Int128>, ColumnFixedSizeHelper>; template class IColumnHelper<ColumnVector<Int128>, ColumnFixedSizeHelper>;
template class IColumnHelper<ColumnVector<Int256>, ColumnFixedSizeHelper>; template class IColumnHelper<ColumnVector<Int256>, ColumnFixedSizeHelper>;
template class IColumnHelper<ColumnVector<BFloat16>, ColumnFixedSizeHelper>;
template class IColumnHelper<ColumnVector<Float32>, ColumnFixedSizeHelper>; template class IColumnHelper<ColumnVector<Float32>, ColumnFixedSizeHelper>;
template class IColumnHelper<ColumnVector<Float64>, ColumnFixedSizeHelper>; template class IColumnHelper<ColumnVector<Float64>, ColumnFixedSizeHelper>;
template class IColumnHelper<ColumnVector<UUID>, ColumnFixedSizeHelper>; template class IColumnHelper<ColumnVector<UUID>, ColumnFixedSizeHelper>;

View File

@ -63,6 +63,7 @@ INSTANTIATE(Int32)
INSTANTIATE(Int64) INSTANTIATE(Int64)
INSTANTIATE(Int128) INSTANTIATE(Int128)
INSTANTIATE(Int256) INSTANTIATE(Int256)
INSTANTIATE(BFloat16)
INSTANTIATE(Float32) INSTANTIATE(Float32)
INSTANTIATE(Float64) INSTANTIATE(Float64)
INSTANTIATE(Decimal32) INSTANTIATE(Decimal32)
@ -200,6 +201,7 @@ static MaskInfo extractMaskImpl(
|| extractMaskNumeric<inverted, Int16>(mask, column, null_value, null_bytemap, nulls, mask_info) || extractMaskNumeric<inverted, Int16>(mask, column, null_value, null_bytemap, nulls, mask_info)
|| extractMaskNumeric<inverted, Int32>(mask, column, null_value, null_bytemap, nulls, mask_info) || extractMaskNumeric<inverted, Int32>(mask, column, null_value, null_bytemap, nulls, mask_info)
|| extractMaskNumeric<inverted, Int64>(mask, column, null_value, null_bytemap, nulls, mask_info) || extractMaskNumeric<inverted, Int64>(mask, column, null_value, null_bytemap, nulls, mask_info)
|| extractMaskNumeric<inverted, BFloat16>(mask, column, null_value, null_bytemap, nulls, mask_info)
|| extractMaskNumeric<inverted, Float32>(mask, column, null_value, null_bytemap, nulls, mask_info) || extractMaskNumeric<inverted, Float32>(mask, column, null_value, null_bytemap, nulls, mask_info)
|| extractMaskNumeric<inverted, Float64>(mask, column, null_value, null_bytemap, nulls, mask_info))) || extractMaskNumeric<inverted, Float64>(mask, column, null_value, null_bytemap, nulls, mask_info)))
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Cannot convert column {} to mask.", column->getName()); throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Cannot convert column {} to mask.", column->getName());

View File

@ -93,6 +93,7 @@ TEST(ColumnVector, Filter)
testFilter<Int64>(); testFilter<Int64>();
testFilter<UInt128>(); testFilter<UInt128>();
testFilter<Int256>(); testFilter<Int256>();
testFilter<BFloat16>();
testFilter<Float32>(); testFilter<Float32>();
testFilter<Float64>(); testFilter<Float64>();
testFilter<UUID>(); testFilter<UUID>();

View File

@ -45,6 +45,7 @@ TEST(ColumnLowCardinality, Insert)
testLowCardinalityNumberInsert<Int128>(std::make_shared<DataTypeInt128>()); testLowCardinalityNumberInsert<Int128>(std::make_shared<DataTypeInt128>());
testLowCardinalityNumberInsert<Int256>(std::make_shared<DataTypeInt256>()); testLowCardinalityNumberInsert<Int256>(std::make_shared<DataTypeInt256>());
testLowCardinalityNumberInsert<BFloat16>(std::make_shared<DataTypeBFloat16>());
testLowCardinalityNumberInsert<Float32>(std::make_shared<DataTypeFloat32>()); testLowCardinalityNumberInsert<Float32>(std::make_shared<DataTypeFloat32>());
testLowCardinalityNumberInsert<Float64>(std::make_shared<DataTypeFloat64>()); testLowCardinalityNumberInsert<Float64>(std::make_shared<DataTypeFloat64>());
} }

View File

@ -266,6 +266,11 @@ inline bool haveAVX512VBMI2() noexcept
return haveAVX512F() && ((CPUInfo(0x7, 0).registers.ecx >> 6) & 1u); return haveAVX512F() && ((CPUInfo(0x7, 0).registers.ecx >> 6) & 1u);
} }
inline bool haveAVX512BF16() noexcept
{
return haveAVX512F() && ((CPUInfo(0x7, 1).registers.eax >> 5) & 1u);
}
inline bool haveRDRAND() noexcept inline bool haveRDRAND() noexcept
{ {
return CPUInfo(0x0).registers.eax >= 0x7 && ((CPUInfo(0x1).registers.ecx >> 30) & 1u); return CPUInfo(0x0).registers.eax >= 0x7 && ((CPUInfo(0x1).registers.ecx >> 30) & 1u);
@ -326,6 +331,7 @@ inline bool haveAMXINT8() noexcept
OP(AVX512VL) \ OP(AVX512VL) \
OP(AVX512VBMI) \ OP(AVX512VBMI) \
OP(AVX512VBMI2) \ OP(AVX512VBMI2) \
OP(AVX512BF16) \
OP(PREFETCHWT1) \ OP(PREFETCHWT1) \
OP(SHA) \ OP(SHA) \
OP(ADX) \ OP(ADX) \

View File

@ -87,6 +87,7 @@ APPLY_FOR_FAILPOINTS(M, M, M, M)
std::unordered_map<String, std::shared_ptr<FailPointChannel>> FailPointInjection::fail_point_wait_channels; std::unordered_map<String, std::shared_ptr<FailPointChannel>> FailPointInjection::fail_point_wait_channels;
std::mutex FailPointInjection::mu; std::mutex FailPointInjection::mu;
class FailPointChannel : private boost::noncopyable class FailPointChannel : private boost::noncopyable
{ {
public: public:

View File

@ -15,6 +15,7 @@
#include <unordered_map> #include <unordered_map>
namespace DB namespace DB
{ {
@ -27,6 +28,7 @@ namespace DB
/// 3. in test file, we can use system failpoint enable/disable 'failpoint_name' /// 3. in test file, we can use system failpoint enable/disable 'failpoint_name'
class FailPointChannel; class FailPointChannel;
class FailPointInjection class FailPointInjection
{ {
public: public:

View File

@ -1,5 +1,4 @@
#include <Common/FieldVisitorConvertToNumber.h> #include <Common/FieldVisitorConvertToNumber.h>
#include "base/Decimal.h"
namespace DB namespace DB
{ {
@ -17,6 +16,7 @@ template class FieldVisitorConvertToNumber<Int128>;
template class FieldVisitorConvertToNumber<UInt128>; template class FieldVisitorConvertToNumber<UInt128>;
template class FieldVisitorConvertToNumber<Int256>; template class FieldVisitorConvertToNumber<Int256>;
template class FieldVisitorConvertToNumber<UInt256>; template class FieldVisitorConvertToNumber<UInt256>;
//template class FieldVisitorConvertToNumber<BFloat16>;
template class FieldVisitorConvertToNumber<Float32>; template class FieldVisitorConvertToNumber<Float32>;
template class FieldVisitorConvertToNumber<Float64>; template class FieldVisitorConvertToNumber<Float64>;

View File

@ -58,7 +58,7 @@ public:
T operator() (const Float64 & x) const T operator() (const Float64 & x) const
{ {
if constexpr (!std::is_floating_point_v<T>) if constexpr (!is_floating_point<T>)
{ {
if (!isFinite(x)) if (!isFinite(x))
{ {
@ -88,7 +88,7 @@ public:
template <typename U> template <typename U>
T operator() (const DecimalField<U> & x) const T operator() (const DecimalField<U> & x) const
{ {
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
return x.getValue().template convertTo<T>() / x.getScaleMultiplier().template convertTo<T>(); return x.getValue().template convertTo<T>() / x.getScaleMultiplier().template convertTo<T>();
else else
return (x.getValue() / x.getScaleMultiplier()).template convertTo<T>(); return (x.getValue() / x.getScaleMultiplier()).template convertTo<T>();
@ -129,6 +129,7 @@ extern template class FieldVisitorConvertToNumber<Int128>;
extern template class FieldVisitorConvertToNumber<UInt128>; extern template class FieldVisitorConvertToNumber<UInt128>;
extern template class FieldVisitorConvertToNumber<Int256>; extern template class FieldVisitorConvertToNumber<Int256>;
extern template class FieldVisitorConvertToNumber<UInt256>; extern template class FieldVisitorConvertToNumber<UInt256>;
//extern template class FieldVisitorConvertToNumber<BFloat16>;
extern template class FieldVisitorConvertToNumber<Float32>; extern template class FieldVisitorConvertToNumber<Float32>;
extern template class FieldVisitorConvertToNumber<Float64>; extern template class FieldVisitorConvertToNumber<Float64>;

View File

@ -322,6 +322,7 @@ DEFINE_HASH(Int32)
DEFINE_HASH(Int64) DEFINE_HASH(Int64)
DEFINE_HASH(Int128) DEFINE_HASH(Int128)
DEFINE_HASH(Int256) DEFINE_HASH(Int256)
DEFINE_HASH(BFloat16)
DEFINE_HASH(Float32) DEFINE_HASH(Float32)
DEFINE_HASH(Float64) DEFINE_HASH(Float64)
DEFINE_HASH(DB::UUID) DEFINE_HASH(DB::UUID)

View File

@ -76,7 +76,7 @@ struct HashTableNoState
template <typename T> template <typename T>
inline bool bitEquals(T a, T b) inline bool bitEquals(T a, T b)
{ {
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
/// Note that memcmp with constant size is a compiler builtin. /// Note that memcmp with constant size is a compiler builtin.
return 0 == memcmp(&a, &b, sizeof(T)); /// NOLINT return 0 == memcmp(&a, &b, sizeof(T)); /// NOLINT
else else

View File

@ -9,6 +9,7 @@
#include <mutex> #include <mutex>
#include <algorithm> #include <algorithm>
#include <Poco/Timespan.h>
namespace ProfileEvents namespace ProfileEvents
@ -49,16 +50,18 @@ HostResolver::WeakPtr HostResolver::getWeakFromThis()
} }
HostResolver::HostResolver(String host_, Poco::Timespan history_) HostResolver::HostResolver(String host_, Poco::Timespan history_)
: host(std::move(host_)) : HostResolver(
, history(history_) [](const String & host_to_resolve) { return DNSResolver::instance().resolveHostAllInOriginOrder(host_to_resolve); },
, resolve_function([](const String & host_to_resolve) { return DNSResolver::instance().resolveHostAllInOriginOrder(host_to_resolve); }) host_,
{ history_)
update(); {}
}
HostResolver::HostResolver( HostResolver::HostResolver(
ResolveFunction && resolve_function_, String host_, Poco::Timespan history_) ResolveFunction && resolve_function_, String host_, Poco::Timespan history_)
: host(std::move(host_)), history(history_), resolve_function(std::move(resolve_function_)) : host(std::move(host_))
, history(history_)
, resolve_interval(history_.totalMicroseconds() / 3)
, resolve_function(std::move(resolve_function_))
{ {
update(); update();
} }
@ -203,7 +206,7 @@ bool HostResolver::isUpdateNeeded()
Poco::Timestamp now; Poco::Timestamp now;
std::lock_guard lock(mutex); std::lock_guard lock(mutex);
return last_resolve_time + history < now || records.empty(); return last_resolve_time + resolve_interval < now || records.empty();
} }
void HostResolver::updateImpl(Poco::Timestamp now, std::vector<Poco::Net::IPAddress> & next_gen) void HostResolver::updateImpl(Poco::Timestamp now, std::vector<Poco::Net::IPAddress> & next_gen)

View File

@ -26,7 +26,7 @@
// a) it still occurs in resolve set after `history_` time or b) all other addresses are pessimized as well. // a) it still occurs in resolve set after `history_` time or b) all other addresses are pessimized as well.
// - resolve schedule // - resolve schedule
// Addresses are resolved through `DB::DNSResolver::instance()`. // Addresses are resolved through `DB::DNSResolver::instance()`.
// Usually it does not happen more often than once in `history_` time. // Usually it does not happen more often than 3 times in `history_` period.
// But also new resolve performed each `setFail()` call. // But also new resolve performed each `setFail()` call.
namespace DB namespace DB
@ -212,6 +212,7 @@ protected:
const String host; const String host;
const Poco::Timespan history; const Poco::Timespan history;
const Poco::Timespan resolve_interval;
const HostResolverMetrics metrics = getMetrics(); const HostResolverMetrics metrics = getMetrics();
// for tests purpose // for tests purpose
@ -245,4 +246,3 @@ private:
}; };
} }

View File

@ -3,24 +3,24 @@
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include <type_traits> #include <type_traits>
#include <base/DecomposedFloat.h>
template <typename T> template <typename T>
inline bool isNaN(T x) inline bool isNaN(T x)
{ {
/// To be sure, that this function is zero-cost for non-floating point types. /// To be sure, that this function is zero-cost for non-floating point types.
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
return std::isnan(x); return DecomposedFloat(x).isNaN();
else else
return false; return false;
} }
template <typename T> template <typename T>
inline bool isFinite(T x) inline bool isFinite(T x)
{ {
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
return std::isfinite(x); return DecomposedFloat(x).isFinite();
else else
return true; return true;
} }
@ -28,7 +28,7 @@ inline bool isFinite(T x)
template <typename T> template <typename T>
bool canConvertTo(Float64 x) bool canConvertTo(Float64 x)
{ {
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
return true; return true;
if (!isFinite(x)) if (!isFinite(x))
return false; return false;
@ -46,3 +46,12 @@ T NaNOrZero()
else else
return {}; return {};
} }
template <typename T>
bool signBit(T x)
{
if constexpr (is_floating_point<T>)
return DecomposedFloat(x).isNegative();
else
return x < 0;
}

View File

@ -23,6 +23,8 @@ UInt32 getSupportedArchs()
result |= static_cast<UInt32>(TargetArch::AVX512VBMI); result |= static_cast<UInt32>(TargetArch::AVX512VBMI);
if (CPU::CPUFlagsCache::have_AVX512VBMI2) if (CPU::CPUFlagsCache::have_AVX512VBMI2)
result |= static_cast<UInt32>(TargetArch::AVX512VBMI2); result |= static_cast<UInt32>(TargetArch::AVX512VBMI2);
if (CPU::CPUFlagsCache::have_AVX512BF16)
result |= static_cast<UInt32>(TargetArch::AVX512BF16);
if (CPU::CPUFlagsCache::have_AMXBF16) if (CPU::CPUFlagsCache::have_AMXBF16)
result |= static_cast<UInt32>(TargetArch::AMXBF16); result |= static_cast<UInt32>(TargetArch::AMXBF16);
if (CPU::CPUFlagsCache::have_AMXTILE) if (CPU::CPUFlagsCache::have_AMXTILE)
@ -50,6 +52,7 @@ String toString(TargetArch arch)
case TargetArch::AVX512BW: return "avx512bw"; case TargetArch::AVX512BW: return "avx512bw";
case TargetArch::AVX512VBMI: return "avx512vbmi"; case TargetArch::AVX512VBMI: return "avx512vbmi";
case TargetArch::AVX512VBMI2: return "avx512vbmi2"; case TargetArch::AVX512VBMI2: return "avx512vbmi2";
case TargetArch::AVX512BF16: return "avx512bf16";
case TargetArch::AMXBF16: return "amxbf16"; case TargetArch::AMXBF16: return "amxbf16";
case TargetArch::AMXTILE: return "amxtile"; case TargetArch::AMXTILE: return "amxtile";
case TargetArch::AMXINT8: return "amxint8"; case TargetArch::AMXINT8: return "amxint8";

View File

@ -83,9 +83,10 @@ enum class TargetArch : UInt32
AVX512BW = (1 << 4), AVX512BW = (1 << 4),
AVX512VBMI = (1 << 5), AVX512VBMI = (1 << 5),
AVX512VBMI2 = (1 << 6), AVX512VBMI2 = (1 << 6),
AMXBF16 = (1 << 7), AVX512BF16 = (1 << 7),
AMXTILE = (1 << 8), AMXBF16 = (1 << 8),
AMXINT8 = (1 << 9), AMXTILE = (1 << 9),
AMXINT8 = (1 << 10),
}; };
/// Runtime detection. /// Runtime detection.
@ -102,6 +103,7 @@ String toString(TargetArch arch);
/// NOLINTNEXTLINE /// NOLINTNEXTLINE
#define USE_MULTITARGET_CODE 1 #define USE_MULTITARGET_CODE 1
#define AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2,avx512bf16")))
#define AVX512VBMI2_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2"))) #define AVX512VBMI2_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2")))
#define AVX512VBMI_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi"))) #define AVX512VBMI_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi")))
#define AVX512BW_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw"))) #define AVX512BW_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw")))
@ -111,6 +113,8 @@ String toString(TargetArch arch);
#define SSE42_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt"))) #define SSE42_FUNCTION_SPECIFIC_ATTRIBUTE __attribute__((target("sse,sse2,sse3,ssse3,sse4,popcnt")))
#define DEFAULT_FUNCTION_SPECIFIC_ATTRIBUTE #define DEFAULT_FUNCTION_SPECIFIC_ATTRIBUTE
# define BEGIN_AVX512BF16_SPECIFIC_CODE \
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2,avx512bf16\"))),apply_to=function)")
# define BEGIN_AVX512VBMI2_SPECIFIC_CODE \ # define BEGIN_AVX512VBMI2_SPECIFIC_CODE \
_Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2\"))),apply_to=function)") _Pragma("clang attribute push(__attribute__((target(\"sse,sse2,sse3,ssse3,sse4,popcnt,avx,avx2,avx512f,avx512bw,avx512vl,avx512vbmi,avx512vbmi2\"))),apply_to=function)")
# define BEGIN_AVX512VBMI_SPECIFIC_CODE \ # define BEGIN_AVX512VBMI_SPECIFIC_CODE \
@ -197,6 +201,14 @@ namespace TargetSpecific::AVX512VBMI2 { \
} \ } \
END_TARGET_SPECIFIC_CODE END_TARGET_SPECIFIC_CODE
#define DECLARE_AVX512BF16_SPECIFIC_CODE(...) \
BEGIN_AVX512BF16_SPECIFIC_CODE \
namespace TargetSpecific::AVX512BF16 { \
DUMMY_FUNCTION_DEFINITION \
using namespace DB::TargetSpecific::AVX512BF16; \
__VA_ARGS__ \
} \
END_TARGET_SPECIFIC_CODE
#else #else
@ -211,6 +223,7 @@ END_TARGET_SPECIFIC_CODE
#define DECLARE_AVX512BW_SPECIFIC_CODE(...) #define DECLARE_AVX512BW_SPECIFIC_CODE(...)
#define DECLARE_AVX512VBMI_SPECIFIC_CODE(...) #define DECLARE_AVX512VBMI_SPECIFIC_CODE(...)
#define DECLARE_AVX512VBMI2_SPECIFIC_CODE(...) #define DECLARE_AVX512VBMI2_SPECIFIC_CODE(...)
#define DECLARE_AVX512BF16_SPECIFIC_CODE(...)
#endif #endif
@ -229,7 +242,8 @@ DECLARE_AVX2_SPECIFIC_CODE (__VA_ARGS__) \
DECLARE_AVX512F_SPECIFIC_CODE(__VA_ARGS__) \ DECLARE_AVX512F_SPECIFIC_CODE(__VA_ARGS__) \
DECLARE_AVX512BW_SPECIFIC_CODE (__VA_ARGS__) \ DECLARE_AVX512BW_SPECIFIC_CODE (__VA_ARGS__) \
DECLARE_AVX512VBMI_SPECIFIC_CODE (__VA_ARGS__) \ DECLARE_AVX512VBMI_SPECIFIC_CODE (__VA_ARGS__) \
DECLARE_AVX512VBMI2_SPECIFIC_CODE (__VA_ARGS__) DECLARE_AVX512VBMI2_SPECIFIC_CODE (__VA_ARGS__) \
DECLARE_AVX512BF16_SPECIFIC_CODE (__VA_ARGS__)
DECLARE_DEFAULT_CODE( DECLARE_DEFAULT_CODE(
constexpr auto BuildArch = TargetArch::Default; /// NOLINT constexpr auto BuildArch = TargetArch::Default; /// NOLINT
@ -263,6 +277,10 @@ DECLARE_AVX512VBMI2_SPECIFIC_CODE(
constexpr auto BuildArch = TargetArch::AVX512VBMI2; /// NOLINT constexpr auto BuildArch = TargetArch::AVX512VBMI2; /// NOLINT
) // DECLARE_AVX512VBMI2_SPECIFIC_CODE ) // DECLARE_AVX512VBMI2_SPECIFIC_CODE
DECLARE_AVX512BF16_SPECIFIC_CODE(
constexpr auto BuildArch = TargetArch::AVX512BF16; /// NOLINT
) // DECLARE_AVX512BF16_SPECIFIC_CODE
/** Runtime Dispatch helpers for class members. /** Runtime Dispatch helpers for class members.
* *
* Example of usage: * Example of usage:

View File

@ -47,7 +47,7 @@ MULTITARGET_FUNCTION_AVX2_SSE42(
/// Unroll the loop manually for floating point, since the compiler doesn't do it without fastmath /// Unroll the loop manually for floating point, since the compiler doesn't do it without fastmath
/// as it might change the return value /// as it might change the return value
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
{ {
constexpr size_t unroll_block = 512 / sizeof(T); /// Chosen via benchmarks with AVX2 so YMMV constexpr size_t unroll_block = 512 / sizeof(T); /// Chosen via benchmarks with AVX2 so YMMV
size_t unrolled_end = i + (((count - i) / unroll_block) * unroll_block); size_t unrolled_end = i + (((count - i) / unroll_block) * unroll_block);

View File

@ -38,7 +38,7 @@ inline void transformEndianness(T & x)
} }
template <std::endian ToEndian, std::endian FromEndian = std::endian::native, typename T> template <std::endian ToEndian, std::endian FromEndian = std::endian::native, typename T>
requires std::is_floating_point_v<T> requires is_floating_point<T>
inline void transformEndianness(T & value) inline void transformEndianness(T & value)
{ {
if constexpr (ToEndian != FromEndian) if constexpr (ToEndian != FromEndian)

View File

@ -3,7 +3,7 @@
#include <IO/WriteBuffer.h> #include <IO/WriteBuffer.h>
#include <Compression/ICompressionCodec.h> #include <Compression/ICompressionCodec.h>
#include <IO/BufferWithOwnMemory.h> #include <IO/BufferWithOwnMemory.h>
#include <Parsers/StringRange.h>
namespace DB namespace DB
{ {

View File

@ -7,7 +7,6 @@
#include <Parsers/ExpressionElementParsers.h> #include <Parsers/ExpressionElementParsers.h>
#include <Parsers/IParser.h> #include <Parsers/IParser.h>
#include <Parsers/TokenIterator.h> #include <Parsers/TokenIterator.h>
#include <base/types.h>
#include <Common/PODArray.h> #include <Common/PODArray.h>
#include <Common/Stopwatch.h> #include <Common/Stopwatch.h>

View File

@ -25,7 +25,7 @@ bool lessOp(A a, B b)
return a < b; return a < b;
/// float vs float /// float vs float
if constexpr (std::is_floating_point_v<A> && std::is_floating_point_v<B>) if constexpr (is_floating_point<A> && is_floating_point<B>)
return a < b; return a < b;
/// anything vs NaN /// anything vs NaN
@ -49,7 +49,7 @@ bool lessOp(A a, B b)
} }
/// int vs float /// int vs float
if constexpr (is_integer<A> && std::is_floating_point_v<B>) if constexpr (is_integer<A> && is_floating_point<B>)
{ {
if constexpr (sizeof(A) <= 4) if constexpr (sizeof(A) <= 4)
return static_cast<double>(a) < static_cast<double>(b); return static_cast<double>(a) < static_cast<double>(b);
@ -57,7 +57,7 @@ bool lessOp(A a, B b)
return DecomposedFloat<B>(b).greater(a); return DecomposedFloat<B>(b).greater(a);
} }
if constexpr (std::is_floating_point_v<A> && is_integer<B>) if constexpr (is_floating_point<A> && is_integer<B>)
{ {
if constexpr (sizeof(B) <= 4) if constexpr (sizeof(B) <= 4)
return static_cast<double>(a) < static_cast<double>(b); return static_cast<double>(a) < static_cast<double>(b);
@ -65,8 +65,8 @@ bool lessOp(A a, B b)
return DecomposedFloat<A>(a).less(b); return DecomposedFloat<A>(a).less(b);
} }
static_assert(is_integer<A> || std::is_floating_point_v<A>); static_assert(is_integer<A> || is_floating_point<A>);
static_assert(is_integer<B> || std::is_floating_point_v<B>); static_assert(is_integer<B> || is_floating_point<B>);
UNREACHABLE(); UNREACHABLE();
} }
@ -101,7 +101,7 @@ bool equalsOp(A a, B b)
return a == b; return a == b;
/// float vs float /// float vs float
if constexpr (std::is_floating_point_v<A> && std::is_floating_point_v<B>) if constexpr (is_floating_point<A> && is_floating_point<B>)
return a == b; return a == b;
/// anything vs NaN /// anything vs NaN
@ -125,7 +125,7 @@ bool equalsOp(A a, B b)
} }
/// int vs float /// int vs float
if constexpr (is_integer<A> && std::is_floating_point_v<B>) if constexpr (is_integer<A> && is_floating_point<B>)
{ {
if constexpr (sizeof(A) <= 4) if constexpr (sizeof(A) <= 4)
return static_cast<double>(a) == static_cast<double>(b); return static_cast<double>(a) == static_cast<double>(b);
@ -133,7 +133,7 @@ bool equalsOp(A a, B b)
return DecomposedFloat<B>(b).equals(a); return DecomposedFloat<B>(b).equals(a);
} }
if constexpr (std::is_floating_point_v<A> && is_integer<B>) if constexpr (is_floating_point<A> && is_integer<B>)
{ {
if constexpr (sizeof(B) <= 4) if constexpr (sizeof(B) <= 4)
return static_cast<double>(a) == static_cast<double>(b); return static_cast<double>(a) == static_cast<double>(b);
@ -163,7 +163,7 @@ inline bool NO_SANITIZE_UNDEFINED convertNumeric(From value, To & result)
return true; return true;
} }
if constexpr (std::is_floating_point_v<From> && std::is_floating_point_v<To>) if constexpr (is_floating_point<From> && is_floating_point<To>)
{ {
/// Note that NaNs doesn't compare equal to anything, but they are still in range of any Float type. /// Note that NaNs doesn't compare equal to anything, but they are still in range of any Float type.
if (isNaN(value)) if (isNaN(value))

View File

@ -17,6 +17,7 @@ class DataTypeNumber;
namespace ErrorCodes namespace ErrorCodes
{ {
extern const int NOT_IMPLEMENTED;
extern const int DECIMAL_OVERFLOW; extern const int DECIMAL_OVERFLOW;
extern const int ARGUMENT_OUT_OF_BOUND; extern const int ARGUMENT_OUT_OF_BOUND;
} }
@ -310,7 +311,14 @@ ReturnType convertToImpl(const DecimalType & decimal, UInt32 scale, To & result)
using DecimalNativeType = typename DecimalType::NativeType; using DecimalNativeType = typename DecimalType::NativeType;
static constexpr bool throw_exception = std::is_void_v<ReturnType>; static constexpr bool throw_exception = std::is_void_v<ReturnType>;
if constexpr (std::is_floating_point_v<To>) if constexpr (std::is_same_v<To, BFloat16>)
{
if constexpr (throw_exception)
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conversion from Decimal to BFloat16 is not implemented");
else
return ReturnType(false);
}
else if constexpr (is_floating_point<To>)
{ {
result = static_cast<To>(decimal.value) / static_cast<To>(scaleMultiplier<DecimalNativeType>(scale)); result = static_cast<To>(decimal.value) / static_cast<To>(scaleMultiplier<DecimalNativeType>(scale));
} }

View File

@ -257,6 +257,7 @@ template <> struct NearestFieldTypeImpl<DecimalField<Decimal64>> { using Type =
template <> struct NearestFieldTypeImpl<DecimalField<Decimal128>> { using Type = DecimalField<Decimal128>; }; template <> struct NearestFieldTypeImpl<DecimalField<Decimal128>> { using Type = DecimalField<Decimal128>; };
template <> struct NearestFieldTypeImpl<DecimalField<Decimal256>> { using Type = DecimalField<Decimal256>; }; template <> struct NearestFieldTypeImpl<DecimalField<Decimal256>> { using Type = DecimalField<Decimal256>; };
template <> struct NearestFieldTypeImpl<DecimalField<DateTime64>> { using Type = DecimalField<DateTime64>; }; template <> struct NearestFieldTypeImpl<DecimalField<DateTime64>> { using Type = DecimalField<DateTime64>; };
template <> struct NearestFieldTypeImpl<BFloat16> { using Type = Float64; };
template <> struct NearestFieldTypeImpl<Float32> { using Type = Float64; }; template <> struct NearestFieldTypeImpl<Float32> { using Type = Float64; };
template <> struct NearestFieldTypeImpl<Float64> { using Type = Float64; }; template <> struct NearestFieldTypeImpl<Float64> { using Type = Float64; };
template <> struct NearestFieldTypeImpl<const char *> { using Type = String; }; template <> struct NearestFieldTypeImpl<const char *> { using Type = String; };

View File

@ -4565,7 +4565,7 @@ Possible values:
- 0 - Disable - 0 - Disable
- 1 - Enable - 1 - Enable
)", 0) \ )", 0) \
DECLARE(Bool, query_plan_merge_filters, true, R"( DECLARE(Bool, query_plan_merge_filters, false, R"(
Allow to merge filters in the query plan Allow to merge filters in the query plan
)", 0) \ )", 0) \
DECLARE(Bool, query_plan_filter_push_down, true, R"( DECLARE(Bool, query_plan_filter_push_down, true, R"(
@ -5742,7 +5742,10 @@ Enable experimental functions for natural language processing.
Enable experimental hash functions Enable experimental hash functions
)", EXPERIMENTAL) \ )", EXPERIMENTAL) \
DECLARE(Bool, allow_experimental_object_type, false, R"( DECLARE(Bool, allow_experimental_object_type, false, R"(
Allow Object and JSON data types Allow the obsolete Object data type
)", EXPERIMENTAL) \
DECLARE(Bool, allow_experimental_bfloat16_type, false, R"(
Allow BFloat16 data type (under development).
)", EXPERIMENTAL) \ )", EXPERIMENTAL) \
DECLARE(Bool, allow_experimental_time_series_table, false, R"( DECLARE(Bool, allow_experimental_time_series_table, false, R"(
Allows creation of tables with the [TimeSeries](../../engines/table-engines/integrations/time-series.md) table engine. Allows creation of tables with the [TimeSeries](../../engines/table-engines/integrations/time-series.md) table engine.

View File

@ -77,8 +77,8 @@ static std::initializer_list<std::pair<ClickHouseVersion, SettingsChangesHistory
{"backup_restore_keeper_max_retries_while_initializing", 0, 20, "New setting."}, {"backup_restore_keeper_max_retries_while_initializing", 0, 20, "New setting."},
{"backup_restore_keeper_max_retries_while_handling_error", 0, 20, "New setting."}, {"backup_restore_keeper_max_retries_while_handling_error", 0, 20, "New setting."},
{"backup_restore_finish_timeout_after_error_sec", 0, 180, "New setting."}, {"backup_restore_finish_timeout_after_error_sec", 0, 180, "New setting."},
{"query_plan_merge_filters", false, true, "Allow to merge filters in the query plan. This is required to properly support filter-push-down with a new analyzer."},
{"parallel_replicas_local_plan", false, true, "Use local plan for local replica in a query with parallel replicas"}, {"parallel_replicas_local_plan", false, true, "Use local plan for local replica in a query with parallel replicas"},
{"allow_experimental_bfloat16_type", false, false, "Add new experimental BFloat16 type"},
{"filesystem_cache_skip_download_if_exceeds_per_query_cache_write_limit", 1, 1, "Rename of setting skip_download_if_exceeds_query_cache_limit"}, {"filesystem_cache_skip_download_if_exceeds_per_query_cache_write_limit", 1, 1, "Rename of setting skip_download_if_exceeds_query_cache_limit"},
{"filesystem_cache_prefer_bigger_buffer_size", true, true, "New setting"}, {"filesystem_cache_prefer_bigger_buffer_size", true, true, "New setting"},
{"read_in_order_use_virtual_row", false, false, "Use virtual row while reading in order of primary key or its monotonic function fashion. It is useful when searching over multiple parts as only relevant ones are touched."}, {"read_in_order_use_virtual_row", false, false, "Use virtual row while reading in order of primary key or its monotonic function fashion. It is useful when searching over multiple parts as only relevant ones are touched."},
@ -127,7 +127,7 @@ static std::initializer_list<std::pair<ClickHouseVersion, SettingsChangesHistory
{"allow_experimental_refreshable_materialized_view", false, true, "Not experimental anymore"}, {"allow_experimental_refreshable_materialized_view", false, true, "Not experimental anymore"},
{"max_parts_to_move", 0, 1000, "New setting"}, {"max_parts_to_move", 0, 1000, "New setting"},
{"hnsw_candidate_list_size_for_search", 64, 256, "New setting. Previously, the value was optionally specified in CREATE INDEX and 64 by default."}, {"hnsw_candidate_list_size_for_search", 64, 256, "New setting. Previously, the value was optionally specified in CREATE INDEX and 64 by default."},
{"allow_reorder_prewhere_conditions", false, true, "New setting"}, {"allow_reorder_prewhere_conditions", true, true, "New setting"},
{"input_format_parquet_bloom_filter_push_down", false, true, "When reading Parquet files, skip whole row groups based on the WHERE/PREWHERE expressions and bloom filter in the Parquet metadata."}, {"input_format_parquet_bloom_filter_push_down", false, true, "When reading Parquet files, skip whole row groups based on the WHERE/PREWHERE expressions and bloom filter in the Parquet metadata."},
{"date_time_64_output_format_cut_trailing_zeros_align_to_groups_of_thousands", false, false, "Dynamically trim the trailing zeros of datetime64 values to adjust the output scale to (0, 3, 6), corresponding to 'seconds', 'milliseconds', and 'microseconds'."}, {"date_time_64_output_format_cut_trailing_zeros_align_to_groups_of_thousands", false, false, "Dynamically trim the trailing zeros of datetime64 values to adjust the output scale to (0, 3, 6), corresponding to 'seconds', 'milliseconds', and 'microseconds'."},
} }

View File

@ -726,6 +726,7 @@ private:
SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<Int128>>, strategy>, SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<Int128>>, strategy>,
SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<Int256>>, strategy>, SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<Int256>>, strategy>,
SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<BFloat16>>, strategy>,
SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<Float32>>, strategy>, SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<Float32>>, strategy>,
SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<Float64>>, strategy>, SortingQueueImpl<SpecializedSingleColumnSortCursor<ColumnVector<Float64>>, strategy>,

View File

@ -21,6 +21,7 @@ enum class TypeIndex : uint8_t
Int64, Int64,
Int128, Int128,
Int256, Int256,
BFloat16,
Float32, Float32,
Float64, Float64,
Date, Date,
@ -94,6 +95,7 @@ TYPEID_MAP(Int32)
TYPEID_MAP(Int64) TYPEID_MAP(Int64)
TYPEID_MAP(Int128) TYPEID_MAP(Int128)
TYPEID_MAP(Int256) TYPEID_MAP(Int256)
TYPEID_MAP(BFloat16)
TYPEID_MAP(Float32) TYPEID_MAP(Float32)
TYPEID_MAP(Float64) TYPEID_MAP(Float64)
TYPEID_MAP(UUID) TYPEID_MAP(UUID)

View File

@ -21,6 +21,7 @@ using Int128 = wide::integer<128, signed>;
using UInt128 = wide::integer<128, unsigned>; using UInt128 = wide::integer<128, unsigned>;
using Int256 = wide::integer<256, signed>; using Int256 = wide::integer<256, signed>;
using UInt256 = wide::integer<256, unsigned>; using UInt256 = wide::integer<256, unsigned>;
class BFloat16;
namespace DB namespace DB
{ {

View File

@ -63,6 +63,7 @@ static bool callOnBasicType(TypeIndex number, F && f)
{ {
switch (number) switch (number)
{ {
case TypeIndex::BFloat16: return f(TypePair<T, BFloat16>());
case TypeIndex::Float32: return f(TypePair<T, Float32>()); case TypeIndex::Float32: return f(TypePair<T, Float32>());
case TypeIndex::Float64: return f(TypePair<T, Float64>()); case TypeIndex::Float64: return f(TypePair<T, Float64>());
default: default:
@ -133,6 +134,7 @@ static inline bool callOnBasicTypes(TypeIndex type_num1, TypeIndex type_num2, F
{ {
switch (type_num1) switch (type_num1)
{ {
case TypeIndex::BFloat16: return callOnBasicType<BFloat16, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f));
case TypeIndex::Float32: return callOnBasicType<Float32, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f)); case TypeIndex::Float32: return callOnBasicType<Float32, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f));
case TypeIndex::Float64: return callOnBasicType<Float64, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f)); case TypeIndex::Float64: return callOnBasicType<Float64, _int, _float, _decimal, _datetime>(type_num2, std::forward<F>(f));
default: default:
@ -190,6 +192,7 @@ static bool callOnIndexAndDataType(TypeIndex number, F && f, ExtraArgs && ... ar
case TypeIndex::Int128: return f(TypePair<DataTypeNumber<Int128>, T>(), std::forward<ExtraArgs>(args)...); case TypeIndex::Int128: return f(TypePair<DataTypeNumber<Int128>, T>(), std::forward<ExtraArgs>(args)...);
case TypeIndex::Int256: return f(TypePair<DataTypeNumber<Int256>, T>(), std::forward<ExtraArgs>(args)...); case TypeIndex::Int256: return f(TypePair<DataTypeNumber<Int256>, T>(), std::forward<ExtraArgs>(args)...);
case TypeIndex::BFloat16: return f(TypePair<DataTypeNumber<BFloat16>, T>(), std::forward<ExtraArgs>(args)...);
case TypeIndex::Float32: return f(TypePair<DataTypeNumber<Float32>, T>(), std::forward<ExtraArgs>(args)...); case TypeIndex::Float32: return f(TypePair<DataTypeNumber<Float32>, T>(), std::forward<ExtraArgs>(args)...);
case TypeIndex::Float64: return f(TypePair<DataTypeNumber<Float64>, T>(), std::forward<ExtraArgs>(args)...); case TypeIndex::Float64: return f(TypePair<DataTypeNumber<Float64>, T>(), std::forward<ExtraArgs>(args)...);

View File

@ -42,6 +42,7 @@ template class DataTypeNumberBase<Int32>;
template class DataTypeNumberBase<Int64>; template class DataTypeNumberBase<Int64>;
template class DataTypeNumberBase<Int128>; template class DataTypeNumberBase<Int128>;
template class DataTypeNumberBase<Int256>; template class DataTypeNumberBase<Int256>;
template class DataTypeNumberBase<BFloat16>;
template class DataTypeNumberBase<Float32>; template class DataTypeNumberBase<Float32>;
template class DataTypeNumberBase<Float64>; template class DataTypeNumberBase<Float64>;

View File

@ -68,6 +68,7 @@ extern template class DataTypeNumberBase<Int32>;
extern template class DataTypeNumberBase<Int64>; extern template class DataTypeNumberBase<Int64>;
extern template class DataTypeNumberBase<Int128>; extern template class DataTypeNumberBase<Int128>;
extern template class DataTypeNumberBase<Int256>; extern template class DataTypeNumberBase<Int256>;
extern template class DataTypeNumberBase<BFloat16>;
extern template class DataTypeNumberBase<Float32>; extern template class DataTypeNumberBase<Float32>;
extern template class DataTypeNumberBase<Float64>; extern template class DataTypeNumberBase<Float64>;

View File

@ -96,6 +96,7 @@ enum class BinaryTypeIndex : uint8_t
SimpleAggregateFunction = 0x2E, SimpleAggregateFunction = 0x2E,
Nested = 0x2F, Nested = 0x2F,
JSON = 0x30, JSON = 0x30,
BFloat16 = 0x31,
}; };
/// In future we can introduce more arguments in the JSON data type definition. /// In future we can introduce more arguments in the JSON data type definition.
@ -151,6 +152,8 @@ BinaryTypeIndex getBinaryTypeIndex(const DataTypePtr & type)
return BinaryTypeIndex::Int128; return BinaryTypeIndex::Int128;
case TypeIndex::Int256: case TypeIndex::Int256:
return BinaryTypeIndex::Int256; return BinaryTypeIndex::Int256;
case TypeIndex::BFloat16:
return BinaryTypeIndex::BFloat16;
case TypeIndex::Float32: case TypeIndex::Float32:
return BinaryTypeIndex::Float32; return BinaryTypeIndex::Float32;
case TypeIndex::Float64: case TypeIndex::Float64:
@ -565,6 +568,8 @@ DataTypePtr decodeDataType(ReadBuffer & buf)
return std::make_shared<DataTypeInt128>(); return std::make_shared<DataTypeInt128>();
case BinaryTypeIndex::Int256: case BinaryTypeIndex::Int256:
return std::make_shared<DataTypeInt256>(); return std::make_shared<DataTypeInt256>();
case BinaryTypeIndex::BFloat16:
return std::make_shared<DataTypeBFloat16>();
case BinaryTypeIndex::Float32: case BinaryTypeIndex::Float32:
return std::make_shared<DataTypeFloat32>(); return std::make_shared<DataTypeFloat32>();
case BinaryTypeIndex::Float64: case BinaryTypeIndex::Float64:

View File

@ -2,6 +2,7 @@
#include <DataTypes/Serializations/SerializationDecimal.h> #include <DataTypes/Serializations/SerializationDecimal.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <Common/NaNUtils.h>
#include <Core/DecimalFunctions.h> #include <Core/DecimalFunctions.h>
#include <DataTypes/DataTypeFactory.h> #include <DataTypes/DataTypeFactory.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
@ -19,6 +20,7 @@ namespace ErrorCodes
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int DECIMAL_OVERFLOW; extern const int DECIMAL_OVERFLOW;
extern const int NOT_IMPLEMENTED;
} }
@ -268,9 +270,13 @@ ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & value,
static constexpr bool throw_exception = std::is_same_v<ReturnType, void>; static constexpr bool throw_exception = std::is_same_v<ReturnType, void>;
if constexpr (std::is_floating_point_v<FromFieldType>) if constexpr (std::is_same_v<typename FromDataType::FieldType, BFloat16>)
{ {
if (!std::isfinite(value)) throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conversion from BFloat16 to Decimal is not implemented");
}
else if constexpr (is_floating_point<FromFieldType>)
{
if (!isFinite(value))
{ {
if constexpr (throw_exception) if constexpr (throw_exception)
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "{} convert overflow. Cannot convert infinity or NaN to decimal", ToDataType::family_name); throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "{} convert overflow. Cannot convert infinity or NaN to decimal", ToDataType::family_name);

View File

@ -4,7 +4,6 @@
#include <base/extended_types.h> #include <base/extended_types.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include <base/Decimal.h> #include <base/Decimal.h>
#include <base/Decimal_fwd.h>
#include <DataTypes/IDataType.h> #include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeDate.h> #include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDate32.h> #include <DataTypes/DataTypeDate32.h>
@ -205,7 +204,6 @@ FOR_EACH_DECIMAL_TYPE(INVOKE);
#undef INVOKE #undef INVOKE
#undef DISPATCH #undef DISPATCH
template <typename FromDataType, typename ToDataType> template <typename FromDataType, typename ToDataType>
requires (is_arithmetic_v<typename FromDataType::FieldType> && IsDataTypeDecimal<ToDataType>) requires (is_arithmetic_v<typename FromDataType::FieldType> && IsDataTypeDecimal<ToDataType>)
typename ToDataType::FieldType convertToDecimal(const typename FromDataType::FieldType & value, UInt32 scale); typename ToDataType::FieldType convertToDecimal(const typename FromDataType::FieldType & value, UInt32 scale);

View File

@ -54,6 +54,7 @@ void registerDataTypeNumbers(DataTypeFactory & factory)
factory.registerDataType("Int32", createNumericDataType<Int32>); factory.registerDataType("Int32", createNumericDataType<Int32>);
factory.registerDataType("Int64", createNumericDataType<Int64>); factory.registerDataType("Int64", createNumericDataType<Int64>);
factory.registerDataType("BFloat16", createNumericDataType<BFloat16>);
factory.registerDataType("Float32", createNumericDataType<Float32>); factory.registerDataType("Float32", createNumericDataType<Float32>);
factory.registerDataType("Float64", createNumericDataType<Float64>); factory.registerDataType("Float64", createNumericDataType<Float64>);
@ -111,6 +112,7 @@ template class DataTypeNumber<Int8>;
template class DataTypeNumber<Int16>; template class DataTypeNumber<Int16>;
template class DataTypeNumber<Int32>; template class DataTypeNumber<Int32>;
template class DataTypeNumber<Int64>; template class DataTypeNumber<Int64>;
template class DataTypeNumber<BFloat16>;
template class DataTypeNumber<Float32>; template class DataTypeNumber<Float32>;
template class DataTypeNumber<Float64>; template class DataTypeNumber<Float64>;

View File

@ -63,6 +63,7 @@ extern template class DataTypeNumber<Int8>;
extern template class DataTypeNumber<Int16>; extern template class DataTypeNumber<Int16>;
extern template class DataTypeNumber<Int32>; extern template class DataTypeNumber<Int32>;
extern template class DataTypeNumber<Int64>; extern template class DataTypeNumber<Int64>;
extern template class DataTypeNumber<BFloat16>;
extern template class DataTypeNumber<Float32>; extern template class DataTypeNumber<Float32>;
extern template class DataTypeNumber<Float64>; extern template class DataTypeNumber<Float64>;
@ -79,6 +80,7 @@ using DataTypeInt8 = DataTypeNumber<Int8>;
using DataTypeInt16 = DataTypeNumber<Int16>; using DataTypeInt16 = DataTypeNumber<Int16>;
using DataTypeInt32 = DataTypeNumber<Int32>; using DataTypeInt32 = DataTypeNumber<Int32>;
using DataTypeInt64 = DataTypeNumber<Int64>; using DataTypeInt64 = DataTypeNumber<Int64>;
using DataTypeBFloat16 = DataTypeNumber<BFloat16>;
using DataTypeFloat32 = DataTypeNumber<Float32>; using DataTypeFloat32 = DataTypeNumber<Float32>;
using DataTypeFloat64 = DataTypeNumber<Float64>; using DataTypeFloat64 = DataTypeNumber<Float64>;

View File

@ -408,9 +408,11 @@ struct WhichDataType
constexpr bool isDecimal256() const { return idx == TypeIndex::Decimal256; } constexpr bool isDecimal256() const { return idx == TypeIndex::Decimal256; }
constexpr bool isDecimal() const { return isDecimal32() || isDecimal64() || isDecimal128() || isDecimal256(); } constexpr bool isDecimal() const { return isDecimal32() || isDecimal64() || isDecimal128() || isDecimal256(); }
constexpr bool isBFloat16() const { return idx == TypeIndex::BFloat16; }
constexpr bool isFloat32() const { return idx == TypeIndex::Float32; } constexpr bool isFloat32() const { return idx == TypeIndex::Float32; }
constexpr bool isFloat64() const { return idx == TypeIndex::Float64; } constexpr bool isFloat64() const { return idx == TypeIndex::Float64; }
constexpr bool isFloat() const { return isFloat32() || isFloat64(); } constexpr bool isNativeFloat() const { return isFloat32() || isFloat64(); }
constexpr bool isFloat() const { return isNativeFloat() || isBFloat16(); }
constexpr bool isNativeNumber() const { return isNativeInteger() || isFloat(); } constexpr bool isNativeNumber() const { return isNativeInteger() || isFloat(); }
constexpr bool isNumber() const { return isInteger() || isFloat() || isDecimal(); } constexpr bool isNumber() const { return isInteger() || isFloat() || isDecimal(); }
@ -625,6 +627,7 @@ template <typename T> inline constexpr bool IsDataTypeEnum<DataTypeEnum<T>> = tr
M(Int64) \ M(Int64) \
M(Int128) \ M(Int128) \
M(Int256) \ M(Int256) \
M(BFloat16) \
M(Float32) \ M(Float32) \
M(Float64) M(Float64)
} }

View File

@ -37,7 +37,7 @@ bool canBeNativeType(const IDataType & type)
return canBeNativeType(*data_type_nullable.getNestedType()); return canBeNativeType(*data_type_nullable.getNestedType());
} }
return data_type.isNativeInt() || data_type.isNativeUInt() || data_type.isFloat() || data_type.isDate() return data_type.isNativeInt() || data_type.isNativeUInt() || data_type.isNativeFloat() || data_type.isDate()
|| data_type.isDate32() || data_type.isDateTime() || data_type.isEnum(); || data_type.isDate32() || data_type.isDateTime() || data_type.isEnum();
} }

View File

@ -74,7 +74,7 @@ template <typename A, typename B> struct ResultOfAdditionMultiplication
{ {
using Type = typename Construct< using Type = typename Construct<
is_signed_v<A> || is_signed_v<B>, is_signed_v<A> || is_signed_v<B>,
std::is_floating_point_v<A> || std::is_floating_point_v<B>, is_floating_point<A> || is_floating_point<B>,
nextSize(max(sizeof(A), sizeof(B)))>::Type; nextSize(max(sizeof(A), sizeof(B)))>::Type;
}; };
@ -82,7 +82,7 @@ template <typename A, typename B> struct ResultOfSubtraction
{ {
using Type = typename Construct< using Type = typename Construct<
true, true,
std::is_floating_point_v<A> || std::is_floating_point_v<B>, is_floating_point<A> || is_floating_point<B>,
nextSize(max(sizeof(A), sizeof(B)))>::Type; nextSize(max(sizeof(A), sizeof(B)))>::Type;
}; };
@ -113,7 +113,7 @@ template <typename A, typename B> struct ResultOfModulo
/// Example: toInt32(-199) % toUInt8(200) will return -199 that does not fit in Int8, only in Int16. /// Example: toInt32(-199) % toUInt8(200) will return -199 that does not fit in Int8, only in Int16.
static constexpr size_t size_of_result = result_is_signed ? nextSize(sizeof(B)) : sizeof(B); static constexpr size_t size_of_result = result_is_signed ? nextSize(sizeof(B)) : sizeof(B);
using Type0 = typename Construct<result_is_signed, false, size_of_result>::Type; using Type0 = typename Construct<result_is_signed, false, size_of_result>::Type;
using Type = std::conditional_t<std::is_floating_point_v<A> || std::is_floating_point_v<B>, Float64, Type0>; using Type = std::conditional_t<is_floating_point<A> || is_floating_point<B>, Float64, Type0>;
}; };
template <typename A, typename B> struct ResultOfPositiveModulo template <typename A, typename B> struct ResultOfPositiveModulo
@ -121,21 +121,21 @@ template <typename A, typename B> struct ResultOfPositiveModulo
/// function positive_modulo always return non-negative number. /// function positive_modulo always return non-negative number.
static constexpr size_t size_of_result = sizeof(B); static constexpr size_t size_of_result = sizeof(B);
using Type0 = typename Construct<false, false, size_of_result>::Type; using Type0 = typename Construct<false, false, size_of_result>::Type;
using Type = std::conditional_t<std::is_floating_point_v<A> || std::is_floating_point_v<B>, Float64, Type0>; using Type = std::conditional_t<is_floating_point<A> || is_floating_point<B>, Float64, Type0>;
}; };
template <typename A, typename B> struct ResultOfModuloLegacy template <typename A, typename B> struct ResultOfModuloLegacy
{ {
using Type0 = typename Construct<is_signed_v<A> || is_signed_v<B>, false, sizeof(B)>::Type; using Type0 = typename Construct<is_signed_v<A> || is_signed_v<B>, false, sizeof(B)>::Type;
using Type = std::conditional_t<std::is_floating_point_v<A> || std::is_floating_point_v<B>, Float64, Type0>; using Type = std::conditional_t<is_floating_point<A> || is_floating_point<B>, Float64, Type0>;
}; };
template <typename A> struct ResultOfNegate template <typename A> struct ResultOfNegate
{ {
using Type = typename Construct< using Type = typename Construct<
true, true,
std::is_floating_point_v<A>, is_floating_point<A>,
is_signed_v<A> ? sizeof(A) : nextSize(sizeof(A))>::Type; is_signed_v<A> ? sizeof(A) : nextSize(sizeof(A))>::Type;
}; };
@ -143,7 +143,7 @@ template <typename A> struct ResultOfAbs
{ {
using Type = typename Construct< using Type = typename Construct<
false, false,
std::is_floating_point_v<A>, is_floating_point<A>,
sizeof(A)>::Type; sizeof(A)>::Type;
}; };
@ -154,7 +154,7 @@ template <typename A, typename B> struct ResultOfBit
using Type = typename Construct< using Type = typename Construct<
is_signed_v<A> || is_signed_v<B>, is_signed_v<A> || is_signed_v<B>,
false, false,
std::is_floating_point_v<A> || std::is_floating_point_v<B> ? 8 : max(sizeof(A), sizeof(B))>::Type; is_floating_point<A> || is_floating_point<B> ? 8 : max(sizeof(A), sizeof(B))>::Type;
}; };
template <typename A> struct ResultOfBitNot template <typename A> struct ResultOfBitNot
@ -180,7 +180,7 @@ template <typename A> struct ResultOfBitNot
template <typename A, typename B> template <typename A, typename B>
struct ResultOfIf struct ResultOfIf
{ {
static constexpr bool has_float = std::is_floating_point_v<A> || std::is_floating_point_v<B>; static constexpr bool has_float = is_floating_point<A> || is_floating_point<B>;
static constexpr bool has_integer = is_integer<A> || is_integer<B>; static constexpr bool has_integer = is_integer<A> || is_integer<B>;
static constexpr bool has_signed = is_signed_v<A> || is_signed_v<B>; static constexpr bool has_signed = is_signed_v<A> || is_signed_v<B>;
static constexpr bool has_unsigned = !is_signed_v<A> || !is_signed_v<B>; static constexpr bool has_unsigned = !is_signed_v<A> || !is_signed_v<B>;
@ -189,7 +189,7 @@ struct ResultOfIf
static constexpr size_t max_size_of_unsigned_integer = max(is_signed_v<A> ? 0 : sizeof(A), is_signed_v<B> ? 0 : sizeof(B)); static constexpr size_t max_size_of_unsigned_integer = max(is_signed_v<A> ? 0 : sizeof(A), is_signed_v<B> ? 0 : sizeof(B));
static constexpr size_t max_size_of_signed_integer = max(is_signed_v<A> ? sizeof(A) : 0, is_signed_v<B> ? sizeof(B) : 0); static constexpr size_t max_size_of_signed_integer = max(is_signed_v<A> ? sizeof(A) : 0, is_signed_v<B> ? sizeof(B) : 0);
static constexpr size_t max_size_of_integer = max(is_integer<A> ? sizeof(A) : 0, is_integer<B> ? sizeof(B) : 0); static constexpr size_t max_size_of_integer = max(is_integer<A> ? sizeof(A) : 0, is_integer<B> ? sizeof(B) : 0);
static constexpr size_t max_size_of_float = max(std::is_floating_point_v<A> ? sizeof(A) : 0, std::is_floating_point_v<B> ? sizeof(B) : 0); static constexpr size_t max_size_of_float = max(is_floating_point<A> ? sizeof(A) : 0, is_floating_point<B> ? sizeof(B) : 0);
using ConstructedType = typename Construct<has_signed, has_float, using ConstructedType = typename Construct<has_signed, has_float,
((has_float && has_integer && max_size_of_integer >= max_size_of_float) ((has_float && has_integer && max_size_of_integer >= max_size_of_float)
@ -244,7 +244,7 @@ template <typename A> struct ToInteger
using Type = typename Construct< using Type = typename Construct<
is_signed_v<A>, is_signed_v<A>,
false, false,
std::is_floating_point_v<A> ? 8 : sizeof(A)>::Type; is_floating_point<A> ? 8 : sizeof(A)>::Type;
}; };

View File

@ -238,6 +238,7 @@ template class SerializationNumber<Int32>;
template class SerializationNumber<Int64>; template class SerializationNumber<Int64>;
template class SerializationNumber<Int128>; template class SerializationNumber<Int128>;
template class SerializationNumber<Int256>; template class SerializationNumber<Int256>;
template class SerializationNumber<BFloat16>;
template class SerializationNumber<Float32>; template class SerializationNumber<Float32>;
template class SerializationNumber<Float64>; template class SerializationNumber<Float64>;

View File

@ -54,6 +54,13 @@ bool canBeSafelyCasted(const DataTypePtr & from_type, const DataTypePtr & to_typ
return false; return false;
} }
case TypeIndex::BFloat16:
{
if (to_which_type.isFloat32() || to_which_type.isFloat64() || to_which_type.isString())
return true;
return false;
}
case TypeIndex::Float32: case TypeIndex::Float32:
{ {
if (to_which_type.isFloat64() || to_which_type.isString()) if (to_which_type.isFloat64() || to_which_type.isString())

View File

@ -109,6 +109,8 @@ DataTypePtr getNumericType(const TypeIndexSet & types)
maximize(max_bits_of_signed_integer, 128); maximize(max_bits_of_signed_integer, 128);
else if (type == TypeIndex::Int256) else if (type == TypeIndex::Int256)
maximize(max_bits_of_signed_integer, 256); maximize(max_bits_of_signed_integer, 256);
else if (type == TypeIndex::BFloat16)
maximize(max_mantissa_bits_of_floating, 8);
else if (type == TypeIndex::Float32) else if (type == TypeIndex::Float32)
maximize(max_mantissa_bits_of_floating, 24); maximize(max_mantissa_bits_of_floating, 24);
else if (type == TypeIndex::Float64) else if (type == TypeIndex::Float64)
@ -145,7 +147,9 @@ DataTypePtr getNumericType(const TypeIndexSet & types)
if (max_mantissa_bits_of_floating) if (max_mantissa_bits_of_floating)
{ {
size_t min_mantissa_bits = std::max(min_bit_width_of_integer, max_mantissa_bits_of_floating); size_t min_mantissa_bits = std::max(min_bit_width_of_integer, max_mantissa_bits_of_floating);
if (min_mantissa_bits <= 24) if (min_mantissa_bits <= 8)
return std::make_shared<DataTypeBFloat16>();
else if (min_mantissa_bits <= 24)
return std::make_shared<DataTypeFloat32>(); return std::make_shared<DataTypeFloat32>();
if (min_mantissa_bits <= 53) if (min_mantissa_bits <= 53)
return std::make_shared<DataTypeFloat64>(); return std::make_shared<DataTypeFloat64>();

View File

@ -297,6 +297,8 @@ DataTypePtr getMostSubtype(const DataTypes & types, bool throw_if_result_is_noth
minimize(min_bits_of_signed_integer, 128); minimize(min_bits_of_signed_integer, 128);
else if (typeid_cast<const DataTypeInt256 *>(type.get())) else if (typeid_cast<const DataTypeInt256 *>(type.get()))
minimize(min_bits_of_signed_integer, 256); minimize(min_bits_of_signed_integer, 256);
else if (typeid_cast<const DataTypeBFloat16 *>(type.get()))
minimize(min_mantissa_bits_of_floating, 8);
else if (typeid_cast<const DataTypeFloat32 *>(type.get())) else if (typeid_cast<const DataTypeFloat32 *>(type.get()))
minimize(min_mantissa_bits_of_floating, 24); minimize(min_mantissa_bits_of_floating, 24);
else if (typeid_cast<const DataTypeFloat64 *>(type.get())) else if (typeid_cast<const DataTypeFloat64 *>(type.get()))
@ -313,7 +315,9 @@ DataTypePtr getMostSubtype(const DataTypes & types, bool throw_if_result_is_noth
/// If the result must be floating. /// If the result must be floating.
if (!min_bits_of_signed_integer && !min_bits_of_unsigned_integer) if (!min_bits_of_signed_integer && !min_bits_of_unsigned_integer)
{ {
if (min_mantissa_bits_of_floating <= 24) if (min_mantissa_bits_of_floating <= 8)
return std::make_shared<DataTypeBFloat16>();
else if (min_mantissa_bits_of_floating <= 24)
return std::make_shared<DataTypeFloat32>(); return std::make_shared<DataTypeFloat32>();
if (min_mantissa_bits_of_floating <= 53) if (min_mantissa_bits_of_floating <= 53)
return std::make_shared<DataTypeFloat64>(); return std::make_shared<DataTypeFloat64>();

View File

@ -24,10 +24,11 @@ void enableAllExperimentalSettings(ContextMutablePtr context)
context->setSetting("allow_experimental_dynamic_type", 1); context->setSetting("allow_experimental_dynamic_type", 1);
context->setSetting("allow_experimental_json_type", 1); context->setSetting("allow_experimental_json_type", 1);
context->setSetting("allow_experimental_vector_similarity_index", 1); context->setSetting("allow_experimental_vector_similarity_index", 1);
context->setSetting("allow_experimental_bigint_types", 1);
context->setSetting("allow_experimental_window_functions", 1); context->setSetting("allow_experimental_window_functions", 1);
context->setSetting("allow_experimental_geo_types", 1); context->setSetting("allow_experimental_geo_types", 1);
context->setSetting("allow_experimental_map_type", 1); context->setSetting("allow_experimental_map_type", 1);
context->setSetting("allow_experimental_bigint_types", 1);
context->setSetting("allow_experimental_bfloat16_type", 1);
context->setSetting("allow_deprecated_error_prone_window_functions", 1); context->setSetting("allow_deprecated_error_prone_window_functions", 1);
context->setSetting("allow_suspicious_low_cardinality_types", 1); context->setSetting("allow_suspicious_low_cardinality_types", 1);

View File

@ -298,7 +298,8 @@ namespace impl
using Types = std::decay_t<decltype(types)>; using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::LeftType; using DataType = typename Types::LeftType;
if constexpr (IsDataTypeDecimalOrNumber<DataType> || IsDataTypeDateOrDateTime<DataType> || IsDataTypeEnum<DataType>) if constexpr ((IsDataTypeDecimalOrNumber<DataType> || IsDataTypeDateOrDateTime<DataType> || IsDataTypeEnum<DataType>)
&& !std::is_same_v<DataType, DataTypeBFloat16>)
{ {
using ColumnType = typename DataType::ColumnType; using ColumnType = typename DataType::ColumnType;
func(TypePair<ColumnType, void>()); func(TypePair<ColumnType, void>());

View File

@ -10,7 +10,6 @@
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <IO/ReadBufferFromString.h> #include <IO/ReadBufferFromString.h>
#include <IO/parseDateTimeBestEffort.h> #include <IO/parseDateTimeBestEffort.h>
#include <Parsers/TokenIterator.h>
namespace DB namespace DB

View File

@ -131,7 +131,7 @@ bool tryGetNumericValueFromJSONElement(
switch (element.type()) switch (element.type())
{ {
case ElementType::DOUBLE: case ElementType::DOUBLE:
if constexpr (std::is_floating_point_v<NumberType>) if constexpr (is_floating_point<NumberType>)
{ {
/// We permit inaccurate conversion of double to float. /// We permit inaccurate conversion of double to float.
/// Example: double 0.1 from JSON is not representable in float. /// Example: double 0.1 from JSON is not representable in float.
@ -175,7 +175,7 @@ bool tryGetNumericValueFromJSONElement(
return false; return false;
auto rb = ReadBufferFromMemory{element.getString()}; auto rb = ReadBufferFromMemory{element.getString()};
if constexpr (std::is_floating_point_v<NumberType>) if constexpr (is_floating_point<NumberType>)
{ {
if (!tryReadFloatText(value, rb) || !rb.eof()) if (!tryReadFloatText(value, rb) || !rb.eof())
{ {

View File

@ -540,7 +540,7 @@ namespace
case FieldTypeId::TYPE_ENUM: case FieldTypeId::TYPE_ENUM:
{ {
if (std::is_floating_point_v<NumberType>) if (is_floating_point<NumberType>)
incompatibleColumnType(TypeName<NumberType>); incompatibleColumnType(TypeName<NumberType>);
write_function = [this](NumberType value) write_function = [this](NumberType value)

View File

@ -87,9 +87,9 @@ inline auto checkedDivision(A a, B b)
{ {
throwIfDivisionLeadsToFPE(a, b); throwIfDivisionLeadsToFPE(a, b);
if constexpr (is_big_int_v<A> && std::is_floating_point_v<B>) if constexpr (is_big_int_v<A> && is_floating_point<B>)
return static_cast<B>(a) / b; return static_cast<B>(a) / b;
else if constexpr (is_big_int_v<B> && std::is_floating_point_v<A>) else if constexpr (is_big_int_v<B> && is_floating_point<A>)
return a / static_cast<A>(b); return a / static_cast<A>(b);
else if constexpr (is_big_int_v<A> && is_big_int_v<B>) else if constexpr (is_big_int_v<A> && is_big_int_v<B>)
return static_cast<A>(a / b); return static_cast<A>(a / b);
@ -126,17 +126,17 @@ struct DivideIntegralImpl
{ {
/// Comparisons are not strict to avoid rounding issues when operand is implicitly casted to float. /// Comparisons are not strict to avoid rounding issues when operand is implicitly casted to float.
if constexpr (std::is_floating_point_v<A>) if constexpr (is_floating_point<A>)
if (isNaN(a) || a >= std::numeric_limits<CastA>::max() || a <= std::numeric_limits<CastA>::lowest()) if (isNaN(a) || a >= std::numeric_limits<CastA>::max() || a <= std::numeric_limits<CastA>::lowest())
throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers");
if constexpr (std::is_floating_point_v<B>) if constexpr (is_floating_point<B>)
if (isNaN(b) || b >= std::numeric_limits<CastB>::max() || b <= std::numeric_limits<CastB>::lowest()) if (isNaN(b) || b >= std::numeric_limits<CastB>::max() || b <= std::numeric_limits<CastB>::lowest())
throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers");
auto res = checkedDivision(CastA(a), CastB(b)); auto res = checkedDivision(CastA(a), CastB(b));
if constexpr (std::is_floating_point_v<decltype(res)>) if constexpr (is_floating_point<decltype(res)>)
if (isNaN(res) || res >= static_cast<double>(std::numeric_limits<Result>::max()) || res <= std::numeric_limits<Result>::lowest()) if (isNaN(res) || res >= static_cast<double>(std::numeric_limits<Result>::max()) || res <= std::numeric_limits<Result>::lowest())
throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division, because it will produce infinite or too large number"); throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division, because it will produce infinite or too large number");
@ -162,18 +162,18 @@ struct ModuloImpl
template <typename Result = ResultType> template <typename Result = ResultType>
static Result apply(A a, B b) static Result apply(A a, B b)
{ {
if constexpr (std::is_floating_point_v<ResultType>) if constexpr (is_floating_point<ResultType>)
{ {
/// This computation is similar to `fmod` but the latter is not inlined and has 40 times worse performance. /// This computation is similar to `fmod` but the latter is not inlined and has 40 times worse performance.
return static_cast<ResultType>(a) - trunc(static_cast<ResultType>(a) / static_cast<ResultType>(b)) * static_cast<ResultType>(b); return static_cast<ResultType>(a) - trunc(static_cast<ResultType>(a) / static_cast<ResultType>(b)) * static_cast<ResultType>(b);
} }
else else
{ {
if constexpr (std::is_floating_point_v<A>) if constexpr (is_floating_point<A>)
if (isNaN(a) || a > std::numeric_limits<IntegerAType>::max() || a < std::numeric_limits<IntegerAType>::lowest()) if (isNaN(a) || a > std::numeric_limits<IntegerAType>::max() || a < std::numeric_limits<IntegerAType>::lowest())
throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers");
if constexpr (std::is_floating_point_v<B>) if constexpr (is_floating_point<B>)
if (isNaN(b) || b > std::numeric_limits<IntegerBType>::max() || b < std::numeric_limits<IntegerBType>::lowest()) if (isNaN(b) || b > std::numeric_limits<IntegerBType>::max() || b < std::numeric_limits<IntegerBType>::lowest())
throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers"); throw Exception(ErrorCodes::ILLEGAL_DIVISION, "Cannot perform integer division on infinite or too large floating point numbers");

View File

@ -110,6 +110,7 @@ template <typename DataType> constexpr bool IsIntegralOrExtendedOrDecimal =
IsDataTypeDecimal<DataType>; IsDataTypeDecimal<DataType>;
template <typename DataType> constexpr bool IsFloatingPoint = false; template <typename DataType> constexpr bool IsFloatingPoint = false;
template <> inline constexpr bool IsFloatingPoint<DataTypeBFloat16> = true;
template <> inline constexpr bool IsFloatingPoint<DataTypeFloat32> = true; template <> inline constexpr bool IsFloatingPoint<DataTypeFloat32> = true;
template <> inline constexpr bool IsFloatingPoint<DataTypeFloat64> = true; template <> inline constexpr bool IsFloatingPoint<DataTypeFloat64> = true;
@ -804,7 +805,7 @@ class FunctionBinaryArithmetic : public IFunction
DataTypeFixedString, DataTypeString, DataTypeFixedString, DataTypeString,
DataTypeInterval>; DataTypeInterval>;
using Floats = TypeList<DataTypeFloat32, DataTypeFloat64>; using Floats = TypeList<DataTypeFloat32, DataTypeFloat64, DataTypeBFloat16>;
using ValidTypes = std::conditional_t<valid_on_float_arguments, using ValidTypes = std::conditional_t<valid_on_float_arguments,
TypeListConcat<Types, Floats>, TypeListConcat<Types, Floats>,
@ -1691,6 +1692,13 @@ public:
} }
else else
{ {
if constexpr ((std::is_same_v<LeftDataType, DataTypeBFloat16> || std::is_same_v<RightDataType, DataTypeBFloat16>)
&& (sizeof(typename LeftDataType::FieldType) > 8 || sizeof(typename RightDataType::FieldType) > 8))
{
/// Big integers and BFloat16 are not supported together.
return false;
}
using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType; using ResultDataType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::ResultDataType;
if constexpr (!std::is_same_v<ResultDataType, InvalidType>) if constexpr (!std::is_same_v<ResultDataType, InvalidType>)
@ -2043,7 +2051,15 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
using DecimalResultType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::DecimalResultDataType; using DecimalResultType = typename BinaryOperationTraits<Op, LeftDataType, RightDataType>::DecimalResultDataType;
if constexpr (std::is_same_v<ResultDataType, InvalidType>) if constexpr (std::is_same_v<ResultDataType, InvalidType>)
{
return nullptr; return nullptr;
}
else if constexpr ((std::is_same_v<LeftDataType, DataTypeBFloat16> || std::is_same_v<RightDataType, DataTypeBFloat16>)
&& (sizeof(typename LeftDataType::FieldType) > 8 || sizeof(typename RightDataType::FieldType) > 8))
{
/// Big integers and BFloat16 are not supported together.
return nullptr;
}
else // we can't avoid the else because otherwise the compiler may assume the ResultDataType may be Invalid else // we can't avoid the else because otherwise the compiler may assume the ResultDataType may be Invalid
// and that would produce the compile error. // and that would produce the compile error.
{ {
@ -2060,7 +2076,7 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
ColumnPtr left_col = nullptr; ColumnPtr left_col = nullptr;
ColumnPtr right_col = nullptr; ColumnPtr right_col = nullptr;
/// When Decimal op Float32/64, convert both of them into Float64 /// When Decimal op Float32/64/16, convert both of them into Float64
if constexpr (decimal_with_float) if constexpr (decimal_with_float)
{ {
const auto converted_type = std::make_shared<DataTypeFloat64>(); const auto converted_type = std::make_shared<DataTypeFloat64>();
@ -2095,7 +2111,6 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
/// Here we check if we have `intDiv` or `intDivOrZero` and at least one of the arguments is decimal, because in this case originally we had result as decimal, so we need to convert result into integer after calculations /// Here we check if we have `intDiv` or `intDivOrZero` and at least one of the arguments is decimal, because in this case originally we had result as decimal, so we need to convert result into integer after calculations
else if constexpr (!decimal_with_float && (is_int_div || is_int_div_or_zero) && (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>)) else if constexpr (!decimal_with_float && (is_int_div || is_int_div_or_zero) && (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>))
{ {
if constexpr (!std::is_same_v<DecimalResultType, InvalidType>) if constexpr (!std::is_same_v<DecimalResultType, InvalidType>)
{ {
DataTypePtr type_res; DataTypePtr type_res;

View File

@ -70,7 +70,7 @@ private:
/// Process all data as a whole and use FastOps implementation /// Process all data as a whole and use FastOps implementation
/// If the argument is integer, convert to Float64 beforehand /// If the argument is integer, convert to Float64 beforehand
if constexpr (!std::is_floating_point_v<T>) if constexpr (!is_floating_point<T>)
{ {
PODArray<Float64> tmp_vec(size); PODArray<Float64> tmp_vec(size);
for (size_t i = 0; i < size; ++i) for (size_t i = 0; i < size; ++i)
@ -152,7 +152,7 @@ private:
{ {
using Types = std::decay_t<decltype(types)>; using Types = std::decay_t<decltype(types)>;
using Type = typename Types::RightType; using Type = typename Types::RightType;
using ReturnType = std::conditional_t<Impl::always_returns_float64 || !std::is_floating_point_v<Type>, Float64, Type>; using ReturnType = std::conditional_t<Impl::always_returns_float64 || !is_floating_point<Type>, Float64, Type>;
using ColVecType = ColumnVectorOrDecimal<Type>; using ColVecType = ColumnVectorOrDecimal<Type>;
const auto col_vec = checkAndGetColumn<ColVecType>(col.column.get()); const auto col_vec = checkAndGetColumn<ColVecType>(col.column.get());

View File

@ -296,6 +296,7 @@ public:
tryExecuteUIntOrInt<Int256>(column, res_column) || tryExecuteUIntOrInt<Int256>(column, res_column) ||
tryExecuteString(column, res_column) || tryExecuteString(column, res_column) ||
tryExecuteFixedString(column, res_column) || tryExecuteFixedString(column, res_column) ||
tryExecuteFloat<BFloat16>(column, res_column) ||
tryExecuteFloat<Float32>(column, res_column) || tryExecuteFloat<Float32>(column, res_column) ||
tryExecuteFloat<Float64>(column, res_column) || tryExecuteFloat<Float64>(column, res_column) ||
tryExecuteDecimal<Decimal32>(column, res_column) || tryExecuteDecimal<Decimal32>(column, res_column) ||

View File

@ -721,6 +721,7 @@ private:
|| (res = executeNumRightType<T0, Int64>(col_left, col_right_untyped)) || (res = executeNumRightType<T0, Int64>(col_left, col_right_untyped))
|| (res = executeNumRightType<T0, Int128>(col_left, col_right_untyped)) || (res = executeNumRightType<T0, Int128>(col_left, col_right_untyped))
|| (res = executeNumRightType<T0, Int256>(col_left, col_right_untyped)) || (res = executeNumRightType<T0, Int256>(col_left, col_right_untyped))
|| (res = executeNumRightType<T0, BFloat16>(col_left, col_right_untyped))
|| (res = executeNumRightType<T0, Float32>(col_left, col_right_untyped)) || (res = executeNumRightType<T0, Float32>(col_left, col_right_untyped))
|| (res = executeNumRightType<T0, Float64>(col_left, col_right_untyped))) || (res = executeNumRightType<T0, Float64>(col_left, col_right_untyped)))
return res; return res;
@ -741,6 +742,7 @@ private:
|| (res = executeNumConstRightType<T0, Int64>(col_left_const, col_right_untyped)) || (res = executeNumConstRightType<T0, Int64>(col_left_const, col_right_untyped))
|| (res = executeNumConstRightType<T0, Int128>(col_left_const, col_right_untyped)) || (res = executeNumConstRightType<T0, Int128>(col_left_const, col_right_untyped))
|| (res = executeNumConstRightType<T0, Int256>(col_left_const, col_right_untyped)) || (res = executeNumConstRightType<T0, Int256>(col_left_const, col_right_untyped))
|| (res = executeNumConstRightType<T0, BFloat16>(col_left_const, col_right_untyped))
|| (res = executeNumConstRightType<T0, Float32>(col_left_const, col_right_untyped)) || (res = executeNumConstRightType<T0, Float32>(col_left_const, col_right_untyped))
|| (res = executeNumConstRightType<T0, Float64>(col_left_const, col_right_untyped))) || (res = executeNumConstRightType<T0, Float64>(col_left_const, col_right_untyped)))
return res; return res;
@ -1292,9 +1294,10 @@ public:
|| (res = executeNumLeftType<Int64>(col_left_untyped, col_right_untyped)) || (res = executeNumLeftType<Int64>(col_left_untyped, col_right_untyped))
|| (res = executeNumLeftType<Int128>(col_left_untyped, col_right_untyped)) || (res = executeNumLeftType<Int128>(col_left_untyped, col_right_untyped))
|| (res = executeNumLeftType<Int256>(col_left_untyped, col_right_untyped)) || (res = executeNumLeftType<Int256>(col_left_untyped, col_right_untyped))
|| (res = executeNumLeftType<BFloat16>(col_left_untyped, col_right_untyped))
|| (res = executeNumLeftType<Float32>(col_left_untyped, col_right_untyped)) || (res = executeNumLeftType<Float32>(col_left_untyped, col_right_untyped))
|| (res = executeNumLeftType<Float64>(col_left_untyped, col_right_untyped)))) || (res = executeNumLeftType<Float64>(col_left_untyped, col_right_untyped))))
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}", throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of the first argument of function {}",
col_left_untyped->getName(), getName()); col_left_untyped->getName(), getName());
return res; return res;
@ -1342,7 +1345,7 @@ public:
getName(), getName(),
left_type->getName(), left_type->getName(),
right_type->getName()); right_type->getName());
/// When Decimal comparing to Float32/64, we convert both of them into Float64. /// When Decimal comparing to Float32/64/16, we convert both of them into Float64.
/// Other systems like MySQL and Spark also do as this. /// Other systems like MySQL and Spark also do as this.
if (left_is_float || right_is_float) if (left_is_float || right_is_float)
{ {

View File

@ -7,10 +7,8 @@
#include <Columns/ColumnFixedString.h> #include <Columns/ColumnFixedString.h>
#include <Columns/ColumnLowCardinality.h> #include <Columns/ColumnLowCardinality.h>
#include <Columns/ColumnMap.h> #include <Columns/ColumnMap.h>
#include <Columns/ColumnNothing.h>
#include <Columns/ColumnNullable.h> #include <Columns/ColumnNullable.h>
#include <Columns/ColumnObjectDeprecated.h> #include <Columns/ColumnObjectDeprecated.h>
#include <Columns/ColumnObject.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
#include <Columns/ColumnStringHelpers.h> #include <Columns/ColumnStringHelpers.h>
#include <Columns/ColumnTuple.h> #include <Columns/ColumnTuple.h>
@ -79,6 +77,7 @@
namespace DB namespace DB
{ {
namespace Setting namespace Setting
{ {
extern const SettingsBool cast_ipv4_ipv6_default_on_conversion_error; extern const SettingsBool cast_ipv4_ipv6_default_on_conversion_error;
@ -694,7 +693,7 @@ inline void convertFromTime<DataTypeDateTime>(DataTypeDateTime::FieldType & x, t
template <typename DataType> template <typename DataType>
void parseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTImpl *, bool precise_float_parsing) void parseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTImpl *, bool precise_float_parsing)
{ {
if constexpr (std::is_floating_point_v<typename DataType::FieldType>) if constexpr (is_floating_point<typename DataType::FieldType>)
{ {
if (precise_float_parsing) if (precise_float_parsing)
readFloatTextPrecise(x, rb); readFloatTextPrecise(x, rb);
@ -758,7 +757,7 @@ inline void parseImpl<DataTypeIPv6>(DataTypeIPv6::FieldType & x, ReadBuffer & rb
template <typename DataType> template <typename DataType>
bool tryParseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTImpl *, bool precise_float_parsing) bool tryParseImpl(typename DataType::FieldType & x, ReadBuffer & rb, const DateLUTImpl *, bool precise_float_parsing)
{ {
if constexpr (std::is_floating_point_v<typename DataType::FieldType>) if constexpr (is_floating_point<typename DataType::FieldType>)
{ {
if (precise_float_parsing) if (precise_float_parsing)
return tryReadFloatTextPrecise(x, rb); return tryReadFloatTextPrecise(x, rb);
@ -1888,7 +1887,7 @@ struct ConvertImpl
else else
{ {
/// If From Data is Nan or Inf and we convert to integer type, throw exception /// If From Data is Nan or Inf and we convert to integer type, throw exception
if constexpr (std::is_floating_point_v<FromFieldType> && !std::is_floating_point_v<ToFieldType>) if constexpr (is_floating_point<FromFieldType> && !is_floating_point<ToFieldType>)
{ {
if (!isFinite(vec_from[i])) if (!isFinite(vec_from[i]))
{ {
@ -2420,9 +2419,9 @@ private:
using RightT = typename RightDataType::FieldType; using RightT = typename RightDataType::FieldType;
static constexpr bool bad_left = static constexpr bool bad_left =
is_decimal<LeftT> || std::is_floating_point_v<LeftT> || is_big_int_v<LeftT> || is_signed_v<LeftT>; is_decimal<LeftT> || is_floating_point<LeftT> || is_big_int_v<LeftT> || is_signed_v<LeftT>;
static constexpr bool bad_right = static constexpr bool bad_right =
is_decimal<RightT> || std::is_floating_point_v<RightT> || is_big_int_v<RightT> || is_signed_v<RightT>; is_decimal<RightT> || is_floating_point<RightT> || is_big_int_v<RightT> || is_signed_v<RightT>;
/// Disallow int vs UUID conversion (but support int vs UInt128 conversion) /// Disallow int vs UUID conversion (but support int vs UInt128 conversion)
if constexpr ((bad_left && std::is_same_v<RightDataType, DataTypeUUID>) || if constexpr ((bad_left && std::is_same_v<RightDataType, DataTypeUUID>) ||
@ -2749,7 +2748,7 @@ struct ToNumberMonotonicity
/// Float cases. /// Float cases.
/// When converting to Float, the conversion is always monotonic. /// When converting to Float, the conversion is always monotonic.
if constexpr (std::is_floating_point_v<T>) if constexpr (is_floating_point<T>)
return { .is_monotonic = true, .is_always_monotonic = true }; return { .is_monotonic = true, .is_always_monotonic = true };
const auto * low_cardinality = typeid_cast<const DataTypeLowCardinality *>(&type); const auto * low_cardinality = typeid_cast<const DataTypeLowCardinality *>(&type);
@ -2962,6 +2961,7 @@ struct NameToInt32 { static constexpr auto name = "toInt32"; };
struct NameToInt64 { static constexpr auto name = "toInt64"; }; struct NameToInt64 { static constexpr auto name = "toInt64"; };
struct NameToInt128 { static constexpr auto name = "toInt128"; }; struct NameToInt128 { static constexpr auto name = "toInt128"; };
struct NameToInt256 { static constexpr auto name = "toInt256"; }; struct NameToInt256 { static constexpr auto name = "toInt256"; };
struct NameToBFloat16 { static constexpr auto name = "toBFloat16"; };
struct NameToFloat32 { static constexpr auto name = "toFloat32"; }; struct NameToFloat32 { static constexpr auto name = "toFloat32"; };
struct NameToFloat64 { static constexpr auto name = "toFloat64"; }; struct NameToFloat64 { static constexpr auto name = "toFloat64"; };
struct NameToUUID { static constexpr auto name = "toUUID"; }; struct NameToUUID { static constexpr auto name = "toUUID"; };
@ -2980,6 +2980,7 @@ using FunctionToInt32 = FunctionConvert<DataTypeInt32, NameToInt32, ToNumberMono
using FunctionToInt64 = FunctionConvert<DataTypeInt64, NameToInt64, ToNumberMonotonicity<Int64>>; using FunctionToInt64 = FunctionConvert<DataTypeInt64, NameToInt64, ToNumberMonotonicity<Int64>>;
using FunctionToInt128 = FunctionConvert<DataTypeInt128, NameToInt128, ToNumberMonotonicity<Int128>>; using FunctionToInt128 = FunctionConvert<DataTypeInt128, NameToInt128, ToNumberMonotonicity<Int128>>;
using FunctionToInt256 = FunctionConvert<DataTypeInt256, NameToInt256, ToNumberMonotonicity<Int256>>; using FunctionToInt256 = FunctionConvert<DataTypeInt256, NameToInt256, ToNumberMonotonicity<Int256>>;
using FunctionToBFloat16 = FunctionConvert<DataTypeBFloat16, NameToBFloat16, ToNumberMonotonicity<BFloat16>>;
using FunctionToFloat32 = FunctionConvert<DataTypeFloat32, NameToFloat32, ToNumberMonotonicity<Float32>>; using FunctionToFloat32 = FunctionConvert<DataTypeFloat32, NameToFloat32, ToNumberMonotonicity<Float32>>;
using FunctionToFloat64 = FunctionConvert<DataTypeFloat64, NameToFloat64, ToNumberMonotonicity<Float64>>; using FunctionToFloat64 = FunctionConvert<DataTypeFloat64, NameToFloat64, ToNumberMonotonicity<Float64>>;
@ -3017,6 +3018,7 @@ template <> struct FunctionTo<DataTypeInt32> { using Type = FunctionToInt32; };
template <> struct FunctionTo<DataTypeInt64> { using Type = FunctionToInt64; }; template <> struct FunctionTo<DataTypeInt64> { using Type = FunctionToInt64; };
template <> struct FunctionTo<DataTypeInt128> { using Type = FunctionToInt128; }; template <> struct FunctionTo<DataTypeInt128> { using Type = FunctionToInt128; };
template <> struct FunctionTo<DataTypeInt256> { using Type = FunctionToInt256; }; template <> struct FunctionTo<DataTypeInt256> { using Type = FunctionToInt256; };
template <> struct FunctionTo<DataTypeBFloat16> { using Type = FunctionToBFloat16; };
template <> struct FunctionTo<DataTypeFloat32> { using Type = FunctionToFloat32; }; template <> struct FunctionTo<DataTypeFloat32> { using Type = FunctionToFloat32; };
template <> struct FunctionTo<DataTypeFloat64> { using Type = FunctionToFloat64; }; template <> struct FunctionTo<DataTypeFloat64> { using Type = FunctionToFloat64; };
@ -3059,6 +3061,7 @@ struct NameToInt32OrZero { static constexpr auto name = "toInt32OrZero"; };
struct NameToInt64OrZero { static constexpr auto name = "toInt64OrZero"; }; struct NameToInt64OrZero { static constexpr auto name = "toInt64OrZero"; };
struct NameToInt128OrZero { static constexpr auto name = "toInt128OrZero"; }; struct NameToInt128OrZero { static constexpr auto name = "toInt128OrZero"; };
struct NameToInt256OrZero { static constexpr auto name = "toInt256OrZero"; }; struct NameToInt256OrZero { static constexpr auto name = "toInt256OrZero"; };
struct NameToBFloat16OrZero { static constexpr auto name = "toBFloat16OrZero"; };
struct NameToFloat32OrZero { static constexpr auto name = "toFloat32OrZero"; }; struct NameToFloat32OrZero { static constexpr auto name = "toFloat32OrZero"; };
struct NameToFloat64OrZero { static constexpr auto name = "toFloat64OrZero"; }; struct NameToFloat64OrZero { static constexpr auto name = "toFloat64OrZero"; };
struct NameToDateOrZero { static constexpr auto name = "toDateOrZero"; }; struct NameToDateOrZero { static constexpr auto name = "toDateOrZero"; };
@ -3085,6 +3088,7 @@ using FunctionToInt32OrZero = FunctionConvertFromString<DataTypeInt32, NameToInt
using FunctionToInt64OrZero = FunctionConvertFromString<DataTypeInt64, NameToInt64OrZero, ConvertFromStringExceptionMode::Zero>; using FunctionToInt64OrZero = FunctionConvertFromString<DataTypeInt64, NameToInt64OrZero, ConvertFromStringExceptionMode::Zero>;
using FunctionToInt128OrZero = FunctionConvertFromString<DataTypeInt128, NameToInt128OrZero, ConvertFromStringExceptionMode::Zero>; using FunctionToInt128OrZero = FunctionConvertFromString<DataTypeInt128, NameToInt128OrZero, ConvertFromStringExceptionMode::Zero>;
using FunctionToInt256OrZero = FunctionConvertFromString<DataTypeInt256, NameToInt256OrZero, ConvertFromStringExceptionMode::Zero>; using FunctionToInt256OrZero = FunctionConvertFromString<DataTypeInt256, NameToInt256OrZero, ConvertFromStringExceptionMode::Zero>;
using FunctionToBFloat16OrZero = FunctionConvertFromString<DataTypeBFloat16, NameToBFloat16OrZero, ConvertFromStringExceptionMode::Zero>;
using FunctionToFloat32OrZero = FunctionConvertFromString<DataTypeFloat32, NameToFloat32OrZero, ConvertFromStringExceptionMode::Zero>; using FunctionToFloat32OrZero = FunctionConvertFromString<DataTypeFloat32, NameToFloat32OrZero, ConvertFromStringExceptionMode::Zero>;
using FunctionToFloat64OrZero = FunctionConvertFromString<DataTypeFloat64, NameToFloat64OrZero, ConvertFromStringExceptionMode::Zero>; using FunctionToFloat64OrZero = FunctionConvertFromString<DataTypeFloat64, NameToFloat64OrZero, ConvertFromStringExceptionMode::Zero>;
using FunctionToDateOrZero = FunctionConvertFromString<DataTypeDate, NameToDateOrZero, ConvertFromStringExceptionMode::Zero>; using FunctionToDateOrZero = FunctionConvertFromString<DataTypeDate, NameToDateOrZero, ConvertFromStringExceptionMode::Zero>;
@ -3111,6 +3115,7 @@ struct NameToInt32OrNull { static constexpr auto name = "toInt32OrNull"; };
struct NameToInt64OrNull { static constexpr auto name = "toInt64OrNull"; }; struct NameToInt64OrNull { static constexpr auto name = "toInt64OrNull"; };
struct NameToInt128OrNull { static constexpr auto name = "toInt128OrNull"; }; struct NameToInt128OrNull { static constexpr auto name = "toInt128OrNull"; };
struct NameToInt256OrNull { static constexpr auto name = "toInt256OrNull"; }; struct NameToInt256OrNull { static constexpr auto name = "toInt256OrNull"; };
struct NameToBFloat16OrNull { static constexpr auto name = "toBFloat16OrNull"; };
struct NameToFloat32OrNull { static constexpr auto name = "toFloat32OrNull"; }; struct NameToFloat32OrNull { static constexpr auto name = "toFloat32OrNull"; };
struct NameToFloat64OrNull { static constexpr auto name = "toFloat64OrNull"; }; struct NameToFloat64OrNull { static constexpr auto name = "toFloat64OrNull"; };
struct NameToDateOrNull { static constexpr auto name = "toDateOrNull"; }; struct NameToDateOrNull { static constexpr auto name = "toDateOrNull"; };
@ -3137,6 +3142,7 @@ using FunctionToInt32OrNull = FunctionConvertFromString<DataTypeInt32, NameToInt
using FunctionToInt64OrNull = FunctionConvertFromString<DataTypeInt64, NameToInt64OrNull, ConvertFromStringExceptionMode::Null>; using FunctionToInt64OrNull = FunctionConvertFromString<DataTypeInt64, NameToInt64OrNull, ConvertFromStringExceptionMode::Null>;
using FunctionToInt128OrNull = FunctionConvertFromString<DataTypeInt128, NameToInt128OrNull, ConvertFromStringExceptionMode::Null>; using FunctionToInt128OrNull = FunctionConvertFromString<DataTypeInt128, NameToInt128OrNull, ConvertFromStringExceptionMode::Null>;
using FunctionToInt256OrNull = FunctionConvertFromString<DataTypeInt256, NameToInt256OrNull, ConvertFromStringExceptionMode::Null>; using FunctionToInt256OrNull = FunctionConvertFromString<DataTypeInt256, NameToInt256OrNull, ConvertFromStringExceptionMode::Null>;
using FunctionToBFloat16OrNull = FunctionConvertFromString<DataTypeBFloat16, NameToBFloat16OrNull, ConvertFromStringExceptionMode::Null>;
using FunctionToFloat32OrNull = FunctionConvertFromString<DataTypeFloat32, NameToFloat32OrNull, ConvertFromStringExceptionMode::Null>; using FunctionToFloat32OrNull = FunctionConvertFromString<DataTypeFloat32, NameToFloat32OrNull, ConvertFromStringExceptionMode::Null>;
using FunctionToFloat64OrNull = FunctionConvertFromString<DataTypeFloat64, NameToFloat64OrNull, ConvertFromStringExceptionMode::Null>; using FunctionToFloat64OrNull = FunctionConvertFromString<DataTypeFloat64, NameToFloat64OrNull, ConvertFromStringExceptionMode::Null>;
using FunctionToDateOrNull = FunctionConvertFromString<DataTypeDate, NameToDateOrNull, ConvertFromStringExceptionMode::Null>; using FunctionToDateOrNull = FunctionConvertFromString<DataTypeDate, NameToDateOrNull, ConvertFromStringExceptionMode::Null>;
@ -5335,7 +5341,7 @@ private:
if constexpr (is_any_of<ToDataType, if constexpr (is_any_of<ToDataType,
DataTypeUInt16, DataTypeUInt32, DataTypeUInt64, DataTypeUInt128, DataTypeUInt256, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64, DataTypeUInt128, DataTypeUInt256,
DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64, DataTypeInt128, DataTypeInt256, DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64, DataTypeInt128, DataTypeInt256,
DataTypeFloat32, DataTypeFloat64, DataTypeBFloat16, DataTypeFloat32, DataTypeFloat64,
DataTypeDate, DataTypeDate32, DataTypeDateTime, DataTypeDate, DataTypeDate32, DataTypeDateTime,
DataTypeUUID, DataTypeIPv4, DataTypeIPv6>) DataTypeUUID, DataTypeIPv4, DataTypeIPv6>)
{ {
@ -5588,6 +5594,17 @@ REGISTER_FUNCTION(Conversion)
factory.registerFunction<FunctionToInt64>(); factory.registerFunction<FunctionToInt64>();
factory.registerFunction<FunctionToInt128>(); factory.registerFunction<FunctionToInt128>();
factory.registerFunction<FunctionToInt256>(); factory.registerFunction<FunctionToInt256>();
factory.registerFunction<FunctionToBFloat16>(FunctionDocumentation{.description=R"(
Converts Float32 to BFloat16 with losing the precision.
Example:
[example:typical]
)",
.examples{
{"typical", "SELECT toBFloat16(12.3::Float32);", "12.3125"}},
.categories{"Conversion"}});
factory.registerFunction<FunctionToFloat32>(); factory.registerFunction<FunctionToFloat32>();
factory.registerFunction<FunctionToFloat64>(); factory.registerFunction<FunctionToFloat64>();
@ -5626,6 +5643,31 @@ REGISTER_FUNCTION(Conversion)
factory.registerFunction<FunctionToInt64OrZero>(); factory.registerFunction<FunctionToInt64OrZero>();
factory.registerFunction<FunctionToInt128OrZero>(); factory.registerFunction<FunctionToInt128OrZero>();
factory.registerFunction<FunctionToInt256OrZero>(); factory.registerFunction<FunctionToInt256OrZero>();
factory.registerFunction<FunctionToBFloat16OrZero>(FunctionDocumentation{.description=R"(
Converts String to BFloat16.
If the string does not represent a floating point value, the function returns zero.
The function allows a silent loss of precision while converting from the string representation. In that case, it will return the truncated result.
Example of successful conversion:
[example:typical]
Examples of not successful conversion:
[example:invalid1]
[example:invalid2]
Example of a loss of precision:
[example:precision]
)",
.examples{
{"typical", "SELECT toBFloat16OrZero('12.3');", "12.3125"},
{"invalid1", "SELECT toBFloat16OrZero('abc');", "0"},
{"invalid2", "SELECT toBFloat16OrZero(' 1');", "0"},
{"precision", "SELECT toBFloat16OrZero('12.3456789');", "12.375"}},
.categories{"Conversion"}});
factory.registerFunction<FunctionToFloat32OrZero>(); factory.registerFunction<FunctionToFloat32OrZero>();
factory.registerFunction<FunctionToFloat64OrZero>(); factory.registerFunction<FunctionToFloat64OrZero>();
factory.registerFunction<FunctionToDateOrZero>(); factory.registerFunction<FunctionToDateOrZero>();
@ -5654,6 +5696,31 @@ REGISTER_FUNCTION(Conversion)
factory.registerFunction<FunctionToInt64OrNull>(); factory.registerFunction<FunctionToInt64OrNull>();
factory.registerFunction<FunctionToInt128OrNull>(); factory.registerFunction<FunctionToInt128OrNull>();
factory.registerFunction<FunctionToInt256OrNull>(); factory.registerFunction<FunctionToInt256OrNull>();
factory.registerFunction<FunctionToBFloat16OrNull>(FunctionDocumentation{.description=R"(
Converts String to Nullable(BFloat16).
If the string does not represent a floating point value, the function returns NULL.
The function allows a silent loss of precision while converting from the string representation. In that case, it will return the truncated result.
Example of successful conversion:
[example:typical]
Examples of not successful conversion:
[example:invalid1]
[example:invalid2]
Example of a loss of precision:
[example:precision]
)",
.examples{
{"typical", "SELECT toBFloat16OrNull('12.3');", "12.3125"},
{"invalid1", "SELECT toBFloat16OrNull('abc');", "NULL"},
{"invalid2", "SELECT toBFloat16OrNull(' 1');", "NULL"},
{"precision", "SELECT toBFloat16OrNull('12.3456789');", "12.375"}},
.categories{"Conversion"}});
factory.registerFunction<FunctionToFloat32OrNull>(); factory.registerFunction<FunctionToFloat32OrNull>();
factory.registerFunction<FunctionToFloat64OrNull>(); factory.registerFunction<FunctionToFloat64OrNull>();
factory.registerFunction<FunctionToDateOrNull>(); factory.registerFunction<FunctionToDateOrNull>();

View File

@ -268,6 +268,19 @@ inline double roundWithMode(double x, RoundingMode mode)
std::unreachable(); std::unreachable();
} }
inline BFloat16 roundWithMode(BFloat16 x, RoundingMode mode)
{
switch (mode)
{
case RoundingMode::Round: return BFloat16(nearbyintf(Float32(x)));
case RoundingMode::Floor: return BFloat16(floorf(Float32(x)));
case RoundingMode::Ceil: return BFloat16(ceilf(Float32(x)));
case RoundingMode::Trunc: return BFloat16(truncf(Float32(x)));
}
std::unreachable();
}
template <typename T> template <typename T>
class FloatRoundingComputationBase<T, Vectorize::No> class FloatRoundingComputationBase<T, Vectorize::No>
{ {
@ -285,10 +298,15 @@ public:
static VectorType prepare(size_t scale) static VectorType prepare(size_t scale)
{ {
return load1(scale); return load1(ScalarType(scale));
} }
}; };
template <>
class FloatRoundingComputationBase<BFloat16, Vectorize::Yes> : public FloatRoundingComputationBase<BFloat16, Vectorize::No>
{
};
/** Implementation of low-level round-off functions for floating-point values. /** Implementation of low-level round-off functions for floating-point values.
*/ */
@ -511,7 +529,7 @@ template <typename T, RoundingMode rounding_mode, TieBreakingMode tie_breaking_m
struct Dispatcher struct Dispatcher
{ {
template <ScaleMode scale_mode> template <ScaleMode scale_mode>
using FunctionRoundingImpl = std::conditional_t<std::is_floating_point_v<T>, using FunctionRoundingImpl = std::conditional_t<is_floating_point<T>,
FloatRoundingImpl<T, rounding_mode, scale_mode>, FloatRoundingImpl<T, rounding_mode, scale_mode>,
IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>; IntegerRoundingImpl<T, rounding_mode, scale_mode, tie_breaking_mode>>;

View File

@ -57,7 +57,7 @@ struct ExtractNumericType
ResultType x = 0; ResultType x = 0;
if (!in.eof()) if (!in.eof())
{ {
if constexpr (std::is_floating_point_v<NumericType>) if constexpr (is_floating_point<NumericType>)
tryReadFloatText(x, in); tryReadFloatText(x, in);
else else
tryReadIntText(x, in); tryReadIntText(x, in);

View File

@ -583,7 +583,7 @@ struct CallPointInPolygon<Type, Types ...>
template <typename PointInPolygonImpl> template <typename PointInPolygonImpl>
static ColumnPtr call(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl) static ColumnPtr call(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl)
{ {
using Impl = TypeListChangeRoot<CallPointInPolygon, TypeListIntAndFloat>; using Impl = TypeListChangeRoot<CallPointInPolygon, TypeListNativeNumber>;
if (auto column = typeid_cast<const ColumnVector<Type> *>(&x)) if (auto column = typeid_cast<const ColumnVector<Type> *>(&x))
return Impl::template call<Type>(*column, y, impl); return Impl::template call<Type>(*column, y, impl);
return CallPointInPolygon<Types ...>::call(x, y, impl); return CallPointInPolygon<Types ...>::call(x, y, impl);
@ -609,7 +609,7 @@ struct CallPointInPolygon<>
template <typename PointInPolygonImpl> template <typename PointInPolygonImpl>
NO_INLINE ColumnPtr pointInPolygon(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl) NO_INLINE ColumnPtr pointInPolygon(const IColumn & x, const IColumn & y, PointInPolygonImpl && impl)
{ {
using Impl = TypeListChangeRoot<CallPointInPolygon, TypeListIntAndFloat>; using Impl = TypeListChangeRoot<CallPointInPolygon, TypeListNativeNumber>;
return Impl::call(x, y, impl); return Impl::call(x, y, impl);
} }

View File

@ -27,7 +27,7 @@ struct AbsImpl
return a < 0 ? static_cast<ResultType>(~a) + 1 : static_cast<ResultType>(a); return a < 0 ? static_cast<ResultType>(~a) + 1 : static_cast<ResultType>(a);
else if constexpr (is_integer<A> && is_unsigned_v<A>) else if constexpr (is_integer<A> && is_unsigned_v<A>)
return static_cast<ResultType>(a); return static_cast<ResultType>(a);
else if constexpr (std::is_floating_point_v<A>) else if constexpr (is_floating_point<A>)
return static_cast<ResultType>(std::abs(a)); return static_cast<ResultType>(std::abs(a));
} }

View File

@ -87,7 +87,7 @@ struct ArrayAggregateResultImpl<ArrayElement, AggregateOperation::sum>
std::conditional_t<std::is_same_v<ArrayElement, Decimal128>, Decimal128, std::conditional_t<std::is_same_v<ArrayElement, Decimal128>, Decimal128,
std::conditional_t<std::is_same_v<ArrayElement, Decimal256>, Decimal256, std::conditional_t<std::is_same_v<ArrayElement, Decimal256>, Decimal256,
std::conditional_t<std::is_same_v<ArrayElement, DateTime64>, Decimal128, std::conditional_t<std::is_same_v<ArrayElement, DateTime64>, Decimal128,
std::conditional_t<std::is_floating_point_v<ArrayElement>, Float64, std::conditional_t<is_floating_point<ArrayElement>, Float64,
std::conditional_t<std::is_signed_v<ArrayElement>, Int64, std::conditional_t<std::is_signed_v<ArrayElement>, Int64,
UInt64>>>>>>>>>>>; UInt64>>>>>>>>>>>;
}; };

View File

@ -14,6 +14,7 @@
#include <immintrin.h> #include <immintrin.h>
#endif #endif
namespace DB namespace DB
{ {
namespace ErrorCodes namespace ErrorCodes
@ -34,7 +35,7 @@ struct L1Distance
template <typename FloatType> template <typename FloatType>
struct State struct State
{ {
FloatType sum = 0; FloatType sum{};
}; };
template <typename ResultType> template <typename ResultType>
@ -65,7 +66,7 @@ struct L2Distance
template <typename FloatType> template <typename FloatType>
struct State struct State
{ {
FloatType sum = 0; FloatType sum{};
}; };
template <typename ResultType> template <typename ResultType>
@ -90,19 +91,17 @@ struct L2Distance
size_t & i_y, size_t & i_y,
State<ResultType> & state) State<ResultType> & state)
{ {
static constexpr bool is_float32 = std::is_same_v<ResultType, Float32>;
__m512 sums; __m512 sums;
if constexpr (is_float32) if constexpr (sizeof(ResultType) <= 4)
sums = _mm512_setzero_ps(); sums = _mm512_setzero_ps();
else else
sums = _mm512_setzero_pd(); sums = _mm512_setzero_pd();
constexpr size_t n = is_float32 ? 16 : 8; constexpr size_t n = sizeof(__m512) / sizeof(ResultType);
for (; i_x + n < i_max; i_x += n, i_y += n) for (; i_x + n < i_max; i_x += n, i_y += n)
{ {
if constexpr (is_float32) if constexpr (sizeof(ResultType) == 4)
{ {
__m512 x = _mm512_loadu_ps(data_x + i_x); __m512 x = _mm512_loadu_ps(data_x + i_x);
__m512 y = _mm512_loadu_ps(data_y + i_y); __m512 y = _mm512_loadu_ps(data_y + i_y);
@ -118,11 +117,38 @@ struct L2Distance
} }
} }
if constexpr (is_float32) if constexpr (sizeof(ResultType) <= 4)
state.sum = _mm512_reduce_add_ps(sums); state.sum = _mm512_reduce_add_ps(sums);
else else
state.sum = _mm512_reduce_add_pd(sums); state.sum = _mm512_reduce_add_pd(sums);
} }
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombineBF16(
const BFloat16 * __restrict data_x,
const BFloat16 * __restrict data_y,
size_t i_max,
size_t & i_x,
size_t & i_y,
State<Float32> & state)
{
__m512 sums = _mm512_setzero_ps();
constexpr size_t n = sizeof(__m512) / sizeof(BFloat16);
for (; i_x + n < i_max; i_x += n, i_y += n)
{
__m512 x_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x)));
__m512 x_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x + n / 2)));
__m512 y_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y)));
__m512 y_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y + n / 2)));
__m512 differences_1 = _mm512_sub_ps(x_1, y_1);
__m512 differences_2 = _mm512_sub_ps(x_2, y_2);
sums = _mm512_fmadd_ps(differences_1, differences_1, sums);
sums = _mm512_fmadd_ps(differences_2, differences_2, sums);
}
state.sum = _mm512_reduce_add_ps(sums);
}
#endif #endif
template <typename ResultType> template <typename ResultType>
@ -156,13 +182,13 @@ struct LpDistance
template <typename FloatType> template <typename FloatType>
struct State struct State
{ {
FloatType sum = 0; FloatType sum{};
}; };
template <typename ResultType> template <typename ResultType>
static void accumulate(State<ResultType> & state, ResultType x, ResultType y, const ConstParams & params) static void accumulate(State<ResultType> & state, ResultType x, ResultType y, const ConstParams & params)
{ {
state.sum += static_cast<ResultType>(std::pow(fabs(x - y), params.power)); state.sum += static_cast<ResultType>(pow(fabs(x - y), params.power));
} }
template <typename ResultType> template <typename ResultType>
@ -174,7 +200,7 @@ struct LpDistance
template <typename ResultType> template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams & params) static ResultType finalize(const State<ResultType> & state, const ConstParams & params)
{ {
return static_cast<ResultType>(std::pow(state.sum, params.inverted_power)); return static_cast<ResultType>(pow(state.sum, params.inverted_power));
} }
}; };
@ -187,7 +213,7 @@ struct LinfDistance
template <typename FloatType> template <typename FloatType>
struct State struct State
{ {
FloatType dist = 0; FloatType dist{};
}; };
template <typename ResultType> template <typename ResultType>
@ -218,9 +244,9 @@ struct CosineDistance
template <typename FloatType> template <typename FloatType>
struct State struct State
{ {
FloatType dot_prod = 0; FloatType dot_prod{};
FloatType x_squared = 0; FloatType x_squared{};
FloatType y_squared = 0; FloatType y_squared{};
}; };
template <typename ResultType> template <typename ResultType>
@ -249,13 +275,11 @@ struct CosineDistance
size_t & i_y, size_t & i_y,
State<ResultType> & state) State<ResultType> & state)
{ {
static constexpr bool is_float32 = std::is_same_v<ResultType, Float32>;
__m512 dot_products; __m512 dot_products;
__m512 x_squareds; __m512 x_squareds;
__m512 y_squareds; __m512 y_squareds;
if constexpr (is_float32) if constexpr (sizeof(ResultType) <= 4)
{ {
dot_products = _mm512_setzero_ps(); dot_products = _mm512_setzero_ps();
x_squareds = _mm512_setzero_ps(); x_squareds = _mm512_setzero_ps();
@ -268,11 +292,11 @@ struct CosineDistance
y_squareds = _mm512_setzero_pd(); y_squareds = _mm512_setzero_pd();
} }
constexpr size_t n = is_float32 ? 16 : 8; constexpr size_t n = sizeof(__m512) / sizeof(ResultType);
for (; i_x + n < i_max; i_x += n, i_y += n) for (; i_x + n < i_max; i_x += n, i_y += n)
{ {
if constexpr (is_float32) if constexpr (sizeof(ResultType) == 4)
{ {
__m512 x = _mm512_loadu_ps(data_x + i_x); __m512 x = _mm512_loadu_ps(data_x + i_x);
__m512 y = _mm512_loadu_ps(data_y + i_y); __m512 y = _mm512_loadu_ps(data_y + i_y);
@ -290,7 +314,7 @@ struct CosineDistance
} }
} }
if constexpr (is_float32) if constexpr (sizeof(ResultType) == 4)
{ {
state.dot_prod = _mm512_reduce_add_ps(dot_products); state.dot_prod = _mm512_reduce_add_ps(dot_products);
state.x_squared = _mm512_reduce_add_ps(x_squareds); state.x_squared = _mm512_reduce_add_ps(x_squareds);
@ -303,16 +327,48 @@ struct CosineDistance
state.y_squared = _mm512_reduce_add_pd(y_squareds); state.y_squared = _mm512_reduce_add_pd(y_squareds);
} }
} }
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombineBF16(
const BFloat16 * __restrict data_x,
const BFloat16 * __restrict data_y,
size_t i_max,
size_t & i_x,
size_t & i_y,
State<Float32> & state)
{
__m512 dot_products;
__m512 x_squareds;
__m512 y_squareds;
dot_products = _mm512_setzero_ps();
x_squareds = _mm512_setzero_ps();
y_squareds = _mm512_setzero_ps();
constexpr size_t n = sizeof(__m512) / sizeof(BFloat16);
for (; i_x + n < i_max; i_x += n, i_y += n)
{
__m512 x = _mm512_loadu_ps(data_x + i_x);
__m512 y = _mm512_loadu_ps(data_y + i_y);
dot_products = _mm512_dpbf16_ps(dot_products, x, y);
x_squareds = _mm512_dpbf16_ps(x_squareds, x, x);
y_squareds = _mm512_dpbf16_ps(y_squareds, y, y);
}
state.dot_prod = _mm512_reduce_add_ps(dot_products);
state.x_squared = _mm512_reduce_add_ps(x_squareds);
state.y_squared = _mm512_reduce_add_ps(y_squareds);
}
#endif #endif
template <typename ResultType> template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams &) static ResultType finalize(const State<ResultType> & state, const ConstParams &)
{ {
return 1 - state.dot_prod / sqrt(state.x_squared * state.y_squared); return 1.0f - state.dot_prod / sqrt(state.x_squared * state.y_squared);
} }
}; };
template <class Kernel> template <typename Kernel>
class FunctionArrayDistance : public IFunction class FunctionArrayDistance : public IFunction
{ {
public: public:
@ -352,12 +408,13 @@ public:
case TypeIndex::Float64: case TypeIndex::Float64:
return std::make_shared<DataTypeFloat64>(); return std::make_shared<DataTypeFloat64>();
case TypeIndex::Float32: case TypeIndex::Float32:
case TypeIndex::BFloat16:
return std::make_shared<DataTypeFloat32>(); return std::make_shared<DataTypeFloat32>();
default: default:
throw Exception( throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Arguments of function {} has nested type {}. " "Arguments of function {} has nested type {}. "
"Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, BFloat16, Float32, Float64.",
getName(), getName(),
common_type->getName()); common_type->getName());
} }
@ -369,10 +426,8 @@ public:
{ {
case TypeIndex::Float32: case TypeIndex::Float32:
return executeWithResultType<Float32>(arguments, input_rows_count); return executeWithResultType<Float32>(arguments, input_rows_count);
break;
case TypeIndex::Float64: case TypeIndex::Float64:
return executeWithResultType<Float64>(arguments, input_rows_count); return executeWithResultType<Float64>(arguments, input_rows_count);
break;
default: default:
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName()); throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName());
} }
@ -388,6 +443,7 @@ public:
ACTION(Int16) \ ACTION(Int16) \
ACTION(Int32) \ ACTION(Int32) \
ACTION(Int64) \ ACTION(Int64) \
ACTION(BFloat16) \
ACTION(Float32) \ ACTION(Float32) \
ACTION(Float64) ACTION(Float64)
@ -412,7 +468,7 @@ private:
throw Exception( throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Arguments of function {} has nested type {}. " "Arguments of function {} has nested type {}. "
"Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, BFloat16, Float32, Float64.",
getName(), getName(),
type_x->getName()); type_x->getName());
} }
@ -437,7 +493,7 @@ private:
throw Exception( throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Arguments of function {} has nested type {}. " "Arguments of function {} has nested type {}. "
"Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, BFloat16, Float32, Float64.",
getName(), getName(),
type_y->getName()); type_y->getName());
} }
@ -446,14 +502,10 @@ private:
template <typename ResultType, typename LeftType, typename RightType> template <typename ResultType, typename LeftType, typename RightType>
ColumnPtr executeWithResultTypeAndLeftTypeAndRightType(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const ColumnPtr executeWithResultTypeAndLeftTypeAndRightType(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const
{ {
if (typeid_cast<const ColumnConst *>(col_x.get())) if (col_x->isConst())
{
return executeWithLeftArgConst<ResultType, LeftType, RightType>(col_x, col_y, input_rows_count, arguments); return executeWithLeftArgConst<ResultType, LeftType, RightType>(col_x, col_y, input_rows_count, arguments);
} if (col_y->isConst())
if (typeid_cast<const ColumnConst *>(col_y.get()))
{
return executeWithLeftArgConst<ResultType, RightType, LeftType>(col_y, col_x, input_rows_count, arguments); return executeWithLeftArgConst<ResultType, RightType, LeftType>(col_y, col_x, input_rows_count, arguments);
}
const auto & array_x = *assert_cast<const ColumnArray *>(col_x.get()); const auto & array_x = *assert_cast<const ColumnArray *>(col_x.get());
const auto & array_y = *assert_cast<const ColumnArray *>(col_y.get()); const auto & array_y = *assert_cast<const ColumnArray *>(col_y.get());
@ -497,7 +549,7 @@ private:
state, static_cast<ResultType>(data_x[prev]), static_cast<ResultType>(data_y[prev]), kernel_params); state, static_cast<ResultType>(data_x[prev]), static_cast<ResultType>(data_y[prev]), kernel_params);
} }
result_data[row] = Kernel::finalize(state, kernel_params); result_data[row] = Kernel::finalize(state, kernel_params);
row++; ++row;
} }
return col_res; return col_res;
} }
@ -548,24 +600,39 @@ private:
/// SIMD optimization: process multiple elements in both input arrays at once. /// SIMD optimization: process multiple elements in both input arrays at once.
/// To avoid combinatorial explosion of SIMD kernels, focus on /// To avoid combinatorial explosion of SIMD kernels, focus on
/// - the two most common input/output types (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64 instead of 10 x /// - the three most common input/output types (BFloat16 x BFloat16) --> Float32,
/// 10 input types x 2 output types, /// (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64
/// instead of 10 x 10 input types x 2 output types,
/// - const/non-const inputs instead of non-const/non-const inputs /// - const/non-const inputs instead of non-const/non-const inputs
/// - the two most common metrics L2 and cosine distance, /// - the two most common metrics L2 and cosine distance,
/// - the most powerful SIMD instruction set (AVX-512F). /// - the most powerful SIMD instruction set (AVX-512F).
bool processed = false;
#if USE_MULTITARGET_CODE #if USE_MULTITARGET_CODE
if constexpr (std::is_same_v<ResultType, LeftType> && std::is_same_v<ResultType, RightType>) /// ResultType is Float32 or Float64 /// ResultType is Float32 or Float64
if constexpr (std::is_same_v<Kernel, L2Distance> || std::is_same_v<Kernel, CosineDistance>)
{ {
if constexpr (std::is_same_v<Kernel, L2Distance> if constexpr (std::is_same_v<ResultType, LeftType> && std::is_same_v<ResultType, RightType>)
|| std::is_same_v<Kernel, CosineDistance>)
{ {
if (isArchSupported(TargetArch::AVX512F)) if (isArchSupported(TargetArch::AVX512F))
{
Kernel::template accumulateCombine<ResultType>(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state); Kernel::template accumulateCombine<ResultType>(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
processed = true;
} }
} }
#else else if constexpr (std::is_same_v<Float32, ResultType> && std::is_same_v<BFloat16, LeftType> && std::is_same_v<BFloat16, RightType>)
/// Process chunks in vectorized manner {
static constexpr size_t VEC_SIZE = 4; if (isArchSupported(TargetArch::AVX512BF16))
{
Kernel::accumulateCombineBF16(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
processed = true;
}
}
}
#endif
if (!processed)
{
/// Process chunks in a vectorized manner.
static constexpr size_t VEC_SIZE = 32;
typename Kernel::template State<ResultType> states[VEC_SIZE]; typename Kernel::template State<ResultType> states[VEC_SIZE];
for (; prev + VEC_SIZE < off; i += VEC_SIZE, prev += VEC_SIZE) for (; prev + VEC_SIZE < off; i += VEC_SIZE, prev += VEC_SIZE)
{ {
@ -576,8 +643,9 @@ private:
for (const auto & other_state : states) for (const auto & other_state : states)
Kernel::template combine<ResultType>(state, other_state, kernel_params); Kernel::template combine<ResultType>(state, other_state, kernel_params);
#endif }
/// Process the tail
/// Process the tail.
for (; prev < off; ++i, ++prev) for (; prev < off; ++i, ++prev)
{ {
Kernel::template accumulate<ResultType>( Kernel::template accumulate<ResultType>(
@ -638,4 +706,5 @@ FunctionPtr createFunctionArrayL2SquaredDistance(ContextPtr context_) { return F
FunctionPtr createFunctionArrayLpDistance(ContextPtr context_) { return FunctionArrayDistance<LpDistance>::create(context_); } FunctionPtr createFunctionArrayLpDistance(ContextPtr context_) { return FunctionArrayDistance<LpDistance>::create(context_); }
FunctionPtr createFunctionArrayLinfDistance(ContextPtr context_) { return FunctionArrayDistance<LinfDistance>::create(context_); } FunctionPtr createFunctionArrayLinfDistance(ContextPtr context_) { return FunctionArrayDistance<LinfDistance>::create(context_); }
FunctionPtr createFunctionArrayCosineDistance(ContextPtr context_) { return FunctionArrayDistance<CosineDistance>::create(context_); } FunctionPtr createFunctionArrayCosineDistance(ContextPtr context_) { return FunctionArrayDistance<CosineDistance>::create(context_); }
} }

View File

@ -452,9 +452,14 @@ private:
using ValueType = typename Types::RightType; using ValueType = typename Types::RightType;
static constexpr bool key_and_value_are_numbers = IsDataTypeNumber<KeyType> && IsDataTypeNumber<ValueType>; static constexpr bool key_and_value_are_numbers = IsDataTypeNumber<KeyType> && IsDataTypeNumber<ValueType>;
static constexpr bool key_is_float = std::is_same_v<KeyType, DataTypeFloat32> || std::is_same_v<KeyType, DataTypeFloat64>;
if constexpr (key_and_value_are_numbers && !key_is_float) if constexpr (key_and_value_are_numbers)
{
if constexpr (is_floating_point<typename KeyType::FieldType>)
{
return false;
}
else
{ {
using KeyFieldType = typename KeyType::FieldType; using KeyFieldType = typename KeyType::FieldType;
using ValueFieldType = typename ValueType::FieldType; using ValueFieldType = typename ValueType::FieldType;
@ -470,6 +475,7 @@ private:
return true; return true;
} }
}
return false; return false;
}; };

View File

@ -18,7 +18,7 @@ struct DivideFloatingImpl
template <typename Result = ResultType> template <typename Result = ResultType>
static NO_SANITIZE_UNDEFINED Result apply(A a [[maybe_unused]], B b [[maybe_unused]]) static NO_SANITIZE_UNDEFINED Result apply(A a [[maybe_unused]], B b [[maybe_unused]])
{ {
return static_cast<Result>(a) / b; return static_cast<Result>(a) / static_cast<Result>(b);
} }
#if USE_EMBEDDED_COMPILER #if USE_EMBEDDED_COMPILER

View File

@ -3,6 +3,12 @@
namespace DB namespace DB
{ {
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
}
namespace namespace
{ {
@ -20,9 +26,16 @@ namespace
template <typename T> template <typename T>
static void execute(const T * src, size_t size, T * dst) static void execute(const T * src, size_t size, T * dst)
{
if constexpr (std::is_same_v<T, BFloat16>)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name);
}
else
{ {
NFastOps::Exp<true>(src, size, dst); NFastOps::Exp<true>(src, size, dst);
} }
}
}; };
} }

View File

@ -21,7 +21,7 @@ struct FactorialImpl
static NO_SANITIZE_UNDEFINED ResultType apply(A a) static NO_SANITIZE_UNDEFINED ResultType apply(A a)
{ {
if constexpr (std::is_floating_point_v<A> || is_over_big_int<A>) if constexpr (is_floating_point<A> || is_over_big_int<A>)
throw Exception( throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type of argument of function factorial, should not be floating point or big int"); "Illegal type of argument of function factorial, should not be floating point or big int");

Some files were not shown because too many files have changed in this diff Show More