mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 15:12:02 +00:00
feedback: use references, dont support decimal, rearrange struct members
This commit is contained in:
parent
746dc1ddae
commit
b297e0ef36
@ -30,13 +30,13 @@ AggregateFunctionPtr createAggregateFunctionDeltaSum(
|
||||
|
||||
DataTypePtr data_type = arguments[0];
|
||||
|
||||
if (!isNumber(data_type))
|
||||
if (isInteger(data_type) || isFloat(data_type))
|
||||
return AggregateFunctionPtr(createWithNumericType<AggregationFunctionDeltaSum>(
|
||||
*data_type, arguments, params));
|
||||
else
|
||||
throw Exception("Illegal type " + arguments[0]->getName() + " of argument for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return AggregateFunctionPtr(createWithNumericType<AggregationFunctionDeltaSum>(*arguments[0], arguments, params));
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionDeltaSum(AggregateFunctionFactory & factory)
|
||||
|
@ -15,14 +15,17 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
template <typename T>
|
||||
using DecimalOrVectorCol = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
|
||||
template <typename T>
|
||||
struct AggregationFunctionDeltaSumData
|
||||
{
|
||||
T sum = 0;
|
||||
bool seen_last = false;
|
||||
T last = 0;
|
||||
bool seen_first = false;
|
||||
T first = 0;
|
||||
bool seen_last = false;
|
||||
bool seen_first = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -32,9 +35,11 @@ class AggregationFunctionDeltaSum final
|
||||
public:
|
||||
AggregationFunctionDeltaSum(const DataTypes & arguments, const Array & params)
|
||||
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{arguments, params}
|
||||
{
|
||||
// empty constructor
|
||||
}
|
||||
{}
|
||||
|
||||
AggregationFunctionDeltaSum()
|
||||
: IAggregateFunctionDataHelper<AggregationFunctionDeltaSumData<T>, AggregationFunctionDeltaSum<T>>{}
|
||||
{}
|
||||
|
||||
String getName() const override { return "deltaSum"; }
|
||||
|
||||
@ -42,7 +47,7 @@ public:
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
|
||||
{
|
||||
auto value = static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];
|
||||
auto value = assert_cast<const DecimalOrVectorCol<T> &>(*columns[0]).getData()[row_num];
|
||||
|
||||
if ((this->data(place).last < value) && this->data(place).seen_last)
|
||||
{
|
||||
@ -61,24 +66,42 @@ public:
|
||||
|
||||
void NO_SANITIZE_UNDEFINED ALWAYS_INLINE merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena *) const override
|
||||
{
|
||||
if ((this->data(place).last < this->data(rhs).first) && this->data(place).seen_last && this->data(rhs).seen_first)
|
||||
auto place_data = &this->data(place);
|
||||
auto rhs_data = &this->data(rhs);
|
||||
|
||||
if ((place_data->last < rhs_data->first) && place_data->seen_last && rhs_data->seen_first)
|
||||
{
|
||||
this->data(place).sum += this->data(rhs).sum + (this->data(rhs).first - this->data(place).last);
|
||||
this->data(place).last = this->data(rhs).last;
|
||||
// If the lhs last number seen is less than the first number the rhs saw, the lhs is before
|
||||
// the rhs, for example [0, 2] [4, 7]. So we want to add the deltasums, but also add the
|
||||
// difference between lhs last number and rhs first number (the 2 and 4). Then we want to
|
||||
// take last value from the rhs, so first and last become 0 and 7.
|
||||
|
||||
place_data->sum += rhs_data->sum + (rhs_data->first - place_data->last);
|
||||
place_data->last = rhs_data->last;
|
||||
}
|
||||
else if ((this->data(rhs).last < this->data(place).first && this->data(rhs).seen_last && this->data(place).seen_first))
|
||||
else if ((rhs_data->last < place_data->first && rhs_data->seen_last && place_data->seen_first))
|
||||
{
|
||||
this->data(place).sum += this->data(rhs).sum + (this->data(place).first - this->data(rhs).last);
|
||||
this->data(place).first = this->data(rhs).first;
|
||||
// In the opposite scenario, the lhs comes after the rhs, e.g. [4, 6] [1, 2]. Since we
|
||||
// assume the input interval states are sorted by time, we assume this is a counter
|
||||
// reset, and therefore do *not* add the difference between our first value and the
|
||||
// rhs last value.
|
||||
|
||||
place_data->sum += rhs_data->sum;
|
||||
place_data->first = rhs_data->first;
|
||||
}
|
||||
else
|
||||
else if (rhs_data->seen_first)
|
||||
{
|
||||
this->data(place).sum += this->data(rhs).sum;
|
||||
this->data(place).first = this->data(rhs).first;
|
||||
this->data(place).seen_first = this->data(rhs).seen_first;
|
||||
this->data(place).last = this->data(rhs).last;
|
||||
this->data(place).seen_last = this->data(rhs).seen_last;
|
||||
// If we're here then the lhs is an empty state and the rhs does have some state, so
|
||||
// we'll just take that state.
|
||||
|
||||
place_data->first = rhs_data->first;
|
||||
place_data->seen_first = rhs_data->seen_first;
|
||||
place_data->last = rhs_data->last;
|
||||
place_data->seen_last = rhs_data->seen_last;
|
||||
place_data->sum = rhs_data->sum;
|
||||
}
|
||||
|
||||
// Otherwise lhs either has data or is unitialized, so we don't need to modify its values.
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
|
@ -4,4 +4,6 @@
|
||||
7
|
||||
7
|
||||
5
|
||||
5
|
||||
2
|
||||
2.25
|
||||
6.5
|
||||
|
@ -5,3 +5,5 @@ select deltaSum(arrayJoin([1, 2, 3, 0, 3, 3, 3, 3, 3, 4, 2, 3]));
|
||||
select deltaSum(arrayJoin([1, 2, 3, 0, 0, 0, 0, 3, 3, 3, 3, 3, 4, 2, 3]));
|
||||
select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([0, 1])) as rows union all select deltaSumState(arrayJoin([4, 5])) as rows);
|
||||
select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([4, 5])) as rows union all select deltaSumState(arrayJoin([0, 1])) as rows);
|
||||
select deltaSum(arrayJoin([2.25, 3, 4.5]));
|
||||
select deltaSumMerge(rows) from (select deltaSumState(arrayJoin([0.1, 0.3, 0.5])) as rows union all select deltaSumState(arrayJoin([4.1, 5.1, 6.6])) as rows);
|
||||
|
Loading…
Reference in New Issue
Block a user