This commit is contained in:
Alexey Milovidov 2024-11-10 13:53:08 +01:00
parent bec94da77e
commit f0dc1330eb

View File

@ -268,6 +268,19 @@ inline double roundWithMode(double x, RoundingMode mode)
std::unreachable();
}
inline BFloat16 roundWithMode(BFloat16 x, RoundingMode mode)
{
switch (mode)
{
case RoundingMode::Round: return BFloat16(nearbyintf(Float32(x)));
case RoundingMode::Floor: return BFloat16(floorf(Float32(x)));
case RoundingMode::Ceil: return BFloat16(ceilf(Float32(x)));
case RoundingMode::Trunc: return BFloat16(truncf(Float32(x)));
}
std::unreachable();
}
template <typename T>
class FloatRoundingComputationBase<T, Vectorize::No>
{
@ -289,6 +302,11 @@ public:
}
};
template <>
class FloatRoundingComputationBase<BFloat16, Vectorize::Yes> : public FloatRoundingComputationBase<BFloat16, Vectorize::No>
{
};
/** Implementation of low-level round-off functions for floating-point values.
*/
@ -688,20 +706,26 @@ public:
using Types = std::decay_t<decltype(types)>;
using DataType = typename Types::RightType;
if constexpr ((IsDataTypeNumber<DataType> || IsDataTypeDecimal<DataType>)
&& !std::is_same_v<DataType, DataTypeBFloat16>)
if (arguments.size() > 1)
{
if (arguments.size() > 1)
const ColumnWithTypeAndName & scale_column = arguments[1];
auto call_scale = [&](const auto & scaleTypes) -> bool
{
const ColumnWithTypeAndName & scale_column = arguments[1];
res = Dispatcher<DataType, rounding_mode, tie_breaking_mode>::template apply<int>(value_arg.column.get(), scale_column.column.get());
using ScaleTypes = std::decay_t<decltype(scaleTypes)>;
using ScaleType = typename ScaleTypes::RightType;
res = Dispatcher<DataType, rounding_mode, tie_breaking_mode>::template apply<ScaleType>(value_arg.column.get(), scale_column.column.get());
return true;
}
res = Dispatcher<DataType, rounding_mode, tie_breaking_mode>::template apply<int>(value_arg.column.get());
};
TypeIndex right_index = scale_column.type->getTypeId();
if (!callOnBasicType<void, true, false, false, false>(right_index, call_scale))
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Scale argument for rounding functions must have integer type");
return true;
}
else
return false;
res = Dispatcher<DataType, rounding_mode, tie_breaking_mode>::template apply<int>(value_arg.column.get());
return true;
};
#if !defined(__SSE4_1__)