This should be better

This commit is contained in:
Alexey Milovidov 2024-11-13 23:33:09 +01:00
parent 78083130f1
commit 55f758fa29

View File

@ -15,30 +15,6 @@
#endif
namespace
{
inline BFloat16 fabs(BFloat16 x)
{
return x.abs();
}
inline BFloat16 sqrt(BFloat16 x)
{
return BFloat16(::sqrtf(Float32(x)));
}
template <typename T>
inline BFloat16 pow(BFloat16 x, T p)
{
return BFloat16(::powf(Float32(x), Float32(p)));
}
inline BFloat16 fmax(BFloat16 x, BFloat16 y)
{
return BFloat16(::fmaxf(Float32(x), Float32(y)));
}
}
namespace DB
{
namespace ErrorCodes
@ -107,7 +83,7 @@ struct L2Distance
#if USE_MULTITARGET_CODE
template <typename ResultType>
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
const ResultType * __restrict data_x,
const ResultType * __restrict data_y,
size_t i_max,
@ -125,19 +101,7 @@ struct L2Distance
for (; i_x + n < i_max; i_x += n, i_y += n)
{
if constexpr (sizeof(ResultType) == 2)
{
__m512 x_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x)));
__m512 x_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x + n / 2)));
__m512 y_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y)));
__m512 y_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y + n / 2)));
__m512 differences_1 = _mm512_sub_ps(x_1, y_1);
__m512 differences_2 = _mm512_sub_ps(x_2, y_2);
sums = _mm512_fmadd_ps(differences_1, differences_1, sums);
sums = _mm512_fmadd_ps(differences_2, differences_2, sums);
}
else if constexpr (sizeof(ResultType) == 4)
if constexpr (sizeof(ResultType) == 4)
{
__m512 x = _mm512_loadu_ps(data_x + i_x);
__m512 y = _mm512_loadu_ps(data_y + i_y);
@ -158,12 +122,39 @@ struct L2Distance
else
state.sum = _mm512_reduce_add_pd(sums);
}
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombineBF16(
const BFloat16 * __restrict data_x,
const BFloat16 * __restrict data_y,
size_t i_max,
size_t & i_x,
size_t & i_y,
State<Float32> & state)
{
__m512 sums = _mm512_setzero_ps();
constexpr size_t n = sizeof(__m512) / sizeof(BFloat16);
for (; i_x + n < i_max; i_x += n, i_y += n)
{
__m512 x_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x)));
__m512 x_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x + n / 2)));
__m512 y_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y)));
__m512 y_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y + n / 2)));
__m512 differences_1 = _mm512_sub_ps(x_1, y_1);
__m512 differences_2 = _mm512_sub_ps(x_2, y_2);
sums = _mm512_fmadd_ps(differences_1, differences_1, sums);
sums = _mm512_fmadd_ps(differences_2, differences_2, sums);
}
state.sum = _mm512_reduce_add_ps(sums);
}
#endif
template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
{
return sqrt(ResultType(state.sum));
return sqrt(state.sum);
}
};
@ -276,7 +267,7 @@ struct CosineDistance
#if USE_MULTITARGET_CODE
template <typename ResultType>
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
const ResultType * __restrict data_x,
const ResultType * __restrict data_y,
size_t i_max,
@ -305,14 +296,6 @@ struct CosineDistance
for (; i_x + n < i_max; i_x += n, i_y += n)
{
if constexpr (sizeof(ResultType) == 2)
{
__m512 x = _mm512_loadu_ps(data_x + i_x);
__m512 y = _mm512_loadu_ps(data_y + i_y);
dot_products = _mm512_dpbf16_ps(dot_products, x, y);
x_squareds = _mm512_dpbf16_ps(x_squareds, x, x);
y_squareds = _mm512_dpbf16_ps(y_squareds, y, y);
}
if constexpr (sizeof(ResultType) == 4)
{
__m512 x = _mm512_loadu_ps(data_x + i_x);
@ -331,7 +314,7 @@ struct CosineDistance
}
}
if constexpr (sizeof(ResultType) == 2 || sizeof(ResultType) == 4)
if constexpr (sizeof(ResultType) == 4)
{
state.dot_prod = _mm512_reduce_add_ps(dot_products);
state.x_squared = _mm512_reduce_add_ps(x_squareds);
@ -344,16 +327,48 @@ struct CosineDistance
state.y_squared = _mm512_reduce_add_pd(y_squareds);
}
}
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombineBF16(
const BFloat16 * __restrict data_x,
const BFloat16 * __restrict data_y,
size_t i_max,
size_t & i_x,
size_t & i_y,
State<Float32> & state)
{
__m512 dot_products;
__m512 x_squareds;
__m512 y_squareds;
dot_products = _mm512_setzero_ps();
x_squareds = _mm512_setzero_ps();
y_squareds = _mm512_setzero_ps();
constexpr size_t n = sizeof(__m512) / sizeof(BFloat16);
for (; i_x + n < i_max; i_x += n, i_y += n)
{
__m512 x = _mm512_loadu_ps(data_x + i_x);
__m512 y = _mm512_loadu_ps(data_y + i_y);
dot_products = _mm512_dpbf16_ps(dot_products, x, y);
x_squareds = _mm512_dpbf16_ps(x_squareds, x, x);
y_squareds = _mm512_dpbf16_ps(y_squareds, y, y);
}
state.dot_prod = _mm512_reduce_add_ps(dot_products);
state.x_squared = _mm512_reduce_add_ps(x_squareds);
state.y_squared = _mm512_reduce_add_ps(y_squareds);
}
#endif
template <typename ResultType>
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
{
return ResultType(1) - state.dot_prod / sqrt(state.x_squared * state.y_squared);
return 1.0f - state.dot_prod / sqrt(state.x_squared * state.y_squared);
}
};
template <class Kernel>
template <typename Kernel>
class FunctionArrayDistance : public IFunction
{
public:
@ -393,9 +408,8 @@ public:
case TypeIndex::Float64:
return std::make_shared<DataTypeFloat64>();
case TypeIndex::Float32:
return std::make_shared<DataTypeFloat32>();
case TypeIndex::BFloat16:
return std::make_shared<DataTypeBFloat16>();
return std::make_shared<DataTypeFloat32>();
default:
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
@ -410,15 +424,10 @@ public:
{
switch (result_type->getTypeId())
{
case TypeIndex::BFloat16:
return executeWithResultType<BFloat16>(arguments, input_rows_count);
break;
case TypeIndex::Float32:
return executeWithResultType<Float32>(arguments, input_rows_count);
break;
case TypeIndex::Float64:
return executeWithResultType<Float64>(arguments, input_rows_count);
break;
default:
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName());
}
@ -493,14 +502,10 @@ private:
template <typename ResultType, typename LeftType, typename RightType>
ColumnPtr executeWithResultTypeAndLeftTypeAndRightType(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const
{
if (typeid_cast<const ColumnConst *>(col_x.get()))
{
if (col_x->isConst())
return executeWithLeftArgConst<ResultType, LeftType, RightType>(col_x, col_y, input_rows_count, arguments);
}
if (typeid_cast<const ColumnConst *>(col_y.get()))
{
if (col_y->isConst())
return executeWithLeftArgConst<ResultType, RightType, LeftType>(col_y, col_x, input_rows_count, arguments);
}
const auto & array_x = *assert_cast<const ColumnArray *>(col_x.get());
const auto & array_y = *assert_cast<const ColumnArray *>(col_y.get());
@ -544,7 +549,7 @@ private:
state, static_cast<ResultType>(data_x[prev]), static_cast<ResultType>(data_y[prev]), kernel_params);
}
result_data[row] = Kernel::finalize(state, kernel_params);
row++;
++row;
}
return col_res;
}
@ -595,22 +600,26 @@ private:
/// SIMD optimization: process multiple elements in both input arrays at once.
/// To avoid combinatorial explosion of SIMD kernels, focus on
/// - the three most common input/output types (BFloat16 x BFloat16) --> BFloat16,
/// - the three most common input/output types (BFloat16 x BFloat16) --> Float32,
/// (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64
/// instead of 10 x 10 input types x 2 output types,
/// - const/non-const inputs instead of non-const/non-const inputs
/// - the two most common metrics L2 and cosine distance,
/// - the most powerful SIMD instruction set (AVX-512F).
#if USE_MULTITARGET_CODE
/// ResultType is BFloat16, Float32 or Float64
if constexpr (std::is_same_v<ResultType, LeftType> && std::is_same_v<ResultType, RightType>)
/// ResultType is Float32 or Float64
if constexpr (std::is_same_v<Kernel, L2Distance> || std::is_same_v<Kernel, CosineDistance>)
{
if constexpr (std::is_same_v<Kernel, L2Distance>
|| std::is_same_v<Kernel, CosineDistance>)
if constexpr (std::is_same_v<ResultType, LeftType> && std::is_same_v<ResultType, RightType>)
{
if (isArchSupported(TargetArch::AVX512F))
Kernel::template accumulateCombine<ResultType>(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
}
else if constexpr (std::is_same_v<Float32, ResultType> && std::is_same_v<BFloat16, LeftType> && std::is_same_v<BFloat16, RightType>)
{
if (isArchSupported(TargetArch::AVX512BF16))
Kernel::accumulateCombineBF16(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
}
}
#else
/// Process chunks in vectorized manner