#pragma once #include #include #include #include #include #include #include #include #include #include #include #include #include /** Алгоритм реализовал Алексей Борзенков https://███████████.yandex-team.ru/snaury * Ему принадлежит авторство кода и половины комментариев в данном namespace, * за исключением слияния, сериализации и сортировки, а также выбора типов и других изменений. * Мы благодарим Алексея Борзенкова за написание изначального кода. */ namespace tdigest { /** * Центроид хранит вес точек вокруг их среднего значения */ template struct Centroid { Value mean; Count count; Centroid() = default; explicit Centroid(Value mean, Count count = 1) : mean(mean) , count(count) {} Centroid & operator+=(const Centroid & other) { count += other.count; mean += other.count * (other.mean - mean) / count; return *this; } bool operator<(const Centroid & other) const { return mean < other.mean; } }; /** :param epsilon: значение \delta из статьи - погрешность в районе * квантиля 0.5 (по-умолчанию 0.01, т.е. 1%) * :param max_unmerged: при накоплении кол-ва новых точек сверх этого * значения запускается компрессия центроидов * (по-умолчанию 2048, чем выше значение - тем * больше требуется памяти, но повышается * амортизация времени выполнения) */ template struct Params { Value epsilon = 0.01; size_t max_unmerged = 2048; }; /** Реализация алгоритма t-digest (https://github.com/tdunning/t-digest). * Этот вариант очень похож на MergingDigest на java, однако решение об * объединении принимается на основе оригинального условия из статьи * (через ограничение на размер, используя апроксимацию квантиля каждого * центроида, а не расстояние на кривой положения их границ). MergingDigest * на java даёт значительно меньше центроидов, чем данный вариант, что * негативно влияет на точность при том же факторе компрессии, но даёт * гарантии размера. Сам автор на предложение об этом варианте сказал, что * размер дайжеста растёт как O(log(n)), в то время как вариант на java * не зависит от предполагаемого кол-ва точек. Кроме того вариант на java * использует asin, чем немного замедляет алгоритм. */ template class MergingDigest { using Params = tdigest::Params; using Centroid = tdigest::Centroid; /// Сразу будет выделена память на несколько элементов так, чтобы состояние занимало 64 байта. static constexpr size_t bytes_in_arena = 64 - sizeof(DB::PODArray) - sizeof(TotalCount) - sizeof(uint32_t); using Summary = DB::PODArray, bytes_in_arena>>; Summary summary; TotalCount count = 0; uint32_t unmerged = 0; /** Линейная интерполяция в точке x на прямой (x1, y1)..(x2, y2) */ static Value interpolate(Value x, Value x1, Value y1, Value x2, Value y2) { double k = (x - x1) / (x2 - x1); return y1 + k * (y2 - y1); } struct RadixSortTraits { using Element = Centroid; using Key = Value; using CountType = uint32_t; using KeyBits = uint32_t; static constexpr size_t PART_SIZE_BITS = 8; using Transform = RadixSortFloatTransform; using Allocator = RadixSortMallocAllocator; /// Функция получения ключа из элемента массива. static Key & extractKey(Element & elem) { return elem.mean; } }; public: /** Добавляет к дайджесту изменение x с весом cnt (по-умолчанию 1) */ void add(const Params & params, Value x, CentroidCount cnt = 1) { add(params, Centroid(x, cnt)); } /** Добавляет к дайджесту центроид c */ void add(const Params & params, const Centroid & c) { summary.push_back(c); count += c.count; ++unmerged; if (unmerged >= params.max_unmerged) compress(params); } /** Выполняет компрессию накопленных центроидов * При объединении сохраняется инвариант на максимальный размер каждого * центроида, не превышающий 4 q (1 - q) \delta N. */ void compress(const Params & params) { if (unmerged > 0) { RadixSort::execute(&summary[0], summary.size()); if (summary.size() > 3) { /// Пара подряд идущих столбиков гистограммы. auto l = summary.begin(); auto r = std::next(l); TotalCount sum = 0; while (r != summary.end()) { // we use quantile which gives us the smallest error /// Отношение части гистограммы до l, включая половинку l ко всей гистограмме. То есть, какого уровня квантиль в позиции l. Value ql = (sum + l->count * 0.5) / count; Value err = ql * (1 - ql); /// Отношение части гистограммы до l, включая l и половинку r ко всей гистограмме. То есть, какого уровня квантиль в позиции r. Value qr = (sum + l->count + r->count * 0.5) / count; Value err2 = qr * (1 - qr); if (err > err2) err = err2; Value k = 4 * count * err * params.epsilon; /** Отношение веса склеенной пары столбиков ко всем значениям не больше, * чем epsilon умножить на некий квадратичный коэффициент, который в медиане равен 1 (4 * 1/2 * 1/2), * а по краям убывает и примерно равен расстоянию до края * 4. */ if (l->count + r->count <= k) { // it is possible to merge left and right /// Левый столбик "съедает" правый. *l += *r; } else { // not enough capacity, check the next pair sum += l->count; ++l; /// Пропускаем все "съеденные" ранее значения. if (l != r) *l = *r; } ++r; } /// По окончании цикла, все значения правее l были "съедены". summary.resize(l - summary.begin() + 1); } unmerged = 0; } } /** Вычисляет квантиль q [0, 1] на основе дайджеста * Для пустого дайджеста возвращает NaN. */ Value getQuantile(const Params & params, Value q) { if (summary.empty()) return NAN; compress(params); if (summary.size() == 1) return summary.front().mean; Value x = q * count; TotalCount sum = 0; Value prev_mean = summary.front().mean; Value prev_x = 0; for (const auto & c : summary) { Value current_x = sum + c.count * 0.5; if (current_x >= x) return interpolate(x, prev_x, prev_mean, current_x, c.mean); sum += c.count; prev_mean = c.mean; prev_x = current_x; } return summary.back().mean; } /** Получить несколько квантилей (size штук). * levels - массив уровней нужных квантилей. Они идут в произвольном порядке. * levels_permutation - массив-перестановка уровней. На i-ой позиции будет лежать индекс i-го по возрастанию уровня в массиве levels. * result - массив, куда сложить результаты, в порядке levels, */ template void getManyQuantiles(const Params & params, const Value * levels, const size_t * levels_permutation, size_t size, ResultType * result) { if (summary.empty()) { for (size_t result_num = 0; result_num < size; ++result_num) result[result_num] = std::is_floating_point::value ? NAN : 0; return; } compress(params); if (summary.size() == 1) { for (size_t result_num = 0; result_num < size; ++result_num) result[result_num] = summary.front().mean; return; } Value x = levels[levels_permutation[0]] * count; TotalCount sum = 0; Value prev_mean = summary.front().mean; Value prev_x = 0; size_t result_num = 0; for (const auto & c : summary) { Value current_x = sum + c.count * 0.5; while (current_x >= x) { result[levels_permutation[result_num]] = interpolate(x, prev_x, prev_mean, current_x, c.mean); ++result_num; if (result_num >= size) return; x = levels[levels_permutation[result_num]] * count; } sum += c.count; prev_mean = c.mean; prev_x = current_x; } auto rest_of_results = summary.back().mean; for (; result_num < size; ++result_num) result[levels_permutation[result_num]] = rest_of_results; } /** Объединить с другим состоянием. */ void merge(const Params & params, const MergingDigest & other) { for (const auto & c : other.summary) add(params, c); } /** Записать в поток. */ void write(const Params & params, DB::WriteBuffer & buf) { compress(params); DB::writeVarUInt(summary.size(), buf); buf.write(reinterpret_cast(&summary[0]), summary.size() * sizeof(summary[0])); } /** Прочитать из потока и объединить с текущим состоянием. */ void readAndMerge(const Params & params, DB::ReadBuffer & buf) { size_t size = 0; DB::readVarUInt(size, buf); if (size > params.max_unmerged) throw DB::Exception("Too large t-digest summary size", DB::ErrorCodes::TOO_LARGE_ARRAY_SIZE); for (size_t i = 0; i < size; ++i) { Centroid c; DB::readPODBinary(c, buf); add(params, c); } } }; } namespace DB { struct AggregateFunctionQuantileTDigestData { tdigest::MergingDigest digest; }; template class AggregateFunctionQuantileTDigest final : public IUnaryAggregateFunction> { private: Float32 level; tdigest::Params params; DataTypePtr type; public: AggregateFunctionQuantileTDigest(double level_ = 0.5) : level(level_) {} String getName() const override { return "quantileTDigest"; } DataTypePtr getReturnType() const override { return type; } void setArgument(const DataTypePtr & argument) { if (returns_float) type = new DataTypeFloat32; else type = argument; } void setParameters(const Array & params) override { if (params.size() != 1) throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); level = apply_visitor(FieldVisitorConvertToNumber(), params[0]); } void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const { this->data(place).digest.add(params, static_cast &>(column).getData()[row_num]); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override { this->data(place).digest.merge(params, this->data(rhs).digest); } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(const_cast(place)).digest.write(params, buf); } void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override { this->data(place).digest.readAndMerge(params, buf); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { auto quantile = this->data(const_cast(place)).digest.getQuantile(params, level); if (returns_float) static_cast(to).getData().push_back(quantile); else static_cast &>(to).getData().push_back(quantile); } }; template class AggregateFunctionQuantileTDigestWeighted final : public IBinaryAggregateFunction> { private: Float32 level; tdigest::Params params; DataTypePtr type; public: AggregateFunctionQuantileTDigestWeighted(double level_ = 0.5) : level(level_) {} String getName() const override { return "quantileTDigestWeighted"; } DataTypePtr getReturnType() const override { return type; } void setArgumentsImpl(const DataTypes & arguments) { if (returns_float) type = new DataTypeFloat32; else type = arguments.at(0); } void setParameters(const Array & params) override { if (params.size() != 1) throw Exception("Aggregate function " + getName() + " requires exactly one parameter.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); level = apply_visitor(FieldVisitorConvertToNumber(), params[0]); } void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const { this->data(place).digest.add(params, static_cast &>(column_value).getData()[row_num], static_cast &>(column_weight).getData()[row_num]); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override { this->data(place).digest.merge(params, this->data(rhs).digest); } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(const_cast(place)).digest.write(params, buf); } void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override { this->data(place).digest.readAndMerge(params, buf); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { auto quantile = this->data(const_cast(place)).digest.getQuantile(params, level); if (returns_float) static_cast(to).getData().push_back(quantile); else static_cast &>(to).getData().push_back(quantile); } }; template class AggregateFunctionQuantilesTDigest final : public IUnaryAggregateFunction> { private: QuantileLevels levels; tdigest::Params params; DataTypePtr type; public: String getName() const override { return "quantilesTDigest"; } DataTypePtr getReturnType() const override { return new DataTypeArray(type); } void setArgument(const DataTypePtr & argument) { if (returns_float) type = new DataTypeFloat32; else type = argument; } void setParameters(const Array & params) override { levels.set(params); } void addImpl(AggregateDataPtr place, const IColumn & column, size_t row_num) const { this->data(place).digest.add(params, static_cast &>(column).getData()[row_num]); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override { this->data(place).digest.merge(params, this->data(rhs).digest); } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(const_cast(place)).digest.write(params, buf); } void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override { this->data(place).digest.readAndMerge(params, buf); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { ColumnArray & arr_to = static_cast(to); ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets(); size_t size = levels.size(); offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + size); if (returns_float) { typename ColumnFloat32::Container_t & data_to = static_cast(arr_to.getData()).getData(); size_t old_size = data_to.size(); data_to.resize(data_to.size() + size); this->data(const_cast(place)).digest.getManyQuantiles( params, &levels.levels[0], &levels.permutation[0], size, &data_to[old_size]); } else { typename ColumnVector::Container_t & data_to = static_cast &>(arr_to.getData()).getData(); size_t old_size = data_to.size(); data_to.resize(data_to.size() + size); this->data(const_cast(place)).digest.getManyQuantiles( params, &levels.levels[0], &levels.permutation[0], size, &data_to[old_size]); } } }; template class AggregateFunctionQuantilesTDigestWeighted final : public IBinaryAggregateFunction> { private: QuantileLevels levels; tdigest::Params params; DataTypePtr type; public: String getName() const override { return "quantilesTDigest"; } DataTypePtr getReturnType() const override { return new DataTypeArray(type); } void setArgumentsImpl(const DataTypes & arguments) { if (returns_float) type = new DataTypeFloat32; else type = arguments.at(0); } void setParameters(const Array & params) override { levels.set(params); } void addImpl(AggregateDataPtr place, const IColumn & column_value, const IColumn & column_weight, size_t row_num) const { this->data(place).digest.add(params, static_cast &>(column_value).getData()[row_num], static_cast &>(column_weight).getData()[row_num]); } void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs) const override { this->data(place).digest.merge(params, this->data(rhs).digest); } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override { this->data(const_cast(place)).digest.write(params, buf); } void deserializeMerge(AggregateDataPtr place, ReadBuffer & buf) const override { this->data(place).digest.readAndMerge(params, buf); } void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override { ColumnArray & arr_to = static_cast(to); ColumnArray::Offsets_t & offsets_to = arr_to.getOffsets(); size_t size = levels.size(); offsets_to.push_back((offsets_to.size() == 0 ? 0 : offsets_to.back()) + size); if (returns_float) { typename ColumnFloat32::Container_t & data_to = static_cast(arr_to.getData()).getData(); size_t old_size = data_to.size(); data_to.resize(data_to.size() + size); this->data(const_cast(place)).digest.getManyQuantiles( params, &levels.levels[0], &levels.permutation[0], size, &data_to[old_size]); } else { typename ColumnVector::Container_t & data_to = static_cast &>(arr_to.getData()).getData(); size_t old_size = data_to.size(); data_to.resize(data_to.size() + size); this->data(const_cast(place)).digest.getManyQuantiles( params, &levels.levels[0], &levels.permutation[0], size, &data_to[old_size]); } } }; }