diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp index 571d6f5c0a1..02303b953d9 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.cpp @@ -12,10 +12,10 @@ namespace DB namespace { -AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params) -{ - assertNoParameters(name, params); +using SumMapArgs = std::pair; +SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments) +{ if (arguments.size() < 2) throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); @@ -25,9 +25,11 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con throw Exception("First argument for function " + name + " must be an array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - const DataTypePtr & keys_type = array_type->getNestedType(); + + DataTypePtr keys_type = array_type->getNestedType(); DataTypes values_types; + values_types.reserve(arguments.size() - 1); for (size_t i = 1; i < arguments.size(); ++i) { array_type = checkAndGetDataType(arguments[i].get()); @@ -37,6 +39,15 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con values_types.push_back(array_type->getNestedType()); } + return {std::move(keys_type), std::move(values_types)}; +} + +AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params) +{ + assertNoParameters(name, params); + + auto [keys_type, values_types] = parseArguments(name, arguments); + AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types)); if (!res) res.reset(createWithDecimalType(*keys_type, keys_type, values_types)); @@ -46,11 +57,33 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con return res; } +AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params) +{ + if (params.size() != 1) + throw Exception("Aggregate function " + name + " requires exactly one parameter of Array type.", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + Array keys_to_keep; + if (!params.front().tryGet(keys_to_keep)) + throw Exception("Aggregate function " + name + " requires an Array as parameter.", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + auto [keys_type, values_types] = parseArguments(name, arguments); + + AggregateFunctionPtr res(createWithNumericBasedType(*keys_type, keys_type, values_types, keys_to_keep)); + if (!res) + res.reset(createWithDecimalType(*keys_type, keys_type, values_types, keys_to_keep)); + if (!res) + throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return res; +} } void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory) { factory.registerFunction("sumMap", createAggregateFunctionSumMap); + factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered); } } diff --git a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h index 4a20a314789..1e5f3e38cd2 100644 --- a/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/dbms/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -50,9 +50,9 @@ struct AggregateFunctionSumMapData * ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20]) */ -template -class AggregateFunctionSumMap final : public IAggregateFunctionDataHelper< - AggregateFunctionSumMapData>, AggregateFunctionSumMap> +template +class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper< + AggregateFunctionSumMapData>, Derived> { private: using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; @@ -61,7 +61,7 @@ private: DataTypes values_types; public: - AggregateFunctionSumMap(const DataTypePtr & keys_type, const DataTypes & values_types) + AggregateFunctionSumMapBase(const DataTypePtr & keys_type, const DataTypes & values_types) : keys_type(keys_type), values_types(values_types) {} String getName() const override { return "sumMap"; } @@ -109,6 +109,11 @@ public: array_column.getData().get(values_vec_offset + i, value); const auto & key = keys_vec.getData()[keys_vec_offset + i]; + if (!keepKey(key)) + { + continue; + } + IteratorType it; if constexpr (IsDecimalNumber) { @@ -253,6 +258,43 @@ public: } const char * getHeaderFilePath() const override { return __FILE__; } + + bool keepKey(const T & key) const { return static_cast(*this).keepKey(key); } +}; + +template +class AggregateFunctionSumMap final : public AggregateFunctionSumMapBase> +{ +public: + AggregateFunctionSumMap(const DataTypePtr & keys_type, DataTypes & values_types) + : AggregateFunctionSumMapBase>{keys_type, values_types} + {} + + String getName() const override { return "sumMap"; } + + bool keepKey(const T &) const { return true; } +}; + +template +class AggregateFunctionSumMapFiltered final : public AggregateFunctionSumMapBase> +{ +private: + std::unordered_set keys_to_keep; + +public: + AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type, const DataTypes & values_types, const Array & keys_to_keep_) + : AggregateFunctionSumMapBase>{keys_type, values_types} + { + keys_to_keep.reserve(keys_to_keep_.size()); + for (const Field & f : keys_to_keep_) + { + keys_to_keep.emplace(f.safeGet>()); + } + } + + String getName() const override { return "sumMapFiltered"; } + + bool keepKey(const T & key) const { return keys_to_keep.count(key); } }; } diff --git a/dbms/src/Core/Types.h b/dbms/src/Core/Types.h index 5e2cd47f440..e4882cd64f7 100644 --- a/dbms/src/Core/Types.h +++ b/dbms/src/Core/Types.h @@ -166,3 +166,20 @@ template <> constexpr bool IsDecimalNumber = true; template <> constexpr bool IsDecimalNumber = true; } + +/// Specialization of `std::hash` for the Decimal types. +namespace std +{ + template + struct hash> { size_t operator()(const DB::Decimal & x) const { return hash()(x.value); } }; + + template <> + struct hash + { + size_t operator()(const DB::Decimal128 & x) const + { + return std::hash()(x.value >> 64) + ^ std::hash()(x.value & std::numeric_limits::max()); + } + }; +} diff --git a/dbms/tests/queries/0_stateless/00502_sum_map.reference b/dbms/tests/queries/0_stateless/00502_sum_map.reference index 6da96805974..7bb325be814 100644 --- a/dbms/tests/queries/0_stateless/00502_sum_map.reference +++ b/dbms/tests/queries/0_stateless/00502_sum_map.reference @@ -8,6 +8,8 @@ 2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10]) 2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10] 2000-01-01 00:01:00 [4,5,6,7,8] [10,10,20,10,10] +([1],[10]) +([1,4,8],[10,20,10]) ([1],[1]) ([1],[1]) (['a'],[1]) diff --git a/dbms/tests/queries/0_stateless/00502_sum_map.sql b/dbms/tests/queries/0_stateless/00502_sum_map.sql index e6377155dac..9cf941dd908 100644 --- a/dbms/tests/queries/0_stateless/00502_sum_map.sql +++ b/dbms/tests/queries/0_stateless/00502_sum_map.sql @@ -12,6 +12,9 @@ SELECT sumMapMerge(s) FROM (SELECT sumMapState(statusMap.status, statusMap.reque SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM test.sum_map GROUP BY timeslot ORDER BY timeslot; SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM test.sum_map GROUP BY timeslot ORDER BY timeslot; +SELECT sumMapFiltered([1])(statusMap.status, statusMap.requests) FROM test.sum_map; +SELECT sumMapFiltered([1, 4, 8])(statusMap.status, statusMap.requests) FROM test.sum_map; + DROP TABLE test.sum_map; select sumMap(val, cnt) from ( SELECT [ CAST(1, 'UInt64') ] as val, [1] as cnt ); diff --git a/docs/en/query_language/agg_functions/parametric_functions.md b/docs/en/query_language/agg_functions/parametric_functions.md index 1cbe784e621..15b9c3360fa 100644 --- a/docs/en/query_language/agg_functions/parametric_functions.md +++ b/docs/en/query_language/agg_functions/parametric_functions.md @@ -155,3 +155,8 @@ Solution: Write in the GROUP BY query SearchPhrase HAVING uniqUpTo(4)(UserID) >= ``` [Original article](https://clickhouse.yandex/docs/en/query_language/agg_functions/parametric_functions/) + + +## sumMapFiltered(keys_to_keep)(keys, values) + +Same behavior as [sumMap](reference.md#sumMap) except that an array of keys is passed as a parameter. This can be especially useful when working with a high cardinality of keys.