Merge pull request #47134 from KevinyhZou/enable_int_types_repeat_function

This commit is contained in:
Vladimir C 2023-03-08 10:09:20 +01:00 committed by GitHub
commit b298fbeecc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 29 additions and 18 deletions

View File

@ -330,7 +330,7 @@ repeat(s, n)
**Arguments**
- `s` — The string to repeat. [String](../../sql-reference/data-types/string.md).
- `n` — The number of times to repeat the string. [UInt](../../sql-reference/data-types/int-uint.md).
- `n` — The number of times to repeat the string. [UInt or Int](../../sql-reference/data-types/int-uint.md).
**Returned value**

View File

@ -39,13 +39,15 @@ struct RepeatImpl
size, max_string_size);
}
template <typename T>
static void vectorStrConstRepeat(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
ColumnString::Chars & res_data,
ColumnString::Offsets & res_offsets,
UInt64 repeat_time)
T repeat_time)
{
repeat_time = repeat_time < 0 ? 0 : repeat_time;
checkRepeatTime(repeat_time);
UInt64 data_size = 0;
@ -77,7 +79,8 @@ struct RepeatImpl
res_offsets.assign(offsets);
for (UInt64 i = 0; i < col_num.size(); ++i)
{
size_t repeated_size = (offsets[i] - offsets[i - 1] - 1) * col_num[i] + 1;
T repeat_time = col_num[i] < 0 ? 0 : col_num[i];
size_t repeated_size = (offsets[i] - offsets[i - 1] - 1) * repeat_time + 1;
checkStringSize(repeated_size);
data_size += repeated_size;
res_offsets[i] = data_size;
@ -86,7 +89,7 @@ struct RepeatImpl
for (UInt64 i = 0; i < col_num.size(); ++i)
{
T repeat_time = col_num[i];
T repeat_time = col_num[i] < 0 ? 0 : col_num[i];
checkRepeatTime(repeat_time);
process(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], repeat_time);
}
@ -105,7 +108,8 @@ struct RepeatImpl
UInt64 col_size = col_num.size();
for (UInt64 i = 0; i < col_size; ++i)
{
size_t repeated_size = str_size * col_num[i] + 1;
T repeat_time = col_num[i] < 0 ? 0 : col_num[i];
size_t repeated_size = str_size * repeat_time + 1;
checkStringSize(repeated_size);
data_size += repeated_size;
res_offsets[i] = data_size;
@ -113,7 +117,7 @@ struct RepeatImpl
res_data.resize(data_size);
for (UInt64 i = 0; i < col_size; ++i)
{
T repeat_time = col_num[i];
T repeat_time = col_num[i] < 0 ? 0 : col_num[i];
checkRepeatTime(repeat_time);
process(
reinterpret_cast<UInt8 *>(const_cast<char *>(copy_str.data())),
@ -168,7 +172,8 @@ class FunctionRepeat : public IFunction
template <typename F>
static bool castType(const IDataType * type, F && f)
{
return castTypeToEither<DataTypeUInt8, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64>(type, std::forward<F>(f));
return castTypeToEither<DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64,
DataTypeUInt8, DataTypeUInt16, DataTypeUInt32, DataTypeUInt64>(type, std::forward<F>(f));
}
public:
@ -186,7 +191,7 @@ public:
if (!isString(arguments[0]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}",
arguments[0]->getName(), getName());
if (!isUnsignedInteger(arguments[1]))
if (!isInteger(arguments[1]))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}",
arguments[1]->getName(), getName());
return arguments[0];
@ -204,9 +209,15 @@ public:
{
if (const ColumnConst * scale_column_num = checkAndGetColumn<ColumnConst>(numcolumn.get()))
{
UInt64 repeat_time = scale_column_num->getValue<UInt64>();
auto col_res = ColumnString::create();
RepeatImpl::vectorStrConstRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), repeat_time);
castType(arguments[1].type.get(), [&](const auto & type)
{
using DataType = std::decay_t<decltype(type)>;
using T = typename DataType::FieldType;
T repeat_time = scale_column_num->getValue<T>();
RepeatImpl::vectorStrConstRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), repeat_time);
return true;
});
return col_res;
}
else if (castType(arguments[1].type.get(), [&](const auto & type)

View File

@ -1,7 +1,7 @@
abcabcabcabcabcabcabcabcabcabc
abcabcabc
sdfggsdfgg
xywq
abcabcabcabcabcabcabcabcabcabcabcabc
sdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfgg
@ -20,8 +20,8 @@ sdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfgg
xywqxywqxywqxywqxywqxywqxywqxywqxywqxywq
plkfplkfplkfplkfplkfplkfplkfplkfplkfplkf
abcabcabc
abcabc
abc
abcabcabcabcabcabcabcabcabcabcabcabc
abcabcabcabcabcabcabcabcabcabc

View File

@ -3,20 +3,20 @@ DROP TABLE IF EXISTS defaults;
CREATE TABLE defaults
(
strings String,
u8 UInt8,
i8 Int8,
u16 UInt16,
u32 UInt32,
u64 UInt64
)ENGINE = Memory();
INSERT INTO defaults values ('abc', 3, 12, 4, 56) ('sdfgg', 2, 10, 21, 200) ('xywq', 1, 4, 9, 5) ('plkf', 0, 5, 7,77);
INSERT INTO defaults values ('abc', 3, 12, 4, 56) ('sdfgg', -2, 10, 21, 200) ('xywq', -1, 4, 9, 5) ('plkf', 0, 5, 7,77);
SELECT repeat(strings, u8) FROM defaults;
SELECT repeat(strings, i8) FROM defaults;
SELECT repeat(strings, u16) FROM defaults;
SELECT repeat(strings, u32) from defaults;
SELECT repeat(strings, u64) FROM defaults;
SELECT repeat(strings, 10) FROM defaults;
SELECT repeat('abc', u8) FROM defaults;
SELECT repeat('abc', i8) FROM defaults;
SELECT repeat('abc', u16) FROM defaults;
SELECT repeat('abc', u32) FROM defaults;
SELECT repeat('abc', u64) FROM defaults;