diff --git a/src/AggregateFunctions/AggregateFunctionMap.cpp b/src/AggregateFunctions/AggregateFunctionMap.cpp index c9c9072ff6c..31505b89fe2 100644 --- a/src/AggregateFunctions/AggregateFunctionMap.cpp +++ b/src/AggregateFunctions/AggregateFunctionMap.cpp @@ -24,22 +24,40 @@ public: const auto * map_type = checkAndGetDataType(arguments[0].get()); if (map_type) + { + if (arguments->size() > 1) + throw Exception( + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, + getName() + " combinator takes only one map argument"); + return DataTypes({map_type->getValueType()}); + } // we need this part just to pass to redirection for mapped arrays + auto check_func = [](DataTypePtr t) + { + return t->getTypeId() == TypeIndex::Array; + }; + const auto * tup_type = checkAndGetDataType(arguments[0].get()); if (tup_type) { - const auto * val_array_type = checkAndGetDataType(tup_type->getElements()[1].get()); - if (val_array_type) - return DataTypes({val_array_type->getNestedType()}); + const auto & types = tup_type->getElements(); + bool arrays_match = arguments.size() == 1 && types.size() >= 2 && std::all_of(types.begin(), types.end(), check_func); + if (arrays_match) + { + const auto & val_array_type = assert_cast(types[1]); + return DataTypes({val_array_type.getNestedType()}); + } } - - if (arguments.size() >= 2) + else { - const auto * val_array_type = checkAndGetDataType(arguments[1].get()); - if (val_array_type) + bool arrays_match = arguments.size() >= 2 && std::all_of(arguments.begin(), arguments.end(), check_func); + if (arrays_match) + { + const auto & val_array_type = assert_cast(arguments[1]); return DataTypes({val_array_type->getNestedType()}); + } } throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Aggregate function " + getName() + " requires map as argument"); @@ -87,37 +105,15 @@ public: throw Exception{"Illegal columns in arguments for combinator " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; } } - else if (!arguments.empty()) + else { - // check if we got tuple of arrays or just arrays and if so, try to redirect to sum/min/max-MappedArrays to implement old behavior + // in case of tuple of arrays or just arrays (checked in transformArguments), try to redirect to sum/min/max-MappedArrays to implement old behavior auto nested_func_name = nested_function->getName(); if (nested_func_name == "sum" || nested_func_name == "min" || nested_func_name == "max") { - bool match; - const auto * tup_type = checkAndGetDataType(arguments[0].get()); - - auto check_func = [](DataTypePtr t) - { - return t->getTypeId() == TypeIndex::Array; - }; - - if (tup_type) - { - const auto & types = tup_type->getElements(); - match = arguments.size() == 1 && types.size() >= 2 && std::all_of(types.begin(), types.end(), check_func); - } - else - { - // sumMappedArrays and others support more than 2 mapped arrays - match = arguments.size() >= 2 && std::all_of(arguments.begin(), arguments.end(), check_func); - } - - if (match) - { - AggregateFunctionProperties out_properties; - auto & aggr_func_factory = AggregateFunctionFactory::instance(); - return aggr_func_factory.get(nested_func_name + "MappedArrays", arguments, params, out_properties); - } + AggregateFunctionProperties out_properties; + auto & aggr_func_factory = AggregateFunctionFactory::instance(); + return aggr_func_factory.get(nested_func_name + "MappedArrays", arguments, params, out_properties); } }