LinearRegression without -State modifier now simply returns its weights

This commit is contained in:
Alexander Kozhikhov 2019-05-25 02:18:44 +03:00
parent ef3e47c037
commit a63ad11ee7
4 changed files with 29 additions and 3 deletions

View File

@ -3,6 +3,7 @@
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/castColumn.h>
#include <Columns/ColumnArray.h>
#include <Common/FieldVisitors.h>
#include <Common/typeid_cast.h>
#include "AggregateFunctionFactory.h"
@ -149,6 +150,16 @@ void LinearModelData::predict(
gradient_computer->predict(container, block, arguments, weights, bias, context);
}
void LinearModelData::return_weights(IColumn & to) const
{
auto & column = static_cast<ColumnVector<Float64> &>(to);
for (auto weight_value : weights)
{
column.getData().push_back(weight_value);
}
column.getData().push_back(bias);
}
void LinearModelData::read(ReadBuffer & buf)
{
readBinary(bias, buf);

View File

@ -218,6 +218,7 @@ public:
void
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
void return_weights(IColumn & to) const;
private:
std::vector<Float64> weights;
Float64 bias{0.0};
@ -301,11 +302,12 @@ public:
this->data(place).predict(column.getData(), block, arguments, context);
}
/** This function is called if aggregate function without State modifier is selected in a query.
* Inserts all weights of the model into the column 'to', so user may use such information if needed
*/
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
std::ignore = place;
std::ignore = to;
throw std::runtime_error("not implemented");
this->data(place).return_weights(to);
}
const char * getHeaderFilePath() const override { return __FILE__; }

View File

@ -1 +1,4 @@
1
1
1
1

View File

@ -15,3 +15,13 @@ create table test.model engine = Memory as select linearRegressionState(0.03, 0.
select ans > -67.0 and ans < -66.9 from
(with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) as ans from test.defaults limit 1);
-- Check that returned weights are close to real
select ans > 0.49 and ans < 0.51 from
(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 0, 1);
select ans > -2.01 and ans < -1.99 from
(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 1, 1);
select ans > 2.99 and ans < 3.01 from
(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 2, 1);