changes on return type of linearRegression

This commit is contained in:
Alexander Kozhikhov 2019-05-27 23:14:23 +03:00
parent dd259c408d
commit 5cd85baec6
9 changed files with 85 additions and 63 deletions

View File

@ -4,6 +4,7 @@
#include <IO/WriteHelpers.h>
#include <Interpreters/castColumn.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnTuple.h>
#include <Common/FieldVisitors.h>
#include <Common/typeid_cast.h>
#include "AggregateFunctionFactory.h"
@ -152,24 +153,28 @@ void LinearModelData::predict(
void LinearModelData::returnWeights(IColumn & to) const
{
// auto & column = static_cast<ColumnVector<Float64> &>(to);
// for (auto weight_value : weights)
// {
// column.getData().push_back(weight_value);
// }
// column.getData().push_back(bias);
auto & column = static_cast<ColumnArray &>(to);
size_t size = weights.size() + 1;
Array weights_array;
weights_array.reserve(weights.size() + 1);
for (auto weight_value : weights)
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
size_t old_size = offsets_to.back();
offsets_to.push_back(old_size + size);
if (size)
{
weights_array.push_back(weight_value);
}
weights_array.push_back(bias);
typename ColumnFloat64::Container & val_to
= static_cast<ColumnFloat64 &>(arr_to.getData()).getData();
// column.getData().push_back(weights_array);
column.getData().insert(weights_array);
val_to.reserve(old_size + size);
size_t i = 0;
while (i < weights.size())
{
val_to.push_back(weights[i]);
i++;
}
val_to.push_back(bias);
}
}
void LinearModelData::read(ReadBuffer & buf)

View File

@ -4,6 +4,8 @@
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeArray.h>
#include "IAggregateFunction.h"
namespace DB
@ -270,7 +272,15 @@ public:
{
}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
}
DataTypePtr getReturnTypeToPredict() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void create(AggregateDataPtr place) const override
{

View File

@ -47,6 +47,10 @@ public:
/// Get the result type.
virtual DataTypePtr getReturnType() const = 0;
virtual DataTypePtr getReturnTypeToPredict() const
{
throw Exception("Prediction is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
virtual ~IAggregateFunction() {}

View File

@ -35,25 +35,6 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_);
}
/// This function is used in convertToValues() and predictValues()
/// and is written here to avoid repetitions
bool ColumnAggregateFunction::tryFinalizeAggregateFunction(MutableColumnPtr *res_) const
{
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
*res_ = std::move(res);
return true;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
*res_ = std::move(res);
return false;
}
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
{
/** If the aggregate function returns an unfinalized/unfinished state,
@ -86,17 +67,17 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
* AggregateFunction(quantileTiming(0.5), UInt64)
* into UInt16 - already finished result of `quantileTiming`.
*/
/** Convertion function is used in convertToValues and predictValues
* in the similar part of both functions
*/
MutableColumnPtr res;
if (tryFinalizeAggregateFunction(&res))
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
for (auto val : data)
func->insertResultInto(val, *res);
@ -105,8 +86,8 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
{
MutableColumnPtr res;
tryFinalizeAggregateFunction(&res);
MutableColumnPtr res = func->getReturnTypeToPredict()->createColumn();
res->reserve(data.size());
auto ML_function = func.get();
if (ML_function)

View File

@ -36,6 +36,7 @@ public:
bool canBeInsideNullable() const override { return false; }
DataTypePtr getReturnType() const { return function->getReturnType(); }
DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); }
DataTypes getArgumentsDataTypes() const { return argument_types; }
/// NOTE These two functions for serializing single values are incompatible with the functions below.

View File

@ -2,6 +2,7 @@
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypesNumber.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Common/typeid_cast.h>
@ -60,7 +61,7 @@ public:
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return type->getReturnType();
return type->getReturnTypeToPredict();
}
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override

View File

@ -5,34 +5,32 @@
<stop_conditions>
<any_of>
<average_speed_not_changing_for_ms>3000</average_speed_not_changing_for_ms>
<total_time_ms>5000</total_time_ms>
<total_time_ms>10000</total_time_ms>
</any_of>
</stop_conditions>
<preconditions>
<table_exists>test.hits</table_exists>
</preconditions>
<main_metric>
<min_time/>
</main_metric>
<create_query>CREATE TABLE train_dataset(p1 Float64, p2 Float64, target Float64) ENGINE = Memory</create_query>
<create_query>CREATE TABLE test_dataset(p1 Float64, p2 Float64) ENGINE = Memory</create_query>
<create_query>CREATE TABLE test_model engine = Memory as select linearRegressionState(0.001)(target, p1, p2) as state from train_dataset</create_query>
<fill_query>INSERT INTO train_dataset values (2.0, 1.0, 6.0), (3.0, -1.0, 6.0), (4.0, 2.0, 11.0), (-2.0, 1.0, -2.0)</fill_query>
<fill_query>INSERT INTO test_dataset values (1.0, 1.0), (5.0, -3.0)</fill_query>
<create_query>CREATE TABLE test_model engine = Memory as select linearRegressionState(0.0001)(Age, Income, ParamPrice, Robotness, RefererHash) as state from test.hits</create_query>
<!-- Check model fit-->
<query>with (SELECT linearRegressionState(0.001)(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
<query>with (SELECT linearRegressionState(0.0001, 0, 15)(Age, Income, ParamPrice, Robotness, RefererHash) FROM test.hits) as model select toColumnTypeName(model)</query>
<!-- Check fit with Momentum -->
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Momentum')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
<!-- Check fit with Nesterov-->
<query>with (SELECT linearRegressionState(0.001, 0.1, 1, 'Nesterov')(target, p1, p2) FROM train_dataset) as model select toColumnTypeName(model)</query>
<!-- Check model fit with Momentum-->
<query>with (SELECT linearRegressionState(0.0001, 0, 15, 'Momentum')(Age, Income, ParamPrice, Robotness, RefererHash) FROM test.hits) as model select toColumnTypeName(model)</query>
<!-- Check model fit with Nesterov-->
<query>with (SELECT linearRegressionState(0.0001, 0, 15, 'Nesterov')(Age, Income, ParamPrice, Robotness, RefererHash) FROM test.hits) as model select toColumnTypeName(model)</query>
<!-- Check model predict-->
<query>with (SELECT state FROM test_model) as model select evalMLMethod(model, p1, p2) from test_dataset</query>
<query>with (SELECT state FROM test_model) as model select evalMLMethod(model, Income, ParamPrice, Robotness, RefererHash) from test.hits</query>
<drop_query>DROP TABLE IF EXISTS train_dataset</drop_query>
<drop_query>DROP TABLE IF EXISTS test_dataset</drop_query>

View File

@ -2,3 +2,5 @@
1
1
1
1
1

File diff suppressed because one or more lines are too long