Change the logic to use analyzer

This commit is contained in:
Yarik Briukhovetskyi 2024-12-03 17:31:45 +01:00 committed by GitHub
parent 802f98856a
commit c04a28d009
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 304 additions and 252 deletions

View File

@ -1,27 +1,7 @@
#include <AggregateFunctions/IAggregateFunction.h> #include <AggregateFunctions/AggregateFunctionGroupConcat.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <Columns/IColumn.h>
#include <Columns/ColumnNullable.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
#include <Core/ServerSettings.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Common/ArenaAllocator.h>
#include <Common/assert_cast.h>
#include <Interpreters/castColumn.h> #include <Interpreters/castColumn.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
namespace DB namespace DB
{ {
struct Settings; struct Settings;
@ -33,67 +13,46 @@ namespace ErrorCodes
extern const int BAD_ARGUMENTS; extern const int BAD_ARGUMENTS;
} }
namespace void GroupConcatDataBase::checkAndUpdateSize(UInt64 add, Arena * arena)
{ {
struct GroupConcatDataBase
{
UInt64 data_size = 0;
UInt64 allocated_size = 0;
char * data = nullptr;
void checkAndUpdateSize(UInt64 add, Arena * arena)
{
if (data_size + add >= allocated_size) if (data_size + add >= allocated_size)
{ {
auto old_size = allocated_size; auto old_size = allocated_size;
allocated_size = std::max(2 * allocated_size, data_size + add); allocated_size = std::max(2 * allocated_size, data_size + add);
data = arena->realloc(data, old_size, allocated_size); data = arena->realloc(data, old_size, allocated_size);
} }
} }
void insertChar(const char * str, UInt64 str_size, Arena * arena) void GroupConcatDataBase::insertChar(const char * str, UInt64 str_size, Arena * arena)
{ {
checkAndUpdateSize(str_size, arena); checkAndUpdateSize(str_size, arena);
memcpy(data + data_size, str, str_size); memcpy(data + data_size, str, str_size);
data_size += str_size; data_size += str_size;
} }
void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena) void GroupConcatDataBase::insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena)
{ {
WriteBufferFromOwnString buff; WriteBufferFromOwnString buff;
serialization->serializeText(*column, row_num, buff, FormatSettings{}); serialization->serializeText(*column, row_num, buff, FormatSettings{});
auto string = buff.stringView(); auto string = buff.stringView();
insertChar(string.data(), string.size(), arena); insertChar(string.data(), string.size(), arena);
} }
};
template <bool has_limit> template <bool has_limit>
struct GroupConcatData; UInt64 GroupConcatData<has_limit>::getSize(size_t i) const
template<>
struct GroupConcatData<false> final : public GroupConcatDataBase
{ {
}; return offsets[i * 2 + 1] - offsets[i * 2];
}
template<> template <bool has_limit>
struct GroupConcatData<true> final : public GroupConcatDataBase UInt64 GroupConcatData<has_limit>::getString(size_t i) const
{ {
using Offset = UInt64; return offsets[i * 2];
using Allocator = MixedAlignedArenaAllocator<alignof(Offset), 4096>; }
using Offsets = PODArray<Offset, 32, Allocator>;
/// offset[i * 2] - beginning of the i-th row, offset[i * 2 + 1] - end of the i-th row template <bool has_limit>
Offsets offsets; void GroupConcatData<has_limit>::insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena)
UInt64 num_rows = 0; {
UInt64 getSize(size_t i) const { return offsets[i * 2 + 1] - offsets[i * 2]; }
UInt64 getString(size_t i) const { return offsets[i * 2]; }
void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena)
{
WriteBufferFromOwnString buff; WriteBufferFromOwnString buff;
serialization->serializeText(*column, row_num, buff, {}); serialization->serializeText(*column, row_num, buff, {});
auto string = buff.stringView(); auto string = buff.stringView();
@ -104,35 +63,34 @@ struct GroupConcatData<true> final : public GroupConcatDataBase
data_size += string.size(); data_size += string.size();
offsets.push_back(data_size, arena); offsets.push_back(data_size, arena);
num_rows++; num_rows++;
} }
};
template <bool has_limit> template <bool has_limit>
class GroupConcatImpl final GroupConcatImpl<has_limit>::GroupConcatImpl(
: public IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>> const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_)
{
static constexpr auto name = "groupConcat";
SerializationPtr serialization;
UInt64 limit;
const String delimiter;
const DataTypePtr type;
public:
GroupConcatImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_)
: IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>( : IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>(
{data_type_}, parameters_, std::make_shared<DataTypeString>()) {data_type_}, parameters_, std::make_shared<DataTypeString>())
, limit(limit_) , limit(limit_)
, delimiter(delimiter_) , delimiter(delimiter_)
, type(data_type_) , type(data_type_)
{ {
serialization = isFixedString(type) ? std::make_shared<DataTypeString>()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization(); serialization = isFixedString(type) ? std::make_shared<DataTypeString>()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization();
} }
String getName() const override { return name; } template <bool has_limit>
String GroupConcatImpl<has_limit>::getName() const
{
return name;
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{ template <bool has_limit>
void GroupConcatImpl<has_limit>::add(
AggregateDataPtr __restrict place,
const IColumn ** columns,
size_t row_num,
Arena * arena) const
{
auto & cur_data = this->data(place); auto & cur_data = this->data(place);
if constexpr (has_limit) if constexpr (has_limit)
@ -150,10 +108,11 @@ public:
} }
else else
cur_data.insert(columns[0], serialization, row_num, arena); cur_data.insert(columns[0], serialization, row_num, arena);
} }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override template <bool has_limit>
{ void GroupConcatImpl<has_limit>::merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const
{
auto & cur_data = this->data(place); auto & cur_data = this->data(place);
auto & rhs_data = this->data(rhs); auto & rhs_data = this->data(rhs);
@ -181,10 +140,11 @@ public:
cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena); cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena);
} }
} }
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override template <bool has_limit>
{ void GroupConcatImpl<has_limit>::serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const
{
auto & cur_data = this->data(place); auto & cur_data = this->data(place);
writeVarUInt(cur_data.data_size, buf); writeVarUInt(cur_data.data_size, buf);
@ -197,10 +157,11 @@ public:
for (const auto & offset : cur_data.offsets) for (const auto & offset : cur_data.offsets)
writeVarUInt(offset, buf); writeVarUInt(offset, buf);
} }
} }
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override template <bool has_limit>
{ void GroupConcatImpl<has_limit>::deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const
{
auto & cur_data = this->data(place); auto & cur_data = this->data(place);
UInt64 temp_size = 0; UInt64 temp_size = 0;
@ -218,10 +179,11 @@ public:
for (auto & offset : cur_data.offsets) for (auto & offset : cur_data.offsets)
readVarUInt(offset, buf); readVarUInt(offset, buf);
} }
} }
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override template <bool has_limit>
{ void GroupConcatImpl<has_limit>::insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const
{
auto & cur_data = this->data(place); auto & cur_data = this->data(place);
if (cur_data.data_size == 0) if (cur_data.data_size == 0)
@ -232,10 +194,12 @@ public:
auto & column_string = assert_cast<ColumnString &>(to); auto & column_string = assert_cast<ColumnString &>(to);
column_string.insertData(cur_data.data, cur_data.data_size); column_string.insertData(cur_data.data, cur_data.data_size);
} }
bool allocatesMemoryInArena() const override { return true; } template <bool has_limit>
}; bool GroupConcatImpl<has_limit>::allocatesMemoryInArena() const { return true; }
// Implementation of add, merge, serialize, deserialize, insertResultInto, etc. remains unchanged.
AggregateFunctionPtr createAggregateFunctionGroupConcat( AggregateFunctionPtr createAggregateFunctionGroupConcat(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *) const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
@ -278,14 +242,12 @@ AggregateFunctionPtr createAggregateFunctionGroupConcat(
return std::make_shared<GroupConcatImpl</* has_limit= */ false>>(argument_types[0], parameters, limit, delimiter); return std::make_shared<GroupConcatImpl</* has_limit= */ false>>(argument_types[0], parameters, limit, delimiter);
} }
}
void registerAggregateFunctionGroupConcat(AggregateFunctionFactory & factory) void registerAggregateFunctionGroupConcat(AggregateFunctionFactory & factory)
{ {
AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true }; AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true };
factory.registerFunction("groupConcat", { createAggregateFunctionGroupConcat, properties }); factory.registerFunction("groupConcat", { createAggregateFunctionGroupConcat, properties });
factory.registerAlias("group_concat", "groupConcat", AggregateFunctionFactory::Case::Insensitive); factory.registerAlias(GroupConcatImpl<false>::getNameAndAliases().at(1), GroupConcatImpl<false>::getNameAndAliases().at(0), AggregateFunctionFactory::Case::Insensitive);
} }
} }

