Merge pull request #15542 from Avogar/quantile_t_digest

Improve quantileTDigest performance
This commit is contained in:
alexey-milovidov 2020-10-04 15:53:36 +03:00 committed by GitHub
commit 6b39d248a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -36,7 +36,7 @@ namespace ErrorCodes
* uses asin, which slows down the algorithm a bit. * uses asin, which slows down the algorithm a bit.
*/ */
template <typename T> template <typename T>
class QuantileTDigest class TDigest
{ {
using Value = Float32; using Value = Float32;
using Count = Float32; using Count = Float32;
@ -86,20 +86,12 @@ class QuantileTDigest
/// The memory will be allocated to several elements at once, so that the state occupies 64 bytes. /// The memory will be allocated to several elements at once, so that the state occupies 64 bytes.
static constexpr size_t bytes_in_arena = 128 - sizeof(PODArray<Centroid>) - sizeof(Count) - sizeof(UInt32); static constexpr size_t bytes_in_arena = 128 - sizeof(PODArray<Centroid>) - sizeof(Count) - sizeof(UInt32);
using Summary = PODArrayWithStackMemory<Centroid, bytes_in_arena>; using Centroids = PODArrayWithStackMemory<Centroid, bytes_in_arena>;
Summary summary; Centroids centroids;
Count count = 0; Count count = 0;
UInt32 unmerged = 0; UInt32 unmerged = 0;
/** Linear interpolation at the point x on the line (x1, y1)..(x2, y2)
*/
static Value interpolate(Value x, Value x1, Value y1, Value x2, Value y2)
{
double k = (x - x1) / (x2 - x1);
return y1 + k * (y2 - y1);
}
struct RadixSortTraits struct RadixSortTraits
{ {
using Element = Centroid; using Element = Centroid;
@ -122,13 +114,14 @@ class QuantileTDigest
*/ */
void addCentroid(const Centroid & c) void addCentroid(const Centroid & c)
{ {
summary.push_back(c); centroids.push_back(c);
count += c.count; count += c.count;
++unmerged; ++unmerged;
if (unmerged >= params.max_unmerged) if (unmerged >= params.max_unmerged)
compress(); compress();
} }
public:
/** Performs compression of accumulated centroids /** Performs compression of accumulated centroids
* When merging, the invariant is retained to the maximum size of each * When merging, the invariant is retained to the maximum size of each
* centroid that does not exceed `4 q (1 - q) \ delta N`. * centroid that does not exceed `4 q (1 - q) \ delta N`.
@ -137,16 +130,16 @@ class QuantileTDigest
{ {
if (unmerged > 0) if (unmerged > 0)
{ {
RadixSort<RadixSortTraits>::executeLSD(summary.data(), summary.size()); RadixSort<RadixSortTraits>::executeLSD(centroids.data(), centroids.size());
if (summary.size() > 3) if (centroids.size() > 3)
{ {
/// A pair of consecutive bars of the histogram. /// A pair of consecutive bars of the histogram.
auto l = summary.begin(); auto l = centroids.begin();
auto r = std::next(l); auto r = std::next(l);
Count sum = 0; Count sum = 0;
while (r != summary.end()) while (r != centroids.end())
{ {
// we use quantile which gives us the smallest error // we use quantile which gives us the smallest error
@ -188,14 +181,13 @@ class QuantileTDigest
} }
/// At the end of the loop, all values to the right of l were "eaten". /// At the end of the loop, all values to the right of l were "eaten".
summary.resize(l - summary.begin() + 1); centroids.resize(l - centroids.begin() + 1);
} }
unmerged = 0; unmerged = 0;
} }
} }
public:
/** Adds to the digest a change in `x` with a weight of `cnt` (default 1) /** Adds to the digest a change in `x` with a weight of `cnt` (default 1)
*/ */
void add(T x, UInt64 cnt = 1) void add(T x, UInt64 cnt = 1)
@ -203,17 +195,17 @@ public:
addCentroid(Centroid(Value(x), Count(cnt))); addCentroid(Centroid(Value(x), Count(cnt)));
} }
void merge(const QuantileTDigest & other) void merge(const TDigest & other)
{ {
for (const auto & c : other.summary) for (const auto & c : other.centroids)
addCentroid(c); addCentroid(c);
} }
void serialize(WriteBuffer & buf) void serialize(WriteBuffer & buf)
{ {
compress(); compress();
writeVarUInt(summary.size(), buf); writeVarUInt(centroids.size(), buf);
buf.write(reinterpret_cast<const char *>(summary.data()), summary.size() * sizeof(summary[0])); buf.write(reinterpret_cast<const char *>(centroids.data()), centroids.size() * sizeof(centroids[0]));
} }
void deserialize(ReadBuffer & buf) void deserialize(ReadBuffer & buf)
@ -222,36 +214,113 @@ public:
readVarUInt(size, buf); readVarUInt(size, buf);
if (size > params.max_unmerged) if (size > params.max_unmerged)
throw Exception("Too large t-digest summary size", ErrorCodes::TOO_LARGE_ARRAY_SIZE); throw Exception("Too large t-digest centroids size", ErrorCodes::TOO_LARGE_ARRAY_SIZE);
summary.resize(size); centroids.resize(size);
buf.read(reinterpret_cast<char *>(summary.data()), size * sizeof(summary[0])); buf.read(reinterpret_cast<char *>(centroids.data()), size * sizeof(centroids[0]));
count = 0; count = 0;
for (const auto & c : summary) for (const auto & c : centroids)
count += c.count; count += c.count;
} }
Count getCount()
{
return count;
}
const Centroids & getCentroids() const
{
return centroids;
}
void reset()
{
centroids.resize(0);
count = 0;
unmerged = 0;
}
};
template <typename T>
class QuantileTDigest
{
using Value = Float32;
using Count = Float32;
/** We store two t-digests. When an amount of elements in sub_tdigest become more than merge_threshold
* we merge sub_tdigest in main_tdigest and reset sub_tdigest. This method is needed to decrease an amount of
* centroids in t-digest (experiments show that after merge_threshold the size of t-digest significantly grows,
* but merging two big t-digest decreases it).
*/
TDigest<T> main_tdigest;
TDigest<T> sub_tdigest;
size_t merge_threshold = 1e7;
/** Linear interpolation at the point x on the line (x1, y1)..(x2, y2)
*/
static Value interpolate(Value x, Value x1, Value y1, Value x2, Value y2)
{
double k = (x - x1) / (x2 - x1);
return y1 + k * (y2 - y1);
}
void mergeTDigests()
{
main_tdigest.merge(sub_tdigest);
sub_tdigest.reset();
}
public:
void add(T x, UInt64 cnt = 1)
{
if (sub_tdigest.getCount() >= merge_threshold)
mergeTDigests();
sub_tdigest.add(x, cnt);
}
void merge(const QuantileTDigest & other)
{
mergeTDigests();
main_tdigest.merge(other.main_tdigest);
main_tdigest.merge(other.sub_tdigest);
}
void serialize(WriteBuffer & buf)
{
mergeTDigests();
main_tdigest.serialize(buf);
}
void deserialize(ReadBuffer & buf)
{
sub_tdigest.reset();
main_tdigest.deserialize(buf);
}
/** Calculates the quantile q [0, 1] based on the digest. /** Calculates the quantile q [0, 1] based on the digest.
* For an empty digest returns NaN. * For an empty digest returns NaN.
*/ */
template <typename ResultType> template <typename ResultType>
ResultType getImpl(Float64 level) ResultType getImpl(Float64 level)
{ {
if (summary.empty()) mergeTDigests();
auto & centroids = main_tdigest.getCentroids();
if (centroids.empty())
return std::is_floating_point_v<ResultType> ? NAN : 0; return std::is_floating_point_v<ResultType> ? NAN : 0;
compress(); main_tdigest.compress();
if (summary.size() == 1) if (centroids.size() == 1)
return summary.front().mean; return centroids.front().mean;
Float64 x = level * count; Float64 x = level * main_tdigest.getCount();
Float64 prev_x = 0; Float64 prev_x = 0;
Count sum = 0; Count sum = 0;
Value prev_mean = summary.front().mean; Value prev_mean = centroids.front().mean;
for (const auto & c : summary) for (const auto & c : centroids)
{ {
Float64 current_x = sum + c.count * 0.5; Float64 current_x = sum + c.count * 0.5;
@ -263,7 +332,7 @@ public:
prev_x = current_x; prev_x = current_x;
} }
return summary.back().mean; return centroids.back().mean;
} }
/** Get multiple quantiles (`size` parts). /** Get multiple quantiles (`size` parts).
@ -274,29 +343,32 @@ public:
template <typename ResultType> template <typename ResultType>
void getManyImpl(const Float64 * levels, const size_t * levels_permutation, size_t size, ResultType * result) void getManyImpl(const Float64 * levels, const size_t * levels_permutation, size_t size, ResultType * result)
{ {
if (summary.empty()) mergeTDigests();
auto & centroids = main_tdigest.getCentroids();
if (centroids.empty())
{ {
for (size_t result_num = 0; result_num < size; ++result_num) for (size_t result_num = 0; result_num < size; ++result_num)
result[result_num] = std::is_floating_point_v<ResultType> ? NAN : 0; result[result_num] = std::is_floating_point_v<ResultType> ? NAN : 0;
return; return;
} }
compress(); main_tdigest.compress();
if (summary.size() == 1) if (centroids.size() == 1)
{ {
for (size_t result_num = 0; result_num < size; ++result_num) for (size_t result_num = 0; result_num < size; ++result_num)
result[result_num] = summary.front().mean; result[result_num] = centroids.front().mean;
return; return;
} }
Float64 x = levels[levels_permutation[0]] * count; Float64 x = levels[levels_permutation[0]] * main_tdigest.getCount();
Float64 prev_x = 0; Float64 prev_x = 0;
Count sum = 0; Count sum = 0;
Value prev_mean = summary.front().mean; Value prev_mean = centroids.front().mean;
size_t result_num = 0; size_t result_num = 0;
for (const auto & c : summary) for (const auto & c : centroids)
{ {
Float64 current_x = sum + c.count * 0.5; Float64 current_x = sum + c.count * 0.5;
@ -308,7 +380,7 @@ public:
if (result_num >= size) if (result_num >= size)
return; return;
x = levels[levels_permutation[result_num]] * count; x = levels[levels_permutation[result_num]] * main_tdigest.getCount();
} }
sum += c.count; sum += c.count;
@ -316,7 +388,7 @@ public:
prev_x = current_x; prev_x = current_x;
} }
auto rest_of_results = summary.back().mean; auto rest_of_results = centroids.back().mean;
for (; result_num < size; ++result_num) for (; result_num < size; ++result_num)
result[levels_permutation[result_num]] = rest_of_results; result[levels_permutation[result_num]] = rest_of_results;
} }