Ensure that the variance is nonnegative

This commit is contained in:
Alexey Milovidov 2020-05-12 03:31:44 +03:00
parent d0b56a4c7d
commit a9f64b4c1c
3 changed files with 24 additions and 4 deletions

View File

@ -82,14 +82,17 @@ struct VarMoments
T NO_SANITIZE_UNDEFINED getPopulation() const
{
return (m[2] - m[1] * m[1] / m[0]) / m[0];
/// Due to numerical errors, the result can be slightly less than zero,
/// but it should be impossible. Trim to zero.
return std::max(T{}, (m[2] - m[1] * m[1] / m[0]) / m[0]);
}
T NO_SANITIZE_UNDEFINED getSample() const
{
if (m[0] == 0)
return std::numeric_limits<T>::quiet_NaN();
return (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1);
return std::max(T{}, (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1));
}
T NO_SANITIZE_UNDEFINED getMoment3() const
@ -180,7 +183,7 @@ struct VarMomentsDecimal
if (common::mulOverflow(getM(1), getM(1), tmp) ||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
return std::max(Float64{}, convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale));
}
Float64 getSample(UInt32 scale) const
@ -194,7 +197,7 @@ struct VarMomentsDecimal
if (common::mulOverflow(getM(1), getM(1), tmp) ||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale);
return std::max(Float64{}, convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale));
}
Float64 getMoment3(UInt32 scale) const

View File

@ -0,0 +1,8 @@
0
0
0
0
0
0
0
0

View File

@ -0,0 +1,9 @@
SELECT varSamp(0.1) FROM numbers(1000000);
SELECT varPop(0.1) FROM numbers(1000000);
SELECT stddevSamp(0.1) FROM numbers(1000000);
SELECT stddevPop(0.1) FROM numbers(1000000);
SELECT varSampStable(0.1) FROM numbers(1000000);
SELECT varPopStable(0.1) FROM numbers(1000000);
SELECT stddevSampStable(0.1) FROM numbers(1000000);
SELECT stddevPopStable(0.1) FROM numbers(1000000);