From d75c73ece26a9caaaabeff0637ebe7a28a79fba7 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 12 Jun 2019 01:56:37 -0400 Subject: [PATCH] movingSum with/or without window_size parameter for numeric and decimal types --- .../AggregateFunctionMovingSum.cpp | 82 +++++++++ .../AggregateFunctionMovingSum.h | 161 ++++++++++++++++++ dbms/src/AggregateFunctions/Helpers.h | 9 + .../registerAggregateFunctions.cpp | 2 + 4 files changed, 254 insertions(+) create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionMovingSum.cpp create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionMovingSum.h diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMovingSum.cpp b/dbms/src/AggregateFunctions/AggregateFunctionMovingSum.cpp new file mode 100644 index 00000000000..6b17cd65660 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionMovingSum.cpp @@ -0,0 +1,82 @@ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; + extern const int BAD_ARGUMENTS; +} + +namespace +{ + +template +inline AggregateFunctionPtr createAggregateFunctionMovingSumImpl(const std::string & name, const DataTypePtr & argument_type, TArgs ... args) +{ + AggregateFunctionPtr res; + + if (isDecimal(argument_type)) + res.reset(createWithDecimalType(*argument_type, argument_type, std::forward(args)...)); + else + res.reset(createWithNumericType(*argument_type, argument_type, std::forward(args)...)); + + if (!res) + throw Exception("Illegal type " + argument_type->getName() + " of argument for aggregate function " + name, + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; +} + + +static AggregateFunctionPtr createAggregateFunctionMovingSum(const std::string & name, const DataTypes & argument_types, const Array & parameters) +{ + assertUnary(name, argument_types); + + bool limit_size = false; + + UInt64 max_elems = std::numeric_limits::max(); + + if (parameters.empty()) + { + // cumulative sum without parameter + } + else if (parameters.size() == 1) + { + auto type = parameters[0].getType(); + if (type != Field::Types::Int64 && type != Field::Types::UInt64) + throw Exception("Parameter for aggregate function " + name + " should be positive number", ErrorCodes::BAD_ARGUMENTS); + + if ((type == Field::Types::Int64 && parameters[0].get() < 0) || + (type == Field::Types::UInt64 && parameters[0].get() == 0)) + throw Exception("Parameter for aggregate function " + name + " should be positive number", ErrorCodes::BAD_ARGUMENTS); + + limit_size = true; + max_elems = parameters[0].get(); + } + else + throw Exception("Incorrect number of parameters for aggregate function " + name + ", should be 0 or 1", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + if (!limit_size) + return createAggregateFunctionMovingSumImpl(name, argument_types[0]); + else + return createAggregateFunctionMovingSumImpl(name, argument_types[0], max_elems); +} + +} + + +void registerAggregateFunctionMovingSum(AggregateFunctionFactory & factory) +{ + factory.registerFunction("movingSum", createAggregateFunctionMovingSum); +} + +} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionMovingSum.h b/dbms/src/AggregateFunctions/AggregateFunctionMovingSum.h new file mode 100644 index 00000000000..df97a476e48 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionMovingSum.h @@ -0,0 +1,161 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +#include + +#include + +#include + +#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE 0xFFFFFF + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int TOO_LARGE_ARRAY_SIZE; + extern const int LOGICAL_ERROR; +} + + +template +struct MovingSumData +{ + // Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena + using Allocator = MixedAlignedArenaAllocator; + using Array = PODArray; + + Array value; + Array window; + T sum = 0; +}; + + +template +class MovingSumImpl final + : public IAggregateFunctionDataHelper, MovingSumImpl> +{ + static constexpr bool limit_num_elems = Tlimit_num_elems::value; + DataTypePtr & data_type; + UInt64 win_size; + +public: + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; + using ColVecResult = std::conditional_t, ColumnDecimal, ColumnVector>; // probably for overflow function in the future + + explicit MovingSumImpl(const DataTypePtr & data_type_, UInt64 win_size_ = std::numeric_limits::max()) + : IAggregateFunctionDataHelper, MovingSumImpl>({data_type_}, {}) + , data_type(this->argument_types[0]), win_size(win_size_) {} + + String getName() const override { return "movingSum"; } + + DataTypePtr getReturnType() const override + { + return std::make_shared(data_type); + } + + void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override + { + auto & sum = this->data(place).sum; + + sum += static_cast(*columns[0]).getData()[row_num]; + + this->data(place).value.push_back(sum, arena); + } + + void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override + { + auto & cur_elems = this->data(place); + auto & rhs_elems = this->data(rhs); + + size_t cur_size = cur_elems.value.size(); + + if (rhs_elems.value.size()) + cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena); + + for (size_t i = cur_size; i < cur_elems.value.size(); ++i) + { + cur_elems.value[i] += cur_elems.sum; + } + + cur_elems.sum += rhs_elems.sum; + } + + void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override + { + const auto & value = this->data(place).value; + size_t size = value.size(); + writeVarUInt(size, buf); + buf.write(reinterpret_cast(value.data()), size * sizeof(value[0])); + } + + void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override + { + size_t size = 0; + readVarUInt(size, buf); + + if (unlikely(size > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE)) + throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE); + + auto & value = this->data(place).value; + + value.resize(size, arena); + buf.read(reinterpret_cast(value.data()), size * sizeof(value[0])); + } + + void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override + { + const auto & value = this->data(place).value; + size_t size = value.size(); + + ColumnArray & arr_to = static_cast(to); + ColumnArray::Offsets & offsets_to = arr_to.getOffsets(); + + offsets_to.push_back(offsets_to.back() + size); + + if (size) + { + typename ColVecResult::Container & data_to = static_cast(arr_to.getData()).getData(); + + if (!limit_num_elems) + { + data_to.insert(this->data(place).value.begin(), this->data(place).value.end()); + } + else + { + size_t i = 0; + for (; i < std::min(static_cast(win_size), size); ++i) + { + data_to.push_back(value[i]); + } + for (; i < size; ++i) + { + data_to.push_back(value[i] - value[i - win_size]); + } + } + + } + } + + bool allocatesMemoryInArena() const override + { + return true; + } + + const char * getHeaderFilePath() const override { return __FILE__; } +}; + +#undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE + +} diff --git a/dbms/src/AggregateFunctions/Helpers.h b/dbms/src/AggregateFunctions/Helpers.h index 44292d880cf..8d42654811a 100644 --- a/dbms/src/AggregateFunctions/Helpers.h +++ b/dbms/src/AggregateFunctions/Helpers.h @@ -108,6 +108,15 @@ static IAggregateFunction * createWithDecimalType(const IDataType & argument_typ return nullptr; } +template