From b297e0ef361d737f3b9e24485097c9a234720401 Mon Sep 17 00:00:00 2001 From: Russ Frank Date: Mon, 8 Feb 2021 22:48:56 -0500 Subject: [PATCH] feedback: use references, dont support decimal, rearrange struct members --- .../AggregateFunctionDeltaSum.cpp | 10 ++-- .../AggregateFunctionDeltaSum.h | 59 +++++++++++++------ .../0_stateless/01700_deltasum.reference | 4 +- tests/queries/0_stateless/01700_deltasum.sql | 2 + 4 files changed, 51 insertions(+), 24 deletions(-) diff --git a/src/AggregateFunctions/AggregateFunctionDeltaSum.cpp b/src/AggregateFunctions/AggregateFunctionDeltaSum.cpp index aeb2549e826..231b730d1aa 100644 --- a/src/AggregateFunctions/AggregateFunctionDeltaSum.cpp +++ b/src/AggregateFunctions/AggregateFunctionDeltaSum.cpp @@ -30,13 +30,13 @@ AggregateFunctionPtr createAggregateFunctionDeltaSum( DataTypePtr data_type = arguments[0]; - if (!isNumber(data_type)) + if (isInteger(data_type) || isFloat(data_type)) + return AggregateFunctionPtr(createWithNumericType( + *data_type, arguments, params)); + else throw Exception("Illegal type " + arguments[0]->getName() + " of argument for aggregate function " + name, - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - return AggregateFunctionPtr(createWithNumericType(*arguments[0], arguments, params)); + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } - } void registerAggregateFunctionDeltaSum(AggregateFunctionFactory & factory) diff --git a/src/AggregateFunctions/AggregateFunctionDeltaSum.h b/src/AggregateFunctions/AggregateFunctionDeltaSum.h index 7d384438912..af745165379 100644 --- a/src/AggregateFunctions/AggregateFunctionDeltaSum.h +++ b/src/AggregateFunctions/AggregateFunctionDeltaSum.h @@ -15,14 +15,17 @@ namespace DB { +template +using DecimalOrVectorCol = std::conditional_t, ColumnDecimal, ColumnVector>; + template struct AggregationFunctionDeltaSumData { T sum = 0; - bool seen_last = false; T last = 0; - bool seen_first = false; T first = 0; + bool seen_last = false; + bool seen_first = false; }; template @@ -32,9 +35,11 @@ class AggregationFunctionDeltaSum final public: AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params) : IAggregateFunctionDataHelper, AggregationFunctionDeltaSum>{arguments, params} - { - // empty constructor - } + {} + + AggregationFunctionDeltaSum() + : IAggregateFunctionDataHelper, AggregationFunctionDeltaSum>{} + {} String getName() const override { return "deltaSum"; } @@ -42,7 +47,7 @@ public: void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override { - auto value = static_cast &>(*columns[0]).getData()[row_num]; + auto value = assert_cast &>(*columns[0]).getData()[row_num]; if ((this->data(place).last < value) && this->data(place).seen_last) { @@ -61,24 +66,42 @@ public: void NO_SANITIZE_UNDEFINED ALWAYS_INLINE merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override { - if ((this->data(place).last < this->data(rhs).first) && this->data(place).seen_last && this->data(rhs).seen_first) + auto place_data = &this->data(place); + auto rhs_data = &this->data(rhs); + + if ((place_data->last < rhs_data->first) && place_data->seen_last && rhs_data->seen_first) { - this->data(place).sum += this->data(rhs).sum + (this->data(rhs).first - this->data(place).last); - this->data(place).last = this->data(rhs).last; + // If the lhs last number seen is less than the first number the rhs saw, the lhs is before + // the rhs, for example [0, 2] [4, 7]. So we want to add the deltasums, but also add the + // difference between lhs last number and rhs first number (the 2 and 4). Then we want to + // take last value from the rhs, so first and last become 0 and 7. + + place_data->sum += rhs_data->sum + (rhs_data->first - place_data->last); + place_data->last = rhs_data->last; } - else if ((this->data(rhs).last < this->data(place).first && this->data(rhs).seen_last && this->data(place).seen_first)) + else if ((rhs_data->last < place_data->first && rhs_data->seen_last && place_data->seen_first)) { - this->data(place).sum += this->data(rhs).sum + (this->data(place).first - this->data(rhs).last); - this->data(place).first = this->data(rhs).first; + // In the opposite scenario, the lhs comes after the rhs, e.g. [4, 6] [1, 2]. Since we + // assume the input interval states are sorted by time, we assume this is a counter + // reset, and therefore do *not* add the difference between our first value and the + // rhs last value. + + place_data->sum += rhs_data->sum; + place_data->first = rhs_data->first; } - else + else if (rhs_data->seen_first) { - this->data(place).sum += this->data(rhs).sum; - this->data(place).first = this->data(rhs).first; - this->data(place).seen_first = this->data(rhs).seen_first; - this->data(place).last = this->data(rhs).last; - this->data(place).seen_last = this->data(rhs).seen_last; + // If we're here then the lhs is an empty state and the rhs does have some state, so + // we'll just take that state. + + place_data->first = rhs_data->first; + place_data->seen_first = rhs_data->seen_first; + place_data->last = rhs_data->last; + place_data->seen_last = rhs_data->seen_last; + place_data->sum = rhs_data->sum; } + + // Otherwise lhs either has data or is unitialized, so we don't need to modify its values. } void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override diff --git a/tests/queries/0_stateless/01700_deltasum.reference b/tests/queries/0_stateless/01700_deltasum.reference index d442bc1de2e..be5b176c627 100644 --- a/tests/queries/0_stateless/01700_deltasum.reference +++ b/tests/queries/0_stateless/01700_deltasum.reference @@ -4,4 +4,6 @@ 7 7 5 -5 +2 +2.25 +6.5 diff --git a/tests/queries/0_stateless/01700_deltasum.sql b/tests/queries/0_stateless/01700_deltasum.sql index a1447cd3c7c..93edb2e477d 100644 --- a/tests/queries/0_stateless/01700_deltasum.sql +++ b/tests/queries/0_stateless/01700_deltasum.sql @@ -5,3 +5,5 @@ select deltaSum(arrayJoin([1, 2, 3, 0, 3, 3, 3, 3, 3, 4, 2, 3])); select deltaSum(arrayJoin([1, 2, 3, 0, 0, 0, 0, 3, 3, 3, 3, 3, 4, 2, 3])); select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([0, 1])) as rows union all select deltaSumState(arrayJoin([4, 5])) as rows); select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([4, 5])) as rows union all select deltaSumState(arrayJoin([0, 1])) as rows); +select deltaSum(arrayJoin([2.25, 3, 4.5])); +select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([0.1, 0.3, 0.5])) as rows union all select deltaSumState(arrayJoin([4.1, 5.1, 6.6])) as rows);