From 57db1fac5990a7227e720c9dd438d88a381d298f Mon Sep 17 00:00:00 2001 From: hcz Date: Wed, 12 Jun 2019 15:46:36 +0800 Subject: [PATCH] Add aggregate function combinator Resample --- .../AggregateFunctionFactory.cpp | 8 +- .../AggregateFunctionResample.cpp | 106 ++++++++++ .../AggregateFunctionResample.h | 182 ++++++++++++++++++ .../IAggregateFunctionCombinator.h | 14 +- .../registerAggregateFunctions.cpp | 2 + 5 files changed, 309 insertions(+), 3 deletions(-) create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionResample.cpp create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionResample.h diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index 6aeaaef2bfa..ce7adf5b96d 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -77,6 +77,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get( throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.", ErrorCodes::LOGICAL_ERROR); DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); + Array nested_parameters = combinator->transformParameters(parameters); AggregateFunctionPtr nested_function; @@ -84,7 +85,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get( /// Combinator will check if nested_function was created. if (name == "count" || std::none_of(argument_types.begin(), argument_types.end(), [](const auto & type) { return type->onlyNull(); })) - nested_function = getImpl(name, nested_types, parameters, recursion_level); + nested_function = getImpl(name, nested_types, nested_parameters, recursion_level); return combinator->transformAggregateFunction(nested_function, argument_types, parameters); } @@ -126,7 +127,10 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl( String nested_name = name.substr(0, name.size() - combinator->getName().size()); DataTypes nested_types = combinator->transformArguments(argument_types); - AggregateFunctionPtr nested_function = get(nested_name, nested_types, parameters, recursion_level + 1); + Array nested_parameters = combinator->transformParameters(parameters); + + AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1); + return combinator->transformAggregateFunction(nested_function, argument_types, parameters); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp b/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp new file mode 100644 index 00000000000..4ef51e5ee08 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp @@ -0,0 +1,106 @@ +#include + +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +class AggregateFunctionCombinatorResample final : public + IAggregateFunctionCombinator +{ +public: + String getName() const override { + return "Resample"; + } + + DataTypes transformArguments(const DataTypes & arguments) const override + { + if (arguments.empty()) + throw Exception { + "Incorrect number of arguments for aggregate function with " + + getName() + " suffix", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH + }; + + return DataTypes(arguments.begin(), arguments.end() - 1); + } + + Array transformParameters(const Array & params) const override + { + if (params.size() < 3) + throw Exception { + "Incorrect number of parameters for aggregate function with " + + getName() + " suffix", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH + }; + + return Array(params.begin(), params.end() - 3); + } + + AggregateFunctionPtr transformAggregateFunction( + const AggregateFunctionPtr & nested_function, + const DataTypes & arguments, + const Array & params + ) const override + { + for (const Field & param : params) + { + if ( + param.getType() != Field::Types::UInt64 + && param.getType() != Field::Types::Int64 + ) + return nullptr; + } + + WhichDataType which { + arguments.back() + }; + + if ( + which.isNativeUInt() + || which.isDateOrDateTime() + ) + return std::make_shared>( + nested_function, + params.front().get(), + params[1].get(), + params.back().get(), + arguments, + params + ); + + if ( + which.isNativeInt() + || which.isEnum() + || which.isInterval() + ) + return std::make_shared>( + nested_function, + params.front().get(), + params[1].get(), + params.back().get(), + arguments, + params + ); + + // TODO + return nullptr; + } +}; + +void registerAggregateFunctionCombinatorResample( + AggregateFunctionCombinatorFactory & factory +) +{ + factory.registerCombinator( + std::make_shared() + ); +} + +} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.h b/dbms/src/AggregateFunctions/AggregateFunctionResample.h new file mode 100644 index 00000000000..40ab8b24a1c --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.h @@ -0,0 +1,182 @@ +#pragma once + +#include +#include +#include + + +namespace DB +{ + +template +class AggregateFunctionResample final : public IAggregateFunctionHelper< + AggregateFunctionResample +> +{ +private: + AggregateFunctionPtr nested_function; + + size_t last_col; + + Key begin; + Key end; + Key step; + + size_t total; + size_t aod; + size_t sod; + +public: + AggregateFunctionResample( + AggregateFunctionPtr nested_function, + Key begin, + Key end, + Key step, + const DataTypes & arguments, + const Array & params + ) : + IAggregateFunctionHelper< + AggregateFunctionResample + > {arguments, params}, + nested_function {nested_function}, + last_col {arguments.size() - 1}, + begin {begin}, + end {end}, + step {step}, + total { + static_cast( + (end - begin + step - (step >= 0 ? 1 : -1)) / step + ) + }, + aod {nested_function->alignOfData()}, + sod {(nested_function->sizeOfData() + aod - 1) / aod * aod} + { + // notice: argument types has been checked before + } + + String getName() const override + { + return nested_function->getName() + "Resample"; + } + + const char * getHeaderFilePath() const override + { + return __FILE__; + } + + bool isState() const override + { + return nested_function->isState(); + } + + bool allocatesMemoryInArena() const override + { + return nested_function->allocatesMemoryInArena(); + } + + bool hasTrivialDestructor() const override + { + return nested_function->hasTrivialDestructor(); + } + + size_t sizeOfData() const override + { + return total * sod; + } + + size_t alignOfData() const override + { + return aod; + } + + void create(AggregateDataPtr place) const override + { + for (size_t i = 0; i < total; ++i) + nested_function->create(place + i * sod); + } + + void destroy(AggregateDataPtr place) const noexcept override + { + for (size_t i = 0; i < total; ++i) + nested_function->destroy(place + i * sod); + } + + void add( + AggregateDataPtr place, + const IColumn ** columns, + size_t row_num, + Arena * arena + ) const override + { + // Key key { + // static_cast *>( + // columns[last_col] + // )->getData()[row_num] + // }; + Key key; + + if constexpr (static_cast(-1) < 0) + key = columns[last_col]->getInt(row_num); + else + key = columns[last_col]->getUInt(row_num); + + size_t pos = (key - begin) / step; + + if (pos >= 0 && pos < total) + nested_function->add(place + pos * sod, columns, row_num, arena); + } + + void merge( + AggregateDataPtr place, + ConstAggregateDataPtr rhs, + Arena * arena + ) const override + { + for (size_t i = 0; i < total; ++i) + nested_function->merge(place + i * sod, rhs + i * sod, arena); + } + + void serialize( + ConstAggregateDataPtr place, + WriteBuffer & buf + ) const override + { + for (size_t i = 0; i < total; ++i) + nested_function->serialize(place + i * sod, buf); + } + + void deserialize( + AggregateDataPtr place, + ReadBuffer & buf, + Arena * arena + ) const override + { + for (size_t i = 0; i < total; ++i) + nested_function->deserialize(place + i * sod, buf, arena); + } + + DataTypePtr getReturnType() const override + { + return std::make_shared( + nested_function->getReturnType() + ); + } + + void insertResultInto( + ConstAggregateDataPtr place, + IColumn & to + ) const override + { + auto & col = static_cast(to); + auto & col_offsets = static_cast( + col.getOffsetsColumn() + ); + + for (size_t i = 0; i < total; ++i) + nested_function->insertResultInto(place + i * sod, col.getData()); + + col_offsets.getData().push_back(col.getData().size()); + } +}; + +} diff --git a/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h b/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h index 0ac9a3d41cd..03e2766dc2c 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h +++ b/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h @@ -38,7 +38,19 @@ public: * get the arguments for nested function (ex: UInt64 for sum). * If arguments are not suitable for combined function, throw an exception. */ - virtual DataTypes transformArguments(const DataTypes & arguments) const = 0; + virtual DataTypes transformArguments(const DataTypes & arguments) const + { + return arguments; + } + + /** From the parameters for combined function, + * get the parameters for nested function. + * If arguments are not suitable for combined function, throw an exception. + */ + virtual Array transformParameters(const Array & parameters) const + { + return parameters; + } /** Create combined aggregate function (ex: sumIf) * from nested function (ex: sum) diff --git a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp index 45507b678fb..d94f2ff1f8b 100644 --- a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -38,6 +38,7 @@ void registerAggregateFunctionCombinatorForEach(AggregateFunctionCombinatorFacto void registerAggregateFunctionCombinatorState(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorMerge(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory &); +void registerAggregateFunctionCombinatorResample(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionHistogram(AggregateFunctionFactory & factory); void registerAggregateFunctionRetention(AggregateFunctionFactory & factory); @@ -85,6 +86,7 @@ void registerAggregateFunctions() registerAggregateFunctionCombinatorState(factory); registerAggregateFunctionCombinatorMerge(factory); registerAggregateFunctionCombinatorNull(factory); + registerAggregateFunctionCombinatorResample(factory); } }