Check argument types in DataTypeAggregateFunction ctor

This commit is contained in:
vdimir 2024-07-24 10:01:50 +00:00
parent 681eafef79
commit b53e757656
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862
2 changed files with 28 additions and 7 deletions

View File

@ -33,6 +33,33 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
DataTypeAggregateFunction::DataTypeAggregateFunction(AggregateFunctionPtr function_, const DataTypes & argument_types_,
const Array & parameters_, std::optional<size_t> version_)
: function(std::move(function_))
, argument_types(argument_types_)
, parameters(parameters_)
, version(version_)
{
Strings argument_type_names;
for (const auto & argument_type : argument_types)
argument_type_names.push_back(argument_type->getName());
Strings function_argument_type_names;
const auto & function_argument_types = function->getArgumentTypes();
for (const auto & argument_type : function_argument_types)
function_argument_type_names.push_back(argument_type->getName());
size_t argument_types_size = std::max(argument_types.size(), function_argument_types.size());
for (size_t i = 0; i < argument_types_size; ++i)
{
if (argument_types.size() != function_argument_types.size() || !argument_types[i]->equals(*function_argument_types[i]))
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Data type AggregateFunction {} got argument types different from function argument types: [{}] != [{}]",
function->getName(), fmt::join(argument_type_names, ", "), fmt::join(function_argument_type_names, ", "));
}
}
String DataTypeAggregateFunction::getFunctionName() const
{
return function->getName();

View File

@ -30,13 +30,7 @@ public:
static constexpr bool is_parametric = true;
DataTypeAggregateFunction(AggregateFunctionPtr function_, const DataTypes & argument_types_,
const Array & parameters_, std::optional<size_t> version_ = std::nullopt)
: function(std::move(function_))
, argument_types(argument_types_)
, parameters(parameters_)
, version(version_)
{
}
const Array & parameters_, std::optional<size_t> version_ = std::nullopt);
size_t getVersion() const;