View File

@ -0,0 +1,78 @@
#ifndef DB_GROUP_CONCAT_H
#define DB_GROUP_CONCAT_H
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <Core/ServerSettings.h>
#include <Common/ArenaAllocator.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeString.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
namespace DB
{
struct Settings;
struct GroupConcatDataBase
{
UInt64 data_size = 0;
UInt64 allocated_size = 0;
char * data = nullptr;
void checkAndUpdateSize(UInt64 add, Arena * arena);
void insertChar(const char * str, UInt64 str_size, Arena * arena);
void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena);
};
template <bool has_limit>
struct GroupConcatData : public GroupConcatDataBase
{
using Offset = UInt64;
using Allocator = MixedAlignedArenaAllocator<alignof(Offset), 4096>;
using Offsets = PODArray<Offset, 32, Allocator>;
Offsets offsets;
UInt64 num_rows = 0;
UInt64 getSize(size_t i) const;
UInt64 getString(size_t i) const;
void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena);
};
template <bool has_limit>
class GroupConcatImpl : public IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>
{
static constexpr auto name = "groupConcat";
constexpr static std::array<const char *, 2> names_and_aliases = { "groupConcat", "group_concat" };
SerializationPtr serialization;
UInt64 limit;
const String delimiter;
const DataTypePtr type;
public:
GroupConcatImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_);
String getName() const override;
static std::array<const char *, 2> getNameAndAliases()
{
return names_and_aliases;
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override;
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override;
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf, std::optional<size_t> version) const override;
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override;
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override;
bool allocatesMemoryInArena() const override;
};
} // namespace DB
#endif // DB_GROUP_CONCAT_H

