From 03deb677b964cb5bc7bc92fcb5881e92b7ef8e64 Mon Sep 17 00:00:00 2001 From: Maxim Kuznetsov Date: Wed, 23 Jan 2019 14:58:05 +0300 Subject: [PATCH 1/2] Code style --- .../AggregateFunctionMLMethod.h | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h index 1e45303b454..252cf6e579c 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h @@ -20,9 +20,11 @@ #include -namespace DB { +namespace DB +{ -struct LinearRegressionData { +struct LinearRegressionData +{ Float64 bias{0.0}; std::vector w1; Float64 learning_rate{0.01}; @@ -30,8 +32,10 @@ struct LinearRegressionData { UInt32 param_num = 0; - void add(Float64 target, std::vector& feature, Float64 learning_rate_, UInt32 param_num_) { - if (w1.empty()) { + void add(Float64 target, std::vector& feature, Float64 learning_rate_, UInt32 param_num_) + { + if (w1.empty()) + { learning_rate = learning_rate_; param_num = param_num_; w1.resize(param_num); @@ -53,11 +57,13 @@ struct LinearRegressionData { ++iter_num; } - void merge(const LinearRegressionData & rhs) { + void merge(const LinearRegressionData & rhs) + { if (iter_num == 0 && rhs.iter_num == 0) throw std::runtime_error("Strange..."); - if (param_num == 0) { + if (param_num == 0) + { param_num = rhs.param_num; w1.resize(param_num); } @@ -74,18 +80,22 @@ struct LinearRegressionData { iter_num += rhs.iter_num; } - void write(WriteBuffer & buf) const { + void write(WriteBuffer & buf) const + { writeBinary(bias, buf); writeBinary(w1, buf); writeBinary(iter_num, buf); } - void read(ReadBuffer & buf) { + void read(ReadBuffer & buf) + { readBinary(bias, buf); readBinary(w1, buf); readBinary(iter_num, buf); } - Float64 predict(std::vector& predict_feature) const { + + Float64 predict(std::vector& predict_feature) const + { Float64 res{0.0}; for (size_t i = 0; i < static_cast(param_num); ++i) { @@ -152,8 +162,9 @@ public: auto &column = dynamic_cast &>(to); std::vector predict_features(arguments.size() - 1); -// for (size_t row_num = 0, rows = block.rows(); row_num < rows; ++row_num) { - for (size_t i = 1; i < arguments.size(); ++i) { +// for (size_t row_num = 0, rows = block.rows(); row_num < rows; ++row_num) + for (size_t i = 1; i < arguments.size(); ++i) + { // predict_features[i] = array_elements[i].get(); predict_features[i - 1] = applyVisitor(FieldVisitorConvertToNumber(), (*block.getByPosition(arguments[i]).column)[row_num]); } @@ -162,7 +173,8 @@ public: // } } - void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { auto &column = dynamic_cast &>(to); std::ignore = column; std::ignore = place; @@ -177,4 +189,4 @@ private: struct NameLinearRegression { static constexpr auto name = "LinearRegression"; }; -} \ No newline at end of file +} From b9972f8e67a0972a7e47cd0ee7fe69c01758bd6e Mon Sep 17 00:00:00 2001 From: Masha Date: Wed, 23 Jan 2019 14:53:50 +0000 Subject: [PATCH 2/2] code style AggregateFunctionMLMethod.cpp --- .../AggregateFunctionMLMethod.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp index 7503a300a47..f1781c92bc9 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp @@ -22,11 +22,10 @@ AggregateFunctionPtr createAggregateFunctionMLMethod( for (size_t i = 0; i < argument_types.size(); ++i) { if (!WhichDataType(argument_types[i]).isFloat64()) - throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " + - std::to_string(i) + "for aggregate function " + name, - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + throw Exception("Illegal type " + argument_types[i]->getName() + " of argument " + + std::to_string(i) + "for aggregate function " + name, + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } - Float64 learning_rate; if (parameters.empty()) { @@ -35,7 +34,6 @@ AggregateFunctionPtr createAggregateFunctionMLMethod( { learning_rate = applyVisitor(FieldVisitorConvertToNumber(), parameters[0]); } - if (argument_types.size() < 2) throw Exception("Aggregate function " + name + " requires at least two arguments", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); @@ -44,8 +42,9 @@ AggregateFunctionPtr createAggregateFunctionMLMethod( } -void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) { +void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory) +{ factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod); } -} \ No newline at end of file +}