diff --git a/src/Functions/GatherUtils/Sources.h b/src/Functions/GatherUtils/Sources.h index 4dbaff9f567..9a459860a68 100644 --- a/src/Functions/GatherUtils/Sources.h +++ b/src/Functions/GatherUtils/Sources.h @@ -755,6 +755,7 @@ struct GenericValueSource : public ValueSourceImpl { using Slice = GenericValueSlice; using SinkType = GenericArraySink; + using Column = IColumn; const IColumn * column; size_t total_rows; diff --git a/src/Functions/leftPadString.cpp b/src/Functions/leftPadString.cpp deleted file mode 100644 index cdcfb46eb73..00000000000 --- a/src/Functions/leftPadString.cpp +++ /dev/null @@ -1,194 +0,0 @@ -#include -#include -#include - -#include -#include -#include - -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int ILLEGAL_COLUMN; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; - extern const int BAD_ARGUMENTS; -} - -namespace -{ - struct LeftPadStringImpl - { - static void vector( - const ColumnString::Chars & data, - const ColumnString::Offsets & offsets, - const size_t length, - const String & padstr, - ColumnString::Chars & res_data, - ColumnString::Offsets & res_offsets) - { - size_t size = offsets.size(); - res_data.resize((length + 1 /* zero terminator */) * size); - res_offsets.resize(size); - - const size_t padstr_size = padstr.size(); - - ColumnString::Offset prev_offset = 0; - ColumnString::Offset res_prev_offset = 0; - for (size_t i = 0; i < size; ++i) - { - size_t data_length = offsets[i] - prev_offset - 1 /* zero terminator */; - if (data_length < length) - { - for (size_t j = 0; j < length - data_length; ++j) - res_data[res_prev_offset + j] = padstr[j % padstr_size]; - memcpy(&res_data[res_prev_offset + length - data_length], &data[prev_offset], data_length); - } - else - { - memcpy(&res_data[res_prev_offset], &data[prev_offset], length); - } - res_data[res_prev_offset + length] = 0; - res_prev_offset += length + 1; - res_offsets[i] = res_prev_offset; - } - } - - static void vectorFixed( - const ColumnFixedString::Chars & data, - const size_t n, - const size_t length, - const String & padstr, - ColumnFixedString::Chars & res_data) - { - const size_t padstr_size = padstr.size(); - const size_t size = data.size() / n; - res_data.resize(length * size); - for (size_t i = 0; i < size; ++i) - { - if (length < n) - { - memcpy(&res_data[i * length], &data[i * n], length); - } - else - { - for (size_t j = 0; j < length - n; ++j) - res_data[i * length + j] = padstr[j % padstr_size]; - memcpy(&res_data[i * length + length - n], &data[i * n], n); - } - } - } - }; - - class FunctionLeftPadString : public IFunction - { - public: - static constexpr auto name = "leftPadString"; - static FunctionPtr create(const ContextPtr) { return std::make_shared(); } - - String getName() const override { return name; } - - bool isVariadic() const override { return true; } - size_t getNumberOfArguments() const override { return 0; } - - bool useDefaultImplementationForConstants() const override { return true; } - - DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override - { - size_t number_of_arguments = arguments.size(); - - if (number_of_arguments != 2 && number_of_arguments != 3) - throw Exception( - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3", - getName(), - toString(number_of_arguments)); - - if (!isStringOrFixedString(arguments[0])) - throw Exception( - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}", arguments[0]->getName(), getName()); - - if (!isNativeNumber(arguments[1])) - throw Exception( - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal type {} of second argument of function {}", - arguments[1]->getName(), - getName()); - - if (number_of_arguments == 3 && !isStringOrFixedString(arguments[2])) - throw Exception( - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Illegal type {} of third argument of function {}", - arguments[2]->getName(), - getName()); - - return arguments[0]; - } - - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override - { - const ColumnPtr str_column = arguments[0].column; - String padstr = " "; - if (arguments.size() == 3) - { - const ColumnConst * pad_column = checkAndGetColumnConst(arguments[2].column.get()); - if (!pad_column) - throw Exception( - ErrorCodes::ILLEGAL_COLUMN, - "Illegal column {} of third ('pad') argument of function {}. Must be constant string.", - arguments[2].column->getName(), - getName()); - - padstr = pad_column->getValue(); - } - - const ColumnConst * len_column = checkAndGetColumnConst(arguments[1].column.get()); - if (!len_column) - throw Exception( - ErrorCodes::ILLEGAL_COLUMN, - "Illegal column {} of second ('len') argument of function {}. Must be a positive integer.", - arguments[1].column->getName(), - getName()); - Int64 len = len_column->getInt(0); - if (len <= 0) - throw Exception( - ErrorCodes::BAD_ARGUMENTS, - "Illegal value {} of second ('len') argument of function {}. Must be a positive integer.", - arguments[1].column->getName(), - getName()); - - if (const ColumnString * strings = checkAndGetColumn(str_column.get())) - { - auto col_res = ColumnString::create(); - LeftPadStringImpl::vector( - strings->getChars(), strings->getOffsets(), len, padstr, col_res->getChars(), col_res->getOffsets()); - return col_res; - } - else if (const ColumnFixedString * strings_fixed = checkAndGetColumn(str_column.get())) - { - auto col_res = ColumnFixedString::create(len); - LeftPadStringImpl::vectorFixed(strings_fixed->getChars(), strings_fixed->getN(), len, padstr, col_res->getChars()); - return col_res; - } - else - { - throw Exception( - ErrorCodes::ILLEGAL_COLUMN, - "Illegal column {} of first ('str') argument of function {}. Must be a string or fixed string.", - arguments[0].column->getName(), - getName()); - } - } - }; -} - -void registerFunctionLeftPadString(FunctionFactory & factory) -{ - factory.registerFunction(FunctionFactory::CaseInsensitive); - factory.registerAlias("lpad", "leftPadString", FunctionFactory::CaseInsensitive); -} - -} diff --git a/src/Functions/padString.cpp b/src/Functions/padString.cpp new file mode 100644 index 00000000000..7711ab1a056 --- /dev/null +++ b/src/Functions/padString.cpp @@ -0,0 +1,308 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +using namespace GatherUtils; + +namespace ErrorCodes +{ + extern const int ILLEGAL_COLUMN; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int TOO_LARGE_STRING_SIZE; +} + +namespace +{ + /// The maximum new padded length. + constexpr size_t MAX_NEW_LENGTH = 1000000; + + /// Appends padding characters to a sink based on a pad string. + /// Depending on how many padding characters are required to add + /// the pad string can be copied only partly or be repeated multiple times. + template + class PaddingChars + { + public: + explicit PaddingChars(const String & pad_string_) : pad_string(pad_string_) { init(); } + + ALWAYS_INLINE size_t numCharsInPadString() const + { + if constexpr (is_utf8) + return utf8_offsets.size() - 1; + else + return pad_string.length(); + } + + ALWAYS_INLINE size_t numCharsToNumBytes(size_t count) const + { + if constexpr (is_utf8) + return utf8_offsets[count]; + else + return count; + } + + void appendTo(StringSink & res_sink, size_t num_chars) const + { + if (!num_chars) + return; + + const size_t step = numCharsInPadString(); + while (true) + { + if (num_chars <= step) + { + writeSlice(StringSource::Slice{bit_cast(pad_string.data()), numCharsToNumBytes(num_chars)}, res_sink); + break; + } + writeSlice(StringSource::Slice{bit_cast(pad_string.data()), numCharsToNumBytes(step)}, res_sink); + num_chars -= step; + } + } + + private: + void init() + { + if (pad_string.empty()) + pad_string = " "; + + if constexpr (is_utf8) + { + size_t offset = 0; + utf8_offsets.reserve(pad_string.length() + 1); + while (true) + { + utf8_offsets.push_back(offset); + if (offset == pad_string.length()) + break; + offset += UTF8::seqLength(pad_string[offset]); + if (offset > pad_string.length()) + offset = pad_string.length(); + } + } + + /// Not necessary, but good for performance. + while (numCharsInPadString() < 16) + { + pad_string += pad_string; + if constexpr (is_utf8) + { + size_t old_size = utf8_offsets.size(); + utf8_offsets.reserve((old_size - 1) * 2); + size_t base = utf8_offsets.back(); + for (size_t i = 1; i != old_size; ++i) + utf8_offsets.push_back(utf8_offsets[i] + base); + } + } + } + + String pad_string; + std::vector utf8_offsets; + }; + + /// Returns the number of characters in a slice. + template + inline ALWAYS_INLINE size_t getLengthOfSlice(const StringSource::Slice & slice) + { + if constexpr (is_utf8) + return UTF8::countCodePoints(slice.data, slice.size); + else + return slice.size; + } + + /// Moves the end of a slice back by n characters. + template + inline ALWAYS_INLINE StringSource::Slice removeSuffixFromSlice(const StringSource::Slice & slice, size_t suffix_length) + { + StringSource::Slice res = slice; + if constexpr (is_utf8) + res.size = UTF8StringSource::skipCodePointsBackward(slice.data + slice.size, suffix_length, slice.data) - res.data; + else + res.size -= std::min(suffix_length, res.size); + return res; + } + + /// If `is_right_pad` - it's the rightPad() function instead of leftPad(). + /// If `is_utf8` - lengths are measured in code points instead of bytes. + template + class FunctionPadString : public IFunction + { + public: + static constexpr auto name = is_right_pad ? (is_utf8 ? "rightPadUTF8" : "rightPad") : (is_utf8 ? "leftPadUTF8" : "leftPad"); + static FunctionPtr create(const ContextPtr) { return std::make_shared(); } + + String getName() const override { return name; } + + bool isVariadic() const override { return true; } + size_t getNumberOfArguments() const override { return 0; } + + bool useDefaultImplementationForConstants() const override { return false; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + size_t number_of_arguments = arguments.size(); + + if (number_of_arguments != 2 && number_of_arguments != 3) + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Number of arguments for function {} doesn't match: passed {}, should be 2 or 3", + getName(), + std::to_string(number_of_arguments)); + + if (!isStringOrFixedString(arguments[0])) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of the first argument of function {}, should be string", + arguments[0]->getName(), + getName()); + + if (!isUnsignedInteger(arguments[1])) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of the second argument of function {}, should be unsigned integer", + arguments[1]->getName(), + getName()); + + if (number_of_arguments == 3 && !isStringOrFixedString(arguments[2])) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of the third argument of function {}, should be const string", + arguments[2]->getName(), + getName()); + + return arguments[0]; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + auto column_string = arguments[0].column; + auto column_length = arguments[1].column; + + String pad_string; + if (arguments.size() == 3) + { + auto column_pad = arguments[2].column; + const ColumnConst * column_pad_const = checkAndGetColumnConst(column_pad.get()); + if (!column_pad_const) + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Illegal column {}, third argument of function {} must be a constant string", + column_pad->getName(), + getName()); + + pad_string = column_pad_const->getValue(); + } + PaddingChars padding_chars{pad_string}; + + auto col_res = ColumnString::create(); + StringSink res_sink{*col_res, input_rows_count}; + + if (const ColumnString * col = checkAndGetColumn(column_string.get())) + executeForSource(StringSource{*col}, column_length, padding_chars, res_sink); + else if (const ColumnFixedString * col_fixed = checkAndGetColumn(column_string.get())) + executeForSource(FixedStringSource{*col_fixed}, column_length, padding_chars, res_sink); + else if (const ColumnConst * col_const = checkAndGetColumnConst(column_string.get())) + executeForSource(ConstSource{*col_const}, column_length, padding_chars, res_sink); + else if (const ColumnConst * col_const_fixed = checkAndGetColumnConst(column_string.get())) + executeForSource(ConstSource{*col_const_fixed}, column_length, padding_chars, res_sink); + else + throw Exception( + ErrorCodes::ILLEGAL_COLUMN, + "Illegal column {}, first argument of function {} must be a string", + arguments[0].column->getName(), + getName()); + + return col_res; + } + + private: + template + void executeForSource( + SourceStrings && strings, + const ColumnPtr & column_length, + const PaddingChars & padding_chars, + StringSink & res_sink) const + { + if (const auto * col_const = checkAndGetColumn(column_length.get())) + executeForSourceAndLength(std::forward(strings), ConstSource{*col_const}, padding_chars, res_sink); + else + executeForSourceAndLength(std::forward(strings), GenericValueSource{*column_length}, padding_chars, res_sink); + } + + template + void executeForSourceAndLength( + SourceStrings && strings, + SourceLengths && lengths, + const PaddingChars & padding_chars, + StringSink & res_sink) const + { + bool is_const_length = lengths.isConst(); + bool need_check_length = true; + + for (; !res_sink.isEnd(); res_sink.next(), strings.next(), lengths.next()) + { + auto str = strings.getWhole(); + size_t current_length = getLengthOfSlice(str); + + auto new_length_slice = lengths.getWhole(); + size_t new_length = new_length_slice.elements->getUInt(new_length_slice.position); + + if (need_check_length) + { + if (new_length > MAX_NEW_LENGTH) + { + throw Exception( + "New padded length (" + std::to_string(new_length) + ") is too big, maximum is: " + std::to_string(MAX_NEW_LENGTH), + ErrorCodes::TOO_LARGE_STRING_SIZE); + } + if (is_const_length) + { + size_t rows_count = res_sink.offsets.size(); + res_sink.reserve((new_length + 1 /* zero terminator */) * rows_count); + need_check_length = false; + } + } + + if (new_length == current_length) + { + writeSlice(str, res_sink); + } + else if (new_length < current_length) + { + str = removeSuffixFromSlice(str, current_length - new_length); + writeSlice(str, res_sink); + } + else if (new_length > current_length) + { + if constexpr (!is_right_pad) + padding_chars.appendTo(res_sink, new_length - current_length); + + writeSlice(str, res_sink); + + if constexpr (is_right_pad) + padding_chars.appendTo(res_sink, new_length - current_length); + } + } + } + }; +} + +void registerFunctionPadString(FunctionFactory & factory) +{ + factory.registerFunction>(); /// leftPad + factory.registerFunction>(); /// leftPadUTF8 + factory.registerFunction>(); /// rightPad + factory.registerFunction>(); /// rightPadUTF8 + + factory.registerAlias("lpad", "leftPad", FunctionFactory::CaseInsensitive); + factory.registerAlias("rpad", "rightPad", FunctionFactory::CaseInsensitive); +} + +} diff --git a/src/Functions/registerFunctionsString.cpp b/src/Functions/registerFunctionsString.cpp index 1c487981844..18a30469386 100644 --- a/src/Functions/registerFunctionsString.cpp +++ b/src/Functions/registerFunctionsString.cpp @@ -29,13 +29,11 @@ void registerFunctionAppendTrailingCharIfAbsent(FunctionFactory &); void registerFunctionStartsWith(FunctionFactory &); void registerFunctionEndsWith(FunctionFactory &); void registerFunctionTrim(FunctionFactory &); +void registerFunctionPadString(FunctionFactory &); void registerFunctionRegexpQuoteMeta(FunctionFactory &); void registerFunctionNormalizeQuery(FunctionFactory &); void registerFunctionNormalizedQueryHash(FunctionFactory &); void registerFunctionCountMatches(FunctionFactory &); -void registerFunctionEncodeXMLComponent(FunctionFactory & factory); -void registerFunctionDecodeXMLComponent(FunctionFactory & factory); -void registerFunctionLeftPadString(FunctionFactory & factory); void registerFunctionEncodeXMLComponent(FunctionFactory &); void registerFunctionDecodeXMLComponent(FunctionFactory &); void registerFunctionExtractTextFromHTML(FunctionFactory &); @@ -71,13 +69,13 @@ void registerFunctionsString(FunctionFactory & factory) registerFunctionStartsWith(factory); registerFunctionEndsWith(factory); registerFunctionTrim(factory); + registerFunctionPadString(factory); registerFunctionRegexpQuoteMeta(factory); registerFunctionNormalizeQuery(factory); registerFunctionNormalizedQueryHash(factory); registerFunctionCountMatches(factory); registerFunctionEncodeXMLComponent(factory); registerFunctionDecodeXMLComponent(factory); - registerFunctionLeftPadString(factory); registerFunctionExtractTextFromHTML(factory); #if USE_BASE64 registerFunctionBase64Encode(factory); diff --git a/src/Functions/ya.make b/src/Functions/ya.make index ba14e9a3e02..5f84511aa52 100644 --- a/src/Functions/ya.make +++ b/src/Functions/ya.make @@ -332,7 +332,6 @@ SRCS( jumpConsistentHash.cpp lcm.cpp least.cpp - leftPadString.cpp lengthUTF8.cpp less.cpp lessOrEquals.cpp @@ -388,6 +387,7 @@ SRCS( now.cpp now64.cpp nullIf.cpp + padString.cpp partitionId.cpp pi.cpp plus.cpp diff --git a/tests/queries/0_stateless/01940_pad_string.reference b/tests/queries/0_stateless/01940_pad_string.reference new file mode 100644 index 00000000000..22cd3f9be07 --- /dev/null +++ b/tests/queries/0_stateless/01940_pad_string.reference @@ -0,0 +1,54 @@ +leftPad + +a +ab +abc + abc + abc + abc +ab +*abc +**abc +*******abc +ab +*abc +*.abc +*.*.*.*abc +leftPadUTF8 +а +аб +аб +абвг +ЧАабвг +ЧАСЧАСЧАабвг +rightPad + +a +ab +abc +abc +abc +abc +ab +abc* +abc** +abc******* +ab +abc* +abc*. +abc*.*.*.* +rightPadUTF8 +а +аб +аб +абвг +абвгЧА +абвгЧАСЧАСЧА +numbers + +1^ +_2^^ +__3^^^ +___4^^^^ +____5^^^^^ +_____6^^^^^^ diff --git a/tests/queries/0_stateless/01940_pad_string.sql b/tests/queries/0_stateless/01940_pad_string.sql new file mode 100644 index 00000000000..e4ba0aec6d2 --- /dev/null +++ b/tests/queries/0_stateless/01940_pad_string.sql @@ -0,0 +1,54 @@ +SELECT 'leftPad'; +SELECT leftPad('abc', 0); +SELECT leftPad('abc', 1); +SELECT leftPad('abc', 2); +SELECT leftPad('abc', 3); +SELECT leftPad('abc', 4); +SELECT leftPad('abc', 5); +SELECT leftPad('abc', 10); + +SELECT leftPad('abc', 2, '*'); +SELECT leftPad('abc', 4, '*'); +SELECT leftPad('abc', 5, '*'); +SELECT leftPad('abc', 10, '*'); +SELECT leftPad('abc', 2, '*.'); +SELECT leftPad('abc', 4, '*.'); +SELECT leftPad('abc', 5, '*.'); +SELECT leftPad('abc', 10, '*.'); + +SELECT 'leftPadUTF8'; +SELECT leftPad('абвг', 2); +SELECT leftPadUTF8('абвг', 2); +SELECT leftPad('абвг', 4); +SELECT leftPadUTF8('абвг', 4); +SELECT leftPad('абвг', 12, 'ЧАС'); +SELECT leftPadUTF8('абвг', 12, 'ЧАС'); + +SELECT 'rightPad'; +SELECT rightPad('abc', 0); +SELECT rightPad('abc', 1); +SELECT rightPad('abc', 2); +SELECT rightPad('abc', 3); +SELECT rightPad('abc', 4); +SELECT rightPad('abc', 5); +SELECT rightPad('abc', 10); + +SELECT rightPad('abc', 2, '*'); +SELECT rightPad('abc', 4, '*'); +SELECT rightPad('abc', 5, '*'); +SELECT rightPad('abc', 10, '*'); +SELECT rightPad('abc', 2, '*.'); +SELECT rightPad('abc', 4, '*.'); +SELECT rightPad('abc', 5, '*.'); +SELECT rightPad('abc', 10, '*.'); + +SELECT 'rightPadUTF8'; +SELECT rightPad('абвг', 2); +SELECT rightPadUTF8('абвг', 2); +SELECT rightPad('абвг', 4); +SELECT rightPadUTF8('абвг', 4); +SELECT rightPad('абвг', 12, 'ЧАС'); +SELECT rightPadUTF8('абвг', 12, 'ЧАС'); + +SELECT 'numbers'; +SELECT rightPad(leftPad(toString(number), number, '_'), number*2, '^') FROM numbers(7);