hasToken function implementation

* Function to check if given token is present in a string;
* Special case for hasToken to 'tokenbf_v1' index;
* Test cases for hasToken()
* Test case for hasToken() + 'tokenbf_v1' integration
This commit is contained in:
Vasily Nemkov 2019-08-21 11:12:39 +03:00
parent c38a8cb755
commit 6d78e3be94
9 changed files with 481 additions and 8 deletions

View File

@ -1,5 +1,7 @@
#pragma once
#include <Common/Exception.h>
#include <Common/StringUtils/StringUtils.h>
#include <Common/UTF8Helpers.h>
#include <Core/Defines.h>
#include <ext/range.h>
@ -23,6 +25,7 @@ namespace DB
namespace ErrorCodes
{
extern const int UNSUPPORTED_PARAMETER;
extern const int BAD_ARGUMENTS;
}
@ -157,7 +160,7 @@ public:
#endif
}
ALWAYS_INLINE bool compare(const UInt8 * pos) const
ALWAYS_INLINE bool compare(const UInt8 * /*haystack*/, const UInt8 * /*haystack_end*/, const UInt8 * pos) const
{
static const Poco::UTF8Encoding utf8;
@ -374,7 +377,7 @@ public:
#endif
}
ALWAYS_INLINE bool compare(const UInt8 * pos) const
ALWAYS_INLINE bool compare(const UInt8 * /*haystack*/, const UInt8 * /*haystack_end*/, const UInt8 * pos) const
{
#ifdef __SSE4_1__
if (pageSafe(pos))
@ -567,7 +570,7 @@ public:
#endif
}
ALWAYS_INLINE bool compare(const UInt8 * pos) const
ALWAYS_INLINE bool compare(const UInt8 * /*haystack*/, const UInt8 * /*haystack_end*/, const UInt8 * pos) const
{
#ifdef __SSE4_1__
if (pageSafe(pos))
@ -697,11 +700,81 @@ public:
}
};
// Searches for needle surrounded by token-separators.
// Separators are anything inside ASCII (0-128) and not alphanum.
// Any value outside of basic ASCII (>=128) is considered a non-separator symbol, hence UTF-8 strings
// should work just fine. But any Unicode whitespace is not considered a token separtor.
template <typename StringSearcher>
class TokenSearcher
{
StringSearcher searcher;
size_t needle_size;
public:
TokenSearcher(const char * const needle_, const size_t needle_size_)
: searcher{needle_, needle_size_},
needle_size(needle_size_)
{
if (std::any_of(reinterpret_cast<const UInt8 *>(needle_), reinterpret_cast<const UInt8 *>(needle_) + needle_size_, isTokenSeparator))
{
throw Exception{"needle must not contain whitespace characters", ErrorCodes::BAD_ARGUMENTS};
}
}
ALWAYS_INLINE bool compare(const UInt8 * haystack, const UInt8 * haystack_end, const UInt8 * pos) const
{
// use searcher only if pos is in the beginning of token and pos + searcher.needle_size is end of token.
if (isToken(haystack, haystack_end, pos))
return searcher.compare(haystack, haystack_end, pos);
return false;
}
const UInt8 * search(const UInt8 * haystack, const UInt8 * const haystack_end) const
{
// use searcher.search(), then verify that returned value is a token
// if it is not, skip it and re-run
const UInt8 * pos = haystack;
while (pos < haystack_end)
{
pos = searcher.search(pos, haystack_end);
if (pos == haystack_end || isToken(haystack, haystack_end, pos))
return pos;
// assuming that heendle does not contain any token separators.
pos += needle_size;
}
return haystack_end;
}
const UInt8 * search(const UInt8 * haystack, const size_t haystack_size) const
{
return search(haystack, haystack + haystack_size);
}
ALWAYS_INLINE bool isToken(const UInt8 * haystack, const UInt8 * const haystack_end, const UInt8* p) const
{
return (p == haystack || isTokenSeparator(*(p - 1)))
&& (p + needle_size >= haystack_end || isTokenSeparator(*(p + needle_size)));
}
ALWAYS_INLINE static bool isTokenSeparator(const UInt8 c)
{
if (isAlphaNumericASCII(c) || !isASCII(c))
return false;
return true;
}
};
using ASCIICaseSensitiveStringSearcher = StringSearcher<true, true>;
using ASCIICaseInsensitiveStringSearcher = StringSearcher<false, true>;
using UTF8CaseSensitiveStringSearcher = StringSearcher<true, false>;
using UTF8CaseInsensitiveStringSearcher = StringSearcher<false, false>;
using ASCIICaseSensitiveTokenSearcher = TokenSearcher<ASCIICaseSensitiveStringSearcher>;
/** Uses functions from libc.

View File

@ -327,6 +327,8 @@ protected:
FallbackSearcher fallback_searcher;
public:
using Searcher = FallbackSearcher;
/** haystack_size_hint - the expected total size of the haystack for `search` calls. Optional (zero means unspecified).
* If you specify it small enough, the fallback algorithm will be used,
* since it is considered that it's useless to waste time initializing the hash table.
@ -373,7 +375,7 @@ public:
const auto res = pos - (hash[cell_num] - 1);
/// pointer in the code is always padded array so we can use pagesafe semantics
if (fallback_searcher.compare(res))
if (fallback_searcher.compare(haystack, haystack_end, res))
return res;
}
}
@ -520,7 +522,7 @@ public:
{
const auto res = pos - (hash[cell_num].off - 1);
const size_t ind = hash[cell_num].id;
if (res + needles[ind].size <= haystack_end && fallback_searchers[ind].compare(res))
if (res + needles[ind].size <= haystack_end && fallback_searchers[ind].compare(haystack, haystack_end, res))
return true;
}
}
@ -552,7 +554,7 @@ public:
{
const auto res = pos - (hash[cell_num].off - 1);
const size_t ind = hash[cell_num].id;
if (res + needles[ind].size <= haystack_end && fallback_searchers[ind].compare(res))
if (res + needles[ind].size <= haystack_end && fallback_searchers[ind].compare(haystack, haystack_end, res))
ans = std::min(ans, ind);
}
}
@ -590,7 +592,7 @@ public:
{
const auto res = pos - (hash[cell_num].off - 1);
const size_t ind = hash[cell_num].id;
if (res + needles[ind].size <= haystack_end && fallback_searchers[ind].compare(res))
if (res + needles[ind].size <= haystack_end && fallback_searchers[ind].compare(haystack, haystack_end, res))
ans = std::min<UInt64>(ans, res - haystack);
}
}
@ -625,7 +627,7 @@ public:
{
const auto * res = pos - (hash[cell_num].off - 1);
const size_t ind = hash[cell_num].id;
if (ans[ind] == 0 && res + needles[ind].size <= haystack_end && fallback_searchers[ind].compare(res))
if (ans[ind] == 0 && res + needles[ind].size <= haystack_end && fallback_searchers[ind].compare(haystack, haystack_end, res))
ans[ind] = count_chars(haystack, res);
}
}
@ -650,6 +652,8 @@ using VolnitskyUTF8 = VolnitskyBase<true, false, ASCIICaseSensitiveStringSearche
using VolnitskyCaseInsensitive = VolnitskyBase<false, true, ASCIICaseInsensitiveStringSearcher>; /// ignores non-ASCII bytes
using VolnitskyCaseInsensitiveUTF8 = VolnitskyBase<false, false, UTF8CaseInsensitiveStringSearcher>;
using VolnitskyToken = VolnitskyBase<true, true, ASCIICaseSensitiveTokenSearcher>;
using MultiVolnitsky = MultiVolnitskyBase<true, true, ASCIICaseSensitiveStringSearcher>;
using MultiVolnitskyUTF8 = MultiVolnitskyBase<true, false, ASCIICaseSensitiveStringSearcher>;
using MultiVolnitskyCaseInsensitive = MultiVolnitskyBase<false, true, ASCIICaseInsensitiveStringSearcher>;

View File

@ -434,6 +434,74 @@ struct MultiSearchFirstIndexImpl
}
};
/** Token search the string, means that needle must be surrounded by some separator chars, like whitespace or puctuation.
*/
template <bool negate_result = false>
struct HasTokenImpl
{
using ResultType = UInt8;
static void vector_constant(
const ColumnString::Chars & data, const ColumnString::Offsets & offsets, const std::string & pattern, PaddedPODArray<UInt8> & res)
{
if (offsets.empty())
return;
const UInt8 * begin = data.data();
const UInt8 * pos = begin;
const UInt8 * end = pos + data.size();
/// The current index in the array of strings.
size_t i = 0;
VolnitskyToken searcher(pattern.data(), pattern.size(), end - pos);
/// We will search for the next occurrence in all rows at once.
while (pos < end && end != (pos = searcher.search(pos, end - pos)))
{
/// Let's determine which index it refers to.
while (begin + offsets[i] <= pos)
{
res[i] = negate_result;
++i;
}
/// We check that the entry does not pass through the boundaries of strings.
if (pos + pattern.size() < begin + offsets[i])
res[i] = !negate_result;
else
res[i] = negate_result;
pos = begin + offsets[i];
++i;
}
/// Tail, in which there can be no substring.
if (i < res.size())
memset(&res[i], negate_result, (res.size() - i) * sizeof(res[0]));
}
static void constant_constant(const std::string & data, const std::string & pattern, UInt8 & res)
{
VolnitskyToken searcher(pattern.data(), pattern.size(), data.size());
const auto found = searcher.search(data.c_str(), data.size()) != data.end().base();
res = negate_result ^ found;
}
template <typename... Args>
static void vector_vector(Args &&...)
{
throw Exception("Function 'hasToken' does not support non-constant needle argument", ErrorCodes::ILLEGAL_COLUMN);
}
/// Search different needles in single haystack.
template <typename... Args>
static void constant_vector(Args &&...)
{
throw Exception("Function 'hasToken' does not support non-constant needle argument", ErrorCodes::ILLEGAL_COLUMN);
}
};
struct NamePosition
{
@ -516,6 +584,11 @@ struct NameMultiSearchFirstPositionCaseInsensitiveUTF8
static constexpr auto name = "multiSearchFirstPositionCaseInsensitiveUTF8";
};
struct NameHasToken
{
static constexpr auto name = "hasToken";
};
using FunctionPosition = FunctionsStringSearch<PositionImpl<PositionCaseSensitiveASCII>, NamePosition>;
using FunctionPositionUTF8 = FunctionsStringSearch<PositionImpl<PositionCaseSensitiveUTF8>, NamePositionUTF8>;
@ -542,6 +615,7 @@ using FunctionMultiSearchFirstPositionUTF8 = FunctionsMultiStringSearch<MultiSea
using FunctionMultiSearchFirstPositionCaseInsensitive = FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseInsensitiveASCII>, NameMultiSearchFirstPositionCaseInsensitive>;
using FunctionMultiSearchFirstPositionCaseInsensitiveUTF8 = FunctionsMultiStringSearch<MultiSearchFirstPositionImpl<PositionCaseInsensitiveUTF8>, NameMultiSearchFirstPositionCaseInsensitiveUTF8>;
using FunctionHasToken = FunctionsStringSearch<HasTokenImpl<false>, NameHasToken>;
void registerFunctionsStringSearch(FunctionFactory & factory)
{
@ -570,6 +644,8 @@ void registerFunctionsStringSearch(FunctionFactory & factory)
factory.registerFunction<FunctionMultiSearchFirstPositionCaseInsensitive>();
factory.registerFunction<FunctionMultiSearchFirstPositionCaseInsensitiveUTF8>();
factory.registerFunction<FunctionHasToken>();
factory.registerAlias("locate", NamePosition::name, FunctionFactory::CaseInsensitive);
}
}

