Fix overflow in simpleLinearRegression

This commit is contained in:
hcz 2020-04-24 16:33:09 +08:00
parent 4ecc86beca
commit 541cd638ba
3 changed files with 15 additions and 14 deletions

View File

@ -18,16 +18,16 @@ namespace ErrorCodes
{
}
template <typename X, typename Y, typename Ret>
template <typename T>
struct AggregateFunctionSimpleLinearRegressionData final
{
size_t count = 0;
Ret sum_x = 0;
Ret sum_y = 0;
Ret sum_xx = 0;
Ret sum_xy = 0;
T sum_x = 0;
T sum_y = 0;
T sum_xx = 0;
T sum_xy = 0;
void add(X x, Y y)
void add(T x, T y)
{
count += 1;
sum_x += x;
@ -63,20 +63,20 @@ struct AggregateFunctionSimpleLinearRegressionData final
readBinary(sum_xy, buf);
}
Ret getK() const
T getK() const
{
Ret divisor = sum_xx * count - sum_x * sum_x;
T divisor = sum_xx * count - sum_x * sum_x;
if (divisor == 0)
return std::numeric_limits<Ret>::quiet_NaN();
return std::numeric_limits<T>::quiet_NaN();
return (sum_xy * count - sum_x * sum_y) / divisor;
}
Ret getB(Ret k) const
T getB(T k) const
{
if (count == 0)
return std::numeric_limits<Ret>::quiet_NaN();
return std::numeric_limits<T>::quiet_NaN();
return (sum_y - k * sum_x) / count;
}
@ -86,7 +86,7 @@ struct AggregateFunctionSimpleLinearRegressionData final
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
template <typename X, typename Y, typename Ret = Float64>
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
AggregateFunctionSimpleLinearRegressionData<Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
>
{
@ -96,7 +96,7 @@ public:
const Array & params
):
IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
AggregateFunctionSimpleLinearRegressionData<Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
> {arguments, params}
{

View File

@ -6,3 +6,4 @@
(0,3)
(nan,nan)
(nan,nan)
(100000000,900000000)

View File

@ -6,4 +6,4 @@ select arrayReduce('simpleLinearRegression', [0], [0]);
select arrayReduce('simpleLinearRegression', [3, 4], [3, 3]);
select arrayReduce('simpleLinearRegression', [3, 3], [3, 4]);
select arrayReduce('simpleLinearRegression', emptyArrayUInt8(), emptyArrayUInt8());
select arrayReduce('simpleLinearRegression', [1, 2, 3, 4], [1000000000, 1100000000, 1200000000, 1300000000]);