mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-27 18:12:02 +00:00
fix on lin_ref_perf
This commit is contained in:
parent
9249561d08
commit
73c07cb9d5
@ -49,7 +49,7 @@ namespace
|
||||
auto l2_reg_coef = Float64(0.01);
|
||||
UInt32 batch_size = 1;
|
||||
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||
std::string weights_updater_name = "\'SGD\'";
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
|
||||
if (!parameters.empty())
|
||||
@ -66,19 +66,8 @@ namespace
|
||||
}
|
||||
if (parameters.size() > 3)
|
||||
{
|
||||
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'")
|
||||
{
|
||||
weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||
}
|
||||
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'")
|
||||
{
|
||||
weights_updater = std::make_shared<Momentum>();
|
||||
}
|
||||
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'")
|
||||
{
|
||||
weights_updater = std::make_shared<Nesterov>();
|
||||
}
|
||||
else
|
||||
weights_updater_name = applyVisitor(FieldVisitorToString(), parameters[3]);
|
||||
if (weights_updater_name != "\'SGD\'" && weights_updater_name != "\'Momentum\'" && weights_updater_name != "\'Nesterov\'")
|
||||
{
|
||||
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
@ -100,7 +89,7 @@ namespace
|
||||
return std::make_shared<Method>(
|
||||
argument_types.size() - 1,
|
||||
gradient_computer,
|
||||
weights_updater,
|
||||
weights_updater_name,
|
||||
learning_rate,
|
||||
l2_reg_coef,
|
||||
batch_size,
|
||||
|
@ -256,7 +256,7 @@ public:
|
||||
explicit AggregateFunctionMLMethod(
|
||||
UInt32 param_num,
|
||||
std::shared_ptr<IGradientComputer> gradient_computer,
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater,
|
||||
std::string weights_updater_name,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
UInt32 batch_size,
|
||||
@ -268,7 +268,7 @@ public:
|
||||
, l2_reg_coef(l2_reg_coef)
|
||||
, batch_size(batch_size)
|
||||
, gradient_computer(std::move(gradient_computer))
|
||||
, weights_updater(std::move(weights_updater))
|
||||
, weights_updater_name(std::move(weights_updater_name))
|
||||
{
|
||||
}
|
||||
|
||||
@ -284,7 +284,21 @@ public:
|
||||
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, weights_updater);
|
||||
std::shared_ptr<IWeightsUpdater> new_weights_updater;
|
||||
if (weights_updater_name == "\'SGD\'")
|
||||
{
|
||||
new_weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||
} else if (weights_updater_name == "\'Momentum\'")
|
||||
{
|
||||
new_weights_updater = std::make_shared<Momentum>();
|
||||
} else if (weights_updater_name == "\'Nesterov\'")
|
||||
{
|
||||
new_weights_updater = std::make_shared<Nesterov>();
|
||||
} else
|
||||
{
|
||||
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
@ -328,7 +342,7 @@ private:
|
||||
Float64 l2_reg_coef;
|
||||
UInt32 batch_size;
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater;
|
||||
std::string weights_updater_name;
|
||||
};
|
||||
|
||||
struct NameLinearRegression
|
||||
|
Loading…
Reference in New Issue
Block a user