diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h index cdb3c88261c..6f3746eb757 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h @@ -20,7 +20,8 @@ #include -namespace DB { +namespace DB +{ namespace ErrorCodes { @@ -92,6 +93,7 @@ struct LinearRegressionData readBinary(weights, buf); readBinary(iter_num, buf); } + Float64 predict(const std::vector& predict_feature) const { Float64 res{0.0}; @@ -162,7 +164,8 @@ public: auto &column = dynamic_cast &>(to); std::vector predict_features(arguments.size() - 1); - for (size_t i = 1; i < arguments.size(); ++i) { + for (size_t i = 1; i < arguments.size(); ++i) + { const auto& element = (*block.getByPosition(arguments[i]).column)[row_num]; if (element.getType() != Field::Types::Float64) throw Exception("Prediction arguments must be values of type Float", @@ -189,4 +192,4 @@ private: struct NameLinearRegression { static constexpr auto name = "LinearRegression"; }; -} \ No newline at end of file +}