Add aggregate function leastSqr

This commit is contained in:
hcz 2019-03-13 15:22:57 +08:00
parent faa94c09a9
commit 72bcbc76b1
5 changed files with 304 additions and 0 deletions

View File

@ -0,0 +1,112 @@
#include <AggregateFunctions/AggregateFunctionLeastSqr.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
namespace DB
{
namespace
{
AggregateFunctionPtr createAggregateFunctionLeastSqr(
const String & name,
const DataTypes & arguments,
const Array & params
)
{
assertNoParameters(name, params);
assertBinary(name, arguments);
const IDataType * x_arg = arguments.front().get();
WhichDataType which_x {
x_arg
};
if (
!which_x.isNativeUInt()
&& !which_x.isNativeInt()
&& !which_x.isFloat()
)
throw Exception {
"Illegal type " + x_arg->getName()
+ " of first argument of aggregate function "
+ name + ", must be Native Int, Native UInt or Float",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
const IDataType * y_arg = arguments.back().get();
WhichDataType which_y {
y_arg
};
if (
!which_y.isNativeUInt()
&& !which_y.isNativeInt()
&& !which_y.isFloat()
)
throw Exception {
"Illegal type " + y_arg->getName()
+ " of second argument of aggregate function "
+ name + ", must be Native Int, Native UInt or Float",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
};
if (which_x.isNativeUInt() && which_y.isNativeUInt())
return std::make_shared<AggregateFunctionLeastSqr<UInt64, UInt64>>(
arguments,
params
);
else if (which_x.isNativeUInt() && which_y.isNativeInt())
return std::make_shared<AggregateFunctionLeastSqr<UInt64, Int64>>(
arguments,
params
);
else if (which_x.isNativeUInt() && which_y.isFloat())
return std::make_shared<AggregateFunctionLeastSqr<UInt64, Float64>>(
arguments,
params
);
else if (which_x.isNativeInt() && which_y.isNativeUInt())
return std::make_shared<AggregateFunctionLeastSqr<Int64, UInt64>>(
arguments,
params
);
else if (which_x.isNativeInt() && which_y.isNativeInt())
return std::make_shared<AggregateFunctionLeastSqr<Int64, Int64>>(
arguments,
params
);
else if (which_x.isNativeInt() && which_y.isFloat())
return std::make_shared<AggregateFunctionLeastSqr<Int64, Float64>>(
arguments,
params
);
else if (which_x.isFloat() && which_y.isNativeUInt())
return std::make_shared<AggregateFunctionLeastSqr<Float64, UInt64>>(
arguments,
params
);
else if (which_x.isFloat() && which_y.isNativeInt())
return std::make_shared<AggregateFunctionLeastSqr<Float64, Int64>>(
arguments,
params
);
else // if (which_x.isFloat() && which_y.isFloat())
return std::make_shared<AggregateFunctionLeastSqr<Float64, Float64>>(
arguments,
params
);
}
}
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory)
{
factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr);
}
}

View File

@ -0,0 +1,176 @@
#pragma once
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/IColumn.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
template <typename X, typename Y, typename Ret>
struct AggregateFunctionLeastSqrData
{
size_t count = 0;
Ret sum_x = 0;
Ret sum_y = 0;
Ret sum_xx = 0;
Ret sum_xy = 0;
void add(X x, Y y)
{
count += 1;
sum_x += x;
sum_y += y;
sum_xx += x * x;
sum_xy += x * y;
}
void merge(const AggregateFunctionLeastSqrData & other)
{
count += other.count;
sum_x += other.sum_x;
sum_y += other.sum_y;
sum_xx += other.sum_xx;
sum_xy += other.sum_xy;
}
void serialize(WriteBuffer & buf) const
{
writeBinary(count, buf);
writeBinary(sum_x, buf);
writeBinary(sum_y, buf);
writeBinary(sum_xx, buf);
writeBinary(sum_xy, buf);
}
void deserialize(ReadBuffer & buf)
{
readBinary(count, buf);
readBinary(sum_x, buf);
readBinary(sum_y, buf);
readBinary(sum_xx, buf);
readBinary(sum_xy, buf);
}
Ret getK() const
{
return (sum_xy * count - sum_x * sum_y)
/ (sum_xx * count - sum_x * sum_x);
}
Ret getB(Ret k) const
{
return (sum_y - k * sum_x) / count;
}
};
template <typename X, typename Y, typename Ret = Float64>
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
AggregateFunctionLeastSqr<X, Y, Ret>
>
{
public:
AggregateFunctionLeastSqr(
const DataTypes & arguments,
const Array & params
):
IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
AggregateFunctionLeastSqr<X, Y, Ret>
> {arguments, params}
{
// notice: arguments has been checked before
}
void add(
AggregateDataPtr place,
const IColumn ** columns,
size_t row_num,
Arena *
) const override
{
X x = (*columns[0])[row_num].template get<X>();
Y y = (*columns[1])[row_num].template get<Y>();
this->data(place).add(x, y);
}
void merge(
AggregateDataPtr place,
ConstAggregateDataPtr rhs, Arena *
) const override
{
this->data(place).merge(this->data(rhs));
}
void serialize(
ConstAggregateDataPtr place,
WriteBuffer & buf
) const override
{
this->data(place).serialize(buf);
}
void deserialize(
AggregateDataPtr place,
ReadBuffer & buf, Arena *
) const override
{
this->data(place).deserialize(buf);
}
DataTypePtr getReturnType() const override
{
DataTypes types {
std::make_shared<DataTypeNullable>(
std::make_shared<DataTypeFloat64>()
),
std::make_shared<DataTypeNullable>(
std::make_shared<DataTypeFloat64>()
),
};
Strings names {
"k",
"b",
};
return std::make_shared<DataTypeTuple>(
std::move(types),
std::move(names)
);
}
void insertResultInto(
ConstAggregateDataPtr place,
IColumn & to
) const override
{
Ret k = this->data(place).getK();
Ret b = this->data(place).getB(k);
Tuple result;
result.toUnderType().reserve(2);
result.toUnderType().emplace_back(k);
result.toUnderType().emplace_back(b);
to.insert(std::move(result));
}
String getName() const override { return "leastSqr"; }
const char * getHeaderFilePath() const override { return __FILE__; }
};
}

View File

@ -29,6 +29,7 @@ void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -69,6 +70,7 @@ void registerAggregateFunctions()
registerAggregateFunctionHistogram(factory);
registerAggregateFunctionRetention(factory);
registerAggregateFunctionEntropy(factory);
registerAggregateFunctionLeastSqr(factory);
}
{

View File

@ -0,0 +1,7 @@
(10,90)
(10.3,89.5)
(10,-90)
(1,1)
(nan,nan)
(0,3)
(nan,nan)

View File

@ -0,0 +1,7 @@
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 130]);
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 131]);
select arrayReduce('leastSqr', [-1, -2, -3, -4], [-100, -110, -120, -130]);
select arrayReduce('leastSqr', [5, 5.1], [6, 6.1]);
select arrayReduce('leastSqr', [0], [0]);
select arrayReduce('leastSqr', [3, 4], [3, 3]);
select arrayReduce('leastSqr', [3, 3], [3, 4]);