dbms: add SSE variants of lower/upper and UTF8 equivalents [#METR-14764]

This commit is contained in:
Andrey Mironov 2015-05-28 15:32:43 +03:00
parent 22b80c8226
commit b06bdb0edf
3 changed files with 182 additions and 0 deletions

View File

@ -192,6 +192,8 @@ public:
return *this;
}
T * data() { return t_start(); }
const T * data() const { return t_start(); }
size_t size() const { return t_end() - t_start(); }
bool empty() const { return t_end() == t_start(); }

View File

@ -234,6 +234,172 @@ private:
}
};
template <char not_case_lower_bound, char not_case_upper_bound>
struct LowerUpperImplVectorized
{
template <char, char, int(int)> friend class LowerUpperUTF8ImplVectorized;
static void vector(const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets,
ColumnString::Chars_t & res_data, ColumnString::Offsets_t & res_offsets)
{
res_data.resize(data.size());
res_offsets.assign(offsets);
array(data.data(), data.data() + data.size(), res_data.data());
}
static void vector_fixed(const ColumnString::Chars_t & data, size_t n,
ColumnString::Chars_t & res_data)
{
res_data.resize(data.size());
array(data.data(), data.data() + data.size(), res_data.data());
}
static void constant(const std::string & data, std::string & res_data)
{
res_data.resize(data.size());
array(reinterpret_cast<const UInt8 *>(data.data()), reinterpret_cast<const UInt8 *>(data.data() + data.size()),
reinterpret_cast<UInt8 *>(&res_data[0]));
}
private:
static void array(const UInt8 * src, const UInt8 * src_end, UInt8 * dst)
{
const auto src_end_sse = src_end - (src_end - src) % 16;
const auto flip_case_mask = 1 << 5;
const auto v_not_case_lower_bound = _mm_set1_epi8(not_case_lower_bound - 1);
const auto v_not_case_upper_bound = _mm_set1_epi8(not_case_upper_bound + 1);
const auto v_flip_case_mask = _mm_set1_epi8(flip_case_mask);
for (; src < src_end_sse; src += 16, dst += 16)
{
/// load 16 sequential 8-bit characters
const auto chars = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src));
/// find which 8-bit sequences belong to range [case_lower_bound, case_upper_bound]
const auto is_not_case = _mm_and_si128(_mm_cmpgt_epi8(chars, v_not_case_lower_bound),
_mm_cmplt_epi8(chars, v_not_case_upper_bound));
/// keep `flip_case_mask` only where necessary, zero out elsewhere
const auto xor_mask = _mm_and_si128(v_flip_case_mask, is_not_case);
/// flip case by applying calculated mask
const auto cased_chars = _mm_xor_si128(chars, xor_mask);
/// store result back to destination
_mm_storeu_si128(reinterpret_cast<__m128i *>(dst), cased_chars);
}
for (; src < src_end; ++src, ++dst)
*dst = (*src >= not_case_lower_bound && *src <= not_case_upper_bound) ? *src ^ flip_case_mask : *src;
}
};
template <char not_case_lower_bound, char not_case_upper_bound, int to_case(int)>
struct LowerUpperUTF8ImplVectorized
{
static void vector(const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets,
ColumnString::Chars_t & res_data, ColumnString::Offsets_t & res_offsets)
{
res_data.resize(data.size());
res_offsets.assign(offsets);
array(data.data(), data.data() + data.size(), res_data.data());
}
static void vector_fixed(const ColumnString::Chars_t & data, size_t n,
ColumnString::Chars_t & res_data)
{
res_data.resize(data.size());
array(data.data(), data.data() + data.size(), res_data.data());
}
static void constant(const std::string & data, std::string & res_data)
{
res_data.resize(data.size());
array(reinterpret_cast<const UInt8 *>(data.data()), reinterpret_cast<const UInt8 *>(data.data() + data.size()),
reinterpret_cast<UInt8 *>(&res_data[0]));
}
private:
static void array(const UInt8 * src, const UInt8 * src_end, UInt8 * dst)
{
auto is_ascii = false;
if (isCaseASCII(src, src_end, is_ascii))
std::copy(src, src_end, dst);
else if (is_ascii)
LowerUpperImplVectorized<not_case_lower_bound, not_case_upper_bound>::array(src, src_end, dst);
else
UTF8ToCase(src, src_end, dst);
}
static bool isCaseASCII(const UInt8 * src, const UInt8 * const src_end, bool & is_ascii)
{
const auto src_end_sse = src_end - (src_end - src) % 16;
const auto not_case_a_16 = _mm_set1_epi8('A' - 1);
const auto not_case_z_16 = _mm_set1_epi8('Z' + 1);
const auto zero_16 = _mm_setzero_si128();
auto is_case = true;
for (; src < src_end_sse; src += 16)
{
const auto chars = _mm_loadu_si128(reinterpret_cast<const __m128i *>(src));
/// check for ASCII and case
const auto is_not_ascii = _mm_cmplt_epi8(chars, zero_16);
const auto mask_is_not_ascii = _mm_movemask_epi8(is_not_ascii);
if (mask_is_not_ascii != 0)
{
is_ascii = false;
return false;
}
const auto is_not_case = _mm_and_si128(_mm_cmpgt_epi8(chars, not_case_a_16),
_mm_cmplt_epi8(chars, not_case_z_16));
const auto mask_is_not_case = _mm_movemask_epi8(is_not_case);
if (mask_is_not_case != 0)
is_case = false;
}
/// handle remaining symbols
for (; src < src_end; ++src)
if (*src > '\x7f')
{
is_ascii = false;
return false;
}
else if (*src >= 'A' && *src <= 'Z')
is_case = false;
is_ascii = true;
return is_case;
}
static void UTF8ToCase(const UInt8 * src, const UInt8 * src_end, UInt8 * dst)
{
static const Poco::UTF8Encoding utf8;
while (src < src_end)
{
if (const auto chars = utf8.convert(to_case(utf8.convert(src)), dst, src_end - src))
{
src += chars;
dst += chars;
}
else
{
++src;
++dst;
}
}
}
};
/** Если строка содержит текст в кодировке UTF-8 - перевести его в нижний (верхний) регистр.
* Замечание: предполагается, что после перевода символа в другой регистр,
@ -1424,6 +1590,11 @@ struct NameReverseUTF8 { static constexpr auto name = "reverseUTF8"; };
struct NameSubstring { static constexpr auto name = "substring"; };
struct NameSubstringUTF8 { static constexpr auto name = "substringUTF8"; };
struct NameSSELower { static constexpr auto name = "sse_lower"; };
struct NameSSEUpper { static constexpr auto name = "sse_upper"; };
struct NameSSELowerUTF8 { static constexpr auto name = "sse_lowerUTF8"; };
struct NameSSEUpperUTF8 { static constexpr auto name = "sse_upperUTF8"; };
typedef FunctionStringOrArrayToT<EmptyImpl<false>, NameEmpty, UInt8> FunctionEmpty;
typedef FunctionStringOrArrayToT<EmptyImpl<true>, NameNotEmpty, UInt8> FunctionNotEmpty;
typedef FunctionStringOrArrayToT<LengthImpl, NameLength, UInt64> FunctionLength;
@ -1437,5 +1608,10 @@ typedef FunctionStringToString<ReverseUTF8Impl, NameReverseUTF8> FunctionReve
typedef FunctionStringNumNumToString<SubstringImpl, NameSubstring> FunctionSubstring;
typedef FunctionStringNumNumToString<SubstringUTF8Impl, NameSubstringUTF8> FunctionSubstringUTF8;
using FunctionSSELower = FunctionStringToString<LowerUpperImplVectorized<'A', 'Z'>, NameSSELower>;
using FunctionSSEUpper = FunctionStringToString<LowerUpperImplVectorized<'a', 'z'>, NameSSEUpper>;
using FunctionSSELowerUTF8 = FunctionStringToString<LowerUpperUTF8ImplVectorized<'A', 'Z', Poco::Unicode::toLower>, NameSSELowerUTF8>;
using FunctionSSEUpperUTF8 = FunctionStringToString<LowerUpperUTF8ImplVectorized<'a', 'z', Poco::Unicode::toUpper>, NameSSEUpperUTF8>;
}

View File

@ -20,6 +20,10 @@ void registerFunctionsString(FunctionFactory & factory)
factory.registerFunction<FunctionSubstring>();
factory.registerFunction<FunctionSubstringUTF8>();
factory.registerFunction<FunctionAppendTrailingCharIfAbsent>();
factory.registerFunction<FunctionSSELower>();
factory.registerFunction<FunctionSSEUpper>();
factory.registerFunction<FunctionSSELowerUTF8>();
factory.registerFunction<FunctionSSEUpperUTF8>();
}
}