mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 09:32:06 +00:00
Merge pull request #5623 from yandex/Quid37-lin_ref_perf
Merging PR #5505
This commit is contained in:
commit
00809c78e2
@ -49,8 +49,8 @@ namespace
|
||||
auto l2_reg_coef = Float64(0.1);
|
||||
UInt32 batch_size = 15;
|
||||
|
||||
std::string weights_updater_name = "\'SGD\'";
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
std::string weights_updater_name = "SGD";
|
||||
std::unique_ptr<IGradientComputer> gradient_computer;
|
||||
|
||||
if (!parameters.empty())
|
||||
{
|
||||
@ -66,20 +66,19 @@ namespace
|
||||
}
|
||||
if (parameters.size() > 3)
|
||||
{
|
||||
weights_updater_name = applyVisitor(FieldVisitorToString(), parameters[3]);
|
||||
if (weights_updater_name != "\'SGD\'" && weights_updater_name != "\'Momentum\'" && weights_updater_name != "\'Nesterov\'")
|
||||
{
|
||||
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
weights_updater_name = parameters[3].safeGet<String>();
|
||||
if (weights_updater_name != "SGD" && weights_updater_name != "Momentum" && weights_updater_name != "Nesterov")
|
||||
throw Exception("Invalid parameter for weights updater. The only supported are 'SGD', 'Momentum' and 'Nesterov'",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
if (std::is_same<Method, FuncLinearRegression>::value)
|
||||
{
|
||||
gradient_computer = std::make_shared<LinearRegression>();
|
||||
gradient_computer = std::make_unique<LinearRegression>();
|
||||
}
|
||||
else if (std::is_same<Method, FuncLogisticRegression>::value)
|
||||
{
|
||||
gradient_computer = std::make_shared<LogisticRegression>();
|
||||
gradient_computer = std::make_unique<LogisticRegression>();
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -88,7 +87,7 @@ namespace
|
||||
|
||||
return std::make_shared<Method>(
|
||||
argument_types.size() - 1,
|
||||
gradient_computer,
|
||||
std::move(gradient_computer),
|
||||
weights_updater_name,
|
||||
learning_rate,
|
||||
l2_reg_coef,
|
||||
|
@ -37,8 +37,7 @@ public:
|
||||
Float64 l2_reg_coef,
|
||||
Float64 target,
|
||||
const IColumn ** columns,
|
||||
size_t row_num)
|
||||
= 0;
|
||||
size_t row_num) = 0;
|
||||
|
||||
virtual void predict(
|
||||
ColumnVector<Float64>::Container & container,
|
||||
@ -201,9 +200,8 @@ private:
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* LinearModelData is a class which manages current state of learning
|
||||
*/
|
||||
/** LinearModelData is a class which manages current state of learning
|
||||
*/
|
||||
class LinearModelData
|
||||
{
|
||||
public:
|
||||
@ -249,9 +247,8 @@ private:
|
||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||
std::shared_ptr<IWeightsUpdater> weights_updater;
|
||||
|
||||
/**
|
||||
* The function is called when we want to flush current batch and update our weights
|
||||
*/
|
||||
/** The function is called when we want to flush current batch and update our weights
|
||||
*/
|
||||
void update_state();
|
||||
};
|
||||
|
||||
@ -268,7 +265,7 @@ public:
|
||||
|
||||
explicit AggregateFunctionMLMethod(
|
||||
UInt32 param_num,
|
||||
std::shared_ptr<IGradientComputer> gradient_computer,
|
||||
std::unique_ptr<IGradientComputer> gradient_computer,
|
||||
std::string weights_updater_name,
|
||||
Float64 learning_rate,
|
||||
Float64 l2_reg_coef,
|
||||
@ -300,19 +297,15 @@ public:
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
std::shared_ptr<IWeightsUpdater> new_weights_updater;
|
||||
if (weights_updater_name == "\'SGD\'")
|
||||
{
|
||||
if (weights_updater_name == "SGD")
|
||||
new_weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||
} else if (weights_updater_name == "\'Momentum\'")
|
||||
{
|
||||
else if (weights_updater_name == "Momentum")
|
||||
new_weights_updater = std::make_shared<Momentum>();
|
||||
} else if (weights_updater_name == "\'Nesterov\'")
|
||||
{
|
||||
else if (weights_updater_name == "Nesterov")
|
||||
new_weights_updater = std::make_shared<Nesterov>();
|
||||
} else
|
||||
{
|
||||
else
|
||||
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user