diff --git a/docs/en/sql_reference/aggregate_functions/reference.md b/docs/en/sql_reference/aggregate_functions/reference.md index d7bc8e963e2..f20e32db14e 100644 --- a/docs/en/sql_reference/aggregate_functions/reference.md +++ b/docs/en/sql_reference/aggregate_functions/reference.md @@ -330,9 +330,10 @@ Computes the sum of the numbers, using the same data type for the result as for Only works for numbers. -## sumMap(key, value) {#agg_functions-summap} +## sumMap(key, value), sumMap(Tuple(key, value)) {#agg_functions-summap} Totals the ‘value’ array according to the keys specified in the ‘key’ array. +Passing tuple of keys and values arrays is synonymical to passing two arrays of keys and values. The number of elements in ‘key’ and ‘value’ must be the same for each row that is totaled. Returns a tuple of two arrays: keys in sorted order, and values ​​summed for the corresponding keys. @@ -345,25 +346,28 @@ CREATE TABLE sum_map( statusMap Nested( status UInt16, requests UInt64 - ) + ), + statusMapTuple Tuple(Array(Int32), Array(Int32)) ) ENGINE = Log; INSERT INTO sum_map VALUES - ('2000-01-01', '2000-01-01 00:00:00', [1, 2, 3], [10, 10, 10]), - ('2000-01-01', '2000-01-01 00:00:00', [3, 4, 5], [10, 10, 10]), - ('2000-01-01', '2000-01-01 00:01:00', [4, 5, 6], [10, 10, 10]), - ('2000-01-01', '2000-01-01 00:01:00', [6, 7, 8], [10, 10, 10]); + ('2000-01-01', '2000-01-01 00:00:00', [1, 2, 3], [10, 10, 10], ([1, 2, 3], [10, 10, 10])), + ('2000-01-01', '2000-01-01 00:00:00', [3, 4, 5], [10, 10, 10], ([3, 4, 5], [10, 10, 10])), + ('2000-01-01', '2000-01-01 00:01:00', [4, 5, 6], [10, 10, 10], ([4, 5, 6], [10, 10, 10])), + ('2000-01-01', '2000-01-01 00:01:00', [6, 7, 8], [10, 10, 10], ([6, 7, 8], [10, 10, 10])); + SELECT timeslot, - sumMap(statusMap.status, statusMap.requests) + sumMap(statusMap.status, statusMap.requests), + sumMap(statusMapTuple) FROM sum_map GROUP BY timeslot ``` ``` text -┌────────────timeslot─┬─sumMap(statusMap.status, statusMap.requests)─┐ -│ 2000-01-01 00:00:00 │ ([1,2,3,4,5],[10,10,20,10,10]) │ -│ 2000-01-01 00:01:00 │ ([4,5,6,7,8],[10,10,20,10,10]) │ -└─────────────────────┴──────────────────────────────────────────────┘ +┌────────────timeslot─┬─sumMap(statusMap.status, statusMap.requests)─┬─sumMap(statusMapTuple)─────────┐ +│ 2000-01-01 00:00:00 │ ([1,2,3,4,5],[10,10,20,10,10]) │ ([1,2,3,4,5],[10,10,20,10,10]) │ +│ 2000-01-01 00:01:00 │ ([4,5,6,7,8],[10,10,20,10,10]) │ ([4,5,6,7,8],[10,10,20,10,10]) │ +└─────────────────────┴──────────────────────────────────────────────┴────────────────────────────────┘ ``` ## skewPop {#skewpop} diff --git a/src/AggregateFunctions/AggregateFunctionSumMap.cpp b/src/AggregateFunctions/AggregateFunctionSumMap.cpp index 5bedf72c39b..6191b26a855 100644 --- a/src/AggregateFunctions/AggregateFunctionSumMap.cpp +++ b/src/AggregateFunctions/AggregateFunctionSumMap.cpp @@ -52,23 +52,37 @@ using SumMapArgs = std::pair; SumMapArgs parseArguments(const std::string & name, const DataTypes & arguments) { - if (arguments.size() < 2) - throw Exception("Aggregate function " + name + " requires at least two arguments of Array type.", + DataTypes args; + + if (arguments.size() == 1) + { + const auto * tuple_type = checkAndGetDataType(arguments[0].get()); + if (!tuple_type) + throw Exception("When function " + name + " gets one argument it must be a tuple", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + const auto elems = tuple_type->getElements(); + args.insert(args.end(), elems.begin(), elems.end()); + } + else + args.insert(args.end(), arguments.begin(), arguments.end()); + + if (args.size() < 2) + throw Exception("Aggregate function " + name + " requires at least two arguments of Array type or one argument of tuple of two arrays", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - const auto * array_type = checkAndGetDataType(arguments[0].get()); + const auto * array_type = checkAndGetDataType(args[0].get()); if (!array_type) - throw Exception("First argument for function " + name + " must be an array.", + throw Exception("First argument for function " + name + " must be an array, not " + args[0]->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - DataTypePtr keys_type = array_type->getNestedType(); DataTypes values_types; - values_types.reserve(arguments.size() - 1); - for (size_t i = 1; i < arguments.size(); ++i) + values_types.reserve(args.size() - 1); + for (size_t i = 1; i < args.size(); ++i) { - array_type = checkAndGetDataType(arguments[i].get()); + array_type = checkAndGetDataType(args[i].get()); if (!array_type) throw Exception("Argument #" + toString(i) + " for function " + name + " must be an array.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); diff --git a/src/AggregateFunctions/AggregateFunctionSumMap.h b/src/AggregateFunctions/AggregateFunctionSumMap.h index 88f99b73841..88ee10f4627 100644 --- a/src/AggregateFunctions/AggregateFunctionSumMap.h +++ b/src/AggregateFunctions/AggregateFunctionSumMap.h @@ -80,6 +80,18 @@ public: void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override { + // Check if tuple + std::unique_ptr tuple_columns; + auto tuple_col = checkAndGetColumn(columns[0]); + if (tuple_col) + { + tuple_columns.reset(new const IColumn*[tuple_col->tupleSize()]); + for (size_t i = 0; i < tuple_col->tupleSize(); i++) + tuple_columns.get()[i] = &const_cast(tuple_col->getColumn(i)); + + columns = tuple_columns.get(); + } + // Column 0 contains array of keys of known type Field key_field; const ColumnArray & array_column0 = assert_cast(*columns[0]); diff --git a/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp b/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp index 3072e4a40c1..bf22845a5f6 100644 --- a/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp +++ b/src/DataTypes/DataTypeCustomSimpleAggregateFunction.cpp @@ -30,7 +30,7 @@ namespace ErrorCodes extern const int LOGICAL_ERROR; } -static const std::vector supported_functions{"any", "anyLast", "min", "max", "sum", "groupBitAnd", "groupBitOr", "groupBitXor"}; +static const std::vector supported_functions{"any", "anyLast", "min", "max", "sum", "groupBitAnd", "groupBitOr", "groupBitXor", "sumMap"}; String DataTypeCustomSimpleAggregateFunction::getName() const diff --git a/tests/queries/0_stateless/00502_sum_map.reference b/tests/queries/0_stateless/00502_sum_map.reference index 6cddf662424..0002c43945a 100644 --- a/tests/queries/0_stateless/00502_sum_map.reference +++ b/tests/queries/0_stateless/00502_sum_map.reference @@ -4,6 +4,7 @@ 2000-01-01 2000-01-01 00:01:00 [6,7,8] [10,10,10] ([1,2,3,4,5,6,7,8],[10,10,20,20,20,20,10,10]) ([1,2,3,4,5,6,7,8],[10,10,20,20,20,20,10,10]) +([1,2,3,4,5,6,7,8],[10,10,20,20,20,20,10,10]) 2000-01-01 00:00:00 ([1,2,3,4,5],[10,10,20,10,10]) 2000-01-01 00:01:00 ([4,5,6,7,8],[10,10,20,10,10]) 2000-01-01 00:00:00 [1,2,3,4,5] [10,10,20,10,10] diff --git a/tests/queries/0_stateless/00502_sum_map.sql b/tests/queries/0_stateless/00502_sum_map.sql index dba8bb5549f..6a4035a3782 100644 --- a/tests/queries/0_stateless/00502_sum_map.sql +++ b/tests/queries/0_stateless/00502_sum_map.sql @@ -7,6 +7,7 @@ INSERT INTO sum_map VALUES ('2000-01-01', '2000-01-01 00:00:00', [1, 2, 3], [10, SELECT * FROM sum_map ORDER BY timeslot; SELECT sumMap(statusMap.status, statusMap.requests) FROM sum_map; +SELECT sumMap((statusMap.status, statusMap.requests)) FROM sum_map; SELECT sumMapMerge(s) FROM (SELECT sumMapState(statusMap.status, statusMap.requests) AS s FROM sum_map); SELECT timeslot, sumMap(statusMap.status, statusMap.requests) FROM sum_map GROUP BY timeslot ORDER BY timeslot; SELECT timeslot, sumMap(statusMap.status, statusMap.requests).1, sumMap(statusMap.status, statusMap.requests).2 FROM sum_map GROUP BY timeslot ORDER BY timeslot; diff --git a/tests/queries/0_stateless/00915_simple_aggregate_function.reference b/tests/queries/0_stateless/00915_simple_aggregate_function.reference index 6fb5e7b3744..d9e0a92cb01 100644 --- a/tests/queries/0_stateless/00915_simple_aggregate_function.reference +++ b/tests/queries/0_stateless/00915_simple_aggregate_function.reference @@ -39,6 +39,6 @@ SimpleAggregateFunction(sum, Float64) 7 14 8 16 9 18 -1 1 2 2.2.2.2 3 -10 2222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222 20 20.20.20.20 5 -SimpleAggregateFunction(anyLast, Nullable(String)) SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) SimpleAggregateFunction(anyLast, IPv4) SimpleAggregateFunction(groupBitOr, UInt32) +1 1 2 2.2.2.2 3 ([1,2,3],[2,1,1]) +10 2222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222 20 20.20.20.20 5 ([2,3,4],[2,1,1]) +SimpleAggregateFunction(anyLast, Nullable(String)) SimpleAggregateFunction(anyLast, LowCardinality(Nullable(String))) SimpleAggregateFunction(anyLast, IPv4) SimpleAggregateFunction(groupBitOr, UInt32) SimpleAggregateFunction(sumMap, Tuple(Array(Int32), Array(Int64))) diff --git a/tests/queries/0_stateless/00915_simple_aggregate_function.sql b/tests/queries/0_stateless/00915_simple_aggregate_function.sql index 030893e3ea1..1866e2bc8c5 100644 --- a/tests/queries/0_stateless/00915_simple_aggregate_function.sql +++ b/tests/queries/0_stateless/00915_simple_aggregate_function.sql @@ -24,16 +24,17 @@ create table simple ( nullable_str SimpleAggregateFunction(anyLast,Nullable(String)), low_str SimpleAggregateFunction(anyLast,LowCardinality(Nullable(String))), ip SimpleAggregateFunction(anyLast,IPv4), - status SimpleAggregateFunction(groupBitOr, UInt32) + status SimpleAggregateFunction(groupBitOr, UInt32), + tup SimpleAggregateFunction(sumMap, Tuple(Array(Int32), Array(Int64))) ) engine=AggregatingMergeTree order by id; -insert into simple values(1,'1','1','1.1.1.1', 1); -insert into simple values(1,null,'2','2.2.2.2', 2); +insert into simple values(1,'1','1','1.1.1.1', 1, ([1,2], [1,1])); +insert into simple values(1,null,'2','2.2.2.2', 2, ([1,3], [1,1])); -- String longer then MAX_SMALL_STRING_SIZE (actual string length is 100) -insert into simple values(10,'10','10','10.10.10.10', 4); -insert into simple values(10,'2222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222','20','20.20.20.20', 1); +insert into simple values(10,'10','10','10.10.10.10', 4, ([2,3], [1,1])); +insert into simple values(10,'2222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222','20','20.20.20.20', 1, ([2, 4], [1,1])); select * from simple final; -select toTypeName(nullable_str),toTypeName(low_str),toTypeName(ip),toTypeName(status) from simple limit 1; +select toTypeName(nullable_str),toTypeName(low_str),toTypeName(ip),toTypeName(status), toTypeName(tup) from simple limit 1; optimize table simple final;