add new function startsWithUTF8 and endsWithUTF8

This commit is contained in:
taiyang-li 2023-07-25 15:34:38 +08:00
parent 3c034d563b
commit a5d7391fbf
6 changed files with 175 additions and 10 deletions

View File

@ -693,6 +693,22 @@ Returns whether string `str` ends with `suffix`.
endsWith(str, suffix)
```
## endsWithUTF8
Returns whether string `str` ends with `suffix`, assuming that both strings contain svalid UTF-8 encoded text. If this assumption is violated, no exception is thrown and the result is undefined.
**Syntax**
```sql
endsWithUTF8(str, suffix)
```
**Example**
``` sql
SELECT endsWithUTF8('中国', '国');
```
## startsWith
Returns whether string `str` starts with `prefix`.
@ -709,6 +725,18 @@ startsWith(str, prefix)
SELECT startsWith('Spider-Man', 'Spi');
```
## startsWithUTF8
Returns whether string `str` starts with `prefix`, assuming that both string contains valid UTF-8 encoded text. If this assumption is violated, no exception is thrown and the result is undefined.
**Example**
``` sql
SELECT startsWithUTF8('中国', '中');
```
## trim
Removes the specified characters from the start or end of a string. If not specified otherwise, the function removes whitespace (ASCII-character 32).

View File

