mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 16:12:01 +00:00
dbms: refactor, fix out of bounds access when needle is empty [#METR-16752]
This commit is contained in:
parent
fd6dca0432
commit
01e767afa0
@ -207,51 +207,68 @@ struct PositionUTF8Impl
|
||||
|
||||
struct PositionCaseInsensitiveImpl
|
||||
{
|
||||
using ResultType = UInt64;
|
||||
|
||||
static void vector(
|
||||
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
|
||||
PODArray<UInt64> & res)
|
||||
private:
|
||||
class CaseInsensitiveSearcher
|
||||
{
|
||||
static constexpr auto n = sizeof(__m128i);
|
||||
|
||||
const int page_size = getpagesize();
|
||||
|
||||
/// string to be searched for
|
||||
const std::string & needle;
|
||||
/// lower and uppercase variants of the first character in `needle`
|
||||
const auto l = std::tolower(needle.front());
|
||||
const auto u = std::toupper(needle.front());
|
||||
/// for detecting leftmost position of the first symbol
|
||||
const auto patl = _mm_set1_epi8(l);
|
||||
const auto patu = _mm_set1_epi8(u);
|
||||
UInt8 l{};
|
||||
UInt8 u{};
|
||||
/// vectors filled with `l` and `u`, for determining leftmost position of the first symbol
|
||||
__m128i patl, patu;
|
||||
/// lower and uppercase vectors of first 16 characters of `needle`
|
||||
auto cachel = _mm_setzero_si128();
|
||||
auto cacheu = _mm_setzero_si128();
|
||||
int cachemask = 0;
|
||||
__m128i cachel = _mm_setzero_si128(), cacheu = _mm_setzero_si128();
|
||||
int cachemask{};
|
||||
|
||||
const auto n = sizeof(cachel);
|
||||
const auto needle_begin = needle.data();
|
||||
const auto needle_end = needle_begin + needle.size();
|
||||
auto needle_pos = needle_begin;
|
||||
|
||||
for (const auto i : ext::range(0, n))
|
||||
bool page_safe(const void * const ptr) const
|
||||
{
|
||||
cachel = _mm_srli_si128(cachel, 1);
|
||||
cacheu = _mm_srli_si128(cacheu, 1);
|
||||
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
|
||||
}
|
||||
|
||||
if (needle_pos != needle_end)
|
||||
public:
|
||||
CaseInsensitiveSearcher(const std::string & needle) : needle(needle)
|
||||
{
|
||||
if (needle.empty())
|
||||
return;
|
||||
|
||||
auto needle_pos = needle.data();
|
||||
|
||||
l = std::tolower(*needle_pos);
|
||||
u = std::toupper(*needle_pos);
|
||||
|
||||
patl = _mm_set1_epi8(l);
|
||||
patu = _mm_set1_epi8(u);
|
||||
|
||||
const auto needle_end = needle_pos + needle.size();
|
||||
|
||||
for (const auto i : ext::range(0, n))
|
||||
{
|
||||
cachel = _mm_insert_epi8(cachel, std::tolower(*needle_pos), n - 1);
|
||||
cacheu = _mm_insert_epi8(cacheu, std::toupper(*needle_pos), n - 1);
|
||||
cachemask |= 1 << i;
|
||||
++needle_pos;
|
||||
cachel = _mm_srli_si128(cachel, 1);
|
||||
cacheu = _mm_srli_si128(cacheu, 1);
|
||||
|
||||
if (needle_pos != needle_end)
|
||||
{
|
||||
cachel = _mm_insert_epi8(cachel, std::tolower(*needle_pos), n - 1);
|
||||
cacheu = _mm_insert_epi8(cacheu, std::toupper(*needle_pos), n - 1);
|
||||
cachemask |= 1 << i;
|
||||
++needle_pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto page_size = getpagesize();
|
||||
const auto page_safe = [&] (const void * const ptr) {
|
||||
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
|
||||
};
|
||||
|
||||
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) {
|
||||
if (needle_begin == needle_end)
|
||||
const UInt8 * find(const UInt8 * haystack, const UInt8 * const haystack_end) const
|
||||
{
|
||||
if (needle.empty())
|
||||
return haystack;
|
||||
|
||||
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data());
|
||||
const auto needle_end = needle_begin + needle.size();
|
||||
|
||||
while (haystack < haystack_end)
|
||||
{
|
||||
/// @todo supposedly for long strings spanning across multiple pages. Why don't we use this technique in other places?
|
||||
@ -285,13 +302,14 @@ struct PositionCaseInsensitiveImpl
|
||||
{
|
||||
if (mask == cachemask)
|
||||
{
|
||||
auto s1 = haystack + n;
|
||||
auto s2 = needle_begin + n;
|
||||
auto haystack_pos = haystack + n;
|
||||
auto needle_pos = needle_begin + n;
|
||||
|
||||
while (s1 < haystack_end && s2 < needle_end && std::tolower(*s1) == std::tolower(*s2))
|
||||
++s1, ++s2;
|
||||
while (haystack_pos < haystack_end && needle_pos < needle_end &&
|
||||
std::tolower(*haystack_pos) == std::tolower(*needle_pos))
|
||||
++haystack_pos, ++needle_pos;
|
||||
|
||||
if (s2 == needle_end)
|
||||
if (needle_pos == needle_end)
|
||||
return haystack;
|
||||
}
|
||||
}
|
||||
@ -308,13 +326,14 @@ struct PositionCaseInsensitiveImpl
|
||||
|
||||
if (*haystack == l || *haystack == u)
|
||||
{
|
||||
auto s1 = haystack + 1;
|
||||
auto s2 = needle_begin + 1;
|
||||
auto haystack_pos = haystack + 1;
|
||||
auto needle_pos = needle_begin + 1;
|
||||
|
||||
while (s1 < haystack_end && s2 < needle_end && std::tolower(*s1) == std::tolower(*s2))
|
||||
++s1, ++s2;
|
||||
while (haystack_pos < haystack_end && needle_pos < needle_end &&
|
||||
std::tolower(*haystack_pos) == std::tolower(*needle_pos))
|
||||
++haystack_pos, ++needle_pos;
|
||||
|
||||
if (s2 == needle_end)
|
||||
if (needle_pos == needle_end)
|
||||
return haystack;
|
||||
}
|
||||
|
||||
@ -322,7 +341,17 @@ struct PositionCaseInsensitiveImpl
|
||||
}
|
||||
|
||||
return haystack_end;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
using ResultType = UInt64;
|
||||
|
||||
static void vector(
|
||||
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
|
||||
PODArray<UInt64> & res)
|
||||
{
|
||||
const CaseInsensitiveSearcher searcher{needle};
|
||||
|
||||
const UInt8 * begin = &data[0];
|
||||
const UInt8 * pos = begin;
|
||||
@ -332,7 +361,7 @@ struct PositionCaseInsensitiveImpl
|
||||
size_t i = 0;
|
||||
|
||||
/// Искать будем следующее вхождение сразу во всех строках.
|
||||
while (pos < end && end != (pos = find_ci(pos, end)))
|
||||
while (pos < end && end != (pos = searcher.find(pos, end)))
|
||||
{
|
||||
/// Определим, к какому индексу оно относится.
|
||||
while (begin + offsets[i] < pos)
|
||||
@ -370,97 +399,128 @@ struct PositionCaseInsensitiveImpl
|
||||
|
||||
struct PositionCaseInsensitiveUTF8Impl
|
||||
{
|
||||
using ResultType = UInt64;
|
||||
|
||||
static void vector(
|
||||
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
|
||||
PODArray<UInt64> & res)
|
||||
private:
|
||||
class CaseInsensitiveSearcher
|
||||
{
|
||||
using UTF8SequenceBuffer = UInt8[6];
|
||||
|
||||
static const Poco::UTF8Encoding utf8;
|
||||
UTF8SequenceBuffer l_seq, u_seq;
|
||||
static constexpr auto n = sizeof(__m128i);
|
||||
|
||||
const auto first_u32 = utf8.convert(reinterpret_cast<const UInt8 *>(needle.data()));
|
||||
const auto first_l_u32 = Poco::Unicode::toLower(first_u32);
|
||||
const auto first_u_u32 = Poco::Unicode::toUpper(first_u32);
|
||||
const int page_size = getpagesize();
|
||||
|
||||
/// string to be searched for
|
||||
const std::string & needle;
|
||||
bool first_needle_symbol_is_ascii{};
|
||||
/// lower and uppercase variants of the first octet of the first character in `needle`
|
||||
utf8.convert(first_l_u32, l_seq, sizeof(l_seq));
|
||||
const auto l = l_seq[0];
|
||||
utf8.convert(first_u_u32, u_seq, sizeof(u_seq));
|
||||
const auto u = u_seq[0];
|
||||
/// for detecting leftmost position of the first symbol
|
||||
const auto patl = _mm_set1_epi8(l);
|
||||
const auto patu = _mm_set1_epi8(u);
|
||||
/// lower and uppercase vectors of first 16 octets of `needle`
|
||||
auto cachel = _mm_setzero_si128();
|
||||
auto cacheu = _mm_setzero_si128();
|
||||
int cachemask = 0;
|
||||
UInt8 l{};
|
||||
UInt8 u{};
|
||||
/// vectors filled with `l` and `u`, for determining leftmost position of the first symbol
|
||||
__m128i patl, patu;
|
||||
/// lower and uppercase vectors of first 16 characters of `needle`
|
||||
__m128i cachel = _mm_setzero_si128(), cacheu = _mm_setzero_si128();
|
||||
int cachemask{};
|
||||
std::size_t cache_valid_len{};
|
||||
std::size_t cache_actual_len{};
|
||||
|
||||
const auto n = sizeof(cachel);
|
||||
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data());
|
||||
const auto needle_end = needle_begin + needle.size();
|
||||
auto needle_pos = needle_begin;
|
||||
|
||||
for (std::size_t i = 0; i < n;)
|
||||
bool page_safe(const void * const ptr) const
|
||||
{
|
||||
if (needle_pos == needle_end)
|
||||
{
|
||||
cachel = _mm_srli_si128(cachel, 1);
|
||||
cacheu = _mm_srli_si128(cacheu, 1);
|
||||
++i;
|
||||
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
|
||||
}
|
||||
|
||||
continue;
|
||||
public:
|
||||
CaseInsensitiveSearcher(const std::string & needle) : needle(needle)
|
||||
{
|
||||
if (needle.empty())
|
||||
return;
|
||||
|
||||
static const Poco::UTF8Encoding utf8;
|
||||
UTF8SequenceBuffer l_seq, u_seq;
|
||||
|
||||
auto needle_pos = reinterpret_cast<const UInt8 *>(needle.data());
|
||||
if (*needle_pos < 0x80u)
|
||||
{
|
||||
first_needle_symbol_is_ascii = true;
|
||||
l = std::tolower(*needle_pos);
|
||||
u = std::toupper(*needle_pos);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto first_u32 = utf8.convert(needle_pos);
|
||||
const auto first_l_u32 = Poco::Unicode::toLower(first_u32);
|
||||
const auto first_u_u32 = Poco::Unicode::toUpper(first_u32);
|
||||
|
||||
/// lower and uppercase variants of the first octet of the first character in `needle`
|
||||
utf8.convert(first_l_u32, l_seq, sizeof(l_seq));
|
||||
l = l_seq[0];
|
||||
utf8.convert(first_u_u32, u_seq, sizeof(u_seq));
|
||||
u = u_seq[0];
|
||||
}
|
||||
|
||||
const auto src_len = utf8_seq_length(*needle_pos);
|
||||
const auto c_u32 = utf8.convert(needle_pos);
|
||||
/// for detecting leftmost position of the first symbol
|
||||
patl = _mm_set1_epi8(l);
|
||||
patu = _mm_set1_epi8(u);
|
||||
/// lower and uppercase vectors of first 16 octets of `needle`
|
||||
|
||||
const auto c_l_u32 = Poco::Unicode::toLower(c_u32);
|
||||
const auto c_u_u32 = Poco::Unicode::toUpper(c_u32);
|
||||
const auto needle_end = needle_pos + needle.size();
|
||||
|
||||
const auto dst_l_len = static_cast<UInt8>(utf8.convert(c_l_u32, l_seq, sizeof(l_seq)));
|
||||
const auto dst_u_len = static_cast<UInt8>(utf8.convert(c_u_u32, u_seq, sizeof(u_seq)));
|
||||
|
||||
/// @note Unicode standard states it is a rare but possible occasion
|
||||
if (!(dst_l_len == dst_u_len && dst_u_len == src_len))
|
||||
throw Exception{
|
||||
"UTF8 sequences with different lowercase and uppercase lengths are not supported",
|
||||
ErrorCodes::UNSUPPORTED_PARAMETER
|
||||
};
|
||||
|
||||
cache_actual_len += src_len;
|
||||
if (cache_actual_len < n)
|
||||
cache_valid_len += src_len;
|
||||
|
||||
for (std::size_t j = 0; j < src_len && i < n; ++j, ++i)
|
||||
for (std::size_t i = 0; i < n;)
|
||||
{
|
||||
cachel = _mm_srli_si128(cachel, 1);
|
||||
cacheu = _mm_srli_si128(cacheu, 1);
|
||||
|
||||
if (needle_pos != needle_end)
|
||||
if (needle_pos == needle_end)
|
||||
{
|
||||
cachel = _mm_insert_epi8(cachel, l_seq[j], n - 1);
|
||||
cacheu = _mm_insert_epi8(cacheu, u_seq[j], n - 1);
|
||||
cachel = _mm_srli_si128(cachel, 1);
|
||||
cacheu = _mm_srli_si128(cacheu, 1);
|
||||
++i;
|
||||
|
||||
cachemask |= 1 << i;
|
||||
++needle_pos;
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto src_len = utf8_seq_length(*needle_pos);
|
||||
const auto c_u32 = utf8.convert(needle_pos);
|
||||
|
||||
const auto c_l_u32 = Poco::Unicode::toLower(c_u32);
|
||||
const auto c_u_u32 = Poco::Unicode::toUpper(c_u32);
|
||||
|
||||
const auto dst_l_len = static_cast<UInt8>(utf8.convert(c_l_u32, l_seq, sizeof(l_seq)));
|
||||
const auto dst_u_len = static_cast<UInt8>(utf8.convert(c_u_u32, u_seq, sizeof(u_seq)));
|
||||
|
||||
/// @note Unicode standard states it is a rare but possible occasion
|
||||
if (!(dst_l_len == dst_u_len && dst_u_len == src_len))
|
||||
throw Exception{
|
||||
"UTF8 sequences with different lowercase and uppercase lengths are not supported",
|
||||
ErrorCodes::UNSUPPORTED_PARAMETER
|
||||
};
|
||||
|
||||
cache_actual_len += src_len;
|
||||
if (cache_actual_len < n)
|
||||
cache_valid_len += src_len;
|
||||
|
||||
for (std::size_t j = 0; j < src_len && i < n; ++j, ++i)
|
||||
{
|
||||
cachel = _mm_srli_si128(cachel, 1);
|
||||
cacheu = _mm_srli_si128(cacheu, 1);
|
||||
|
||||
if (needle_pos != needle_end)
|
||||
{
|
||||
cachel = _mm_insert_epi8(cachel, l_seq[j], n - 1);
|
||||
cacheu = _mm_insert_epi8(cacheu, u_seq[j], n - 1);
|
||||
|
||||
cachemask |= 1 << i;
|
||||
++needle_pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto page_size = getpagesize();
|
||||
const auto page_safe = [&] (const void * const ptr) {
|
||||
return ((page_size - 1) & reinterpret_cast<std::uintptr_t>(ptr)) <= page_size - n;
|
||||
};
|
||||
|
||||
const auto find_ci = [&] (const UInt8 * haystack, const UInt8 * const haystack_end) {
|
||||
if (needle_begin == needle_end)
|
||||
const UInt8 * find(const UInt8 * haystack, const UInt8 * const haystack_end) const
|
||||
{
|
||||
if (needle.empty())
|
||||
return haystack;
|
||||
|
||||
static const Poco::UTF8Encoding utf8;
|
||||
|
||||
const auto needle_begin = reinterpret_cast<const UInt8 *>(needle.data());
|
||||
const auto needle_end = needle_begin + needle.size();
|
||||
|
||||
while (haystack < haystack_end)
|
||||
{
|
||||
if (haystack + n <= haystack_end && page_safe(haystack))
|
||||
@ -494,19 +554,19 @@ struct PositionCaseInsensitiveUTF8Impl
|
||||
{
|
||||
if (mask == cachemask)
|
||||
{
|
||||
auto s1 = haystack + cache_valid_len;
|
||||
auto s2 = needle_begin + cache_valid_len;
|
||||
auto haystack_pos = haystack + cache_valid_len;
|
||||
auto needle_pos = needle_begin + cache_valid_len;
|
||||
|
||||
while (s1 < haystack_end && s2 < needle_end &&
|
||||
Poco::Unicode::toLower(utf8.convert(s1)) ==
|
||||
Poco::Unicode::toLower(utf8.convert(reinterpret_cast<const UInt8 *>(s2))))
|
||||
while (haystack_pos < haystack_end && needle_pos < needle_end &&
|
||||
Poco::Unicode::toLower(utf8.convert(haystack_pos)) ==
|
||||
Poco::Unicode::toLower(utf8.convert(needle_pos)))
|
||||
{
|
||||
/// @note assuming sequences for lowercase and uppercase have exact same length
|
||||
const auto len = utf8_seq_length(*s1);
|
||||
s1 += len, s2 += len;
|
||||
const auto len = utf8_seq_length(*haystack_pos);
|
||||
haystack_pos += len, needle_pos += len;
|
||||
}
|
||||
|
||||
if (s2 == needle_end)
|
||||
if (needle_pos == needle_end)
|
||||
return haystack;
|
||||
}
|
||||
}
|
||||
@ -524,18 +584,18 @@ struct PositionCaseInsensitiveUTF8Impl
|
||||
|
||||
if (*haystack == l || *haystack == u)
|
||||
{
|
||||
auto s1 = haystack;
|
||||
auto s2 = needle_begin;
|
||||
auto haystack_pos = haystack + first_needle_symbol_is_ascii;
|
||||
auto needle_pos = needle_begin + first_needle_symbol_is_ascii;
|
||||
|
||||
while (s1 < haystack_end && s2 < needle_end &&
|
||||
Poco::Unicode::toLower(utf8.convert(s1)) ==
|
||||
Poco::Unicode::toLower(utf8.convert(s2)))
|
||||
while (haystack_pos < haystack_end && needle_pos < needle_end &&
|
||||
Poco::Unicode::toLower(utf8.convert(haystack_pos)) ==
|
||||
Poco::Unicode::toLower(utf8.convert(needle_pos)))
|
||||
{
|
||||
const auto len = utf8_seq_length(*s1);
|
||||
s1 += len, s2 += len;
|
||||
const auto len = utf8_seq_length(*haystack_pos);
|
||||
haystack_pos += len, needle_pos += len;
|
||||
}
|
||||
|
||||
if (s2 == needle_end)
|
||||
if (needle_pos == needle_end)
|
||||
return haystack;
|
||||
}
|
||||
|
||||
@ -544,7 +604,17 @@ struct PositionCaseInsensitiveUTF8Impl
|
||||
}
|
||||
|
||||
return haystack_end;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
using ResultType = UInt64;
|
||||
|
||||
static void vector(
|
||||
const ColumnString::Chars_t & data, const ColumnString::Offsets_t & offsets, const std::string & needle,
|
||||
PODArray<UInt64> & res)
|
||||
{
|
||||
const CaseInsensitiveSearcher searcher{needle};
|
||||
|
||||
const UInt8 * begin = &data[0];
|
||||
const UInt8 * pos = begin;
|
||||
@ -554,7 +624,7 @@ struct PositionCaseInsensitiveUTF8Impl
|
||||
size_t i = 0;
|
||||
|
||||
/// Искать будем следующее вхождение сразу во всех строках.
|
||||
while (pos < end && end != (pos = find_ci(pos, end)))
|
||||
while (pos < end && end != (pos = searcher.find(pos, end)))
|
||||
{
|
||||
/// Определим, к какому индексу оно относится.
|
||||
while (begin + offsets[i] < pos)
|
||||
|
Loading…
Reference in New Issue
Block a user