diff --git a/src/Functions/makeDate.cpp b/src/Functions/makeDate.cpp index ae12659c8de..a3a8bf4ed21 100644 --- a/src/Functions/makeDate.cpp +++ b/src/Functions/makeDate.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -48,6 +49,19 @@ public: size_t getNumberOfArguments() const override { return 0; } protected: + template + Columns convertMandatoryArguments(const ColumnsWithTypeAndName & arguments, const ArgumentNames & argument_names) const + { + Columns converted_arguments; + const DataTypePtr converted_argument_type = std::make_shared(); + for (size_t i = 0; i < argument_names.size(); ++i) + { + ColumnPtr argument_column = castColumn(arguments[i], converted_argument_type); + argument_column = argument_column->convertToFullColumnIfConst(); + converted_arguments.push_back(argument_column); + } + return converted_arguments; + } }; /// Common implementation for makeDate, makeDate32 @@ -55,8 +69,8 @@ template class FunctionMakeDate : public FunctionWithNumericParamsBase { private: - static constexpr std::array argument_names_month_day = {"year", "month", "day"}; - static constexpr std::array argument_names_dayofyear = {"year", "dayofyear"}; + static constexpr std::array mandatory_argument_names_year_month_day = {"year", "month", "day"}; + static constexpr std::array mandatory_argument_names_year_dayofyear = {"year", "dayofyear"}; public: static constexpr auto name = Traits::name; @@ -67,19 +81,24 @@ public: DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { - if (arguments.size() != argument_names_month_day.size() && arguments.size() != argument_names_dayofyear.size()) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Function {} requires {} or {} arguments, but {} given", - getName(), argument_names_month_day.size(), argument_names_dayofyear.size(), arguments.size()); + const bool isYearMonthDayVariant = (arguments.size() == 3); - for (size_t i = 0; i < arguments.size(); ++i) + if (isYearMonthDayVariant) { - DataTypePtr argument_type = arguments[i].type; - if (!isNumber(argument_type)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Argument '{}' for function {} must be a number", - (arguments.size() == argument_names_month_day.size()) ? argument_names_month_day[i] : argument_names_dayofyear[i], getName()); - + FunctionArgumentDescriptors args{ + {mandatory_argument_names_year_month_day[0], &isNumber, nullptr, "Number"}, + {mandatory_argument_names_year_month_day[1], &isNumber, nullptr, "Number"}, + {mandatory_argument_names_year_month_day[2], &isNumber, nullptr, "Number"} + }; + validateFunctionArgumentTypes(*this, arguments, args); + } + else + { + FunctionArgumentDescriptors args{ + {mandatory_argument_names_year_dayofyear[0], &isNumber, nullptr, "Number"}, + {mandatory_argument_names_year_dayofyear[1], &isNumber, nullptr, "Number"} + }; + validateFunctionArgumentTypes(*this, arguments, args); } return std::make_shared(); @@ -87,15 +106,13 @@ public: ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override { + const bool isYearMonthDayVariant = (arguments.size() == 3); + Columns converted_arguments; - const DataTypePtr converted_argument_type = std::make_shared(); - converted_arguments.clear(); - for (const auto & argument : arguments) - { - ColumnPtr argument_column = castColumn(argument, converted_argument_type); - argument_column = argument_column->convertToFullColumnIfConst(); - converted_arguments.push_back(argument_column); - } + if (isYearMonthDayVariant) + converted_arguments = convertMandatoryArguments(arguments, mandatory_argument_names_year_month_day); + else + converted_arguments = convertMandatoryArguments(arguments, mandatory_argument_names_year_dayofyear); auto res_column = Traits::ReturnDataType::ColumnType::create(input_rows_count); auto & result_data = res_column->getData(); @@ -103,7 +120,7 @@ public: const auto & date_lut = DateLUT::instance(); const Int32 max_days_since_epoch = date_lut.makeDayNum(Traits::MAX_DATE[0], Traits::MAX_DATE[1], Traits::MAX_DATE[2]); - if (converted_arguments.size() == argument_names_month_day.size()) + if (isYearMonthDayVariant) { const auto & year_data = typeid_cast(*converted_arguments[0]).getData(); const auto & month_data = typeid_cast(*converted_arguments[1]).getData(); @@ -132,7 +149,6 @@ public: } else { - /// case argument_names_dayofyear: const auto & year_data = typeid_cast(*converted_arguments[0]).getData(); const auto & dayofyear_data = typeid_cast(*converted_arguments[1]).getData(); @@ -185,35 +201,7 @@ struct MakeDate32Traits class FunctionMakeDateTimeBase : public FunctionWithNumericParamsBase { protected: - static constexpr std::array argument_names = {"year", "month", "day", "hour", "minute", "second"}; - - void checkRequiredArguments(const ColumnsWithTypeAndName & arguments, size_t optional_argument_count) const - { - if (arguments.size() < argument_names.size() || arguments.size() > argument_names.size() + optional_argument_count) - throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, - "Function {} requires {} to {} arguments, but {} given", - getName(), argument_names.size(), argument_names.size() + optional_argument_count, arguments.size()); - - for (size_t i = 0; i < argument_names.size(); ++i) - { - DataTypePtr argument_type = arguments[i].type; - if (!isNumber(argument_type)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Argument '{}' for function {} must be a number", argument_names[i], getName()); - } - } - - void convertRequiredArguments(const ColumnsWithTypeAndName & arguments, Columns & converted_arguments) const - { - const DataTypePtr converted_argument_type = std::make_shared(); - converted_arguments.clear(); - for (size_t i = 0; i < argument_names.size(); ++i) - { - ColumnPtr argument_column = castColumn(arguments[i], converted_argument_type); - argument_column = argument_column->convertToFullColumnIfConst(); - converted_arguments.push_back(argument_column); - } - } + static constexpr std::array mandatory_argument_names = {"year", "month", "day", "hour", "minute", "second"}; template static Int64 dateTime(T year, T month, T day_of_month, T hour, T minute, T second, const DateLUTImpl & lut) @@ -271,11 +259,24 @@ public: DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { - checkRequiredArguments(arguments, optional_argument_names.size()); + FunctionArgumentDescriptors mandatory_args{ + {mandatory_argument_names[0], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[1], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[2], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[3], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[4], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[5], &isNumber, nullptr, "Number"} + }; + + FunctionArgumentDescriptors optional_args{ + {optional_argument_names[0], &isString, nullptr, "String"} + }; + + validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args); /// Optional timezone argument std::string timezone; - if (arguments.size() == argument_names.size() + 1) + if (arguments.size() == mandatory_argument_names.size() + 1) timezone = extractTimezone(arguments.back()); return std::make_shared(timezone); @@ -285,11 +286,10 @@ public: { /// Optional timezone argument std::string timezone; - if (arguments.size() == argument_names.size() + 1) + if (arguments.size() == mandatory_argument_names.size() + 1) timezone = extractTimezone(arguments.back()); - Columns converted_arguments; - convertRequiredArguments(arguments, converted_arguments); + Columns converted_arguments = convertMandatoryArguments(arguments, mandatory_argument_names); auto res_column = ColumnDateTime::create(input_rows_count); auto & result_data = res_column->getData(); @@ -325,7 +325,7 @@ public: } }; -/// makeDateTime64(year, month, day, hour, minute, second, [fraction], [precision], [timezone]) +/// makeDateTime64(year, month, day, hour, minute, second[, fraction[, precision[, timezone]]]) class FunctionMakeDateTime64 : public FunctionMakeDateTimeBase { private: @@ -341,11 +341,26 @@ public: DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { - checkRequiredArguments(arguments, optional_argument_names.size()); + FunctionArgumentDescriptors mandatory_args{ + {mandatory_argument_names[0], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[1], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[2], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[3], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[4], &isNumber, nullptr, "Number"}, + {mandatory_argument_names[5], &isNumber, nullptr, "Number"} + }; - if (arguments.size() >= argument_names.size() + 1) + FunctionArgumentDescriptors optional_args{ + {optional_argument_names[0], &isNumber, nullptr, "Number"}, + {optional_argument_names[1], &isNumber, nullptr, "Number"}, + {optional_argument_names[2], &isString, nullptr, "String"} + }; + + validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args); + + if (arguments.size() >= mandatory_argument_names.size() + 1) { - const auto& fraction_argument = arguments[argument_names.size()]; + const auto& fraction_argument = arguments[mandatory_argument_names.size()]; if (!isNumber(fraction_argument.type)) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 'fraction' for function {} must be a number", getName()); @@ -353,12 +368,12 @@ public: /// Optional precision argument Int64 precision = DEFAULT_PRECISION; - if (arguments.size() >= argument_names.size() + 2) - precision = extractPrecision(arguments[argument_names.size() + 1]); + if (arguments.size() >= mandatory_argument_names.size() + 2) + precision = extractPrecision(arguments[mandatory_argument_names.size() + 1]); /// Optional timezone argument std::string timezone; - if (arguments.size() == argument_names.size() + 3) + if (arguments.size() == mandatory_argument_names.size() + 3) timezone = extractTimezone(arguments.back()); return std::make_shared(precision, timezone); @@ -368,22 +383,21 @@ public: { /// Optional precision argument Int64 precision = DEFAULT_PRECISION; - if (arguments.size() >= argument_names.size() + 2) - precision = extractPrecision(arguments[argument_names.size() + 1]); + if (arguments.size() >= mandatory_argument_names.size() + 2) + precision = extractPrecision(arguments[mandatory_argument_names.size() + 1]); /// Optional timezone argument std::string timezone; - if (arguments.size() == argument_names.size() + 3) + if (arguments.size() == mandatory_argument_names.size() + 3) timezone = extractTimezone(arguments.back()); - Columns converted_arguments; - convertRequiredArguments(arguments, converted_arguments); + Columns converted_arguments = convertMandatoryArguments(arguments, mandatory_argument_names); /// Optional fraction argument const ColumnVector::Container * fraction_data = nullptr; - if (arguments.size() >= argument_names.size() + 1) + if (arguments.size() >= mandatory_argument_names.size() + 1) { - ColumnPtr fraction_column = castColumn(arguments[argument_names.size()], std::make_shared()); + ColumnPtr fraction_column = castColumn(arguments[mandatory_argument_names.size()], std::make_shared()); fraction_column = fraction_column->convertToFullColumnIfConst(); converted_arguments.push_back(fraction_column); fraction_data = &typeid_cast(*converted_arguments[6]).getData();