Improved numeric stability and stricter invariants in TDigest. Fixes bug when TDigest centroids array will grow beyond reasonable means and trigger exception inTDigest::deserialize during database Merge operations

This commit is contained in:
Grigory Buteyko 2020-11-04 17:14:00 +03:00
parent 2fae1c3c31
commit dc51482e78

View File

@ -40,6 +40,7 @@ class TDigest
{
using Value = Float32;
using Count = Float32;
using BetterFloat = Float64; // For intermediate results and sum(Count). Must have better precision, than Count
/** The centroid stores the weight of points around their mean value
*/
@ -55,13 +56,6 @@ class TDigest
, count(count_)
{}
Centroid & operator+=(const Centroid & other)
{
count += other.count;
mean += other.count * (other.mean - mean) / count;
return *this;
}
bool operator<(const Centroid & other) const
{
return mean < other.mean;
@ -89,8 +83,8 @@ class TDigest
using Centroids = PODArrayWithStackMemory<Centroid, bytes_in_arena>;
Centroids centroids;
Count count = 0;
UInt32 unmerged = 0;
BetterFloat count = 0;
size_t unmerged = 0;
struct RadixSortTraits
{
@ -111,6 +105,7 @@ class TDigest
};
/** Adds a centroid `c` to the digest
* centroid must be valid, validity is checked in add(), deserialize() and is maintained by compress()
*/
void addCentroid(const Centroid & c)
{
@ -138,47 +133,63 @@ public:
auto l = centroids.begin();
auto r = std::next(l);
Count sum = 0;
const BetterFloat count_epsilon_4 = count * params.epsilon * 4; // Compiler is unable to do this optimization
BetterFloat sum = 0;
BetterFloat l_mean = l->mean; // We have high-precision temporaries for numeric stability
BetterFloat l_count = l->count;
while (r != centroids.end())
{
if (l->mean == r->mean) // Perfect aggregation (fast). We compare l->mean, not l_mean, to avoid identical elements after compress
{
l_count += r->count;
l->count = l_count;
++r;
continue;
}
// we use quantile which gives us the smallest error
/// The ratio of the part of the histogram to l, including the half l to the entire histogram. That is, what level quantile in position l.
Value ql = (sum + l->count * 0.5) / count;
Value err = ql * (1 - ql);
BetterFloat ql = (sum + l_count * 0.5) / count;
BetterFloat err = ql * (1 - ql);
/// The ratio of the portion of the histogram to l, including l and half r to the entire histogram. That is, what level is the quantile in position r.
Value qr = (sum + l->count + r->count * 0.5) / count;
Value err2 = qr * (1 - qr);
BetterFloat qr = (sum + l_count + r->count * 0.5) / count;
BetterFloat err2 = qr * (1 - qr);
if (err > err2)
err = err2;
Value k = 4 * count * err * params.epsilon;
BetterFloat k = count_epsilon_4 * err;
/** The ratio of the weight of the glued column pair to all values is not greater,
* than epsilon multiply by a certain quadratic coefficient, which in the median is 1 (4 * 1/2 * 1/2),
* and at the edges decreases and is approximately equal to the distance to the edge * 4.
*/
if (l->count + r->count <= k)
if (l_count + r->count <= k)
{
// it is possible to merge left and right
/// The left column "eats" the right.
*l += *r;
l_count += r->count;
l_mean += r->count * (r->mean - l_mean) / l_count; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
l->mean = l_mean;
l->count = l_count;
}
else
{
// not enough capacity, check the next pair
sum += l->count;
sum += l->count; // Not l_count, otherwise actual sum of elements will be different
++l;
/// We skip all the values "eaten" earlier.
if (l != r)
*l = *r;
l_mean = l->mean;
l_count = l->count;
}
++r;
}
count = sum + l_count; // Update count, changed due inaccurancy
/// At the end of the loop, all values to the right of l were "eaten".
centroids.resize(l - centroids.begin() + 1);
@ -192,6 +203,8 @@ public:
*/
void add(T x, UInt64 cnt = 1)
{
if (cnt == 0)
return; // Count 0 breaks compress() assumptions
addCentroid(Centroid(Value(x), Count(cnt)));
}
@ -220,8 +233,16 @@ public:
buf.read(reinterpret_cast<char *>(centroids.data()), size * sizeof(centroids[0]));
count = 0;
for (const auto & c : centroids)
for (size_t i = 0; i != centroids.size(); ++i)
{
Centroid & c = centroids[i];
if (c.count <= 0 || std::isnan(c.count) || std::isnan(c.mean)) // invalid count breaks compress(), invalid mean breaks sort()
{
centroids.resize(i); // Exception safety, without this line we will end up with TDigest with invalid centroids
throw std::runtime_error("Invalid centroid " + std::to_string(c.count) + ":" + std::to_string(c.mean));
}
count += c.count;
}
}
Count getCount()