View File

@ -122,6 +122,8 @@ private:
ColumnTransformersNodes buildColumnTransformers(const ASTPtr & matcher_expression, const ContextPtr & context) const; ColumnTransformersNodes buildColumnTransformers(const ASTPtr & matcher_expression, const ContextPtr & context) const;
QueryTreeNodePtr setFirstArgumentAsParameter(const ASTFunction * function, const ContextPtr & context) const;
ASTPtr query; ASTPtr query;
QueryTreeNodePtr query_tree_node; QueryTreeNodePtr query_tree_node;
}; };
@ -642,6 +644,17 @@ QueryTreeNodePtr QueryTreeBuilder::buildExpression(const ASTPtr & expression, co
result = std::make_shared<LambdaNode>(std::move(lambda_arguments), std::move(lambda_expression_node)); result = std::make_shared<LambdaNode>(std::move(lambda_arguments), std::move(lambda_expression_node));
} }
else else
{
const char * name = function->name.c_str();
// Check if the function is groupConcat with exactly two arguments
if (std::any_of(GroupConcatImpl<false>::getNameAndAliases().begin(),
GroupConcatImpl<false>::getNameAndAliases().end(),
[name](const char *alias) { return std::strcmp(name, alias) == 0; })
&& function->arguments && function->arguments->children.size() == 2)
{
result = setFirstArgumentAsParameter(function, context);
}
else
{ {
auto function_node = std::make_shared<FunctionNode>(function->name); auto function_node = std::make_shared<FunctionNode>(function->name);
function_node->setNullsAction(function->nulls_action); function_node->setNullsAction(function->nulls_action);
@ -671,6 +684,7 @@ QueryTreeNodePtr QueryTreeBuilder::buildExpression(const ASTPtr & expression, co
result = std::move(function_node); result = std::move(function_node);
} }
} }
}
else if (const auto * subquery = expression->as<ASTSubquery>()) else if (const auto * subquery = expression->as<ASTSubquery>())
{ {
auto subquery_query = subquery->children[0]; auto subquery_query = subquery->children[0];
@ -1071,4 +1085,42 @@ QueryTreeNodePtr buildQueryTree(ASTPtr query, ContextPtr context)
return builder.getQueryTreeNode(); return builder.getQueryTreeNode();
} }
QueryTreeNodePtr QueryTreeBuilder::setFirstArgumentAsParameter(const ASTFunction * function, const ContextPtr & context) const
{
const auto * first_arg_ast = function->arguments->children[0].get();
const auto * first_arg_literal = first_arg_ast->as<ASTLiteral>();
if (!first_arg_literal)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"If groupConcat is used with two arguments, the first argument must be a constant String");
}
if (first_arg_literal->value.getType() != Field::Types::String)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"If groupConcat is used with two arguments, the first argument must be a constant String");
}
std::string separator = first_arg_literal->value.safeGet<String>();
ASTPtr second_arg = function->arguments->children[1]->clone();
auto function_node = std::make_shared<FunctionNode>(function->name);
function_node->setNullsAction(function->nulls_action);
function_node->getParameters().getNodes().push_back(buildExpression(function->arguments->children[0], context)); // Separator
function_node->getArguments().getNodes().push_back(buildExpression(second_arg, context)); // Column to concatenate
if (function->is_window_function)
{
if (function->window_definition)
function_node->getWindowNode() = buildWindow(function->window_definition, context);
else
function_node->getWindowNode() = std::make_shared<IdentifierNode>(Identifier(function->window_name));
}
return std::move(function_node);
}
} }

