From deddd1db312822460dcf163c7ccf1160862af6a1 Mon Sep 17 00:00:00 2001 From: Yarik Briukhovetskyi <114298166+yariks5s@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:22:04 +0100 Subject: [PATCH] Change the logic --- .../reference/groupconcat.md | 3 +- .../AggregateFunctionGroupConcat.cpp | 56 ++++--------------- src/Parsers/ASTFunction.cpp | 42 +++++++++++++- src/Parsers/ASTFunction.h | 2 + .../0_stateless/03156_group_concat.sql | 4 +- 5 files changed, 56 insertions(+), 51 deletions(-) diff --git a/docs/en/sql-reference/aggregate-functions/reference/groupconcat.md b/docs/en/sql-reference/aggregate-functions/reference/groupconcat.md index 7f22e4125a6..304676f1819 100644 --- a/docs/en/sql-reference/aggregate-functions/reference/groupconcat.md +++ b/docs/en/sql-reference/aggregate-functions/reference/groupconcat.md @@ -15,8 +15,9 @@ groupConcat[(delimiter [, limit])](expression); **Arguments** +- `delimiter` — A [string](../../../sql-reference/data-types/string.md) that will be used to separate concatenated values. This parameter is optional and defaults to an empty string or delimiter from parameters if not specified. - `expression` — The expression or column name that outputs strings to be concatenated. -- `delimiter` — A [string](../../../sql-reference/data-types/string.md) that will be used to separate concatenated values. This parameter is optional and defaults to an empty string if not specified. + **Parameters** diff --git a/src/AggregateFunctions/AggregateFunctionGroupConcat.cpp b/src/AggregateFunctions/AggregateFunctionGroupConcat.cpp index f27c99db0b0..8cf5ec5705a 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupConcat.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupConcat.cpp @@ -31,24 +31,16 @@ namespace ErrorCodes extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int BAD_ARGUMENTS; - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } namespace { -enum GroupConcatOverload -{ - ONE_ARGUMENT, - TWO_ARGUMENTS -}; - struct GroupConcatDataBase { UInt64 data_size = 0; UInt64 allocated_size = 0; char * data = nullptr; - String data_delimiter; void checkAndUpdateSize(UInt64 add, Arena * arena) { @@ -125,18 +117,16 @@ class GroupConcatImpl final UInt64 limit; const String delimiter; const DataTypePtr type; - GroupConcatOverload overload = ONE_ARGUMENT; public: - GroupConcatImpl(const DataTypes & data_types, const Array & parameters_, UInt64 limit_, const String & delimiter_) + GroupConcatImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_) : IAggregateFunctionDataHelper, GroupConcatImpl>( - {data_types}, parameters_, std::make_shared()) + {data_type_}, parameters_, std::make_shared()) , limit(limit_) , delimiter(delimiter_) - , type(data_types[0]) + , type(data_type_) { serialization = isFixedString(type) ? std::make_shared()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization(); - overload = data_types.size() > 1 ? TWO_ARGUMENTS : ONE_ARGUMENT; } String getName() const override { return name; } @@ -144,20 +134,13 @@ public: void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override { auto & cur_data = this->data(place); - if (cur_data.data_delimiter.empty()) - { - cur_data.data_delimiter = delimiter; - if (overload == GroupConcatOverload::TWO_ARGUMENTS) - cur_data.data_delimiter = columns[1]->getDataAt(0).toString(); - } - const String & delim = cur_data.data_delimiter; if constexpr (has_limit) if (cur_data.num_rows >= limit) return; if (cur_data.data_size != 0) - cur_data.insertChar(delim.c_str(), delim.size(), arena); + cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena); if (isFixedString(type)) { @@ -177,15 +160,13 @@ public: if (rhs_data.data_size == 0) return; - const String & delim = cur_data.data_delimiter; - if constexpr (has_limit) { UInt64 new_elems_count = std::min(rhs_data.num_rows, limit - cur_data.num_rows); for (UInt64 i = 0; i < new_elems_count; ++i) { if (cur_data.data_size != 0) - cur_data.insertChar(delim.c_str(), delim.size(), arena); + cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena); cur_data.offsets.push_back(cur_data.data_size, arena); cur_data.insertChar(rhs_data.data + rhs_data.getString(i), rhs_data.getSize(i), arena); @@ -196,7 +177,7 @@ public: else { if (cur_data.data_size != 0) - cur_data.insertChar(delim.c_str(), delim.size(), arena); + cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena); cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena); } @@ -257,16 +238,9 @@ public: }; AggregateFunctionPtr createAggregateFunctionGroupConcat( - const std::string & name, - const DataTypes & argument_types, - const Array & parameters, - const Settings * /* settings */ -) + const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) { - // Ensure the number of arguments is between 1 and 2 - if (argument_types.empty() || argument_types.size() > 2) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Incorrect number of arguments for function {}, expected 1 to 2, got {}", name, argument_types.size()); + assertUnary(name, argument_types); bool has_limit = false; UInt64 limit = 0; @@ -299,19 +273,9 @@ AggregateFunctionPtr createAggregateFunctionGroupConcat( limit = parameters[1].safeGet(); } - // Handle additional arguments if provided (delimiter and limit as arguments) - if (argument_types.size() == 2) - { - // Second argument should be delimiter (string) - const DataTypePtr & delimiter_type = argument_types[1]; - if (!isString(delimiter_type)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Second argument for function {} must be a string", name); - } - if (has_limit) - return std::make_shared>(argument_types, parameters, limit, delimiter); - return std::make_shared>(argument_types, parameters, limit, delimiter); + return std::make_shared>(argument_types[0], parameters, limit, delimiter); + return std::make_shared>(argument_types[0], parameters, limit, delimiter); } } diff --git a/src/Parsers/ASTFunction.cpp b/src/Parsers/ASTFunction.cpp index 11cfe2e584e..358b7b1e26b 100644 --- a/src/Parsers/ASTFunction.cpp +++ b/src/Parsers/ASTFunction.cpp @@ -130,13 +130,51 @@ String ASTFunction::getID(char delim) const return "Function" + (delim + name); } +void ASTFunction::groupConcatArgumentOverride(std::shared_ptr res) const +{ + // Clone the first argument to be used as a parameter + ASTPtr first_arg = arguments->children[0]->clone(); + + // Clone the second argument to remain as the function argument + ASTPtr second_arg = arguments->children[1]->clone(); + + // Initialize or clear parameters + if (!res->parameters) + res->parameters = std::make_shared(); + else + res->parameters->children.clear(); + + // Add the first argument as a parameter + res->parameters->children.emplace_back(first_arg); + res->children.emplace_back(res->parameters); + + // Initialize arguments with the second argument only + res->arguments = std::make_shared(); + res->arguments->children.emplace_back(second_arg); + res->children.emplace_back(res->arguments); +} + ASTPtr ASTFunction::clone() const { auto res = std::make_shared(*this); res->children.clear(); - if (arguments) { res->arguments = arguments->clone(); res->children.push_back(res->arguments); } - if (parameters) { res->parameters = parameters->clone(); res->children.push_back(res->parameters); } + // Special handling for groupConcat with two arguments + if (name == "groupConcat" && arguments && arguments->children.size() == 2) + groupConcatArgumentOverride(res); + else + { + if (arguments) + { + res->arguments = arguments->clone(); + res->children.push_back(res->arguments); + } + if (parameters) + { + res->parameters = parameters->clone(); + res->children.push_back(res->parameters); + } + } if (window_definition) { diff --git a/src/Parsers/ASTFunction.h b/src/Parsers/ASTFunction.h index 1b4a5928d1c..b6aae46d21e 100644 --- a/src/Parsers/ASTFunction.h +++ b/src/Parsers/ASTFunction.h @@ -81,6 +81,8 @@ public: bool hasSecretParts() const override; + void groupConcatArgumentOverride(std::shared_ptr res) const; + protected: void formatImplWithoutAlias(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override; void appendColumnNameImpl(WriteBuffer & ostr) const override; diff --git a/tests/queries/0_stateless/03156_group_concat.sql b/tests/queries/0_stateless/03156_group_concat.sql index 8989bb20df9..c6fea68551a 100644 --- a/tests/queries/0_stateless/03156_group_concat.sql +++ b/tests/queries/0_stateless/03156_group_concat.sql @@ -44,9 +44,9 @@ SELECT length(groupConcat(number)) FROM numbers(100000); SELECT 'TESTING GroupConcat second argument overload'; -SELECT groupConcat(p_int, ',') FROM test_groupConcat; +SELECT groupConcat(',', p_int) FROM test_groupConcat; SELECT groupConcat('.')(p_string) FROM test_groupConcat; -SELECT groupConcat(p_array, '/') FROM test_groupConcat; +SELECT groupConcat('/', p_array) FROM test_groupConcat; DROP TABLE IF EXISTS test_groupConcat;