2019-01-22 21:07:05 +00:00
|
|
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
|
|
|
#include <AggregateFunctions/AggregateFunctionMLMethod.h>
|
|
|
|
#include <AggregateFunctions/Helpers.h>
|
|
|
|
#include <AggregateFunctions/FactoryHelpers.h>
|
|
|
|
|
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
|
|
|
|
2019-01-23 01:29:53 +00:00
|
|
|
namespace
|
|
|
|
{
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-01-26 12:38:42 +00:00
|
|
|
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
|
2019-01-28 11:54:55 +00:00
|
|
|
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
|
2019-01-23 01:29:53 +00:00
|
|
|
template <class Method>
|
|
|
|
AggregateFunctionPtr createAggregateFunctionMLMethod(
|
|
|
|
const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
|
|
|
{
|
2019-01-26 12:38:42 +00:00
|
|
|
if (parameters.size() > 4)
|
|
|
|
throw Exception("Aggregate function " + name + " requires at most four parameters", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-01-23 01:29:53 +00:00
|
|
|
for (size_t i = 0; i < argument_types.size(); ++i)
|
|
|
|
{
|
|
|
|
if (!WhichDataType(argument_types[i]).isFloat64())
|
2019-01-23 14:53:50 +00:00
|
|
|
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument "
|
|
|
|
+ std::to_string(i) + "for aggregate function " + name,
|
|
|
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
2019-01-23 01:29:53 +00:00
|
|
|
}
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-01-23 18:03:26 +00:00
|
|
|
Float64 learning_rate = Float64(0.01);
|
|
|
|
UInt32 batch_size = 1;
|
2019-01-26 12:38:42 +00:00
|
|
|
|
|
|
|
std::shared_ptr<IGradientComputer> gc;
|
|
|
|
std::shared_ptr<IWeightsUpdater> wu;
|
2019-01-23 18:03:26 +00:00
|
|
|
if (!parameters.empty())
|
2019-01-23 01:29:53 +00:00
|
|
|
{
|
|
|
|
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
|
|
|
|
}
|
2019-01-23 18:03:26 +00:00
|
|
|
if (parameters.size() > 1)
|
|
|
|
{
|
|
|
|
batch_size = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[1]);
|
|
|
|
|
|
|
|
}
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-01-26 12:38:42 +00:00
|
|
|
/// Gradient_Computer for LinearRegression has LinearRegression gradient computer
|
|
|
|
if (std::is_same<Method, FuncLinearRegression>::value)
|
|
|
|
{
|
|
|
|
gc = std::make_shared<LinearRegression>(argument_types.size());
|
2019-01-28 11:54:55 +00:00
|
|
|
} else if (std::is_same<Method, FuncLogisticRegression>::value)
|
2019-01-28 10:39:57 +00:00
|
|
|
{
|
2019-01-28 11:54:55 +00:00
|
|
|
gc = std::make_shared<LogisticRegression>(argument_types.size());
|
2019-01-26 12:38:42 +00:00
|
|
|
} else
|
|
|
|
{
|
|
|
|
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
|
|
|
}
|
|
|
|
if (parameters.size() > 2)
|
|
|
|
{
|
|
|
|
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{1.0})
|
|
|
|
{
|
|
|
|
wu = std::make_shared<StochasticGradientDescent>();
|
2019-01-28 10:39:57 +00:00
|
|
|
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{2.0})
|
|
|
|
{
|
|
|
|
wu = std::make_shared<Momentum>();
|
2019-02-26 08:12:16 +00:00
|
|
|
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{3.0})
|
|
|
|
{
|
|
|
|
wu = std::make_shared<Nesterov>();
|
|
|
|
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{4.0})
|
|
|
|
{
|
|
|
|
wu = std::make_shared<Adam>();
|
2019-01-26 12:38:42 +00:00
|
|
|
} else
|
|
|
|
{
|
|
|
|
throw Exception("Such weights updater is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
|
|
|
}
|
|
|
|
} else
|
|
|
|
{
|
|
|
|
wu = std::make_unique<StochasticGradientDescent>();
|
|
|
|
}
|
|
|
|
|
2019-01-23 01:29:53 +00:00
|
|
|
if (argument_types.size() < 2)
|
|
|
|
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-02-12 21:37:49 +00:00
|
|
|
return std::make_shared<Method>(argument_types.size() - 1, gc, wu, learning_rate, batch_size, argument_types, parameters);
|
2019-01-23 01:29:53 +00:00
|
|
|
}
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-01-23 01:29:53 +00:00
|
|
|
}
|
|
|
|
|
2019-01-23 14:53:50 +00:00
|
|
|
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
|
|
|
|
{
|
2019-01-23 01:29:53 +00:00
|
|
|
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
2019-01-28 11:54:55 +00:00
|
|
|
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
|
2019-01-23 01:29:53 +00:00
|
|
|
}
|
2019-01-22 21:07:05 +00:00
|
|
|
|
2019-01-23 14:53:50 +00:00
|
|
|
}
|