From d8ba6167e96cf1494c392c2ad2610448f68b25fa Mon Sep 17 00:00:00 2001 From: vdimir Date: Fri, 28 May 2021 18:49:20 +0300 Subject: [PATCH] Minor fixes in AggregateFunctionSegmentLengthSumData --- .../AggregateFunctionSegmentLengthSum.h | 70 ++++++++++++------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionSegmentLengthSum.h b/src/AggregateFunctions/AggregateFunctionSegmentLengthSum.h index 0ca70cc7367..2f758431f12 100644 --- a/src/AggregateFunctions/AggregateFunctionSegmentLengthSum.h +++ b/src/AggregateFunctions/AggregateFunctionSegmentLengthSum.h @@ -1,22 +1,39 @@ #pragma once -#include +#include + #include -#include -#include -#include -#include + #include #include -#include +#include +#include + +#include +#include + +#include namespace DB { +namespace ErrorCodes +{ + extern const int TOO_LARGE_ARRAY_SIZE; +} + +/** + * Calculate total length of intervals without intersections. Each interval is the pair of numbers [begin, end]; + * Return UInt64 for integral types (UInt/Int*, DateTime) and return Float64 for Float*. + * + * Implementation simply stores intervals sorted by beginning and sums lengths at final. + */ template struct AggregateFunctionSegmentLengthSumData { + constexpr static size_t MAX_ARRAY_SIZE = 0xFFFFFF; + using Segment = std::pair; using Segments = PODArrayWithStackMemory; @@ -24,15 +41,13 @@ struct AggregateFunctionSegmentLengthSumData Segments segments; - size_t size() const { return segments.size(); } - - void add(T start, T end) + void add(T begin, T end) { - if (sorted && segments.size() > 0) + if (sorted && !segments.empty()) { - sorted = segments.back().first <= start; + sorted = segments.back().first <= begin; } - segments.emplace_back(start, end); + segments.emplace_back(begin, end); } void merge(const AggregateFunctionSegmentLengthSumData & other) @@ -46,7 +61,9 @@ struct AggregateFunctionSegmentLengthSumData /// either sort whole container or do so partially merging ranges afterwards if (!sorted && !other.sorted) - std::stable_sort(std::begin(segments), std::end(segments)); + { + std::sort(std::begin(segments), std::end(segments)); + } else { const auto begin = std::begin(segments); @@ -54,10 +71,10 @@ struct AggregateFunctionSegmentLengthSumData const auto end = std::end(segments); if (!sorted) - std::stable_sort(begin, middle); + std::sort(begin, middle); if (!other.sorted) - std::stable_sort(middle, end); + std::sort(middle, end); std::inplace_merge(begin, middle, end); } @@ -69,7 +86,7 @@ struct AggregateFunctionSegmentLengthSumData { if (!sorted) { - std::stable_sort(std::begin(segments), std::end(segments)); + std::sort(std::begin(segments), std::end(segments)); sorted = true; } } @@ -93,16 +110,18 @@ struct AggregateFunctionSegmentLengthSumData size_t size; readBinary(size, buf); + if (unlikely(size > MAX_ARRAY_SIZE)) + throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE); + segments.clear(); segments.reserve(size); - T start, end; - + Segment segment; for (size_t i = 0; i < size; ++i) { - readBinary(start, buf); - readBinary(end, buf); - segments.emplace_back(start, end); + readBinary(segment.first, buf); + readBinary(segment.second, buf); + segments.emplace_back(segment); } } }; @@ -114,7 +133,7 @@ private: template TResult getSegmentLengthSum(Data & data) const { - if (data.size() == 0) + if (data.segments.empty()) return 0; data.sort(); @@ -123,8 +142,9 @@ private: typename Data::Segment cur_segment = data.segments[0]; - for (size_t i = 1; i < data.segments.size(); ++i) + for (size_t i = 1, sz = data.segments.size(); i < sz; ++i) { + /// Check if current interval intersect with next one then add length, otherwise advance interval end if (cur_segment.second < data.segments[i].first) { res += cur_segment.second - cur_segment.first; @@ -167,9 +187,9 @@ public: void add(AggregateDataPtr __restrict place, const IColumn ** columns, const size_t row_num, Arena *) const override { - auto start = assert_cast *>(columns[0])->getData()[row_num]; + auto begin = assert_cast *>(columns[0])->getData()[row_num]; auto end = assert_cast *>(columns[1])->getData()[row_num]; - this->data(place).add(start, end); + this->data(place).add(begin, end); } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override