From 939a15d29a0e2b2d4739d0b7fb32e2bfdc4f0560 Mon Sep 17 00:00:00 2001 From: vdimir Date: Thu, 17 Feb 2022 15:40:26 +0000 Subject: [PATCH] Upd FunctionArrayMapped for Map --- src/Functions/array/FunctionArrayMapped.h | 19 +++++++++- src/Functions/mapFilter.cpp | 38 +++++++++---------- .../registerFunctionsHigherOrder.cpp | 2 - .../0_stateless/02169_map_functions.reference | 5 ++- .../0_stateless/02169_map_functions.sql | 19 +++++++--- 5 files changed, 53 insertions(+), 30 deletions(-) diff --git a/src/Functions/array/FunctionArrayMapped.h b/src/Functions/array/FunctionArrayMapped.h index 32fccd89244..28540354b94 100644 --- a/src/Functions/array/FunctionArrayMapped.h +++ b/src/Functions/array/FunctionArrayMapped.h @@ -180,11 +180,28 @@ public: throw Exception("Expression for function " + getName() + " must return UInt8, found " + return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + static_assert( + std::is_same_v || + std::is_same_v, + "unsupported type"); + + if (arguments.size() < 2) + { + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "{}", arguments.size()); + } + const auto * first_array_type = checkAndGetDataType(arguments[1].type.get()); + + if (!first_array_type) + throw DB::Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Unsupported type {}", arguments[1].type->getName()); + if constexpr (std::is_same_v) return Impl::getReturnType(return_type, first_array_type->getNestedType()); - else + + if constexpr (std::is_same_v) return Impl::getReturnType(return_type, first_array_type->getKeyValueTypes()); + + throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unreachable code reached"); } } diff --git a/src/Functions/mapFilter.cpp b/src/Functions/mapFilter.cpp index 2308b22d3db..78a7934b2ba 100644 --- a/src/Functions/mapFilter.cpp +++ b/src/Functions/mapFilter.cpp @@ -13,6 +13,7 @@ namespace DB namespace ErrorCodes { extern const int ILLEGAL_COLUMN; + extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; } @@ -29,6 +30,8 @@ struct MapFilterImpl using data_type = DataTypeMap; using column_type = ColumnMap; + static constexpr auto name = "mapFilter"; + static bool needBoolean() { return true; } static bool needExpression() { return true; } static bool needOneArray() { return false; } @@ -85,14 +88,6 @@ struct MapFilterImpl } }; -struct NameMapFilter { static constexpr auto name = "mapFilter"; }; -using FunctionMapFilter = FunctionArrayMapped; - -void registerFunctionMapFilter(FunctionFactory & factory) -{ - factory.registerFunction(); -} - /** mapApply((k,v) -> expression, map) - apply the expression to the map. */ @@ -101,6 +96,8 @@ struct MapApplyImpl using data_type = DataTypeMap; using column_type = ColumnMap; + static constexpr auto name = "mapApply"; + /// true if the expression (for an overload of f(expression, maps)) or a map (for f(map)) should be boolean. static bool needBoolean() { return false; } static bool needExpression() { return true; } @@ -108,12 +105,15 @@ struct MapApplyImpl static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypes & /*elems*/) { - const auto & tuple_types = typeid_cast(&*expression_return)->getElements(); - if (tuple_types.size() != 2) - throw Exception("Expected 2 columns as map's key and value, but found " - + toString(tuple_types.size()) + " columns", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + const auto * tuple_types = typeid_cast(expression_return.get()); + if (!tuple_types) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Expected return type is tuple, got {}", expression_return->getName()); + if (tuple_types->getElements().size() != 2) + throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + "Expected 2 columns as map's key and value, but found {}", tuple_types->getElements().size()); - return std::make_shared(tuple_types); + return std::make_shared(tuple_types->getElements()); } static ColumnPtr execute(const ColumnMap & map, ColumnPtr mapped) @@ -123,9 +123,9 @@ struct MapApplyImpl { const ColumnConst * column_const_tuple = checkAndGetColumnConst(mapped.get()); if (!column_const_tuple) - throw Exception("Expected tuple column, found " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN); - ColumnPtr column_tuple_ptr = recursiveRemoveLowCardinality(column_const_tuple->convertToFullColumn()); - column_tuple = checkAndGetColumn(column_tuple_ptr.get()); + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Expected tuple column, found {}", mapped->getName()); + auto cols = convertConstTupleToConstantElements(*column_const_tuple); + return ColumnMap::create(cols[0]->convertToFullColumnIfConst(), cols[1]->convertToFullColumnIfConst(), map.getNestedColumn().getOffsetsPtr()); } return ColumnMap::create(column_tuple->getColumnPtr(0), column_tuple->getColumnPtr(1), @@ -133,12 +133,10 @@ struct MapApplyImpl } }; -struct NameMapApply { static constexpr auto name = "mapApply"; }; -using FunctionMapApply = FunctionArrayMapped; - void registerFunctionMapApply(FunctionFactory & factory) { - factory.registerFunction(); + factory.registerFunction>(); + factory.registerFunction>(); } } diff --git a/src/Functions/registerFunctionsHigherOrder.cpp b/src/Functions/registerFunctionsHigherOrder.cpp index 00b300b18b9..00bea58b918 100644 --- a/src/Functions/registerFunctionsHigherOrder.cpp +++ b/src/Functions/registerFunctionsHigherOrder.cpp @@ -18,7 +18,6 @@ void registerFunctionsArraySort(FunctionFactory & factory); void registerFunctionArrayCumSum(FunctionFactory & factory); void registerFunctionArrayCumSumNonNegative(FunctionFactory & factory); void registerFunctionArrayDifference(FunctionFactory & factory); -void registerFunctionMapFilter(FunctionFactory & factory); void registerFunctionMapApply(FunctionFactory & factory); void registerFunctionsHigherOrder(FunctionFactory & factory) @@ -38,7 +37,6 @@ void registerFunctionsHigherOrder(FunctionFactory & factory) registerFunctionArrayCumSum(factory); registerFunctionArrayCumSumNonNegative(factory); registerFunctionArrayDifference(factory); - registerFunctionMapFilter(factory); registerFunctionMapApply(factory); } diff --git a/tests/queries/0_stateless/02169_map_functions.reference b/tests/queries/0_stateless/02169_map_functions.reference index c570ba4e724..160aebbc852 100644 --- a/tests/queries/0_stateless/02169_map_functions.reference +++ b/tests/queries/0_stateless/02169_map_functions.reference @@ -19,7 +19,6 @@ {'key1':1112,'key2':2223} {'key1':1113,'key2':2225} {'key1':1114,'key2':2227} -{1:2,2:3} {} {} {} @@ -28,3 +27,7 @@ {} {} {3:2,1:0,2:0} +{1:2,2:3} +{1:2,2:3} +{'x':'y','x':'y'} +{'x':'y','x':'y'} diff --git a/tests/queries/0_stateless/02169_map_functions.sql b/tests/queries/0_stateless/02169_map_functions.sql index 1d8f90e8a90..ee2e70f82cd 100644 --- a/tests/queries/0_stateless/02169_map_functions.sql +++ b/tests/queries/0_stateless/02169_map_functions.sql @@ -1,13 +1,20 @@ DROP TABLE IF EXISTS table_map; -create TABLE table_map (id UInt32, col Map(String, UInt64)) engine = MergeTree() ORDER BY tuple(); +CREATE TABLE table_map (id UInt32, col Map(String, UInt64)) engine = MergeTree() ORDER BY tuple(); INSERT INTO table_map SELECT number, map('key1', number, 'key2', number * 2) FROM numbers(1111, 3); INSERT INTO table_map SELECT number, map('key3', number, 'key2', number + 1, 'key4', number + 2) FROM numbers(100, 4); -SELECT mapFilter((k,v)->k like '%3' and v > 102, col) FROM table_map ORDER BY id; +SELECT mapFilter((k, v) -> k like '%3' and v > 102, col) FROM table_map ORDER BY id; SELECT col, mapFilter((k, v) -> ((v % 10) > 1), col) FROM table_map ORDER BY id ASC; -SELECT mapApply((k,v)->(k,v+1), col) FROM table_map ORDER BY id; -SELECT mapApply((x, y) -> (x, x + 1), map(1, 0, 2, 0)); -SELECT mapFilter((k,v)->0, col) from table_map; +SELECT mapApply((k, v) -> (k, v + 1), col) FROM table_map ORDER BY id; +SELECT mapFilter((k, v) -> 0, col) from table_map; +SELECT mapApply((k, v) -> tuple(v + 9223372036854775806), col) FROM table_map; -- { serverError 42 } + SELECT mapUpdate(map(1, 3, 3, 2), map(1, 0, 2, 0)); -SELECT mapApply((k, v) -> tuple(v + 9223372036854775806), col) FROM table_map; -- { serverError 42 } +SELECT mapApply((x, y) -> (x, x + 1), map(1, 0, 2, 0)); +SELECT mapApply((x, y) -> (x, x + 1), materialize(map(1, 0, 2, 0))); +SELECT mapApply((x, y) -> ('x', 'y'), map(1, 0, 2, 0)); +SELECT mapApply((x, y) -> ('x', 'y'), materialize(map(1, 0, 2, 0))); +SELECT mapApply((x, y) -> (x), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } +SELECT mapApply((x, y) -> ('x'), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } + DROP TABLE table_map;