mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
Add aggregate function leastSqr
This commit is contained in:
parent
faa94c09a9
commit
72bcbc76b1
112
dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.cpp
Normal file
112
dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.cpp
Normal file
@ -0,0 +1,112 @@
|
||||
#include <AggregateFunctions/AggregateFunctionLeastSqr.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
const String & name,
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
)
|
||||
{
|
||||
assertNoParameters(name, params);
|
||||
assertBinary(name, arguments);
|
||||
|
||||
const IDataType * x_arg = arguments.front().get();
|
||||
|
||||
WhichDataType which_x {
|
||||
x_arg
|
||||
};
|
||||
|
||||
if (
|
||||
!which_x.isNativeUInt()
|
||||
&& !which_x.isNativeInt()
|
||||
&& !which_x.isFloat()
|
||||
)
|
||||
throw Exception {
|
||||
"Illegal type " + x_arg->getName()
|
||||
+ " of first argument of aggregate function "
|
||||
+ name + ", must be Native Int, Native UInt or Float",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
|
||||
};
|
||||
|
||||
const IDataType * y_arg = arguments.back().get();
|
||||
|
||||
WhichDataType which_y {
|
||||
y_arg
|
||||
};
|
||||
|
||||
if (
|
||||
!which_y.isNativeUInt()
|
||||
&& !which_y.isNativeInt()
|
||||
&& !which_y.isFloat()
|
||||
)
|
||||
throw Exception {
|
||||
"Illegal type " + y_arg->getName()
|
||||
+ " of second argument of aggregate function "
|
||||
+ name + ", must be Native Int, Native UInt or Float",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT
|
||||
};
|
||||
|
||||
if (which_x.isNativeUInt() && which_y.isNativeUInt())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<UInt64, UInt64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
else if (which_x.isNativeUInt() && which_y.isNativeInt())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<UInt64, Int64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
else if (which_x.isNativeUInt() && which_y.isFloat())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<UInt64, Float64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
else if (which_x.isNativeInt() && which_y.isNativeUInt())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<Int64, UInt64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
else if (which_x.isNativeInt() && which_y.isNativeInt())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<Int64, Int64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
else if (which_x.isNativeInt() && which_y.isFloat())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<Int64, Float64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
else if (which_x.isFloat() && which_y.isNativeUInt())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<Float64, UInt64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
else if (which_x.isFloat() && which_y.isNativeInt())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<Float64, Int64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
else // if (which_x.isFloat() && which_y.isFloat())
|
||||
return std::make_shared<AggregateFunctionLeastSqr<Float64, Float64>>(
|
||||
arguments,
|
||||
params
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr);
|
||||
}
|
||||
|
||||
}
|
176
dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.h
Normal file
176
dbms/src/AggregateFunctions/AggregateFunctionLeastSqr.h
Normal file
@ -0,0 +1,176 @@
|
||||
#pragma once
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
#include <Columns/IColumn.h>
|
||||
#include <DataTypes/DataTypeNullable.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Ret>
|
||||
struct AggregateFunctionLeastSqrData
|
||||
{
|
||||
size_t count = 0;
|
||||
Ret sum_x = 0;
|
||||
Ret sum_y = 0;
|
||||
Ret sum_xx = 0;
|
||||
Ret sum_xy = 0;
|
||||
|
||||
void add(X x, Y y)
|
||||
{
|
||||
count += 1;
|
||||
sum_x += x;
|
||||
sum_y += y;
|
||||
sum_xx += x * x;
|
||||
sum_xy += x * y;
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionLeastSqrData & other)
|
||||
{
|
||||
count += other.count;
|
||||
sum_x += other.sum_x;
|
||||
sum_y += other.sum_y;
|
||||
sum_xx += other.sum_xx;
|
||||
sum_xy += other.sum_xy;
|
||||
}
|
||||
|
||||
void serialize(WriteBuffer & buf) const
|
||||
{
|
||||
writeBinary(count, buf);
|
||||
writeBinary(sum_x, buf);
|
||||
writeBinary(sum_y, buf);
|
||||
writeBinary(sum_xx, buf);
|
||||
writeBinary(sum_xy, buf);
|
||||
}
|
||||
|
||||
void deserialize(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(count, buf);
|
||||
readBinary(sum_x, buf);
|
||||
readBinary(sum_y, buf);
|
||||
readBinary(sum_xx, buf);
|
||||
readBinary(sum_xy, buf);
|
||||
}
|
||||
|
||||
Ret getK() const
|
||||
{
|
||||
return (sum_xy * count - sum_x * sum_y)
|
||||
/ (sum_xx * count - sum_x * sum_x);
|
||||
}
|
||||
|
||||
Ret getB(Ret k) const
|
||||
{
|
||||
return (sum_y - k * sum_x) / count;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename X, typename Y, typename Ret = Float64>
|
||||
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionLeastSqrData<X, Y, Ret>,
|
||||
AggregateFunctionLeastSqr<X, Y, Ret>
|
||||
>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionLeastSqr(
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
):
|
||||
IAggregateFunctionDataHelper<
|
||||
AggregateFunctionLeastSqrData<X, Y, Ret>,
|
||||
AggregateFunctionLeastSqr<X, Y, Ret>
|
||||
> {arguments, params}
|
||||
{
|
||||
// notice: arguments has been checked before
|
||||
}
|
||||
|
||||
void add(
|
||||
AggregateDataPtr place,
|
||||
const IColumn ** columns,
|
||||
size_t row_num,
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
X x = (*columns[0])[row_num].template get<X>();
|
||||
Y y = (*columns[1])[row_num].template get<Y>();
|
||||
|
||||
this->data(place).add(x, y);
|
||||
}
|
||||
|
||||
void merge(
|
||||
AggregateDataPtr place,
|
||||
ConstAggregateDataPtr rhs, Arena *
|
||||
) const override
|
||||
{
|
||||
this->data(place).merge(this->data(rhs));
|
||||
}
|
||||
|
||||
void serialize(
|
||||
ConstAggregateDataPtr place,
|
||||
WriteBuffer & buf
|
||||
) const override
|
||||
{
|
||||
this->data(place).serialize(buf);
|
||||
}
|
||||
|
||||
void deserialize(
|
||||
AggregateDataPtr place,
|
||||
ReadBuffer & buf, Arena *
|
||||
) const override
|
||||
{
|
||||
this->data(place).deserialize(buf);
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
DataTypes types {
|
||||
std::make_shared<DataTypeNullable>(
|
||||
std::make_shared<DataTypeFloat64>()
|
||||
),
|
||||
std::make_shared<DataTypeNullable>(
|
||||
std::make_shared<DataTypeFloat64>()
|
||||
),
|
||||
};
|
||||
|
||||
Strings names {
|
||||
"k",
|
||||
"b",
|
||||
};
|
||||
|
||||
return std::make_shared<DataTypeTuple>(
|
||||
std::move(types),
|
||||
std::move(names)
|
||||
);
|
||||
}
|
||||
|
||||
void insertResultInto(
|
||||
ConstAggregateDataPtr place,
|
||||
IColumn & to
|
||||
) const override
|
||||
{
|
||||
Ret k = this->data(place).getK();
|
||||
Ret b = this->data(place).getB(k);
|
||||
|
||||
Tuple result;
|
||||
result.toUnderType().reserve(2);
|
||||
|
||||
result.toUnderType().emplace_back(k);
|
||||
result.toUnderType().emplace_back(b);
|
||||
|
||||
to.insert(std::move(result));
|
||||
}
|
||||
|
||||
String getName() const override { return "leastSqr"; }
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
};
|
||||
|
||||
}
|
@ -29,6 +29,7 @@ void registerAggregateFunctionsBitwise(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
|
||||
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||
@ -69,6 +70,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionHistogram(factory);
|
||||
registerAggregateFunctionRetention(factory);
|
||||
registerAggregateFunctionEntropy(factory);
|
||||
registerAggregateFunctionLeastSqr(factory);
|
||||
}
|
||||
|
||||
{
|
||||
|
7
dbms/tests/queries/0_stateless/00917_least_sqr.reference
Normal file
7
dbms/tests/queries/0_stateless/00917_least_sqr.reference
Normal file
@ -0,0 +1,7 @@
|
||||
(10,90)
|
||||
(10.3,89.5)
|
||||
(10,-90)
|
||||
(1,1)
|
||||
(nan,nan)
|
||||
(0,3)
|
||||
(nan,nan)
|
7
dbms/tests/queries/0_stateless/00917_least_sqr.sql
Normal file
7
dbms/tests/queries/0_stateless/00917_least_sqr.sql
Normal file
@ -0,0 +1,7 @@
|
||||
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 130]);
|
||||
select arrayReduce('leastSqr', [1, 2, 3, 4], [100, 110, 120, 131]);
|
||||
select arrayReduce('leastSqr', [-1, -2, -3, -4], [-100, -110, -120, -130]);
|
||||
select arrayReduce('leastSqr', [5, 5.1], [6, 6.1]);
|
||||
select arrayReduce('leastSqr', [0], [0]);
|
||||
select arrayReduce('leastSqr', [3, 4], [3, 3]);
|
||||
select arrayReduce('leastSqr', [3, 3], [3, 4]);
|
Loading…
Reference in New Issue
Block a user