@ -28,10 +28,23 @@ namespace ErrorCodes
struct NameStartsWith
{
static constexpr auto name = "startsWith";
static constexpr auto is_utf8 = false;
};
struct NameEndsWith
{
static constexpr auto name = "endsWith";
static constexpr auto is_utf8 = false;
};
struct NameStartsWithUTF8
{
static constexpr auto name = "startsWithUTF8";
static constexpr auto is_utf8 = true;
};
struct NameEndsWithUTF8
{
static constexpr auto name = "endsWithUTF8";
static constexpr auto is_utf8 = true;
};
DECLARE_MULTITARGET_CODE(
@ -41,6 +54,7 @@ class FunctionStartsEndsWith : public IFunction
{
public:
static constexpr auto name = Name::name;
static constexpr auto is_utf8 = Name::is_utf8;
String getName() const override
{
@ -64,7 +78,8 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{
if (isStringOrFixedString(arguments[0]) && isStringOrFixedString(arguments[1]))
if (!is_utf8 && isStringOrFixedString(arguments[0]) && isStringOrFixedString(arguments[1])
|| isString(arguments[0]) && isString(arguments[1]))
return std::make_shared<DataTypeUInt8>();
if (isArray(arguments[0]) && isArray(arguments[1]))
@ -78,8 +93,11 @@ public:
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
auto data_type = arguments[0].type;
if (isStringOrFixedString(*data_type))
if (!is_utf8 && isStringOrFixedString(*data_type))
return executeImplString(arguments, {}, input_rows_count);
if (is_utf8 && isString(*data_type))
return executeImplStringUTF8(arguments, {}, input_rows_count);
if (isArray(data_type))
return executeImplArray(arguments, {}, input_rows_count);
return {};
@ -131,7 +149,6 @@ private:
typename ColumnVector<UInt8>::Container & vec_res = col_res->getData();
vec_res.resize(input_rows_count);
if (const ColumnString * haystack = checkAndGetColumn<ColumnString>(haystack_column))
dispatch<StringSource>(StringSource(*haystack), needle_column, vec_res);
else if (const ColumnFixedString * haystack_fixed = checkAndGetColumn<ColumnFixedString>(haystack_column))
@ -146,6 +163,26 @@ private:
return col_res;
}
ColumnPtr executeImplStringUTF8(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const
{
const IColumn * haystack_column = arguments[0].column.get();
const IColumn * needle_column = arguments[1].column.get();
auto col_res = ColumnVector<UInt8>::create();
typename ColumnVector<UInt8>::Container & vec_res = col_res->getData();
vec_res.resize(input_rows_count);
if (const ColumnString * haystack = checkAndGetColumn<ColumnString>(haystack_column))
dispatchUTF8<UTF8StringSource>(UTF8StringSource(*haystack), needle_column, vec_res);
else if (const ColumnConst * haystack_const = checkAndGetColumnConst<ColumnString>(haystack_column))
dispatchUTF8<ConstSource<UTF8StringSource>>(ConstSource<UTF8StringSource>(*haystack_const), needle_column, vec_res);
else
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal combination of columns as arguments of function {}", getName());
return col_res;
}
template <typename HaystackSource>
void dispatch(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray<UInt8> & res_data) const
{
@ -161,6 +198,17 @@ private:
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal combination of columns as arguments of function {}", getName());
}
template <typename HaystackSource>
void dispatchUTF8(HaystackSource haystack_source, const IColumn * needle_column, PaddedPODArray<UInt8> & res_data) const
{
if (const ColumnString * needle = checkAndGetColumn<ColumnString>(needle_column))
execute<HaystackSource, UTF8StringSource>(haystack_source, UTF8StringSource(*needle), res_data);
else if (const ColumnConst * needle_const = checkAndGetColumnConst<ColumnString>(needle_column))
execute<HaystackSource, ConstSource<UTF8StringSource>>(haystack_source, ConstSource<UTF8StringSource>(*needle_const), res_data);
else
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal combination of columns as arguments of function {}", getName());
}
template <typename HaystackSource, typename NeedleSource>
static void execute(HaystackSource haystack_source, NeedleSource needle_source, PaddedPODArray<UInt8> & res_data)
{
@ -172,18 +220,27 @@ private:
auto needle = needle_source.getWhole();
if (needle.size > haystack.size)
{
res_data[row_num] = false;
else
{
if constexpr (std::is_same_v<Name, NameStartsWith>) /// startsWith
res_data[row_num] = StringRef(haystack.data, needle.size) == StringRef(needle.data, needle.size);
else if constexpr (std::is_same_v<Name, NameEndsWith>) /// endsWith
res_data[row_num] = StringRef(haystack.data + haystack.size - needle.size, needle.size) == StringRef(needle.data, needle.size);
else /// startsWithUTF8 or endsWithUTF8
{
auto length = UTF8::countCodePoints(needle.data, needle.size);
if constexpr (std::is_same_v<Name, NameStartsWithUTF8>)
{
auto slice = haystack_source.getSliceFromLeft(0, length);
res_data[row_num] = StringRef(slice.data, slice.size) == StringRef(needle.data, needle.size);
}
else
{
if constexpr (std::is_same_v<Name, NameStartsWith>)
{
res_data[row_num] = StringRef(haystack.data, needle.size) == StringRef(needle.data, needle.size);
auto slice = haystack_source.getSliceFromRight(length);
res_data[row_num] = StringRef(slice.data, slice.size) == StringRef(needle.data, needle.size);
}
else /// endsWith
{
res_data[row_num] = StringRef(haystack.data + haystack.size - needle.size, needle.size) == StringRef(needle.data, needle.size);
}
}

View File

@ -0,0 +1,16 @@
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionStartsEndsWith.h>
namespace DB
{
using FunctionEndsWithUTF8 = FunctionStartsEndsWith<NameEndsWithUTF8>;
REGISTER_FUNCTION(EndsWithUTF8)
{
factory.registerFunction<FunctionEndsWithUTF8>();
}
}

View File

@ -0,0 +1,16 @@
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionStartsEndsWith.h>
namespace DB
{
using FunctionStartsWithUTF8 = FunctionStartsEndsWith<NameStartsWithUTF8>;
REGISTER_FUNCTION(StartsWithUTF8)
{
factory.registerFunction<FunctionStartsWithUTF8>();
}
}

View File

@ -0,0 +1,29 @@
-- { echoOn }
select startsWithUTF8('富强民主文明和谐', '富强');
1
select startsWithUTF8('富强民主文明和谐', '\xe5');
0
select startsWithUTF8('富强民主文明和谐', '');
1
SELECT startsWithUTF8('123', '123');
1
SELECT startsWithUTF8('123', '12');
1
SELECT startsWithUTF8('123', '1234');
0
SELECT startsWithUTF8('123', '');
1
select endsWithUTF8('富强民主文明和谐', '和谐');
1
select endsWithUTF8('富强民主文明和谐', '\x90');
0
select endsWithUTF8('富强民主文明和谐', '');
1
SELECT endsWithUTF8('123', '3');
1
SELECT endsWithUTF8('123', '23');
1
SELECT endsWithUTF8('123', '32');
0
SELECT endsWithUTF8('123', '');
1

View File

@ -0,0 +1,19 @@
-- { echoOn }
select startsWithUTF8('富强民主文明和谐', '富强');
select startsWithUTF8('富强民主文明和谐', '\xe5');
select startsWithUTF8('富强民主文明和谐', '');
SELECT startsWithUTF8('123', '123');
SELECT startsWithUTF8('123', '12');
SELECT startsWithUTF8('123', '1234');
SELECT startsWithUTF8('123', '');
select endsWithUTF8('富强民主文明和谐', '和谐');
select endsWithUTF8('富强民主文明和谐', '\x90');
select endsWithUTF8('富强民主文明和谐', '');
SELECT endsWithUTF8('123', '3');
SELECT endsWithUTF8('123', '23');
SELECT endsWithUTF8('123', '32');
SELECT endsWithUTF8('123', '');
-- { echoOff }