changes on lin reg

This commit is contained in:
Alexander Kozhikhov 2019-05-25 21:41:58 +03:00
parent bbae5448f4
commit 0be0529b59
3 changed files with 23 additions and 10 deletions

View File

@ -3,6 +3,7 @@
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
#include <Interpreters/castColumn.h> #include <Interpreters/castColumn.h>
#include <Columns/ColumnArray.h>
#include <Common/FieldVisitors.h> #include <Common/FieldVisitors.h>
#include <Common/typeid_cast.h> #include <Common/typeid_cast.h>
#include "AggregateFunctionFactory.h" #include "AggregateFunctionFactory.h"
@ -149,14 +150,26 @@ void LinearModelData::predict(
gradient_computer->predict(container, block, arguments, weights, bias, context); gradient_computer->predict(container, block, arguments, weights, bias, context);
} }
void LinearModelData::return_weights(IColumn & to) const void LinearModelData::returnWeights(IColumn & to) const
{ {
auto & column = static_cast<ColumnVector<Float64> &>(to); // auto & column = static_cast<ColumnVector<Float64> &>(to);
// for (auto weight_value : weights)
// {
// column.getData().push_back(weight_value);
// }
// column.getData().push_back(bias);
auto & column = static_cast<ColumnArray &>(to);
Array weights_array;
weights_array.reserve(weights.size() + 1);
for (auto weight_value : weights) for (auto weight_value : weights)
{ {
column.getData().push_back(weight_value); weights_array.push_back(weight_value);
} }
column.getData().push_back(bias); weights_array.push_back(bias);
// column.getData().push_back(weights_array);
column.getData().insert(weights_array);
} }
void LinearModelData::read(ReadBuffer & buf) void LinearModelData::read(ReadBuffer & buf)

View File

@ -218,7 +218,7 @@ public:
void void
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const; predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
void return_weights(IColumn & to) const; void returnWeights(IColumn & to) const;
private: private:
std::vector<Float64> weights; std::vector<Float64> weights;
Float64 bias{0.0}; Float64 bias{0.0};
@ -307,7 +307,7 @@ public:
*/ */
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{ {
this->data(place).return_weights(to); this->data(place).returnWeights(to);
} }
const char * getHeaderFilePath() const override { return __FILE__; } const char * getHeaderFilePath() const override { return __FILE__; }

View File

@ -21,17 +21,17 @@
<fill_query>INSERT INTO train_dataset values (2.0, 1.0, 6.0), (3.0, -1.0, 6.0), (4.0, 2.0, 11.0), (-2.0, 1.0, -2.0)</fill_query> <fill_query>INSERT INTO train_dataset values (2.0, 1.0, 6.0), (3.0, -1.0, 6.0), (4.0, 2.0, 11.0), (-2.0, 1.0, -2.0)</fill_query>
<fill_query>INSERT INTO test_dataset values (1.0, 1.0), (5.0, -3.0)</fill_query> <fill_query>INSERT INTO test_dataset values (1.0, 1.0), (5.0, -3.0)</fill_query>
<!-- Проверяем, как работает fit--> <!-- Check model fit-->
<query>with (SELECT linearRegressionState(0.001)(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query> <query>with (SELECT linearRegressionState(0.001)(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
<!-- Проверяем fit с Momentum --> <!-- Check fit with Momentum -->
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Momentum')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query> <query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Momentum')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
<!-- Проверяем fit с Nesterov--> <!-- Check fit with Nesterov-->
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Nesterov')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query> <query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Nesterov')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
<!-- Проверяем, как работает predict--> <!-- Check model predict-->
<query>with (SELECT state FROM test_model) as model select evalMLMethod(model, p1, p2) from test_dataset</query> <query>with (SELECT state FROM test_model) as model select evalMLMethod(model, p1, p2) from test_dataset</query>
<drop_query>DROP TABLE IF EXISTS train_dataset</drop_query> <drop_query>DROP TABLE IF EXISTS train_dataset</drop_query>