dbms: refactor, fix out of bounds access when needle is empty [#METR-16752]

This commit is contained in:
Andrey Mironov 2015-09-24 17:28:31 +03:00
parent fd6dca0432
commit 01e767afa0

View File

@ -207,27 +207,44 @@ struct PositionUTF8Impl
struct PositionCaseInsensitiveImpl struct PositionCaseInsensitiveImpl
{ {
using ResultType = UInt64; private:
class CaseInsensitiveSearcher
static void vector(
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
{ {
/// lower and uppercase variants of the first character in `needle` static constexpr auto n = sizeof(__m128i);
const auto l = std::tolower(needle.front());
const auto u = std::toupper(needle.front());
/// for detecting leftmost position of the first symbol
const auto patl = _mm_set1_epi8(l);
const auto patu = _mm_set1_epi8(u);
/// lower and uppercase vectors of first 16 characters of `needle`
auto cachel = _mm_setzero_si128();
auto cacheu = _mm_setzero_si128();
int cachemask = 0;
const auto n = sizeof(cachel); const int page_size = getpagesize();
const auto needle_begin = needle.data();
const auto needle_end = needle_begin + needle.size(); /// string to be searched for
auto needle_pos = needle_begin; const std::string & needle;
/// lower and uppercase variants of the first character in `needle`
UInt8 l{};
UInt8 u{};
/// vectors filled with `l` and `u`, for determining leftmost position of the first symbol
__m128i patl, patu;
/// lower and uppercase vectors of first 16 characters of `needle`
__m128i cachel = _mm_setzero_si128(), cacheu = _mm_setzero_si128();
int cachemask{};
bool page_safe(const void * const ptr) const
{
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
}
public:
CaseInsensitiveSearcher(const std::string & needle) : needle(needle)
{
if (needle.empty())
return;
auto needle_pos = needle.data();
l = std::tolower(*needle_pos);
u = std::toupper(*needle_pos);
patl = _mm_set1_epi8(l);
patu = _mm_set1_epi8(u);
const auto needle_end = needle_pos + needle.size();
for (const auto i : ext::range(0, n)) for (const auto i : ext::range(0, n))
{ {
@ -242,16 +259,16 @@ struct PositionCaseInsensitiveImpl
++needle_pos; ++needle_pos;
} }
} }
}
const auto page_size = getpagesize(); const UInt8 * find(const UInt8 * haystack, const UInt8 * const haystack_end) const
const auto page_safe = [&] (const void * const ptr) { {
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n; if (needle.empty())
};
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) {
if (needle_begin == needle_end)
return haystack; return haystack;
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data());
const auto needle_end = needle_begin + needle.size();
while (haystack < haystack_end) while (haystack < haystack_end)
{ {
/// @todo supposedly for long strings spanning across multiple pages. Why don't we use this technique in other places? /// @todo supposedly for long strings spanning across multiple pages. Why don't we use this technique in other places?
@ -285,13 +302,14 @@ struct PositionCaseInsensitiveImpl
{ {
if (mask == cachemask) if (mask == cachemask)
{ {
auto s1 = haystack + n; auto haystack_pos = haystack + n;
auto s2 = needle_begin + n; auto needle_pos = needle_begin + n;
while (s1 < haystack_end && s2 < needle_end && std::tolower(*s1) == std::tolower(*s2)) while (haystack_pos < haystack_end && needle_pos < needle_end &&
++s1, ++s2; std::tolower(*haystack_pos) == std::tolower(*needle_pos))
++haystack_pos, ++needle_pos;
if (s2 == needle_end) if (needle_pos == needle_end)
return haystack; return haystack;
} }
} }
@ -308,13 +326,14 @@ struct PositionCaseInsensitiveImpl
if (*haystack == l || *haystack == u) if (*haystack == l || *haystack == u)
{ {
auto s1 = haystack + 1; auto haystack_pos = haystack + 1;
auto s2 = needle_begin + 1; auto needle_pos = needle_begin + 1;
while (s1 < haystack_end && s2 < needle_end && std::tolower(*s1) == std::tolower(*s2)) while (haystack_pos < haystack_end && needle_pos < needle_end &&
++s1, ++s2; std::tolower(*haystack_pos) == std::tolower(*needle_pos))
++haystack_pos, ++needle_pos;
if (s2 == needle_end) if (needle_pos == needle_end)
return haystack; return haystack;
} }
@ -322,8 +341,18 @@ struct PositionCaseInsensitiveImpl
} }
return haystack_end; return haystack_end;
}
}; };
public:
using ResultType = UInt64;
static void vector(
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
{
const CaseInsensitiveSearcher searcher{needle};
const UInt8 * begin = &data[0]; const UInt8 * begin = &data[0];
const UInt8 * pos = begin; const UInt8 * pos = begin;
const UInt8 * end = pos + data.size(); const UInt8 * end = pos + data.size();
@ -332,7 +361,7 @@ struct PositionCaseInsensitiveImpl
size_t i = 0; size_t i = 0;
/// Искать будем следующее вхождение сразу во всех строках. /// Искать будем следующее вхождение сразу во всех строках.
while (pos < end && end != (pos = find_ci(pos, end))) while (pos < end && end != (pos = searcher.find(pos, end)))
{ {
/// Определим, к какому индексу оно относится. /// Определим, к какому индексу оно относится.
while (begin + offsets[i] < pos) while (begin + offsets[i] < pos)
@ -370,40 +399,69 @@ struct PositionCaseInsensitiveImpl
struct PositionCaseInsensitiveUTF8Impl struct PositionCaseInsensitiveUTF8Impl
{ {
using ResultType = UInt64; private:
class CaseInsensitiveSearcher
static void vector(
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
{ {
using UTF8SequenceBuffer = UInt8[6]; using UTF8SequenceBuffer = UInt8[6];
static constexpr auto n = sizeof(__m128i);
const int page_size = getpagesize();
/// string to be searched for
const std::string & needle;
bool first_needle_symbol_is_ascii{};
/// lower and uppercase variants of the first octet of the first character in `needle`
UInt8 l{};
UInt8 u{};
/// vectors filled with `l` and `u`, for determining leftmost position of the first symbol
__m128i patl, patu;
/// lower and uppercase vectors of first 16 characters of `needle`
__m128i cachel = _mm_setzero_si128(), cacheu = _mm_setzero_si128();
int cachemask{};
std::size_t cache_valid_len{};
std::size_t cache_actual_len{};
bool page_safe(const void * const ptr) const
{
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
}
public:
CaseInsensitiveSearcher(const std::string & needle) : needle(needle)
{
if (needle.empty())
return;
static const Poco::UTF8Encoding utf8; static const Poco::UTF8Encoding utf8;
UTF8SequenceBuffer l_seq, u_seq; UTF8SequenceBuffer l_seq, u_seq;
const auto first_u32 = utf8.convert(reinterpret_cast<const UInt8 *>(needle.data())); auto needle_pos = reinterpret_cast<const UInt8 *>(needle.data());
if (*needle_pos < 0x80u)
{
first_needle_symbol_is_ascii = true;
l = std::tolower(*needle_pos);
u = std::toupper(*needle_pos);
}
else
{
const auto first_u32 = utf8.convert(needle_pos);
const auto first_l_u32 = Poco::Unicode::toLower(first_u32); const auto first_l_u32 = Poco::Unicode::toLower(first_u32);
const auto first_u_u32 = Poco::Unicode::toUpper(first_u32); const auto first_u_u32 = Poco::Unicode::toUpper(first_u32);
/// lower and uppercase variants of the first octet of the first character in `needle` /// lower and uppercase variants of the first octet of the first character in `needle`
utf8.convert(first_l_u32, l_seq, sizeof(l_seq)); utf8.convert(first_l_u32, l_seq, sizeof(l_seq));
const auto l = l_seq[0]; l = l_seq[0];
utf8.convert(first_u_u32, u_seq, sizeof(u_seq)); utf8.convert(first_u_u32, u_seq, sizeof(u_seq));
const auto u = u_seq[0]; u = u_seq[0];
/// for detecting leftmost position of the first symbol }
const auto patl = _mm_set1_epi8(l);
const auto patu = _mm_set1_epi8(u);
/// lower and uppercase vectors of first 16 octets of `needle`
auto cachel = _mm_setzero_si128();
auto cacheu = _mm_setzero_si128();
int cachemask = 0;
std::size_t cache_valid_len{};
std::size_t cache_actual_len{};
const auto n = sizeof(cachel); /// for detecting leftmost position of the first symbol
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data()); patl = _mm_set1_epi8(l);
const auto needle_end = needle_begin + needle.size(); patu = _mm_set1_epi8(u);
auto needle_pos = needle_begin; /// lower and uppercase vectors of first 16 octets of `needle`
const auto needle_end = needle_pos + needle.size();
for (std::size_t i = 0; i < n;) for (std::size_t i = 0; i < n;)
{ {
@ -451,16 +509,18 @@ struct PositionCaseInsensitiveUTF8Impl
} }
} }
} }
}
const auto page_size = getpagesize(); const UInt8 * find(const UInt8 * haystack, const UInt8 * const haystack_end) const
const auto page_safe = [&] (const void * const ptr) { {
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n; if (needle.empty())
};
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) {
if (needle_begin == needle_end)
return haystack; return haystack;
static const Poco::UTF8Encoding utf8;
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data());
const auto needle_end = needle_begin + needle.size();
while (haystack < haystack_end) while (haystack < haystack_end)
{ {
if (haystack + n <= haystack_end && page_safe(haystack)) if (haystack + n <= haystack_end && page_safe(haystack))
@ -494,19 +554,19 @@ struct PositionCaseInsensitiveUTF8Impl
{ {
if (mask == cachemask) if (mask == cachemask)
{ {
auto s1 = haystack + cache_valid_len; auto haystack_pos = haystack + cache_valid_len;
auto s2 = needle_begin + cache_valid_len; auto needle_pos = needle_begin + cache_valid_len;
while (s1 < haystack_end && s2 < needle_end && while (haystack_pos < haystack_end && needle_pos < needle_end &&
Poco::Unicode::toLower(utf8.convert(s1)) == Poco::Unicode::toLower(utf8.convert(haystack_pos)) ==
Poco::Unicode::toLower(utf8.convert(reinterpret_cast<const UInt8 *>(s2)))) Poco::Unicode::toLower(utf8.convert(needle_pos)))
{ {
/// @note assuming sequences for lowercase and uppercase have exact same length /// @note assuming sequences for lowercase and uppercase have exact same length
const auto len = utf8_seq_length(*s1); const auto len = utf8_seq_length(*haystack_pos);
s1 += len, s2 += len; haystack_pos += len, needle_pos += len;
} }
if (s2 == needle_end) if (needle_pos == needle_end)
return haystack; return haystack;
} }
} }
@ -524,18 +584,18 @@ struct PositionCaseInsensitiveUTF8Impl
if (*haystack == l || *haystack == u) if (*haystack == l || *haystack == u)
{ {
auto s1 = haystack; auto haystack_pos = haystack + first_needle_symbol_is_ascii;
auto s2 = needle_begin; auto needle_pos = needle_begin + first_needle_symbol_is_ascii;
while (s1 < haystack_end && s2 < needle_end && while (haystack_pos < haystack_end && needle_pos < needle_end &&
Poco::Unicode::toLower(utf8.convert(s1)) == Poco::Unicode::toLower(utf8.convert(haystack_pos)) ==
Poco::Unicode::toLower(utf8.convert(s2))) Poco::Unicode::toLower(utf8.convert(needle_pos)))
{ {
const auto len = utf8_seq_length(*s1); const auto len = utf8_seq_length(*haystack_pos);
s1 += len, s2 += len; haystack_pos += len, needle_pos += len;
} }
if (s2 == needle_end) if (needle_pos == needle_end)
return haystack; return haystack;
} }
@ -544,8 +604,18 @@ struct PositionCaseInsensitiveUTF8Impl
} }
return haystack_end; return haystack_end;
}
}; };
public:
using ResultType = UInt64;
static void vector(
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
{
const CaseInsensitiveSearcher searcher{needle};
const UInt8 * begin = &data[0]; const UInt8 * begin = &data[0];
const UInt8 * pos = begin; const UInt8 * pos = begin;
const UInt8 * end = pos + data.size(); const UInt8 * end = pos + data.size();
@ -554,7 +624,7 @@ struct PositionCaseInsensitiveUTF8Impl
size_t i = 0; size_t i = 0;
/// Искать будем следующее вхождение сразу во всех строках. /// Искать будем следующее вхождение сразу во всех строках.
while (pos < end && end != (pos = find_ci(pos, end))) while (pos < end && end != (pos = searcher.find(pos, end)))
{ {
/// Определим, к какому индексу оно относится. /// Определим, к какому индексу оно относится.
while (begin + offsets[i] < pos) while (begin + offsets[i] < pos)