View File

@ -168,6 +168,19 @@ const MergeTreeConditionFullText::AtomMap MergeTreeConditionFullText::atom_map
return true;
}
},
{
"hasToken",
[] (RPNElement & out, const Field & value, const MergeTreeIndexFullText & idx)
{
out.function = RPNElement::FUNCTION_EQUALS;
out.bloom_filter = std::make_unique<BloomFilter>(
idx.bloom_filter_size, idx.bloom_filter_hashes, idx.seed);
const auto & str = value.get<String>();
stringToBloomFilter(str.c_str(), str.size(), idx.token_extractor_func, *out.bloom_filter);
return true;
}
},
{
"startsWith",
[] (RPNElement & out, const Field & value, const MergeTreeIndexFullText & idx)

View File

@ -0,0 +1,124 @@
#!/usr/bin/env python
# encoding: utf-8
import re
HAYSTACKS = [
"hay hay hay hay hay hay hay hay hay hay hay hay hay hay hay hay hay hay needle",
"needle hay hay hay hay hay hay hay hay hay hay hay hay hay hay hay hay hay hay",
"hay hay hay hay hay hay hay hay hay needle hay hay hay hay hay hay hay hay hay",
]
NEEDLE = "needle"
HAY_RE = re.compile(r'\bhay\b', re.IGNORECASE)
NEEDLE_RE = re.compile(r'\bneedle\b', re.IGNORECASE)
def replace_follow_case(replacement):
def func(match):
g = match.group()
if g.islower(): return replacement.lower()
if g.istitle(): return replacement.title()
if g.isupper(): return replacement.upper()
return replacement
return func
def replace_separators(query, new_sep):
SEP_RE = re.compile('\\s+')
result = SEP_RE.sub(new_sep, query)
return result
def enlarge_haystack(query, times, separator=''):
return HAY_RE.sub(replace_follow_case(('hay' + separator) * times), query)
def small_needle(query):
return NEEDLE_RE.sub(replace_follow_case('n'), query)
def remove_needle(query):
return NEEDLE_RE.sub('', query)
def replace_needle(query, new_needle):
return NEEDLE_RE.sub(new_needle, query)
# with str.lower, str.uppert, str.title and such
def transform_needle(query, string_transformation_func):
def replace_with_transformation(match):
g = match.group()
return string_transformation_func(g)
return NEEDLE_RE.sub(replace_with_transformation, query)
def create_cases(table_row_template, table_query_template, const_query_template):
const_queries = []
table_rows = []
table_queries = set()
def add_case(haystack, needle, match):
match = int(match)
const_queries.append(const_query_template.format(haystack=haystack, needle=needle, match=match))
table_queries.add(table_query_template.format(haystack=haystack, needle=needle, match=match))
table_rows.append(table_row_template.format(haystack=haystack, needle=needle, match=match))
# Negative cases
add_case(remove_needle(HAYSTACKS[0]), NEEDLE, False)
for haystack in HAYSTACKS:
add_case(transform_needle(haystack, str.title), NEEDLE, False)
sep = ''
h = replace_separators(haystack, sep)
add_case(h, NEEDLE, False)
add_case(small_needle(h), small_needle(NEEDLE), False)
add_case(enlarge_haystack(h, 10, sep), NEEDLE, False)
# positive cases
for haystack in HAYSTACKS:
add_case(transform_needle(haystack, str.title), transform_needle(NEEDLE, str.title), True)
add_case(transform_needle(haystack, str.upper), transform_needle(NEEDLE, str.upper), True)
# Not checking all separators since some (like ' and \n) cause issues when coupled with
# re-based replacement and quoting in query
# other are rare in practice and checking all separators makes this test too lengthy.
# r'\\\\' turns into a single '\' in query
#separators = list(''' \t`~!@#$%^&*()-=+|]}[{";:/?.>,<''') + [r'\\\\']
separators = list(''' \t;:?.,''') + [r'\\\\']
for sep in separators:
h = replace_separators(haystack, sep)
add_case(h, NEEDLE, True)
add_case(small_needle(h), small_needle(NEEDLE), True)
add_case(enlarge_haystack(h, 200, sep), NEEDLE, True)
add_case(replace_needle(h, 'иголка'), replace_needle(NEEDLE, 'иголка'), True)
add_case(replace_needle(h, '指针'), replace_needle(NEEDLE, '指针'), True)
return table_rows, table_queries, const_queries
def main():
def query(x):
print x
CONST_QUERY = """SELECT hasToken('{haystack}', '{needle}'), ' expecting ', {match};"""
#SELECT hasToken(haystack, '{needle}') FROM ht WHERE needle = '{needle}' AND match = {match};"""
TABLE_QUERY = """WITH '{needle}' as n SELECT haystack, needle, hasToken(haystack, n) as result FROM ht WHERE needle = n AND result != match;"""
TABLE_ROW = """('{haystack}', '{needle}', {match})"""
rows, table_queries, const_queries = create_cases(TABLE_ROW, TABLE_QUERY, CONST_QUERY)
for q in const_queries:
query(q)
query("""DROP TABLE IF EXISTS ht;
CREATE TABLE IF NOT EXISTS
ht
(
haystack String,
needle String,
match UInt8
)
ENGINE MergeTree()
ORDER BY haystack;
INSERT INTO ht VALUES {values};""".format(values=", ".join(rows)))
for q in sorted(table_queries):
query(q)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,139 @@
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
0 expecting 0
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1
1 expecting 1

View File

@ -0,0 +1,8 @@
#!/usr/bin/env bash
CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
. $CURDIR/../shell_config.sh
# We should have correct env vars from shell_config.sh to run this test
python $CURDIR/00990_hasToken.python | ${CLICKHOUSE_CLIENT} -nm

View File

@ -0,0 +1,3 @@
2007
2007
2007

View File

@ -0,0 +1,33 @@
SET allow_experimental_data_skipping_indices = 1;
DROP TABLE IF EXISTS bloom_filter;
CREATE TABLE bloom_filter
(
id UInt64,
s String,
INDEX tok_bf (s, lower(s)) TYPE tokenbf_v1(512, 3, 0) GRANULARITY 1
) ENGINE = MergeTree() ORDER BY id SETTINGS index_granularity = 8;
insert into bloom_filter select number, 'yyy,uuu' from numbers(1024);
insert into bloom_filter select number+2000, 'abc,def,zzz' from numbers(8);
insert into bloom_filter select number+3000, 'yyy,uuu' from numbers(1024);
insert into bloom_filter select number+3000, 'abcdefzzz' from numbers(1024);
set max_rows_to_read = 16;
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'abc');
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'def');
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'zzz');
-- invert result
-- this does not work as expected, reading more rows that it should
-- SELECT max(id) FROM bloom_filter WHERE NOT hasToken(s, 'yyy');
-- accessing to many rows
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'yyy'); -- { serverError 158 }
-- this syntax is not supported by tokenbf
SELECT max(id) FROM bloom_filter WHERE hasToken(s, 'zzz') == 1; -- { serverError 158 }
DROP TABLE bloom_filter;