ClickHouse/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp

94 lines
3.5 KiB
C++
Raw Normal View History

2019-01-22 21:07:05 +00:00
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionMLMethod.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
2019-01-23 01:29:53 +00:00
namespace
{
2019-01-22 21:07:05 +00:00
2019-01-26 12:38:42 +00:00
using FuncLinearRegression = AggregateFunctionMLMethod<LinearModelData, NameLinearRegression>;
2019-01-28 11:54:55 +00:00
using FuncLogisticRegression = AggregateFunctionMLMethod<LinearModelData, NameLogisticRegression>;
2019-01-23 01:29:53 +00:00
template <class Method>
AggregateFunctionPtr createAggregateFunctionMLMethod(
const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
2019-01-26 12:38:42 +00:00
if (parameters.size() > 4)
throw Exception("Aggregate function " + name + " requires at most four parameters", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
2019-01-22 21:07:05 +00:00
2019-01-23 01:29:53 +00:00
for (size_t i = 0; i < argument_types.size(); ++i)
{
if (!WhichDataType(argument_types[i]).isFloat64())
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument "
+ std::to_string(i) + "for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-01-23 18:03:26 +00:00
Float64 learning_rate = Float64(0.01);
UInt32 batch_size = 1;
2019-01-26 12:38:42 +00:00
std::shared_ptr<IGradientComputer> gc;
std::shared_ptr<IWeightsUpdater> wu;
2019-01-23 18:03:26 +00:00
if (!parameters.empty())
2019-01-23 01:29:53 +00:00
{
learning_rate = applyVisitor(FieldVisitorConvertToNumber<Float64>(), parameters[0]);
}
2019-01-23 18:03:26 +00:00
if (parameters.size() > 1)
{
batch_size = applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[1]);
}
2019-01-22 21:07:05 +00:00
2019-01-26 12:38:42 +00:00
/// Gradient_Computer for LinearRegression has LinearRegression gradient computer
if (std::is_same<Method, FuncLinearRegression>::value)
{
gc = std::make_shared<LinearRegression>(argument_types.size());
2019-01-28 11:54:55 +00:00
} else if (std::is_same<Method, FuncLogisticRegression>::value)
2019-01-28 10:39:57 +00:00
{
2019-01-28 11:54:55 +00:00
gc = std::make_shared<LogisticRegression>(argument_types.size());
2019-01-26 12:38:42 +00:00
} else
{
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
if (parameters.size() > 2)
{
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{1.0})
{
wu = std::make_shared<StochasticGradientDescent>();
2019-01-28 10:39:57 +00:00
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{2.0})
{
wu = std::make_shared<Momentum>();
2019-02-26 08:12:16 +00:00
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{3.0})
{
wu = std::make_shared<Nesterov>();
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[2]) == Float64{4.0})
{
wu = std::make_shared<Adam>();
2019-01-26 12:38:42 +00:00
} else
{
throw Exception("Such weights updater is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
} else
{
wu = std::make_unique<StochasticGradientDescent>();
}
2019-01-23 01:29:53 +00:00
if (argument_types.size() < 2)
throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
2019-01-22 21:07:05 +00:00
2019-02-12 21:37:49 +00:00
return std::make_shared<Method>(argument_types.size() - 1, gc, wu, learning_rate, batch_size, argument_types, parameters);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
2019-01-23 01:29:53 +00:00
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
2019-01-23 01:29:53 +00:00
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
2019-01-28 11:54:55 +00:00
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
2019-01-23 01:29:53 +00:00
}
2019-01-22 21:07:05 +00:00
}