diff --git a/dbms/include/DB/AggregateFunctions/AggregateFunctionIf.h b/dbms/include/DB/AggregateFunctions/AggregateFunctionIf.h index 652afa5e649..f8dcd6fef81 100644 --- a/dbms/include/DB/AggregateFunctions/AggregateFunctionIf.h +++ b/dbms/include/DB/AggregateFunctions/AggregateFunctionIf.h @@ -19,6 +19,7 @@ class AggregateFunctionIf : public IAggregateFunction private: AggregateFunctionPtr nested_func_owner; IAggregateFunction * nested_func; + int num_agruments; public: AggregateFunctionIf(AggregateFunctionPtr nested_) : nested_func_owner(nested_), nested_func(nested_func_owner.get()) {} @@ -34,15 +35,15 @@ public: void setArguments(const DataTypes & arguments) { - if (arguments.size() != 2) - throw Exception("Aggregate function " + getName() + " requires exactly two arguments.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - - if (!dynamic_cast(&*arguments[1])) - throw Exception("Illegal type " + arguments[1]->getName() + " of second argument for aggregate function " + getName() + ". Must be UInt8.", + num_agruments = arguments.size(); + + if (!dynamic_cast(&*arguments[num_agruments - 1])) + throw Exception("Illegal type " + arguments[num_agruments - 1]->getName() + " of second argument for aggregate function " + getName() + ". Must be UInt8.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); DataTypes nested_arguments; - nested_arguments.push_back(arguments[0]); + for (int i = 0; i < num_agruments - 1; i ++) + nested_arguments.push_back(arguments[i]); nested_func->setArguments(nested_arguments); } @@ -73,7 +74,7 @@ public: void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num) const { - if (static_cast(*columns[1]).getData()[row_num]) + if (static_cast(*columns[num_agruments - 1]).getData()[row_num]) nested_func->add(place, columns, row_num); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index 4476c35de81..7a97d3a27a6 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -278,10 +278,9 @@ AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const Da else if (recursion_level == 0 && name.size() >= 3 && name[name.size() - 2] == 'I' && name[name.size() - 1] == 'f') { /// Для агрегатных функций вида aggIf, где agg - имя другой агрегатной функции. - if (argument_types.size() != 2) - throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - - AggregateFunctionPtr nested = get(String(name.data(), name.size() - 2), DataTypes(1, argument_types[0]), 1); + DataTypes nested_dt = argument_types; + nested_dt.pop_back(); + AggregateFunctionPtr nested = get(String(name.data(), name.size() - 2), nested_dt); return new AggregateFunctionIf(nested); } else