Preparation

This commit is contained in:
Alexey Milovidov 2024-06-02 20:43:02 +02:00
parent 6e08f415c4
commit bf2a8f6a7f
19 changed files with 110 additions and 292 deletions

View File

@ -1,9 +1,22 @@
#pragma once
#include <base/bit_cast.h>
using BFloat16 = __bf16;
namespace std
{
inline constexpr bool isfinite(BFloat16) { return true; }
inline constexpr bool signbit(BFloat16) { return false; }
inline constexpr bool isfinite(BFloat16 x) { return (bit_cast<UInt16>(x) & 0b0111111110000000) != 0b0111111110000000; }
inline constexpr bool signbit(BFloat16 x) { return bit_cast<UInt16>(x) & 0b1000000000000000; }
}
inline Float32 BFloat16ToFloat32(BFloat16 x)
{
return bit_cast<Float32>(static_cast<UInt32>(bit_cast<UInt16>(x)) << 16);
}
inline BFloat16 Float32ToBFloat16(Float32 x)
{
return bit_cast<BFloat16>(std::bit_cast<UInt32>(x) >> 16);
}

View File

@ -193,12 +193,11 @@ struct AggregateFunctionSumData
Impl::add(sum, local_sum);
return;
}
else if constexpr (is_floating_point<T>)
else if constexpr (is_floating_point<T> && (sizeof(Value) == 4 || sizeof(Value) == 8))
{
/// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned
/// For floating point we use a similar trick as above, except that now we reinterpret the floating point number as an unsigned
/// integer of the same size and use a mask instead (0 to discard, 0xFF..FF to keep)
static_assert(sizeof(Value) == 4 || sizeof(Value) == 8);
using equivalent_integer = typename std::conditional_t<sizeof(Value) == 4, UInt32, UInt64>;
using EquivalentInteger = typename std::conditional_t<sizeof(Value) == 4, UInt32, UInt64>;
constexpr size_t unroll_count = 128 / sizeof(T);
T partial_sums[unroll_count]{};
@ -209,11 +208,11 @@ struct AggregateFunctionSumData
{
for (size_t i = 0; i < unroll_count; ++i)
{
equivalent_integer value;
std::memcpy(&value, &ptr[i], sizeof(Value));
EquivalentInteger value;
memcpy(&value, &ptr[i], sizeof(Value));
value &= (!condition_map[i] != add_if_zero) - 1;
Value d;
std::memcpy(&d, &value, sizeof(Value));
memcpy(&d, &value, sizeof(Value));
Impl::add(partial_sums[i], d);
}
ptr += unroll_count;

View File

@ -257,7 +257,7 @@ template <typename T> struct AggregateFunctionUniqTraits
{
static UInt64 hash(T x)
{
if constexpr (std::is_same_v<T, Float32> || std::is_same_v<T, Float64>)
if constexpr (is_floating_point<T>)
{
return bit_cast<UInt64>(x);
}

View File

@ -17,6 +17,7 @@ class DataTypeNumber;
namespace ErrorCodes
{
extern const int NOT_IMPLEMENTED;
extern const int DECIMAL_OVERFLOW;
extern const int ARGUMENT_OUT_OF_BOUND;
}
@ -310,7 +311,14 @@ ReturnType convertToImpl(const DecimalType & decimal, UInt32 scale, To & result)
using DecimalNativeType = typename DecimalType::NativeType;
static constexpr bool throw_exception = std::is_void_v<ReturnType>;
if constexpr (is_floating_point<To>)
if constexpr (std::is_same_v<To, BFloat16>)
{
if constexpr (throw_exception)
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conversion from Decimal to BFloat16 is not implemented");
else
return ReturnType(false);
}
else if constexpr (is_floating_point<To>)
{
result = static_cast<To>(decimal.value) / static_cast<To>(scaleMultiplier<DecimalNativeType>(scale));
}

View File

@ -1,149 +0,0 @@
#include "iostream_debug_helpers.h"
#include <iostream>
#include <Client/Connection.h>
#include <Core/Block.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Core/Field.h>
#include <Core/NamesAndTypes.h>
#include <DataTypes/IDataType.h>
#include <Functions/IFunction.h>
#include <IO/WriteBufferFromOStream.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/ExpressionActions.h>
#include <Parsers/IAST.h>
#include <Storages/IStorage.h>
#include <Common/COW.h>
#include <Common/FieldVisitorDump.h>
namespace DB
{
template <>
std::ostream & operator<< <Field>(std::ostream & stream, const Field & what)
{
stream << applyVisitor(FieldVisitorDump(), what);
return stream;
}
std::ostream & operator<<(std::ostream & stream, const NameAndTypePair & what)
{
stream << "NameAndTypePair(name = " << what.name << ", type = " << what.type << ")";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const IDataType & what)
{
stream << "IDataType(name = " << what.getName() << ", default = " << what.getDefault() << ")";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const IStorage & what)
{
auto table_id = what.getStorageID();
stream << "IStorage(name = " << what.getName() << ", tableName = " << table_id.table_name << ") {"
<< what.getInMemoryMetadataPtr()->getColumns().getAllPhysical().toString() << "}";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const TableLockHolder &)
{
stream << "TableStructureReadLock()";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const IFunctionOverloadResolver & what)
{
stream << "IFunction(name = " << what.getName() << ", variadic = " << what.isVariadic() << ", args = " << what.getNumberOfArguments()
<< ")";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const Block & what)
{
stream << "Block("
<< "num_columns = " << what.columns() << "){" << what.dumpStructure() << "}";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const ColumnWithTypeAndName & what)
{
stream << "ColumnWithTypeAndName(name = " << what.name << ", type = " << *what.type << ", column = ";
return dumpValue(stream, what.column) << ")";
}
std::ostream & operator<<(std::ostream & stream, const IColumn & what)
{
stream << "IColumn(" << what.dumpStructure() << ")";
stream << "{";
for (size_t i = 0; i < what.size(); ++i)
{
if (i)
stream << ", ";
stream << applyVisitor(FieldVisitorDump(), what[i]);
}
stream << "}";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const Packet & what)
{
stream << "Packet("
<< "type = " << what.type;
// types description: Core/Protocol.h
if (what.exception)
stream << "exception = " << what.exception.get();
// TODO: profile_info
stream << ") {" << what.block << "}";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what)
{
stream << "ExpressionActions(" << what.dumpActions() << ")";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const TreeRewriterResult & what)
{
stream << "SyntaxAnalyzerResult{";
stream << "storage=" << what.storage << "; ";
if (!what.source_columns.empty())
{
stream << "source_columns=";
dumpValue(stream, what.source_columns);
stream << "; ";
}
if (!what.aliases.empty())
{
stream << "aliases=";
dumpValue(stream, what.aliases);
stream << "; ";
}
if (!what.array_join_result_to_source.empty())
{
stream << "array_join_result_to_source=";
dumpValue(stream, what.array_join_result_to_source);
stream << "; ";
}
if (!what.array_join_alias_to_name.empty())
{
stream << "array_join_alias_to_name=";
dumpValue(stream, what.array_join_alias_to_name);
stream << "; ";
}
if (!what.array_join_name_to_alias.empty())
{
stream << "array_join_name_to_alias=";
dumpValue(stream, what.array_join_name_to_alias);
stream << "; ";
}
stream << "rewrite_subqueries=" << what.rewrite_subqueries << "; ";
stream << "}";
return stream;
}
}

View File

@ -1,49 +0,0 @@
#pragma once
#include <iostream>
namespace DB
{
// Use template to disable implicit casting for certain overloaded types such as Field, which leads
// to overload resolution ambiguity.
class Field;
template <typename T>
requires std::is_same_v<T, Field>
std::ostream & operator<<(std::ostream & stream, const T & what);
struct NameAndTypePair;
std::ostream & operator<<(std::ostream & stream, const NameAndTypePair & what);
class IDataType;
std::ostream & operator<<(std::ostream & stream, const IDataType & what);
class IStorage;
std::ostream & operator<<(std::ostream & stream, const IStorage & what);
class IFunctionOverloadResolver;
std::ostream & operator<<(std::ostream & stream, const IFunctionOverloadResolver & what);
class IFunctionBase;
std::ostream & operator<<(std::ostream & stream, const IFunctionBase & what);
class Block;
std::ostream & operator<<(std::ostream & stream, const Block & what);
struct ColumnWithTypeAndName;
std::ostream & operator<<(std::ostream & stream, const ColumnWithTypeAndName & what);
class IColumn;
std::ostream & operator<<(std::ostream & stream, const IColumn & what);
struct Packet;
std::ostream & operator<<(std::ostream & stream, const Packet & what);
class ExpressionActions;
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what);
struct TreeRewriterResult;
std::ostream & operator<<(std::ostream & stream, const TreeRewriterResult & what);
}
/// some operator<< should be declared before operator<<(... std::shared_ptr<>)
#include <base/iostream_debug_helpers.h>

View File

@ -20,6 +20,7 @@ namespace ErrorCodes
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int DECIMAL_OVERFLOW;
extern const int NOT_IMPLEMENTED;
}
@ -262,15 +263,19 @@ FOR_EACH_ARITHMETIC_TYPE(INVOKE);
template <typename FromDataType, typename ToDataType, typename ReturnType>
requires (is_arithmetic_v<typename FromDataType::FieldType> && IsDataTypeDecimal<ToDataType>)
ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & value, UInt32 scale, typename ToDataType::FieldType & result)
ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & /*value*/, UInt32 /*scale*/, typename ToDataType::FieldType & /*result*/)
{
using FromFieldType = typename FromDataType::FieldType;
/* using FromFieldType = typename FromDataType::FieldType;
using ToFieldType = typename ToDataType::FieldType;
using ToNativeType = typename ToFieldType::NativeType;
static constexpr bool throw_exception = std::is_same_v<ReturnType, void>;
if constexpr (is_floating_point<FromFieldType>)
if constexpr (std::is_same_v<typename FromDataType::FieldType, BFloat16>)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Conversion from BFloat16 to Decimal is not implemented");
}
else if constexpr (is_floating_point<FromFieldType>)
{
if (!isFinite(value))
{
@ -302,7 +307,9 @@ ReturnType convertToDecimalImpl(const typename FromDataType::FieldType & value,
return ReturnType(convertDecimalsImpl<DataTypeDecimal<Decimal128>, ToDataType, ReturnType>(static_cast<Int128>(value), 0, scale, result));
else
return ReturnType(convertDecimalsImpl<DataTypeDecimal<Decimal64>, ToDataType, ReturnType>(static_cast<Int64>(value), 0, scale, result));
}
}*/
return ReturnType();
}
#define DISPATCH(FROM_DATA_TYPE, TO_DATA_TYPE) \

