diff --git a/docs/en/sql-reference/functions/string-replace-functions.md b/docs/en/sql-reference/functions/string-replace-functions.md index 3f50cd24f93..f8d46cacbdc 100644 --- a/docs/en/sql-reference/functions/string-replace-functions.md +++ b/docs/en/sql-reference/functions/string-replace-functions.md @@ -253,7 +253,7 @@ SELECT format('{} {}', 'Hello', 'World') ## translate -Replaces characters in the string `s` using a one-to-one character mapping defined by `from` and `to` strings. `from` and `to` must be constant ASCII strings of the same size. Non-ASCII characters in the original string are not modified. +Replaces characters in the string `s` using a one-to-one character mapping defined by `from` and `to` strings. `from` and `to` must be constant ASCII strings. Non-ASCII characters in the original string are not modified. If the number of characters in `from` list is larger than the `to` list, non overlapping characters will be deleted from the input string. **Syntax** diff --git a/src/Functions/translate.cpp b/src/Functions/translate.cpp index f7077f99629..1f4ff147a6a 100644 --- a/src/Functions/translate.cpp +++ b/src/Functions/translate.cpp @@ -32,17 +32,31 @@ struct TranslateImpl const std::string & map_from, const std::string & map_to) { - if (map_from.size() != map_to.size()) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second and third arguments must be the same length"); - iota(map.data(), map.size(), UInt8(0)); - for (size_t i = 0; i < map_from.size(); ++i) + size_t min_size = std::min(map_from.size(), map_to.size()); + + // Map characters from map_from to map_to for the overlapping range + for (size_t i = 0; i < min_size; ++i) { if (!isASCII(map_from[i]) || !isASCII(map_to[i])) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second and third arguments must be ASCII strings"); + map[static_cast(map_from[i])] = static_cast(map_to[i]); + } - map[map_from[i]] = map_to[i]; + // Handle any remaining characters in map_from by assigning a default value + for (size_t i = min_size; i < map_from.size(); ++i) + { + if (!isASCII(map_from[i])) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument must be ASCII strings"); + map[static_cast(map_from[i])] = ascii_upper_bound + 1; + } + + // Validate any extra characters in map_to to ensure they are ASCII + for (size_t i = min_size; i < map_to.size(); ++i) + { + if (!isASCII(map_to[i])) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Third argument must be ASCII strings"); } } @@ -59,10 +73,11 @@ struct TranslateImpl fillMapWithValues(map, map_from, map_to); res_data.resize(data.size()); - res_offsets.assign(offsets); + res_offsets.resize(input_rows_count); UInt8 * dst = res_data.data(); + UInt64 data_size = 0; for (UInt64 i = 0; i < input_rows_count; ++i) { const UInt8 * src = data.data() + offsets[i - 1]; @@ -70,18 +85,24 @@ struct TranslateImpl while (src < src_end) { - if (*src <= ascii_upper_bound) - *dst = map[*src]; - else - *dst = *src; + if (*src <= ascii_upper_bound && map[*src] != ascii_upper_bound + 1) + { + *dst++ = map[*src]; + ++data_size; + } + else if (*src > ascii_upper_bound) + { + *dst++ = *src; + ++data_size; + } ++src; - ++dst; } /// Technically '\0' can be mapped into other character, /// so we need to process '\0' delimiter separately *dst++ = 0; + res_offsets[i] = ++data_size; } } @@ -92,6 +113,9 @@ struct TranslateImpl const std::string & map_to, ColumnString::Chars & res_data) { + if (map_from.size() != map_to.size()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second and third arguments must be the same length"); + std::array map; fillMapWithValues(map, map_from, map_to); @@ -128,12 +152,6 @@ struct TranslateUTF8Impl const std::string & map_from, const std::string & map_to) { - auto map_from_size = UTF8::countCodePoints(reinterpret_cast(map_from.data()), map_from.size()); - auto map_to_size = UTF8::countCodePoints(reinterpret_cast(map_to.data()), map_to.size()); - - if (map_from_size != map_to_size) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second and third arguments must be the same length"); - iota(map_ascii.data(), map_ascii.size(), UInt32(0)); const UInt8 * map_from_ptr = reinterpret_cast(map_from.data()); @@ -141,32 +159,44 @@ struct TranslateUTF8Impl const UInt8 * map_to_ptr = reinterpret_cast(map_to.data()); const UInt8 * map_to_end = map_to_ptr + map_to.size(); - while (map_from_ptr < map_from_end && map_to_ptr < map_to_end) + while (map_from_ptr < map_from_end) { size_t len_from = UTF8::seqLength(*map_from_ptr); - size_t len_to = UTF8::seqLength(*map_to_ptr); std::optional res_from, res_to; if (map_from_ptr + len_from <= map_from_end) res_from = UTF8::convertUTF8ToCodePoint(map_from_ptr, len_from); - if (map_to_ptr + len_to <= map_to_end) - res_to = UTF8::convertUTF8ToCodePoint(map_to_ptr, len_to); - if (!res_from) throw Exception(ErrorCodes::BAD_ARGUMENTS, "Second argument must be a valid UTF-8 string"); - if (!res_to) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Third argument must be a valid UTF-8 string"); + if (map_to_ptr < map_to_end) + { + size_t len_to = UTF8::seqLength(*map_to_ptr); - if (*map_from_ptr <= ascii_upper_bound) - map_ascii[*map_from_ptr] = *res_to; + if (map_to_ptr + len_to <= map_to_end) + res_to = UTF8::convertUTF8ToCodePoint(map_to_ptr, len_to); + + if (!res_to) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Third argument must be a valid UTF-8 string"); + + if (*map_from_ptr <= ascii_upper_bound) + map_ascii[*map_from_ptr] = *res_to; + else + map[*res_from] = *res_to; + + map_to_ptr += len_to; + } else - map[*res_from] = *res_to; + { + if (*map_from_ptr <= ascii_upper_bound) + map_ascii[*map_from_ptr] = max_uint32; + else + map[*res_from] = max_uint32; + } map_from_ptr += len_from; - map_to_ptr += len_to; } } @@ -205,6 +235,12 @@ struct TranslateUTF8Impl if (*src <= ascii_upper_bound) { + if (map_ascii[*src] == max_uint32) + { + src += 1; + continue; + } + size_t dst_len = UTF8::convertCodePointToUTF8(map_ascii[*src], dst, 4); assert(0 < dst_len && dst_len <= 4); @@ -226,10 +262,13 @@ struct TranslateUTF8Impl auto * it = map.find(*src_code_point); if (it != map.end()) { + src += src_len; + if (it->getMapped() == max_uint32) + continue; + size_t dst_len = UTF8::convertCodePointToUTF8(it->getMapped(), dst, 4); assert(0 < dst_len && dst_len <= 4); - src += src_len; dst += dst_len; data_size += dst_len; continue; @@ -270,6 +309,7 @@ struct TranslateUTF8Impl private: static constexpr auto ascii_upper_bound = '\x7f'; + static constexpr auto max_uint32 = 0xffffffff; }; diff --git a/tests/queries/0_stateless/02353_translate.reference b/tests/queries/0_stateless/02353_translate.reference index 557b5182127..16bf6bc6015 100644 --- a/tests/queries/0_stateless/02353_translate.reference +++ b/tests/queries/0_stateless/02353_translate.reference @@ -14,3 +14,9 @@ HotelGenev ¿йðՅনй abc abc +abc +内码 +1A2BC +1A2B3C +ABC + diff --git a/tests/queries/0_stateless/02353_translate.sql b/tests/queries/0_stateless/02353_translate.sql index f6f40c4265d..62a4eda0115 100644 --- a/tests/queries/0_stateless/02353_translate.sql +++ b/tests/queries/0_stateless/02353_translate.sql @@ -10,4 +10,12 @@ SELECT translate('abc', '', ''); SELECT translateUTF8('abc', '', ''); SELECT translate('abc', 'Ááéíóúôè', 'aaeiouoe'); -- { serverError BAD_ARGUMENTS } -SELECT translateUTF8('abc', 'efg', ''); -- { serverError BAD_ARGUMENTS } +SELECT translateUTF8('abc', 'efg', ''); + +SELECT translateUTF8('中文内码', '中文', ''); +SELECT translate('aAbBcC', 'abc', '12'); + +SELECT translate('aAbBcC', 'abc', '1235'); +SELECT translate('aAbBcC', 'abc', ''); +SELECT translate('abc', 'abc', ''); +SELECT translate('aAbBcC', '中文内码', '12'); -- { serverError BAD_ARGUMENTS }