From 5a5cb238d8b29b8de37baceb290547a2790231fc Mon Sep 17 00:00:00 2001 From: Ildus Kurbangaliev Date: Wed, 26 May 2021 16:47:58 +0200 Subject: [PATCH] Map combinator: add more arguments checks, fix memory align errors, support more key types, add some docs and tests --- .../aggregate-functions/combinators.md | 6 +++ .../AggregateFunctionMap.cpp | 40 +++++++++++-------- src/AggregateFunctions/AggregateFunctionMap.h | 16 ++++---- .../01852_map_combinator.reference | 4 ++ .../0_stateless/01852_map_combinator.sql | 15 +++++++ 5 files changed, 56 insertions(+), 25 deletions(-) diff --git a/docs/en/sql-reference/aggregate-functions/combinators.md b/docs/en/sql-reference/aggregate-functions/combinators.md index 3fc5121ebcc..44615628eef 100644 --- a/docs/en/sql-reference/aggregate-functions/combinators.md +++ b/docs/en/sql-reference/aggregate-functions/combinators.md @@ -25,6 +25,12 @@ Example 2: `uniqArray(arr)` – Counts the number of unique elements in all ‘a -If and -Array can be combined. However, ‘Array’ must come first, then ‘If’. Examples: `uniqArrayIf(arr, cond)`, `quantilesTimingArrayIf(level1, level2)(arr, cond)`. Due to this order, the ‘cond’ argument won’t be an array. +## -Map {#agg-functions-combinator-map} + +The -Map suffix can be appended to any aggregate function. This will create an aggregate function which gets Map type as an argument, and aggregates values of each key of the map separately using the specified aggregate function. The result is also of a Map type. + +Examples: `sumMap(map(1,1))`, `avgMap(map('a', 1))`. + ## -SimpleState {#agg-functions-combinator-simplestate} If you apply this combinator, the aggregate function returns the same value but with a different type. This is a [SimpleAggregateFunction(...)](../../sql-reference/data-types/simpleaggregatefunction.md) that can be stored in a table to work with [AggregatingMergeTree](../../engines/table-engines/mergetree-family/aggregatingmergetree.md) tables. diff --git a/src/AggregateFunctions/AggregateFunctionMap.cpp b/src/AggregateFunctions/AggregateFunctionMap.cpp index 31505b89fe2..09214427ad6 100644 --- a/src/AggregateFunctions/AggregateFunctionMap.cpp +++ b/src/AggregateFunctions/AggregateFunctionMap.cpp @@ -25,19 +25,14 @@ public: const auto * map_type = checkAndGetDataType(arguments[0].get()); if (map_type) { - if (arguments->size() > 1) - throw Exception( - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - getName() + " combinator takes only one map argument"); + if (arguments.size() > 1) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, getName() + " combinator takes only one map argument"); return DataTypes({map_type->getValueType()}); } // we need this part just to pass to redirection for mapped arrays - auto check_func = [](DataTypePtr t) - { - return t->getTypeId() == TypeIndex::Array; - }; + auto check_func = [](DataTypePtr t) { return t->getTypeId() == TypeIndex::Array; }; const auto * tup_type = checkAndGetDataType(arguments[0].get()); if (tup_type) @@ -46,8 +41,8 @@ public: bool arrays_match = arguments.size() == 1 && types.size() >= 2 && std::all_of(types.begin(), types.end(), check_func); if (arrays_match) { - const auto & val_array_type = assert_cast(types[1]); - return DataTypes({val_array_type.getNestedType()}); + const auto * val_array_type = assert_cast(types[1].get()); + return DataTypes({val_array_type->getNestedType()}); } } else @@ -55,7 +50,7 @@ public: bool arrays_match = arguments.size() >= 2 && std::all_of(arguments.begin(), arguments.end(), check_func); if (arrays_match) { - const auto & val_array_type = assert_cast(arguments[1]); + const auto * val_array_type = assert_cast(arguments[1].get()); return DataTypes({val_array_type->getNestedType()}); } } @@ -72,9 +67,9 @@ public: const auto * map_type = checkAndGetDataType(arguments[0].get()); if (map_type) { - auto key_type_id = map_type->getKeyType()->getTypeId(); + const auto & key_type = map_type->getKeyType(); - switch (key_type_id) + switch (key_type->getTypeId()) { case TypeIndex::Enum8: case TypeIndex::Int8: @@ -86,6 +81,10 @@ public: return std::make_shared>(nested_function, arguments); case TypeIndex::Int64: return std::make_shared>(nested_function, arguments); + case TypeIndex::Int128: + return std::make_shared>(nested_function, arguments); + case TypeIndex::Int256: + return std::make_shared>(nested_function, arguments); case TypeIndex::UInt8: return std::make_shared>(nested_function, arguments); case TypeIndex::Date: @@ -96,13 +95,19 @@ public: return std::make_shared>(nested_function, arguments); case TypeIndex::UInt64: return std::make_shared>(nested_function, arguments); - case TypeIndex::UUID: + case TypeIndex::UInt128: return std::make_shared>(nested_function, arguments); + case TypeIndex::UInt256: + return std::make_shared>(nested_function, arguments); + case TypeIndex::UUID: + return std::make_shared>(nested_function, arguments); case TypeIndex::FixedString: case TypeIndex::String: return std::make_shared>(nested_function, arguments); default: - throw Exception{"Illegal columns in arguments for combinator " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; + throw Exception{ + "Map key type " + key_type->getName() + " is not is not supported by combinator " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; } } else @@ -115,9 +120,10 @@ public: auto & aggr_func_factory = AggregateFunctionFactory::instance(); return aggr_func_factory.get(nested_func_name + "MappedArrays", arguments, params, out_properties); } + else + throw Exception{ + "Aggregation '" + nested_func_name + "Map' is not implemented for mapped arrays", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; } - - throw Exception{"Illegal columns in arguments for combinator " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; } }; diff --git a/src/AggregateFunctions/AggregateFunctionMap.h b/src/AggregateFunctions/AggregateFunctionMap.h index 1ae836a13d8..75bb2e75840 100644 --- a/src/AggregateFunctions/AggregateFunctionMap.h +++ b/src/AggregateFunctions/AggregateFunctionMap.h @@ -58,7 +58,11 @@ public: { if (types.empty()) throw Exception( - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " require at least one argument"); + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires at least one argument"); + + if (types.size() > 1) + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Aggregate function " + getName() + " requires only one map argument"); const auto * map_type = checkAndGetDataType(types[0].get()); if (!map_type) @@ -103,7 +107,7 @@ public: if (it == merged_maps.end()) { // create a new place for each key - nested_place = arena->alloc(nested_func->sizeOfData()); + nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData()); nested_func->create(nested_place); merged_maps.emplace(key, nested_place); } @@ -157,7 +161,7 @@ public: AggregateDataPtr nested_place; this->data(place).readKey(key, buf); - nested_place = arena->alloc(nested_func->sizeOfData()); + nested_place = arena->alignedAlloc(nested_func->sizeOfData(), nested_func->alignOfData()); nested_func->create(nested_place); merged_maps.emplace(key, nested_place); nested_func->deserialize(nested_place, buf, arena); @@ -175,8 +179,6 @@ public: auto & merged_maps = this->data(place).merged_maps; - size_t res_offset = 0; - // sort the keys std::vector keys; keys.reserve(merged_maps.size()); @@ -189,14 +191,12 @@ public: // insert using sorted keys to result column for (auto & key : keys) { - res_offset++; key_column.insert(key); nested_func->insertResultInto(merged_maps[key], val_column, arena); } IColumn::Offsets & res_offsets = nested_column.getOffsets(); - auto last_offset = res_offsets[res_offsets.size() - 1]; - res_offsets.push_back(last_offset + res_offset); + res_offsets.push_back(val_column.size()); } bool allocatesMemoryInArena() const override { return true; } diff --git a/tests/queries/0_stateless/01852_map_combinator.reference b/tests/queries/0_stateless/01852_map_combinator.reference index 59a2d22933b..7c0648ccb65 100644 --- a/tests/queries/0_stateless/01852_map_combinator.reference +++ b/tests/queries/0_stateless/01852_map_combinator.reference @@ -26,5 +26,9 @@ Map(UInt16,Float64) {1:10,2:10,3:10,4:10,5:10,6:10,7:10,8:10} {'1970-01-01 03:00:01':1} {'a':1} {'1':'2'} +{1:1} +{1:1} +{1:1} +{1:1} {1:1.00000,2:2.00000,3:6.00000,4:8.00000,5:10.00000,6:12.00000,7:7.00000,8:8.00000} {1:1.00000,2:2.00000,3:6.00000,4:8.00000,5:10.00000,6:12.00000,7:7.00000,8:8.00000} diff --git a/tests/queries/0_stateless/01852_map_combinator.sql b/tests/queries/0_stateless/01852_map_combinator.sql index 26911b983ae..20923460eb6 100644 --- a/tests/queries/0_stateless/01852_map_combinator.sql +++ b/tests/queries/0_stateless/01852_map_combinator.sql @@ -29,6 +29,21 @@ select minMap(val) from values ('val Map(Date, Int16)', (map(1, 1)), (map(1, 2) select minMap(val) from values ('val Map(DateTime(\'Europe/Moscow\'), Int32)', (map(1, 1)), (map(1, 2))); select minMap(val) from values ('val Map(Enum16(\'a\'=1), Int16)', (map('a', 1)), (map('a', 2))); select maxMap(val) from values ('val Map(String, String)', (map('1', '1')), (map('1', '2'))); +select minMap(val) from values ('val Map(Int128, Int128)', (map(1, 1)), (map(1, 2))); +select minMap(val) from values ('val Map(Int256, Int256)', (map(1, 1)), (map(1, 2))); +select minMap(val) from values ('val Map(UInt128, UInt128)', (map(1, 1)), (map(1, 2))); +select minMap(val) from values ('val Map(UInt256, UInt256)', (map(1, 1)), (map(1, 2))); + +select sumMap(map(1,2), 1, 2); -- { serverError 42 } +select sumMap(map(1,2), map(1,3)); -- { serverError 42 } + +-- array and tuple arguments +select avgMap([1,1,1], [2,2,2]); -- { serverError 43 } +select minMap((1,1)); -- { serverError 43 } +select minMap(([1,1,1],1)); -- { serverError 43 } +select minMap([1,1,1],1); -- { serverError 43 } +select minMap([1,1,1]); -- { serverError 43 } +select minMap(([1,1,1])); -- { serverError 43 } DROP TABLE IF EXISTS sum_map_decimal;