View File

@ -298,7 +298,8 @@ namespace impl
using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::LeftType;
if constexpr (IsDataTypeDecimalOrNumber<DataType> || IsDataTypeDateOrDateTime<DataType> || IsDataTypeEnum<DataType>)
if constexpr ((IsDataTypeDecimalOrNumber<DataType> || IsDataTypeDateOrDateTime<DataType> || IsDataTypeEnum<DataType>)
&& !std::is_same_v<DataType, DataTypeBFloat16>)
{
using ColumnType = typename DataType::ColumnType;
func(TypePair<ColumnType, void>());

View File

@ -579,7 +579,8 @@ public:
using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::LeftType;
if constexpr (IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>)
if constexpr ((IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>)
&& !std::is_same_v<DataType, DataTypeBFloat16>)
{
using FieldType = typename DataType::FieldType;
res = Dispatcher<FieldType, rounding_mode, tie_breaking_mode>::apply(column.column.get(), scale_arg);

View File

@ -453,23 +453,29 @@ private:
using ValueType = typename Types::RightType;
static constexpr bool key_and_value_are_numbers = IsDataTypeNumber<KeyType> && IsDataTypeNumber<ValueType>;
static constexpr bool key_is_float = std::is_same_v<KeyType, DataTypeFloat32> || std::is_same_v<KeyType, DataTypeFloat64>;
if constexpr (key_and_value_are_numbers && !key_is_float)
if constexpr (key_and_value_are_numbers)
{
using KeyFieldType = typename KeyType::FieldType;
using ValueFieldType = typename ValueType::FieldType;
if constexpr (is_floating_point<typename KeyType::FieldType>)
{
return false;
}
else
{
using KeyFieldType = typename KeyType::FieldType;
using ValueFieldType = typename ValueType::FieldType;
executeImplTyped<KeyFieldType, ValueFieldType>(
input.key_column,
input.value_column,
input.offsets_column,
input.max_key_column,
std::move(result_columns.result_key_column),
std::move(result_columns.result_value_column),
std::move(result_columns.result_offset_column));
executeImplTyped<KeyFieldType, ValueFieldType>(
input.key_column,
input.value_column,
input.offsets_column,
input.max_key_column,
std::move(result_columns.result_key_column),
std::move(result_columns.result_value_column),
std::move(result_columns.result_offset_column));
return true;
return true;
}
}
return false;

View File

@ -21,7 +21,14 @@ namespace
template <typename T>
static void execute(const T * src, size_t size, T * dst)
{
NFastOps::Exp<true>(src, size, dst);
if constexpr (std::is_same_v<T, BFloat16>)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name);
}
else
{
NFastOps::Exp<true>(src, size, dst);
}
}
};
}

