#include #include #include #include namespace DB { namespace { using FuncLinearRegression = AggregateFunctionMLMethod; using FuncLogisticRegression = AggregateFunctionMLMethod; template AggregateFunctionPtr createAggregateFunctionMLMethod( const std::string & name, const DataTypes & argument_types, const Array & parameters) { if (parameters.size() > 4) throw Exception("Aggregate function " + name + " requires at most four parameters: learning_rate, l2_regularization_coef, mini-batch size and weights_updater method", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); if (argument_types.size() < 2) throw Exception("Aggregate function " + name + " requires at least two arguments: target and model's parameters", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); for (size_t i = 0; i < argument_types.size(); ++i) { // if (!WhichDataType(argument_types[i]).isFloat64()) // throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " + std::to_string(i) + "for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); // if (!WhichDataType(argument_types[i]).isNumeric()) if (!isNumber(argument_types[i])) throw Exception("Argument " + std::to_string(i) + " of type " + argument_types[i]->getName() + " must be numeric for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } /// Such default parameters were picked because they did good on some tests, /// though it still requires to fit parameters to achieve better result auto learning_rate = Float64(0.01); auto l2_reg_coef = Float64(0.01); UInt32 batch_size = 1; std::shared_ptr weights_updater = std::make_unique(); std::shared_ptr gradient_computer; if (!parameters.empty()) { learning_rate = applyVisitor(FieldVisitorConvertToNumber(), parameters[0]); } if (parameters.size() > 1) { l2_reg_coef = applyVisitor(FieldVisitorConvertToNumber(), parameters[1]); } if (parameters.size() > 2) { batch_size = applyVisitor(FieldVisitorConvertToNumber(), parameters[2]); } if (parameters.size() > 3) { if (applyVisitor(FieldVisitorConvertToNumber(), parameters[3]) == Float64{1.0}) { weights_updater = std::make_shared(); } else if (applyVisitor(FieldVisitorConvertToNumber(), parameters[3]) == Float64{2.0}) { weights_updater = std::make_shared(); } else if (applyVisitor(FieldVisitorConvertToNumber(), parameters[3]) == Float64{3.0}) { weights_updater = std::make_shared(); } else { throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } } // else // { // weights_updater = std::make_unique(); // } if (std::is_same::value) { gradient_computer = std::make_shared(); } else if (std::is_same::value) { gradient_computer = std::make_shared(); } else { throw Exception("Such gradient computer is not implemented yet", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } return std::make_shared(argument_types.size() - 1, gradient_computer, weights_updater, learning_rate, l2_reg_coef, batch_size, argument_types, parameters); } } void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) { factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod); factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod); } }