From dc51482e78e2f56784e4ddc343b9d8027269ce34 Mon Sep 17 00:00:00 2001 From: Grigory Buteyko Date: Wed, 4 Nov 2020 17:14:00 +0300 Subject: [PATCH] Improved numeric stability and stricter invariants in TDigest. Fixes bug when TDigest centroids array will grow beyond reasonable means and trigger exception inTDigest::deserialize during database Merge operations --- src/AggregateFunctions/QuantileTDigest.h | 59 ++++++++++++++++-------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/src/AggregateFunctions/QuantileTDigest.h b/src/AggregateFunctions/QuantileTDigest.h index 02d43ede66d..c09797573c4 100644 --- a/src/AggregateFunctions/QuantileTDigest.h +++ b/src/AggregateFunctions/QuantileTDigest.h @@ -40,6 +40,7 @@ class TDigest { using Value = Float32; using Count = Float32; + using BetterFloat = Float64; // For intermediate results and sum(Count). Must have better precision, than Count /** The centroid stores the weight of points around their mean value */ @@ -55,13 +56,6 @@ class TDigest , count(count_) {} - Centroid & operator+=(const Centroid & other) - { - count += other.count; - mean += other.count * (other.mean - mean) / count; - return *this; - } - bool operator<(const Centroid & other) const { return mean < other.mean; @@ -89,8 +83,8 @@ class TDigest using Centroids = PODArrayWithStackMemory; Centroids centroids; - Count count = 0; - UInt32 unmerged = 0; + BetterFloat count = 0; + size_t unmerged = 0; struct RadixSortTraits { @@ -111,6 +105,7 @@ class TDigest }; /** Adds a centroid `c` to the digest + * centroid must be valid, validity is checked in add(), deserialize() and is maintained by compress() */ void addCentroid(const Centroid & c) { @@ -138,47 +133,63 @@ public: auto l = centroids.begin(); auto r = std::next(l); - Count sum = 0; + const BetterFloat count_epsilon_4 = count * params.epsilon * 4; // Compiler is unable to do this optimization + BetterFloat sum = 0; + BetterFloat l_mean = l->mean; // We have high-precision temporaries for numeric stability + BetterFloat l_count = l->count; while (r != centroids.end()) { + if (l->mean == r->mean) // Perfect aggregation (fast). We compare l->mean, not l_mean, to avoid identical elements after compress + { + l_count += r->count; + l->count = l_count; + ++r; + continue; + } // we use quantile which gives us the smallest error /// The ratio of the part of the histogram to l, including the half l to the entire histogram. That is, what level quantile in position l. - Value ql = (sum + l->count * 0.5) / count; - Value err = ql * (1 - ql); + BetterFloat ql = (sum + l_count * 0.5) / count; + BetterFloat err = ql * (1 - ql); /// The ratio of the portion of the histogram to l, including l and half r to the entire histogram. That is, what level is the quantile in position r. - Value qr = (sum + l->count + r->count * 0.5) / count; - Value err2 = qr * (1 - qr); + BetterFloat qr = (sum + l_count + r->count * 0.5) / count; + BetterFloat err2 = qr * (1 - qr); if (err > err2) err = err2; - Value k = 4 * count * err * params.epsilon; + BetterFloat k = count_epsilon_4 * err; /** The ratio of the weight of the glued column pair to all values is not greater, * than epsilon multiply by a certain quadratic coefficient, which in the median is 1 (4 * 1/2 * 1/2), * and at the edges decreases and is approximately equal to the distance to the edge * 4. */ - if (l->count + r->count <= k) + if (l_count + r->count <= k) { // it is possible to merge left and right /// The left column "eats" the right. - *l += *r; + l_count += r->count; + l_mean += r->count * (r->mean - l_mean) / l_count; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower + l->mean = l_mean; + l->count = l_count; } else { // not enough capacity, check the next pair - sum += l->count; + sum += l->count; // Not l_count, otherwise actual sum of elements will be different ++l; /// We skip all the values "eaten" earlier. if (l != r) *l = *r; + l_mean = l->mean; + l_count = l->count; } ++r; } + count = sum + l_count; // Update count, changed due inaccurancy /// At the end of the loop, all values to the right of l were "eaten". centroids.resize(l - centroids.begin() + 1); @@ -192,6 +203,8 @@ public: */ void add(T x, UInt64 cnt = 1) { + if (cnt == 0) + return; // Count 0 breaks compress() assumptions addCentroid(Centroid(Value(x), Count(cnt))); } @@ -220,8 +233,16 @@ public: buf.read(reinterpret_cast(centroids.data()), size * sizeof(centroids[0])); count = 0; - for (const auto & c : centroids) + for (size_t i = 0; i != centroids.size(); ++i) + { + Centroid & c = centroids[i]; + if (c.count <= 0 || std::isnan(c.count) || std::isnan(c.mean)) // invalid count breaks compress(), invalid mean breaks sort() + { + centroids.resize(i); // Exception safety, without this line we will end up with TDigest with invalid centroids + throw std::runtime_error("Invalid centroid " + std::to_string(c.count) + ":" + std::to_string(c.mean)); + } count += c.count; + } } Count getCount()