View File

@ -20,7 +20,14 @@ struct LogName { static constexpr auto name = "log"; };
template <typename T>
static void execute(const T * src, size_t size, T * dst)
{
NFastOps::Log<true>(src, size, dst);
if constexpr (std::is_same_v<T, BFloat16>)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name);
}
else
{
NFastOps::Log<true>(src, size, dst);
}
}
};

View File

@ -17,8 +17,8 @@ struct MinusImpl
{
if constexpr (is_big_int_v<A> || is_big_int_v<B>)
{
using CastA = std::conditional_t<floating_point<B>, B, A>;
using CastB = std::conditional_t<floating_point<A>, A, B>;
using CastA = std::conditional_t<is_floating_point<B>, B, A>;
using CastB = std::conditional_t<is_floating_point<A>, A, B>;
return static_cast<Result>(static_cast<CastA>(a)) - static_cast<Result>(static_cast<CastB>(b));
}

View File

@ -21,7 +21,14 @@ namespace
template <typename T>
static void execute(const T * src, size_t size, T * dst)
{
NFastOps::Sigmoid<>(src, size, dst);
if constexpr (std::is_same_v<T, BFloat16>)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name);
}
else
{
NFastOps::Sigmoid<>(src, size, dst);
}
}
};
}
@ -47,4 +54,3 @@ REGISTER_FUNCTION(Sigmoid)
}
}

