Merge pull request #9163 from ClickHouse/match-zero-byte-fix

Fixed `match` and `extract` when haystack has zero bytes.
This commit is contained in:
alexey-milovidov 2020-02-21 02:01:50 +03:00 committed by GitHub
commit d1e26f5b35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 366 additions and 89 deletions

View File

@ -290,31 +290,44 @@ OptimizedRegularExpressionImpl<thread_safe>::OptimizedRegularExpressionImpl(cons
throw DB::Exception("OptimizedRegularExpression: too many subpatterns in regexp: " + regexp_, DB::ErrorCodes::CANNOT_COMPILE_REGEXP);
}
}
if (!required_substring.empty())
{
if (is_case_insensitive)
case_insensitive_substring_searcher.emplace(required_substring.data(), required_substring.size());
else
case_sensitive_substring_searcher.emplace(required_substring.data(), required_substring.size());
}
}
template <bool thread_safe>
bool OptimizedRegularExpressionImpl<thread_safe>::match(const char * subject, size_t subject_size) const
{
const UInt8 * haystack = reinterpret_cast<const UInt8 *>(subject);
const UInt8 * haystack_end = haystack + subject_size;
if (is_trivial)
{
if (is_case_insensitive)
return nullptr != strcasestr(subject, required_substring.data());
return haystack_end != case_insensitive_substring_searcher->search(haystack, subject_size);
else
return nullptr != strstr(subject, required_substring.data());
return haystack_end != case_sensitive_substring_searcher->search(haystack, subject_size);
}
else
{
if (!required_substring.empty())
{
const char * pos;
if (is_case_insensitive)
pos = strcasestr(subject, required_substring.data());
{
if (haystack_end == case_insensitive_substring_searcher->search(haystack, subject_size))
return false;
}
else
pos = strstr(subject, required_substring.data());
if (nullptr == pos)
return 0;
{
if (haystack_end == case_sensitive_substring_searcher->search(haystack, subject_size))
return false;
}
}
return re2->Match(StringPieceType(subject, subject_size), 0, subject_size, RegexType::UNANCHORED, nullptr, 0);
@ -325,19 +338,22 @@ bool OptimizedRegularExpressionImpl<thread_safe>::match(const char * subject, si
template <bool thread_safe>
bool OptimizedRegularExpressionImpl<thread_safe>::match(const char * subject, size_t subject_size, Match & match) const
{
const UInt8 * haystack = reinterpret_cast<const UInt8 *>(subject);
const UInt8 * haystack_end = haystack + subject_size;
if (is_trivial)
{
const char * pos;
const UInt8 * pos;
if (is_case_insensitive)
pos = strcasestr(subject, required_substring.data());
pos = case_insensitive_substring_searcher->search(haystack, subject_size);
else
pos = strstr(subject, required_substring.data());
pos = case_sensitive_substring_searcher->search(haystack, subject_size);
if (pos == nullptr)
return 0;
if (haystack_end == pos)
return false;
else
{
match.offset = pos - subject;
match.offset = pos - haystack;
match.length = required_substring.size();
return 1;
}
@ -346,25 +362,25 @@ bool OptimizedRegularExpressionImpl<thread_safe>::match(const char * subject, si
{
if (!required_substring.empty())
{
const char * pos;
const UInt8 * pos;
if (is_case_insensitive)
pos = strcasestr(subject, required_substring.data());
pos = case_insensitive_substring_searcher->search(haystack, subject_size);
else
pos = strstr(subject, required_substring.data());
pos = case_sensitive_substring_searcher->search(haystack, subject_size);
if (nullptr == pos)
return 0;
if (haystack_end == pos)
return false;
}
StringPieceType piece;
if (!RegexType::PartialMatch(StringPieceType(subject, subject_size), *re2, &piece))
return 0;
return false;
else
{
match.offset = piece.data() - subject;
match.length = piece.length();
return 1;
return true;
}
}
}
@ -373,6 +389,9 @@ bool OptimizedRegularExpressionImpl<thread_safe>::match(const char * subject, si
template <bool thread_safe>
unsigned OptimizedRegularExpressionImpl<thread_safe>::match(const char * subject, size_t subject_size, MatchVec & matches, unsigned limit) const
{
const UInt8 * haystack = reinterpret_cast<const UInt8 *>(subject);
const UInt8 * haystack_end = haystack + subject_size;
matches.clear();
if (limit == 0)
@ -383,18 +402,18 @@ unsigned OptimizedRegularExpressionImpl<thread_safe>::match(const char * subject
if (is_trivial)
{
const char * pos;
const UInt8 * pos;
if (is_case_insensitive)
pos = strcasestr(subject, required_substring.data());
pos = case_insensitive_substring_searcher->search(haystack, subject_size);
else
pos = strstr(subject, required_substring.data());
pos = case_sensitive_substring_searcher->search(haystack, subject_size);
if (pos == nullptr)
if (haystack_end == pos)
return 0;
else
{
Match match;
match.offset = pos - subject;
match.offset = pos - haystack;
match.length = required_substring.size();
matches.push_back(match);
return 1;
@ -404,13 +423,13 @@ unsigned OptimizedRegularExpressionImpl<thread_safe>::match(const char * subject
{
if (!required_substring.empty())
{
const char * pos;
const UInt8 * pos;
if (is_case_insensitive)
pos = strcasestr(subject, required_substring.data());
pos = case_insensitive_substring_searcher->search(haystack, subject_size);
else
pos = strstr(subject, required_substring.data());
pos = case_sensitive_substring_searcher->search(haystack, subject_size);
if (nullptr == pos)
if (haystack_end == pos)
return 0;
}

View File

@ -3,6 +3,8 @@
#include <string>
#include <vector>
#include <memory>
#include <optional>
#include <Common/StringSearcher.h>
#include <Common/config.h>
#include <re2/re2.h>
#if USE_RE2_ST
@ -101,6 +103,8 @@ private:
bool required_substring_is_prefix;
bool is_case_insensitive;
std::string required_substring;
std::optional<DB::StringSearcher<true, true>> case_sensitive_substring_searcher;
std::optional<DB::StringSearcher<false, true>> case_insensitive_substring_searcher;
std::unique_ptr<RegexType> re2;
unsigned number_of_subpatterns;
};

View File

@ -222,6 +222,8 @@ public:
return false;
}
/** Returns haystack_end if not found.
*/
const UInt8 * search(const UInt8 * haystack, const UInt8 * const haystack_end) const
{
if (0 == needle_size)

View File

@ -284,7 +284,7 @@ private:
/// The approximate total number of rows to read. For progress bar.
size_t total_rows_approx = 0;
/// The successors must implement this function.
/// Derived classes must implement this function.
virtual Block readImpl() = 0;
/// Here you can do a preliminary initialization.

View File

@ -89,6 +89,8 @@ inline bool likePatternIsStrstr(const String & pattern, String & res)
template <bool like, bool revert = false>
struct MatchImpl
{
static constexpr bool use_default_implementation_for_constants = true;
using ResultType = UInt8;
static void vector_constant(
@ -240,12 +242,6 @@ struct MatchImpl
}
}
static void constant_constant(const std::string & data, const std::string & pattern, UInt8 & res)
{
const auto & regexp = Regexps::get<like, true>(pattern);
res = revert ^ regexp->match(data);
}
template <typename... Args>
static void vector_vector(Args &&...)
{
@ -846,29 +842,6 @@ struct ReplaceStringImpl
#undef COPY_REST_OF_CURRENT_STRING
}
}
static void constant(const std::string & data, const std::string & needle, const std::string & replacement, std::string & res_data)
{
res_data = "";
int replace_cnt = 0;
for (size_t i = 0; i < data.size(); ++i)
{
bool match = true;
if (i + needle.size() > data.size() || (replace_one && replace_cnt > 0))
match = false;
for (size_t j = 0; match && j < needle.size(); ++j)
if (data[i + j] != needle[j])
match = false;
if (match)
{
++replace_cnt;
res_data += replacement;
i = i + needle.size() - 1;
}
else
res_data += data[i];
}
}
};

View File

@ -146,6 +146,8 @@ struct PositionCaseInsensitiveUTF8
template <typename Impl>
struct PositionImpl
{
static constexpr bool use_default_implementation_for_constants = false;
using ResultType = UInt64;
/// Find one substring in many strings.
@ -459,6 +461,8 @@ struct HasTokenImpl
{
using ResultType = UInt8;
static constexpr bool use_default_implementation_for_constants = true;
static void vector_constant(
const ColumnString::Chars & data, const ColumnString::Offsets & offsets, const std::string & pattern, PaddedPODArray<UInt8> & res)
{
@ -499,13 +503,6 @@ struct HasTokenImpl
memset(&res[i], negate_result, (res.size() - i) * sizeof(res[0]));
}
static void constant_constant(const std::string & data, const std::string & pattern, UInt8 & res)
{
TokenSearcher searcher(pattern.data(), pattern.size(), data.size());
const auto found = searcher.search(data.c_str(), data.size()) != data.end().base();
res = negate_result ^ found;
}
template <typename... Args>
static void vector_vector(Args &&...)
{

View File

@ -82,6 +82,15 @@ public:
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return Impl::use_default_implementation_for_constants; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override
{
return Impl::use_default_implementation_for_constants
? ColumnNumbers{1, 2}
: ColumnNumbers{};
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (!isString(arguments[0]))
@ -105,6 +114,8 @@ public:
const ColumnConst * col_haystack_const = typeid_cast<const ColumnConst *>(&*column_haystack);
const ColumnConst * col_needle_const = typeid_cast<const ColumnConst *>(&*column_needle);
if constexpr (!Impl::use_default_implementation_for_constants)
{
if (col_haystack_const && col_needle_const)
{
ResultType res{};
@ -113,6 +124,7 @@ public:
= block.getByPosition(result).type->createColumnConst(col_haystack_const->size(), toField(res));
return;
}
}
auto col_res = ColumnVector<ResultType>::create();

View File

@ -78,6 +78,8 @@ struct ExtractParamImpl
{
using ResultType = typename ParamExtractor::ResultType;
static constexpr bool use_default_implementation_for_constants = true;
/// It is assumed that `res` is the correct size and initialized with zeros.
static void vector_constant(const ColumnString::Chars & data, const ColumnString::Offsets & offsets,
std::string needle,
@ -119,19 +121,6 @@ struct ExtractParamImpl
memset(&res[i], 0, (res.size() - i) * sizeof(res[0]));
}
static void constant_constant(const std::string & data, std::string needle, ResultType & res)
{
needle = "\"" + needle + "\":";
size_t pos = data.find(needle);
if (pos == std::string::npos)
res = 0;
else
res = ParamExtractor::extract(
reinterpret_cast<const UInt8 *>(data.data() + pos + needle.size()),
reinterpret_cast<const UInt8 *>(data.data() + data.size())
);
}
template <typename... Args> static void vector_vector(Args &&...)
{
throw Exception("Functions 'visitParamHas' and 'visitParamExtract*' doesn't support non-constant needle argument", ErrorCodes::ILLEGAL_COLUMN);

View File

@ -22,7 +22,7 @@ namespace ErrorCodes
* Unlike std::ostream, it provides access to the internal buffer,
* and also allows you to manually manage the position inside the buffer.
*
* The successors must implement the nextImpl() method.
* Derived classes must implement the nextImpl() method.
*/
class WriteBuffer : public BufferBase
{
@ -55,7 +55,7 @@ public:
pos = working_buffer.begin();
}
/** it is desirable in the successors to place the next() call in the destructor,
/** it is desirable in the derived classes to place the next() call in the destructor,
* so that the last data is written
*/
virtual ~WriteBuffer() {}

View File

@ -0,0 +1,264 @@
1
1
1
1
1
1
4 key="v" 10 v
\0 key="v" 10 v
0 v
1 v
2 v
3 v
4 v
5 v
6 v
7 v
8 v
9 v
10 v
11 v
12 v
13 v
14 v
15 v
16 v
17 v
18 v
19 v
20 v
21 v
22 v
23 v
24 v
25 v
26 v
27 v
28 v
29 v
30 v
31 v
32 v
33 v
34 v
35 v
36 v
37 v
38 v
39 v
40 v
41 v
42 v
43 v
44 v
45 v
46 v
47 v
48 v
49 v
50 v
51 v
52 v
53 v
54 v
55 v
56 v
57 v
58 v
59 v
60 v
61 v
62 v
63 v
64 v
65 v
66 v
67 v
68 v
69 v
70 v
71 v
72 v
73 v
74 v
75 v
76 v
77 v
78 v
79 v
80 v
81 v
82 v
83 v
84 v
85 v
86 v
87 v
88 v
89 v
90 v
91 v
92 v
93 v
94 v
95 v
96 v
97 v
98 v
99 v
100 v
101 v
102 v
103 v
104 v
105 v
106 v
107 v
108 v
109 v
110 v
111 v
112 v
113 v
114 v
115 v
116 v
117 v
118 v
119 v
120 v
121 v
122 v
123 v
124 v
125 v
126 v
127 v
128 v
129 v
130 v
131 v
132 v
133 v
134 v
135 v
136 v
137 v
138 v
139 v
140 v
141 v
142 v
143 v
144 v
145 v
146 v
147 v
148 v
149 v
150 v
151 v
152 v
153 v
154 v
155 v
156 v
157 v
158 v
159 v
160 v
161 v
162 v
163 v
164 v
165 v
166 v
167 v
168 v
169 v
170 v
171 v
172 v
173 v
174 v
175 v
176 v
177 v
178 v
179 v
180 v
181 v
182 v
183 v
184 v
185 v
186 v
187 v
188 v
189 v
190 v
191 v
192 v
193 v
194 v
195 v
196 v
197 v
198 v
199 v
200 v
201 v
202 v
203 v
204 v
205 v
206 v
207 v
208 v
209 v
210 v
211 v
212 v
213 v
214 v
215 v
216 v
217 v
218 v
219 v
220 v
221 v
222 v
223 v
224 v
225 v
226 v
227 v
228 v
229 v
230 v
231 v
232 v
233 v
234 v
235 v
236 v
237 v
238 v
239 v
240 v
241 v
242 v
243 v
244 v
245 v
246 v
247 v
248 v
249 v
250 v
251 v
252 v
253 v
254 v
255 v

View File

@ -0,0 +1,17 @@
select match('a key="v" ', 'key="(.*?)"');
select match(materialize('a key="v" '), 'key="(.*?)"');
select match('\0 key="v" ', 'key="(.*?)"');
select match(materialize('\0 key="v" '), 'key="(.*?)"');
select multiMatchAny('\0 key="v" ', ['key="(.*?)"']);
select multiMatchAny(materialize('\0 key="v" '), ['key="(.*?)"']);
select unhex('34') || ' key="v" ' as haystack, length(haystack), extract( haystack, 'key="(.*?)"') as needle;
-- works, result = v
select unhex('00') || ' key="v" ' as haystack, length(haystack), extract( haystack, 'key="(.*?)"') as needle;
-- before fix: returns nothing (zero-byte in the begining of haystack)
select number as char_code, extract( char(char_code) || ' key="v" ' as haystack, 'key="(.*?)"') as needle from numbers(256);
-- every other chars codes (except of zero byte) works ok