diff --git a/dbms/src/Functions/FunctionHelpers.cpp b/dbms/src/Functions/FunctionHelpers.cpp index 212a107e37c..16708564160 100644 --- a/dbms/src/Functions/FunctionHelpers.cpp +++ b/dbms/src/Functions/FunctionHelpers.cpp @@ -116,4 +116,45 @@ void validateArgumentType(const IFunction & func, const DataTypes & arguments, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } +namespace +{ +void validateArgumentsImpl(const IFunction & func, const ColumnsWithTypeAndName & arguments, size_t argument_offset, const FunctionArgumentTypeValidators & validators) +{ + for (size_t i = 0; i < validators.size(); ++i) + { + const auto argument_index = i + argument_offset; + if (argument_index >= arguments.size()) + { + break; + } + + const auto & arg = arguments[i + argument_offset]; + const auto validator = validators[i]; + if (validator.validator_func(*arg.type) == false) + throw Exception("Illegal type " + arg.type->getName() + + " of " + std::to_string(i) + + " argument of function " + func.getName() + + " expected " + validator.expected_type_description, + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } +} + +} + +void validateFunctionArgumentTypes(const IFunction & func, const ColumnsWithTypeAndName & arguments, const FunctionArgumentTypeValidators & mandatory_args, const FunctionArgumentTypeValidators & optional_args) +{ + if (arguments.size() < mandatory_args.size()) + throw Exception("Incorrect number of arguments of function " + func.getName() + + " provided: " + std::to_string(arguments.size()) + + " expected: " + std::to_string(mandatory_args.size()) + + (optional_args.size() ? " or " + std::to_string(mandatory_args.size() + optional_args.size()) : ""), + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + + validateArgumentsImpl(func, arguments, 0, mandatory_args); + if (optional_args.size()) + { + validateArgumentsImpl(func, arguments, mandatory_args.size(), optional_args); + } +} + } diff --git a/dbms/src/Functions/FunctionHelpers.h b/dbms/src/Functions/FunctionHelpers.h index ac116510b7e..e67a8b3b14c 100644 --- a/dbms/src/Functions/FunctionHelpers.h +++ b/dbms/src/Functions/FunctionHelpers.h @@ -89,4 +89,14 @@ void validateArgumentType(const IFunction & func, const DataTypes & arguments, size_t argument_index, bool (* validator_func)(const IDataType &), const char * expected_type_description); +struct FunctionArgumentTypeValidator +{ + bool (* validator_func)(const IDataType &); + const char * expected_type_description; +}; + +using FunctionArgumentTypeValidators = std::vector; + +void validateFunctionArgumentTypes(const IFunction & func, const ColumnsWithTypeAndName & arguments, const FunctionArgumentTypeValidators & mandatory_args, const FunctionArgumentTypeValidators & optional_args = {}); + } diff --git a/dbms/src/Functions/FunctionsConversion.h b/dbms/src/Functions/FunctionsConversion.h index b27ed681f84..563f46719cd 100644 --- a/dbms/src/Functions/FunctionsConversion.h +++ b/dbms/src/Functions/FunctionsConversion.h @@ -832,16 +832,19 @@ public: DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { - if (to_decimal && arguments.size() != 2) + FunctionArgumentTypeValidators mandatory_args = {{[](const auto &) {return true;}, "ANY TYPE"}}; + FunctionArgumentTypeValidators optional_args; + + if constexpr (to_decimal || std::is_same_v) { - throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " - + toString(arguments.size()) + ", should be 2.", - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + mandatory_args.push_back(FunctionArgumentTypeValidator{&isNativeInteger, "Integer"}); // scale } - else if (arguments.size() != 1 && arguments.size() != 2) - throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " - + toString(arguments.size()) + ", should be 1 or 2. Second argument (time zone) is optional only make sense for DateTime.", - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); + else + { + optional_args.push_back(FunctionArgumentTypeValidator{&isString, "String"}); // timezone + } + + validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args); if constexpr (std::is_same_v) { @@ -865,45 +868,20 @@ public: } else { - UInt8 max_args = 2; + UInt8 timezone_arg_position = 1; UInt32 scale = DataTypeDateTime64::default_scale; + // DateTime64 requires more arguments: scale and timezone. Since timezone is optional, scale should be first. if constexpr (std::is_same_v) { - max_args += 1; - if (isNativeInteger(*arguments[max_args - 1].type)) - { - scale = static_cast(arguments[max_args - 1].column->get64(0)); - } + timezone_arg_position += 1; + scale = static_cast(arguments[1].column->get64(0)); } - /** Optional (could be first or second) argument with time zone is supported: - * - for functions toDateTime, toUnixTimestamp, toDate; - * - for function toString of DateTime argument. - */ - if (arguments.size() == max_args) - { - if (!checkAndGetDataType(arguments[max_args - 1].type.get())) - throw Exception("Illegal type " + arguments[max_args - 1].type->getName() + " of 2nd argument of function " + getName(), - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); - - static constexpr bool to_date_or_time = std::is_same_v - || std::is_same_v - || std::is_same_v; - - if (!(to_date_or_time - || (std::is_same_v && (WhichDataType(arguments[0].type).isDateTime() || WhichDataType(arguments[0].type).isDateTime64())))) - { - throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " - + toString(arguments.size()) + ", should be " + std::to_string(max_args) + ".", - ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - } - } - - if (std::is_same_v) - return std::make_shared(extractTimeZoneNameFromFunctionArguments(arguments, max_args - 1, 0)); - else if (std::is_same_v) - return std::make_shared(scale, extractTimeZoneNameFromFunctionArguments(arguments, max_args - 1, 0)); + if constexpr (std::is_same_v) + return std::make_shared(extractTimeZoneNameFromFunctionArguments(arguments, timezone_arg_position, 0)); + else if constexpr (std::is_same_v) + return std::make_shared(scale, extractTimeZoneNameFromFunctionArguments(arguments, timezone_arg_position, 0)); else return std::make_shared(); }