ClickHouse/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp

97 lines
3.7 KiB
C++
Raw Normal View History

2019-01-22 21:07:05 +00:00
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionMLMethod.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
2019-01-23 01:29:53 +00:00
namespace
{
2019-01-22 21:07:05 +00:00
2019-01-26 12:38:42 +00:00
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
2019-01-28 11:54:55 +00:00
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
2019-01-23 01:29:53 +00:00
template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
2019-01-26 12:38:42 +00:00
if (parameters.size() > 4)
2019-04-20 23:22:42 +00:00
throw Exception("Aggregate function " + name + " requires at most four parameters: learning_rate, l2_regularization_coef, mini-batch size and weights_updater method", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
2019-01-22 21:07:05 +00:00
2019-04-08 21:01:10 +00:00
if (argument_types.size() < 2)
2019-04-20 23:22:42 +00:00
throw Exception("Aggregate function " + name + " requires at least two arguments: target and model's parameters", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
2019-04-08 21:01:10 +00:00
2019-01-23 01:29:53 +00:00
for (size_t i = 0; i < argument_types.size(); ++i)
{
2019-04-20 23:22:42 +00:00
if (!isNumber(argument_types[i]))
throw Exception("Argument " + std::to_string(i) + " of type " + argument_types[i]->getName() + " must be numeric for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-04-20 23:22:42 +00:00
/// Such default parameters were picked because they did good on some tests,
/// though it still requires to fit parameters to achieve better result
auto learning_rate = Float64(0.01);
auto l2_reg_coef = Float64(0.01);
2019-01-23 18:03:26 +00:00
UInt32 batch_size = 1;
2019-01-26 12:38:42 +00:00
2019-04-21 14:42:37 +00:00
std::shared_ptr<IWeightsUpdater> weights_updater = std::make_shared<StochasticGradientDescent>();
2019-04-20 23:22:42 +00:00
std::shared_ptr<IGradientComputer> gradient_computer;
2019-04-08 21:01:10 +00:00
2019-01-23 18:03:26 +00:00
if (!parameters.empty())
2019-01-23 01:29:53 +00:00
{
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
}
2019-01-23 18:03:26 +00:00
if (parameters.size() > 1)
{
2019-04-08 21:01:10 +00:00
l2_reg_coef = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[1]);
2019-01-23 18:03:26 +00:00
}
2019-04-08 21:01:10 +00:00
if (parameters.size() > 2)
2019-01-26 12:38:42 +00:00
{
2019-04-08 21:01:10 +00:00
batch_size = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]);
2019-01-26 12:38:42 +00:00
}
2019-04-08 21:01:10 +00:00
if (parameters.size() > 3)
2019-01-26 12:38:42 +00:00
{
2019-04-21 14:32:42 +00:00
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'") {
2019-04-20 23:22:42 +00:00
weights_updater = std::make_shared<StochasticGradientDescent>();
2019-04-15 00:16:13 +00:00
}
2019-04-21 14:32:42 +00:00
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'") {
2019-04-20 23:22:42 +00:00
weights_updater = std::make_shared<Momentum>();
2019-04-15 00:16:13 +00:00
}
2019-04-21 14:32:42 +00:00
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'") {
2019-04-20 23:22:42 +00:00
weights_updater = std::make_shared<Nesterov>();
2019-04-15 00:16:13 +00:00
}
else
{
2019-04-08 21:01:10 +00:00
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
2019-01-26 12:38:42 +00:00
}
2019-04-15 00:16:13 +00:00
}
2019-01-26 12:38:42 +00:00
2019-04-08 21:01:10 +00:00
if (std::is_same<Method, FuncLinearRegression>::value)
{
2019-04-20 23:22:42 +00:00
gradient_computer = std::make_shared<LinearRegression>();
2019-04-15 00:16:13 +00:00
}
else if (std::is_same<Method, FuncLogisticRegression>::value)
2019-04-08 21:01:10 +00:00
{
2019-04-20 23:22:42 +00:00
gradient_computer = std::make_shared<LogisticRegression>();
2019-04-15 00:16:13 +00:00
}
else
2019-04-08 21:01:10 +00:00
{
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
2019-04-20 23:22:42 +00:00
return std::make_shared<Method>(argument_types.size() - 1,
gradient_computer, weights_updater,
learning_rate, l2_reg_coef, batch_size, argument_types, parameters);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-01-23 01:29:53 +00:00
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
2019-01-23 01:29:53 +00:00
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
2019-01-28 11:54:55 +00:00
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
}