From c04a28d0095dc0296e1af4f08072826adfc37019 Mon Sep 17 00:00:00 2001 From: Yarik Briukhovetskyi <114298166+yariks5s@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:31:45 +0100 Subject: [PATCH] Change the logic to use analyzer --- .../AggregateFunctionGroupConcat.cpp | 340 ++++++++---------- .../AggregateFunctionGroupConcat.h | 78 ++++ src/Analyzer/QueryTreeBuilder.cpp | 94 +++-- src/Parsers/ASTFunction.cpp | 42 +-- src/Parsers/ASTFunction.h | 2 - 5 files changed, 304 insertions(+), 252 deletions(-) create mode 100644 src/AggregateFunctions/AggregateFunctionGroupConcat.h diff --git a/src/AggregateFunctions/AggregateFunctionGroupConcat.cpp b/src/AggregateFunctions/AggregateFunctionGroupConcat.cpp index 8cf5ec5705a..81dd7541e3a 100644 --- a/src/AggregateFunctions/AggregateFunctionGroupConcat.cpp +++ b/src/AggregateFunctions/AggregateFunctionGroupConcat.cpp @@ -1,27 +1,7 @@ -#include -#include -#include - -#include -#include +#include #include - -#include -#include - -#include -#include #include -#include -#include -#include -#include - -#include -#include - - namespace DB { struct Settings; @@ -33,209 +13,193 @@ namespace ErrorCodes extern const int BAD_ARGUMENTS; } -namespace +void GroupConcatDataBase::checkAndUpdateSize(UInt64 add, Arena * arena) { + if (data_size + add >= allocated_size) + { + auto old_size = allocated_size; + allocated_size = std::max(2 * allocated_size, data_size + add); + data = arena->realloc(data, old_size, allocated_size); + } +} -struct GroupConcatDataBase +void GroupConcatDataBase::insertChar(const char * str, UInt64 str_size, Arena * arena) { - UInt64 data_size = 0; - UInt64 allocated_size = 0; - char * data = nullptr; + checkAndUpdateSize(str_size, arena); + memcpy(data + data_size, str, str_size); + data_size += str_size; +} - void checkAndUpdateSize(UInt64 add, Arena * arena) - { - if (data_size + add >= allocated_size) - { - auto old_size = allocated_size; - allocated_size = std::max(2 * allocated_size, data_size + add); - data = arena->realloc(data, old_size, allocated_size); - } - } - - void insertChar(const char * str, UInt64 str_size, Arena * arena) - { - checkAndUpdateSize(str_size, arena); - memcpy(data + data_size, str, str_size); - data_size += str_size; - } - - void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena) - { - WriteBufferFromOwnString buff; - serialization->serializeText(*column, row_num, buff, FormatSettings{}); - auto string = buff.stringView(); - insertChar(string.data(), string.size(), arena); - } - -}; +void GroupConcatDataBase::insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena) +{ + WriteBufferFromOwnString buff; + serialization->serializeText(*column, row_num, buff, FormatSettings{}); + auto string = buff.stringView(); + insertChar(string.data(), string.size(), arena); +} template -struct GroupConcatData; - -template<> -struct GroupConcatData final : public GroupConcatDataBase +UInt64 GroupConcatData::getSize(size_t i) const { -}; - -template<> -struct GroupConcatData final : public GroupConcatDataBase -{ - using Offset = UInt64; - using Allocator = MixedAlignedArenaAllocator; - using Offsets = PODArray; - - /// offset[i * 2] - beginning of the i-th row, offset[i * 2 + 1] - end of the i-th row - Offsets offsets; - UInt64 num_rows = 0; - - UInt64 getSize(size_t i) const { return offsets[i * 2 + 1] - offsets[i * 2]; } - - UInt64 getString(size_t i) const { return offsets[i * 2]; } - - void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena) - { - WriteBufferFromOwnString buff; - serialization->serializeText(*column, row_num, buff, {}); - auto string = buff.stringView(); - - checkAndUpdateSize(string.size(), arena); - memcpy(data + data_size, string.data(), string.size()); - offsets.push_back(data_size, arena); - data_size += string.size(); - offsets.push_back(data_size, arena); - num_rows++; - } -}; + return offsets[i * 2 + 1] - offsets[i * 2]; +} template -class GroupConcatImpl final - : public IAggregateFunctionDataHelper, GroupConcatImpl> +UInt64 GroupConcatData::getString(size_t i) const { - static constexpr auto name = "groupConcat"; + return offsets[i * 2]; +} - SerializationPtr serialization; - UInt64 limit; - const String delimiter; - const DataTypePtr type; +template +void GroupConcatData::insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena) +{ + WriteBufferFromOwnString buff; + serialization->serializeText(*column, row_num, buff, {}); + auto string = buff.stringView(); -public: - GroupConcatImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_) - : IAggregateFunctionDataHelper, GroupConcatImpl>( - {data_type_}, parameters_, std::make_shared()) - , limit(limit_) - , delimiter(delimiter_) - , type(data_type_) - { - serialization = isFixedString(type) ? std::make_shared()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization(); - } + checkAndUpdateSize(string.size(), arena); + memcpy(data + data_size, string.data(), string.size()); + offsets.push_back(data_size, arena); + data_size += string.size(); + offsets.push_back(data_size, arena); + num_rows++; +} - String getName() const override { return name; } +template +GroupConcatImpl::GroupConcatImpl( + const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_) + : IAggregateFunctionDataHelper, GroupConcatImpl>( + {data_type_}, parameters_, std::make_shared()) + , limit(limit_) + , delimiter(delimiter_) + , type(data_type_) +{ + serialization = isFixedString(type) ? std::make_shared()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization(); +} - void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override - { - auto & cur_data = this->data(place); +template +String GroupConcatImpl::getName() const +{ + return name; +} - if constexpr (has_limit) - if (cur_data.num_rows >= limit) - return; - if (cur_data.data_size != 0) - cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena); +template +void GroupConcatImpl::add( + AggregateDataPtr __restrict place, + const IColumn ** columns, + size_t row_num, + Arena * arena) const +{ + auto & cur_data = this->data(place); - if (isFixedString(type)) - { - ColumnWithTypeAndName col = {columns[0]->getPtr(), type, "column"}; - const auto & col_str = castColumn(col, std::make_shared()); - cur_data.insert(col_str.get(), serialization, row_num, arena); - } - else - cur_data.insert(columns[0], serialization, row_num, arena); - } - - void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override - { - auto & cur_data = this->data(place); - auto & rhs_data = this->data(rhs); - - if (rhs_data.data_size == 0) + if constexpr (has_limit) + if (cur_data.num_rows >= limit) return; - if constexpr (has_limit) - { - UInt64 new_elems_count = std::min(rhs_data.num_rows, limit - cur_data.num_rows); - for (UInt64 i = 0; i < new_elems_count; ++i) - { - if (cur_data.data_size != 0) - cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena); + if (cur_data.data_size != 0) + cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena); - cur_data.offsets.push_back(cur_data.data_size, arena); - cur_data.insertChar(rhs_data.data + rhs_data.getString(i), rhs_data.getSize(i), arena); - cur_data.num_rows++; - cur_data.offsets.push_back(cur_data.data_size, arena); - } - } - else + if (isFixedString(type)) + { + ColumnWithTypeAndName col = {columns[0]->getPtr(), type, "column"}; + const auto & col_str = castColumn(col, std::make_shared()); + cur_data.insert(col_str.get(), serialization, row_num, arena); + } + else + cur_data.insert(columns[0], serialization, row_num, arena); +} + +template +void GroupConcatImpl::merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const +{ + auto & cur_data = this->data(place); + auto & rhs_data = this->data(rhs); + + if (rhs_data.data_size == 0) + return; + + if constexpr (has_limit) + { + UInt64 new_elems_count = std::min(rhs_data.num_rows, limit - cur_data.num_rows); + for (UInt64 i = 0; i < new_elems_count; ++i) { if (cur_data.data_size != 0) cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena); - cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena); + cur_data.offsets.push_back(cur_data.data_size, arena); + cur_data.insertChar(rhs_data.data + rhs_data.getString(i), rhs_data.getSize(i), arena); + cur_data.num_rows++; + cur_data.offsets.push_back(cur_data.data_size, arena); } } - - void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const override + else { - auto & cur_data = this->data(place); + if (cur_data.data_size != 0) + cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena); - writeVarUInt(cur_data.data_size, buf); - - buf.write(cur_data.data, cur_data.data_size); - - if constexpr (has_limit) - { - writeVarUInt(cur_data.num_rows, buf); - for (const auto & offset : cur_data.offsets) - writeVarUInt(offset, buf); - } + cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena); } +} - void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena * arena) const override +template +void GroupConcatImpl::serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional /* version */) const +{ + auto & cur_data = this->data(place); + + writeVarUInt(cur_data.data_size, buf); + + buf.write(cur_data.data, cur_data.data_size); + + if constexpr (has_limit) { - auto & cur_data = this->data(place); - - UInt64 temp_size = 0; - readVarUInt(temp_size, buf); - - cur_data.checkAndUpdateSize(temp_size, arena); - - buf.readStrict(cur_data.data + cur_data.data_size, temp_size); - cur_data.data_size = temp_size; - - if constexpr (has_limit) - { - readVarUInt(cur_data.num_rows, buf); - cur_data.offsets.resize_exact(cur_data.num_rows * 2, arena); - for (auto & offset : cur_data.offsets) - readVarUInt(offset, buf); - } + writeVarUInt(cur_data.num_rows, buf); + for (const auto & offset : cur_data.offsets) + writeVarUInt(offset, buf); } +} - void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override +template +void GroupConcatImpl::deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional /* version */, Arena * arena) const +{ + auto & cur_data = this->data(place); + + UInt64 temp_size = 0; + readVarUInt(temp_size, buf); + + cur_data.checkAndUpdateSize(temp_size, arena); + + buf.readStrict(cur_data.data + cur_data.data_size, temp_size); + cur_data.data_size = temp_size; + + if constexpr (has_limit) { - auto & cur_data = this->data(place); + readVarUInt(cur_data.num_rows, buf); + cur_data.offsets.resize_exact(cur_data.num_rows * 2, arena); + for (auto & offset : cur_data.offsets) + readVarUInt(offset, buf); + } +} - if (cur_data.data_size == 0) - { - to.insertDefault(); - return; - } +template +void GroupConcatImpl::insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const +{ + auto & cur_data = this->data(place); - auto & column_string = assert_cast(to); - column_string.insertData(cur_data.data, cur_data.data_size); + if (cur_data.data_size == 0) + { + to.insertDefault(); + return; } - bool allocatesMemoryInArena() const override { return true; } -}; + auto & column_string = assert_cast(to); + column_string.insertData(cur_data.data, cur_data.data_size); +} + +template +bool GroupConcatImpl::allocatesMemoryInArena() const { return true; } + +// Implementation of add, merge, serialize, deserialize, insertResultInto, etc. remains unchanged. AggregateFunctionPtr createAggregateFunctionGroupConcat( const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) @@ -278,14 +242,12 @@ AggregateFunctionPtr createAggregateFunctionGroupConcat( return std::make_shared>(argument_types[0], parameters, limit, delimiter); } -} - void registerAggregateFunctionGroupConcat(AggregateFunctionFactory & factory) { AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true }; factory.registerFunction("groupConcat", { createAggregateFunctionGroupConcat, properties }); - factory.registerAlias("group_concat", "groupConcat", AggregateFunctionFactory::Case::Insensitive); + factory.registerAlias(GroupConcatImpl::getNameAndAliases().at(1), GroupConcatImpl::getNameAndAliases().at(0), AggregateFunctionFactory::Case::Insensitive); } } diff --git a/src/AggregateFunctions/AggregateFunctionGroupConcat.h b/src/AggregateFunctions/AggregateFunctionGroupConcat.h new file mode 100644 index 00000000000..290b8d80050 --- /dev/null +++ b/src/AggregateFunctions/AggregateFunctionGroupConcat.h @@ -0,0 +1,78 @@ +#ifndef DB_GROUP_CONCAT_H +#define DB_GROUP_CONCAT_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +struct Settings; + +struct GroupConcatDataBase +{ + UInt64 data_size = 0; + UInt64 allocated_size = 0; + char * data = nullptr; + + void checkAndUpdateSize(UInt64 add, Arena * arena); + void insertChar(const char * str, UInt64 str_size, Arena * arena); + void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena); +}; + +template +struct GroupConcatData : public GroupConcatDataBase +{ + using Offset = UInt64; + using Allocator = MixedAlignedArenaAllocator; + using Offsets = PODArray; + + Offsets offsets; + UInt64 num_rows = 0; + + UInt64 getSize(size_t i) const; + UInt64 getString(size_t i) const; + + void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena); +}; + +template +class GroupConcatImpl : public IAggregateFunctionDataHelper, GroupConcatImpl> +{ + static constexpr auto name = "groupConcat"; + + constexpr static std::array names_and_aliases = { "groupConcat", "group_concat" }; + + SerializationPtr serialization; + UInt64 limit; + const String delimiter; + const DataTypePtr type; + +public: + GroupConcatImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_); + + String getName() const override; + + static std::array getNameAndAliases() + { + return names_and_aliases; + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override; + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override; + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf, std::optional version) const override; + void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional version, Arena * arena) const override; + void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override; + + bool allocatesMemoryInArena() const override; +}; + +} // namespace DB + +#endif // DB_GROUP_CONCAT_H diff --git a/src/Analyzer/QueryTreeBuilder.cpp b/src/Analyzer/QueryTreeBuilder.cpp index d3c88d39213..69e4598a941 100644 --- a/src/Analyzer/QueryTreeBuilder.cpp +++ b/src/Analyzer/QueryTreeBuilder.cpp @@ -122,6 +122,8 @@ private: ColumnTransformersNodes buildColumnTransformers(const ASTPtr & matcher_expression, const ContextPtr & context) const; + QueryTreeNodePtr setFirstArgumentAsParameter(const ASTFunction * function, const ContextPtr & context) const; + ASTPtr query; QueryTreeNodePtr query_tree_node; }; @@ -643,32 +645,44 @@ QueryTreeNodePtr QueryTreeBuilder::buildExpression(const ASTPtr & expression, co } else { - auto function_node = std::make_shared(function->name); - function_node->setNullsAction(function->nulls_action); - - if (function->parameters) + const char * name = function->name.c_str(); + // Check if the function is groupConcat with exactly two arguments + if (std::any_of(GroupConcatImpl::getNameAndAliases().begin(), + GroupConcatImpl::getNameAndAliases().end(), + [name](const char *alias) { return std::strcmp(name, alias) == 0; }) + && function->arguments && function->arguments->children.size() == 2) { - const auto & function_parameters_list = function->parameters->as()->children; - for (const auto & argument : function_parameters_list) - function_node->getParameters().getNodes().push_back(buildExpression(argument, context)); + result = setFirstArgumentAsParameter(function, context); } - - if (function->arguments) + else { - const auto & function_arguments_list = function->arguments->as()->children; - for (const auto & argument : function_arguments_list) - function_node->getArguments().getNodes().push_back(buildExpression(argument, context)); - } + auto function_node = std::make_shared(function->name); + function_node->setNullsAction(function->nulls_action); - if (function->is_window_function) - { - if (function->window_definition) - function_node->getWindowNode() = buildWindow(function->window_definition, context); - else - function_node->getWindowNode() = std::make_shared(Identifier(function->window_name)); - } + if (function->parameters) + { + const auto & function_parameters_list = function->parameters->as()->children; + for (const auto & argument : function_parameters_list) + function_node->getParameters().getNodes().push_back(buildExpression(argument, context)); + } - result = std::move(function_node); + if (function->arguments) + { + const auto & function_arguments_list = function->arguments->as()->children; + for (const auto & argument : function_arguments_list) + function_node->getArguments().getNodes().push_back(buildExpression(argument, context)); + } + + if (function->is_window_function) + { + if (function->window_definition) + function_node->getWindowNode() = buildWindow(function->window_definition, context); + else + function_node->getWindowNode() = std::make_shared(Identifier(function->window_name)); + } + + result = std::move(function_node); + } } } else if (const auto * subquery = expression->as()) @@ -1071,4 +1085,42 @@ QueryTreeNodePtr buildQueryTree(ASTPtr query, ContextPtr context) return builder.getQueryTreeNode(); } +QueryTreeNodePtr QueryTreeBuilder::setFirstArgumentAsParameter(const ASTFunction * function, const ContextPtr & context) const +{ + const auto * first_arg_ast = function->arguments->children[0].get(); + const auto * first_arg_literal = first_arg_ast->as(); + + if (!first_arg_literal) + { + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "If groupConcat is used with two arguments, the first argument must be a constant String"); + } + + if (first_arg_literal->value.getType() != Field::Types::String) + { + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "If groupConcat is used with two arguments, the first argument must be a constant String"); + } + + std::string separator = first_arg_literal->value.safeGet(); + + ASTPtr second_arg = function->arguments->children[1]->clone(); + + auto function_node = std::make_shared(function->name); + function_node->setNullsAction(function->nulls_action); + + function_node->getParameters().getNodes().push_back(buildExpression(function->arguments->children[0], context)); // Separator + function_node->getArguments().getNodes().push_back(buildExpression(second_arg, context)); // Column to concatenate + + if (function->is_window_function) + { + if (function->window_definition) + function_node->getWindowNode() = buildWindow(function->window_definition, context); + else + function_node->getWindowNode() = std::make_shared(Identifier(function->window_name)); + } + + return std::move(function_node); +} + } diff --git a/src/Parsers/ASTFunction.cpp b/src/Parsers/ASTFunction.cpp index bac512545d9..11cfe2e584e 100644 --- a/src/Parsers/ASTFunction.cpp +++ b/src/Parsers/ASTFunction.cpp @@ -130,51 +130,13 @@ String ASTFunction::getID(char delim) const return "Function" + (delim + name); } -void ASTFunction::groupConcatArgumentOverride(std::shared_ptr res) const -{ - // Clone the first argument to be used as a parameter - ASTPtr first_arg = arguments->children[0]->clone(); - - // Clone the second argument to remain as the function argument - ASTPtr second_arg = arguments->children[1]->clone(); - - // Initialize or clear parameters - if (!res->parameters) - res->parameters = std::make_shared(); - else - res->parameters->children.clear(); - - // Add the first argument as a parameter - res->parameters->children.emplace_back(first_arg); - res->children.emplace_back(res->parameters); - - // Initialize arguments with the second argument only - res->arguments = std::make_shared(); - res->arguments->children.emplace_back(second_arg); - res->children.emplace_back(res->arguments); -} - ASTPtr ASTFunction::clone() const { auto res = std::make_shared(*this); res->children.clear(); - // Special handling for groupConcat with two arguments - if ((name == "groupConcat" || Poco::toLower(name) == "group_concat") && arguments && arguments->children.size() == 2) - groupConcatArgumentOverride(res); - else - { - if (arguments) - { - res->arguments = arguments->clone(); - res->children.push_back(res->arguments); - } - if (parameters) - { - res->parameters = parameters->clone(); - res->children.push_back(res->parameters); - } - } + if (arguments) { res->arguments = arguments->clone(); res->children.push_back(res->arguments); } + if (parameters) { res->parameters = parameters->clone(); res->children.push_back(res->parameters); } if (window_definition) { diff --git a/src/Parsers/ASTFunction.h b/src/Parsers/ASTFunction.h index b6aae46d21e..1b4a5928d1c 100644 --- a/src/Parsers/ASTFunction.h +++ b/src/Parsers/ASTFunction.h @@ -81,8 +81,6 @@ public: bool hasSecretParts() const override; - void groupConcatArgumentOverride(std::shared_ptr res) const; - protected: void formatImplWithoutAlias(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override; void appendColumnNameImpl(WriteBuffer & ostr) const override;