Updated implementation of find_first_symbols for run-time needle

Implemented compile-time and run-time dispatching between SSE4.2 and SSE2 implementations
Added find_first_symbols_sse2
Added tests
This commit is contained in:
Vasily Nemkov 2023-03-24 15:45:12 +01:00
parent 487296c1c3
commit e22eee8055
6 changed files with 204 additions and 241 deletions

View File

@ -40,7 +40,7 @@ template <char ...chars> constexpr bool is_in(char x) { return ((x == chars) ||
static bool is_in(char c, const char * symbols, size_t num_chars) static bool is_in(char c, const char * symbols, size_t num_chars)
{ {
for (auto i = 0u; i < num_chars; i++) for (size_t i = 0u; i < num_chars; ++i)
{ {
if (c == symbols[i]) if (c == symbols[i])
{ {
@ -66,6 +66,43 @@ inline __m128i mm_is_in(__m128i bytes)
__m128i eq = mm_is_in<s1, tail...>(bytes); __m128i eq = mm_is_in<s1, tail...>(bytes);
return _mm_or_si128(eq0, eq); return _mm_or_si128(eq0, eq);
} }
inline __m128i mm_is_in(__m128i bytes, const char * symbols, size_t num_chars)
{
__m128i accumulator = _mm_setzero_si128();
for (size_t i = 0; i < num_chars; ++i)
{
__m128i eq = _mm_cmpeq_epi8(bytes, _mm_set1_epi8(symbols[i]));
accumulator = _mm_or_si128(accumulator, eq);
}
return accumulator;
}
inline std::vector<__m128i> mm_is_in_prepare(const char * symbols, size_t num_chars)
{
std::vector<__m128i> result;
result.reserve(num_chars);
for (size_t i = 0; i < num_chars; ++i)
{
result.emplace_back(_mm_set1_epi8(symbols[i]));
}
return result;
}
inline __m128i mm_is_in_execute(__m128i bytes, const std::vector<__m128i> & needles)
{
__m128i accumulator = _mm_setzero_si128();
for (const auto & needle : needles)
{
__m128i eq = _mm_cmpeq_epi8(bytes, needle);
accumulator = _mm_or_si128(accumulator, eq);
}
return accumulator;
}
#endif #endif
template <bool positive> template <bool positive>
@ -112,6 +149,32 @@ inline const char * find_first_symbols_sse2(const char * const begin, const char
return return_mode == ReturnMode::End ? end : nullptr; return return_mode == ReturnMode::End ? end : nullptr;
} }
template <bool positive, ReturnMode return_mode>
inline const char * find_first_symbols_sse2(const char * const begin, const char * const end, const char * symbols, size_t num_chars)
{
const char * pos = begin;
const auto needles = mm_is_in_prepare(symbols, num_chars);
#if defined(__SSE2__)
for (; pos + 15 < end; pos += 16)
{
__m128i bytes = _mm_loadu_si128(reinterpret_cast<const __m128i *>(pos));
__m128i eq = mm_is_in_execute(bytes, needles);
uint16_t bit_mask = maybe_negate<positive>(uint16_t(_mm_movemask_epi8(eq)));
if (bit_mask)
return pos + __builtin_ctz(bit_mask);
}
#endif
for (; pos < end; ++pos)
if (maybe_negate<positive>(is_in(*pos, symbols, num_chars)))
return pos;
return return_mode == ReturnMode::End ? end : nullptr;
}
template <bool positive, ReturnMode return_mode, char... symbols> template <bool positive, ReturnMode return_mode, char... symbols>
inline const char * find_last_symbols_sse2(const char * const begin, const char * const end) inline const char * find_last_symbols_sse2(const char * const begin, const char * const end)
@ -192,21 +255,6 @@ inline const char * find_first_symbols_sse42(const char * const begin, const cha
return return_mode == ReturnMode::End ? end : nullptr; return return_mode == ReturnMode::End ? end : nullptr;
} }
/// NOTE No SSE 4.2 implementation for find_last_symbols_or_null. Not worth to do.
template <bool positive, ReturnMode return_mode, char... symbols>
inline const char * find_first_symbols_dispatch(const char * begin, const char * end)
requires(0 <= sizeof...(symbols) && sizeof...(symbols) <= 16)
{
#if defined(__SSE4_2__)
if (sizeof...(symbols) >= 5)
return find_first_symbols_sse42<positive, return_mode, sizeof...(symbols), symbols...>(begin, end);
else
#endif
return find_first_symbols_sse2<positive, return_mode, symbols...>(begin, end);
}
template <bool positive, ReturnMode return_mode> template <bool positive, ReturnMode return_mode>
inline const char * find_first_symbols_sse42(const char * const begin, const char * const end, const char * symbols, size_t num_chars) inline const char * find_first_symbols_sse42(const char * const begin, const char * const end, const char * symbols, size_t num_chars)
{ {
@ -241,10 +289,30 @@ inline const char * find_first_symbols_sse42(const char * const begin, const cha
return return_mode == ReturnMode::End ? end : nullptr; return return_mode == ReturnMode::End ? end : nullptr;
} }
template <bool positive, ReturnMode return_mode> /// NOTE No SSE 4.2 implementation for find_last_symbols_or_null. Not worth to do.
auto find_first_symbols_sse42(std::string_view haystack, std::string_view symbols)
template <bool positive, ReturnMode return_mode, char... symbols>
inline const char * find_first_symbols_dispatch(const char * begin, const char * end)
requires(0 <= sizeof...(symbols) && sizeof...(symbols) <= 16)
{ {
return find_first_symbols_sse42<positive, return_mode>(haystack.begin(), haystack.end(), symbols.begin(), symbols.size()); #if defined(__SSE4_2__)
if (sizeof...(symbols) >= 5)
return find_first_symbols_sse42<positive, return_mode, sizeof...(symbols), symbols...>(begin, end);
else
#endif
return find_first_symbols_sse2<positive, return_mode, symbols...>(begin, end);
}
template <bool positive, ReturnMode return_mode>
inline const char * find_first_symbols_dispatch(const std::string_view haystack, const std::string_view symbols)
{
const size_t num_chars = std::max<size_t>(symbols.size(), 16);
#if defined(__SSE4_2__)
if (num_chars >= 5)
return find_first_symbols_sse42<positive, return_mode>(haystack.begin(), haystack.end(), symbols.begin(), num_chars);
else
#endif
return find_first_symbols_sse2<positive, return_mode>(haystack.begin(), haystack.end(), symbols.begin(), num_chars);
} }
} }
@ -266,7 +334,7 @@ inline char * find_first_symbols(char * begin, char * end)
inline const char * find_first_symbols(std::string_view haystack, std::string_view symbols) inline const char * find_first_symbols(std::string_view haystack, std::string_view symbols)
{ {
return detail::find_first_symbols_sse42<true, detail::ReturnMode::End>(haystack, symbols); return detail::find_first_symbols_dispatch<true, detail::ReturnMode::End>(haystack, symbols);
} }
template <char... symbols> template <char... symbols>
@ -283,7 +351,7 @@ inline char * find_first_not_symbols(char * begin, char * end)
inline const char * find_first_not_symbols(std::string_view haystack, std::string_view symbols) inline const char * find_first_not_symbols(std::string_view haystack, std::string_view symbols)
{ {
return detail::find_first_symbols_sse42<false, detail::ReturnMode::End>(haystack, symbols); return detail::find_first_symbols_dispatch<false, detail::ReturnMode::End>(haystack, symbols);
} }
template <char... symbols> template <char... symbols>
@ -300,7 +368,7 @@ inline char * find_first_symbols_or_null(char * begin, char * end)
inline const char * find_first_symbols_or_null(std::string_view haystack, std::string_view symbols) inline const char * find_first_symbols_or_null(std::string_view haystack, std::string_view symbols)
{ {
return detail::find_first_symbols_sse42<true, detail::ReturnMode::Nullptr>(haystack, symbols); return detail::find_first_symbols_dispatch<true, detail::ReturnMode::Nullptr>(haystack, symbols);
} }
template <char... symbols> template <char... symbols>
@ -317,7 +385,7 @@ inline char * find_first_not_symbols_or_null(char * begin, char * end)
inline const char * find_first_not_symbols_or_null(std::string_view haystack, std::string_view symbols) inline const char * find_first_not_symbols_or_null(std::string_view haystack, std::string_view symbols)
{ {
return detail::find_first_symbols_sse42<false, detail::ReturnMode::Nullptr>(haystack, symbols); return detail::find_first_symbols_dispatch<false, detail::ReturnMode::Nullptr>(haystack, symbols);
} }
template <char... symbols> template <char... symbols>

View File

@ -23,7 +23,7 @@ void test_find_first_not(const std::string & haystack, const std::string & symbo
TEST(FindSymbols, SimpleTest) TEST(FindSymbols, SimpleTest)
{ {
std::string s = "Hello, world! Goodbye..."; const std::string s = "Hello, world! Goodbye...";
const char * begin = s.data(); const char * begin = s.data();
const char * end = s.data() + s.size(); const char * end = s.data() + s.size();
@ -34,6 +34,14 @@ TEST(FindSymbols, SimpleTest)
ASSERT_EQ(find_first_symbols<'H'>(begin, end), begin); ASSERT_EQ(find_first_symbols<'H'>(begin, end), begin);
ASSERT_EQ((find_first_symbols<'a', 'e'>(begin, end)), begin + 1); ASSERT_EQ((find_first_symbols<'a', 'e'>(begin, end)), begin + 1);
// Check that nothing matches on big haystack,
ASSERT_EQ(find_first_symbols(s, "ABCDEFIJKLMNOPQRSTUVWXYZacfghijkmnpqstuvxz"), end);
// only 16 bytes of haystack are checked, so nothing is found
ASSERT_EQ(find_first_symbols(s, "ABCDEFIJKLMNOPQR0helloworld"), end);
// 16-byte needle
ASSERT_EQ(find_first_symbols(s, "XYZ!,.GHbdelorwy"), begin + 12);
ASSERT_EQ(find_last_symbols_or_null<'a'>(begin, end), nullptr); ASSERT_EQ(find_last_symbols_or_null<'a'>(begin, end), nullptr);
ASSERT_EQ(find_last_symbols_or_null<'e'>(begin, end), end - 4); ASSERT_EQ(find_last_symbols_or_null<'e'>(begin, end), end - 4);
ASSERT_EQ(find_last_symbols_or_null<'.'>(begin, end), end - 1); ASSERT_EQ(find_last_symbols_or_null<'.'>(begin, end), end - 1);
@ -54,6 +62,79 @@ TEST(FindSymbols, SimpleTest)
} }
} }
TEST(FindSymbols, RunTimeNeedle)
{
auto test_haystack = [](const auto & haystack, const auto & unfindable_needle) {
#define TEST_HAYSTACK_AND_NEEDLE(haystack_, needle_) \
do { \
const auto & h = haystack_; \
const auto & n = needle_; \
EXPECT_EQ( \
std::find_first_of(h.data(), h.data() + h.size(), n.data(), n.data() + n.size()), \
find_first_symbols(h, n) \
) << "haystack: \"" << h << "\"" \
<< ", needle: \"" << n << "\""; \
} \
while (false)
// can't find
TEST_HAYSTACK_AND_NEEDLE(haystack, unfindable_needle);
#define test_with_modified_needle(haystack, in_needle, needle_update, with) \
do \
{ \
std::string needle = in_needle; \
needle_update = with; \
TEST_HAYSTACK_AND_NEEDLE(haystack, needle); \
} \
while (false)
// findable symbol is at beginnig of the needle
// Can find at first pos of haystack
test_with_modified_needle(haystack, unfindable_needle, needle.front(), haystack.front());
// Can find at first pos of haystack
test_with_modified_needle(haystack, unfindable_needle, needle.front(), haystack.back());
// Can find in the middle of haystack
test_with_modified_needle(haystack, unfindable_needle, needle.front(), haystack[haystack.size() / 2]);
// findable symbol is at end of the needle
// Can find at first pos of haystack
test_with_modified_needle(haystack, unfindable_needle, needle.back(), haystack.front());
// Can find at first pos of haystack
test_with_modified_needle(haystack, unfindable_needle, needle.back(), haystack.back());
// Can find in the middle of haystack
test_with_modified_needle(haystack, unfindable_needle, needle.back(), haystack[haystack.size() / 2]);
// findable symbol is in the middle of the needle
// Can find at first pos of haystack
test_with_modified_needle(haystack, unfindable_needle, needle[needle.size() / 2], haystack.front());
// Can find at first pos of haystack
test_with_modified_needle(haystack, unfindable_needle, needle[needle.size() / 2], haystack.back());
// Can find in the middle of haystack
test_with_modified_needle(haystack, unfindable_needle, needle[needle.size() / 2], haystack[haystack.size() / 2]);
};
// there are 4 major groups of cases:
// haystack < 16 bytes, haystack > 16 bytes
// needle < 5 bytes, needle >= 5 bytes
// First and last symbols of haystack should be unique
const std::string long_haystack = "Hello, world! Goodbye...?";
const std::string short_haystack = "Hello, world!";
// In sync with find_first_symbols_dispatch code: long needles receve special treatment.
// as of now "long" means >= 5
const std::string unfindable_long_needle = "0123456789ABCDEF";
const std::string unfindable_short_needle = "0123";
test_haystack(long_haystack, unfindable_long_needle);
test_haystack(long_haystack, unfindable_short_needle);
test_haystack(short_haystack, unfindable_long_needle);
test_haystack(short_haystack, unfindable_short_needle);
}
TEST(FindNotSymbols, AllSymbolsPresent) TEST(FindNotSymbols, AllSymbolsPresent)
{ {
std::string str_with_17_bytes = "hello world hello"; std::string str_with_17_bytes = "hello world hello";

View File

@ -1,183 +0,0 @@
#include "InlineEscapingKeyStateHandler.h"
#include <Functions/keyvaluepair/src/impl/state/strategies/util/CharacterFinder.h>
#include <Functions/keyvaluepair/src/impl/state/strategies/util/EscapedCharacterReader.h>
#include <Functions/keyvaluepair/src/impl/state/strategies/util/NeedleFactory.h>
namespace DB
{
InlineEscapingKeyStateHandler::InlineEscapingKeyStateHandler(Configuration configuration_)
: extractor_configuration(std::move(configuration_))
{
wait_needles = EscapingNeedleFactory::getWaitNeedles(extractor_configuration);
read_needles = EscapingNeedleFactory::getReadNeedles(extractor_configuration);
read_quoted_needles = EscapingNeedleFactory::getReadQuotedNeedles(extractor_configuration);
}
NextState InlineEscapingKeyStateHandler::wait(std::string_view file, size_t pos) const
{
BoundsSafeCharacterFinder finder;
const auto quoting_character = extractor_configuration.quoting_character;
while (auto character_position_opt = finder.findFirstNot(file, pos, wait_needles))
{
auto character_position = *character_position_opt;
auto character = file[character_position];
if (quoting_character == character)
{
return {character_position + 1u, State::READING_QUOTED_KEY};
}
else
{
return {character_position, State::READING_KEY};
}
}
return {file.size(), State::END};
}
/*
* I only need to iteratively copy stuff if there are escape sequences. If not, views are sufficient.
* TSKV has a nice catch for that, implementers kept an auxiliary string to hold copied characters.
* If I find a key value delimiter and that is empty, I do not need to copy? hm,m hm hm
* */
NextState InlineEscapingKeyStateHandler::read(std::string_view file, size_t pos, ElementType & key) const
{
BoundsSafeCharacterFinder finder;
const auto & [key_value_delimiter, quoting_character, pair_delimiters]
= extractor_configuration;
key.clear();
/*
* Maybe modify finder return type to be the actual pos. In case of failures, it shall return pointer to the end.
* It might help updating current pos?
* */
while (auto character_position_opt = finder.findFirst(file, pos, read_needles))
{
auto character_position = *character_position_opt;
auto character = file[character_position];
auto next_pos = character_position + 1u;
if (EscapedCharacterReader::ESCAPE_CHARACTER == character)
{
for (auto i = pos; i < character_position; i++)
{
key.push_back(file[i]);
}
auto [next_byte_ptr, escaped_characters] = EscapedCharacterReader::read(file, character_position);
next_pos = next_byte_ptr - file.begin();
if (escaped_characters.empty())
{
return {next_pos, State::WAITING_KEY};
}
else
{
for (auto escaped_character : escaped_characters)
{
key.push_back(escaped_character);
}
}
}
else if (character == key_value_delimiter)
{
// todo try to optimize with resize and memcpy
for (auto i = pos; i < character_position; i++)
{
key.push_back(file[i]);
}
return {next_pos, State::WAITING_VALUE};
}
else if (std::find(pair_delimiters.begin(), pair_delimiters.end(), character) != pair_delimiters.end())
{
return {next_pos, State::WAITING_KEY};
}
pos = next_pos;
}
// might be problematic in case string reaches the end and I haven't copied anything over to key
return {file.size(), State::END};
}
NextState InlineEscapingKeyStateHandler::readQuoted(std::string_view file, size_t pos, ElementType & key) const
{
BoundsSafeCharacterFinder finder;
const auto quoting_character = extractor_configuration.quoting_character;
key.clear();
while (auto character_position_opt = finder.findFirst(file, pos, read_quoted_needles))
{
auto character_position = *character_position_opt;
auto character = file[character_position];
auto next_pos = character_position + 1u;
if (character == EscapedCharacterReader::ESCAPE_CHARACTER)
{
for (auto i = pos; i < character_position; i++)
{
key.push_back(file[i]);
}
auto [next_byte_ptr, escaped_characters] = EscapedCharacterReader::read(file, character_position);
next_pos = next_byte_ptr - file.begin();
if (escaped_characters.empty())
{
return {next_pos, State::WAITING_KEY};
}
else
{
for (auto escaped_character : escaped_characters)
{
key.push_back(escaped_character);
}
}
}
else if (quoting_character == character)
{
// todo try to optimize with resize and memcpy
for (auto i = pos; i < character_position; i++)
{
key.push_back(file[i]);
}
if (key.empty())
{
return {next_pos, State::WAITING_KEY};
}
return {next_pos, State::READING_KV_DELIMITER};
}
pos = next_pos;
}
return {file.size(), State::END};
}
NextState InlineEscapingKeyStateHandler::readKeyValueDelimiter(std::string_view file, size_t pos) const
{
if (pos == file.size())
{
return {pos, State::END};
}
else
{
const auto current_character = file[pos++];
return {pos, extractor_configuration.key_value_delimiter == current_character ? State::WAITING_VALUE : State::WAITING_KEY};
}
}
}

View File

@ -1,34 +0,0 @@
#pragma once
#include <optional>
#include <string>
#include <Functions/keyvaluepair/src/impl/state/Configuration.h>
#include <Functions/keyvaluepair/src/impl/state/StateHandler.h>
namespace DB
{
class InlineEscapingKeyStateHandler : public StateHandler
{
public:
using ElementType = std::string;
explicit InlineEscapingKeyStateHandler(Configuration configuration_);
[[nodiscard]] NextState wait(std::string_view file, size_t pos) const;
[[nodiscard]] NextState read(std::string_view file, size_t pos, ElementType & key) const;
[[nodiscard]] NextState readQuoted(std::string_view file, size_t pos, ElementType & key) const;
[[nodiscard]] NextState readKeyValueDelimiter(std::string_view file, size_t pos) const;
private:
Configuration extractor_configuration;
std::vector<char> wait_needles;
std::vector<char> read_needles;
std::vector<char> read_quoted_needles;
};
}

View File

@ -3,6 +3,13 @@
#include <Functions/keyvaluepair/src/impl/state/strategies/util/EscapedCharacterReader.h> #include <Functions/keyvaluepair/src/impl/state/strategies/util/EscapedCharacterReader.h>
#include <Functions/keyvaluepair/src/impl/state/strategies/util/NeedleFactory.h> #include <Functions/keyvaluepair/src/impl/state/strategies/util/NeedleFactory.h>
#include <IO/ReadBufferFromMemory.h>
namespace
{
}
namespace DB namespace DB
{ {

View File

@ -9,6 +9,30 @@
namespace DB namespace DB
{ {
class InlineEscapingKeyStateHandler : public StateHandler
{
public:
using ElementType = std::string;
explicit InlineEscapingKeyStateHandler(Configuration configuration_);
[[nodiscard]] NextState wait(std::string_view file, size_t pos) const;
[[nodiscard]] NextState read(std::string_view file, size_t pos, ElementType & key) const;
[[nodiscard]] NextState readQuoted(std::string_view file, size_t pos, ElementType & key) const;
[[nodiscard]] NextState readKeyValueDelimiter(std::string_view file, size_t pos) const;
private:
Configuration extractor_configuration;
std::vector<char> wait_needles;
std::vector<char> read_needles;
std::vector<char> read_quoted_needles;
};
class InlineEscapingValueStateHandler : public StateHandler class InlineEscapingValueStateHandler : public StateHandler
{ {
public: public: