Upd FunctionArrayMapped for Map

This commit is contained in:
vdimir 2022-02-17 15:40:26 +00:00
parent 82a76d47ff
commit 939a15d29a
No known key found for this signature in database
GPG Key ID: 6EE4CE2BEDC51862
5 changed files with 53 additions and 30 deletions

View File

@ -180,11 +180,28 @@ public:
throw Exception("Expression for function " + getName() + " must return UInt8, found "
+ return_type->getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
static_assert(
std::is_same_v<typename Impl::data_type, DataTypeMap> ||
std::is_same_v<typename Impl::data_type, DataTypeArray>,
"unsupported type");
if (arguments.size() < 2)
{
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "{}", arguments.size());
}
const auto * first_array_type = checkAndGetDataType<typename Impl::data_type>(arguments[1].type.get());
if (!first_array_type)
throw DB::Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Unsupported type {}", arguments[1].type->getName());
if constexpr (std::is_same_v<typename Impl::data_type, DataTypeArray>)
return Impl::getReturnType(return_type, first_array_type->getNestedType());
else
if constexpr (std::is_same_v<typename Impl::data_type, DataTypeMap>)
return Impl::getReturnType(return_type, first_array_type->getKeyValueTypes());
throw DB::Exception(ErrorCodes::LOGICAL_ERROR, "Unreachable code reached");
}
}

View File

@ -13,6 +13,7 @@ namespace DB
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
}
@ -29,6 +30,8 @@ struct MapFilterImpl
using data_type = DataTypeMap;
using column_type = ColumnMap;
static constexpr auto name = "mapFilter";
static bool needBoolean() { return true; }
static bool needExpression() { return true; }
static bool needOneArray() { return false; }
@ -85,14 +88,6 @@ struct MapFilterImpl
}
};
struct NameMapFilter { static constexpr auto name = "mapFilter"; };
using FunctionMapFilter = FunctionArrayMapped<MapFilterImpl, NameMapFilter>;
void registerFunctionMapFilter(FunctionFactory & factory)
{
factory.registerFunction<FunctionMapFilter>();
}
/** mapApply((k,v) -> expression, map) - apply the expression to the map.
*/
@ -101,6 +96,8 @@ struct MapApplyImpl
using data_type = DataTypeMap;
using column_type = ColumnMap;
static constexpr auto name = "mapApply";
/// true if the expression (for an overload of f(expression, maps)) or a map (for f(map)) should be boolean.
static bool needBoolean() { return false; }
static bool needExpression() { return true; }
@ -108,12 +105,15 @@ struct MapApplyImpl
static DataTypePtr getReturnType(const DataTypePtr & expression_return, const DataTypes & /*elems*/)
{
const auto & tuple_types = typeid_cast<const DataTypeTuple *>(&*expression_return)->getElements();
if (tuple_types.size() != 2)
throw Exception("Expected 2 columns as map's key and value, but found "
+ toString(tuple_types.size()) + " columns", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
const auto * tuple_types = typeid_cast<const DataTypeTuple *>(expression_return.get());
if (!tuple_types)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expected return type is tuple, got {}", expression_return->getName());
if (tuple_types->getElements().size() != 2)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Expected 2 columns as map's key and value, but found {}", tuple_types->getElements().size());
return std::make_shared<DataTypeMap>(tuple_types);
return std::make_shared<DataTypeMap>(tuple_types->getElements());
}
static ColumnPtr execute(const ColumnMap & map, ColumnPtr mapped)
@ -123,9 +123,9 @@ struct MapApplyImpl
{
const ColumnConst * column_const_tuple = checkAndGetColumnConst<ColumnTuple>(mapped.get());
if (!column_const_tuple)
throw Exception("Expected tuple column, found " + mapped->getName(), ErrorCodes::ILLEGAL_COLUMN);
ColumnPtr column_tuple_ptr = recursiveRemoveLowCardinality(column_const_tuple->convertToFullColumn());
column_tuple = checkAndGetColumn<ColumnTuple>(column_tuple_ptr.get());
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Expected tuple column, found {}", mapped->getName());
auto cols = convertConstTupleToConstantElements(*column_const_tuple);
return ColumnMap::create(cols[0]->convertToFullColumnIfConst(), cols[1]->convertToFullColumnIfConst(), map.getNestedColumn().getOffsetsPtr());
}
return ColumnMap::create(column_tuple->getColumnPtr(0), column_tuple->getColumnPtr(1),
@ -133,12 +133,10 @@ struct MapApplyImpl
}
};
struct NameMapApply { static constexpr auto name = "mapApply"; };
using FunctionMapApply = FunctionArrayMapped<MapApplyImpl, NameMapApply>;
void registerFunctionMapApply(FunctionFactory & factory)
{
factory.registerFunction<FunctionMapApply>();
factory.registerFunction<FunctionArrayMapped<MapFilterImpl, MapFilterImpl>>();
factory.registerFunction<FunctionArrayMapped<MapApplyImpl, MapApplyImpl>>();
}
}

