No, this is better

This commit is contained in:
Alexey Milovidov 2024-11-14 00:34:34 +01:00
parent d59087a0f6
commit d4599a68fa

View File

@ -606,6 +606,7 @@ private:
/// - const/non-const inputs instead of non-const/non-const inputs
/// - the two most common metrics L2 and cosine distance,
/// - the most powerful SIMD instruction set (AVX-512F).
bool processed = false;
#if USE_MULTITARGET_CODE
/// ResultType is Float32 or Float64
if constexpr (std::is_same_v<Kernel, L2Distance> || std::is_same_v<Kernel, CosineDistance>)
@ -613,28 +614,37 @@ private:
if constexpr (std::is_same_v<ResultType, LeftType> && std::is_same_v<ResultType, RightType>)
{
if (isArchSupported(TargetArch::AVX512F))
{
Kernel::template accumulateCombine<ResultType>(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
processed = true;
}
}
else if constexpr (std::is_same_v<Float32, ResultType> && std::is_same_v<BFloat16, LeftType> && std::is_same_v<BFloat16, RightType>)
{
if (isArchSupported(TargetArch::AVX512BF16))
{
Kernel::accumulateCombineBF16(data_x.data(), data_y.data(), i + offsets_x[0], i, prev, state);
processed = true;
}
}
}
#else
/// Process chunks in a vectorized manner.
static constexpr size_t VEC_SIZE = 16;
typename Kernel::template State<ResultType> states[VEC_SIZE];
for (; prev + VEC_SIZE < off; i += VEC_SIZE, prev += VEC_SIZE)
#endif
if (!processed)
{
for (size_t s = 0; s < VEC_SIZE; ++s)
Kernel::template accumulate<ResultType>(
states[s], static_cast<ResultType>(data_x[i + s]), static_cast<ResultType>(data_y[prev + s]), kernel_params);
/// Process chunks in a vectorized manner.
static constexpr size_t VEC_SIZE = 32;
typename Kernel::template State<ResultType> states[VEC_SIZE];
for (; prev + VEC_SIZE < off; i += VEC_SIZE, prev += VEC_SIZE)
{
for (size_t s = 0; s < VEC_SIZE; ++s)
Kernel::template accumulate<ResultType>(
states[s], static_cast<ResultType>(data_x[i + s]), static_cast<ResultType>(data_y[prev + s]), kernel_params);
}
for (const auto & other_state : states)
Kernel::template combine<ResultType>(state, other_state, kernel_params);
}
for (const auto & other_state : states)
Kernel::template combine<ResultType>(state, other_state, kernel_params);
#endif
/// Process the tail.
for (; prev < off; ++i, ++prev)
{