Change the logic to use analyzer

This commit is contained in:
Yarik Briukhovetskyi 2024-12-03 17:31:45 +01:00 committed by GitHub
parent 802f98856a
commit c04a28d009
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 304 additions and 252 deletions

View File

@ -1,27 +1,7 @@
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <Columns/IColumn.h>
#include <Columns/ColumnNullable.h>
#include <AggregateFunctions/AggregateFunctionGroupConcat.h>
#include <Columns/ColumnString.h>
#include <Core/ServerSettings.h>
#include <Core/ColumnWithTypeAndName.h>
#include <Common/ArenaAllocator.h>
#include <Common/assert_cast.h>
#include <Interpreters/castColumn.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
namespace DB
{
struct Settings;
@ -33,209 +13,193 @@ namespace ErrorCodes
extern const int BAD_ARGUMENTS;
}
namespace
void GroupConcatDataBase::checkAndUpdateSize(UInt64 add, Arena * arena)
{
if (data_size + add >= allocated_size)
{
auto old_size = allocated_size;
allocated_size = std::max(2 * allocated_size, data_size + add);
data = arena->realloc(data, old_size, allocated_size);
}
}
struct GroupConcatDataBase
void GroupConcatDataBase::insertChar(const char * str, UInt64 str_size, Arena * arena)
{
UInt64 data_size = 0;
UInt64 allocated_size = 0;
char * data = nullptr;
checkAndUpdateSize(str_size, arena);
memcpy(data + data_size, str, str_size);
data_size += str_size;
}
void checkAndUpdateSize(UInt64 add, Arena * arena)
{
if (data_size + add >= allocated_size)
{
auto old_size = allocated_size;
allocated_size = std::max(2 * allocated_size, data_size + add);
data = arena->realloc(data, old_size, allocated_size);
}
}
void insertChar(const char * str, UInt64 str_size, Arena * arena)
{
checkAndUpdateSize(str_size, arena);
memcpy(data + data_size, str, str_size);
data_size += str_size;
}
void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena)
{
WriteBufferFromOwnString buff;
serialization->serializeText(*column, row_num, buff, FormatSettings{});
auto string = buff.stringView();
insertChar(string.data(), string.size(), arena);
}
};
void GroupConcatDataBase::insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena)
{
WriteBufferFromOwnString buff;
serialization->serializeText(*column, row_num, buff, FormatSettings{});
auto string = buff.stringView();
insertChar(string.data(), string.size(), arena);
}
template <bool has_limit>
struct GroupConcatData;
template<>
struct GroupConcatData<false> final : public GroupConcatDataBase
UInt64 GroupConcatData<has_limit>::getSize(size_t i) const
{
};
template<>
struct GroupConcatData<true> final : public GroupConcatDataBase
{
using Offset = UInt64;
using Allocator = MixedAlignedArenaAllocator<alignof(Offset), 4096>;
using Offsets = PODArray<Offset, 32, Allocator>;
/// offset[i * 2] - beginning of the i-th row, offset[i * 2 + 1] - end of the i-th row
Offsets offsets;
UInt64 num_rows = 0;
UInt64 getSize(size_t i) const { return offsets[i * 2 + 1] - offsets[i * 2]; }
UInt64 getString(size_t i) const { return offsets[i * 2]; }
void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena)
{
WriteBufferFromOwnString buff;
serialization->serializeText(*column, row_num, buff, {});
auto string = buff.stringView();
checkAndUpdateSize(string.size(), arena);
memcpy(data + data_size, string.data(), string.size());
offsets.push_back(data_size, arena);
data_size += string.size();
offsets.push_back(data_size, arena);
num_rows++;
}
};
return offsets[i * 2 + 1] - offsets[i * 2];
}
template <bool has_limit>
class GroupConcatImpl final
: public IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>
UInt64 GroupConcatData<has_limit>::getString(size_t i) const
{
static constexpr auto name = "groupConcat";
return offsets[i * 2];
}
SerializationPtr serialization;
UInt64 limit;
const String delimiter;
const DataTypePtr type;
template <bool has_limit>
void GroupConcatData<has_limit>::insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena)
{
WriteBufferFromOwnString buff;
serialization->serializeText(*column, row_num, buff, {});
auto string = buff.stringView();
public:
GroupConcatImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_)
: IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>(
{data_type_}, parameters_, std::make_shared<DataTypeString>())
, limit(limit_)
, delimiter(delimiter_)
, type(data_type_)
{
serialization = isFixedString(type) ? std::make_shared<DataTypeString>()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization();
}
checkAndUpdateSize(string.size(), arena);
memcpy(data + data_size, string.data(), string.size());
offsets.push_back(data_size, arena);
data_size += string.size();
offsets.push_back(data_size, arena);
num_rows++;
}
String getName() const override { return name; }
template <bool has_limit>
GroupConcatImpl<has_limit>::GroupConcatImpl(
const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_)
: IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>(
{data_type_}, parameters_, std::make_shared<DataTypeString>())
, limit(limit_)
, delimiter(delimiter_)
, type(data_type_)
{
serialization = isFixedString(type) ? std::make_shared<DataTypeString>()->getDefaultSerialization() : this->argument_types[0]->getDefaultSerialization();
}
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
auto & cur_data = this->data(place);
template <bool has_limit>
String GroupConcatImpl<has_limit>::getName() const
{
return name;
}
if constexpr (has_limit)
if (cur_data.num_rows >= limit)
return;
if (cur_data.data_size != 0)
cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena);
template <bool has_limit>
void GroupConcatImpl<has_limit>::add(
AggregateDataPtr __restrict place,
const IColumn ** columns,
size_t row_num,
Arena * arena) const
{
auto & cur_data = this->data(place);
if (isFixedString(type))
{
ColumnWithTypeAndName col = {columns[0]->getPtr(), type, "column"};
const auto & col_str = castColumn(col, std::make_shared<DataTypeString>());
cur_data.insert(col_str.get(), serialization, row_num, arena);
}
else
cur_data.insert(columns[0], serialization, row_num, arena);
}
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & cur_data = this->data(place);
auto & rhs_data = this->data(rhs);
if (rhs_data.data_size == 0)
if constexpr (has_limit)
if (cur_data.num_rows >= limit)
return;
if constexpr (has_limit)
{
UInt64 new_elems_count = std::min(rhs_data.num_rows, limit - cur_data.num_rows);
for (UInt64 i = 0; i < new_elems_count; ++i)
{
if (cur_data.data_size != 0)
cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena);
if (cur_data.data_size != 0)
cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena);
cur_data.offsets.push_back(cur_data.data_size, arena);
cur_data.insertChar(rhs_data.data + rhs_data.getString(i), rhs_data.getSize(i), arena);
cur_data.num_rows++;
cur_data.offsets.push_back(cur_data.data_size, arena);
}
}
else
if (isFixedString(type))
{
ColumnWithTypeAndName col = {columns[0]->getPtr(), type, "column"};
const auto & col_str = castColumn(col, std::make_shared<DataTypeString>());
cur_data.insert(col_str.get(), serialization, row_num, arena);
}
else
cur_data.insert(columns[0], serialization, row_num, arena);
}
template <bool has_limit>
void GroupConcatImpl<has_limit>::merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const
{
auto & cur_data = this->data(place);
auto & rhs_data = this->data(rhs);
if (rhs_data.data_size == 0)
return;
if constexpr (has_limit)
{
UInt64 new_elems_count = std::min(rhs_data.num_rows, limit - cur_data.num_rows);
for (UInt64 i = 0; i < new_elems_count; ++i)
{
if (cur_data.data_size != 0)
cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena);
cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena);
cur_data.offsets.push_back(cur_data.data_size, arena);
cur_data.insertChar(rhs_data.data + rhs_data.getString(i), rhs_data.getSize(i), arena);
cur_data.num_rows++;
cur_data.offsets.push_back(cur_data.data_size, arena);
}
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const override
else
{
auto & cur_data = this->data(place);
if (cur_data.data_size != 0)
cur_data.insertChar(delimiter.c_str(), delimiter.size(), arena);
writeVarUInt(cur_data.data_size, buf);
buf.write(cur_data.data, cur_data.data_size);
if constexpr (has_limit)
{
writeVarUInt(cur_data.num_rows, buf);
for (const auto & offset : cur_data.offsets)
writeVarUInt(offset, buf);
}
cur_data.insertChar(rhs_data.data, rhs_data.data_size, arena);
}
}
void deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const override
template <bool has_limit>
void GroupConcatImpl<has_limit>::serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> /* version */) const
{
auto & cur_data = this->data(place);
writeVarUInt(cur_data.data_size, buf);
buf.write(cur_data.data, cur_data.data_size);
if constexpr (has_limit)
{
auto & cur_data = this->data(place);
UInt64 temp_size = 0;
readVarUInt(temp_size, buf);
cur_data.checkAndUpdateSize(temp_size, arena);
buf.readStrict(cur_data.data + cur_data.data_size, temp_size);
cur_data.data_size = temp_size;
if constexpr (has_limit)
{
readVarUInt(cur_data.num_rows, buf);
cur_data.offsets.resize_exact(cur_data.num_rows * 2, arena);
for (auto & offset : cur_data.offsets)
readVarUInt(offset, buf);
}
writeVarUInt(cur_data.num_rows, buf);
for (const auto & offset : cur_data.offsets)
writeVarUInt(offset, buf);
}
}
void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override
template <bool has_limit>
void GroupConcatImpl<has_limit>::deserialize(AggregateDataPtr __restrict place, ReadBuffer & buf, std::optional<size_t> /* version */, Arena * arena) const
{
auto & cur_data = this->data(place);
UInt64 temp_size = 0;
readVarUInt(temp_size, buf);
cur_data.checkAndUpdateSize(temp_size, arena);
buf.readStrict(cur_data.data + cur_data.data_size, temp_size);
cur_data.data_size = temp_size;
if constexpr (has_limit)
{
auto & cur_data = this->data(place);
readVarUInt(cur_data.num_rows, buf);
cur_data.offsets.resize_exact(cur_data.num_rows * 2, arena);
for (auto & offset : cur_data.offsets)
readVarUInt(offset, buf);
}
}
if (cur_data.data_size == 0)
{
to.insertDefault();
return;
}
template <bool has_limit>
void GroupConcatImpl<has_limit>::insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const
{
auto & cur_data = this->data(place);
auto & column_string = assert_cast<ColumnString &>(to);
column_string.insertData(cur_data.data, cur_data.data_size);
if (cur_data.data_size == 0)
{
to.insertDefault();
return;
}
bool allocatesMemoryInArena() const override { return true; }
};
auto & column_string = assert_cast<ColumnString &>(to);
column_string.insertData(cur_data.data, cur_data.data_size);
}
template <bool has_limit>
bool GroupConcatImpl<has_limit>::allocatesMemoryInArena() const { return true; }
// Implementation of add, merge, serialize, deserialize, insertResultInto, etc. remains unchanged.
AggregateFunctionPtr createAggregateFunctionGroupConcat(
const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings *)
@ -278,14 +242,12 @@ AggregateFunctionPtr createAggregateFunctionGroupConcat(
return std::make_shared<GroupConcatImpl</* has_limit= */ false>>(argument_types[0], parameters, limit, delimiter);
}
}
void registerAggregateFunctionGroupConcat(AggregateFunctionFactory & factory)
{
AggregateFunctionProperties properties = { .returns_default_when_only_null = false, .is_order_dependent = true };
factory.registerFunction("groupConcat", { createAggregateFunctionGroupConcat, properties });
factory.registerAlias("group_concat", "groupConcat", AggregateFunctionFactory::Case::Insensitive);
factory.registerAlias(GroupConcatImpl<false>::getNameAndAliases().at(1), GroupConcatImpl<false>::getNameAndAliases().at(0), AggregateFunctionFactory::Case::Insensitive);
}
}

