dbms: refactor, fix out of bounds access when needle is empty [#METR-16752]

This commit is contained in:
Andrey Mironov 2015-09-24 17:28:31 +03:00
parent fd6dca0432
commit 01e767afa0

View File

@ -207,51 +207,68 @@ struct PositionUTF8Impl
struct PositionCaseInsensitiveImpl
{
using ResultType = UInt64;
static void vector(
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
private:
class CaseInsensitiveSearcher
{
static constexpr auto n = sizeof(__m128i);
const int page_size = getpagesize();
/// string to be searched for
const std::string & needle;
/// lower and uppercase variants of the first character in `needle`
const auto l = std::tolower(needle.front());
const auto u = std::toupper(needle.front());
/// for detecting leftmost position of the first symbol
const auto patl = _mm_set1_epi8(l);
const auto patu = _mm_set1_epi8(u);
UInt8 l{};
UInt8 u{};
/// vectors filled with `l` and `u`, for determining leftmost position of the first symbol
__m128i patl, patu;
/// lower and uppercase vectors of first 16 characters of `needle`
auto cachel = _mm_setzero_si128();
auto cacheu = _mm_setzero_si128();
int cachemask = 0;
__m128i cachel = _mm_setzero_si128(), cacheu = _mm_setzero_si128();
int cachemask{};
const auto n = sizeof(cachel);
const auto needle_begin = needle.data();
const auto needle_end = needle_begin + needle.size();
auto needle_pos = needle_begin;
for (const auto i : ext::range(0, n))
bool page_safe(const void * const ptr) const
{
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
}
if (needle_pos != needle_end)
public:
CaseInsensitiveSearcher(const std::string & needle) : needle(needle)
{
if (needle.empty())
return;
auto needle_pos = needle.data();
l = std::tolower(*needle_pos);
u = std::toupper(*needle_pos);
patl = _mm_set1_epi8(l);
patu = _mm_set1_epi8(u);
const auto needle_end = needle_pos + needle.size();
for (const auto i : ext::range(0, n))
{
cachel = _mm_insert_epi8(cachel, std::tolower(*needle_pos), n - 1);
cacheu = _mm_insert_epi8(cacheu, std::toupper(*needle_pos), n - 1);
cachemask |= 1 << i;
++needle_pos;
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
if (needle_pos != needle_end)
{
cachel = _mm_insert_epi8(cachel, std::tolower(*needle_pos), n - 1);
cacheu = _mm_insert_epi8(cacheu, std::toupper(*needle_pos), n - 1);
cachemask |= 1 << i;
++needle_pos;
}
}
}
const auto page_size = getpagesize();
const auto page_safe = [&] (const void * const ptr) {
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
};
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) {
if (needle_begin == needle_end)
const UInt8 * find(const UInt8 * haystack, const UInt8 * const haystack_end) const
{
if (needle.empty())
return haystack;
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data());
const auto needle_end = needle_begin + needle.size();
while (haystack < haystack_end)
{
/// @todo supposedly for long strings spanning across multiple pages. Why don't we use this technique in other places?
@ -285,13 +302,14 @@ struct PositionCaseInsensitiveImpl
{
if (mask == cachemask)
{
auto s1 = haystack + n;
auto s2 = needle_begin + n;
auto haystack_pos = haystack + n;
auto needle_pos = needle_begin + n;
while (s1 < haystack_end && s2 < needle_end && std::tolower(*s1) == std::tolower(*s2))
++s1, ++s2;
while (haystack_pos < haystack_end && needle_pos < needle_end &&
std::tolower(*haystack_pos) == std::tolower(*needle_pos))
++haystack_pos, ++needle_pos;
if (s2 == needle_end)
if (needle_pos == needle_end)
return haystack;
}
}
@ -308,13 +326,14 @@ struct PositionCaseInsensitiveImpl
if (*haystack == l || *haystack == u)
{
auto s1 = haystack + 1;
auto s2 = needle_begin + 1;
auto haystack_pos = haystack + 1;
auto needle_pos = needle_begin + 1;
while (s1 < haystack_end && s2 < needle_end && std::tolower(*s1) == std::tolower(*s2))
++s1, ++s2;
while (haystack_pos < haystack_end && needle_pos < needle_end &&
std::tolower(*haystack_pos) == std::tolower(*needle_pos))
++haystack_pos, ++needle_pos;
if (s2 == needle_end)
if (needle_pos == needle_end)
return haystack;
}
@ -322,7 +341,17 @@ struct PositionCaseInsensitiveImpl
}
return haystack_end;
};
}
};
public:
using ResultType = UInt64;
static void vector(
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
{
const CaseInsensitiveSearcher searcher{needle};
const UInt8 * begin = &data[0];
const UInt8 * pos = begin;
@ -332,7 +361,7 @@ struct PositionCaseInsensitiveImpl
size_t i = 0;
/// Искать будем следующее вхождение сразу во всех строках.
while (pos < end && end != (pos = find_ci(pos, end)))
while (pos < end && end != (pos = searcher.find(pos, end)))
{
/// Определим, к какому индексу оно относится.
while (begin + offsets[i] < pos)
@ -370,97 +399,128 @@ struct PositionCaseInsensitiveImpl
struct PositionCaseInsensitiveUTF8Impl
{
using ResultType = UInt64;
static void vector(
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
private:
class CaseInsensitiveSearcher
{
using UTF8SequenceBuffer = UInt8[6];
static const Poco::UTF8Encoding utf8;
UTF8SequenceBuffer l_seq, u_seq;
static constexpr auto n = sizeof(__m128i);
const auto first_u32 = utf8.convert(reinterpret_cast<const UInt8 *>(needle.data()));
const auto first_l_u32 = Poco::Unicode::toLower(first_u32);
const auto first_u_u32 = Poco::Unicode::toUpper(first_u32);
const int page_size = getpagesize();
/// string to be searched for
const std::string & needle;
bool first_needle_symbol_is_ascii{};
/// lower and uppercase variants of the first octet of the first character in `needle`
utf8.convert(first_l_u32, l_seq, sizeof(l_seq));
const auto l = l_seq[0];
utf8.convert(first_u_u32, u_seq, sizeof(u_seq));
const auto u = u_seq[0];
/// for detecting leftmost position of the first symbol
const auto patl = _mm_set1_epi8(l);
const auto patu = _mm_set1_epi8(u);
/// lower and uppercase vectors of first 16 octets of `needle`
auto cachel = _mm_setzero_si128();
auto cacheu = _mm_setzero_si128();
int cachemask = 0;
UInt8 l{};
UInt8 u{};
/// vectors filled with `l` and `u`, for determining leftmost position of the first symbol
__m128i patl, patu;
/// lower and uppercase vectors of first 16 characters of `needle`
__m128i cachel = _mm_setzero_si128(), cacheu = _mm_setzero_si128();
int cachemask{};
std::size_t cache_valid_len{};
std::size_t cache_actual_len{};
const auto n = sizeof(cachel);
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data());
const auto needle_end = needle_begin + needle.size();
auto needle_pos = needle_begin;
for (std::size_t i = 0; i < n;)
bool page_safe(const void * const ptr) const
{
if (needle_pos == needle_end)
{
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
++i;
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
}
continue;
public:
CaseInsensitiveSearcher(const std::string & needle) : needle(needle)
{
if (needle.empty())
return;
static const Poco::UTF8Encoding utf8;
UTF8SequenceBuffer l_seq, u_seq;
auto needle_pos = reinterpret_cast<const UInt8 *>(needle.data());
if (*needle_pos < 0x80u)
{
first_needle_symbol_is_ascii = true;
l = std::tolower(*needle_pos);
u = std::toupper(*needle_pos);
}
else
{
const auto first_u32 = utf8.convert(needle_pos);
const auto first_l_u32 = Poco::Unicode::toLower(first_u32);
const auto first_u_u32 = Poco::Unicode::toUpper(first_u32);
/// lower and uppercase variants of the first octet of the first character in `needle`
utf8.convert(first_l_u32, l_seq, sizeof(l_seq));
l = l_seq[0];
utf8.convert(first_u_u32, u_seq, sizeof(u_seq));
u = u_seq[0];
}
const auto src_len = utf8_seq_length(*needle_pos);
const auto c_u32 = utf8.convert(needle_pos);
/// for detecting leftmost position of the first symbol
patl = _mm_set1_epi8(l);
patu = _mm_set1_epi8(u);
/// lower and uppercase vectors of first 16 octets of `needle`
const auto c_l_u32 = Poco::Unicode::toLower(c_u32);
const auto c_u_u32 = Poco::Unicode::toUpper(c_u32);
const auto needle_end = needle_pos + needle.size();
const auto dst_l_len = static_cast<UInt8>(utf8.convert(c_l_u32, l_seq, sizeof(l_seq)));
const auto dst_u_len = static_cast<UInt8>(utf8.convert(c_u_u32, u_seq, sizeof(u_seq)));
/// @note Unicode standard states it is a rare but possible occasion
if (!(dst_l_len == dst_u_len && dst_u_len == src_len))
throw Exception{
"UTF8 sequences with different lowercase and uppercase lengths are not supported",
ErrorCodes::UNSUPPORTED_PARAMETER
};
cache_actual_len += src_len;
if (cache_actual_len < n)
cache_valid_len += src_len;
for (std::size_t j = 0; j < src_len && i < n; ++j, ++i)
for (std::size_t i = 0; i < n;)
{
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
if (needle_pos != needle_end)
if (needle_pos == needle_end)
{
cachel = _mm_insert_epi8(cachel, l_seq[j], n - 1);
cacheu = _mm_insert_epi8(cacheu, u_seq[j], n - 1);
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
++i;
cachemask |= 1 << i;
++needle_pos;
continue;
}
const auto src_len = utf8_seq_length(*needle_pos);
const auto c_u32 = utf8.convert(needle_pos);
const auto c_l_u32 = Poco::Unicode::toLower(c_u32);
const auto c_u_u32 = Poco::Unicode::toUpper(c_u32);
const auto dst_l_len = static_cast<UInt8>(utf8.convert(c_l_u32, l_seq, sizeof(l_seq)));
const auto dst_u_len = static_cast<UInt8>(utf8.convert(c_u_u32, u_seq, sizeof(u_seq)));
/// @note Unicode standard states it is a rare but possible occasion
if (!(dst_l_len == dst_u_len && dst_u_len == src_len))
throw Exception{
"UTF8 sequences with different lowercase and uppercase lengths are not supported",
ErrorCodes::UNSUPPORTED_PARAMETER
};
cache_actual_len += src_len;
if (cache_actual_len < n)
cache_valid_len += src_len;
for (std::size_t j = 0; j < src_len && i < n; ++j, ++i)
{
cachel = _mm_srli_si128(cachel, 1);
cacheu = _mm_srli_si128(cacheu, 1);
if (needle_pos != needle_end)
{
cachel = _mm_insert_epi8(cachel, l_seq[j], n - 1);
cacheu = _mm_insert_epi8(cacheu, u_seq[j], n - 1);
cachemask |= 1 << i;
++needle_pos;
}
}
}
}
const auto page_size = getpagesize();
const auto page_safe = [&] (const void * const ptr) {
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
};
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) {
if (needle_begin == needle_end)
const UInt8 * find(const UInt8 * haystack, const UInt8 * const haystack_end) const
{
if (needle.empty())
return haystack;
static const Poco::UTF8Encoding utf8;
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data());
const auto needle_end = needle_begin + needle.size();
while (haystack < haystack_end)
{
if (haystack + n <= haystack_end && page_safe(haystack))
@ -494,19 +554,19 @@ struct PositionCaseInsensitiveUTF8Impl
{
if (mask == cachemask)
{
auto s1 = haystack + cache_valid_len;
auto s2 = needle_begin + cache_valid_len;
auto haystack_pos = haystack + cache_valid_len;
auto needle_pos = needle_begin + cache_valid_len;
while (s1 < haystack_end && s2 < needle_end &&
Poco::Unicode::toLower(utf8.convert(s1)) ==
Poco::Unicode::toLower(utf8.convert(reinterpret_cast<const UInt8 *>(s2))))
while (haystack_pos < haystack_end && needle_pos < needle_end &&
Poco::Unicode::toLower(utf8.convert(haystack_pos)) ==
Poco::Unicode::toLower(utf8.convert(needle_pos)))
{
/// @note assuming sequences for lowercase and uppercase have exact same length
const auto len = utf8_seq_length(*s1);
s1 += len, s2 += len;
const auto len = utf8_seq_length(*haystack_pos);
haystack_pos += len, needle_pos += len;
}
if (s2 == needle_end)
if (needle_pos == needle_end)
return haystack;
}
}
@ -524,18 +584,18 @@ struct PositionCaseInsensitiveUTF8Impl
if (*haystack == l || *haystack == u)
{
auto s1 = haystack;
auto s2 = needle_begin;
auto haystack_pos = haystack + first_needle_symbol_is_ascii;
auto needle_pos = needle_begin + first_needle_symbol_is_ascii;
while (s1 < haystack_end && s2 < needle_end &&
Poco::Unicode::toLower(utf8.convert(s1)) ==
Poco::Unicode::toLower(utf8.convert(s2)))
while (haystack_pos < haystack_end && needle_pos < needle_end &&
Poco::Unicode::toLower(utf8.convert(haystack_pos)) ==
Poco::Unicode::toLower(utf8.convert(needle_pos)))
{
const auto len = utf8_seq_length(*s1);
s1 += len, s2 += len;
const auto len = utf8_seq_length(*haystack_pos);
haystack_pos += len, needle_pos += len;
}
if (s2 == needle_end)
if (needle_pos == needle_end)
return haystack;
}
@ -544,7 +604,17 @@ struct PositionCaseInsensitiveUTF8Impl
}
return haystack_end;
};
}
};
public:
using ResultType = UInt64;
static void vector(
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
PODArray<UInt64> & res)
{
const CaseInsensitiveSearcher searcher{needle};
const UInt8 * begin = &data[0];
const UInt8 * pos = begin;
@ -554,7 +624,7 @@ struct PositionCaseInsensitiveUTF8Impl
size_t i = 0;
/// Искать будем следующее вхождение сразу во всех строках.
while (pos < end && end != (pos = find_ci(pos, end)))
while (pos < end && end != (pos = searcher.find(pos, end)))
{
/// Определим, к какому индексу оно относится.
while (begin + offsets[i] < pos)