This commit is contained in:
Alexander Kozhikhov 2019-04-15 03:16:13 +03:00
parent 12132b8fdf
commit 19021e76bb
4 changed files with 22 additions and 14 deletions

View File

@ -25,9 +25,7 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
for (size_t i = 0; i < argument_types.size(); ++i)
{
if (!WhichDataType(argument_types[i]).isFloat64())
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument "
+ std::to_string(i) + "for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " + std::to_string(i) + "for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
Float64 learning_rate = Float64(0.01);
@ -55,17 +53,22 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{1.0})
{
wu = std::make_shared<StochasticGradientDescent>();
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{2.0})
}
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{2.0})
{
wu = std::make_shared<Momentum>();
} else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{3.0})
}
else if (applyVisitor(FieldVisitorConvertToNumber<UInt32>(), parameters[3]) == Float64{3.0})
{
wu = std::make_shared<Nesterov>();
} else {
}
else
{
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
} else
}
else
{
wu = std::make_unique<StochasticGradientDescent>();
}
@ -73,10 +76,12 @@ AggregateFunctionPtr createAggregateFunctionMLMethod(
if (std::is_same<Method, FuncLinearRegression>::value)
{
gc = std::make_shared<LinearRegression>();
} else if (std::is_same<Method, FuncLogisticRegression>::value)
}
else if (std::is_same<Method, FuncLogisticRegression>::value)
{
gc = std::make_shared<LogisticRegression>();
} else
}
else
{
throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}

View File

@ -483,8 +483,8 @@ public:
l2_reg_coef(l2_reg_coef),
batch_size(batch_size),
gc(std::move(gradient_computer)),
wu(std::move(weights_updater)) {
}
wu(std::move(weights_updater))
{}
DataTypePtr getReturnType() const override
{

View File

@ -120,7 +120,8 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
++row_num;
}
} else if (ML_function_Logistic)
}
else if (ML_function_Logistic)
{
size_t row_num = 0;
for (auto val : data)
@ -128,7 +129,8 @@ MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const Col
ML_function_Logistic->predictResultInto(val, *res, block, arguments);
++row_num;
}
} else
}
else
{
throw Exception("Illegal aggregate function is passed",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);

View File

@ -39,7 +39,8 @@ public:
return name;
}
bool isVariadic() const override {
bool isVariadic() const override
{
return true;
}
size_t getNumberOfArguments() const override