View File

@ -0,0 +1,78 @@
#ifndef DB_GROUP_CONCAT_H
#define DB_GROUP_CONCAT_H
#include <AggregateFunctions/IAggregateFunction.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <Core/ServerSettings.h>
#include <Common/ArenaAllocator.h>
#include <DataTypes/IDataType.h>
#include <DataTypes/DataTypeString.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
namespace DB
{
struct Settings;
struct GroupConcatDataBase
{
UInt64 data_size = 0;
UInt64 allocated_size = 0;
char * data = nullptr;
void checkAndUpdateSize(UInt64 add, Arena * arena);
void insertChar(const char * str, UInt64 str_size, Arena * arena);
void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena);
};
template <bool has_limit>
struct GroupConcatData : public GroupConcatDataBase
{
using Offset = UInt64;
using Allocator = MixedAlignedArenaAllocator<alignof(Offset), 4096>;
using Offsets = PODArray<Offset, 32, Allocator>;
Offsets offsets;
UInt64 num_rows = 0;
UInt64 getSize(size_t i) const;
UInt64 getString(size_t i) const;
void insert(const IColumn * column, const SerializationPtr & serialization, size_t row_num, Arena * arena);
};
template <bool has_limit>
class GroupConcatImpl : public IAggregateFunctionDataHelper<GroupConcatData<has_limit>, GroupConcatImpl<has_limit>>
{
static constexpr auto name = "groupConcat";
constexpr static std::array<const char *, 2> names_and_aliases = { "groupConcat", "group_concat" };
SerializationPtr serialization;
UInt64 limit;
const String delimiter;
const DataTypePtr type;
public:
GroupConcatImpl(const DataTypePtr & data_type_, const Array & parameters_, UInt64 limit_, const String & delimiter_);
String getName() const override;
static std::array<const char *, 2> getNameAndAliases()
{
return names_and_aliases;
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override;
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override;
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf, std::optional<size_t> version) const override;
void deserialize(AggregateDataPtr place, ReadBuffer & buf, std::optional<size_t> version, Arena * arena) const override;
void insertResultInto(AggregateDataPtr place, IColumn & to, Arena * arena) const override;
bool allocatesMemoryInArena() const override;
};
} // namespace DB
#endif // DB_GROUP_CONCAT_H

