mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 23:52:03 +00:00
Merge pull request #4917 from yandex/hczhcz-master
Add aggregate function leastSqr
This commit is contained in:
commit
b656786608
85
dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.cpp
Normal file
85
dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.cpp
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
#include <AggregateFunctions/AggregateFunctionLeastSqr.h>
|
||||||
|
|
||||||
|
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||||
|
#include <AggregateFunctions/FactoryHelpers.h>
|
||||||
|
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace
|
||||||
|
{
|
||||||
|
|
||||||
|
AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||||
|
const String & name,
|
||||||
|
const DataTypes & arguments,
|
||||||
|
const Array & params
|
||||||
|
)
|
||||||
|
{
|
||||||
|
assertNoParameters(name, params);
|
||||||
|
assertBinary(name, arguments);
|
||||||
|
|
||||||
|
const IDataType * x_arg = arguments.front().get();
|
||||||
|
|
||||||
|
WhichDataType which_x {
|
||||||
|
x_arg
|
||||||
|
};
|
||||||
|
|
||||||
|
const IDataType * y_arg = arguments.back().get();
|
||||||
|
|
||||||
|
WhichDataType which_y {
|
||||||
|
y_arg
|
||||||
|
};
|
||||||
|
|
||||||
|
#define FOR_LEASTSQR_TYPES_2(M, T) \
|
||||||
|
M(T, UInt8) \
|
||||||
|
M(T, UInt16) \
|
||||||
|
M(T, UInt32) \
|
||||||
|
M(T, UInt64) \
|
||||||
|
M(T, Int8) \
|
||||||
|
M(T, Int16) \
|
||||||
|
M(T, Int32) \
|
||||||
|
M(T, Int64) \
|
||||||
|
M(T, Float32) \
|
||||||
|
M(T, Float64)
|
||||||
|
#define FOR_LEASTSQR_TYPES(M) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, UInt8) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, UInt16) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, UInt32) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, UInt64) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, Int8) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, Int16) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, Int32) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, Int64) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, Float32) \
|
||||||
|
FOR_LEASTSQR_TYPES_2(M, Float64)
|
||||||
|
#define DISPATCH(T1, T2) \
|
||||||
|
if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
|
||||||
|
return std::make_shared<AggregateFunctionLeastSqr<T1, T2>>( \
|
||||||
|
arguments, \
|
||||||
|
params \
|
||||||
|
);
|
||||||
|
|
||||||
|
FOR_LEASTSQR_TYPES(DISPATCH)
|
||||||
|
|
||||||
|
#undef FOR_LEASTSQR_TYPES_2
|
||||||
|
#undef FOR_LEASTSQR_TYPES
|
||||||
|
#undef DISPATCH
|
||||||
|
|
||||||
|
throw Exception(
|
||||||
|
"Illegal types ("
|
||||||
|
+ x_arg->getName() + ", " + y_arg->getName()
|
||||||
|
+ ") of arguments of aggregate function " + name
|
||||||
|
+ ", must be Native Ints, Native UInts or Floats",
|
||||||
|
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory)
|
||||||
|
{
|
||||||
|
factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
195
dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.h
Normal file
195
dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.h
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <AggregateFunctions/IAggregateFunction.h>
|
||||||
|
#include <Columns/ColumnVector.h>
|
||||||
|
#include <Columns/ColumnTuple.h>
|
||||||
|
#include <DataTypes/DataTypeNullable.h>
|
||||||
|
#include <DataTypes/DataTypesNumber.h>
|
||||||
|
#include <DataTypes/DataTypeTuple.h>
|
||||||
|
#include <IO/ReadHelpers.h>
|
||||||
|
#include <IO/WriteHelpers.h>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
namespace DB
|
||||||
|
{
|
||||||
|
|
||||||
|
namespace ErrorCodes
|
||||||
|
{
|
||||||
|
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Y, typename Ret>
|
||||||
|
struct AggregateFunctionLeastSqrData final
|
||||||
|
{
|
||||||
|
size_t count = 0;
|
||||||
|
Ret sum_x = 0;
|
||||||
|
Ret sum_y = 0;
|
||||||
|
Ret sum_xx = 0;
|
||||||
|
Ret sum_xy = 0;
|
||||||
|
|
||||||
|
void add(X x, Y y)
|
||||||
|
{
|
||||||
|
count += 1;
|
||||||
|
sum_x += x;
|
||||||
|
sum_y += y;
|
||||||
|
sum_xx += x * x;
|
||||||
|
sum_xy += x * y;
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(const AggregateFunctionLeastSqrData & other)
|
||||||
|
{
|
||||||
|
count += other.count;
|
||||||
|
sum_x += other.sum_x;
|
||||||
|
sum_y += other.sum_y;
|
||||||
|
sum_xx += other.sum_xx;
|
||||||
|
sum_xy += other.sum_xy;
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(WriteBuffer & buf) const
|
||||||
|
{
|
||||||
|
writeBinary(count, buf);
|
||||||
|
writeBinary(sum_x, buf);
|
||||||
|
writeBinary(sum_y, buf);
|
||||||
|
writeBinary(sum_xx, buf);
|
||||||
|
writeBinary(sum_xy, buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(ReadBuffer & buf)
|
||||||
|
{
|
||||||
|
readBinary(count, buf);
|
||||||
|
readBinary(sum_x, buf);
|
||||||
|
readBinary(sum_y, buf);
|
||||||
|
readBinary(sum_xx, buf);
|
||||||
|
readBinary(sum_xy, buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ret getK() const
|
||||||
|
{
|
||||||
|
Ret divisor = sum_xx * count - sum_x * sum_x;
|
||||||
|
|
||||||
|
if (divisor == 0)
|
||||||
|
return std::numeric_limits<Ret>::quiet_NaN();
|
||||||
|
|
||||||
|
return (sum_xy * count - sum_x * sum_y) / divisor;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ret getB(Ret k) const
|
||||||
|
{
|
||||||
|
if (count == 0)
|
||||||
|
return std::numeric_limits<Ret>::quiet_NaN();
|
||||||
|
|
||||||
|
return (sum_y - k * sum_x) / count;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Calculates simple linear regression parameters.
|
||||||
|
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
|
||||||
|
template <typename X, typename Y, typename Ret = Float64>
|
||||||
|
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
|
||||||
|
AggregateFunctionLeastSqrData<X, Y, Ret>,
|
||||||
|
AggregateFunctionLeastSqr<X, Y, Ret>
|
||||||
|
>
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
AggregateFunctionLeastSqr(
|
||||||
|
const DataTypes & arguments,
|
||||||
|
const Array & params
|
||||||
|
):
|
||||||
|
IAggregateFunctionDataHelper<
|
||||||
|
AggregateFunctionLeastSqrData<X, Y, Ret>,
|
||||||
|
AggregateFunctionLeastSqr<X, Y, Ret>
|
||||||
|
> {arguments, params}
|
||||||
|
{
|
||||||
|
// notice: arguments has been checked before
|
||||||
|
}
|
||||||
|
|
||||||
|
String getName() const override
|
||||||
|
{
|
||||||
|
return "leastSqr";
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * getHeaderFilePath() const override
|
||||||
|
{
|
||||||
|
return __FILE__;
|
||||||
|
}
|
||||||
|
|
||||||
|
void add(
|
||||||
|
AggregateDataPtr place,
|
||||||
|
const IColumn ** columns,
|
||||||
|
size_t row_num,
|
||||||
|
Arena *
|
||||||
|
) const override
|
||||||
|
{
|
||||||
|
auto col_x {
|
||||||
|
static_cast<const ColumnVector<X> *>(columns[0])
|
||||||
|
};
|
||||||
|
auto col_y {
|
||||||
|
static_cast<const ColumnVector<Y> *>(columns[1])
|
||||||
|
};
|
||||||
|
|
||||||
|
X x = col_x->getData()[row_num];
|
||||||
|
Y y = col_y->getData()[row_num];
|
||||||
|
|
||||||
|
this->data(place).add(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
void merge(
|
||||||
|
AggregateDataPtr place,
|
||||||
|
ConstAggregateDataPtr rhs, Arena *
|
||||||
|
) const override
|
||||||
|
{
|
||||||
|
this->data(place).merge(this->data(rhs));
|
||||||
|
}
|
||||||
|
|
||||||
|
void serialize(
|
||||||
|
ConstAggregateDataPtr place,
|
||||||
|
WriteBuffer & buf
|
||||||
|
) const override
|
||||||
|
{
|
||||||
|
this->data(place).serialize(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
void deserialize(
|
||||||
|
AggregateDataPtr place,
|
||||||
|
ReadBuffer & buf, Arena *
|
||||||
|
) const override
|
||||||
|
{
|
||||||
|
this->data(place).deserialize(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
DataTypePtr getReturnType() const override
|
||||||
|
{
|
||||||
|
DataTypes types {
|
||||||
|
std::make_shared<DataTypeNumber<Ret>>(),
|
||||||
|
std::make_shared<DataTypeNumber<Ret>>(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Strings names {
|
||||||
|
"k",
|
||||||
|
"b",
|
||||||
|
};
|
||||||
|
|
||||||
|
return std::make_shared<DataTypeTuple>(
|
||||||
|
std::move(types),
|
||||||
|
std::move(names)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
void insertResultInto(
|
||||||
|
ConstAggregateDataPtr place,
|
||||||
|
IColumn & to
|
||||||
|
) const override
|
||||||
|
{
|
||||||
|
Ret k = this->data(place).getK();
|
||||||
|
Ret b = this->data(place).getB(k);
|
||||||
|
|
||||||
|
auto & col_tuple = static_cast<ColumnTuple &>(to);
|
||||||
|
auto & col_k = static_cast<ColumnVector<Ret> &>(col_tuple.getColumn(0));
|
||||||
|
auto & col_b = static_cast<ColumnVector<Ret> &>(col_tuple.getColumn(1));
|
||||||
|
|
||||||
|
col_k.getData().push_back(k);
|
||||||
|
col_b.getData().push_back(b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -29,6 +29,7 @@ void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
|
|||||||
void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
|
void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||||
|
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
|
||||||
|
|
||||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||||
@ -69,6 +70,7 @@ void registerAggregateFunctions()
|
|||||||
registerAggregateFunctionHistogram(factory);
|
registerAggregateFunctionHistogram(factory);
|
||||||
registerAggregateFunctionRetention(factory);
|
registerAggregateFunctionRetention(factory);
|
||||||
registerAggregateFunctionEntropy(factory);
|
registerAggregateFunctionEntropy(factory);
|
||||||
|
registerAggregateFunctionLeastSqr(factory);
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
8
dbms/tests/queries/0_stateless/00917_least_sqr.reference
Normal file
8
dbms/tests/queries/0_stateless/00917_least_sqr.reference
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
(10,90)
|
||||||
|
(10.3,89.5)
|
||||||
|
(10,-90)
|
||||||
|
(1,1)
|
||||||
|
(nan,nan)
|
||||||
|
(0,3)
|
||||||
|
(nan,nan)
|
||||||
|
(nan,nan)
|
9
dbms/tests/queries/0_stateless/00917_least_sqr.sql
Normal file
9
dbms/tests/queries/0_stateless/00917_least_sqr.sql
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 130]);
|
||||||
|
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 131]);
|
||||||
|
select arrayReduce('leastSqr', [-1, -2, -3, -4], [-100, -110, -120, -130]);
|
||||||
|
select arrayReduce('leastSqr', [5, 5.1], [6, 6.1]);
|
||||||
|
select arrayReduce('leastSqr', [0], [0]);
|
||||||
|
select arrayReduce('leastSqr', [3, 4], [3, 3]);
|
||||||
|
select arrayReduce('leastSqr', [3, 3], [3, 4]);
|
||||||
|
select arrayReduce('leastSqr', emptyArrayUInt8(), emptyArrayUInt8());
|
||||||
|
|
Loading…
Reference in New Issue
Block a user