Remove unnecessary QuantileTDigest layer

This commit is contained in:
Pavel Kruglov 2020-11-17 14:01:51 +03:00
parent 1a02ec85a5
commit 7ecd207eac

View File

@ -37,7 +37,7 @@ namespace ErrorCodes
* uses asin, which slows down the algorithm a bit. * uses asin, which slows down the algorithm a bit.
*/ */
template <typename T> template <typename T>
class TDigest class QuantileTDigest
{ {
using Value = Float32; using Value = Float32;
using Count = Float32; using Count = Float32;
@ -95,6 +95,14 @@ class TDigest
BetterFloat count = 0; BetterFloat count = 0;
size_t unmerged = 0; size_t unmerged = 0;
/** Linear interpolation at the point x on the line (x1, y1)..(x2, y2)
*/
static Value interpolate(Value x, Value x1, Value y1, Value x2, Value y2)
{
double k = (x - x1) / (x2 - x1);
return y1 + k * (y2 - y1);
}
struct RadixSortTraits struct RadixSortTraits
{ {
using Element = Centroid; using Element = Centroid;
@ -124,7 +132,8 @@ class TDigest
if (unmerged > params.max_unmerged) if (unmerged > params.max_unmerged)
compress(); compress();
} }
void compressBrute() { void compressBrute()
{
if (centroids.size() <= params.max_centroids) if (centroids.size() <= params.max_centroids)
return; return;
const size_t batch_size = (centroids.size() + params.max_centroids - 1) / params.max_centroids; // at least 2 const size_t batch_size = (centroids.size() + params.max_centroids - 1) / params.max_centroids; // at least 2
@ -256,7 +265,7 @@ public:
addCentroid(Centroid{vx, static_cast<Count>(cnt)}); addCentroid(Centroid{vx, static_cast<Count>(cnt)});
} }
void merge(const TDigest & other) void merge(const QuantileTDigest & other)
{ {
for (const auto & c : other.centroids) for (const auto & c : other.centroids)
addCentroid(c); addCentroid(c);
@ -293,77 +302,21 @@ public:
compress(); // Allows reading/writing TDigests with different epsilon/max_centroids params compress(); // Allows reading/writing TDigests with different epsilon/max_centroids params
} }
Count getCount()
{
return count;
}
const Centroids & getCentroids() const
{
return centroids;
}
void reset()
{
centroids.resize(0);
count = 0;
unmerged = 0;
}
};
template <typename T>
class QuantileTDigest
{
using Value = Float32;
using Count = Float32;
TDigest<T> main_tdigest;
/** Linear interpolation at the point x on the line (x1, y1)..(x2, y2)
*/
static Value interpolate(Value x, Value x1, Value y1, Value x2, Value y2)
{
double k = (x - x1) / (x2 - x1);
return y1 + k * (y2 - y1);
}
public:
void add(T x, UInt64 cnt = 1)
{
main_tdigest.add(x, cnt);
}
void merge(const QuantileTDigest & other)
{
main_tdigest.merge(other.main_tdigest);
}
void serialize(WriteBuffer & buf)
{
main_tdigest.serialize(buf);
}
void deserialize(ReadBuffer & buf)
{
main_tdigest.deserialize(buf);
}
/** Calculates the quantile q [0, 1] based on the digest. /** Calculates the quantile q [0, 1] based on the digest.
* For an empty digest returns NaN. * For an empty digest returns NaN.
*/ */
template <typename ResultType> template <typename ResultType>
ResultType getImpl(Float64 level) ResultType getImpl(Float64 level)
{ {
auto & centroids = main_tdigest.getCentroids();
if (centroids.empty()) if (centroids.empty())
return std::is_floating_point_v<ResultType> ? NAN : 0; return std::is_floating_point_v<ResultType> ? NAN : 0;
main_tdigest.compress(); compress();
if (centroids.size() == 1) if (centroids.size() == 1)
return centroids.front().mean; return centroids.front().mean;
Float64 x = level * main_tdigest.getCount(); Float64 x = level * count;
Float64 prev_x = 0; Float64 prev_x = 0;
Count sum = 0; Count sum = 0;
Value prev_mean = centroids.front().mean; Value prev_mean = centroids.front().mean;
@ -391,7 +344,6 @@ public:
template <typename ResultType> template <typename ResultType>
void getManyImpl(const Float64 * levels, const size_t * levels_permutation, size_t size, ResultType * result) void getManyImpl(const Float64 * levels, const size_t * levels_permutation, size_t size, ResultType * result)
{ {
auto & centroids = main_tdigest.getCentroids();
if (centroids.empty()) if (centroids.empty())
{ {
for (size_t result_num = 0; result_num < size; ++result_num) for (size_t result_num = 0; result_num < size; ++result_num)
@ -399,7 +351,7 @@ public:
return; return;
} }
main_tdigest.compress(); compress();
if (centroids.size() == 1) if (centroids.size() == 1)
{ {
@ -408,7 +360,7 @@ public:
return; return;
} }
Float64 x = levels[levels_permutation[0]] * main_tdigest.getCount(); Float64 x = levels[levels_permutation[0]] * count;
Float64 prev_x = 0; Float64 prev_x = 0;
Count sum = 0; Count sum = 0;
Value prev_mean = centroids.front().mean; Value prev_mean = centroids.front().mean;
@ -426,7 +378,7 @@ public:
if (result_num >= size) if (result_num >= size)
return; return;
x = levels[levels_permutation[result_num]] * main_tdigest.getCount(); x = levels[levels_permutation[result_num]] * count;
} }
sum += c.count; sum += c.count;