Fix overflow in simpleLinearRegression

This commit is contained in:
hcz 2020-04-24 16:33:09 +08:00
parent 4ecc86beca
commit 541cd638ba
3 changed files with 15 additions and 14 deletions

View File

@ -18,16 +18,16 @@ namespace ErrorCodes
{ {
} }
template <typename X, typename Y, typename Ret> template <typename T>
struct AggregateFunctionSimpleLinearRegressionData final struct AggregateFunctionSimpleLinearRegressionData final
{ {
size_t count = 0; size_t count = 0;
Ret sum_x = 0; T sum_x = 0;
Ret sum_y = 0; T sum_y = 0;
Ret sum_xx = 0; T sum_xx = 0;
Ret sum_xy = 0; T sum_xy = 0;
void add(X x, Y y) void add(T x, T y)
{ {
count += 1; count += 1;
sum_x += x; sum_x += x;
@ -63,20 +63,20 @@ struct AggregateFunctionSimpleLinearRegressionData final
readBinary(sum_xy, buf); readBinary(sum_xy, buf);
} }
Ret getK() const T getK() const
{ {
Ret divisor = sum_xx * count - sum_x * sum_x; T divisor = sum_xx * count - sum_x * sum_x;
if (divisor == 0) if (divisor == 0)
return std::numeric_limits<Ret>::quiet_NaN(); return std::numeric_limits<T>::quiet_NaN();
return (sum_xy * count - sum_x * sum_y) / divisor; return (sum_xy * count - sum_x * sum_y) / divisor;
} }
Ret getB(Ret k) const T getB(T k) const
{ {
if (count == 0) if (count == 0)
return std::numeric_limits<Ret>::quiet_NaN(); return std::numeric_limits<T>::quiet_NaN();
return (sum_y - k * sum_x) / count; return (sum_y - k * sum_x) / count;
} }
@ -86,7 +86,7 @@ struct AggregateFunctionSimpleLinearRegressionData final
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation. /// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
template <typename X, typename Y, typename Ret = Float64> template <typename X, typename Y, typename Ret = Float64>
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper< class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>, AggregateFunctionSimpleLinearRegressionData<Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret> AggregateFunctionSimpleLinearRegression<X, Y, Ret>
> >
{ {
@ -96,7 +96,7 @@ public:
const Array & params const Array & params
): ):
IAggregateFunctionDataHelper< IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>, AggregateFunctionSimpleLinearRegressionData<Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret> AggregateFunctionSimpleLinearRegression<X, Y, Ret>
> {arguments, params} > {arguments, params}
{ {

View File

@ -6,3 +6,4 @@
(0,3) (0,3)
(nan,nan) (nan,nan)
(nan,nan) (nan,nan)
(100000000,900000000)

View File

@ -6,4 +6,4 @@ select arrayReduce('simpleLinearRegression', [0], [0]);
select arrayReduce('simpleLinearRegression', [3, 4], [3, 3]); select arrayReduce('simpleLinearRegression', [3, 4], [3, 3]);
select arrayReduce('simpleLinearRegression', [3, 3], [3, 4]); select arrayReduce('simpleLinearRegression', [3, 3], [3, 4]);
select arrayReduce('simpleLinearRegression', emptyArrayUInt8(), emptyArrayUInt8()); select arrayReduce('simpleLinearRegression', emptyArrayUInt8(), emptyArrayUInt8());
select arrayReduce('simpleLinearRegression', [1, 2, 3, 4], [1000000000, 1100000000, 1200000000, 1300000000]);