mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
This should be better
This commit is contained in:
parent
78083130f1
commit
55f758fa29
@ -15,30 +15,6 @@
|
||||
#endif
|
||||
|
||||
|
||||
namespace
|
||||
{
|
||||
inline BFloat16 fabs(BFloat16 x)
|
||||
{
|
||||
return x.abs();
|
||||
}
|
||||
|
||||
inline BFloat16 sqrt(BFloat16 x)
|
||||
{
|
||||
return BFloat16(::sqrtf(Float32(x)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline BFloat16 pow(BFloat16 x, T p)
|
||||
{
|
||||
return BFloat16(::powf(Float32(x), Float32(p)));
|
||||
}
|
||||
|
||||
inline BFloat16 fmax(BFloat16 x, BFloat16 y)
|
||||
{
|
||||
return BFloat16(::fmaxf(Float32(x), Float32(y)));
|
||||
}
|
||||
}
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
@ -107,7 +83,7 @@ struct L2Distance
|
||||
|
||||
#if USE_MULTITARGET_CODE
|
||||
template <typename ResultType>
|
||||
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
|
||||
AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
|
||||
const ResultType * __restrict data_x,
|
||||
const ResultType * __restrict data_y,
|
||||
size_t i_max,
|
||||
@ -125,19 +101,7 @@ struct L2Distance
|
||||
|
||||
for (; i_x + n < i_max; i_x += n, i_y += n)
|
||||
{
|
||||
if constexpr (sizeof(ResultType) == 2)
|
||||
{
|
||||
__m512 x_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x)));
|
||||
__m512 x_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x + n / 2)));
|
||||
__m512 y_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y)));
|
||||
__m512 y_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y + n / 2)));
|
||||
|
||||
__m512 differences_1 = _mm512_sub_ps(x_1, y_1);
|
||||
__m512 differences_2 = _mm512_sub_ps(x_2, y_2);
|
||||
sums = _mm512_fmadd_ps(differences_1, differences_1, sums);
|
||||
sums = _mm512_fmadd_ps(differences_2, differences_2, sums);
|
||||
}
|
||||
else if constexpr (sizeof(ResultType) == 4)
|
||||
if constexpr (sizeof(ResultType) == 4)
|
||||
{
|
||||
__m512 x = _mm512_loadu_ps(data_x + i_x);
|
||||
__m512 y = _mm512_loadu_ps(data_y + i_y);
|
||||
@ -158,12 +122,39 @@ struct L2Distance
|
||||
else
|
||||
state.sum = _mm512_reduce_add_pd(sums);
|
||||
}
|
||||
|
||||
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombineBF16(
|
||||
const BFloat16 * __restrict data_x,
|
||||
const BFloat16 * __restrict data_y,
|
||||
size_t i_max,
|
||||
size_t & i_x,
|
||||
size_t & i_y,
|
||||
State<Float32> & state)
|
||||
{
|
||||
__m512 sums = _mm512_setzero_ps();
|
||||
constexpr size_t n = sizeof(__m512) / sizeof(BFloat16);
|
||||
|
||||
for (; i_x + n < i_max; i_x += n, i_y += n)
|
||||
{
|
||||
__m512 x_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x)));
|
||||
__m512 x_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_x + i_x + n / 2)));
|
||||
__m512 y_1 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y)));
|
||||
__m512 y_2 = _mm512_cvtpbh_ps(_mm256_loadu_ps(reinterpret_cast<const Float32 *>(data_y + i_y + n / 2)));
|
||||
|
||||
__m512 differences_1 = _mm512_sub_ps(x_1, y_1);
|
||||
__m512 differences_2 = _mm512_sub_ps(x_2, y_2);
|
||||
sums = _mm512_fmadd_ps(differences_1, differences_1, sums);
|
||||
sums = _mm512_fmadd_ps(differences_2, differences_2, sums);
|
||||
}
|
||||
|
||||
state.sum = _mm512_reduce_add_ps(sums);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
|
||||
{
|
||||
return sqrt(ResultType(state.sum));
|
||||
return sqrt(state.sum);
|
||||
}
|
||||
};
|
||||
|
||||
@ -276,7 +267,7 @@ struct CosineDistance
|
||||
|
||||
#if USE_MULTITARGET_CODE
|
||||
template <typename ResultType>
|
||||
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
|
||||
AVX512_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombine(
|
||||
const ResultType * __restrict data_x,
|
||||
const ResultType * __restrict data_y,
|
||||
size_t i_max,
|
||||
@ -305,14 +296,6 @@ struct CosineDistance
|
||||
|
||||
for (; i_x + n < i_max; i_x += n, i_y += n)
|
||||
{
|
||||
if constexpr (sizeof(ResultType) == 2)
|
||||
{
|
||||
__m512 x = _mm512_loadu_ps(data_x + i_x);
|
||||
__m512 y = _mm512_loadu_ps(data_y + i_y);
|
||||
dot_products = _mm512_dpbf16_ps(dot_products, x, y);
|
||||
x_squareds = _mm512_dpbf16_ps(x_squareds, x, x);
|
||||
y_squareds = _mm512_dpbf16_ps(y_squareds, y, y);
|
||||
}
|
||||
if constexpr (sizeof(ResultType) == 4)
|
||||
{
|
||||
__m512 x = _mm512_loadu_ps(data_x + i_x);
|
||||
@ -331,7 +314,7 @@ struct CosineDistance
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (sizeof(ResultType) == 2 || sizeof(ResultType) == 4)
|
||||
if constexpr (sizeof(ResultType) == 4)
|
||||
{
|
||||
state.dot_prod = _mm512_reduce_add_ps(dot_products);
|
||||
state.x_squared = _mm512_reduce_add_ps(x_squareds);
|
||||
@ -344,16 +327,48 @@ struct CosineDistance
|
||||
state.y_squared = _mm512_reduce_add_pd(y_squareds);
|
||||
}
|
||||
}
|
||||
|
||||
AVX512BF16_FUNCTION_SPECIFIC_ATTRIBUTE static void accumulateCombineBF16(
|
||||
const BFloat16 * __restrict data_x,
|
||||
const BFloat16 * __restrict data_y,
|
||||
size_t i_max,
|
||||
size_t & i_x,
|
||||
size_t & i_y,
|
||||
State<Float32> & state)
|
||||
{
|
||||
__m512 dot_products;
|
||||
__m512 x_squareds;
|
||||
__m512 y_squareds;
|
||||
|
||||
dot_products = _mm512_setzero_ps();
|
||||
x_squareds = _mm512_setzero_ps();
|
||||
y_squareds = _mm512_setzero_ps();
|
||||
|
||||
constexpr size_t n = sizeof(__m512) / sizeof(BFloat16);
|
||||
|
||||
for (; i_x + n < i_max; i_x += n, i_y += n)
|
||||
{
|
||||
__m512 x = _mm512_loadu_ps(data_x + i_x);
|
||||
__m512 y = _mm512_loadu_ps(data_y + i_y);
|
||||
dot_products = _mm512_dpbf16_ps(dot_products, x, y);
|
||||
x_squareds = _mm512_dpbf16_ps(x_squareds, x, x);
|
||||
y_squareds = _mm512_dpbf16_ps(y_squareds, y, y);
|
||||
}
|
||||
|
||||
state.dot_prod = _mm512_reduce_add_ps(dot_products);
|
||||
state.x_squared = _mm512_reduce_add_ps(x_squareds);
|
||||
state.y_squared = _mm512_reduce_add_ps(y_squareds);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
|
||||
{
|
||||
return ResultType(1) - state.dot_prod / sqrt(state.x_squared * state.y_squared);
|
||||
return 1.0f - state.dot_prod / sqrt(state.x_squared * state.y_squared);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Kernel>
|
||||
template <typename Kernel>
|
||||
class FunctionArrayDistance : public IFunction
|
||||
{
|
||||
public:
|
||||
@ -393,9 +408,8 @@ public:
|
||||
case TypeIndex::Float64:
|
||||
return std::make_shared<DataTypeFloat64>();
|
||||
case TypeIndex::Float32:
|
||||
return std::make_shared<DataTypeFloat32>();
|
||||
case TypeIndex::BFloat16:
|
||||
return std::make_shared<DataTypeBFloat16>();
|
||||
return std::make_shared<DataTypeFloat32>();
|
||||
default:
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
@ -410,15 +424,10 @@ public:
|
||||
{
|
||||
switch (result_type->getTypeId())
|
||||
{
|
||||
case TypeIndex::BFloat16:
|
||||
return executeWithResultType<BFloat16>(arguments, input_rows_count);
|
||||
break;
|
||||
case TypeIndex::Float32:
|
||||
return executeWithResultType<Float32>(arguments, input_rows_count);
|
||||
break;
|
||||
case TypeIndex::Float64:
|
||||
return executeWithResultType<Float64>(arguments, input_rows_count);
|
||||
break;
|
||||
default:
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName());
|
||||
}
|
||||
@ -493,14 +502,10 @@ private:
|
||||
template <typename ResultType, typename LeftType, typename RightType>
|
||||
ColumnPtr executeWithResultTypeAndLeftTypeAndRightType(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
if (typeid_cast<const ColumnConst *>(col_x.get()))
|
||||
{
|
||||
if (col_x->isConst())
|
||||
return executeWithLeftArgConst<ResultType, LeftType, RightType>(col_x, col_y, input_rows_count, arguments);
|
||||
}
|
||||
if (typeid_cast<const ColumnConst *>(col_y.get()))
|
||||
{
|
||||
if (col_y->isConst())
|
||||
return executeWithLeftArgConst<ResultType, RightType, LeftType>(col_y, col_x, input_rows_count, arguments);
|
||||
}
|
||||
|
||||
const auto & array_x = *assert_cast<const ColumnArray *>(col_x.get());
|
||||
const auto & array_y = *assert_cast<const ColumnArray *>(col_y.get());
|
||||
@ -544,7 +549,7 @@ private:
|
||||
state, static_cast<ResultType>(data_x[prev]), static_cast<ResultType>(data_y[prev]), kernel_params);
|
||||
}
|
||||
result_data[row] = Kernel::finalize(state, kernel_params);
|
||||
row++;
|
||||
++row;
|
||||
}
|
||||
return col_res;
|
||||
}
|
||||
@ -595,22 +600,26 @@ private:
|
||||
|
||||
/// SIMD optimization: process multiple elements in both input arrays at once.
|
||||
/// To avoid combinatorial explosion of SIMD kernels, focus on
|
||||
/// - the three most common input/output types (BFloat16 x BFloat16) --> BFloat16,
|
||||
/// - the three most common input/output types (BFloat16 x BFloat16) --> Float32,
|
||||
/// (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64
|
||||
/// instead of 10 x 10 input types x 2 output types,
|
||||
/// - const/non-const inputs instead of non-const/non-const inputs
|
||||
/// - the two most common metrics L2 and cosine distance,
|
||||
/// - the most powerful SIMD instruction set (AVX-512F).
|
||||
#if USE_MULTITARGET_CODE
|
||||
/// ResultType is BFloat16, Float32 or Float64
|
||||
if constexpr (std::is_same_v<ResultType, LeftType> && std::is_same_v<ResultType, RightType>)
|
||||
/// ResultType is Float32 or Float64
|
||||
if constexpr (std::is_same_v<Kernel, L2Distance> || std::is_same_v<Kernel, CosineDistance>)
|
||||
{
|
||||
if constexpr (std::is_same_v<Kernel, L2Distance>
|
||||
|| std::is_same_v<Kernel, CosineDistance>)
|
||||
if constexpr (std::is_same_v<ResultType, LeftType> && std::is_same_v<ResultType, RightType>)
|
||||
{
|
||||
if (isArchSupported(TargetArch::AVX512F))
|
||||
Kernel::template accumulateCombine<ResultType>(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
|
||||
}
|
||||
else if constexpr (std::is_same_v<Float32, ResultType> && std::is_same_v<BFloat16, LeftType> && std::is_same_v<BFloat16, RightType>)
|
||||
{
|
||||
if (isArchSupported(TargetArch::AVX512BF16))
|
||||
Kernel::accumulateCombineBF16(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
|
||||
}
|
||||
}
|
||||
#else
|
||||
/// Process chunks in vectorized manner
|
||||
|
Loading…
Reference in New Issue
Block a user