2017-04-01 09:19:00 +00:00
# include <AggregateFunctions/AggregateFunctionFactory.h>
2017-08-18 17:06:22 +00:00
2017-04-01 09:19:00 +00:00
# include <DataTypes/DataTypeAggregateFunction.h>
# include <DataTypes/DataTypeArray.h>
# include <DataTypes/DataTypeNullable.h>
2017-08-18 17:06:22 +00:00
# include <IO/WriteBuffer.h>
# include <IO/WriteHelpers.h>
# include <Interpreters/Context.h>
2017-04-01 09:19:00 +00:00
# include <Common/StringUtils.h>
2017-07-13 16:49:09 +00:00
# include <Common/typeid_cast.h>
2016-07-12 13:02:52 +00:00
2017-08-18 17:06:22 +00:00
# include <Poco/String.h>
2011-09-19 03:40:05 +00:00
namespace DB
{
2016-01-12 02:21:15 +00:00
namespace ErrorCodes
{
2017-04-01 07:20:54 +00:00
extern const int UNKNOWN_AGGREGATE_FUNCTION ;
extern const int LOGICAL_ERROR ;
extern const int ILLEGAL_TYPE_OF_ARGUMENT ;
2016-01-12 02:21:15 +00:00
}
2015-09-24 12:40:36 +00:00
namespace
2011-09-19 03:40:05 +00:00
{
2017-03-25 20:12:56 +00:00
/// Does not check anything.
2016-07-14 05:22:09 +00:00
std : : string trimRight ( const std : : string & in , const char * suffix )
2015-02-27 17:38:21 +00:00
{
2017-04-01 07:20:54 +00:00
return in . substr ( 0 , in . size ( ) - strlen ( suffix ) ) ;
2015-02-27 17:38:21 +00:00
}
2014-08-18 05:45:41 +00:00
}
2015-09-24 12:40:36 +00:00
AggregateFunctionPtr createAggregateFunctionArray ( AggregateFunctionPtr & nested ) ;
2017-04-09 12:26:41 +00:00
AggregateFunctionPtr createAggregateFunctionForEach ( AggregateFunctionPtr & nested ) ;
2015-09-24 12:40:36 +00:00
AggregateFunctionPtr createAggregateFunctionIf ( AggregateFunctionPtr & nested ) ;
AggregateFunctionPtr createAggregateFunctionState ( AggregateFunctionPtr & nested ) ;
AggregateFunctionPtr createAggregateFunctionMerge ( AggregateFunctionPtr & nested ) ;
2017-02-10 05:03:42 +00:00
AggregateFunctionPtr createAggregateFunctionNullUnary ( AggregateFunctionPtr & nested ) ;
AggregateFunctionPtr createAggregateFunctionNullVariadic ( AggregateFunctionPtr & nested ) ;
2017-02-10 09:02:10 +00:00
AggregateFunctionPtr createAggregateFunctionCountNotNull ( const DataTypes & argument_types ) ;
2014-08-18 05:45:41 +00:00
2015-03-01 01:06:49 +00:00
2016-07-14 05:22:09 +00:00
void AggregateFunctionFactory : : registerFunction ( const String & name , Creator creator , CaseSensitiveness case_sensitiveness )
2011-09-19 03:40:05 +00:00
{
2017-04-01 07:20:54 +00:00
if ( creator = = nullptr )
throw Exception ( " AggregateFunctionFactory: the aggregate function " + name + " has been provided "
" a null constructor " , ErrorCodes : : LOGICAL_ERROR ) ;
if ( ! aggregate_functions . emplace ( name , creator ) . second )
2017-06-10 09:04:31 +00:00
throw Exception ( " AggregateFunctionFactory: the aggregate function name ' " + name + " ' is not unique " ,
2017-04-01 07:20:54 +00:00
ErrorCodes : : LOGICAL_ERROR ) ;
if ( case_sensitiveness = = CaseInsensitive
& & ! case_insensitive_aggregate_functions . emplace ( Poco : : toLower ( name ) , creator ) . second )
2017-06-10 09:04:31 +00:00
throw Exception ( " AggregateFunctionFactory: the case insensitive aggregate function name ' " + name + " ' is not unique " ,
2017-04-01 07:20:54 +00:00
ErrorCodes : : LOGICAL_ERROR ) ;
2015-09-24 12:40:36 +00:00
}
2015-05-17 17:46:21 +00:00
2017-07-10 23:30:17 +00:00
AggregateFunctionPtr AggregateFunctionFactory : : get (
const String & name ,
const DataTypes & argument_types ,
const Array & parameters ,
int recursion_level ) const
2016-07-12 13:02:52 +00:00
{
2017-04-01 07:20:54 +00:00
bool has_nullable_types = false ;
for ( const auto & arg_type : argument_types )
{
if ( arg_type - > isNullable ( ) | | arg_type - > isNull ( ) )
{
has_nullable_types = true ;
break ;
}
}
if ( has_nullable_types )
{
/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if ( Poco : : toLower ( name ) = = " count " )
return createAggregateFunctionCountNotNull ( argument_types ) ;
DataTypes nested_argument_types ;
nested_argument_types . reserve ( argument_types . size ( ) ) ;
for ( const auto & arg_type : argument_types )
{
if ( arg_type - > isNullable ( ) )
{
const DataTypeNullable & actual_type = static_cast < const DataTypeNullable & > ( * arg_type . get ( ) ) ;
const DataTypePtr & nested_type = actual_type . getNestedType ( ) ;
nested_argument_types . push_back ( nested_type ) ;
}
else
nested_argument_types . push_back ( arg_type ) ;
}
2017-07-10 23:30:17 +00:00
AggregateFunctionPtr function = getImpl ( name , nested_argument_types , parameters , recursion_level ) ;
2017-04-01 07:20:54 +00:00
if ( argument_types . size ( ) = = 1 )
return createAggregateFunctionNullUnary ( function ) ;
else
return createAggregateFunctionNullVariadic ( function ) ;
}
else
2017-07-10 23:30:17 +00:00
return getImpl ( name , argument_types , parameters , recursion_level ) ;
2016-07-12 13:02:52 +00:00
}
2017-07-10 23:30:17 +00:00
AggregateFunctionPtr AggregateFunctionFactory : : getImpl (
const String & name ,
const DataTypes & argument_types ,
const Array & parameters ,
int recursion_level ) const
2015-09-24 12:40:36 +00:00
{
2017-04-01 07:20:54 +00:00
auto it = aggregate_functions . find ( name ) ;
if ( it ! = aggregate_functions . end ( ) )
{
auto it = aggregate_functions . find ( name ) ;
if ( it ! = aggregate_functions . end ( ) )
2017-07-10 23:30:17 +00:00
return it - > second ( name , argument_types , parameters ) ;
2017-04-01 07:20:54 +00:00
}
if ( recursion_level = = 0 )
{
auto it = case_insensitive_aggregate_functions . find ( Poco : : toLower ( name ) ) ;
if ( it ! = case_insensitive_aggregate_functions . end ( ) )
2017-07-10 23:30:17 +00:00
return it - > second ( name , argument_types , parameters ) ;
2017-04-01 07:20:54 +00:00
}
if ( ( recursion_level = = 0 ) & & endsWith ( name , " State " ) )
{
/// For aggregate functions of the form `aggState`, where `agg` is the name of another aggregate function.
2017-07-10 23:30:17 +00:00
AggregateFunctionPtr nested = get ( trimRight ( name , " State " ) , argument_types , parameters , recursion_level + 1 ) ;
2017-04-01 07:20:54 +00:00
return createAggregateFunctionState ( nested ) ;
}
if ( ( recursion_level < = 1 ) & & endsWith ( name , " Merge " ) )
{
/// For aggregate functions of the form `aggMerge`, where `agg` is the name of another aggregate function.
if ( argument_types . size ( ) ! = 1 )
throw Exception ( " Incorrect number of arguments for aggregate function " + name , ErrorCodes : : NUMBER_OF_ARGUMENTS_DOESNT_MATCH ) ;
const DataTypeAggregateFunction * function = typeid_cast < const DataTypeAggregateFunction * > ( & * argument_types [ 0 ] ) ;
if ( ! function )
throw Exception ( " Illegal type " + argument_types [ 0 ] - > getName ( ) + " of argument for aggregate function " + name ,
ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
2017-07-10 23:30:17 +00:00
AggregateFunctionPtr nested = get ( trimRight ( name , " Merge " ) , function - > getArgumentsDataTypes ( ) , parameters , recursion_level + 1 ) ;
2017-04-01 07:20:54 +00:00
if ( nested - > getName ( ) ! = function - > getFunctionName ( ) )
throw Exception ( " Illegal type " + argument_types [ 0 ] - > getName ( ) + " of argument for aggregate function " + name ,
ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
return createAggregateFunctionMerge ( nested ) ;
}
if ( ( recursion_level < = 2 ) & & endsWith ( name , " If " ) )
{
if ( argument_types . empty ( ) )
throw Exception {
" Incorrect number of arguments for aggregate function " + name ,
2017-07-12 19:20:57 +00:00
ErrorCodes : : NUMBER_OF_ARGUMENTS_DOESNT_MATCH } ;
2017-04-01 07:20:54 +00:00
/// For aggregate functions of the form `aggIf`, where `agg` is the name of another aggregate function.
DataTypes nested_dt = argument_types ;
nested_dt . pop_back ( ) ;
2017-07-10 23:30:17 +00:00
AggregateFunctionPtr nested = get ( trimRight ( name , " If " ) , nested_dt , parameters , recursion_level + 1 ) ;
2017-04-01 07:20:54 +00:00
return createAggregateFunctionIf ( nested ) ;
}
if ( ( recursion_level < = 3 ) & & endsWith ( name , " Array " ) )
{
/// For aggregate functions of the form `aggArray`, where `agg` is the name of another aggregate function.
size_t num_agruments = argument_types . size ( ) ;
DataTypes nested_arguments ;
for ( size_t i = 0 ; i < num_agruments ; + + i )
{
if ( const DataTypeArray * array = typeid_cast < const DataTypeArray * > ( & * argument_types [ i ] ) )
nested_arguments . push_back ( array - > getNestedType ( ) ) ;
else
throw Exception ( " Illegal type " + argument_types [ i ] - > getName ( ) + " of argument # " + toString ( i + 1 ) +
" for aggregate function " + name + " . Must be array. " , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
}
/// + 3, so that no other modifier can go before the `Array`
2017-07-10 23:30:17 +00:00
AggregateFunctionPtr nested = get ( trimRight ( name , " Array " ) , nested_arguments , parameters , recursion_level + 3 ) ;
2017-04-01 07:20:54 +00:00
return createAggregateFunctionArray ( nested ) ;
}
2017-07-10 23:30:17 +00:00
if ( ( recursion_level < = 3 ) & & endsWith ( name , " ForEach " ) )
{
/// For functions like aggForEach, where 'agg' is the name of another aggregate function
if ( argument_types . size ( ) ! = 1 )
throw Exception ( " Incorrect number of arguments for aggregate function " + name , ErrorCodes : : NUMBER_OF_ARGUMENTS_DOESNT_MATCH ) ;
2017-04-09 12:26:41 +00:00
2017-07-10 23:30:17 +00:00
DataTypes nested_arguments ;
if ( const DataTypeArray * array = typeid_cast < const DataTypeArray * > ( & * argument_types [ 0 ] ) )
nested_arguments . push_back ( array - > getNestedType ( ) ) ;
else
throw Exception ( " Illegal type " + argument_types [ 0 ] - > getName ( ) + " of argument for aggregate function " + name + " . Must be array. " , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
2017-04-09 12:26:41 +00:00
2017-07-10 23:30:17 +00:00
/// + 3, so that no other modifier can stay before ForEach. Note that the modifiers Array and ForEach are mutually exclusive.
AggregateFunctionPtr nested = get ( trimRight ( name , " ForEach " ) , nested_arguments , parameters , recursion_level + 3 ) ;
return createAggregateFunctionForEach ( nested ) ;
}
2017-04-09 12:26:41 +00:00
2017-07-10 23:30:17 +00:00
throw Exception ( " Unknown aggregate function " + name , ErrorCodes : : UNKNOWN_AGGREGATE_FUNCTION ) ;
2011-09-25 05:07:47 +00:00
}
2011-09-19 03:40:05 +00:00
2011-09-25 05:07:47 +00:00
2017-07-10 23:30:17 +00:00
AggregateFunctionPtr AggregateFunctionFactory : : tryGet ( const String & name , const DataTypes & argument_types , const Array & parameters ) const
2011-09-19 03:40:05 +00:00
{
2017-04-01 07:20:54 +00:00
return isAggregateFunctionName ( name )
2017-07-10 23:30:17 +00:00
? get ( name , argument_types , parameters )
2017-04-01 07:20:54 +00:00
: nullptr ;
2013-05-24 10:49:19 +00:00
}
2015-04-24 15:49:30 +00:00
bool AggregateFunctionFactory : : isAggregateFunctionName ( const String & name , int recursion_level ) const
{
2017-04-01 07:20:54 +00:00
if ( aggregate_functions . count ( name ) )
return true ;
2016-07-14 05:22:09 +00:00
2017-04-01 07:20:54 +00:00
if ( recursion_level = = 0 & & case_insensitive_aggregate_functions . count ( Poco : : toLower ( name ) ) )
return true ;
2016-07-14 05:22:09 +00:00
2017-04-01 07:20:54 +00:00
/// For aggregate functions of the form `aggState`, where `agg` is the name of another aggregate function.
if ( ( recursion_level < = 0 ) & & endsWith ( name , " State " ) )
return isAggregateFunctionName ( trimRight ( name , " State " ) , recursion_level + 1 ) ;
2016-07-14 05:22:09 +00:00
2017-04-01 07:20:54 +00:00
/// For aggregate functions of the form `aggMerge`, where `agg` is the name of another aggregate function.
if ( ( recursion_level < = 1 ) & & endsWith ( name , " Merge " ) )
return isAggregateFunctionName ( trimRight ( name , " Merge " ) , recursion_level + 1 ) ;
2016-07-14 05:22:09 +00:00
2017-04-01 07:20:54 +00:00
/// For aggregate functions of the form `aggIf`, where `agg` is the name of another aggregate function.
if ( ( recursion_level < = 2 ) & & endsWith ( name , " If " ) )
return isAggregateFunctionName ( trimRight ( name , " If " ) , recursion_level + 1 ) ;
2016-07-14 05:22:09 +00:00
2017-04-01 07:20:54 +00:00
/// For aggregate functions of the form `aggArray`, where `agg` is the name of another aggregate function.
if ( ( recursion_level < = 3 ) & & endsWith ( name , " Array " ) )
{
/// + 3, so that no other modifier can go before `Array`
return isAggregateFunctionName ( trimRight ( name , " Array " ) , recursion_level + 3 ) ;
}
2015-09-24 12:40:36 +00:00
2017-07-10 23:30:17 +00:00
if ( ( recursion_level < = 3 ) & & endsWith ( name , " ForEach " ) )
{
2017-04-09 12:26:41 +00:00
/// + 3, so that no other modifier can go before `ForEach`
2017-07-10 23:30:17 +00:00
return isAggregateFunctionName ( trimRight ( name , " ForEach " ) , recursion_level + 3 ) ;
}
2017-04-09 12:26:41 +00:00
2017-07-10 23:30:17 +00:00
return false ;
2015-09-24 12:40:36 +00:00
}
2011-09-19 03:40:05 +00:00
}