#include #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int UNKNOWN_AGGREGATE_FUNCTION; extern const int LOGICAL_ERROR; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } namespace { /// Does not check anything. std::string trimRight(const std::string & in, const char * suffix) { return in.substr(0, in.size() - strlen(suffix)); } } AggregateFunctionPtr createAggregateFunctionArray(AggregateFunctionPtr & nested, const DataTypes & argument_types); AggregateFunctionPtr createAggregateFunctionForEach(AggregateFunctionPtr & nested, const DataTypes & argument_types); AggregateFunctionPtr createAggregateFunctionIf(AggregateFunctionPtr & nested, const DataTypes & argument_types); AggregateFunctionPtr createAggregateFunctionState(AggregateFunctionPtr & nested, const DataTypes & argument_types, const Array & parameters); AggregateFunctionPtr createAggregateFunctionMerge(const String & name, AggregateFunctionPtr & nested, const DataTypes & argument_types); AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested); AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested, const DataTypes & argument_types); AggregateFunctionPtr createAggregateFunctionCountNotNull(const String & name, const DataTypes & argument_types, const Array & parameters); AggregateFunctionPtr createAggregateFunctionNothing(); void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness) { if (creator == nullptr) throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided " " a null constructor", ErrorCodes::LOGICAL_ERROR); if (!aggregate_functions.emplace(name, creator).second) throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique", ErrorCodes::LOGICAL_ERROR); if (case_sensitiveness == CaseInsensitive && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second) throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique", ErrorCodes::LOGICAL_ERROR); } AggregateFunctionPtr AggregateFunctionFactory::get( const String & name, const DataTypes & argument_types, const Array & parameters, int recursion_level) const { bool has_nullable_types = false; bool has_null_types = false; for (const auto & arg_type : argument_types) { if (arg_type->isNullable()) { has_nullable_types = true; if (arg_type->onlyNull()) { has_null_types = true; break; } } } if (has_nullable_types) { /// Special case for 'count' function. It could be called with Nullable arguments /// - that means - count number of calls, when all arguments are not NULL. if (Poco::toLower(name) == "count") return createAggregateFunctionCountNotNull(name, argument_types, parameters); AggregateFunctionPtr nested_function; if (has_null_types) { nested_function = createAggregateFunctionNothing(); } else { DataTypes nested_argument_types; nested_argument_types.reserve(argument_types.size()); for (const auto & arg_type : argument_types) { if (arg_type->isNullable()) { const DataTypeNullable & actual_type = static_cast(*arg_type.get()); const DataTypePtr & nested_type = actual_type.getNestedType(); nested_argument_types.push_back(nested_type); } else nested_argument_types.push_back(arg_type); } nested_function = getImpl(name, nested_argument_types, parameters, recursion_level); } if (argument_types.size() == 1) return createAggregateFunctionNullUnary(nested_function); else return createAggregateFunctionNullVariadic(nested_function, argument_types); } else return getImpl(name, argument_types, parameters, recursion_level); } AggregateFunctionPtr AggregateFunctionFactory::getImpl( const String & name, const DataTypes & argument_types, const Array & parameters, int recursion_level) const { auto it = aggregate_functions.find(name); if (it != aggregate_functions.end()) { auto it = aggregate_functions.find(name); if (it != aggregate_functions.end()) return it->second(name, argument_types, parameters); } /// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names. if (recursion_level == 0) { auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name)); if (it != case_insensitive_aggregate_functions.end()) return it->second(name, argument_types, parameters); } /// Combinators of aggregate functions. /// For every aggregate function 'agg' and combiner '-Comb' there is combined aggregate function with name 'aggComb', /// that can have different number and/or types of arguments, different result type and different behaviour. if (endsWith(name, "State")) { AggregateFunctionPtr nested = get(trimRight(name, "State"), argument_types, parameters, recursion_level + 1); return createAggregateFunctionState(nested, argument_types, parameters); } if (endsWith(name, "Merge")) { if (argument_types.size() != 1) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); const DataTypeAggregateFunction * function = typeid_cast(argument_types[0].get()); if (!function) throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name + " must be AggregateFunction(...)", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); AggregateFunctionPtr nested = get(trimRight(name, "Merge"), function->getArgumentsDataTypes(), parameters, recursion_level + 1); if (nested->getName() != function->getFunctionName()) throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name + ", because it corresponds to different aggregate function: " + function->getFunctionName() + " instead of " + nested->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); return createAggregateFunctionMerge(name, nested, argument_types); } if (endsWith(name, "If")) { if (argument_types.empty()) throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (!typeid_cast(argument_types.back().get())) throw Exception("Illegal type " + argument_types.back()->getName() + " of last argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); DataTypes nested_dt = argument_types; nested_dt.pop_back(); AggregateFunctionPtr nested = get(trimRight(name, "If"), nested_dt, parameters, recursion_level + 1); return createAggregateFunctionIf(nested, argument_types); } if (endsWith(name, "Array")) { DataTypes nested_arguments; for (const auto & type : argument_types) { if (const DataTypeArray * array = typeid_cast(type.get())) nested_arguments.push_back(array->getNestedType()); else throw Exception("Illegal type " + type->getName() + " of argument" " for aggregate function " + name + ". Must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } AggregateFunctionPtr nested = get(trimRight(name, "Array"), nested_arguments, parameters, recursion_level + 1); return createAggregateFunctionArray(nested, argument_types); } if (endsWith(name, "ForEach")) { DataTypes nested_arguments; for (const auto & type : argument_types) { if (const DataTypeArray * array = typeid_cast(type.get())) nested_arguments.push_back(array->getNestedType()); else throw Exception("Illegal type " + type->getName() + " of argument" " for aggregate function " + name + ". Must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } AggregateFunctionPtr nested = get(trimRight(name, "ForEach"), nested_arguments, parameters, recursion_level + 1); return createAggregateFunctionForEach(nested, argument_types); } throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION); } AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const { return isAggregateFunctionName(name) ? get(name, argument_types, parameters) : nullptr; } bool AggregateFunctionFactory::isAggregateFunctionName(const String & name, int recursion_level) const { if (aggregate_functions.count(name)) return true; if (recursion_level == 0 && case_insensitive_aggregate_functions.count(Poco::toLower(name))) return true; if (endsWith(name, "State")) return isAggregateFunctionName(trimRight(name, "State"), recursion_level + 1); if (endsWith(name, "Merge")) return isAggregateFunctionName(trimRight(name, "Merge"), recursion_level + 1); if (endsWith(name, "If")) return isAggregateFunctionName(trimRight(name, "If"), recursion_level + 1); if (endsWith(name, "Array")) return isAggregateFunctionName(trimRight(name, "Array"), recursion_level + 1); if (endsWith(name, "ForEach")) return isAggregateFunctionName(trimRight(name, "ForEach"), recursion_level + 1); return false; } }