mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
changes on lin reg
This commit is contained in:
parent
bbae5448f4
commit
0be0529b59
@ -3,6 +3,7 @@
|
|||||||
#include <IO/ReadHelpers.h>
|
#include <IO/ReadHelpers.h>
|
||||||
#include <IO/WriteHelpers.h>
|
#include <IO/WriteHelpers.h>
|
||||||
#include <Interpreters/castColumn.h>
|
#include <Interpreters/castColumn.h>
|
||||||
|
#include <Columns/ColumnArray.h>
|
||||||
#include <Common/FieldVisitors.h>
|
#include <Common/FieldVisitors.h>
|
||||||
#include <Common/typeid_cast.h>
|
#include <Common/typeid_cast.h>
|
||||||
#include "AggregateFunctionFactory.h"
|
#include "AggregateFunctionFactory.h"
|
||||||
@ -149,14 +150,26 @@ void LinearModelData::predict(
|
|||||||
gradient_computer->predict(container, block, arguments, weights, bias, context);
|
gradient_computer->predict(container, block, arguments, weights, bias, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LinearModelData::return_weights(IColumn & to) const
|
void LinearModelData::returnWeights(IColumn & to) const
|
||||||
{
|
{
|
||||||
auto & column = static_cast<ColumnVector<Float64> &>(to);
|
// auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||||
|
// for (auto weight_value : weights)
|
||||||
|
// {
|
||||||
|
// column.getData().push_back(weight_value);
|
||||||
|
// }
|
||||||
|
// column.getData().push_back(bias);
|
||||||
|
auto & column = static_cast<ColumnArray &>(to);
|
||||||
|
|
||||||
|
Array weights_array;
|
||||||
|
weights_array.reserve(weights.size() + 1);
|
||||||
for (auto weight_value : weights)
|
for (auto weight_value : weights)
|
||||||
{
|
{
|
||||||
column.getData().push_back(weight_value);
|
weights_array.push_back(weight_value);
|
||||||
}
|
}
|
||||||
column.getData().push_back(bias);
|
weights_array.push_back(bias);
|
||||||
|
|
||||||
|
// column.getData().push_back(weights_array);
|
||||||
|
column.getData().insert(weights_array);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LinearModelData::read(ReadBuffer & buf)
|
void LinearModelData::read(ReadBuffer & buf)
|
||||||
|
@ -218,7 +218,7 @@ public:
|
|||||||
void
|
void
|
||||||
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
|
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
|
||||||
|
|
||||||
void return_weights(IColumn & to) const;
|
void returnWeights(IColumn & to) const;
|
||||||
private:
|
private:
|
||||||
std::vector<Float64> weights;
|
std::vector<Float64> weights;
|
||||||
Float64 bias{0.0};
|
Float64 bias{0.0};
|
||||||
@ -307,7 +307,7 @@ public:
|
|||||||
*/
|
*/
|
||||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||||
{
|
{
|
||||||
this->data(place).return_weights(to);
|
this->data(place).returnWeights(to);
|
||||||
}
|
}
|
||||||
|
|
||||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||||
|
@ -21,17 +21,17 @@
|
|||||||
<fill_query>INSERT INTO train_dataset values (2.0, 1.0, 6.0), (3.0, -1.0, 6.0), (4.0, 2.0, 11.0), (-2.0, 1.0, -2.0)</fill_query>
|
<fill_query>INSERT INTO train_dataset values (2.0, 1.0, 6.0), (3.0, -1.0, 6.0), (4.0, 2.0, 11.0), (-2.0, 1.0, -2.0)</fill_query>
|
||||||
<fill_query>INSERT INTO test_dataset values (1.0, 1.0), (5.0, -3.0)</fill_query>
|
<fill_query>INSERT INTO test_dataset values (1.0, 1.0), (5.0, -3.0)</fill_query>
|
||||||
|
|
||||||
<!-- Проверяем, как работает fit-->
|
<!-- Check model fit-->
|
||||||
<query>with (SELECT linearRegressionState(0.001)(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
<query>with (SELECT linearRegressionState(0.001)(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
||||||
|
|
||||||
<!-- Проверяем fit с Momentum -->
|
<!-- Check fit with Momentum -->
|
||||||
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Momentum')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Momentum')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
||||||
|
|
||||||
<!-- Проверяем fit с Nesterov-->
|
<!-- Check fit with Nesterov-->
|
||||||
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Nesterov')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Nesterov')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
||||||
|
|
||||||
|
|
||||||
<!-- Проверяем, как работает predict-->
|
<!-- Check model predict-->
|
||||||
<query>with (SELECT state FROM test_model) as model select evalMLMethod(model, p1, p2) from test_dataset</query>
|
<query>with (SELECT state FROM test_model) as model select evalMLMethod(model, p1, p2) from test_dataset</query>
|
||||||
|
|
||||||
<drop_query>DROP TABLE IF EXISTS train_dataset</drop_query>
|
<drop_query>DROP TABLE IF EXISTS train_dataset</drop_query>
|
||||||
|
Loading…
Reference in New Issue
Block a user