2017-04-01 09:19:00 +00:00
|
|
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
2017-08-18 17:06:22 +00:00
|
|
|
|
2017-04-01 09:19:00 +00:00
|
|
|
#include <DataTypes/DataTypeAggregateFunction.h>
|
|
|
|
#include <DataTypes/DataTypeArray.h>
|
|
|
|
#include <DataTypes/DataTypeNullable.h>
|
2017-12-20 07:36:30 +00:00
|
|
|
#include <DataTypes/DataTypesNumber.h>
|
2017-08-18 17:06:22 +00:00
|
|
|
#include <IO/WriteBuffer.h>
|
|
|
|
#include <IO/WriteHelpers.h>
|
|
|
|
#include <Interpreters/Context.h>
|
|
|
|
|
2017-04-01 09:19:00 +00:00
|
|
|
#include <Common/StringUtils.h>
|
2017-07-13 16:49:09 +00:00
|
|
|
#include <Common/typeid_cast.h>
|
2016-07-12 13:02:52 +00:00
|
|
|
|
2017-08-18 17:06:22 +00:00
|
|
|
#include <Poco/String.h>
|
|
|
|
|
2011-09-19 03:40:05 +00:00
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
|
|
|
|
2016-01-12 02:21:15 +00:00
|
|
|
namespace ErrorCodes
|
|
|
|
{
|
2017-04-01 07:20:54 +00:00
|
|
|
extern const int UNKNOWN_AGGREGATE_FUNCTION;
|
|
|
|
extern const int LOGICAL_ERROR;
|
|
|
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
2017-12-20 07:36:30 +00:00
|
|
|
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
2016-01-12 02:21:15 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2015-09-24 12:40:36 +00:00
|
|
|
namespace
|
2011-09-19 03:40:05 +00:00
|
|
|
{
|
|
|
|
|
2017-03-25 20:12:56 +00:00
|
|
|
/// Does not check anything.
|
2016-07-14 05:22:09 +00:00
|
|
|
std::string trimRight(const std::string & in, const char * suffix)
|
2015-02-27 17:38:21 +00:00
|
|
|
{
|
2017-04-01 07:20:54 +00:00
|
|
|
return in.substr(0, in.size() - strlen(suffix));
|
2015-02-27 17:38:21 +00:00
|
|
|
}
|
|
|
|
|
2014-08-18 05:45:41 +00:00
|
|
|
}
|
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
AggregateFunctionPtr createAggregateFunctionArray(AggregateFunctionPtr & nested, const DataTypes & argument_types);
|
|
|
|
AggregateFunctionPtr createAggregateFunctionForEach(AggregateFunctionPtr & nested, const DataTypes & argument_types);
|
|
|
|
AggregateFunctionPtr createAggregateFunctionIf(AggregateFunctionPtr & nested, const DataTypes & argument_types);
|
|
|
|
AggregateFunctionPtr createAggregateFunctionState(AggregateFunctionPtr & nested, const DataTypes & argument_types, const Array & parameters);
|
2017-12-20 08:14:33 +00:00
|
|
|
AggregateFunctionPtr createAggregateFunctionMerge(const String & name, AggregateFunctionPtr & nested, const DataTypes & argument_types);
|
2017-12-20 20:25:22 +00:00
|
|
|
|
|
|
|
AggregateFunctionPtr createAggregateFunctionNullUnary(AggregateFunctionPtr & nested);
|
2017-12-20 07:36:30 +00:00
|
|
|
AggregateFunctionPtr createAggregateFunctionNullVariadic(AggregateFunctionPtr & nested, const DataTypes & argument_types);
|
|
|
|
AggregateFunctionPtr createAggregateFunctionCountNotNull(const String & name, const DataTypes & argument_types, const Array & parameters);
|
2017-12-08 05:09:08 +00:00
|
|
|
AggregateFunctionPtr createAggregateFunctionNothing();
|
2014-08-18 05:45:41 +00:00
|
|
|
|
2015-03-01 01:06:49 +00:00
|
|
|
|
2016-07-14 05:22:09 +00:00
|
|
|
void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness)
|
2011-09-19 03:40:05 +00:00
|
|
|
{
|
2017-04-01 07:20:54 +00:00
|
|
|
if (creator == nullptr)
|
|
|
|
throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
|
|
|
|
" a null constructor", ErrorCodes::LOGICAL_ERROR);
|
|
|
|
|
|
|
|
if (!aggregate_functions.emplace(name, creator).second)
|
2017-06-10 09:04:31 +00:00
|
|
|
throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique",
|
2017-04-01 07:20:54 +00:00
|
|
|
ErrorCodes::LOGICAL_ERROR);
|
|
|
|
|
|
|
|
if (case_sensitiveness == CaseInsensitive
|
|
|
|
&& !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second)
|
2017-06-10 09:04:31 +00:00
|
|
|
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
|
2017-04-01 07:20:54 +00:00
|
|
|
ErrorCodes::LOGICAL_ERROR);
|
2015-09-24 12:40:36 +00:00
|
|
|
}
|
2015-05-17 17:46:21 +00:00
|
|
|
|
|
|
|
|
2017-07-10 23:30:17 +00:00
|
|
|
AggregateFunctionPtr AggregateFunctionFactory::get(
|
|
|
|
const String & name,
|
|
|
|
const DataTypes & argument_types,
|
|
|
|
const Array & parameters,
|
|
|
|
int recursion_level) const
|
2016-07-12 13:02:52 +00:00
|
|
|
{
|
2017-04-01 07:20:54 +00:00
|
|
|
bool has_nullable_types = false;
|
2017-12-08 05:09:08 +00:00
|
|
|
bool has_null_types = false;
|
2017-04-01 07:20:54 +00:00
|
|
|
for (const auto & arg_type : argument_types)
|
|
|
|
{
|
2017-12-07 08:31:47 +00:00
|
|
|
if (arg_type->isNullable())
|
2017-04-01 07:20:54 +00:00
|
|
|
{
|
|
|
|
has_nullable_types = true;
|
2017-12-09 06:32:22 +00:00
|
|
|
if (arg_type->onlyNull())
|
2017-12-08 05:09:08 +00:00
|
|
|
{
|
|
|
|
has_null_types = true;
|
|
|
|
break;
|
|
|
|
}
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (has_nullable_types)
|
|
|
|
{
|
|
|
|
/// Special case for 'count' function. It could be called with Nullable arguments
|
|
|
|
/// - that means - count number of calls, when all arguments are not NULL.
|
|
|
|
if (Poco::toLower(name) == "count")
|
2017-12-20 07:36:30 +00:00
|
|
|
return createAggregateFunctionCountNotNull(name, argument_types, parameters);
|
2017-04-01 07:20:54 +00:00
|
|
|
|
2017-12-08 05:09:08 +00:00
|
|
|
AggregateFunctionPtr nested_function;
|
2017-04-01 07:20:54 +00:00
|
|
|
|
2017-12-08 05:09:08 +00:00
|
|
|
if (has_null_types)
|
|
|
|
{
|
|
|
|
nested_function = createAggregateFunctionNothing();
|
|
|
|
}
|
|
|
|
else
|
2017-04-01 07:20:54 +00:00
|
|
|
{
|
2017-12-08 05:09:08 +00:00
|
|
|
DataTypes nested_argument_types;
|
|
|
|
nested_argument_types.reserve(argument_types.size());
|
|
|
|
|
|
|
|
for (const auto & arg_type : argument_types)
|
2017-04-01 07:20:54 +00:00
|
|
|
{
|
2017-12-08 05:09:08 +00:00
|
|
|
if (arg_type->isNullable())
|
|
|
|
{
|
|
|
|
const DataTypeNullable & actual_type = static_cast<const DataTypeNullable &>(*arg_type.get());
|
|
|
|
const DataTypePtr & nested_type = actual_type.getNestedType();
|
|
|
|
nested_argument_types.push_back(nested_type);
|
|
|
|
}
|
|
|
|
else
|
|
|
|
nested_argument_types.push_back(arg_type);
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
|
2017-12-08 05:09:08 +00:00
|
|
|
nested_function = getImpl(name, nested_argument_types, parameters, recursion_level);
|
|
|
|
}
|
2017-04-01 07:20:54 +00:00
|
|
|
|
|
|
|
if (argument_types.size() == 1)
|
2017-12-20 20:25:22 +00:00
|
|
|
return createAggregateFunctionNullUnary(nested_function);
|
2017-04-01 07:20:54 +00:00
|
|
|
else
|
2017-12-20 08:14:33 +00:00
|
|
|
return createAggregateFunctionNullVariadic(nested_function, argument_types);
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
else
|
2017-07-10 23:30:17 +00:00
|
|
|
return getImpl(name, argument_types, parameters, recursion_level);
|
2016-07-12 13:02:52 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
2017-07-10 23:30:17 +00:00
|
|
|
AggregateFunctionPtr AggregateFunctionFactory::getImpl(
|
|
|
|
const String & name,
|
|
|
|
const DataTypes & argument_types,
|
|
|
|
const Array & parameters,
|
|
|
|
int recursion_level) const
|
2015-09-24 12:40:36 +00:00
|
|
|
{
|
2017-04-01 07:20:54 +00:00
|
|
|
auto it = aggregate_functions.find(name);
|
|
|
|
if (it != aggregate_functions.end())
|
|
|
|
{
|
|
|
|
auto it = aggregate_functions.find(name);
|
|
|
|
if (it != aggregate_functions.end())
|
2017-07-10 23:30:17 +00:00
|
|
|
return it->second(name, argument_types, parameters);
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
/// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names.
|
2017-04-01 07:20:54 +00:00
|
|
|
if (recursion_level == 0)
|
|
|
|
{
|
|
|
|
auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name));
|
|
|
|
if (it != case_insensitive_aggregate_functions.end())
|
2017-07-10 23:30:17 +00:00
|
|
|
return it->second(name, argument_types, parameters);
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
/// Combinators of aggregate functions.
|
|
|
|
/// For every aggregate function 'agg' and combiner '-Comb' there is combined aggregate function with name 'aggComb',
|
|
|
|
/// that can have different number and/or types of arguments, different result type and different behaviour.
|
|
|
|
|
|
|
|
if (endsWith(name, "State"))
|
2017-04-01 07:20:54 +00:00
|
|
|
{
|
2017-07-10 23:30:17 +00:00
|
|
|
AggregateFunctionPtr nested = get(trimRight(name, "State"), argument_types, parameters, recursion_level + 1);
|
2017-12-20 07:36:30 +00:00
|
|
|
return createAggregateFunctionState(nested, argument_types, parameters);
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
if (endsWith(name, "Merge"))
|
2017-04-01 07:20:54 +00:00
|
|
|
{
|
|
|
|
if (argument_types.size() != 1)
|
|
|
|
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
2017-12-20 07:36:30 +00:00
|
|
|
const DataTypeAggregateFunction * function = typeid_cast<const DataTypeAggregateFunction *>(argument_types[0].get());
|
2017-04-01 07:20:54 +00:00
|
|
|
if (!function)
|
2017-12-20 07:36:30 +00:00
|
|
|
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name
|
|
|
|
+ " must be AggregateFunction(...)", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
2017-04-01 07:20:54 +00:00
|
|
|
|
2017-07-10 23:30:17 +00:00
|
|
|
AggregateFunctionPtr nested = get(trimRight(name, "Merge"), function->getArgumentsDataTypes(), parameters, recursion_level + 1);
|
2017-04-01 07:20:54 +00:00
|
|
|
|
|
|
|
if (nested->getName() != function->getFunctionName())
|
2017-12-20 07:36:30 +00:00
|
|
|
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name
|
|
|
|
+ ", because it corresponds to different aggregate function: " + function->getFunctionName() + " instead of " + nested->getName(),
|
2017-04-01 07:20:54 +00:00
|
|
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
|
|
|
|
2017-12-20 08:14:33 +00:00
|
|
|
return createAggregateFunctionMerge(name, nested, argument_types);
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
if (endsWith(name, "If"))
|
2017-04-01 07:20:54 +00:00
|
|
|
{
|
|
|
|
if (argument_types.empty())
|
2017-12-20 07:36:30 +00:00
|
|
|
throw Exception("Incorrect number of arguments for aggregate function " + name,
|
|
|
|
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
|
|
|
|
|
|
|
if (!typeid_cast<const DataTypeUInt8 *>(argument_types.back().get()))
|
|
|
|
throw Exception("Illegal type " + argument_types.back()->getName() + " of last argument for aggregate function " + name,
|
|
|
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
2017-04-01 07:20:54 +00:00
|
|
|
|
|
|
|
DataTypes nested_dt = argument_types;
|
|
|
|
nested_dt.pop_back();
|
2017-12-20 07:36:30 +00:00
|
|
|
|
2017-07-10 23:30:17 +00:00
|
|
|
AggregateFunctionPtr nested = get(trimRight(name, "If"), nested_dt, parameters, recursion_level + 1);
|
2017-12-20 07:36:30 +00:00
|
|
|
return createAggregateFunctionIf(nested, argument_types);
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
if (endsWith(name, "Array"))
|
2017-04-01 07:20:54 +00:00
|
|
|
{
|
|
|
|
DataTypes nested_arguments;
|
2017-12-20 07:36:30 +00:00
|
|
|
for (const auto & type : argument_types)
|
2017-04-01 07:20:54 +00:00
|
|
|
{
|
2017-12-20 07:36:30 +00:00
|
|
|
if (const DataTypeArray * array = typeid_cast<const DataTypeArray *>(type.get()))
|
2017-04-01 07:20:54 +00:00
|
|
|
nested_arguments.push_back(array->getNestedType());
|
|
|
|
else
|
2017-12-20 08:14:33 +00:00
|
|
|
throw Exception("Illegal type " + type->getName() + " of argument"
|
2017-04-01 07:20:54 +00:00
|
|
|
" for aggregate function " + name + ". Must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
|
|
|
}
|
2017-12-20 07:36:30 +00:00
|
|
|
|
|
|
|
AggregateFunctionPtr nested = get(trimRight(name, "Array"), nested_arguments, parameters, recursion_level + 1);
|
|
|
|
return createAggregateFunctionArray(nested, argument_types);
|
2017-04-01 07:20:54 +00:00
|
|
|
}
|
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
if (endsWith(name, "ForEach"))
|
2017-07-10 23:30:17 +00:00
|
|
|
{
|
|
|
|
DataTypes nested_arguments;
|
2017-12-20 07:36:30 +00:00
|
|
|
for (const auto & type : argument_types)
|
|
|
|
{
|
|
|
|
if (const DataTypeArray * array = typeid_cast<const DataTypeArray *>(type.get()))
|
|
|
|
nested_arguments.push_back(array->getNestedType());
|
|
|
|
else
|
2017-12-20 08:14:33 +00:00
|
|
|
throw Exception("Illegal type " + type->getName() + " of argument"
|
2017-12-20 07:36:30 +00:00
|
|
|
" for aggregate function " + name + ". Must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
|
|
|
}
|
2017-04-09 12:26:41 +00:00
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
AggregateFunctionPtr nested = get(trimRight(name, "ForEach"), nested_arguments, parameters, recursion_level + 1);
|
|
|
|
return createAggregateFunctionForEach(nested, argument_types);
|
2017-07-10 23:30:17 +00:00
|
|
|
}
|
2017-04-09 12:26:41 +00:00
|
|
|
|
2017-07-10 23:30:17 +00:00
|
|
|
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
|
2011-09-25 05:07:47 +00:00
|
|
|
}
|
2011-09-19 03:40:05 +00:00
|
|
|
|
2011-09-25 05:07:47 +00:00
|
|
|
|
2017-07-10 23:30:17 +00:00
|
|
|
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const
|
2011-09-19 03:40:05 +00:00
|
|
|
{
|
2017-04-01 07:20:54 +00:00
|
|
|
return isAggregateFunctionName(name)
|
2017-07-10 23:30:17 +00:00
|
|
|
? get(name, argument_types, parameters)
|
2017-04-01 07:20:54 +00:00
|
|
|
: nullptr;
|
2013-05-24 10:49:19 +00:00
|
|
|
}
|
|
|
|
|
2015-04-24 15:49:30 +00:00
|
|
|
|
|
|
|
bool AggregateFunctionFactory::isAggregateFunctionName(const String & name, int recursion_level) const
|
|
|
|
{
|
2017-04-01 07:20:54 +00:00
|
|
|
if (aggregate_functions.count(name))
|
|
|
|
return true;
|
2016-07-14 05:22:09 +00:00
|
|
|
|
2017-04-01 07:20:54 +00:00
|
|
|
if (recursion_level == 0 && case_insensitive_aggregate_functions.count(Poco::toLower(name)))
|
|
|
|
return true;
|
2016-07-14 05:22:09 +00:00
|
|
|
|
2017-12-20 07:36:30 +00:00
|
|
|
if (endsWith(name, "State"))
|
2017-04-01 07:20:54 +00:00
|
|
|
return isAggregateFunctionName(trimRight(name, "State"), recursion_level + 1);
|
2017-12-20 07:36:30 +00:00
|
|
|
if (endsWith(name, "Merge"))
|
2017-04-01 07:20:54 +00:00
|
|
|
return isAggregateFunctionName(trimRight(name, "Merge"), recursion_level + 1);
|
2017-12-20 07:36:30 +00:00
|
|
|
if (endsWith(name, "If"))
|
2017-04-01 07:20:54 +00:00
|
|
|
return isAggregateFunctionName(trimRight(name, "If"), recursion_level + 1);
|
2017-12-20 07:36:30 +00:00
|
|
|
if (endsWith(name, "Array"))
|
|
|
|
return isAggregateFunctionName(trimRight(name, "Array"), recursion_level + 1);
|
|
|
|
if (endsWith(name, "ForEach"))
|
|
|
|
return isAggregateFunctionName(trimRight(name, "ForEach"), recursion_level + 1);
|
2017-04-09 12:26:41 +00:00
|
|
|
|
2017-07-10 23:30:17 +00:00
|
|
|
return false;
|
2015-09-24 12:40:36 +00:00
|
|
|
}
|
|
|
|
|
2011-09-19 03:40:05 +00:00
|
|
|
}
|