movingSum with/or without window_size parameter for numeric and decimal types

This commit is contained in:
unknown 2019-06-12 01:56:37 -04:00
parent b5f42d36fa
commit d75c73ece2
4 changed files with 254 additions and 0 deletions

View File

@ -0,0 +1,82 @@
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionMovingSum.h>
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
namespace DB
{
namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
}
namespace
{
template <typename has_limit, typename ... TArgs>
inline AggregateFunctionPtr createAggregateFunctionMovingSumImpl(const std::string & name, const DataTypePtr & argument_type, TArgs ... args)
{
AggregateFunctionPtr res;
if (isDecimal(argument_type))
res.reset(createWithDecimalType<MovingSumImpl, has_limit>(*argument_type, argument_type, std::forward<TArgs>(args)...));
else
res.reset(createWithNumericType<MovingSumImpl, has_limit>(*argument_type, argument_type, std::forward<TArgs>(args)...));
if (!res)
throw Exception("Illegal type " + argument_type->getName() + " of argument for aggregate function " + name,
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res;
}
static AggregateFunctionPtr createAggregateFunctionMovingSum(const std::string & name, const DataTypes & argument_types, const Array & parameters)
{
assertUnary(name, argument_types);
bool limit_size = false;
UInt64 max_elems = std::numeric_limits<UInt64>::max();
if (parameters.empty())
{
// cumulative sum without parameter
}
else if (parameters.size() == 1)
{
auto type = parameters[0].getType();
if (type != Field::Types::Int64 && type != Field::Types::UInt64)
throw Exception("Parameter for aggregate function " + name + " should be positive number", ErrorCodes::BAD_ARGUMENTS);
if ((type == Field::Types::Int64 && parameters[0].get<Int64>() < 0) ||
(type == Field::Types::UInt64 && parameters[0].get<UInt64>() == 0))
throw Exception("Parameter for aggregate function " + name + " should be positive number", ErrorCodes::BAD_ARGUMENTS);
limit_size = true;
max_elems = parameters[0].get<UInt64>();
}
else
throw Exception("Incorrect number of parameters for aggregate function " + name + ", should be 0 or 1",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
if (!limit_size)
return createAggregateFunctionMovingSumImpl<std::false_type>(name, argument_types[0]);
else
return createAggregateFunctionMovingSumImpl<std::true_type>(name, argument_types[0], max_elems);
}
}
void registerAggregateFunctionMovingSum(AggregateFunctionFactory & factory)
{
factory.registerFunction("movingSum", createAggregateFunctionMovingSum);
}
}

View File

@ -0,0 +1,161 @@
#pragma once
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypesDecimal.h>
#include <Columns/ColumnVector.h>
#include <Columns/ColumnArray.h>
#include <Common/ArenaAllocator.h>
#include <AggregateFunctions/IAggregateFunction.h>
#include <type_traits>
#define AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE 0xFFFFFF
namespace DB
{
namespace ErrorCodes
{
extern const int TOO_LARGE_ARRAY_SIZE;
extern const int LOGICAL_ERROR;
}
template <typename T>
struct MovingSumData
{
// Switch to ordinary Allocator after 4096 bytes to avoid fragmentation and trash in Arena
using Allocator = MixedAlignedArenaAllocator<alignof(T), 4096>;
using Array = PODArray<T, 32, Allocator>;
Array value;
Array window;
T sum = 0;
};
template <typename T, typename Tlimit_num_elems>
class MovingSumImpl final
: public IAggregateFunctionDataHelper<MovingSumData<T>, MovingSumImpl<T, Tlimit_num_elems>>
{
static constexpr bool limit_num_elems = Tlimit_num_elems::value;
DataTypePtr & data_type;
UInt64 win_size;
public:
using ColVecType = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>;
using ColVecResult = std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<T>, ColumnVector<T>>; // probably for overflow function in the future
explicit MovingSumImpl(const DataTypePtr & data_type_, UInt64 win_size_ = std::numeric_limits<UInt64>::max())
: IAggregateFunctionDataHelper<MovingSumData<T>, MovingSumImpl<T, Tlimit_num_elems>>({data_type_}, {})
, data_type(this->argument_types[0]), win_size(win_size_) {}
String getName() const override { return "movingSum"; }
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(data_type);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
auto & sum = this->data(place).sum;
sum += static_cast<const ColVecType &>(*columns[0]).getData()[row_num];
this->data(place).value.push_back(sum, arena);
}
void merge(AggregateDataPtr place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
auto & cur_elems = this->data(place);
auto & rhs_elems = this->data(rhs);
size_t cur_size = cur_elems.value.size();
if (rhs_elems.value.size())
cur_elems.value.insert(rhs_elems.value.begin(), rhs_elems.value.end(), arena);
for (size_t i = cur_size; i < cur_elems.value.size(); ++i)
{
cur_elems.value[i] += cur_elems.sum;
}
cur_elems.sum += rhs_elems.sum;
}
void serialize(ConstAggregateDataPtr place, WriteBuffer & buf) const override
{
const auto & value = this->data(place).value;
size_t size = value.size();
writeVarUInt(size, buf);
buf.write(reinterpret_cast<const char *>(value.data()), size * sizeof(value[0]));
}
void deserialize(AggregateDataPtr place, ReadBuffer & buf, Arena * arena) const override
{
size_t size = 0;
readVarUInt(size, buf);
if (unlikely(size > AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE))
throw Exception("Too large array size", ErrorCodes::TOO_LARGE_ARRAY_SIZE);
auto & value = this->data(place).value;
value.resize(size, arena);
buf.read(reinterpret_cast<char *>(value.data()), size * sizeof(value[0]));
}
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const auto & value = this->data(place).value;
size_t size = value.size();
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
offsets_to.push_back(offsets_to.back() + size);
if (size)
{
typename ColVecResult::Container & data_to = static_cast<ColVecResult &>(arr_to.getData()).getData();
if (!limit_num_elems)
{
data_to.insert(this->data(place).value.begin(), this->data(place).value.end());
}
else
{
size_t i = 0;
for (; i < std::min(static_cast<size_t>(win_size), size); ++i)
{
data_to.push_back(value[i]);
}
for (; i < size; ++i)
{
data_to.push_back(value[i] - value[i - win_size]);
}
}
}
}
bool allocatesMemoryInArena() const override
{
return true;
}
const char * getHeaderFilePath() const override { return __FILE__; }
};
#undef AGGREGATE_FUNCTION_GROUP_ARRAY_MAX_ARRAY_SIZE
}

View File

@ -108,6 +108,15 @@ static IAggregateFunction * createWithDecimalType(const IDataType & argument_typ
return nullptr;
}
template <template <typename, typename> class AggregateFunctionTemplate, typename Data, typename... TArgs>
static IAggregateFunction * createWithDecimalType(const IDataType & argument_type, TArgs && ... args)
{
WhichDataType which(argument_type);
if (which.idx == TypeIndex::Decimal32) return new AggregateFunctionTemplate<Decimal32, Data>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Decimal64) return new AggregateFunctionTemplate<Decimal64, Data>(std::forward<TArgs>(args)...);
if (which.idx == TypeIndex::Decimal128) return new AggregateFunctionTemplate<Decimal128, Data>(std::forward<TArgs>(args)...);
return nullptr;
}
/** For template with two arguments.
*/

View File

@ -31,6 +31,7 @@ void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
void registerAggregateFunctionMovingSum(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -74,6 +75,7 @@ void registerAggregateFunctions()
registerAggregateFunctionMLMethod(factory);
registerAggregateFunctionEntropy(factory);
registerAggregateFunctionSimpleLinearRegression(factory);
registerAggregateFunctionMovingSum(factory);
}
{