Various cleanups

This commit is contained in:
Robert Schulze 2023-05-07 13:06:35 +00:00
parent c893302a08
commit aa09b6154b
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A

View File

@ -1,5 +1,6 @@
#include <Functions/IFunction.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDate32.h>
#include <DataTypes/DataTypeDateTime.h>
@ -48,6 +49,19 @@ public:
size_t getNumberOfArguments() const override { return 0; }
protected:
template <class ArgumentNames>
Columns convertMandatoryArguments(const ColumnsWithTypeAndName & arguments, const ArgumentNames & argument_names) const
{
Columns converted_arguments;
const DataTypePtr converted_argument_type = std::make_shared<DataTypeFloat32>();
for (size_t i = 0; i < argument_names.size(); ++i)
{
ColumnPtr argument_column = castColumn(arguments[i], converted_argument_type);
argument_column = argument_column->convertToFullColumnIfConst();
converted_arguments.push_back(argument_column);
}
return converted_arguments;
}
};
/// Common implementation for makeDate, makeDate32
@ -55,8 +69,8 @@ template <typename Traits>
class FunctionMakeDate : public FunctionWithNumericParamsBase
{
private:
static constexpr std::array argument_names_month_day = {"year", "month", "day"};
static constexpr std::array argument_names_dayofyear = {"year", "dayofyear"};
static constexpr std::array mandatory_argument_names_year_month_day = {"year", "month", "day"};
static constexpr std::array mandatory_argument_names_year_dayofyear = {"year", "dayofyear"};
public:
static constexpr auto name = Traits::name;
@ -67,19 +81,24 @@ public:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
if (arguments.size() != argument_names_month_day.size() && arguments.size() != argument_names_dayofyear.size())
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} requires {} or {} arguments, but {} given",
getName(), argument_names_month_day.size(), argument_names_dayofyear.size(), arguments.size());
const bool isYearMonthDayVariant = (arguments.size() == 3);
for (size_t i = 0; i < arguments.size(); ++i)
if (isYearMonthDayVariant)
{
DataTypePtr argument_type = arguments[i].type;
if (!isNumber(argument_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument '{}' for function {} must be a number",
(arguments.size() == argument_names_month_day.size()) ? argument_names_month_day[i] : argument_names_dayofyear[i], getName());
FunctionArgumentDescriptors args{
{mandatory_argument_names_year_month_day[0], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names_year_month_day[1], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names_year_month_day[2], &isNumber<IDataType>, nullptr, "Number"}
};
validateFunctionArgumentTypes(*this, arguments, args);
}
else
{
FunctionArgumentDescriptors args{
{mandatory_argument_names_year_dayofyear[0], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names_year_dayofyear[1], &isNumber<IDataType>, nullptr, "Number"}
};
validateFunctionArgumentTypes(*this, arguments, args);
}
return std::make_shared<typename Traits::ReturnDataType>();
@ -87,15 +106,13 @@ public:
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const bool isYearMonthDayVariant = (arguments.size() == 3);
Columns converted_arguments;
const DataTypePtr converted_argument_type = std::make_shared<DataTypeFloat32>();
converted_arguments.clear();
for (const auto & argument : arguments)
{
ColumnPtr argument_column = castColumn(argument, converted_argument_type);
argument_column = argument_column->convertToFullColumnIfConst();
converted_arguments.push_back(argument_column);
}
if (isYearMonthDayVariant)
converted_arguments = convertMandatoryArguments(arguments, mandatory_argument_names_year_month_day);
else
converted_arguments = convertMandatoryArguments(arguments, mandatory_argument_names_year_dayofyear);
auto res_column = Traits::ReturnDataType::ColumnType::create(input_rows_count);
auto & result_data = res_column->getData();
@ -103,7 +120,7 @@ public:
const auto & date_lut = DateLUT::instance();
const Int32 max_days_since_epoch = date_lut.makeDayNum(Traits::MAX_DATE[0], Traits::MAX_DATE[1], Traits::MAX_DATE[2]);
if (converted_arguments.size() == argument_names_month_day.size())
if (isYearMonthDayVariant)
{
const auto & year_data = typeid_cast<const ColumnFloat32 &>(*converted_arguments[0]).getData();
const auto & month_data = typeid_cast<const ColumnFloat32 &>(*converted_arguments[1]).getData();
@ -132,7 +149,6 @@ public:
}
else
{
/// case argument_names_dayofyear:
const auto & year_data = typeid_cast<const ColumnFloat32 &>(*converted_arguments[0]).getData();
const auto & dayofyear_data = typeid_cast<const ColumnFloat32 &>(*converted_arguments[1]).getData();
@ -185,35 +201,7 @@ struct MakeDate32Traits
class FunctionMakeDateTimeBase : public FunctionWithNumericParamsBase
{
protected:
static constexpr std::array argument_names = {"year", "month", "day", "hour", "minute", "second"};
void checkRequiredArguments(const ColumnsWithTypeAndName & arguments, size_t optional_argument_count) const
{
if (arguments.size() < argument_names.size() || arguments.size() > argument_names.size() + optional_argument_count)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Function {} requires {} to {} arguments, but {} given",
getName(), argument_names.size(), argument_names.size() + optional_argument_count, arguments.size());
for (size_t i = 0; i < argument_names.size(); ++i)
{
DataTypePtr argument_type = arguments[i].type;
if (!isNumber(argument_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument '{}' for function {} must be a number", argument_names[i], getName());
}
}
void convertRequiredArguments(const ColumnsWithTypeAndName & arguments, Columns & converted_arguments) const
{
const DataTypePtr converted_argument_type = std::make_shared<DataTypeFloat32>();
converted_arguments.clear();
for (size_t i = 0; i < argument_names.size(); ++i)
{
ColumnPtr argument_column = castColumn(arguments[i], converted_argument_type);
argument_column = argument_column->convertToFullColumnIfConst();
converted_arguments.push_back(argument_column);
}
}
static constexpr std::array mandatory_argument_names = {"year", "month", "day", "hour", "minute", "second"};
template <typename T>
static Int64 dateTime(T year, T month, T day_of_month, T hour, T minute, T second, const DateLUTImpl & lut)
@ -271,11 +259,24 @@ public:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
checkRequiredArguments(arguments, optional_argument_names.size());
FunctionArgumentDescriptors mandatory_args{
{mandatory_argument_names[0], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[1], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[2], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[3], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[4], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[5], &isNumber<IDataType>, nullptr, "Number"}
};
FunctionArgumentDescriptors optional_args{
{optional_argument_names[0], &isString<IDataType>, nullptr, "String"}
};
validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args);
/// Optional timezone argument
std::string timezone;
if (arguments.size() == argument_names.size() + 1)
if (arguments.size() == mandatory_argument_names.size() + 1)
timezone = extractTimezone(arguments.back());
return std::make_shared<DataTypeDateTime>(timezone);
@ -285,11 +286,10 @@ public:
{
/// Optional timezone argument
std::string timezone;
if (arguments.size() == argument_names.size() + 1)
if (arguments.size() == mandatory_argument_names.size() + 1)
timezone = extractTimezone(arguments.back());
Columns converted_arguments;
convertRequiredArguments(arguments, converted_arguments);
Columns converted_arguments = convertMandatoryArguments(arguments, mandatory_argument_names);
auto res_column = ColumnDateTime::create(input_rows_count);
auto & result_data = res_column->getData();
@ -325,7 +325,7 @@ public:
}
};
/// makeDateTime64(year, month, day, hour, minute, second, [fraction], [precision], [timezone])
/// makeDateTime64(year, month, day, hour, minute, second[, fraction[, precision[, timezone]]])
class FunctionMakeDateTime64 : public FunctionMakeDateTimeBase
{
private:
@ -341,11 +341,26 @@ public:
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
checkRequiredArguments(arguments, optional_argument_names.size());
FunctionArgumentDescriptors mandatory_args{
{mandatory_argument_names[0], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[1], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[2], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[3], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[4], &isNumber<IDataType>, nullptr, "Number"},
{mandatory_argument_names[5], &isNumber<IDataType>, nullptr, "Number"}
};
if (arguments.size() >= argument_names.size() + 1)
FunctionArgumentDescriptors optional_args{
{optional_argument_names[0], &isNumber<IDataType>, nullptr, "Number"},
{optional_argument_names[1], &isNumber<IDataType>, nullptr, "Number"},
{optional_argument_names[2], &isString<IDataType>, nullptr, "String"}
};
validateFunctionArgumentTypes(*this, arguments, mandatory_args, optional_args);
if (arguments.size() >= mandatory_argument_names.size() + 1)
{
const auto& fraction_argument = arguments[argument_names.size()];
const auto& fraction_argument = arguments[mandatory_argument_names.size()];
if (!isNumber(fraction_argument.type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Argument 'fraction' for function {} must be a number", getName());
@ -353,12 +368,12 @@ public:
/// Optional precision argument
Int64 precision = DEFAULT_PRECISION;
if (arguments.size() >= argument_names.size() + 2)
precision = extractPrecision(arguments[argument_names.size() + 1]);
if (arguments.size() >= mandatory_argument_names.size() + 2)
precision = extractPrecision(arguments[mandatory_argument_names.size() + 1]);
/// Optional timezone argument
std::string timezone;
if (arguments.size() == argument_names.size() + 3)
if (arguments.size() == mandatory_argument_names.size() + 3)
timezone = extractTimezone(arguments.back());
return std::make_shared<DataTypeDateTime64>(precision, timezone);
@ -368,22 +383,21 @@ public:
{
/// Optional precision argument
Int64 precision = DEFAULT_PRECISION;
if (arguments.size() >= argument_names.size() + 2)
precision = extractPrecision(arguments[argument_names.size() + 1]);
if (arguments.size() >= mandatory_argument_names.size() + 2)
precision = extractPrecision(arguments[mandatory_argument_names.size() + 1]);
/// Optional timezone argument
std::string timezone;
if (arguments.size() == argument_names.size() + 3)
if (arguments.size() == mandatory_argument_names.size() + 3)
timezone = extractTimezone(arguments.back());
Columns converted_arguments;
convertRequiredArguments(arguments, converted_arguments);
Columns converted_arguments = convertMandatoryArguments(arguments, mandatory_argument_names);
/// Optional fraction argument
const ColumnVector<Float64>::Container * fraction_data = nullptr;
if (arguments.size() >= argument_names.size() + 1)
if (arguments.size() >= mandatory_argument_names.size() + 1)
{
ColumnPtr fraction_column = castColumn(arguments[argument_names.size()], std::make_shared<DataTypeFloat64>());
ColumnPtr fraction_column = castColumn(arguments[mandatory_argument_names.size()], std::make_shared<DataTypeFloat64>());
fraction_column = fraction_column->convertToFullColumnIfConst();
converted_arguments.push_back(fraction_column);
fraction_data = &typeid_cast<const ColumnFloat64 &>(*converted_arguments[6]).getData();