ClickHouse/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp

102 lines
3.4 KiB
C++
Raw Normal View History

2019-01-22 21:07:05 +00:00
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionMLMethod.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
2019-01-23 01:29:53 +00:00
namespace
{
2019-01-22 21:07:05 +00:00
2019-01-26 12:38:42 +00:00
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
2019-01-28 11:54:55 +00:00
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
2019-01-23 01:29:53 +00:00
template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
2019-01-26 12:38:42 +00:00
if (parameters.size() > 4)
throw Exception("Aggregate function " + name + " requires at most four parameters", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
2019-01-22 21:07:05 +00:00
2019-04-08 21:01:10 +00:00
if (argument_types.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
2019-01-23 01:29:53 +00:00
for (size_t i = 0; i < argument_types.size(); ++i)
{
if (!WhichDataType(argument_types[i]).isFloat64())
2019-04-15 00:16:13 +00:00
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " + std::to_string(i) + "for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-01-23 18:03:26 +00:00
Float64 learning_rate = Float64(0.01);
2019-04-08 21:01:10 +00:00
Float64 l2_reg_coef = Float64(0.01);
2019-01-23 18:03:26 +00:00
UInt32 batch_size = 1;
2019-01-26 12:38:42 +00:00
std::shared_ptr<IWeightsUpdater> wu;
2019-04-08 21:01:10 +00:00
std::shared_ptr<IGradientComputer> gc;
2019-01-23 18:03:26 +00:00
if (!parameters.empty())
2019-01-23 01:29:53 +00:00
{
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
}
2019-01-23 18:03:26 +00:00
if (parameters.size() > 1)
{
2019-04-08 21:01:10 +00:00
l2_reg_coef = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[1]);
2019-01-23 18:03:26 +00:00
}
2019-04-08 21:01:10 +00:00
if (parameters.size() > 2)
2019-01-26 12:38:42 +00:00
{
2019-04-08 21:01:10 +00:00
batch_size = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]);
2019-01-26 12:38:42 +00:00
}
2019-04-08 21:01:10 +00:00
if (parameters.size() > 3)
2019-01-26 12:38:42 +00:00
{
2019-04-08 21:01:10 +00:00
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{1.0})
2019-01-26 12:38:42 +00:00
{
wu = std::make_shared<StochasticGradientDescent>();
2019-04-15 00:16:13 +00:00
}
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{2.0})
2019-01-28 10:39:57 +00:00
{
wu = std::make_shared<Momentum>();
2019-04-15 00:16:13 +00:00
}
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{3.0})
2019-02-26 08:12:16 +00:00
{
wu = std::make_shared<Nesterov>();
2019-03-03 08:46:36 +00:00
2019-04-15 00:16:13 +00:00
}
else
{
2019-04-08 21:01:10 +00:00
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
2019-01-26 12:38:42 +00:00
}
2019-04-15 00:16:13 +00:00
}
else
2019-01-26 12:38:42 +00:00
{
wu = std::make_unique<StochasticGradientDescent>();
}
2019-04-08 21:01:10 +00:00
if (std::is_same<Method, FuncLinearRegression>::value)
{
gc = std::make_shared<LinearRegression>();
2019-04-15 00:16:13 +00:00
}
else if (std::is_same<Method, FuncLogisticRegression>::value)
2019-04-08 21:01:10 +00:00
{
gc = std::make_shared<LogisticRegression>();
2019-04-15 00:16:13 +00:00
}
else
2019-04-08 21:01:10 +00:00
{
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
2019-01-22 21:07:05 +00:00
2019-04-08 21:01:10 +00:00
return std::make_shared<Method>(argument_types.size() - 1, gc, wu, learning_rate, l2_reg_coef, batch_size, argument_types, parameters);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-01-23 01:29:53 +00:00
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
2019-01-23 01:29:53 +00:00
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
2019-01-28 11:54:55 +00:00
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
}