From cdc8596f81cd4322fda3a188ffcbb9a4065e8882 Mon Sep 17 00:00:00 2001 From: KevinyhZou Date: Thu, 2 Mar 2023 20:08:39 +0800 Subject: [PATCH] enable int type in repeat function --- .../functions/string-functions.md | 2 +- src/Functions/repeat.cpp | 29 +++++++++++++------ .../01013_repeat_function.reference | 8 ++--- .../0_stateless/01013_repeat_function.sql | 8 ++--- 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/docs/en/sql-reference/functions/string-functions.md b/docs/en/sql-reference/functions/string-functions.md index 845be6e04c7..f3c5b20f886 100644 --- a/docs/en/sql-reference/functions/string-functions.md +++ b/docs/en/sql-reference/functions/string-functions.md @@ -330,7 +330,7 @@ repeat(s, n) **Arguments** - `s` — The string to repeat. [String](../../sql-reference/data-types/string.md). -- `n` — The number of times to repeat the string. [UInt](../../sql-reference/data-types/int-uint.md). +- `n` — The number of times to repeat the string. [UInt or Int](../../sql-reference/data-types/int-uint.md). **Returned value** diff --git a/src/Functions/repeat.cpp b/src/Functions/repeat.cpp index dcd05f373fc..0c323c39969 100644 --- a/src/Functions/repeat.cpp +++ b/src/Functions/repeat.cpp @@ -39,13 +39,15 @@ struct RepeatImpl size, max_string_size); } + template static void vectorStrConstRepeat( const ColumnString::Chars & data, const ColumnString::Offsets & offsets, ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets, - UInt64 repeat_time) + T repeat_time) { + repeat_time = repeat_time < 0 ? 0 : repeat_time; checkRepeatTime(repeat_time); UInt64 data_size = 0; @@ -77,7 +79,8 @@ struct RepeatImpl res_offsets.assign(offsets); for (UInt64 i = 0; i < col_num.size(); ++i) { - size_t repeated_size = (offsets[i] - offsets[i - 1] - 1) * col_num[i] + 1; + T repeat_time = col_num[i] < 0 ? 0 : col_num[i]; + size_t repeated_size = (offsets[i] - offsets[i - 1] - 1) * repeat_time + 1; checkStringSize(repeated_size); data_size += repeated_size; res_offsets[i] = data_size; @@ -86,7 +89,7 @@ struct RepeatImpl for (UInt64 i = 0; i < col_num.size(); ++i) { - T repeat_time = col_num[i]; + T repeat_time = col_num[i] < 0 ? 0 : col_num[i]; checkRepeatTime(repeat_time); process(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], repeat_time); } @@ -105,7 +108,8 @@ struct RepeatImpl UInt64 col_size = col_num.size(); for (UInt64 i = 0; i < col_size; ++i) { - size_t repeated_size = str_size * col_num[i] + 1; + T repeat_time = col_num[i] < 0 ? 0 : col_num[i]; + size_t repeated_size = str_size * repeat_time + 1; checkStringSize(repeated_size); data_size += repeated_size; res_offsets[i] = data_size; @@ -113,7 +117,7 @@ struct RepeatImpl res_data.resize(data_size); for (UInt64 i = 0; i < col_size; ++i) { - T repeat_time = col_num[i]; + T repeat_time = col_num[i] < 0 ? 0 : col_num[i]; checkRepeatTime(repeat_time); process( reinterpret_cast(const_cast(copy_str.data())), @@ -168,7 +172,8 @@ class FunctionRepeat : public IFunction template static bool castType(const IDataType * type, F && f) { - return castTypeToEither(type, std::forward(f)); + return castTypeToEither(type, std::forward(f)); } public: @@ -186,7 +191,7 @@ public: if (!isString(arguments[0])) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[0]->getName(), getName()); - if (!isUnsignedInteger(arguments[1])) + if (!isInteger(arguments[1])) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[1]->getName(), getName()); return arguments[0]; @@ -204,9 +209,15 @@ public: { if (const ColumnConst * scale_column_num = checkAndGetColumn(numcolumn.get())) { - UInt64 repeat_time = scale_column_num->getValue(); auto col_res = ColumnString::create(); - RepeatImpl::vectorStrConstRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), repeat_time); + castType(arguments[1].type.get(), [&](const auto & type) + { + using DataType = std::decay_t; + using T = typename DataType::FieldType; + T repeat_time = scale_column_num->getValue(); + RepeatImpl::vectorStrConstRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), repeat_time); + return true; + }); return col_res; } else if (castType(arguments[1].type.get(), [&](const auto & type) diff --git a/tests/queries/0_stateless/01013_repeat_function.reference b/tests/queries/0_stateless/01013_repeat_function.reference index 46bb248a99a..ea0dadd524f 100644 --- a/tests/queries/0_stateless/01013_repeat_function.reference +++ b/tests/queries/0_stateless/01013_repeat_function.reference @@ -1,7 +1,7 @@ abcabcabcabcabcabcabcabcabcabc abcabcabc -sdfggsdfgg -xywq + + abcabcabcabcabcabcabcabcabcabcabcabc sdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfgg @@ -20,8 +20,8 @@ sdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfggsdfgg xywqxywqxywqxywqxywqxywqxywqxywqxywqxywq plkfplkfplkfplkfplkfplkfplkfplkfplkfplkf abcabcabc -abcabc -abc + + abcabcabcabcabcabcabcabcabcabcabcabc abcabcabcabcabcabcabcabcabcabc diff --git a/tests/queries/0_stateless/01013_repeat_function.sql b/tests/queries/0_stateless/01013_repeat_function.sql index 85b0c16b4ab..b29cc032f28 100644 --- a/tests/queries/0_stateless/01013_repeat_function.sql +++ b/tests/queries/0_stateless/01013_repeat_function.sql @@ -3,20 +3,20 @@ DROP TABLE IF EXISTS defaults; CREATE TABLE defaults ( strings String, - u8 UInt8, + i8 Int8, u16 UInt16, u32 UInt32, u64 UInt64 )ENGINE = Memory(); -INSERT INTO defaults values ('abc', 3, 12, 4, 56) ('sdfgg', 2, 10, 21, 200) ('xywq', 1, 4, 9, 5) ('plkf', 0, 5, 7,77); +INSERT INTO defaults values ('abc', 3, 12, 4, 56) ('sdfgg', -2, 10, 21, 200) ('xywq', -1, 4, 9, 5) ('plkf', 0, 5, 7,77); -SELECT repeat(strings, u8) FROM defaults; +SELECT repeat(strings, i8) FROM defaults; SELECT repeat(strings, u16) FROM defaults; SELECT repeat(strings, u32) from defaults; SELECT repeat(strings, u64) FROM defaults; SELECT repeat(strings, 10) FROM defaults; -SELECT repeat('abc', u8) FROM defaults; +SELECT repeat('abc', i8) FROM defaults; SELECT repeat('abc', u16) FROM defaults; SELECT repeat('abc', u32) FROM defaults; SELECT repeat('abc', u64) FROM defaults;