dbms: Server: Small style cleanup + numerical stability improvement for covariance and correlation. [#METR-16188]

This commit is contained in:
Alexey Arno 2015-05-20 12:17:25 +03:00
parent a29d67c82d
commit 76dcc45edf

View File

@ -12,6 +12,27 @@
namespace DB
{
namespace
{
/// Эта функция возвращает true если оба значения велики и сравнимы.
/// Она употребляется для вычисления среднего значения путём слияния двух источников.
/// Ибо если размеры обоих источников велики и сравнимы, то надо применить особенную
/// формулу гарантирующую больше стабильности.
bool areComparable(UInt64 a, UInt64 b)
{
const Float64 sensitivity = 0.001;
const UInt64 threshold = 10000;
if ((a == 0) || (b == 0))
return false;
auto res = std::minmax(a, b);
return (((1 - static_cast<Float64>(res.first) / res.second) < sensitivity) && (res.first > threshold));
}
}
/** Статистические аггрегатные функции:
* varSamp - выборочная дисперсия
* stddevSamp - среднее выборочное квадратичное отклонение
@ -52,12 +73,8 @@ public:
Float64 factor = static_cast<Float64>(count * source.count) / total_count;
Float64 delta = mean - source.mean;
auto res = std::minmax(count, source.count);
if (((1 - static_cast<Float64>(res.first) / res.second) < 0.001) && (res.first > 10000))
{
/// Эта формула более стабильная, когда размеры обоих источников велики и сравнимы.
if (areComparable(count, source.count))
mean = (source.count * source.mean + count * mean) / total_count;
}
else
mean = source.mean + delta * (static_cast<Float64>(count) / total_count);
@ -93,7 +110,9 @@ private:
/** Основной код для реализации функций varSamp, stddevSamp, varPop, stddevPop.
*/
template<typename T, typename Op>
class AggregateFunctionVariance final : public IUnaryAggregateFunction<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op> >
class AggregateFunctionVariance final
: public IUnaryAggregateFunction<AggregateFunctionVarianceData<T, Op>,
AggregateFunctionVariance<T, Op> >
{
public:
String getName() const override { return Op::name; }
@ -212,12 +231,12 @@ public:
void update(const IColumn & column_left, const IColumn & column_right, size_t row_num)
{
T left_received = static_cast<const ColumnVector<T> &>(column_left).getData()[row_num];
Float64 val_left = static_cast<Float64>(left_received);
Float64 left_delta = val_left - left_mean;
Float64 left_val = static_cast<Float64>(left_received);
Float64 left_delta = left_val - left_mean;
U right_received = static_cast<const ColumnVector<U> &>(column_right).getData()[row_num];
Float64 val_right = static_cast<Float64>(right_received);
Float64 right_delta = val_right - right_mean;
Float64 right_val = static_cast<Float64>(right_received);
Float64 right_delta = right_val - right_mean;
Float64 old_right_mean = right_mean;
@ -225,12 +244,12 @@ public:
left_mean += left_delta / count;
right_mean += right_delta / count;
co_moment += (val_left - left_mean) * (val_right - old_right_mean);
co_moment += (left_val - left_mean) * (right_val - old_right_mean);
if (compute_marginal_moments)
{
left_m2 += left_delta * (val_left - left_mean);
right_m2 += right_delta * (val_right - right_mean);
left_m2 += left_delta * (left_val - left_mean);
right_m2 += right_delta * (right_val - right_mean);
}
}
@ -244,8 +263,17 @@ public:
Float64 left_delta = left_mean - source.left_mean;
Float64 right_delta = right_mean - source.right_mean;
left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count);
right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count);
if (areComparable(count, source.count))
{
left_mean = (source.count * source.left_mean + count * left_mean) / total_count;
right_mean = (source.count * source.right_mean + count * right_mean) / total_count;
}
else
{
left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count);
right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count);
}
co_moment += source.co_moment + left_delta * right_delta * factor;
count = total_count;