mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
changes on return type of linearRegression
This commit is contained in:
parent
dd259c408d
commit
5cd85baec6
@ -4,6 +4,7 @@
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Interpreters/castColumn.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include "AggregateFunctionFactory.h"
|
||||
@ -152,24 +153,28 @@ void LinearModelData::predict(
|
||||
|
||||
void LinearModelData::returnWeights(IColumn & to) const
|
||||
{
|
||||
// auto & column = static_cast<ColumnVector<Float64> &>(to);
|
||||
// for (auto weight_value : weights)
|
||||
// {
|
||||
// column.getData().push_back(weight_value);
|
||||
// }
|
||||
// column.getData().push_back(bias);
|
||||
auto & column = static_cast<ColumnArray &>(to);
|
||||
size_t size = weights.size() + 1;
|
||||
|
||||
Array weights_array;
|
||||
weights_array.reserve(weights.size() + 1);
|
||||
for (auto weight_value : weights)
|
||||
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
size_t old_size = offsets_to.back();
|
||||
offsets_to.push_back(old_size + size);
|
||||
|
||||
if (size)
|
||||
{
|
||||
weights_array.push_back(weight_value);
|
||||
}
|
||||
weights_array.push_back(bias);
|
||||
typename ColumnFloat64::Container & val_to
|
||||
= static_cast<ColumnFloat64 &>(arr_to.getData()).getData();
|
||||
|
||||
// column.getData().push_back(weights_array);
|
||||
column.getData().insert(weights_array);
|
||||
val_to.reserve(old_size + size);
|
||||
size_t i = 0;
|
||||
while (i < weights.size())
|
||||
{
|
||||
val_to.push_back(weights[i]);
|
||||
i++;
|
||||
}
|
||||
val_to.push_back(bias);
|
||||
}
|
||||
}
|
||||
|
||||
void LinearModelData::read(ReadBuffer & buf)
|
||||
|
@ -4,6 +4,8 @@
|
||||
#include <Columns/ColumnsCommon.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include "IAggregateFunction.h"
|
||||
|
||||
namespace DB
|
||||
@ -270,7 +272,15 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeToPredict() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
|
@ -47,6 +47,10 @@ public:
|
||||
|
||||
/// Get the result type.
|
||||
virtual DataTypePtr getReturnType() const = 0;
|
||||
virtual DataTypePtr getReturnTypeToPredict() const
|
||||
{
|
||||
throw Exception("Prediction is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
virtual ~IAggregateFunction() {}
|
||||
|
||||
|
@ -35,25 +35,6 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
|
||||
arenas.push_back(arena_);
|
||||
}
|
||||
|
||||
/// This function is used in convertToValues() and predictValues()
|
||||
/// and is written here to avoid repetitions
|
||||
bool ColumnAggregateFunction::tryFinalizeAggregateFunction(MutableColumnPtr *res_) const
|
||||
{
|
||||
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
*res_ = std::move(res);
|
||||
return true;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
*res_ = std::move(res);
|
||||
return false;
|
||||
}
|
||||
|
||||
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
{
|
||||
/** If the aggregate function returns an unfinalized/unfinished state,
|
||||
@ -86,17 +67,17 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
* AggregateFunction(quantileTiming(0.5), UInt64)
|
||||
* into UInt16 - already finished result of `quantileTiming`.
|
||||
*/
|
||||
|
||||
/** Convertion function is used in convertToValues and predictValues
|
||||
* in the similar part of both functions
|
||||
*/
|
||||
|
||||
MutableColumnPtr res;
|
||||
if (tryFinalizeAggregateFunction(&res))
|
||||
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
|
||||
for (auto val : data)
|
||||
func->insertResultInto(val, *res);
|
||||
|
||||
@ -105,8 +86,8 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
|
||||
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
|
||||
{
|
||||
MutableColumnPtr res;
|
||||
tryFinalizeAggregateFunction(&res);
|
||||
MutableColumnPtr res = func->getReturnTypeToPredict()->createColumn();
|
||||
res->reserve(data.size());
|
||||
|
||||
auto ML_function = func.get();
|
||||
if (ML_function)
|
||||
|
@ -36,6 +36,7 @@ public:
|
||||
bool canBeInsideNullable() const override { return false; }
|
||||
|
||||
DataTypePtr getReturnType() const { return function->getReturnType(); }
|
||||
DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); }
|
||||
DataTypes getArgumentsDataTypes() const { return argument_types; }
|
||||
|
||||
/// NOTE These two functions for serializing single values are incompatible with the functions below.
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnAggregateFunction.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
|
||||
@ -60,7 +61,7 @@ public:
|
||||
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return type->getReturnType();
|
||||
return type->getReturnTypeToPredict();
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||
|
@ -5,34 +5,32 @@
|
||||
<stop_conditions>
|
||||
<any_of>
|
||||
<average_speed_not_changing_for_ms>3000</average_speed_not_changing_for_ms>
|
||||
<total_time_ms>5000</total_time_ms>
|
||||
<total_time_ms>10000</total_time_ms>
|
||||
</any_of>
|
||||
</stop_conditions>
|
||||
|
||||
<preconditions>
|
||||
<table_exists>test.hits</table_exists>
|
||||
</preconditions>
|
||||
|
||||
<main_metric>
|
||||
<min_time/>
|
||||
</main_metric>
|
||||
|
||||
<create_query>CREATE TABLE train_dataset(p1 Float64, p2 Float64, target Float64) ENGINE = Memory</create_query>
|
||||
<create_query>CREATE TABLE test_dataset(p1 Float64, p2 Float64) ENGINE = Memory</create_query>
|
||||
|
||||
<create_query>CREATE TABLE test_model engine = Memory as select linearRegressionState(0.001)(target, p1, p2) as state from train_dataset</create_query>
|
||||
|
||||
<fill_query>INSERT INTO train_dataset values (2.0, 1.0, 6.0), (3.0, -1.0, 6.0), (4.0, 2.0, 11.0), (-2.0, 1.0, -2.0)</fill_query>
|
||||
<fill_query>INSERT INTO test_dataset values (1.0, 1.0), (5.0, -3.0)</fill_query>
|
||||
<create_query>CREATE TABLE test_model engine = Memory as select linearRegressionState(0.0001)(Age, Income, ParamPrice, Robotness, RefererHash) as state from test.hits</create_query>
|
||||
|
||||
<!-- Check model fit-->
|
||||
<query>with (SELECT linearRegressionState(0.001)(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
||||
<query>with (SELECT linearRegressionState(0.0001, 0, 15)(Age, Income, ParamPrice, Robotness, RefererHash) FROM test.hits) as model select toColumnTypeName(model)</query>
|
||||
|
||||
<!-- Check fit with Momentum -->
|
||||
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Momentum')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
||||
|
||||
<!-- Check fit with Nesterov-->
|
||||
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Nesterov')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
|
||||
<!-- Check model fit with Momentum-->
|
||||
<query>with (SELECT linearRegressionState(0.0001, 0, 15, 'Momentum')(Age, Income, ParamPrice, Robotness, RefererHash) FROM test.hits) as model select toColumnTypeName(model)</query>
|
||||
|
||||
<!-- Check model fit with Nesterov-->
|
||||
<query>with (SELECT linearRegressionState(0.0001, 0, 15, 'Nesterov')(Age, Income, ParamPrice, Robotness, RefererHash) FROM test.hits) as model select toColumnTypeName(model)</query>
|
||||
|
||||
<!-- Check model predict-->
|
||||
<query>with (SELECT state FROM test_model) as model select evalMLMethod(model, p1, p2) from test_dataset</query>
|
||||
<query>with (SELECT state FROM test_model) as model select evalMLMethod(model, Income, ParamPrice, Robotness, RefererHash) from test.hits</query>
|
||||
|
||||
|
||||
<drop_query>DROP TABLE IF EXISTS train_dataset</drop_query>
|
||||
<drop_query>DROP TABLE IF EXISTS test_dataset</drop_query>
|
||||
|
@ -2,3 +2,5 @@
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
1
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user