2017-12-20 07:36:30 +00:00
# pragma once
2017-12-26 19:00:20 +00:00
# include <cmath>
2017-12-20 07:36:30 +00:00
# include <Common/RadixSort.h>
2017-12-26 19:00:20 +00:00
# include <Common/PODArray.h>
# include <IO/WriteBuffer.h>
# include <IO/ReadBuffer.h>
# include <IO/VarInt.h>
2017-12-20 07:36:30 +00:00
namespace DB
{
namespace ErrorCodes
{
extern const int TOO_LARGE_ARRAY_SIZE ;
2020-11-12 13:13:30 +00:00
extern const int CANNOT_PARSE_INPUT_ASSERTION_FAILED ;
2017-12-20 07:36:30 +00:00
}
/** The algorithm was implemented by Alexei Borzenkov https://github.com/snaury
* He owns the authorship of the code and half the comments in this namespace ,
* except for merging , serialization , and sorting , as well as selecting types and other changes .
* We thank Alexei Borzenkov for writing the original code .
*/
/** Implementation of t-digest algorithm (https://github.com/tdunning/t-digest).
* This option is very similar to MergingDigest on java , however the decision about
* the union is accepted based on the original condition from the article
* ( via a size constraint , using the approximation of the quantile of each
* centroid , not the distance on the curve of the position of their boundaries ) . MergingDigest
* on java gives significantly fewer centroids than this variant , that
* negatively affects accuracy with the same compression factor , but gives
* size guarantees . The author himself on the proposal for this variant said that
* the size of the digest grows like O ( log ( n ) ) , while the version on java
* does not depend on the expected number of points . Also an variant on java
* uses asin , which slows down the algorithm a bit .
*/
2017-12-20 20:25:22 +00:00
template < typename T >
2020-10-02 17:07:54 +00:00
class TDigest
2017-12-20 07:36:30 +00:00
{
2017-12-20 20:25:22 +00:00
using Value = Float32 ;
using Count = Float32 ;
2020-11-04 14:14:00 +00:00
using BetterFloat = Float64 ; // For intermediate results and sum(Count). Must have better precision, than Count
2017-12-20 20:25:22 +00:00
2017-12-20 07:36:30 +00:00
/** The centroid stores the weight of points around their mean value
*/
struct Centroid
{
Value mean ;
Count count ;
Centroid ( ) = default ;
2019-08-03 11:02:40 +00:00
explicit Centroid ( Value mean_ , Count count_ )
: mean ( mean_ )
, count ( count_ )
2017-12-20 07:36:30 +00:00
{ }
bool operator < ( const Centroid & other ) const
{
return mean < other . mean ;
}
} ;
/** :param epsilon: value \delta from the article - error in the range
* quantile 0.5 ( default is 0.01 , i . e . 1 % )
2020-11-08 21:44:41 +00:00
* if you change epsilon , you must also change max_centroids
* : param max_centroids : depends on epsilon , the better accuracy , the more centroids you need
* to describe data with this accuracy . Read article before changing .
2017-12-20 07:36:30 +00:00
* : param max_unmerged : when accumulating count of new points beyond this
* value centroid compression is triggered
* ( default is 2048 , the higher the value - the
* more memory is required , but amortization of execution time increases )
2020-11-08 21:44:41 +00:00
* Change freely anytime .
2017-12-20 07:36:30 +00:00
*/
struct Params
{
Value epsilon = 0.01 ;
2020-11-08 21:44:41 +00:00
size_t max_centroids = 2048 ;
2017-12-20 07:36:30 +00:00
size_t max_unmerged = 2048 ;
} ;
2020-11-08 21:44:41 +00:00
/** max_centroids_deserialize should be >= all max_centroids ever used in production.
* This is security parameter , preventing allocation of too much centroids in deserialize , so can be relatively large .
*/
static constexpr size_t max_centroids_deserialize = 65536 ;
2017-12-20 07:36:30 +00:00
2020-11-08 21:44:41 +00:00
static constexpr Params params { } ;
2017-12-20 07:36:30 +00:00
2020-11-08 21:44:41 +00:00
static constexpr size_t bytes_in_arena = 128 - sizeof ( PODArray < Centroid > ) - sizeof ( BetterFloat ) - sizeof ( size_t ) ; // If alignment is imperfect, sizeof(TDigest) will be more than naively expected
2020-10-02 17:07:54 +00:00
using Centroids = PODArrayWithStackMemory < Centroid , bytes_in_arena > ;
2017-12-20 07:36:30 +00:00
2020-10-02 17:07:54 +00:00
Centroids centroids ;
2020-11-04 14:14:00 +00:00
BetterFloat count = 0 ;
size_t unmerged = 0 ;
2017-12-20 07:36:30 +00:00
struct RadixSortTraits
{
using Element = Centroid ;
2020-05-23 14:28:05 +00:00
using Result = Element ;
2017-12-20 07:36:30 +00:00
using Key = Value ;
using CountType = UInt32 ;
using KeyBits = UInt32 ;
static constexpr size_t PART_SIZE_BITS = 8 ;
using Transform = RadixSortFloatTransform < KeyBits > ;
using Allocator = RadixSortMallocAllocator ;
/// The function to get the key from an array element.
static Key & extractKey ( Element & elem ) { return elem . mean ; }
2020-05-23 14:28:05 +00:00
static Result & extractResult ( Element & elem ) { return elem ; }
2017-12-20 07:36:30 +00:00
} ;
/** Adds a centroid `c` to the digest
2020-11-04 14:14:00 +00:00
* centroid must be valid , validity is checked in add ( ) , deserialize ( ) and is maintained by compress ( )
2017-12-20 07:36:30 +00:00
*/
2017-12-20 20:25:22 +00:00
void addCentroid ( const Centroid & c )
2017-12-20 07:36:30 +00:00
{
2020-10-02 17:07:54 +00:00
centroids . push_back ( c ) ;
2017-12-20 07:36:30 +00:00
count + = c . count ;
+ + unmerged ;
2020-11-08 21:44:41 +00:00
if ( unmerged > params . max_unmerged )
2017-12-20 20:25:22 +00:00
compress ( ) ;
2017-12-20 07:36:30 +00:00
}
2020-11-13 16:04:53 +00:00
void compressBrute ( ) {
2020-11-08 21:44:41 +00:00
if ( centroids . size ( ) < = params . max_centroids )
return ;
const size_t batch_size = ( centroids . size ( ) + params . max_centroids - 1 ) / params . max_centroids ; // at least 2
auto l = centroids . begin ( ) ;
auto r = std : : next ( l ) ;
BetterFloat sum = 0 ;
BetterFloat l_mean = l - > mean ; // We have high-precision temporaries for numeric stability
BetterFloat l_count = l - > count ;
size_t batch_pos = 0 ;
for ( ; r ! = centroids . end ( ) ; + + r )
{
if ( batch_pos < batch_size - 1 )
{
/// The left column "eats" the right. Middle of the batch
l_count + = r - > count ;
l_mean + = r - > count * ( r - > mean - l_mean ) / l_count ; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
l - > mean = l_mean ;
l - > count = l_count ;
batch_pos + = 1 ;
}
else
{
// End of the batch, start the next one
sum + = l - > count ; // Not l_count, otherwise actual sum of elements will be different
+ + l ;
/// We skip all the values "eaten" earlier.
* l = * r ;
l_mean = l - > mean ;
l_count = l - > count ;
batch_pos = 0 ;
}
}
count = sum + l_count ; // Update count, it might be different due to += inaccuracy
centroids . resize ( l - centroids . begin ( ) + 1 ) ;
// Here centroids.size() <= params.max_centroids
}
2017-12-20 07:36:30 +00:00
2020-10-02 17:07:54 +00:00
public :
2017-12-20 07:36:30 +00:00
/** Performs compression of accumulated centroids
* When merging , the invariant is retained to the maximum size of each
* centroid that does not exceed ` 4 q ( 1 - q ) \ delta N ` .
*/
2017-12-20 20:25:22 +00:00
void compress ( )
2017-12-20 07:36:30 +00:00
{
2020-11-08 21:44:41 +00:00
if ( unmerged > 0 | | centroids . size ( ) > params . max_centroids )
2017-12-20 07:36:30 +00:00
{
2020-11-08 21:44:41 +00:00
// unmerged > 0 implies centroids.size() > 0, hence *l is valid below
2020-10-02 17:07:54 +00:00
RadixSort < RadixSortTraits > : : executeLSD ( centroids . data ( ) , centroids . size ( ) ) ;
2017-12-20 07:36:30 +00:00
2020-11-08 21:44:41 +00:00
/// A pair of consecutive bars of the histogram.
auto l = centroids . begin ( ) ;
auto r = std : : next ( l ) ;
const BetterFloat count_epsilon_4 = count * params . epsilon * 4 ; // Compiler is unable to do this optimization
BetterFloat sum = 0 ;
BetterFloat l_mean = l - > mean ; // We have high-precision temporaries for numeric stability
BetterFloat l_count = l - > count ;
while ( r ! = centroids . end ( ) )
2017-12-20 07:36:30 +00:00
{
2020-11-08 21:44:41 +00:00
if ( l - > mean = = r - > mean ) // Perfect aggregation (fast). We compare l->mean, not l_mean, to avoid identical elements after compress
2017-12-20 07:36:30 +00:00
{
2020-11-08 21:44:41 +00:00
l_count + = r - > count ;
l - > count = l_count ;
2017-12-20 07:36:30 +00:00
+ + r ;
2020-11-08 21:44:41 +00:00
continue ;
2017-12-20 07:36:30 +00:00
}
2020-11-08 21:44:41 +00:00
// we use quantile which gives us the smallest error
/// The ratio of the part of the histogram to l, including the half l to the entire histogram. That is, what level quantile in position l.
BetterFloat ql = ( sum + l_count * 0.5 ) / count ;
BetterFloat err = ql * ( 1 - ql ) ;
/// The ratio of the portion of the histogram to l, including l and half r to the entire histogram. That is, what level is the quantile in position r.
BetterFloat qr = ( sum + l_count + r - > count * 0.5 ) / count ;
BetterFloat err2 = qr * ( 1 - qr ) ;
2017-12-20 07:36:30 +00:00
2020-11-08 21:44:41 +00:00
if ( err > err2 )
err = err2 ;
BetterFloat k = count_epsilon_4 * err ;
/** The ratio of the weight of the glued column pair to all values is not greater,
* than epsilon multiply by a certain quadratic coefficient , which in the median is 1 ( 4 * 1 / 2 * 1 / 2 ) ,
* and at the edges decreases and is approximately equal to the distance to the edge * 4.
*/
if ( l_count + r - > count < = k )
{
// it is possible to merge left and right
/// The left column "eats" the right.
l_count + = r - > count ;
l_mean + = r - > count * ( r - > mean - l_mean ) / l_count ; // Symmetric algo (M1*C1 + M2*C2)/(C1+C2) is numerically better, but slower
l - > mean = l_mean ;
l - > count = l_count ;
}
else
{
// not enough capacity, check the next pair
sum + = l - > count ; // Not l_count, otherwise actual sum of elements will be different
+ + l ;
/// We skip all the values "eaten" earlier.
if ( l ! = r )
* l = * r ;
l_mean = l - > mean ;
l_count = l - > count ;
}
+ + r ;
2017-12-20 07:36:30 +00:00
}
2020-11-08 21:44:41 +00:00
count = sum + l_count ; // Update count, it might be different due to += inaccuracy
2017-12-20 07:36:30 +00:00
2020-11-08 21:44:41 +00:00
/// At the end of the loop, all values to the right of l were "eaten".
centroids . resize ( l - centroids . begin ( ) + 1 ) ;
2017-12-20 07:36:30 +00:00
unmerged = 0 ;
}
2020-11-08 21:44:41 +00:00
// Ensures centroids.size() < max_centroids, independent of unprovable floating point blackbox above
2020-11-13 16:04:53 +00:00
compressBrute ( ) ;
2017-12-20 07:36:30 +00:00
}
/** Adds to the digest a change in `x` with a weight of `cnt` (default 1)
*/
2017-12-20 20:25:22 +00:00
void add ( T x , UInt64 cnt = 1 )
2017-12-20 07:36:30 +00:00
{
2020-11-08 21:44:41 +00:00
auto vx = static_cast < Value > ( x ) ;
if ( cnt = = 0 | | std : : isnan ( vx ) )
return ; // Count 0 breaks compress() assumptions, Nan breaks sort(). We treat them as no sample.
addCentroid ( Centroid { vx , static_cast < Count > ( cnt ) } ) ;
2017-12-20 07:36:30 +00:00
}
2020-10-02 17:07:54 +00:00
void merge ( const TDigest & other )
2017-12-20 07:36:30 +00:00
{
2020-10-02 17:07:54 +00:00
for ( const auto & c : other . centroids )
2017-12-20 20:25:22 +00:00
addCentroid ( c ) ;
2017-12-20 07:36:30 +00:00
}
void serialize ( WriteBuffer & buf )
{
2017-12-20 20:25:22 +00:00
compress ( ) ;
2020-10-02 17:07:54 +00:00
writeVarUInt ( centroids . size ( ) , buf ) ;
buf . write ( reinterpret_cast < const char * > ( centroids . data ( ) ) , centroids . size ( ) * sizeof ( centroids [ 0 ] ) ) ;
2017-12-20 07:36:30 +00:00
}
void deserialize ( ReadBuffer & buf )
{
size_t size = 0 ;
readVarUInt ( size , buf ) ;
2020-11-08 21:44:41 +00:00
if ( size > max_centroids_deserialize )
2020-10-02 17:07:54 +00:00
throw Exception ( " Too large t-digest centroids size " , ErrorCodes : : TOO_LARGE_ARRAY_SIZE ) ;
2017-12-20 07:36:30 +00:00
2020-11-08 21:44:41 +00:00
count = 0 ;
unmerged = 0 ;
2020-10-02 17:07:54 +00:00
centroids . resize ( size ) ;
2020-11-08 21:44:41 +00:00
// From now, TDigest will be in invalid state if exception is thrown.
2020-10-02 17:07:54 +00:00
buf . read ( reinterpret_cast < char * > ( centroids . data ( ) ) , size * sizeof ( centroids [ 0 ] ) ) ;
2018-12-26 21:02:39 +00:00
2020-11-08 21:44:41 +00:00
for ( const auto & c : centroids )
2020-11-04 14:14:00 +00:00
{
if ( c . count < = 0 | | std : : isnan ( c . count ) | | std : : isnan ( c . mean ) ) // invalid count breaks compress(), invalid mean breaks sort()
2020-11-10 20:03:36 +00:00
throw Exception ( " Invalid centroid " + std : : to_string ( c . count ) + " : " + std : : to_string ( c . mean ) , ErrorCodes : : CANNOT_PARSE_INPUT_ASSERTION_FAILED ) ;
2018-12-26 21:02:39 +00:00
count + = c . count ;
2020-11-04 14:14:00 +00:00
}
2020-11-08 21:44:41 +00:00
compress ( ) ; // Allows reading/writing TDigests with different epsilon/max_centroids params
2017-12-20 07:36:30 +00:00
}
2020-10-02 17:07:54 +00:00
Count getCount ( )
{
return count ;
}
const Centroids & getCentroids ( ) const
{
return centroids ;
}
void reset ( )
{
centroids . resize ( 0 ) ;
count = 0 ;
unmerged = 0 ;
}
} ;
template < typename T >
2020-10-02 19:20:35 +00:00
class QuantileTDigest
{
2020-10-02 17:07:54 +00:00
using Value = Float32 ;
using Count = Float32 ;
TDigest < T > main_tdigest ;
/** Linear interpolation at the point x on the line (x1, y1)..(x2, y2)
*/
static Value interpolate ( Value x , Value x1 , Value y1 , Value x2 , Value y2 )
{
double k = ( x - x1 ) / ( x2 - x1 ) ;
return y1 + k * ( y2 - y1 ) ;
}
public :
void add ( T x , UInt64 cnt = 1 )
{
2020-11-08 21:44:41 +00:00
main_tdigest . add ( x , cnt ) ;
2020-10-02 17:07:54 +00:00
}
void merge ( const QuantileTDigest & other )
{
main_tdigest . merge ( other . main_tdigest ) ;
}
void serialize ( WriteBuffer & buf )
{
main_tdigest . serialize ( buf ) ;
}
void deserialize ( ReadBuffer & buf )
{
main_tdigest . deserialize ( buf ) ;
}
2017-12-20 07:36:30 +00:00
/** Calculates the quantile q [0, 1] based on the digest.
* For an empty digest returns NaN .
*/
template < typename ResultType >
2017-12-20 20:25:22 +00:00
ResultType getImpl ( Float64 level )
2017-12-20 07:36:30 +00:00
{
2020-10-02 17:07:54 +00:00
auto & centroids = main_tdigest . getCentroids ( ) ;
if ( centroids . empty ( ) )
2017-12-20 07:36:30 +00:00
return std : : is_floating_point_v < ResultType > ? NAN : 0 ;
2020-10-02 17:07:54 +00:00
main_tdigest . compress ( ) ;
2017-12-20 07:36:30 +00:00
2020-10-02 17:07:54 +00:00
if ( centroids . size ( ) = = 1 )
return centroids . front ( ) . mean ;
2017-12-20 07:36:30 +00:00
2020-10-02 17:07:54 +00:00
Float64 x = level * main_tdigest . getCount ( ) ;
2017-12-20 20:25:22 +00:00
Float64 prev_x = 0 ;
2017-12-20 08:39:21 +00:00
Count sum = 0 ;
2020-10-02 17:07:54 +00:00
Value prev_mean = centroids . front ( ) . mean ;
2017-12-20 07:36:30 +00:00
2020-10-02 17:07:54 +00:00
for ( const auto & c : centroids )
2017-12-20 07:36:30 +00:00
{
2017-12-20 20:25:22 +00:00
Float64 current_x = sum + c . count * 0.5 ;
2017-12-20 07:36:30 +00:00
if ( current_x > = x )
return interpolate ( x , prev_x , prev_mean , current_x , c . mean ) ;
sum + = c . count ;
prev_mean = c . mean ;
prev_x = current_x ;
}
2020-10-02 17:07:54 +00:00
return centroids . back ( ) . mean ;
2017-12-20 07:36:30 +00:00
}
/** Get multiple quantiles (`size` parts).
* levels - an array of levels of the desired quantiles . They are in a random order .
* levels_permutation - array - permutation levels . The i - th position will be the index of the i - th ascending level in the ` levels ` array .
* result - the array where the results are added , in order of ` levels ` ,
*/
template < typename ResultType >
2017-12-20 20:25:22 +00:00
void getManyImpl ( const Float64 * levels , const size_t * levels_permutation , size_t size , ResultType * result )
2017-12-20 07:36:30 +00:00
{
2020-10-02 17:07:54 +00:00
auto & centroids = main_tdigest . getCentroids ( ) ;
if ( centroids . empty ( ) )
2017-12-20 07:36:30 +00:00
{
for ( size_t result_num = 0 ; result_num < size ; + + result_num )
result [ result_num ] = std : : is_floating_point_v < ResultType > ? NAN : 0 ;
return ;
}
2020-10-02 17:07:54 +00:00
main_tdigest . compress ( ) ;
2017-12-20 07:36:30 +00:00
2020-10-02 17:07:54 +00:00
if ( centroids . size ( ) = = 1 )
2017-12-20 07:36:30 +00:00
{
for ( size_t result_num = 0 ; result_num < size ; + + result_num )
2020-10-02 17:07:54 +00:00
result [ result_num ] = centroids . front ( ) . mean ;
2017-12-20 07:36:30 +00:00
return ;
}
2020-10-02 17:07:54 +00:00
Float64 x = levels [ levels_permutation [ 0 ] ] * main_tdigest . getCount ( ) ;
2017-12-20 20:25:22 +00:00
Float64 prev_x = 0 ;
2017-12-20 08:39:21 +00:00
Count sum = 0 ;
2020-10-02 17:07:54 +00:00
Value prev_mean = centroids . front ( ) . mean ;
2017-12-20 07:36:30 +00:00
size_t result_num = 0 ;
2020-10-02 17:07:54 +00:00
for ( const auto & c : centroids )
2017-12-20 07:36:30 +00:00
{
2017-12-20 20:25:22 +00:00
Float64 current_x = sum + c . count * 0.5 ;
2017-12-20 07:36:30 +00:00
while ( current_x > = x )
{
result [ levels_permutation [ result_num ] ] = interpolate ( x , prev_x , prev_mean , current_x , c . mean ) ;
+ + result_num ;
if ( result_num > = size )
return ;
2020-10-02 17:07:54 +00:00
x = levels [ levels_permutation [ result_num ] ] * main_tdigest . getCount ( ) ;
2017-12-20 07:36:30 +00:00
}
sum + = c . count ;
prev_mean = c . mean ;
prev_x = current_x ;
}
2020-10-02 17:07:54 +00:00
auto rest_of_results = centroids . back ( ) . mean ;
2017-12-20 07:36:30 +00:00
for ( ; result_num < size ; + + result_num )
result [ levels_permutation [ result_num ] ] = rest_of_results ;
}
2017-12-20 20:25:22 +00:00
T get ( Float64 level )
2017-12-20 07:36:30 +00:00
{
2017-12-20 20:25:22 +00:00
return getImpl < T > ( level ) ;
2017-12-20 07:36:30 +00:00
}
2017-12-21 01:19:25 +00:00
Float32 getFloat ( Float64 level )
2017-12-20 07:36:30 +00:00
{
2017-12-21 01:19:25 +00:00
return getImpl < Float32 > ( level ) ;
2017-12-20 07:36:30 +00:00
}
2017-12-20 20:25:22 +00:00
void getMany ( const Float64 * levels , const size_t * indices , size_t size , T * result )
2017-12-20 07:36:30 +00:00
{
getManyImpl ( levels , indices , size , result ) ;
}
2017-12-21 01:19:25 +00:00
void getManyFloat ( const Float64 * levels , const size_t * indices , size_t size , Float32 * result )
2017-12-20 07:36:30 +00:00
{
getManyImpl ( levels , indices , size , result ) ;
}
} ;
}