ClickHouse/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp

254 lines
10 KiB
C++
Raw Normal View History

2011-09-19 03:40:05 +00:00
#include <DB/AggregateFunctions/AggregateFunctionFactory.h>
#include <DB/IO/WriteBuffer.h>
#include <DB/IO/WriteHelpers.h>
2015-09-24 12:40:36 +00:00
#include <DB/DataTypes/DataTypeAggregateFunction.h>
#include <DB/DataTypes/DataTypeArray.h>
#include <DB/DataTypes/DataTypeNullable.h>
#include <DB/Common/StringUtils.h>
#include <Poco/String.h>
2011-09-19 03:40:05 +00:00
namespace DB
{
2016-01-12 02:21:15 +00:00
namespace ErrorCodes
{
extern const int UNKNOWN_AGGREGATE_FUNCTION;
extern const int LOGICAL_ERROR;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
2015-09-24 12:40:36 +00:00
namespace
2011-09-19 03:40:05 +00:00
{
2015-09-24 12:40:36 +00:00
/// Ничего не проверяет.
std::string trimRight(const std::string & in, const char * suffix)
{
return in.substr(0, in.size() - strlen(suffix));
}
}
2015-09-24 12:40:36 +00:00
void registerAggregateFunctionAvg(AggregateFunctionFactory & factory);
void registerAggregateFunctionCount(AggregateFunctionFactory & factory);
void registerAggregateFunctionGroupArray(AggregateFunctionFactory & factory);
void registerAggregateFunctionGroupUniqArray(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantile(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileExact(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileExactWeighted(AggregateFunctionFactory & factory);
2015-09-24 12:40:36 +00:00
void registerAggregateFunctionsQuantileDeterministic(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileTiming(AggregateFunctionFactory & factory);
void registerAggregateFunctionsQuantileTDigest(AggregateFunctionFactory & factory);
2015-09-24 12:40:36 +00:00
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory);
void registerAggregateFunctionsMinMaxAny(AggregateFunctionFactory & factory);
void registerAggregateFunctionsStatistics(AggregateFunctionFactory & factory);
void registerAggregateFunctionSum(AggregateFunctionFactory & factory);
void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory);
void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory & factory);
void registerAggregateFunctionDebug(AggregateFunctionFactory & factory);
2015-09-24 12:40:36 +00:00
AggregateFunctionPtr createAggregateFunctionArray(AggregateFunctionPtr & nested);
AggregateFunctionPtr createAggregateFunctionIf(AggregateFunctionPtr & nested);
AggregateFunctionPtr createAggregateFunctionState(AggregateFunctionPtr & nested);
AggregateFunctionPtr createAggregateFunctionMerge(AggregateFunctionPtr & nested);
AggregateFunctionPtr createAggregateFunctionNull(AggregateFunctionPtr & nested);
2015-09-24 12:40:36 +00:00
AggregateFunctionFactory::AggregateFunctionFactory()
{
2015-09-24 12:40:36 +00:00
registerAggregateFunctionAvg(*this);
registerAggregateFunctionCount(*this);
registerAggregateFunctionGroupArray(*this);
registerAggregateFunctionGroupUniqArray(*this);
registerAggregateFunctionsQuantile(*this);
registerAggregateFunctionsQuantileExact(*this);
registerAggregateFunctionsQuantileExactWeighted(*this);
2015-09-24 12:40:36 +00:00
registerAggregateFunctionsQuantileDeterministic(*this);
registerAggregateFunctionsQuantileTiming(*this);
registerAggregateFunctionsQuantileTDigest(*this);
2015-09-24 12:40:36 +00:00
registerAggregateFunctionsSequenceMatch(*this);
registerAggregateFunctionsMinMaxAny(*this);
registerAggregateFunctionsStatistics(*this);
registerAggregateFunctionSum(*this);
registerAggregateFunctionsUniq(*this);
registerAggregateFunctionUniqUpTo(*this);
registerAggregateFunctionDebug(*this);
}
void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness)
2011-09-19 03:40:05 +00:00
{
2015-09-24 12:40:36 +00:00
if (creator == nullptr)
throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
2015-09-24 12:40:36 +00:00
" a null constructor", ErrorCodes::LOGICAL_ERROR);
if (!aggregate_functions.emplace(name, creator).second)
throw Exception("AggregateFunctionFactory: the aggregate function name " + name + " is not unique",
ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive
&& !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second)
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name " + name + " is not unique",
ErrorCodes::LOGICAL_ERROR);
2015-09-24 12:40:36 +00:00
}
2015-09-24 12:40:36 +00:00
AggregateFunctionPtr AggregateFunctionFactory::get(const String & name, const DataTypes & argument_types, int recursion_level) const
{
bool has_nullable_types = false;
for (const auto & arg_type : argument_types)
{
if (arg_type.get()->isNullable())
{
has_nullable_types = true;
break;
}
}
if (has_nullable_types)
{
DataTypes new_argument_types;
new_argument_types.reserve(argument_types.size());
for (const auto & arg_type : argument_types)
{
if (arg_type.get()->isNullable())
{
const DataTypeNullable & actual_type = static_cast<const DataTypeNullable &>(*arg_type.get());
const DataTypePtr & nested_type = actual_type.getNestedType();
new_argument_types.push_back(nested_type);
}
else
new_argument_types.push_back(arg_type);
}
AggregateFunctionPtr function = getImpl(name, new_argument_types, recursion_level);
return createAggregateFunctionNull(function);
}
else
return getImpl(name, argument_types, recursion_level);
}
AggregateFunctionPtr AggregateFunctionFactory::getImpl(const String & name, const DataTypes & argument_types, int recursion_level) const
2015-09-24 12:40:36 +00:00
{
auto it = aggregate_functions.find(name);
if (it != aggregate_functions.end())
{
auto it = aggregate_functions.find(name);
if (it != aggregate_functions.end())
return it->second(name, argument_types);
}
if (recursion_level == 0)
{
auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name));
if (it != case_insensitive_aggregate_functions.end())
return it->second(name, argument_types);
}
if ((recursion_level == 0) && endsWith(name, "State"))
2014-05-21 13:27:40 +00:00
{
/// Для агрегатных функций вида aggState, где agg - имя другой агрегатной функции.
AggregateFunctionPtr nested = get(trimRight(name, "State"), argument_types, recursion_level + 1);
2015-09-24 12:40:36 +00:00
return createAggregateFunctionState(nested);
2014-05-21 13:27:40 +00:00
}
if ((recursion_level <= 1) && endsWith(name, "Merge"))
2014-05-21 13:27:40 +00:00
{
/// Для агрегатных функций вида aggMerge, где agg - имя другой агрегатной функции.
if (argument_types.size() != 1)
throw Exception("Incorrect number of arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const DataTypeAggregateFunction * function = typeid_cast<const DataTypeAggregateFunction *>(&*argument_types[0]);
2014-05-21 13:27:40 +00:00
if (!function)
2015-09-24 12:40:36 +00:00
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
2014-05-21 13:27:40 +00:00
AggregateFunctionPtr nested = get(trimRight(name, "Merge"), function->getArgumentsDataTypes(), recursion_level + 1);
2014-05-21 13:27:40 +00:00
if (nested->getName() != function->getFunctionName())
2015-09-24 12:40:36 +00:00
throw Exception("Illegal type " + argument_types[0]->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
2014-05-21 13:27:40 +00:00
2015-09-24 12:40:36 +00:00
return createAggregateFunctionMerge(nested);
2014-05-21 13:27:40 +00:00
}
if ((recursion_level <= 2) && endsWith(name, "If"))
{
if (argument_types.empty())
throw Exception{
"Incorrect number of arguments for aggregate function " + name,
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH
};
/// Для агрегатных функций вида aggIf, где agg - имя другой агрегатной функции.
DataTypes nested_dt = argument_types;
nested_dt.pop_back();
AggregateFunctionPtr nested = get(trimRight(name, "If"), nested_dt, recursion_level + 1);
2015-09-24 12:40:36 +00:00
return createAggregateFunctionIf(nested);
}
if ((recursion_level <= 3) && endsWith(name, "Array"))
{
/// Для агрегатных функций вида aggArray, где agg - имя другой агрегатной функции.
size_t num_agruments = argument_types.size();
DataTypes nested_arguments;
for (size_t i = 0; i < num_agruments; ++i)
{
if (const DataTypeArray * array = typeid_cast<const DataTypeArray *>(&*argument_types[i]))
nested_arguments.push_back(array->getNestedType());
else
2015-09-24 12:40:36 +00:00
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument #" + toString(i + 1) +
" for aggregate function " + name + ". Must be array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
2015-09-24 12:40:36 +00:00
/// + 3, чтобы ни один другой модификатор не мог идти перед Array
AggregateFunctionPtr nested = get(trimRight(name, "Array"), nested_arguments, recursion_level + 3);
2015-09-24 12:40:36 +00:00
return createAggregateFunctionArray(nested);
}
throw Exception("Unknown aggregate function " + name, ErrorCodes::UNKNOWN_AGGREGATE_FUNCTION);
2011-09-25 05:07:47 +00:00
}
2011-09-19 03:40:05 +00:00
2011-09-25 05:07:47 +00:00
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types) const
2011-09-19 03:40:05 +00:00
{
return isAggregateFunctionName(name)
? get(name, argument_types)
2015-09-24 12:40:36 +00:00
: nullptr;
}
2015-04-24 15:49:30 +00:00
bool AggregateFunctionFactory::isAggregateFunctionName(const String & name, int recursion_level) const
{
if (aggregate_functions.count(name))
return true;
if (recursion_level == 0 && case_insensitive_aggregate_functions.count(Poco::toLower(name)))
2015-04-24 15:49:30 +00:00
return true;
2014-05-21 13:27:40 +00:00
/// Для агрегатных функций вида aggState, где agg - имя другой агрегатной функции.
if ((recursion_level <= 0) && endsWith(name, "State"))
return isAggregateFunctionName(trimRight(name, "State"), recursion_level + 1);
2014-05-21 13:27:40 +00:00
/// Для агрегатных функций вида aggMerge, где agg - имя другой агрегатной функции.
if ((recursion_level <= 1) && endsWith(name, "Merge"))
return isAggregateFunctionName(trimRight(name, "Merge"), recursion_level + 1);
/// Для агрегатных функций вида aggIf, где agg - имя другой агрегатной функции.
if ((recursion_level <= 2) && endsWith(name, "If"))
return isAggregateFunctionName(trimRight(name, "If"), recursion_level + 1);
/// Для агрегатных функций вида aggArray, где agg - имя другой агрегатной функции.
if ((recursion_level <= 3) && endsWith(name, "Array"))
2015-09-24 12:40:36 +00:00
{
/// + 3, чтобы ни один другой модификатор не мог идти перед Array
return isAggregateFunctionName(trimRight(name, "Array"), recursion_level + 3);
2015-09-24 12:40:36 +00:00
}
return false;
2015-09-24 12:40:36 +00:00
}
2011-09-19 03:40:05 +00:00
}