From 541cd638ba7b4bbbfbd3d01aea7c5a7db0f3044a Mon Sep 17 00:00:00 2001 From: hcz Date: Fri, 24 Apr 2020 16:33:09 +0800 Subject: [PATCH] Fix overflow in simpleLinearRegression --- .../AggregateFunctionSimpleLinearRegression.h | 26 +++++++++---------- .../0_stateless/00917_least_sqr.reference | 1 + tests/queries/0_stateless/00917_least_sqr.sql | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionSimpleLinearRegression.h b/src/AggregateFunctions/AggregateFunctionSimpleLinearRegression.h index c9dd22a0649..db4e57c0c6c 100644 --- a/src/AggregateFunctions/AggregateFunctionSimpleLinearRegression.h +++ b/src/AggregateFunctions/AggregateFunctionSimpleLinearRegression.h @@ -18,16 +18,16 @@ namespace ErrorCodes { } -template +template struct AggregateFunctionSimpleLinearRegressionData final { size_t count = 0; - Ret sum_x = 0; - Ret sum_y = 0; - Ret sum_xx = 0; - Ret sum_xy = 0; + T sum_x = 0; + T sum_y = 0; + T sum_xx = 0; + T sum_xy = 0; - void add(X x, Y y) + void add(T x, T y) { count += 1; sum_x += x; @@ -63,20 +63,20 @@ struct AggregateFunctionSimpleLinearRegressionData final readBinary(sum_xy, buf); } - Ret getK() const + T getK() const { - Ret divisor = sum_xx * count - sum_x * sum_x; + T divisor = sum_xx * count - sum_x * sum_x; if (divisor == 0) - return std::numeric_limits::quiet_NaN(); + return std::numeric_limits::quiet_NaN(); return (sum_xy * count - sum_x * sum_y) / divisor; } - Ret getB(Ret k) const + T getB(T k) const { if (count == 0) - return std::numeric_limits::quiet_NaN(); + return std::numeric_limits::quiet_NaN(); return (sum_y - k * sum_x) / count; } @@ -86,7 +86,7 @@ struct AggregateFunctionSimpleLinearRegressionData final /// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation. template class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper< - AggregateFunctionSimpleLinearRegressionData, + AggregateFunctionSimpleLinearRegressionData, AggregateFunctionSimpleLinearRegression > { @@ -96,7 +96,7 @@ public: const Array & params ): IAggregateFunctionDataHelper< - AggregateFunctionSimpleLinearRegressionData, + AggregateFunctionSimpleLinearRegressionData, AggregateFunctionSimpleLinearRegression > {arguments, params} { diff --git a/tests/queries/0_stateless/00917_least_sqr.reference b/tests/queries/0_stateless/00917_least_sqr.reference index 8abd62892db..ee1e3115b0e 100644 --- a/tests/queries/0_stateless/00917_least_sqr.reference +++ b/tests/queries/0_stateless/00917_least_sqr.reference @@ -6,3 +6,4 @@ (0,3) (nan,nan) (nan,nan) +(100000000,900000000) diff --git a/tests/queries/0_stateless/00917_least_sqr.sql b/tests/queries/0_stateless/00917_least_sqr.sql index b6171c216d2..5fa01830231 100644 --- a/tests/queries/0_stateless/00917_least_sqr.sql +++ b/tests/queries/0_stateless/00917_least_sqr.sql @@ -6,4 +6,4 @@ select arrayReduce('simpleLinearRegression', [0], [0]); select arrayReduce('simpleLinearRegression', [3, 4], [3, 3]); select arrayReduce('simpleLinearRegression', [3, 3], [3, 4]); select arrayReduce('simpleLinearRegression', emptyArrayUInt8(), emptyArrayUInt8()); - +select arrayReduce('simpleLinearRegression', [1, 2, 3, 4], [1000000000, 1100000000, 1200000000, 1300000000]);