View File

@ -122,6 +122,8 @@ private:
ColumnTransformersNodes buildColumnTransformers(const ASTPtr & matcher_expression, const ContextPtr & context) const;
QueryTreeNodePtr setFirstArgumentAsParameter(const ASTFunction * function, const ContextPtr & context) const;
ASTPtr query;
QueryTreeNodePtr query_tree_node;
};
@ -643,32 +645,44 @@ QueryTreeNodePtr QueryTreeBuilder::buildExpression(const ASTPtr & expression, co
}
else
{
auto function_node = std::make_shared<FunctionNode>(function->name);
function_node->setNullsAction(function->nulls_action);
if (function->parameters)
const char * name = function->name.c_str();
// Check if the function is groupConcat with exactly two arguments
if (std::any_of(GroupConcatImpl<false>::getNameAndAliases().begin(),
GroupConcatImpl<false>::getNameAndAliases().end(),
[name](const char *alias) { return std::strcmp(name, alias) == 0; })
&& function->arguments && function->arguments->children.size() == 2)
{
const auto & function_parameters_list = function->parameters->as<ASTExpressionList>()->children;
for (const auto & argument : function_parameters_list)
function_node->getParameters().getNodes().push_back(buildExpression(argument, context));
result = setFirstArgumentAsParameter(function, context);
}
if (function->arguments)
else
{
const auto & function_arguments_list = function->arguments->as<ASTExpressionList>()->children;
for (const auto & argument : function_arguments_list)
function_node->getArguments().getNodes().push_back(buildExpression(argument, context));
}
auto function_node = std::make_shared<FunctionNode>(function->name);
function_node->setNullsAction(function->nulls_action);
if (function->is_window_function)
{
if (function->window_definition)
function_node->getWindowNode() = buildWindow(function->window_definition, context);
else
function_node->getWindowNode() = std::make_shared<IdentifierNode>(Identifier(function->window_name));
}
if (function->parameters)
{
const auto & function_parameters_list = function->parameters->as<ASTExpressionList>()->children;
for (const auto & argument : function_parameters_list)
function_node->getParameters().getNodes().push_back(buildExpression(argument, context));
}
result = std::move(function_node);
if (function->arguments)
{
const auto & function_arguments_list = function->arguments->as<ASTExpressionList>()->children;
for (const auto & argument : function_arguments_list)
function_node->getArguments().getNodes().push_back(buildExpression(argument, context));
}
if (function->is_window_function)
{
if (function->window_definition)
function_node->getWindowNode() = buildWindow(function->window_definition, context);
else
function_node->getWindowNode() = std::make_shared<IdentifierNode>(Identifier(function->window_name));
}
result = std::move(function_node);
}
}
}
else if (const auto * subquery = expression->as<ASTSubquery>())
@ -1071,4 +1085,42 @@ QueryTreeNodePtr buildQueryTree(ASTPtr query, ContextPtr context)
return builder.getQueryTreeNode();
}
QueryTreeNodePtr QueryTreeBuilder::setFirstArgumentAsParameter(const ASTFunction * function, const ContextPtr & context) const
{
const auto * first_arg_ast = function->arguments->children[0].get();
const auto * first_arg_literal = first_arg_ast->as<ASTLiteral>();
if (!first_arg_literal)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"If groupConcat is used with two arguments, the first argument must be a constant String");
}
if (first_arg_literal->value.getType() != Field::Types::String)
{
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"If groupConcat is used with two arguments, the first argument must be a constant String");
}
std::string separator = first_arg_literal->value.safeGet<String>();
ASTPtr second_arg = function->arguments->children[1]->clone();
auto function_node = std::make_shared<FunctionNode>(function->name);
function_node->setNullsAction(function->nulls_action);
function_node->getParameters().getNodes().push_back(buildExpression(function->arguments->children[0], context)); // Separator
function_node->getArguments().getNodes().push_back(buildExpression(second_arg, context)); // Column to concatenate
if (function->is_window_function)
{
if (function->window_definition)
function_node->getWindowNode() = buildWindow(function->window_definition, context);
else
function_node->getWindowNode() = std::make_shared<IdentifierNode>(Identifier(function->window_name));
}
return std::move(function_node);
}
}

