diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp index d1a06875387..cf6b0551655 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp @@ -49,7 +49,7 @@ namespace auto l2_reg_coef = Float64(0.01); UInt32 batch_size = 1; - std::shared_ptr weights_updater = std::make_shared(); + std::string weights_updater_name = "\'SGD\'"; std::shared_ptr gradient_computer; if (!parameters.empty()) @@ -66,19 +66,8 @@ namespace } if (parameters.size() > 3) { - if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'") - { - weights_updater = std::make_shared(); - } - else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'") - { - weights_updater = std::make_shared(); - } - else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'") - { - weights_updater = std::make_shared(); - } - else + weights_updater_name = applyVisitor(FieldVisitorToString(), parameters[3]); + if (weights_updater_name != "\'SGD\'" && weights_updater_name != "\'Momentum\'" && weights_updater_name != "\'Nesterov\'") { throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } @@ -100,7 +89,7 @@ namespace return std::make_shared( argument_types.size() - 1, gradient_computer, - weights_updater, + weights_updater_name, learning_rate, l2_reg_coef, batch_size, diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h index 8d258691102..12d511c4557 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h @@ -256,7 +256,7 @@ public: explicit AggregateFunctionMLMethod( UInt32 param_num, std::shared_ptr gradient_computer, - std::shared_ptr weights_updater, + std::string weights_updater_name, Float64 learning_rate, Float64 l2_reg_coef, UInt32 batch_size, @@ -268,7 +268,7 @@ public: , l2_reg_coef(l2_reg_coef) , batch_size(batch_size) , gradient_computer(std::move(gradient_computer)) - , weights_updater(std::move(weights_updater)) + , weights_updater_name(std::move(weights_updater_name)) { } @@ -284,7 +284,21 @@ public: void create(AggregateDataPtr place) const override { - new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, weights_updater); + std::shared_ptr new_weights_updater; + if (weights_updater_name == "\'SGD\'") + { + new_weights_updater = std::make_shared(); + } else if (weights_updater_name == "\'Momentum\'") + { + new_weights_updater = std::make_shared(); + } else if (weights_updater_name == "\'Nesterov\'") + { + new_weights_updater = std::make_shared(); + } else + { + throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR); + } + new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater); } void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override @@ -328,7 +342,7 @@ private: Float64 l2_reg_coef; UInt32 batch_size; std::shared_ptr gradient_computer; - std::shared_ptr weights_updater; + std::string weights_updater_name; }; struct NameLinearRegression