mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-22 15:42:02 +00:00
dbms: Server: Small style cleanup + numerical stability improvement for covariance and correlation. [#METR-16188]
This commit is contained in:
parent
a29d67c82d
commit
76dcc45edf
@ -12,6 +12,27 @@
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
/// Эта функция возвращает true если оба значения велики и сравнимы.
|
||||
/// Она употребляется для вычисления среднего значения путём слияния двух источников.
|
||||
/// Ибо если размеры обоих источников велики и сравнимы, то надо применить особенную
|
||||
/// формулу гарантирующую больше стабильности.
|
||||
bool areComparable(UInt64 a, UInt64 b)
|
||||
{
|
||||
const Float64 sensitivity = 0.001;
|
||||
const UInt64 threshold = 10000;
|
||||
|
||||
if ((a == 0) || (b == 0))
|
||||
return false;
|
||||
|
||||
auto res = std::minmax(a, b);
|
||||
return (((1 - static_cast<Float64>(res.first) / res.second) < sensitivity) && (res.first > threshold));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/** Статистические аггрегатные функции:
|
||||
* varSamp - выборочная дисперсия
|
||||
* stddevSamp - среднее выборочное квадратичное отклонение
|
||||
@ -52,12 +73,8 @@ public:
|
||||
Float64 factor = static_cast<Float64>(count * source.count) / total_count;
|
||||
Float64 delta = mean - source.mean;
|
||||
|
||||
auto res = std::minmax(count, source.count);
|
||||
if (((1 - static_cast<Float64>(res.first) / res.second) < 0.001) && (res.first > 10000))
|
||||
{
|
||||
/// Эта формула более стабильная, когда размеры обоих источников велики и сравнимы.
|
||||
if (areComparable(count, source.count))
|
||||
mean = (source.count * source.mean + count * mean) / total_count;
|
||||
}
|
||||
else
|
||||
mean = source.mean + delta * (static_cast<Float64>(count) / total_count);
|
||||
|
||||
@ -93,7 +110,9 @@ private:
|
||||
/** Основной код для реализации функций varSamp, stddevSamp, varPop, stddevPop.
|
||||
*/
|
||||
template<typename T, typename Op>
|
||||
class AggregateFunctionVariance final : public IUnaryAggregateFunction<AggregateFunctionVarianceData<T, Op>, AggregateFunctionVariance<T, Op> >
|
||||
class AggregateFunctionVariance final
|
||||
: public IUnaryAggregateFunction<AggregateFunctionVarianceData<T, Op>,
|
||||
AggregateFunctionVariance<T, Op> >
|
||||
{
|
||||
public:
|
||||
String getName() const override { return Op::name; }
|
||||
@ -212,12 +231,12 @@ public:
|
||||
void update(const IColumn & column_left, const IColumn & column_right, size_t row_num)
|
||||
{
|
||||
T left_received = static_cast<const ColumnVector<T> &>(column_left).getData()[row_num];
|
||||
Float64 val_left = static_cast<Float64>(left_received);
|
||||
Float64 left_delta = val_left - left_mean;
|
||||
Float64 left_val = static_cast<Float64>(left_received);
|
||||
Float64 left_delta = left_val - left_mean;
|
||||
|
||||
U right_received = static_cast<const ColumnVector<U> &>(column_right).getData()[row_num];
|
||||
Float64 val_right = static_cast<Float64>(right_received);
|
||||
Float64 right_delta = val_right - right_mean;
|
||||
Float64 right_val = static_cast<Float64>(right_received);
|
||||
Float64 right_delta = right_val - right_mean;
|
||||
|
||||
Float64 old_right_mean = right_mean;
|
||||
|
||||
@ -225,12 +244,12 @@ public:
|
||||
|
||||
left_mean += left_delta / count;
|
||||
right_mean += right_delta / count;
|
||||
co_moment += (val_left - left_mean) * (val_right - old_right_mean);
|
||||
co_moment += (left_val - left_mean) * (right_val - old_right_mean);
|
||||
|
||||
if (compute_marginal_moments)
|
||||
{
|
||||
left_m2 += left_delta * (val_left - left_mean);
|
||||
right_m2 += right_delta * (val_right - right_mean);
|
||||
left_m2 += left_delta * (left_val - left_mean);
|
||||
right_m2 += right_delta * (right_val - right_mean);
|
||||
}
|
||||
}
|
||||
|
||||
@ -244,8 +263,17 @@ public:
|
||||
Float64 left_delta = left_mean - source.left_mean;
|
||||
Float64 right_delta = right_mean - source.right_mean;
|
||||
|
||||
left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count);
|
||||
right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count);
|
||||
if (areComparable(count, source.count))
|
||||
{
|
||||
left_mean = (source.count * source.left_mean + count * left_mean) / total_count;
|
||||
right_mean = (source.count * source.right_mean + count * right_mean) / total_count;
|
||||
}
|
||||
else
|
||||
{
|
||||
left_mean = source.left_mean + left_delta * (static_cast<Float64>(count) / total_count);
|
||||
right_mean = source.right_mean + right_delta * (static_cast<Float64>(count) / total_count);
|
||||
}
|
||||
|
||||
co_moment += source.co_moment + left_delta * right_delta * factor;
|
||||
count = total_count;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user