View File

@ -130,51 +130,13 @@ String ASTFunction::getID(char delim) const
return "Function" + (delim + name); return "Function" + (delim + name);
} }
void ASTFunction::groupConcatArgumentOverride(std::shared_ptr<ASTFunction> res) const
{
// Clone the first argument to be used as a parameter
ASTPtr first_arg = arguments->children[0]->clone();
// Clone the second argument to remain as the function argument
ASTPtr second_arg = arguments->children[1]->clone();
// Initialize or clear parameters
if (!res->parameters)
res->parameters = std::make_shared<ASTExpressionList>();
else
res->parameters->children.clear();
// Add the first argument as a parameter
res->parameters->children.emplace_back(first_arg);
res->children.emplace_back(res->parameters);
// Initialize arguments with the second argument only
res->arguments = std::make_shared<ASTExpressionList>();
res->arguments->children.emplace_back(second_arg);
res->children.emplace_back(res->arguments);
}
ASTPtr ASTFunction::clone() const ASTPtr ASTFunction::clone() const
{ {
auto res = std::make_shared<ASTFunction>(*this); auto res = std::make_shared<ASTFunction>(*this);
res->children.clear(); res->children.clear();
// Special handling for groupConcat with two arguments if (arguments) { res->arguments = arguments->clone(); res->children.push_back(res->arguments); }
if ((name == "groupConcat" || Poco::toLower(name) == "group_concat") && arguments && arguments->children.size() == 2) if (parameters) { res->parameters = parameters->clone(); res->children.push_back(res->parameters); }
groupConcatArgumentOverride(res);
else
{
if (arguments)
{
res->arguments = arguments->clone();
res->children.push_back(res->arguments);
}
if (parameters)
{
res->parameters = parameters->clone();
res->children.push_back(res->parameters);
}
}
if (window_definition) if (window_definition)
{ {

View File

@ -81,8 +81,6 @@ public:
bool hasSecretParts() const override; bool hasSecretParts() const override;
void groupConcatArgumentOverride(std::shared_ptr<ASTFunction> res) const;
protected: protected:
void formatImplWithoutAlias(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override; void formatImplWithoutAlias(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override;
void appendColumnNameImpl(WriteBuffer & ostr) const override; void appendColumnNameImpl(WriteBuffer & ostr) const override;