Merge branch 'master' into alexkoja_ML

This commit is contained in:
alexander kozhikhov 2019-01-23 21:07:05 +03:00
commit 8472048328
2 changed files with 11 additions and 8 deletions

View File

@ -22,9 +22,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
for (size_t i = 0; i < argument_types.size(); ++i)
{
if (!WhichDataType(argument_types[i]).isFloat64())
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " +
std::to_string(i) + "for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument "
+ std::to_string(i) + "for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
Float64 learning_rate = Float64(0.01);
@ -47,8 +47,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) {
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
}
}
}

View File

@ -20,7 +20,8 @@
#include <Common/FieldVisitors.h>
namespace DB {
namespace DB
{
namespace ErrorCodes
{
@ -209,7 +210,8 @@ public:
auto &column = dynamic_cast<ColumnVector<Float64> &>(to);
std::vector<Float64> predict_features(arguments.size() - 1);
for (size_t i = 1; i < arguments.size(); ++i) {
for (size_t i = 1; i < arguments.size(); ++i)
{
const auto& element = (*block.getByPosition(arguments[i]).column)[row_num];
if (element.getType() != Field::Types::Float64)
throw Exception("Prediction arguments must be values of type Float",
@ -237,4 +239,4 @@ private:
struct NameLinearRegression { static constexpr auto name = "LinearRegression"; };
}
}