From 8ad1a55f3bfe1015f41fb76aecc98df221d08b08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Ercolanelli?= Date: Tue, 22 Jan 2019 17:47:43 +0100 Subject: [PATCH] implement sumMapFiltered --- .../AggregateFunctionSumMap.cpp | 35 ++++++++++-- .../AggregateFunctionSumMap.h | 54 +++++++++++++++++-- .../0_stateless/00502_sum_map.reference | 2 + .../queries/0_stateless/00502_sum_map.sql | 3 ++ .../agg_functions/parametric_functions.md | 5 ++ 5 files changed, 92 insertions(+), 7 deletions(-) diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp index 571d6f5c0a1..5138d8f1f02 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp @@ -12,10 +12,10 @@ namespace DB namespace { -AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params) -{ - assertNoParameters(name, params); +using SumMapArgs = std::pair; +SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments) +{ if (arguments.size() < 2) throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); @@ -25,6 +25,7 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con throw Exception("First argument for function " + name + " must be an array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + const DataTypePtr & keys_type = array_type->getNestedType(); DataTypes values_types; @@ -37,6 +38,15 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con values_types.push_back(array_type->getNestedType()); } + return {keys_type, std::move(values_types)}; +} + +AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params) +{ + assertNoParameters(name, params); + + auto [keys_type, values_types] = parseArguments(name, arguments); + AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types)); if (!res) res.reset(createWithDecimalType(*keys_type, keys_type, values_types)); @@ -46,11 +56,30 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con return res; } +AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params) +{ + if (params.size() != 1) + throw Exception("Aggregate function " + name + "requires exactly one parameter of Array type.", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + Array keys_to_keep = params.front().safeGet(); + + auto [keys_type, values_types] = parseArguments(name, arguments); + + AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types, keys_to_keep)); + if (!res) + res.reset(createWithDecimalType(*keys_type, keys_type, values_types, keys_to_keep)); + if (!res) + throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; +} } void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory) { factory.registerFunction("sumMap", createAggregateFunctionSumMap); + factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered); } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h index 4a20a314789..e9c70eaa5f1 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -50,9 +50,9 @@ struct AggregateFunctionSumMapData * ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20]) */ -template -class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper< - AggregateFunctionSumMapData>, AggregateFunctionSumMap> +template +class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper< + AggregateFunctionSumMapData>, Derived> { private: using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; @@ -61,7 +61,7 @@ private: DataTypes values_types; public: - AggregateFunctionSumMap(const DataTypePtr & keys_type, const DataTypes & values_types) + AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types) : keys_type(keys_type), values_types(values_types) {} String getName() const override { return "sumMap"; } @@ -109,6 +109,11 @@ public: array_column.getData().get(values_vec_offset + i, value); const auto & key = keys_vec.getData()[keys_vec_offset + i]; + if (!keepKey(key)) + { + continue; + } + IteratorType it; if constexpr (IsDecimalNumber) { @@ -253,6 +258,47 @@ public: } const char * getHeaderFilePath() const override { return __FILE__; } + + virtual bool keepKey(const T & key) const = 0; +}; + +template +class AggregateFunctionSumMap final : public AggregateFunctionSumMapBase> +{ +public: + AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types) + : AggregateFunctionSumMapBase>{keys_type, values_types} + {} + + String getName() const override { return "sumMap"; } + + bool keepKey(const T &) const override { return true; } +}; + +template +class AggregateFunctionSumMapFiltered final : public AggregateFunctionSumMapBase> +{ +private: + std::vector keys_to_keep; + +public: + AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep) + : AggregateFunctionSumMapBase>{keys_type, values_types} + { + this->keys_to_keep.reserve(keys_to_keep.size()); + for (const Field & f : keys_to_keep) + { + this->keys_to_keep.emplace_back(f.safeGet>()); + } + std::sort(begin(this->keys_to_keep), end(this->keys_to_keep)); + } + + String getName() const override { return "sumMapFiltered"; } + + bool keepKey(const T & key) const override + { + return std::binary_search(begin(keys_to_keep), end(keys_to_keep), key); + } }; } diff --git a/dbms/tests/queries/0_stateless/00502_sum_map.reference b/dbms/tests/queries/0_stateless/00502_sum_map.reference index 6da96805974..ac5678ebeab 100644 --- a/dbms/tests/queries/0_stateless/00502_sum_map.reference +++ b/dbms/tests/queries/0_stateless/00502_sum_map.reference @@ -8,6 +8,8 @@ 2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10]) 2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10] 2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10] +([1],[10]) +([1, 4, 8], [10, 20, 10]) ([1],[1]) ([1],[1]) (['a'],[1]) diff --git a/dbms/tests/queries/0_stateless/00502_sum_map.sql b/dbms/tests/queries/0_stateless/00502_sum_map.sql index e6377155dac..9cf941dd908 100644 --- a/dbms/tests/queries/0_stateless/00502_sum_map.sql +++ b/dbms/tests/queries/0_stateless/00502_sum_map.sql @@ -12,6 +12,9 @@ SELECT sumMapMerge(s) FROM (SELECT sumMapState(statusMap.status, statusMap.reque SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot; SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot; +SELECT sumMapFiltered([1])(statusMap.status, statusMap.requests) FROM test.sum_map; +SELECT sumMapFiltered([1, 4, 8])(statusMap.status, statusMap.requests) FROM test.sum_map; + DROP TABLE test.sum_map; select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt ); diff --git a/docs/en/query_language/agg_functions/parametric_functions.md b/docs/en/query_language/agg_functions/parametric_functions.md index 1cbe784e621..15b9c3360fa 100644 --- a/docs/en/query_language/agg_functions/parametric_functions.md +++ b/docs/en/query_language/agg_functions/parametric_functions.md @@ -155,3 +155,8 @@ Solution: Write in the GROUP BY query SearchPhrase HAVING uniqUpTo(4)(UserID) >= ``` [Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) + + +## sumMapFiltered(keys_to_keep)(keys, values) + +Same behavior as [sumMap](reference.md#sumMap) except that an array of keys is passed as a parameter. This can be especially useful when working with a high cardinality of keys.