From 47e3eccc59ed62149aff6fe68068b6ca4cbc1dad Mon Sep 17 00:00:00 2001 From: Antonio Andelic Date: Mon, 18 Mar 2024 13:41:25 +0100 Subject: [PATCH] Fix repeat with non native integers --- src/Functions/repeat.cpp | 32 ++++++++++++++----- ...3_repeat_with_nonnative_integers.reference | 4 +++ .../03013_repeat_with_nonnative_integers.sql | 4 +++ 3 files changed, 32 insertions(+), 8 deletions(-) create mode 100644 tests/queries/0_stateless/03013_repeat_with_nonnative_integers.reference create mode 100644 tests/queries/0_stateless/03013_repeat_with_nonnative_integers.sql diff --git a/src/Functions/repeat.cpp b/src/Functions/repeat.cpp index 11a2ca37a3b..6f2078b7e48 100644 --- a/src/Functions/repeat.cpp +++ b/src/Functions/repeat.cpp @@ -44,7 +44,7 @@ struct RepeatImpl ColumnString::Offsets & res_offsets, T repeat_time) { - repeat_time = repeat_time < 0 ? 0 : repeat_time; + repeat_time = repeat_time < 0 ? static_cast(0) : repeat_time; checkRepeatTime(repeat_time); UInt64 data_size = 0; @@ -76,7 +76,7 @@ struct RepeatImpl res_offsets.assign(offsets); for (UInt64 i = 0; i < col_num.size(); ++i) { - T repeat_time = col_num[i] < 0 ? 0 : col_num[i]; + T repeat_time = col_num[i] < 0 ? static_cast(0) : col_num[i]; size_t repeated_size = (offsets[i] - offsets[i - 1] - 1) * repeat_time + 1; checkStringSize(repeated_size); data_size += repeated_size; @@ -86,7 +86,7 @@ struct RepeatImpl for (UInt64 i = 0; i < col_num.size(); ++i) { - T repeat_time = col_num[i] < 0 ? 0 : col_num[i]; + T repeat_time = col_num[i] < 0 ? static_cast(0) : col_num[i]; checkRepeatTime(repeat_time); process(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], repeat_time); } @@ -105,7 +105,7 @@ struct RepeatImpl UInt64 col_size = col_num.size(); for (UInt64 i = 0; i < col_size; ++i) { - T repeat_time = col_num[i] < 0 ? 0 : col_num[i]; + T repeat_time = col_num[i] < 0 ? static_cast(0) : col_num[i]; size_t repeated_size = str_size * repeat_time + 1; checkStringSize(repeated_size); data_size += repeated_size; @@ -114,7 +114,7 @@ struct RepeatImpl res_data.resize(data_size); for (UInt64 i = 0; i < col_size; ++i) { - T repeat_time = col_num[i] < 0 ? 0 : col_num[i]; + T repeat_time = col_num[i] < 0 ? static_cast(0) : col_num[i]; checkRepeatTime(repeat_time); process( reinterpret_cast(const_cast(copy_str.data())), @@ -169,8 +169,19 @@ class FunctionRepeat : public IFunction template static bool castType(const IDataType * type, F && f) { - return castTypeToEither(type, std::forward(f)); + return castTypeToEither< + DataTypeInt8, + DataTypeInt16, + DataTypeInt32, + DataTypeInt64, + DataTypeInt128, + DataTypeInt256, + DataTypeUInt8, + DataTypeUInt16, + DataTypeUInt32, + DataTypeUInt64, + DataTypeUInt128, + DataTypeUInt256>(type, std::forward(f)); } public: @@ -208,7 +219,7 @@ public: if (const ColumnConst * col_num_const = checkAndGetColumn(col_num.get())) { auto col_res = ColumnString::create(); - castType(arguments[1].type.get(), [&](const auto & type) + auto success = castType(arguments[1].type.get(), [&](const auto & type) { using DataType = std::decay_t; using T = typename DataType::FieldType; @@ -216,6 +227,11 @@ public: RepeatImpl::vectorStrConstRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), times); return true; }); + + if (!success) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column type {} of 'n' of function {}", + arguments[1].column->getName(), getName()); + return col_res; } else if (castType(arguments[1].type.get(), [&](const auto & type) diff --git a/tests/queries/0_stateless/03013_repeat_with_nonnative_integers.reference b/tests/queries/0_stateless/03013_repeat_with_nonnative_integers.reference new file mode 100644 index 00000000000..50cb7002a32 --- /dev/null +++ b/tests/queries/0_stateless/03013_repeat_with_nonnative_integers.reference @@ -0,0 +1,4 @@ +000000000000 +000000000000 +000000000000 +000000000000 diff --git a/tests/queries/0_stateless/03013_repeat_with_nonnative_integers.sql b/tests/queries/0_stateless/03013_repeat_with_nonnative_integers.sql new file mode 100644 index 00000000000..0dbe98994b9 --- /dev/null +++ b/tests/queries/0_stateless/03013_repeat_with_nonnative_integers.sql @@ -0,0 +1,4 @@ +SELECT repeat(toString(number), toUInt256(12)) FROM numbers(1); +SELECT repeat(toString(number), toUInt128(12)) FROM numbers(1); +SELECT repeat(toString(number), toInt256(12)) FROM numbers(1); +SELECT repeat(toString(number), toInt128(12)) FROM numbers(1);