2019-01-22 21:07:05 +00:00
# include <AggregateFunctions/AggregateFunctionFactory.h>
# include <AggregateFunctions/AggregateFunctionMLMethod.h>
# include <AggregateFunctions/Helpers.h>
# include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
2019-01-23 01:29:53 +00:00
namespace
{
2019-01-22 21:07:05 +00:00
2019-01-26 12:38:42 +00:00
using FuncLinearRegression = AggregateFunctionMLMethod < LinearModelData , NameLinearRegression > ;
2019-01-28 11:54:55 +00:00
using FuncLogisticRegression = AggregateFunctionMLMethod < LinearModelData , NameLogisticRegression > ;
2019-01-23 01:29:53 +00:00
template < class Method >
AggregateFunctionPtr createAggregateFunctionMLMethod (
const std : : string & name , const DataTypes & argument_types , const Array & parameters )
{
2019-01-26 12:38:42 +00:00
if ( parameters . size ( ) > 4 )
throw Exception ( " Aggregate function " + name + " requires at most four parameters " , ErrorCodes : : NUMBER_OF_ARGUMENTS_DOESNT_MATCH ) ;
2019-01-22 21:07:05 +00:00
2019-04-08 21:01:10 +00:00
if ( argument_types . size ( ) < 2 )
throw Exception ( " Aggregate function " + name + " requires at least two arguments " , ErrorCodes : : NUMBER_OF_ARGUMENTS_DOESNT_MATCH ) ;
2019-01-23 01:29:53 +00:00
for ( size_t i = 0 ; i < argument_types . size ( ) ; + + i )
{
if ( ! WhichDataType ( argument_types [ i ] ) . isFloat64 ( ) )
2019-04-15 00:16:13 +00:00
throw Exception ( " Illegal type " + argument_types [ i ] - > getName ( ) + " of argument " + std : : to_string ( i ) + " for aggregate function " + name , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-01-23 18:03:26 +00:00
Float64 learning_rate = Float64 ( 0.01 ) ;
2019-04-08 21:01:10 +00:00
Float64 l2_reg_coef = Float64 ( 0.01 ) ;
2019-01-23 18:03:26 +00:00
UInt32 batch_size = 1 ;
2019-01-26 12:38:42 +00:00
std : : shared_ptr < IWeightsUpdater > wu ;
2019-04-08 21:01:10 +00:00
std : : shared_ptr < IGradientComputer > gc ;
2019-01-23 18:03:26 +00:00
if ( ! parameters . empty ( ) )
2019-01-23 01:29:53 +00:00
{
learning_rate = applyVisitor ( FieldVisitorConvertToNumber < Float64 > ( ) , parameters [ 0 ] ) ;
}
2019-01-23 18:03:26 +00:00
if ( parameters . size ( ) > 1 )
{
2019-04-08 21:01:10 +00:00
l2_reg_coef = applyVisitor ( FieldVisitorConvertToNumber < Float64 > ( ) , parameters [ 1 ] ) ;
2019-01-23 18:03:26 +00:00
}
2019-04-08 21:01:10 +00:00
if ( parameters . size ( ) > 2 )
2019-01-26 12:38:42 +00:00
{
2019-04-08 21:01:10 +00:00
batch_size = applyVisitor ( FieldVisitorConvertToNumber < UInt32 > ( ) , parameters [ 2 ] ) ;
2019-01-26 12:38:42 +00:00
}
2019-04-08 21:01:10 +00:00
if ( parameters . size ( ) > 3 )
2019-01-26 12:38:42 +00:00
{
2019-04-08 21:01:10 +00:00
if ( applyVisitor ( FieldVisitorConvertToNumber < UInt32 > ( ) , parameters [ 3 ] ) = = Float64 { 1.0 } )
2019-01-26 12:38:42 +00:00
{
wu = std : : make_shared < StochasticGradientDescent > ( ) ;
2019-04-15 00:16:13 +00:00
}
else if ( applyVisitor ( FieldVisitorConvertToNumber < UInt32 > ( ) , parameters [ 3 ] ) = = Float64 { 2.0 } )
2019-01-28 10:39:57 +00:00
{
wu = std : : make_shared < Momentum > ( ) ;
2019-04-15 00:16:13 +00:00
}
else if ( applyVisitor ( FieldVisitorConvertToNumber < UInt32 > ( ) , parameters [ 3 ] ) = = Float64 { 3.0 } )
2019-02-26 08:12:16 +00:00
{
wu = std : : make_shared < Nesterov > ( ) ;
2019-03-03 08:46:36 +00:00
2019-04-15 00:16:13 +00:00
}
else
{
2019-04-08 21:01:10 +00:00
throw Exception ( " Invalid parameter for weights updater " , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
2019-01-26 12:38:42 +00:00
}
2019-04-15 00:16:13 +00:00
}
else
2019-01-26 12:38:42 +00:00
{
wu = std : : make_unique < StochasticGradientDescent > ( ) ;
}
2019-04-08 21:01:10 +00:00
if ( std : : is_same < Method , FuncLinearRegression > : : value )
{
gc = std : : make_shared < LinearRegression > ( ) ;
2019-04-15 00:16:13 +00:00
}
else if ( std : : is_same < Method , FuncLogisticRegression > : : value )
2019-04-08 21:01:10 +00:00
{
gc = std : : make_shared < LogisticRegression > ( ) ;
2019-04-15 00:16:13 +00:00
}
else
2019-04-08 21:01:10 +00:00
{
throw Exception ( " Such gradient computer is not implemented yet " , ErrorCodes : : ILLEGAL_TYPE_OF_ARGUMENT ) ;
}
2019-01-22 21:07:05 +00:00
2019-04-08 21:01:10 +00:00
return std : : make_shared < Method > ( argument_types . size ( ) - 1 , gc , wu , learning_rate , l2_reg_coef , batch_size , argument_types , parameters ) ;
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-01-23 01:29:53 +00:00
}
2019-01-23 14:53:50 +00:00
void registerAggregateFunctionMLMethod ( AggregateFunctionFactory & factory )
{
2019-01-23 01:29:53 +00:00
factory . registerFunction ( " LinearRegression " , createAggregateFunctionMLMethod < FuncLinearRegression > ) ;
2019-01-28 11:54:55 +00:00
factory . registerFunction ( " LogisticRegression " , createAggregateFunctionMLMethod < FuncLogisticRegression > ) ;
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-01-23 14:53:50 +00:00
}