View File

@ -19,7 +19,14 @@ struct TanhName { static constexpr auto name = "tanh"; };
template <typename T>
static void execute(const T * src, size_t size, T * dst)
{
NFastOps::Tanh<>(src, size, dst);
if constexpr (std::is_same_v<T, BFloat16>)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Function `{}` is not implemented for BFloat16", name);
}
else
{
NFastOps::Tanh<>(src, size, dst);
}
}
};

View File

@ -155,7 +155,7 @@ inline size_t writeFloatTextFastPath(T x, char * buffer)
{
Int64 result = 0;
if constexpr (std::is_same_v<T, double>)
if constexpr (std::is_same_v<T, Float64>)
{
/// The library Ryu has low performance on integers.
/// This workaround improves performance 6..10 times.
@ -165,10 +165,16 @@ inline size_t writeFloatTextFastPath(T x, char * buffer)
else
result = jkj::dragonbox::to_chars_n(x, buffer) - buffer;
}
else
else if constexpr (std::is_same_v<T, Float32>)
{
/// This will support 16-bit floats as well.
float f32 = x;
if (DecomposedFloat32(x).isIntegerInRepresentableRange())
result = itoa(Int32(x), buffer) - buffer;
else
result = jkj::dragonbox::to_chars_n(x, buffer) - buffer;
}
else if constexpr (std::is_same_v<T, BFloat16>)
{
Float32 f32 = BFloat16ToFloat32(x);
if (DecomposedFloat32(f32).isIntegerInRepresentableRange())
result = itoa(Int32(f32), buffer) - buffer;

View File

@ -183,7 +183,7 @@ private:
if (sorted.load(std::memory_order_relaxed))
return;
if constexpr (std::is_arithmetic_v<TKey> && !std::is_floating_point<TKey>)
if constexpr (std::is_arithmetic_v<TKey> && !is_floating_point<TKey>)
{
if (likely(entries.size() > 256))
{

View File

@ -1,35 +0,0 @@
#include "iostream_debug_helpers.h"
#include <Parsers/IAST.h>
#include <Parsers/IParser.h>
#include <Parsers/Lexer.h>
#include <Parsers/TokenIterator.h>
#include <IO/WriteBufferFromOStream.h>
#include <IO/Operators.h>
namespace DB
{
std::ostream & operator<<(std::ostream & stream, const Token & what)
{
stream << "Token (type="<< static_cast<int>(what.type) <<"){"<< std::string{what.begin, what.end} << "}";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const Expected & what)
{
stream << "Expected {variants=";
dumpValue(stream, what.variants)
<< "; max_parsed_pos=" << what.max_parsed_pos << "}";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const IAST & what)
{
WriteBufferFromOStream buf(stream, 4096);
buf << "IAST{";
what.dumpTree(buf);
buf << "}";
return stream;
}
}

View File

@ -1,17 +0,0 @@
#pragma once
#include <iostream>
namespace DB
{
struct Token;
std::ostream & operator<<(std::ostream & stream, const Token & what);
struct Expected;
std::ostream & operator<<(std::ostream & stream, const Expected & what);
class IAST;
std::ostream & operator<<(std::ostream & stream, const IAST & what);
}
#include <Core/iostream_debug_helpers.h>