From a63ad11ee7915fb367cec80f18fb2c35d025d6d4 Mon Sep 17 00:00:00 2001 From: Alexander Kozhikhov Date: Sat, 25 May 2019 02:18:44 +0300 Subject: [PATCH] LinearRegression without -State modifier now simply returns its weights --- .../AggregateFunctions/AggregateFunctionMLMethod.cpp | 11 +++++++++++ .../AggregateFunctions/AggregateFunctionMLMethod.h | 8 +++++--- .../tests/queries/0_stateless/00947_ml_test.reference | 3 +++ dbms/tests/queries/0_stateless/00947_ml_test.sql | 10 ++++++++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp index c7162f04b51..29f6b700d76 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include "AggregateFunctionFactory.h" @@ -149,6 +150,16 @@ void LinearModelData::predict( gradient_computer->predict(container, block, arguments, weights, bias, context); } +void LinearModelData::return_weights(IColumn & to) const +{ + auto & column = static_cast &>(to); + for (auto weight_value : weights) + { + column.getData().push_back(weight_value); + } + column.getData().push_back(bias); +} + void LinearModelData::read(ReadBuffer & buf) { readBinary(bias, buf); diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h index 3c4660f14f5..2ee2c1fe415 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionMLMethod.h @@ -218,6 +218,7 @@ public: void predict(ColumnVector::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const; + void return_weights(IColumn & to) const; private: std::vector weights; Float64 bias{0.0}; @@ -301,11 +302,12 @@ public: this->data(place).predict(column.getData(), block, arguments, context); } + /** This function is called if aggregate function without State modifier is selected in a query. + * Inserts all weights of the model into the column 'to', so user may use such information if needed + */ void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { - std::ignore = place; - std::ignore = to; - throw std::runtime_error("not implemented"); + this->data(place).return_weights(to); } const char * getHeaderFilePath() const override { return __FILE__; } diff --git a/dbms/tests/queries/0_stateless/00947_ml_test.reference b/dbms/tests/queries/0_stateless/00947_ml_test.reference index d00491fd7e5..98fb6a68656 100644 --- a/dbms/tests/queries/0_stateless/00947_ml_test.reference +++ b/dbms/tests/queries/0_stateless/00947_ml_test.reference @@ -1 +1,4 @@ 1 +1 +1 +1 diff --git a/dbms/tests/queries/0_stateless/00947_ml_test.sql b/dbms/tests/queries/0_stateless/00947_ml_test.sql index b5a602f6826..923fa1ff357 100644 --- a/dbms/tests/queries/0_stateless/00947_ml_test.sql +++ b/dbms/tests/queries/0_stateless/00947_ml_test.sql @@ -15,3 +15,13 @@ create table test.model engine = Memory as select linearRegressionState(0.03, 0. select ans > -67.0 and ans < -66.9 from (with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) as ans from test.defaults limit 1); + +-- Check that returned weights are close to real +select ans > 0.49 and ans < 0.51 from +(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 0, 1); + +select ans > -2.01 and ans < -1.99 from +(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 1, 1); + +select ans > 2.99 and ans < 3.01 from +(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 2, 1);