fix on lin_ref_perf

This commit is contained in:
Alexander Kozhikhov 2019-05-31 18:39:17 +03:00
parent 9249561d08
commit 73c07cb9d5
2 changed files with 22 additions and 19 deletions

View File

@ -49,7 +49,7 @@ namespace
auto l2_reg_coef = Float64(0.01); auto l2_reg_coef = Float64(0.01);
UInt32 batch_size = 1; UInt32 batch_size = 1;
std::shared_ptr<IWeightsUpdater> weights_updater = std::make_shared<StochasticGradientDescent>(); std::string weights_updater_name = "\'SGD\'";
std::shared_ptr<IGradientComputer> gradient_computer; std::shared_ptr<IGradientComputer> gradient_computer;
if (!parameters.empty()) if (!parameters.empty())
@ -66,19 +66,8 @@ namespace
} }
if (parameters.size() > 3) if (parameters.size() > 3)
{ {
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'") weights_updater_name = applyVisitor(FieldVisitorToString(), parameters[3]);
{ if (weights_updater_name != "\'SGD\'" && weights_updater_name != "\'Momentum\'" && weights_updater_name != "\'Nesterov\'")
weights_updater = std::make_shared<StochasticGradientDescent>();
}
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'")
{
weights_updater = std::make_shared<Momentum>();
}
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'")
{
weights_updater = std::make_shared<Nesterov>();
}
else
{ {
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
} }
@ -100,7 +89,7 @@ namespace
return std::make_shared<Method>( return std::make_shared<Method>(
argument_types.size() - 1, argument_types.size() - 1,
gradient_computer, gradient_computer,
weights_updater, weights_updater_name,
learning_rate, learning_rate,
l2_reg_coef, l2_reg_coef,
batch_size, batch_size,

View File

@ -256,7 +256,7 @@ public:
explicit AggregateFunctionMLMethod( explicit AggregateFunctionMLMethod(
UInt32 param_num, UInt32 param_num,
std::shared_ptr<IGradientComputer> gradient_computer, std::shared_ptr<IGradientComputer> gradient_computer,
std::shared_ptr<IWeightsUpdater> weights_updater, std::string weights_updater_name,
Float64 learning_rate, Float64 learning_rate,
Float64 l2_reg_coef, Float64 l2_reg_coef,
UInt32 batch_size, UInt32 batch_size,
@ -268,7 +268,7 @@ public:
, l2_reg_coef(l2_reg_coef) , l2_reg_coef(l2_reg_coef)
, batch_size(batch_size) , batch_size(batch_size)
, gradient_computer(std::move(gradient_computer)) , gradient_computer(std::move(gradient_computer))
, weights_updater(std::move(weights_updater)) , weights_updater_name(std::move(weights_updater_name))
{ {
} }
@ -284,7 +284,21 @@ public:
void create(AggregateDataPtr place) const override void create(AggregateDataPtr place) const override
{ {
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, weights_updater); std::shared_ptr<IWeightsUpdater> new_weights_updater;
if (weights_updater_name == "\'SGD\'")
{
new_weights_updater = std::make_shared<StochasticGradientDescent>();
} else if (weights_updater_name == "\'Momentum\'")
{
new_weights_updater = std::make_shared<Momentum>();
} else if (weights_updater_name == "\'Nesterov\'")
{
new_weights_updater = std::make_shared<Nesterov>();
} else
{
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
}
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
} }
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -328,7 +342,7 @@ private:
Float64 l2_reg_coef; Float64 l2_reg_coef;
UInt32 batch_size; UInt32 batch_size;
std::shared_ptr<IGradientComputer> gradient_computer; std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater; std::string weights_updater_name;
}; };
struct NameLinearRegression struct NameLinearRegression