mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-09-25 11:10:49 +00:00
fix on lin_ref_perf
This commit is contained in:
parent
9249561d08
commit
73c07cb9d5
@ -49,7 +49,7 @@ namespace
|
|||||||
auto l2_reg_coef = Float64(0.01);
|
auto l2_reg_coef = Float64(0.01);
|
||||||
UInt32 batch_size = 1;
|
UInt32 batch_size = 1;
|
||||||
|
|
||||||
std::shared_ptr<IWeightsUpdater> weights_updater = std::make_shared<StochasticGradientDescent>();
|
std::string weights_updater_name = "\'SGD\'";
|
||||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||||
|
|
||||||
if (!parameters.empty())
|
if (!parameters.empty())
|
||||||
@ -66,19 +66,8 @@ namespace
|
|||||||
}
|
}
|
||||||
if (parameters.size() > 3)
|
if (parameters.size() > 3)
|
||||||
{
|
{
|
||||||
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'")
|
weights_updater_name = applyVisitor(FieldVisitorToString(), parameters[3]);
|
||||||
{
|
if (weights_updater_name != "\'SGD\'" && weights_updater_name != "\'Momentum\'" && weights_updater_name != "\'Nesterov\'")
|
||||||
weights_updater = std::make_shared<StochasticGradientDescent>();
|
|
||||||
}
|
|
||||||
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'")
|
|
||||||
{
|
|
||||||
weights_updater = std::make_shared<Momentum>();
|
|
||||||
}
|
|
||||||
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'")
|
|
||||||
{
|
|
||||||
weights_updater = std::make_shared<Nesterov>();
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
{
|
||||||
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||||
}
|
}
|
||||||
@ -100,7 +89,7 @@ namespace
|
|||||||
return std::make_shared<Method>(
|
return std::make_shared<Method>(
|
||||||
argument_types.size() - 1,
|
argument_types.size() - 1,
|
||||||
gradient_computer,
|
gradient_computer,
|
||||||
weights_updater,
|
weights_updater_name,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
l2_reg_coef,
|
l2_reg_coef,
|
||||||
batch_size,
|
batch_size,
|
||||||
|
@ -256,7 +256,7 @@ public:
|
|||||||
explicit AggregateFunctionMLMethod(
|
explicit AggregateFunctionMLMethod(
|
||||||
UInt32 param_num,
|
UInt32 param_num,
|
||||||
std::shared_ptr<IGradientComputer> gradient_computer,
|
std::shared_ptr<IGradientComputer> gradient_computer,
|
||||||
std::shared_ptr<IWeightsUpdater> weights_updater,
|
std::string weights_updater_name,
|
||||||
Float64 learning_rate,
|
Float64 learning_rate,
|
||||||
Float64 l2_reg_coef,
|
Float64 l2_reg_coef,
|
||||||
UInt32 batch_size,
|
UInt32 batch_size,
|
||||||
@ -268,7 +268,7 @@ public:
|
|||||||
, l2_reg_coef(l2_reg_coef)
|
, l2_reg_coef(l2_reg_coef)
|
||||||
, batch_size(batch_size)
|
, batch_size(batch_size)
|
||||||
, gradient_computer(std::move(gradient_computer))
|
, gradient_computer(std::move(gradient_computer))
|
||||||
, weights_updater(std::move(weights_updater))
|
, weights_updater_name(std::move(weights_updater_name))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,7 +284,21 @@ public:
|
|||||||
|
|
||||||
void create(AggregateDataPtr place) const override
|
void create(AggregateDataPtr place) const override
|
||||||
{
|
{
|
||||||
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, weights_updater);
|
std::shared_ptr<IWeightsUpdater> new_weights_updater;
|
||||||
|
if (weights_updater_name == "\'SGD\'")
|
||||||
|
{
|
||||||
|
new_weights_updater = std::make_shared<StochasticGradientDescent>();
|
||||||
|
} else if (weights_updater_name == "\'Momentum\'")
|
||||||
|
{
|
||||||
|
new_weights_updater = std::make_shared<Momentum>();
|
||||||
|
} else if (weights_updater_name == "\'Nesterov\'")
|
||||||
|
{
|
||||||
|
new_weights_updater = std::make_shared<Nesterov>();
|
||||||
|
} else
|
||||||
|
{
|
||||||
|
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
|
||||||
|
}
|
||||||
|
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
|
||||||
}
|
}
|
||||||
|
|
||||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||||
@ -328,7 +342,7 @@ private:
|
|||||||
Float64 l2_reg_coef;
|
Float64 l2_reg_coef;
|
||||||
UInt32 batch_size;
|
UInt32 batch_size;
|
||||||
std::shared_ptr<IGradientComputer> gradient_computer;
|
std::shared_ptr<IGradientComputer> gradient_computer;
|
||||||
std::shared_ptr<IWeightsUpdater> weights_updater;
|
std::string weights_updater_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NameLinearRegression
|
struct NameLinearRegression
|
||||||
|
Loading…
Reference in New Issue
Block a user