Change the logic

This commit is contained in:
Yarik Briukhovetskyi 2024-12-02 20:22:04 +01:00 committed by GitHub
parent c2f74fa4aa
commit deddd1db31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 56 additions and 51 deletions

View File

@ -15,8 +15,9 @@ groupConcat[(delimiter [, limit])](expression);
**Arguments** **Arguments**
- `delimiter` — A [string](../../../sql-reference/data-types/string.md) that will be used to separate concatenated values. This parameter is optional and defaults to an empty string or delimiter from parameters if not specified.
- `expression` — The expression or column name that outputs strings to be concatenated. - `expression` — The expression or column name that outputs strings to be concatenated.
- `delimiter` — A [string](../../../sql-reference/data-types/string.md) that will be used to separate concatenated values. This parameter is optional and defaults to an empty string if not specified.
**Parameters** **Parameters**

View File

@ -31,24 +31,16 @@ namespace ErrorCodes
extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION; extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int BAD_ARGUMENTS; extern const int BAD_ARGUMENTS;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
} }
namespace namespace
{ {
enum GroupConcatOverload
{
ONE_ARGUMENT,
TWO_ARGUMENTS
};
struct GroupConcatDataBase struct GroupConcatDataBase
{ {
UInt64 data_size = 0; UInt64 data_size = 0;
UInt64 allocated_size = 0; UInt64 allocated_size = 0;
char * data = nullptr; char * data = nullptr;
String data_delimiter;
void checkAndUpdateSize(UInt64 add, Arena * arena) void checkAndUpdateSize(UInt64 add, Arena * arena)
{ {
@ -125,18 +117,16 @@ class GroupConcatImpl final
UInt64 limit; UInt64 limit;
const String delimiter; const String delimiter;
const DataTypePtr type; const DataTypePtr type;
GroupConcatOverload overload = ONE_ARGUMENT;
public: public:
GroupConcatImpl(const DataTypes & data_types, const Array & parameters_, UInt64 limit_, const String & delimiter_) GroupConcatImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_)
: IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>( : IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>(
{data_types}, parameters_, std::make_shared<DataTypeString>()) {data_type_}, parameters_, std::make_shared<DataTypeString>())
, limit(limit_) , limit(limit_)
, delimiter(delimiter_) , delimiter(delimiter_)
, type(data_types[0]) , type(data_type_)
{ {
serialization = isFixedString(type) ? std::make_shared<DataTypeString>()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization(); serialization = isFixedString(type) ? std::make_shared<DataTypeString>()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization();
overload = data_types.size() > 1 ? TWO_ARGUMENTS : ONE_ARGUMENT;
} }
String getName() const override { return name; } String getName() const override { return name; }
@ -144,20 +134,13 @@ public:
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{ {
auto & cur_data = this->data(place); auto & cur_data = this->data(place);
if (cur_data.data_delimiter.empty())
{
cur_data.data_delimiter = delimiter;
if (overload == GroupConcatOverload::TWO_ARGUMENTS)
cur_data.data_delimiter = columns[1]->getDataAt(0).toString();
}
const String & delim = cur_data.data_delimiter;
if constexpr (has_limit) if constexpr (has_limit)
if (cur_data.num_rows >= limit) if (cur_data.num_rows >= limit)
return; return;
if (cur_data.data_size != 0) if (cur_data.data_size != 0)
cur_data.insertChar(delim.c_str(), delim.size(), arena); cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena);
if (isFixedString(type)) if (isFixedString(type))
{ {
@ -177,15 +160,13 @@ public:
if (rhs_data.data_size == 0) if (rhs_data.data_size == 0)
return; return;
const String & delim = cur_data.data_delimiter;
if constexpr (has_limit) if constexpr (has_limit)
{ {
UInt64 new_elems_count = std::min(rhs_data.num_rows, limit - cur_data.num_rows); UInt64 new_elems_count = std::min(rhs_data.num_rows, limit - cur_data.num_rows);
for (UInt64 i = 0; i < new_elems_count; ++i) for (UInt64 i = 0; i < new_elems_count; ++i)
{ {
if (cur_data.data_size != 0) if (cur_data.data_size != 0)
cur_data.insertChar(delim.c_str(), delim.size(), arena); cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena);
cur_data.offsets.push_back(cur_data.data_size, arena); cur_data.offsets.push_back(cur_data.data_size, arena);
cur_data.insertChar(rhs_data.data + rhs_data.getString(i), rhs_data.getSize(i), arena); cur_data.insertChar(rhs_data.data + rhs_data.getString(i), rhs_data.getSize(i), arena);
@ -196,7 +177,7 @@ public:
else else
{ {
if (cur_data.data_size != 0) if (cur_data.data_size != 0)
cur_data.insertChar(delim.c_str(), delim.size(), arena); cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena);
cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena); cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena);
} }
@ -257,16 +238,9 @@ public:
}; };
AggregateFunctionPtr createAggregateFunctionGroupConcat( AggregateFunctionPtr createAggregateFunctionGroupConcat(
const std::string & name, const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
const DataTypes & argument_types,
const Array & parameters,
const Settings * /* settings */
)
{ {
// Ensure the number of arguments is between 1 and 2 assertUnary(name, argument_types);
if (argument_types.empty() || argument_types.size() > 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Incorrect number of arguments for function {}, expected 1 to 2, got {}", name, argument_types.size());
bool has_limit = false; bool has_limit = false;
UInt64 limit = 0; UInt64 limit = 0;
@ -299,19 +273,9 @@ AggregateFunctionPtr createAggregateFunctionGroupConcat(
limit = parameters[1].safeGet<UInt64>(); limit = parameters[1].safeGet<UInt64>();
} }
// Handle additional arguments if provided (delimiter and limit as arguments)
if (argument_types.size() == 2)
{
// Second argument should be delimiter (string)
const DataTypePtr & delimiter_type = argument_types[1];
if (!isString(delimiter_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Second argument for function {} must be a string", name);
}
if (has_limit) if (has_limit)
return std::make_shared<GroupConcatImpl</* has_limit= */ true>>(argument_types, parameters, limit, delimiter); return std::make_shared<GroupConcatImpl</* has_limit= */ true>>(argument_types[0], parameters, limit, delimiter);
return std::make_shared<GroupConcatImpl</* has_limit= */ false>>(argument_types, parameters, limit, delimiter); return std::make_shared<GroupConcatImpl</* has_limit= */ false>>(argument_types[0], parameters, limit, delimiter);
} }
} }

View File

@ -130,13 +130,51 @@ String ASTFunction::getID(char delim) const
return "Function" + (delim + name); return "Function" + (delim + name);
} }
void ASTFunction::groupConcatArgumentOverride(std::shared_ptr<ASTFunction> res) const
{
// Clone the first argument to be used as a parameter
ASTPtr first_arg = arguments->children[0]->clone();
// Clone the second argument to remain as the function argument
ASTPtr second_arg = arguments->children[1]->clone();
// Initialize or clear parameters
if (!res->parameters)
res->parameters = std::make_shared<ASTExpressionList>();
else
res->parameters->children.clear();
// Add the first argument as a parameter
res->parameters->children.emplace_back(first_arg);
res->children.emplace_back(res->parameters);
// Initialize arguments with the second argument only
res->arguments = std::make_shared<ASTExpressionList>();
res->arguments->children.emplace_back(second_arg);
res->children.emplace_back(res->arguments);
}
ASTPtr ASTFunction::clone() const ASTPtr ASTFunction::clone() const
{ {
auto res = std::make_shared<ASTFunction>(*this); auto res = std::make_shared<ASTFunction>(*this);
res->children.clear(); res->children.clear();
if (arguments) { res->arguments = arguments->clone(); res->children.push_back(res->arguments); } // Special handling for groupConcat with two arguments
if (parameters) { res->parameters = parameters->clone(); res->children.push_back(res->parameters); } if (name == "groupConcat" && arguments && arguments->children.size() == 2)
groupConcatArgumentOverride(res);
else
{
if (arguments)
{
res->arguments = arguments->clone();
res->children.push_back(res->arguments);
}
if (parameters)
{
res->parameters = parameters->clone();
res->children.push_back(res->parameters);
}
}
if (window_definition) if (window_definition)
{ {

View File

@ -81,6 +81,8 @@ public:
bool hasSecretParts() const override; bool hasSecretParts() const override;
void groupConcatArgumentOverride(std::shared_ptr<ASTFunction> res) const;
protected: protected:
void formatImplWithoutAlias(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override; void formatImplWithoutAlias(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override;
void appendColumnNameImpl(WriteBuffer & ostr) const override; void appendColumnNameImpl(WriteBuffer & ostr) const override;

View File

@ -44,9 +44,9 @@ SELECT length(groupConcat(number)) FROM numbers(100000);
SELECT 'TESTING GroupConcat second argument overload'; SELECT 'TESTING GroupConcat second argument overload';
SELECT groupConcat(p_int, ',') FROM test_groupConcat; SELECT groupConcat(',', p_int) FROM test_groupConcat;
SELECT groupConcat('.')(p_string) FROM test_groupConcat; SELECT groupConcat('.')(p_string) FROM test_groupConcat;
SELECT groupConcat(p_array, '/') FROM test_groupConcat; SELECT groupConcat('/', p_array) FROM test_groupConcat;
DROP TABLE IF EXISTS test_groupConcat; DROP TABLE IF EXISTS test_groupConcat;