View File

@ -18,7 +18,6 @@ void registerFunctionsArraySort(FunctionFactory & factory);
void registerFunctionArrayCumSum(FunctionFactory & factory);
void registerFunctionArrayCumSumNonNegative(FunctionFactory & factory);
void registerFunctionArrayDifference(FunctionFactory & factory);
void registerFunctionMapFilter(FunctionFactory & factory);
void registerFunctionMapApply(FunctionFactory & factory);
void registerFunctionsHigherOrder(FunctionFactory & factory)
@ -38,7 +37,6 @@ void registerFunctionsHigherOrder(FunctionFactory & factory)
registerFunctionArrayCumSum(factory);
registerFunctionArrayCumSumNonNegative(factory);
registerFunctionArrayDifference(factory);
registerFunctionMapFilter(factory);
registerFunctionMapApply(factory);
}

View File

@ -19,7 +19,6 @@
{'key1':1112,'key2':2223}
{'key1':1113,'key2':2225}
{'key1':1114,'key2':2227}
{1:2,2:3}
{}
{}
{}
@ -28,3 +27,7 @@
{}
{}
{3:2,1:0,2:0}
{1:2,2:3}
{1:2,2:3}
{'x':'y','x':'y'}
{'x':'y','x':'y'}

View File

@ -1,13 +1,20 @@
DROP TABLE IF EXISTS table_map;
create TABLE table_map (id UInt32, col Map(String, UInt64)) engine = MergeTree() ORDER BY tuple();
CREATE TABLE table_map (id UInt32, col Map(String, UInt64)) engine = MergeTree() ORDER BY tuple();
INSERT INTO table_map SELECT number, map('key1', number, 'key2', number * 2) FROM numbers(1111, 3);
INSERT INTO table_map SELECT number, map('key3', number, 'key2', number + 1, 'key4', number + 2) FROM numbers(100, 4);
SELECT mapFilter((k,v)->k like '%3' and v > 102, col) FROM table_map ORDER BY id;
SELECT mapFilter((k, v) -> k like '%3' and v > 102, col) FROM table_map ORDER BY id;
SELECT col, mapFilter((k, v) -> ((v % 10) > 1), col) FROM table_map ORDER BY id ASC;
SELECT mapApply((k,v)->(k,v+1), col) FROM table_map ORDER BY id;
SELECT mapApply((x, y) -> (x, x + 1), map(1, 0, 2, 0));
SELECT mapFilter((k,v)->0, col) from table_map;
SELECT mapApply((k, v) -> (k, v + 1), col) FROM table_map ORDER BY id;
SELECT mapFilter((k, v) -> 0, col) from table_map;
SELECT mapApply((k, v) -> tuple(v + 9223372036854775806), col) FROM table_map; -- { serverError 42 }
SELECT mapUpdate(map(1, 3, 3, 2), map(1, 0, 2, 0));
SELECT mapApply((k, v) -> tuple(v + 9223372036854775806), col) FROM table_map; -- { serverError 42 }
SELECT mapApply((x, y) -> (x, x + 1), map(1, 0, 2, 0));
SELECT mapApply((x, y) -> (x, x + 1), materialize(map(1, 0, 2, 0)));
SELECT mapApply((x, y) -> ('x', 'y'), map(1, 0, 2, 0));
SELECT mapApply((x, y) -> ('x', 'y'), materialize(map(1, 0, 2, 0)));
SELECT mapApply((x, y) -> (x), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
SELECT mapApply((x, y) -> ('x'), map(1, 0, 2, 0)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT }
DROP TABLE table_map;