diff --git a/dbms/src/Functions/repeat.cpp b/dbms/src/Functions/repeat.cpp index 741af698452..ad1d3954393 100644 --- a/dbms/src/Functions/repeat.cpp +++ b/dbms/src/Functions/repeat.cpp @@ -1,4 +1,3 @@ -#include #include #include #include @@ -18,38 +17,29 @@ namespace ErrorCodes struct RepeatImpl { - static void vectorNonConstStr( + static void vectorStrConstRepeat( const ColumnString::Chars & data, const ColumnString::Offsets & offsets, ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets, - const UInt64 & repeat_time) + UInt64 repeat_time) { UInt64 data_size = 0; res_offsets.assign(offsets); for (UInt64 i = 0; i < offsets.size(); ++i) { - data_size += (offsets[i] - offsets[i - 1] - 1) * repeat_time + 1; + data_size += (offsets[i] - offsets[i - 1] - 1) * repeat_time + 1; /// Note that accessing -1th element is valid for PaddedPODArray. res_offsets[i] = data_size; } res_data.resize(data_size); for (UInt64 i = 0; i < res_offsets.size(); ++i) { - array(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], repeat_time); + process(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], repeat_time); } } - static void - vectorConst(const String & copy_str, const UInt64 & repeat_time, ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets) - { - UInt64 data_size = copy_str.size() * repeat_time + 1; - res_data.resize(data_size); - res_offsets.resize_fill(1, data_size); - array(reinterpret_cast(const_cast(copy_str.data())), res_data.data(), copy_str.size() + 1, repeat_time); - } - template - static void vectorNonConst( + static void vectorStrVectorRepeat( const ColumnString::Chars & data, const ColumnString::Offsets & offsets, ColumnString::Chars & res_data, @@ -66,17 +56,20 @@ struct RepeatImpl res_data.resize(data_size); for (UInt64 i = 0; i < col_num.size(); ++i) { - array(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], col_num[i]); + process(data.data() + offsets[i - 1], res_data.data() + res_offsets[i - 1], offsets[i] - offsets[i - 1], col_num[i]); } } template - static void vectorNonConstInteger( - const String & copy_str, ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets, const PaddedPODArray & col_num) + static void constStrVectorRepeat( + const StringRef & copy_str, + ColumnString::Chars & res_data, + ColumnString::Offsets & res_offsets, + const PaddedPODArray & col_num) { UInt64 data_size = 0; res_offsets.resize(col_num.size()); - UInt64 str_size = copy_str.size(); + UInt64 str_size = copy_str.size; UInt64 col_size = col_num.size(); for (UInt64 i = 0; i < col_size; ++i) { @@ -86,8 +79,8 @@ struct RepeatImpl res_data.resize(data_size); for (UInt64 i = 0; i < col_size; ++i) { - array( - reinterpret_cast(const_cast(copy_str.data())), + process( + reinterpret_cast(const_cast(copy_str.data)), res_data.data() + res_offsets[i - 1], str_size + 1, col_num[i]); @@ -95,7 +88,7 @@ struct RepeatImpl } private: - static void array(const UInt8 * src, UInt8 * dst, const UInt64 & size, const UInt64 & repeat_time) + static void process(const UInt8 * src, UInt8 * dst, UInt64 size, UInt64 repeat_time) { for (UInt64 i = 0; i < repeat_time; ++i) { @@ -106,8 +99,8 @@ private: } }; -template -class FunctionRepeatImpl : public IFunction + +class FunctionRepeat : public IFunction { template static bool castType(const IDataType * type, F && f) @@ -117,7 +110,7 @@ class FunctionRepeatImpl : public IFunction public: static constexpr auto name = "repeat"; - static FunctionPtr create(const Context &) { return std::make_shared(); } + static FunctionPtr create(const Context &) { return std::make_shared(); } String getName() const override { return name; } @@ -138,74 +131,64 @@ public: void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t) override { - auto & strcolumn = block.getByPosition(arguments[0]).column; - auto & numcolumn = block.getByPosition(arguments[1]).column; + const auto & strcolumn = block.getByPosition(arguments[0]).column; + const auto & numcolumn = block.getByPosition(arguments[1]).column; if (const ColumnString * col = checkAndGetColumn(strcolumn.get())) { if (const ColumnConst * scale_column_num = checkAndGetColumn(numcolumn.get())) { - Field scale_field_num = scale_column_num->getField(); - UInt64 repeat_time = scale_field_num.get(); + UInt64 repeat_time = scale_column_num->getValue(); auto col_res = ColumnString::create(); - Impl::vectorNonConstStr(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), repeat_time); + RepeatImpl::vectorStrConstRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), repeat_time); block.getByPosition(result).column = std::move(col_res); + return; } - else if (!castType(block.getByPosition(arguments[1]).type.get(), [&](const auto & type) - { - using DataType = std::decay_t; - using T0 = typename DataType::FieldType; - const ColumnVector * colnum = checkAndGetColumn>(numcolumn.get()); - auto col_res = ColumnString::create(); - Impl::vectorNonConst(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), colnum->getData()); - block.getByPosition(result).column = std::move(col_res); - return 0; - })); - else - throw Exception( - "Illegal column " + block.getByPosition(arguments[1]).column->getName() + " of argument of function " + getName(), - ErrorCodes::ILLEGAL_COLUMN); - } - else if (const ColumnConst * scale_column_str = checkAndGetColumn(strcolumn.get())) - { - Field scale_field_str = scale_column_str->getField(); - String copy_str = scale_field_str.get(); - if (const ColumnConst * scale_column_num = checkAndGetColumn(numcolumn.get())) - { - Field scale_field_num = scale_column_num->getField(); - UInt64 repeat_time = scale_field_num.get(); - auto col_res = ColumnString::create(); - Impl::vectorConst(copy_str, repeat_time, col_res->getChars(), col_res->getOffsets()); - block.getByPosition(result).column = std::move(col_res); - } - else if (!castType(block.getByPosition(arguments[1]).type.get(), [&](const auto & type) + else if (castType(block.getByPosition(arguments[1]).type.get(), [&](const auto & type) { using DataType = std::decay_t; - using T0 = typename DataType::FieldType; - const ColumnVector * colnum = checkAndGetColumn>(numcolumn.get()); + using T = typename DataType::FieldType; + const ColumnVector * colnum = checkAndGetColumn>(numcolumn.get()); auto col_res = ColumnString::create(); - Impl::vectorNonConstInteger(copy_str, col_res->getChars(), col_res->getOffsets(), colnum->getData()); + RepeatImpl::vectorStrVectorRepeat(col->getChars(), col->getOffsets(), col_res->getChars(), col_res->getOffsets(), colnum->getData()); block.getByPosition(result).column = std::move(col_res); - return 0; + return true; })) { + return; } - else - throw Exception( - "Illegal column " + block.getByPosition(arguments[1]).column->getName() + " of argument of function " + getName(), - ErrorCodes::ILLEGAL_COLUMN); } - else - throw Exception( - "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(), - ErrorCodes::ILLEGAL_COLUMN); + else if (const ColumnConst * col_const = checkAndGetColumn(strcolumn.get())) + { + /// Note that const-const case is handled by useDefaultImplementationForConstants. + + StringRef copy_str = col_const->getDataColumn().getDataAt(0); + + if (castType(block.getByPosition(arguments[1]).type.get(), [&](const auto & type) + { + using DataType = std::decay_t; + using T = typename DataType::FieldType; + const ColumnVector * colnum = checkAndGetColumn>(numcolumn.get()); + auto col_res = ColumnString::create(); + RepeatImpl::constStrVectorRepeat(copy_str, col_res->getChars(), col_res->getOffsets(), colnum->getData()); + block.getByPosition(result).column = std::move(col_res); + return true; + })) + { + return; + } + } + + throw Exception( + "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_COLUMN); } }; -using FunctionRepeat = FunctionRepeatImpl; void registerFunctionRepeat(FunctionFactory & factory) { factory.registerFunction(); } + } diff --git a/dbms/tests/queries/0_stateless/01013_repeat_function.sql b/dbms/tests/queries/0_stateless/01013_repeat_function.sql index 5de0e7a64e5..7d34307a21f 100644 --- a/dbms/tests/queries/0_stateless/01013_repeat_function.sql +++ b/dbms/tests/queries/0_stateless/01013_repeat_function.sql @@ -2,11 +2,11 @@ SELECT repeat('abc', 10); DROP TABLE IF EXISTS defaults; CREATE TABLE defaults ( - strings String, - u8 UInt8, - u16 UInt16, - u32 UInt32, - u64 UInt64 + strings String, + u8 UInt8, + u16 UInt16, + u32 UInt32, + u64 UInt64 )ENGINE = Memory(); INSERT INTO defaults values ('abc', 3, 12, 4, 56) ('sdfgg', 2, 10, 21, 200) ('xywq', 1, 4, 9, 5) ('plkf', 0, 5, 7,77);