Minor fixes for leastSqr.

This commit is contained in:
Nikolai Kochetov 2019-04-05 16:32:25 +03:00
parent feb16eedd2
commit bb9958b0d7
3 changed files with 13 additions and 12 deletions

View File

@ -2,6 +2,7 @@
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnTuple.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
@ -74,6 +75,8 @@ struct AggregateFunctionLeastSqrData final
}
};
/// Calculates simple linear regression parameters.
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
template <typename X, typename Y, typename Ret = Float64>
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
@ -150,12 +153,8 @@ public:
DataTypePtr getReturnType() const override
{
DataTypes types {
std::make_shared<DataTypeNullable>(
std::make_shared<DataTypeFloat64>()
),
std::make_shared<DataTypeNullable>(
std::make_shared<DataTypeFloat64>()
),
std::make_shared<DataTypeNumber<Ret>>(),
std::make_shared<DataTypeNumber<Ret>>(),
};
Strings names {
@ -177,13 +176,12 @@ public:
Ret k = this->data(place).getK();
Ret b = this->data(place).getB(k);
Tuple result;
result.toUnderType().reserve(2);
auto & col_tuple = static_cast<ColumnTuple &>(to);
auto & col_k = static_cast<ColumnVector<Ret> &>(col_tuple.getColumn(0));
auto & col_b = static_cast<ColumnVector<Ret> &>(col_tuple.getColumn(1));
result.toUnderType().emplace_back(k);
result.toUnderType().emplace_back(b);
to.insert(std::move(result));
col_k.getData().push_back(k);
col_b.getData().push_back(b);
}
};

View File

@ -5,3 +5,4 @@
(nan,nan)
(0,3)
(nan,nan)
(nan,nan)

View File

@ -5,3 +5,5 @@ select arrayReduce('leastSqr', [5, 5.1], [6, 6.1]);
select arrayReduce('leastSqr', [0], [0]);
select arrayReduce('leastSqr', [3, 4], [3, 3]);
select arrayReduce('leastSqr', [3, 3], [3, 4]);
select arrayReduce('leastSqr', emptyArrayUInt8(), emptyArrayUInt8());