mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
dbms: add SSE variants of lower/upper and UTF8 equivalents [#METR-14764]
This commit is contained in:
parent
22b80c8226
commit
b06bdb0edf
@ -192,6 +192,8 @@ public:
|
||||
return *this;
|
||||
}
|
||||
|
||||
T * data() { return t_start(); }
|
||||
const T * data() const { return t_start(); }
|
||||
|
||||
size_t size() const { return t_end() - t_start(); }
|
||||
bool empty() const { return t_end() == t_start(); }
|
||||
|
@ -234,6 +234,172 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
template <char not_case_lower_bound, char not_case_upper_bound>
|
||||
struct LowerUpperImplVectorized
|
||||
{
|
||||
template <char, char, int(int)> friend class LowerUpperUTF8ImplVectorized;
|
||||
|
||||
static void vector(const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets,
|
||||
ColumnString::Chars_t & res_data, ColumnString::Offsets_t & res_offsets)
|
||||
{
|
||||
res_data.resize(data.size());
|
||||
res_offsets.assign(offsets);
|
||||
array(data.data(), data.data() + data.size(), res_data.data());
|
||||
}
|
||||
|
||||
static void vector_fixed(const ColumnString::Chars_t & data, size_t n,
|
||||
ColumnString::Chars_t & res_data)
|
||||
{
|
||||
res_data.resize(data.size());
|
||||
array(data.data(), data.data() + data.size(), res_data.data());
|
||||
}
|
||||
|
||||
static void constant(const std::string & data, std::string & res_data)
|
||||
{
|
||||
res_data.resize(data.size());
|
||||
array(reinterpret_cast<const UInt8 *>(data.data()), reinterpret_cast<const UInt8 *>(data.data() + data.size()),
|
||||
reinterpret_cast<UInt8 *>(&res_data[0]));
|
||||
}
|
||||
|
||||
private:
|
||||
static void array(const UInt8 * src, const UInt8 * src_end, UInt8 * dst)
|
||||
{
|
||||
const auto src_end_sse = src_end - (src_end - src) % 16;
|
||||
|
||||
const auto flip_case_mask = 1 << 5;
|
||||
|
||||
const auto v_not_case_lower_bound = _mm_set1_epi8(not_case_lower_bound - 1);
|
||||
const auto v_not_case_upper_bound = _mm_set1_epi8(not_case_upper_bound + 1);
|
||||
const auto v_flip_case_mask = _mm_set1_epi8(flip_case_mask);
|
||||
|
||||
for (; src < src_end_sse; src += 16, dst += 16)
|
||||
{
|
||||
/// load 16 sequential 8-bit characters
|
||||
const auto chars = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src));
|
||||
|
||||
/// find which 8-bit sequences belong to range [case_lower_bound, case_upper_bound]
|
||||
const auto is_not_case = _mm_and_si128(_mm_cmpgt_epi8(chars, v_not_case_lower_bound),
|
||||
_mm_cmplt_epi8(chars, v_not_case_upper_bound));
|
||||
|
||||
/// keep `flip_case_mask` only where necessary, zero out elsewhere
|
||||
const auto xor_mask = _mm_and_si128(v_flip_case_mask, is_not_case);
|
||||
|
||||
/// flip case by applying calculated mask
|
||||
const auto cased_chars = _mm_xor_si128(chars, xor_mask);
|
||||
|
||||
/// store result back to destination
|
||||
_mm_storeu_si128(reinterpret_cast<__m128i *>(dst), cased_chars);
|
||||
}
|
||||
|
||||
for (; src < src_end; ++src, ++dst)
|
||||
*dst = (*src >= not_case_lower_bound && *src <= not_case_upper_bound) ? *src ^ flip_case_mask : *src;
|
||||
}
|
||||
};
|
||||
|
||||
template <char not_case_lower_bound, char not_case_upper_bound, int to_case(int)>
|
||||
struct LowerUpperUTF8ImplVectorized
|
||||
{
|
||||
static void vector(const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets,
|
||||
ColumnString::Chars_t & res_data, ColumnString::Offsets_t & res_offsets)
|
||||
{
|
||||
res_data.resize(data.size());
|
||||
res_offsets.assign(offsets);
|
||||
array(data.data(), data.data() + data.size(), res_data.data());
|
||||
}
|
||||
|
||||
static void vector_fixed(const ColumnString::Chars_t & data, size_t n,
|
||||
ColumnString::Chars_t & res_data)
|
||||
{
|
||||
res_data.resize(data.size());
|
||||
array(data.data(), data.data() + data.size(), res_data.data());
|
||||
}
|
||||
|
||||
static void constant(const std::string & data, std::string & res_data)
|
||||
{
|
||||
res_data.resize(data.size());
|
||||
array(reinterpret_cast<const UInt8 *>(data.data()), reinterpret_cast<const UInt8 *>(data.data() + data.size()),
|
||||
reinterpret_cast<UInt8 *>(&res_data[0]));
|
||||
}
|
||||
|
||||
private:
|
||||
static void array(const UInt8 * src, const UInt8 * src_end, UInt8 * dst)
|
||||
{
|
||||
auto is_ascii = false;
|
||||
|
||||
if (isCaseASCII(src, src_end, is_ascii))
|
||||
std::copy(src, src_end, dst);
|
||||
else if (is_ascii)
|
||||
LowerUpperImplVectorized<not_case_lower_bound, not_case_upper_bound>::array(src, src_end, dst);
|
||||
else
|
||||
UTF8ToCase(src, src_end, dst);
|
||||
}
|
||||
|
||||
static bool isCaseASCII(const UInt8 * src, const UInt8 * const src_end, bool & is_ascii)
|
||||
{
|
||||
const auto src_end_sse = src_end - (src_end - src) % 16;
|
||||
|
||||
const auto not_case_a_16 = _mm_set1_epi8('A' - 1);
|
||||
const auto not_case_z_16 = _mm_set1_epi8('Z' + 1);
|
||||
const auto zero_16 = _mm_setzero_si128();
|
||||
|
||||
auto is_case = true;
|
||||
|
||||
for (; src < src_end_sse; src += 16)
|
||||
{
|
||||
const auto chars = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src));
|
||||
|
||||
/// check for ASCII and case
|
||||
const auto is_not_ascii = _mm_cmplt_epi8(chars, zero_16);
|
||||
const auto mask_is_not_ascii = _mm_movemask_epi8(is_not_ascii);
|
||||
|
||||
if (mask_is_not_ascii != 0)
|
||||
{
|
||||
is_ascii = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto is_not_case = _mm_and_si128(_mm_cmpgt_epi8(chars, not_case_a_16),
|
||||
_mm_cmplt_epi8(chars, not_case_z_16));
|
||||
const auto mask_is_not_case = _mm_movemask_epi8(is_not_case);
|
||||
|
||||
if (mask_is_not_case != 0)
|
||||
is_case = false;
|
||||
}
|
||||
|
||||
/// handle remaining symbols
|
||||
for (; src < src_end; ++src)
|
||||
if (*src > '\x7f')
|
||||
{
|
||||
is_ascii = false;
|
||||
return false;
|
||||
}
|
||||
else if (*src >= 'A' && *src <= 'Z')
|
||||
is_case = false;
|
||||
|
||||
is_ascii = true;
|
||||
return is_case;
|
||||
}
|
||||
|
||||
static void UTF8ToCase(const UInt8 * src, const UInt8 * src_end, UInt8 * dst)
|
||||
{
|
||||
static const Poco::UTF8Encoding utf8;
|
||||
|
||||
while (src < src_end)
|
||||
{
|
||||
if (const auto chars = utf8.convert(to_case(utf8.convert(src)), dst, src_end - src))
|
||||
{
|
||||
src += chars;
|
||||
dst += chars;
|
||||
}
|
||||
else
|
||||
{
|
||||
++src;
|
||||
++dst;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
/** Если строка содержит текст в кодировке UTF-8 - перевести его в нижний (верхний) регистр.
|
||||
* Замечание: предполагается, что после перевода символа в другой регистр,
|
||||
@ -1424,6 +1590,11 @@ struct NameReverseUTF8 { static constexpr auto name = "reverseUTF8"; };
|
||||
struct NameSubstring { static constexpr auto name = "substring"; };
|
||||
struct NameSubstringUTF8 { static constexpr auto name = "substringUTF8"; };
|
||||
|
||||
struct NameSSELower { static constexpr auto name = "sse_lower"; };
|
||||
struct NameSSEUpper { static constexpr auto name = "sse_upper"; };
|
||||
struct NameSSELowerUTF8 { static constexpr auto name = "sse_lowerUTF8"; };
|
||||
struct NameSSEUpperUTF8 { static constexpr auto name = "sse_upperUTF8"; };
|
||||
|
||||
typedef FunctionStringOrArrayToT<EmptyImpl<false>, NameEmpty, UInt8> FunctionEmpty;
|
||||
typedef FunctionStringOrArrayToT<EmptyImpl<true>, NameNotEmpty, UInt8> FunctionNotEmpty;
|
||||
typedef FunctionStringOrArrayToT<LengthImpl, NameLength, UInt64> FunctionLength;
|
||||
@ -1437,5 +1608,10 @@ typedef FunctionStringToString<ReverseUTF8Impl, NameReverseUTF8> FunctionReve
|
||||
typedef FunctionStringNumNumToString<SubstringImpl, NameSubstring> FunctionSubstring;
|
||||
typedef FunctionStringNumNumToString<SubstringUTF8Impl, NameSubstringUTF8> FunctionSubstringUTF8;
|
||||
|
||||
using FunctionSSELower = FunctionStringToString<LowerUpperImplVectorized<'A', 'Z'>, NameSSELower>;
|
||||
using FunctionSSEUpper = FunctionStringToString<LowerUpperImplVectorized<'a', 'z'>, NameSSEUpper>;
|
||||
using FunctionSSELowerUTF8 = FunctionStringToString<LowerUpperUTF8ImplVectorized<'A', 'Z', Poco::Unicode::toLower>, NameSSELowerUTF8>;
|
||||
using FunctionSSEUpperUTF8 = FunctionStringToString<LowerUpperUTF8ImplVectorized<'a', 'z', Poco::Unicode::toUpper>, NameSSEUpperUTF8>;
|
||||
|
||||
|
||||
}
|
||||
|
@ -20,6 +20,10 @@ void registerFunctionsString(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionSubstring>();
|
||||
factory.registerFunction<FunctionSubstringUTF8>();
|
||||
factory.registerFunction<FunctionAppendTrailingCharIfAbsent>();
|
||||
factory.registerFunction<FunctionSSELower>();
|
||||
factory.registerFunction<FunctionSSEUpper>();
|
||||
factory.registerFunction<FunctionSSELowerUTF8>();
|
||||
factory.registerFunction<FunctionSSEUpperUTF8>();
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user