mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
LinearRegression without -State modifier now simply returns its weights
This commit is contained in:
parent
ef3e47c037
commit
a63ad11ee7
@ -3,6 +3,7 @@
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Interpreters/castColumn.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include "AggregateFunctionFactory.h"
|
||||
@ -149,6 +150,16 @@ void LinearModelData::predict(
|
||||
gradient_computer->predict(container, block, arguments, weights, bias, context);
|
||||
}
|
||||
|
||||
void LinearModelData::return_weights(IColumn & to) const
|
||||
{
|
||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||
for (auto weight_value : weights)
|
||||
{
|
||||
column.getData().push_back(weight_value);
|
||||
}
|
||||
column.getData().push_back(bias);
|
||||
}
|
||||
|
||||
void LinearModelData::read(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(bias, buf);
|
||||
|
@ -218,6 +218,7 @@ public:
|
||||
void
|
||||
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
|
||||
|
||||
void return_weights(IColumn & to) const;
|
||||
private:
|
||||
std::vector<Float64> weights;
|
||||
Float64 bias{0.0};
|
||||
@ -301,11 +302,12 @@ public:
|
||||
this->data(place).predict(column.getData(), block, arguments, context);
|
||||
}
|
||||
|
||||
/** This function is called if aggregate function without State modifier is selected in a query.
|
||||
* Inserts all weights of the model into the column 'to', so user may use such information if needed
|
||||
*/
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
std::ignore = place;
|
||||
std::ignore = to;
|
||||
throw std::runtime_error("not implemented");
|
||||
this->data(place).return_weights(to);
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
|
@ -1 +1,4 @@
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
|
@ -15,3 +15,13 @@ create table test.model engine = Memory as select linearRegressionState(0.03, 0.
|
||||
|
||||
select ans > -67.0 and ans < -66.9 from
|
||||
(with (select state from test.model) as model select evalMLMethod(model, predict1, predict2) as ans from test.defaults limit 1);
|
||||
|
||||
-- Check that returned weights are close to real
|
||||
select ans > 0.49 and ans < 0.51 from
|
||||
(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 0, 1);
|
||||
|
||||
select ans > -2.01 and ans < -1.99 from
|
||||
(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 1, 1);
|
||||
|
||||
select ans > 2.99 and ans < 3.01 from
|
||||
(select linearRegression(0.03, 0.00001, 2, 'Nesterov')(target, param1, param2) as ans from test.defaults limit 2, 1);
|
||||
|
Loading…
Reference in New Issue
Block a user