diff --git a/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.h b/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.h index 5a82d9e6a5c..93e81be93fc 100644 --- a/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.h +++ b/src/AggregateFunctions/AggregateFunctionIntervalLengthSum.h @@ -26,12 +26,11 @@ namespace ErrorCodes extern const int TOO_LARGE_ARRAY_SIZE; } -/** - * Calculate total length of intervals without intersections. Each interval is the pair of numbers [begin, end]; - * Return Int64 for integral types (UInt/Int*, Date/DateTime) and return Float64 for Float*. - * - * Implementation simply stores intervals sorted by beginning and sums lengths at final. - */ +/** Calculate total length of intervals without intersections. Each interval is the pair of numbers [begin, end]; + * Returns UInt64 for integral types (UInt/Int*, Date/DateTime) and returns Float64 for Float*. + * + * Implementation simply stores intervals sorted by beginning and sums lengths at final. + */ template struct AggregateFunctionIntervalLengthSumData { @@ -46,10 +45,14 @@ struct AggregateFunctionIntervalLengthSumData void add(T begin, T end) { + /// Reversed intervals are counted by absolute value of their length. + if (unlikely(end < begin)) + std::swap(begin, end); + else if (unlikely(begin == end)) + return; + if (sorted && !segments.empty()) - { sorted = segments.back().first <= begin; - } segments.emplace_back(begin, end); } @@ -147,7 +150,7 @@ private: for (size_t i = 1, sz = data.segments.size(); i < sz; ++i) { - /// Check if current interval intersect with next one then add length, otherwise advance interval end + /// Check if current interval intersects with next one then add length, otherwise advance interval end. if (cur_segment.second < data.segments[i].first) { if constexpr (std::is_floating_point_v) @@ -184,7 +187,7 @@ public: { if constexpr (std::is_floating_point_v) return std::make_shared(); - return std::make_shared(); + return std::make_shared(); } bool allocatesMemoryInArena() const override { return false; } @@ -225,7 +228,7 @@ public: if constexpr (std::is_floating_point_v) assert_cast(to).getData().push_back(getIntervalLengthSum(this->data(place))); else - assert_cast(to).getData().push_back(getIntervalLengthSum(this->data(place))); + assert_cast(to).getData().push_back(getIntervalLengthSum(this->data(place))); } };