View File

@ -130,51 +130,13 @@ String ASTFunction::getID(char delim) const
return "Function" + (delim + name);
}
void ASTFunction::groupConcatArgumentOverride(std::shared_ptr<ASTFunction> res) const
{
// Clone the first argument to be used as a parameter
ASTPtr first_arg = arguments->children[0]->clone();
// Clone the second argument to remain as the function argument
ASTPtr second_arg = arguments->children[1]->clone();
// Initialize or clear parameters
if (!res->parameters)
res->parameters = std::make_shared<ASTExpressionList>();
else
res->parameters->children.clear();
// Add the first argument as a parameter
res->parameters->children.emplace_back(first_arg);
res->children.emplace_back(res->parameters);
// Initialize arguments with the second argument only
res->arguments = std::make_shared<ASTExpressionList>();
res->arguments->children.emplace_back(second_arg);
res->children.emplace_back(res->arguments);
}
ASTPtr ASTFunction::clone() const
{
auto res = std::make_shared<ASTFunction>(*this);
res->children.clear();
// Special handling for groupConcat with two arguments
if ((name == "groupConcat" || Poco::toLower(name) == "group_concat") && arguments && arguments->children.size() == 2)
groupConcatArgumentOverride(res);
else
{
if (arguments)
{
res->arguments = arguments->clone();
res->children.push_back(res->arguments);
}
if (parameters)
{
res->parameters = parameters->clone();
res->children.push_back(res->parameters);
}
}
if (arguments) { res->arguments = arguments->clone(); res->children.push_back(res->arguments); }
if (parameters) { res->parameters = parameters->clone(); res->children.push_back(res->parameters); }
if (window_definition)
{

View File

@ -81,8 +81,6 @@ public:
bool hasSecretParts() const override;
void groupConcatArgumentOverride(std::shared_ptr<ASTFunction> res) const;
protected:
void formatImplWithoutAlias(const FormatSettings & settings, FormatState & state, FormatStateStacked frame) const override;
void appendColumnNameImpl(WriteBuffer & ostr) const override;