fix on lin_ref_perf

This commit is contained in:
Alexander Kozhikhov 2019-05-31 18:39:17 +03:00
parent 9249561d08
commit 73c07cb9d5
2 changed files with 22 additions and 19 deletions

View File

@ -49,7 +49,7 @@ namespace
auto l2_reg_coef = Float64(0.01);
UInt32 batch_size = 1;
std::shared_ptr<IWeightsUpdater> weights_updater = std::make_shared<StochasticGradientDescent>();
std::string weights_updater_name = "\'SGD\'";
std::shared_ptr<IGradientComputer> gradient_computer;
if (!parameters.empty())
@ -66,19 +66,8 @@ namespace
}
if (parameters.size() > 3)
{
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'")
{
weights_updater = std::make_shared<StochasticGradientDescent>();
}
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'")
{
weights_updater = std::make_shared<Momentum>();
}
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'")
{
weights_updater = std::make_shared<Nesterov>();
}
else
weights_updater_name = applyVisitor(FieldVisitorToString(), parameters[3]);
if (weights_updater_name != "\'SGD\'" && weights_updater_name != "\'Momentum\'" && weights_updater_name != "\'Nesterov\'")
{
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -100,7 +89,7 @@ namespace
return std::make_shared<Method>(
argument_types.size() - 1,
gradient_computer,
weights_updater,
weights_updater_name,
learning_rate,
l2_reg_coef,
batch_size,

View File

@ -256,7 +256,7 @@ public:
explicit AggregateFunctionMLMethod(
UInt32 param_num,
std::shared_ptr<IGradientComputer> gradient_computer,
std::shared_ptr<IWeightsUpdater> weights_updater,
std::string weights_updater_name,
Float64 learning_rate,
Float64 l2_reg_coef,
UInt32 batch_size,
@ -268,7 +268,7 @@ public:
, l2_reg_coef(l2_reg_coef)
, batch_size(batch_size)
, gradient_computer(std::move(gradient_computer))
, weights_updater(std::move(weights_updater))
, weights_updater_name(std::move(weights_updater_name))
{
}
@ -284,7 +284,21 @@ public:
void create(AggregateDataPtr place) const override
{
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, weights_updater);
std::shared_ptr<IWeightsUpdater> new_weights_updater;
if (weights_updater_name == "\'SGD\'")
{
new_weights_updater = std::make_shared<StochasticGradientDescent>();
} else if (weights_updater_name == "\'Momentum\'")
{
new_weights_updater = std::make_shared<Momentum>();
} else if (weights_updater_name == "\'Nesterov\'")
{
new_weights_updater = std::make_shared<Nesterov>();
} else
{
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
}
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -328,7 +342,7 @@ private:
Float64 l2_reg_coef;
UInt32 batch_size;
std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater;
std::string weights_updater_name;
};
struct NameLinearRegression