mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-18 13:42:02 +00:00
movingSum with/or without window_size parameter for numeric and decimal types
This commit is contained in:
parent
b5f42d36fa
commit
d75c73ece2
82
dbms/src/AggregateFunctions/AggregateFunctionMovingSum.cpp
Normal file
82
dbms/src/AggregateFunctions/AggregateFunctionMovingSum.cpp
Normal file
@ -0,0 +1,82 @@
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionMovingSum.h>
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
|
||||
extern const int BAD_ARGUMENTS;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename has_limit, typename ... TArgs>
|
||||
inline AggregateFunctionPtr createAggregateFunctionMovingSumImpl(const std::string & name, const DataTypePtr & argument_type, TArgs ... args)
|
||||
{
|
||||
AggregateFunctionPtr res;
|
||||
|
||||
if (isDecimal(argument_type))
|
||||
res.reset(createWithDecimalType<MovingSumImpl, has_limit>(*argument_type, argument_type, std::forward<TArgs>(args)...));
|
||||
else
|
||||
res.reset(createWithNumericType<MovingSumImpl, has_limit>(*argument_type, argument_type, std::forward<TArgs>(args)...));
|
||||
|
||||
if (!res)
|
||||
throw Exception("Illegal type " + argument_type->getName() + " of argument for aggregate function " + name,
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
|
||||
static AggregateFunctionPtr createAggregateFunctionMovingSum(const std::string & name, const DataTypes & argument_types, const Array & parameters)
|
||||
{
|
||||
assertUnary(name, argument_types);
|
||||
|
||||
bool limit_size = false;
|
||||
|
||||
UInt64 max_elems = std::numeric_limits<UInt64>::max();
|
||||
|
||||
if (parameters.empty())
|
||||
{
|
||||
// cumulative sum without parameter
|
||||
}
|
||||
else if (parameters.size() == 1)
|
||||
{
|
||||
auto type = parameters[0].getType();
|
||||
if (type != Field::Types::Int64 && type != Field::Types::UInt64)
|
||||
throw Exception("Parameter for aggregate function " + name + " should be positive number", ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
if ((type == Field::Types::Int64 && parameters[0].get<Int64>() < 0) ||
|
||||
(type == Field::Types::UInt64 && parameters[0].get<UInt64>() == 0))
|
||||
throw Exception("Parameter for aggregate function " + name + " should be positive number", ErrorCodes::BAD_ARGUMENTS);
|
||||
|
||||
limit_size = true;
|
||||
max_elems = parameters[0].get<UInt64>();
|
||||
}
|
||||
else
|
||||
throw Exception("Incorrect number of parameters for aggregate function " + name + ", should be 0 or 1",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
if (!limit_size)
|
||||
return createAggregateFunctionMovingSumImpl<std::false_type>(name, argument_types[0]);
|
||||
else
|
||||
return createAggregateFunctionMovingSumImpl<std::true_type>(name, argument_types[0], max_elems);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void registerAggregateFunctionMovingSum(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("movingSum", createAggregateFunctionMovingSum);
|
||||
}
|
||||
|
||||
}
|
161
dbms/src/AggregateFunctions/AggregateFunctionMovingSum.h
Normal file
161
dbms/src/AggregateFunctions/AggregateFunctionMovingSum.h
Normal file
@ -0,0 +1,161 @@
|
||||
#pragma once
|
||||
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
|
||||
#include <Common/ArenaAllocator.h>
|
||||
|
||||
#include <AggregateFunctions/IAggregateFunction.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE 0xFFFFFF
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TOO_LARGE_ARRAY_SIZE;
|
||||
extern const int LOGICAL_ERROR;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct MovingSumData
|
||||
{
|
||||
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
|
||||
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
|
||||
using Array = PODArray<T, 32, Allocator>;
|
||||
|
||||
Array value;
|
||||
Array window;
|
||||
T sum = 0;
|
||||
};
|
||||
|
||||
|
||||
template <typename T, typename Tlimit_num_elems>
|
||||
class MovingSumImpl final
|
||||
: public IAggregateFunctionDataHelper<MovingSumData<T>, MovingSumImpl<T, Tlimit_num_elems>>
|
||||
{
|
||||
static constexpr bool limit_num_elems = Tlimit_num_elems::value;
|
||||
DataTypePtr & data_type;
|
||||
UInt64 win_size;
|
||||
|
||||
public:
|
||||
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
|
||||
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; // probably for overflow function in the future
|
||||
|
||||
explicit MovingSumImpl(const DataTypePtr & data_type_, UInt64 win_size_ = std::numeric_limits<UInt64>::max())
|
||||
: IAggregateFunctionDataHelper<MovingSumData<T>, MovingSumImpl<T, Tlimit_num_elems>>({data_type_}, {})
|
||||
, data_type(this->argument_types[0]), win_size(win_size_) {}
|
||||
|
||||
String getName() const override { return "movingSum"; }
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(data_type);
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
|
||||
{
|
||||
auto & sum = this->data(place).sum;
|
||||
|
||||
sum += static_cast<const ColVecType &>(*columns[0]).getData()[row_num];
|
||||
|
||||
this->data(place).value.push_back(sum, arena);
|
||||
}
|
||||
|
||||
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
|
||||
{
|
||||
auto & cur_elems = this->data(place);
|
||||
auto & rhs_elems = this->data(rhs);
|
||||
|
||||
size_t cur_size = cur_elems.value.size();
|
||||
|
||||
if (rhs_elems.value.size())
|
||||
cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
|
||||
|
||||
for (size_t i = cur_size; i < cur_elems.value.size(); ++i)
|
||||
{
|
||||
cur_elems.value[i] += cur_elems.sum;
|
||||
}
|
||||
|
||||
cur_elems.sum += rhs_elems.sum;
|
||||
}
|
||||
|
||||
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
|
||||
{
|
||||
const auto & value = this->data(place).value;
|
||||
size_t size = value.size();
|
||||
writeVarUInt(size, buf);
|
||||
buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
|
||||
}
|
||||
|
||||
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
|
||||
{
|
||||
size_t size = 0;
|
||||
readVarUInt(size, buf);
|
||||
|
||||
if (unlikely(size > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE))
|
||||
throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE);
|
||||
|
||||
auto & value = this->data(place).value;
|
||||
|
||||
value.resize(size, arena);
|
||||
buf.read(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
|
||||
}
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const auto & value = this->data(place).value;
|
||||
size_t size = value.size();
|
||||
|
||||
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
offsets_to.push_back(offsets_to.back() + size);
|
||||
|
||||
if (size)
|
||||
{
|
||||
typename ColVecResult::Container & data_to = static_cast<ColVecResult &>(arr_to.getData()).getData();
|
||||
|
||||
if (!limit_num_elems)
|
||||
{
|
||||
data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t i = 0;
|
||||
for (; i < std::min(static_cast<size_t>(win_size), size); ++i)
|
||||
{
|
||||
data_to.push_back(value[i]);
|
||||
}
|
||||
for (; i < size; ++i)
|
||||
{
|
||||
data_to.push_back(value[i] - value[i - win_size]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
bool allocatesMemoryInArena() const override
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
};
|
||||
|
||||
#undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE
|
||||
|
||||
}
|
@ -108,6 +108,15 @@ static IAggregateFunction * createWithDecimalType(const IDataType & argument_typ
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename... TArgs>
|
||||
static IAggregateFunction * createWithDecimalType(const IDataType & argument_type, TArgs && ... args)
|
||||
{
|
||||
WhichDataType which(argument_type);
|
||||
if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionTemplate<Decimal32, Data>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionTemplate<Decimal64, Data>(std::forward<TArgs>(args)...);
|
||||
if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionTemplate<Decimal128, Data>(std::forward<TArgs>(args)...);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/** For template with two arguments.
|
||||
*/
|
||||
|
@ -31,6 +31,7 @@ void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionMovingSum(AggregateFunctionFactory &);
|
||||
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||
@ -74,6 +75,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionMLMethod(factory);
|
||||
registerAggregateFunctionEntropy(factory);
|
||||
registerAggregateFunctionSimpleLinearRegression(factory);
|
||||
registerAggregateFunctionMovingSum(factory);
|
||||
}
|
||||
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user