Merge pull request #37021 from excitoon-favorites/fixtdigest

Fixed problem with infs in `quantileTDigest`
This commit is contained in:
Kruglov Pavel 2022-05-16 15:21:59 +02:00 committed by GitHub
commit 5e34f48a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 143 additions and 18 deletions

View File

@ -103,8 +103,9 @@ class QuantileTDigest
*/
static Value interpolate(Value x, Value x1, Value y1, Value x2, Value y2)
{
/// Symmetric interpolation for better results with infinities.
double k = (x - x1) / (x2 - x1);
return y1 + k * (y2 - y1);
return (1 - k) * y1 + k * y2;
}
struct RadixSortTraits
@ -137,6 +138,11 @@ class QuantileTDigest
compress();
}
inline bool canBeMerged(const BetterFloat & l_mean, const Value & r_mean)
{
return l_mean == r_mean || (!std::isinf(l_mean) && !std::isinf(r_mean));
}
void compressBrute()
{
if (centroids.size() <= params.max_centroids)
@ -149,13 +155,17 @@ class QuantileTDigest
BetterFloat l_mean = l->mean; // We have high-precision temporaries for numeric stability
BetterFloat l_count = l->count;
size_t batch_pos = 0;
for (;r != centroids.end(); ++r)
for (; r != centroids.end(); ++r)
{
if (batch_pos < batch_size - 1)
{
/// The left column "eats" the right. Middle of the batch
l_count += r->count;
l_mean += r->count * (r->mean - l_mean) / l_count; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
if (r->mean != l_mean) /// Handling infinities of the same sign well.
{
l_mean += r->count * (r->mean - l_mean) / l_count; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
}
l->mean = l_mean;
l->count = l_count;
batch_pos += 1;
@ -163,8 +173,11 @@ class QuantileTDigest
else
{
// End of the batch, start the next one
sum += l->count; // Not l_count, otherwise actual sum of elements will be different
++l;
if (!std::isnan(l->mean)) /// Skip writing batch result if we compressed something to nan.
{
sum += l->count; // Not l_count, otherwise actual sum of elements will be different
++l;
}
/// We skip all the values "eaten" earlier.
*l = *r;
@ -173,8 +186,17 @@ class QuantileTDigest
batch_pos = 0;
}
}
count = sum + l_count; // Update count, it might be different due to += inaccuracy
centroids.resize(l - centroids.begin() + 1);
if (!std::isnan(l->mean))
{
count = sum + l_count; // Update count, it might be different due to += inaccuracy
centroids.resize(l - centroids.begin() + 1);
}
else /// Skip writing last batch if (super unlikely) it's nan.
{
count = sum;
centroids.resize(l - centroids.begin());
}
// Here centroids.size() <= params.max_centroids
}
@ -200,11 +222,8 @@ public:
BetterFloat l_count = l->count;
while (r != centroids.end())
{
/// N.B. Piece of logic which compresses the same singleton centroids into one centroid is removed
/// because: 1) singleton centroids are being processed in unusual way in recent version of algorithm
/// and such compression would break this logic;
/// 2) we shall not compress centroids further than `max_centroids` parameter requires because
/// this will lead to uneven compression.
/// N.B. We cannot merge all the same values into single centroids because this will lead to
/// unbalanced compression and wrong results.
/// For more information see: https://arxiv.org/abs/1902.04023
/// The ratio of the part of the histogram to l, including the half l to the entire histogram. That is, what level quantile in position l.
@ -225,12 +244,15 @@ public:
* and at the edges decreases and is approximately equal to the distance to the edge * 4.
*/
if (l_count + r->count <= k)
if (l_count + r->count <= k && canBeMerged(l_mean, r->mean))
{
// it is possible to merge left and right
/// The left column "eats" the right.
l_count += r->count;
l_mean += r->count * (r->mean - l_mean) / l_count; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
if (r->mean != l_mean) /// Handling infinities of the same sign well.
{
l_mean += r->count * (r->mean - l_mean) / l_count; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
}
l->mean = l_mean;
l->count = l_count;
}
@ -254,6 +276,7 @@ public:
centroids.resize(l - centroids.begin() + 1);
unmerged = 0;
}
// Ensures centroids.size() < max_centroids, independent of unprovable floating point blackbox above
compressBrute();
}
@ -298,10 +321,17 @@ public:
for (const auto & c : centroids)
{
if (c.count <= 0 || std::isnan(c.count) || std::isnan(c.mean)) // invalid count breaks compress(), invalid mean breaks sort()
if (c.count <= 0 || std::isnan(c.count)) // invalid count breaks compress()
throw Exception("Invalid centroid " + std::to_string(c.count) + ":" + std::to_string(c.mean), ErrorCodes::CANNOT_PARSE_INPUT_ASSERTION_FAILED);
count += c.count;
if (!std::isnan(c.mean))
{
count += c.count;
}
}
auto it = std::remove_if(centroids.begin(), centroids.end(), [](Centroid & c) { return std::isnan(c.mean); });
centroids.erase(it, centroids.end());
compress(); // Allows reading/writing TDigests with different epsilon/max_centroids params
}
@ -312,7 +342,7 @@ public:
ResultType getImpl(Float64 level)
{
if (centroids.empty())
return std::is_floating_point_v<ResultType> ? NAN : 0;
return std::is_floating_point_v<ResultType> ? std::numeric_limits<ResultType>::quiet_NaN() : 0;
compress();
@ -395,7 +425,6 @@ public:
while (current_x >= x)
{
if (x <= left)
result[levels_permutation[result_num]] = prev_mean;
else if (x >= right)

View File

@ -0,0 +1,42 @@
1
[-inf,-inf,-inf,nan,inf,inf,inf]
[-inf,-inf,-inf,nan,inf,inf,inf]
[0,0,0,inf,inf,inf,inf]
[-inf,-inf,-inf,-inf,0,0,0]
[-inf,-inf,-inf,0,inf,inf,inf]
[-inf,-inf,-inf,0,inf,inf,inf]
2
[-inf]
[-inf]
[inf]
3
[nan]
[inf]
[nan]
[-inf]
4
[nan]
[nan]
[nan]
[nan]
[0]
[0]
[0]
5
6
inf
inf
-inf
-inf
7
-inf
-inf
8
-inf
inf
-inf
-inf
inf
inf
-inf
inf

View File

@ -0,0 +1,54 @@
SELECT '1';
SELECT quantilesTDigestArray(0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99)(arrayResize(arrayResize([inf], 500000, -inf), 1000000, inf));
SELECT quantilesTDigestArray(0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99)(arrayResize(arrayResize([inf], 500000, inf), 1000000, -inf));
SELECT quantilesTDigestArray(0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99)(arrayResize(arrayResize([inf], 500000, inf), 1000000, 0));
SELECT quantilesTDigestArray(0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99)(arrayResize(arrayResize([inf], 500000, -inf), 1000000, 0));
SELECT quantilesTDigestArray(0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99)(arrayResize(arrayResize([0], 500000, inf), 1000000, -inf));
SELECT quantilesTDigestArray(0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99)(arrayResize(arrayResize([0], 500000, -inf), 1000000, inf));
SELECT '2';
SELECT quantilesTDigest(0.05)(x) FROM (SELECT inf*(number%2-0.5) x FROM numbers(300));
SELECT quantilesTDigest(0.5)(x) FROM (SELECT inf*(number%2-0.5) x FROM numbers(300));
SELECT quantilesTDigest(0.95)(x) FROM (SELECT inf*(number%2-0.5) x FROM numbers(300));
SELECT '3';
SELECT quantiles(0.5)(inf) FROM numbers(5);
SELECT quantiles(0.5)(inf) FROM numbers(300);
SELECT quantiles(0.5)(-inf) FROM numbers(5);
SELECT quantiles(0.5)(-inf) FROM numbers(300);
SELECT '4';
SELECT quantiles(0.5)(arrayJoin([inf, 0, -inf]));
SELECT quantiles(0.5)(arrayJoin([-inf, 0, inf]));
SELECT quantiles(0.5)(arrayJoin([inf, -inf, 0]));
SELECT quantiles(0.5)(arrayJoin([-inf, inf, 0]));
SELECT quantiles(0.5)(arrayJoin([inf, inf, 0, -inf, -inf, -0]));
SELECT quantiles(0.5)(arrayJoin([inf, -inf, 0, -inf, inf, -0]));
SELECT quantiles(0.5)(arrayJoin([-inf, -inf, 0, inf, inf, -0]));
SELECT '5';
DROP TABLE IF EXISTS issue32107;
CREATE TABLE issue32107(A Int64, s_quantiles AggregateFunction(quantilesTDigest(0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99), Float64)) ENGINE = AggregatingMergeTree ORDER BY A;
INSERT INTO issue32107 SELECT A, quantilesTDigestState(0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99)(x) FROM (SELECT 1 A, arrayJoin(cast([2.0, inf, number / 33333],'Array(Float64)')) x FROM numbers(100)) GROUP BY A;
OPTIMIZE TABLE issue32107 FINAL;
DROP TABLE IF EXISTS issue32107;
SELECT '6';
SELECT quantileTDigest(inf) FROM numbers(200);
SELECT quantileTDigest(inf) FROM numbers(500);
SELECT quantileTDigest(-inf) FROM numbers(200);
SELECT quantileTDigest(-inf) FROM numbers(500);
SELECT '7';
SELECT quantileTDigest(x) FROM (SELECT inf AS x UNION ALL SELECT -inf);
SELECT quantileTDigest(x) FROM (SELECT -inf AS x UNION ALL SELECT inf);
SELECT '8';
SELECT quantileTDigest(x) FROM (SELECT inf AS x UNION ALL SELECT -inf UNION ALL SELECT -inf);
SELECT quantileTDigest(x) FROM (SELECT inf AS x UNION ALL SELECT inf UNION ALL SELECT -inf);
SELECT quantileTDigest(x) FROM (SELECT -inf AS x UNION ALL SELECT -inf UNION ALL SELECT -inf);
SELECT quantileTDigest(x) FROM (SELECT -inf AS x UNION ALL SELECT inf UNION ALL SELECT -inf);
SELECT quantileTDigest(x) FROM (SELECT inf AS x UNION ALL SELECT -inf UNION ALL SELECT inf);
SELECT quantileTDigest(x) FROM (SELECT inf AS x UNION ALL SELECT inf UNION ALL SELECT inf);
SELECT quantileTDigest(x) FROM (SELECT -inf AS x UNION ALL SELECT -inf UNION ALL SELECT inf);
SELECT quantileTDigest(x) FROM (SELECT -inf AS x UNION ALL SELECT inf UNION ALL SELECT inf);