From 57db1fac5990a7227e720c9dd438d88a381d298f Mon Sep 17 00:00:00 2001 From: hcz Date: Wed, 12 Jun 2019 15:46:36 +0800 Subject: [PATCH 1/6] Add aggregate function combinator Resample --- .../AggregateFunctionFactory.cpp | 8 +- .../AggregateFunctionResample.cpp | 106 ++++++++++ .../AggregateFunctionResample.h | 182 ++++++++++++++++++ .../IAggregateFunctionCombinator.h | 14 +- .../registerAggregateFunctions.cpp | 2 + 5 files changed, 309 insertions(+), 3 deletions(-) create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionResample.cpp create mode 100644 dbms/src/AggregateFunctions/AggregateFunctionResample.h diff --git a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp index 6aeaaef2bfa..ce7adf5b96d 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp @@ -77,6 +77,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get( throw Exception("Logical error: cannot find aggregate function combinator to apply a function to Nullable arguments.", ErrorCodes::LOGICAL_ERROR); DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); + Array nested_parameters = combinator->transformParameters(parameters); AggregateFunctionPtr nested_function; @@ -84,7 +85,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get( /// Combinator will check if nested_function was created. if (name == "count" || std::none_of(argument_types.begin(), argument_types.end(), [](const auto & type) { return type->onlyNull(); })) - nested_function = getImpl(name, nested_types, parameters, recursion_level); + nested_function = getImpl(name, nested_types, nested_parameters, recursion_level); return combinator->transformAggregateFunction(nested_function, argument_types, parameters); } @@ -126,7 +127,10 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl( String nested_name = name.substr(0, name.size() - combinator->getName().size()); DataTypes nested_types = combinator->transformArguments(argument_types); - AggregateFunctionPtr nested_function = get(nested_name, nested_types, parameters, recursion_level + 1); + Array nested_parameters = combinator->transformParameters(parameters); + + AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1); + return combinator->transformAggregateFunction(nested_function, argument_types, parameters); } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp b/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp new file mode 100644 index 00000000000..4ef51e5ee08 --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp @@ -0,0 +1,106 @@ +#include + +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} + +class AggregateFunctionCombinatorResample final : public + IAggregateFunctionCombinator +{ +public: + String getName() const override { + return "Resample"; + } + + DataTypes transformArguments(const DataTypes & arguments) const override + { + if (arguments.empty()) + throw Exception { + "Incorrect number of arguments for aggregate function with " + + getName() + " suffix", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH + }; + + return DataTypes(arguments.begin(), arguments.end() - 1); + } + + Array transformParameters(const Array & params) const override + { + if (params.size() < 3) + throw Exception { + "Incorrect number of parameters for aggregate function with " + + getName() + " suffix", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH + }; + + return Array(params.begin(), params.end() - 3); + } + + AggregateFunctionPtr transformAggregateFunction( + const AggregateFunctionPtr & nested_function, + const DataTypes & arguments, + const Array & params + ) const override + { + for (const Field & param : params) + { + if ( + param.getType() != Field::Types::UInt64 + && param.getType() != Field::Types::Int64 + ) + return nullptr; + } + + WhichDataType which { + arguments.back() + }; + + if ( + which.isNativeUInt() + || which.isDateOrDateTime() + ) + return std::make_shared>( + nested_function, + params.front().get(), + params[1].get(), + params.back().get(), + arguments, + params + ); + + if ( + which.isNativeInt() + || which.isEnum() + || which.isInterval() + ) + return std::make_shared>( + nested_function, + params.front().get(), + params[1].get(), + params.back().get(), + arguments, + params + ); + + // TODO + return nullptr; + } +}; + +void registerAggregateFunctionCombinatorResample( + AggregateFunctionCombinatorFactory & factory +) +{ + factory.registerCombinator( + std::make_shared() + ); +} + +} diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.h b/dbms/src/AggregateFunctions/AggregateFunctionResample.h new file mode 100644 index 00000000000..40ab8b24a1c --- /dev/null +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.h @@ -0,0 +1,182 @@ +#pragma once + +#include +#include +#include + + +namespace DB +{ + +template +class AggregateFunctionResample final : public IAggregateFunctionHelper< + AggregateFunctionResample +> +{ +private: + AggregateFunctionPtr nested_function; + + size_t last_col; + + Key begin; + Key end; + Key step; + + size_t total; + size_t aod; + size_t sod; + +public: + AggregateFunctionResample( + AggregateFunctionPtr nested_function, + Key begin, + Key end, + Key step, + const DataTypes & arguments, + const Array & params + ) : + IAggregateFunctionHelper< + AggregateFunctionResample + > {arguments, params}, + nested_function {nested_function}, + last_col {arguments.size() - 1}, + begin {begin}, + end {end}, + step {step}, + total { + static_cast( + (end - begin + step - (step >= 0 ? 1 : -1)) / step + ) + }, + aod {nested_function->alignOfData()}, + sod {(nested_function->sizeOfData() + aod - 1) / aod * aod} + { + // notice: argument types has been checked before + } + + String getName() const override + { + return nested_function->getName() + "Resample"; + } + + const char * getHeaderFilePath() const override + { + return __FILE__; + } + + bool isState() const override + { + return nested_function->isState(); + } + + bool allocatesMemoryInArena() const override + { + return nested_function->allocatesMemoryInArena(); + } + + bool hasTrivialDestructor() const override + { + return nested_function->hasTrivialDestructor(); + } + + size_t sizeOfData() const override + { + return total * sod; + } + + size_t alignOfData() const override + { + return aod; + } + + void create(AggregateDataPtr place) const override + { + for (size_t i = 0; i < total; ++i) + nested_function->create(place + i * sod); + } + + void destroy(AggregateDataPtr place) const noexcept override + { + for (size_t i = 0; i < total; ++i) + nested_function->destroy(place + i * sod); + } + + void add( + AggregateDataPtr place, + const IColumn ** columns, + size_t row_num, + Arena * arena + ) const override + { + // Key key { + // static_cast *>( + // columns[last_col] + // )->getData()[row_num] + // }; + Key key; + + if constexpr (static_cast(-1) < 0) + key = columns[last_col]->getInt(row_num); + else + key = columns[last_col]->getUInt(row_num); + + size_t pos = (key - begin) / step; + + if (pos >= 0 && pos < total) + nested_function->add(place + pos * sod, columns, row_num, arena); + } + + void merge( + AggregateDataPtr place, + ConstAggregateDataPtr rhs, + Arena * arena + ) const override + { + for (size_t i = 0; i < total; ++i) + nested_function->merge(place + i * sod, rhs + i * sod, arena); + } + + void serialize( + ConstAggregateDataPtr place, + WriteBuffer & buf + ) const override + { + for (size_t i = 0; i < total; ++i) + nested_function->serialize(place + i * sod, buf); + } + + void deserialize( + AggregateDataPtr place, + ReadBuffer & buf, + Arena * arena + ) const override + { + for (size_t i = 0; i < total; ++i) + nested_function->deserialize(place + i * sod, buf, arena); + } + + DataTypePtr getReturnType() const override + { + return std::make_shared( + nested_function->getReturnType() + ); + } + + void insertResultInto( + ConstAggregateDataPtr place, + IColumn & to + ) const override + { + auto & col = static_cast(to); + auto & col_offsets = static_cast( + col.getOffsetsColumn() + ); + + for (size_t i = 0; i < total; ++i) + nested_function->insertResultInto(place + i * sod, col.getData()); + + col_offsets.getData().push_back(col.getData().size()); + } +}; + +} diff --git a/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h b/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h index 0ac9a3d41cd..03e2766dc2c 100644 --- a/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h +++ b/dbms/src/AggregateFunctions/IAggregateFunctionCombinator.h @@ -38,7 +38,19 @@ public: * get the arguments for nested function (ex: UInt64 for sum). * If arguments are not suitable for combined function, throw an exception. */ - virtual DataTypes transformArguments(const DataTypes & arguments) const = 0; + virtual DataTypes transformArguments(const DataTypes & arguments) const + { + return arguments; + } + + /** From the parameters for combined function, + * get the parameters for nested function. + * If arguments are not suitable for combined function, throw an exception. + */ + virtual Array transformParameters(const Array & parameters) const + { + return parameters; + } /** Create combined aggregate function (ex: sumIf) * from nested function (ex: sum) diff --git a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp index 45507b678fb..d94f2ff1f8b 100644 --- a/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp +++ b/dbms/src/AggregateFunctions/registerAggregateFunctions.cpp @@ -38,6 +38,7 @@ void registerAggregateFunctionCombinatorForEach(AggregateFunctionCombinatorFacto void registerAggregateFunctionCombinatorState(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorMerge(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory &); +void registerAggregateFunctionCombinatorResample(AggregateFunctionCombinatorFactory &); void registerAggregateFunctionHistogram(AggregateFunctionFactory & factory); void registerAggregateFunctionRetention(AggregateFunctionFactory & factory); @@ -85,6 +86,7 @@ void registerAggregateFunctions() registerAggregateFunctionCombinatorState(factory); registerAggregateFunctionCombinatorMerge(factory); registerAggregateFunctionCombinatorNull(factory); + registerAggregateFunctionCombinatorResample(factory); } } From 347d3828034ee820fc148b51e56bd49193df0374 Mon Sep 17 00:00:00 2001 From: hcz Date: Wed, 12 Jun 2019 15:47:52 +0800 Subject: [PATCH 2/6] Add tests --- .../00954_resample_combinator.reference | 16 ++++++++++++++++ .../0_stateless/00954_resample_combinator.sql | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 dbms/tests/queries/0_stateless/00954_resample_combinator.reference create mode 100644 dbms/tests/queries/0_stateless/00954_resample_combinator.sql diff --git a/dbms/tests/queries/0_stateless/00954_resample_combinator.reference b/dbms/tests/queries/0_stateless/00954_resample_combinator.reference new file mode 100644 index 00000000000..5238c774aae --- /dev/null +++ b/dbms/tests/queries/0_stateless/00954_resample_combinator.reference @@ -0,0 +1,16 @@ +[11,12,13,14,15,16] +[27,31,35] +[39,48,37] +[19,35,31,27,23,10] +[0,0,0,0,0,0] +[0.5,0.5,0.5] +[0.816496580927726,0.816496580927726,0.5] +[0,0.5,0.5,0.5,0.5,0] +[[11],[12],[13],[14],[15],[16]] +[[13,14],[15,16],[17,18]] +[[12,13,14],[15,16,17],[18,19]] +[[19],[17,18],[15,16],[13,14],[11,12],[10]] +[1,1,1,1,1,1] +[2,2,2] +[3,3,2] +[1,2,2,2,2,1] diff --git a/dbms/tests/queries/0_stateless/00954_resample_combinator.sql b/dbms/tests/queries/0_stateless/00954_resample_combinator.sql new file mode 100644 index 00000000000..21215567b1c --- /dev/null +++ b/dbms/tests/queries/0_stateless/00954_resample_combinator.sql @@ -0,0 +1,16 @@ +select arrayReduce('sumResample(1, 7, 1)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('sumResample(3, 8, 2)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('sumResample(2, 9, 3)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('sumResample(10, -1, -2)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [-0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('stddevPopResample(1, 7, 1)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('stddevPopResample(3, 8, 2)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('stddevPopResample(2, 9, 3)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('stddevPopResample(10, -1, -2)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [-0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('groupArrayResample(1, 7, 1)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('groupArrayResample(3, 8, 2)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('groupArrayResample(2, 9, 3)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('groupArrayResample(10, -1, -2)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [-0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('uniqResample(1, 7, 1)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('uniqResample(3, 8, 2)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('uniqResample(2, 9, 3)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); +select arrayReduce('uniqResample(10, -1, -2)', [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [-0, 1, 2, 3, 4, 5, 6, 7, 8, 9]); From 6b6f2293fc81f6765f65930a2decf87a0c6b49e8 Mon Sep 17 00:00:00 2001 From: hcz Date: Wed, 12 Jun 2019 15:48:40 +0800 Subject: [PATCH 3/6] Change style --- dbms/src/AggregateFunctions/AggregateFunctionResample.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp b/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp index 4ef51e5ee08..ecb805fbf52 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp @@ -22,11 +22,11 @@ public: DataTypes transformArguments(const DataTypes & arguments) const override { if (arguments.empty()) - throw Exception { + throw Exception( "Incorrect number of arguments for aggregate function with " + getName() + " suffix", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH - }; + ); return DataTypes(arguments.begin(), arguments.end() - 1); } @@ -34,11 +34,11 @@ public: Array transformParameters(const Array & params) const override { if (params.size() < 3) - throw Exception { + throw Exception( "Incorrect number of parameters for aggregate function with " + getName() + " suffix", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH - }; + ); return Array(params.begin(), params.end() - 3); } From 45529d8489daf299cebe9c57d7aadc282e005c89 Mon Sep 17 00:00:00 2001 From: hcz Date: Wed, 12 Jun 2019 16:25:29 +0800 Subject: [PATCH 4/6] Fix bugs --- .../AggregateFunctionResample.h | 40 +++++++++++++++---- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.h b/dbms/src/AggregateFunctions/AggregateFunctionResample.h index 40ab8b24a1c..d2073d470f5 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionResample.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.h @@ -8,12 +8,19 @@ namespace DB { +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} + template class AggregateFunctionResample final : public IAggregateFunctionHelper< AggregateFunctionResample > { private: + const size_t MAX_ELEMENTS = 4096; + AggregateFunctionPtr nested_function; size_t last_col; @@ -52,6 +59,25 @@ public: sod {(nested_function->sizeOfData() + aod - 1) / aod * aod} { // notice: argument types has been checked before + if (step == 0) + throw Exception { + "The step given in function " + + getName() + " should not be zero", + ErrorCodes::BAD_ARGUMENTS + }; + + if (total > MAX_ELEMENTS) + throw Exception { + "The range given in function " + + getName() + " contains too many elements", + ErrorCodes::BAD_ARGUMENTS + }; + + if ((step > 0 && end < begin) || (step < 0 && end > begin)) + { + end = begin; + total = 0; + } } String getName() const override @@ -108,11 +134,6 @@ public: Arena * arena ) const override { - // Key key { - // static_cast *>( - // columns[last_col] - // )->getData()[row_num] - // }; Key key; if constexpr (static_cast(-1) < 0) @@ -120,10 +141,15 @@ public: else key = columns[last_col]->getUInt(row_num); + if (step > 0 && (key < begin || key >= end)) + return; + + if (step < 0 && (key > begin || key <= end)) + return; + size_t pos = (key - begin) / step; - if (pos >= 0 && pos < total) - nested_function->add(place + pos * sod, columns, row_num, arena); + nested_function->add(place + pos * sod, columns, row_num, arena); } void merge( From 712aefca2a75e5a0711ac2a80808ea08c069a924 Mon Sep 17 00:00:00 2001 From: hcz Date: Thu, 13 Jun 2019 14:30:59 +0800 Subject: [PATCH 5/6] Fix wrong test file --- .../00954_resample_combinator.reference | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dbms/tests/queries/0_stateless/00954_resample_combinator.reference b/dbms/tests/queries/0_stateless/00954_resample_combinator.reference index 5238c774aae..abcbff8435e 100644 --- a/dbms/tests/queries/0_stateless/00954_resample_combinator.reference +++ b/dbms/tests/queries/0_stateless/00954_resample_combinator.reference @@ -1,16 +1,16 @@ [11,12,13,14,15,16] -[27,31,35] -[39,48,37] +[27,31,17] +[39,48,18] [19,35,31,27,23,10] [0,0,0,0,0,0] -[0.5,0.5,0.5] -[0.816496580927726,0.816496580927726,0.5] +[0.5,0.5,0] +[0.816496580927726,0.816496580927726,0] [0,0.5,0.5,0.5,0.5,0] [[11],[12],[13],[14],[15],[16]] -[[13,14],[15,16],[17,18]] -[[12,13,14],[15,16,17],[18,19]] +[[13,14],[15,16],[17]] +[[12,13,14],[15,16,17],[18]] [[19],[17,18],[15,16],[13,14],[11,12],[10]] [1,1,1,1,1,1] -[2,2,2] -[3,3,2] +[2,2,1] +[3,3,1] [1,2,2,2,2,1] From 0385e0923a8f0e1b1f0465bebaa5a0d0ce37467a Mon Sep 17 00:00:00 2001 From: hcz Date: Fri, 14 Jun 2019 21:20:21 +0800 Subject: [PATCH 6/6] Fix style --- dbms/src/AggregateFunctions/AggregateFunctionResample.cpp | 3 ++- dbms/src/AggregateFunctions/AggregateFunctionResample.h | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp b/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp index ecb805fbf52..c60724a35fa 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.cpp @@ -15,7 +15,8 @@ class AggregateFunctionCombinatorResample final : public IAggregateFunctionCombinator { public: - String getName() const override { + String getName() const override + { return "Resample"; } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionResample.h b/dbms/src/AggregateFunctions/AggregateFunctionResample.h index d2073d470f5..a7bf98ffdc0 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionResample.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionResample.h @@ -60,18 +60,18 @@ public: { // notice: argument types has been checked before if (step == 0) - throw Exception { + throw Exception( "The step given in function " + getName() + " should not be zero", ErrorCodes::BAD_ARGUMENTS - }; + ); if (total > MAX_ELEMENTS) - throw Exception { + throw Exception( "The range given in function " + getName() + " contains too many elements", ErrorCodes::BAD_ARGUMENTS - }; + ); if ((step > 0 && end < begin) || (step < 0 && end > begin)) {