Small fixups

This commit is contained in:
Robert Schulze 2023-12-19 20:48:30 +00:00
parent 202ca21e3f
commit 174309821a
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A

View File

@ -31,7 +31,7 @@ namespace ErrorCodes
class FunctionToStartOfInterval : public IFunction
{
public:
private:
enum class Overload
{
Default, /// toStartOfInterval(time, interval) or toStartOfInterval(time, interval, timezone)
@ -39,6 +39,7 @@ public:
};
mutable Overload overload;
public:
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionToStartOfInterval>(); }
static constexpr auto name = "toStartOfInterval";
@ -82,7 +83,9 @@ public:
"Illegal type {} of 2nd argument of function {}, expected a time interval",
type_arg2->getName(), getName());
/// Result here is determined for default overload (without origin)
overload = Overload::Default;
/// Determine result type for default overload (no origin)
switch (interval_type->getKind()) // NOLINT(bugprone-switch-missing-default-case)
{
case IntervalKind::Nanosecond:
@ -110,21 +113,20 @@ public:
const DataTypePtr & type_arg3 = arguments[2].type;
if (isString(type_arg3))
{
overload = Overload::Default;
if (value_is_date && result_type == ResultType::Date)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"The timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64",
"A timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64",
getName(), interval_type->getKind().toString());
}
else if (isDateTimeOrDateTime64(type_arg3) || isDate(type_arg3))
else if (isDate(type_arg3) || isDateTime(type_arg3) || isDateTime64(type_arg3))
{
overload = Overload::Origin;
if (isDateTime64(arguments[0].type) && isDateTime64(arguments[2].type))
const DataTypePtr & type_arg1 = arguments[0].type;
if (isDateTime64(type_arg1) && isDateTime64(type_arg3))
result_type = ResultType::DateTime64;
else if (isDateTime(arguments[0].type) && isDateTime(arguments[2].type))
else if (isDateTime(type_arg1) && isDateTime(type_arg3))
result_type = ResultType::DateTime;
else if (isDate(arguments[0].type) && isDate(arguments[2].type))
else if (isDate(type_arg1) && isDate(type_arg3))
result_type = ResultType::Date;
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName());
@ -149,7 +151,7 @@ public:
type_arg4->getName(), getName());
if (value_is_date && result_type == ResultType::Date)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"The timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64",
"A timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64",
getName(), interval_type->getKind().toString());
};
@ -190,14 +192,14 @@ public:
case ResultType::DateTime64:
{
UInt32 scale = 0;
if (isDate32(arguments[0].type) || isDateTime64(arguments[0].type))
if (isDateTime64(arguments[0].type))
scale = assert_cast<const DataTypeDateTime64 &>(*arguments[0].type.get()).getScale();
if (interval_type->getKind() == IntervalKind::Nanosecond)
scale = 9 > scale ? 9 : scale;
scale = (9 > scale) ? 9 : scale;
else if (interval_type->getKind() == IntervalKind::Microsecond)
scale = 6 > scale ? 6 : scale;
scale = (6 > scale) ? 6 : scale;
else if (interval_type->getKind() == IntervalKind::Millisecond)
scale = 3 > scale ? 3 : scale;
scale = (3 > scale) ? 3 : scale;
const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3;
return std::make_shared<DataTypeDateTime64>(scale, extractTimeZoneNameFromFunctionArguments(arguments, time_zone_arg_num, 0, false));
@ -213,20 +215,19 @@ public:
const auto & interval_column = arguments[1];
ColumnWithTypeAndName origin_column;
const bool has_origin_arg = (arguments.size() == 3 && isDateOrDate32OrDateTimeOrDateTime64(arguments[2].type)) || arguments.size() == 4;
if (has_origin_arg)
if (overload == Overload::Origin)
origin_column = arguments[2];
const size_t time_zone_arg_num = (arguments.size() == 2 || (arguments.size() == 3 && isString(arguments[2].type))) ? 2 : 3;
const size_t time_zone_arg_num = (overload == Overload::Origin) ? 3 : 2;
const auto & time_zone = extractTimeZoneFromFunctionArguments(arguments, time_zone_arg_num, 0);
ColumnPtr result_column = nullptr;
if (isDateTime64(result_type))
result_column = dispatchForTimeColumn<DataTypeDateTime64>(time_column, interval_column, origin_column, result_type, time_zone);
ColumnPtr result_column;
if (isDate(result_type))
result_column = dispatchForTimeColumn<DataTypeDate>(time_column, interval_column, origin_column, result_type, time_zone);
else if (isDateTime(result_type))
result_column = dispatchForTimeColumn<DataTypeDateTime>(time_column, interval_column, origin_column, result_type, time_zone);
else
result_column = dispatchForTimeColumn<DataTypeDate>(time_column, interval_column, origin_column, result_type, time_zone);
else if (isDateTime64(result_type))
result_column = dispatchForTimeColumn<DataTypeDateTime64>(time_column, interval_column, origin_column, result_type, time_zone);
return result_column;
}
@ -238,44 +239,24 @@ private:
const auto & time_column_type = *time_column.type.get();
const auto & time_column_col = *time_column.column.get();
if (isDateTime64(time_column_type))
if (isDate(time_column_type))
{
if (origin_column.column != nullptr && !isDateTime64(origin_column.type.get()))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName());
const auto * time_column_vec = checkAndGetColumn<ColumnDateTime64>(time_column_col);
auto scale = assert_cast<const DataTypeDateTime64 &>(time_column_type).getScale();
if (time_column_vec)
return dispatchForIntervalColumn<ReturnType>(assert_cast<const DataTypeDateTime64 &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone, scale);
}
else if (isDateTime(time_column_type))
{
if (origin_column.column != nullptr && !isDateTime(origin_column.type.get()))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName());
const auto * time_column_vec = checkAndGetColumn<ColumnDateTime>(time_column_col);
if (time_column_vec)
return dispatchForIntervalColumn<ReturnType>(assert_cast<const DataTypeDateTime &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone);
}
else if (isDate(time_column_type))
{
if (origin_column.column != nullptr && !isDate(origin_column.type.get()))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName());
const auto * time_column_vec = checkAndGetColumn<ColumnDate>(time_column_col);
if (time_column_vec)
return dispatchForIntervalColumn<ReturnType>(assert_cast<const DataTypeDate &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone);
}
else if (isDate32(time_column_type))
else if (isDateTime(time_column_type))
{
if (origin_column.column != nullptr)
if (!isDate32(origin_column.type.get()))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName());
const auto * time_column_vec = checkAndGetColumn<ColumnDate32>(time_column_col);
const auto * time_column_vec = checkAndGetColumn<ColumnDateTime>(time_column_col);
if (time_column_vec)
return dispatchForIntervalColumn<ReturnType>(assert_cast<const DataTypeDate32 &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone);
return dispatchForIntervalColumn<ReturnType>(assert_cast<const DataTypeDateTime &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone);
}
else if (isDateTime64(time_column_type))
{
const auto * time_column_vec = checkAndGetColumn<ColumnDateTime64>(time_column_col);
auto scale = assert_cast<const DataTypeDateTime64 &>(time_column_type).getScale();
if (time_column_vec)
return dispatchForIntervalColumn<ReturnType>(assert_cast<const DataTypeDateTime64 &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone, scale);
}
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column for 1st argument of function {}, expected a Date, DateTime or DateTime64", getName());
}
@ -327,26 +308,23 @@ private:
}
template <IntervalKind::Kind unit>
Int64 decideScaleOnPrecision() const
static Int64 scaleFromInterval()
{
static constexpr Int64 MILLISECOND_SCALE = 1000;
static constexpr Int64 MICROSECOND_SCALE = 1000000;
static constexpr Int64 NANOSECOND_SCALE = 1000000000;
switch (unit)
{
case IntervalKind::Millisecond:
return MILLISECOND_SCALE;
return 1'000;
case IntervalKind::Microsecond:
return MICROSECOND_SCALE;
return 1'000'000;
case IntervalKind::Nanosecond:
return NANOSECOND_SCALE;
return 1'000'000'000;
default:
return 1;
}
}
template <typename TimeDataType, typename ResultDataType, IntervalKind::Kind unit, typename ColumnType>
ColumnPtr execute(const TimeDataType &, const ColumnType & time_column_type, Int64 num_units, const ColumnWithTypeAndName & origin_column, const DataTypePtr & result_type, const DateLUTImpl & time_zone, const UInt16 scale) const
ColumnPtr execute(const TimeDataType &, const ColumnType & time_column_type, Int64 num_units, const ColumnWithTypeAndName & origin_column, const DataTypePtr & result_type, const DateLUTImpl & time_zone, UInt16 scale) const
{
using ResultColumnType = typename ResultDataType::ColumnType;
using ResultFieldType = typename ResultDataType::FieldType;
@ -359,23 +337,29 @@ private:
auto & result_data = col_to->getData();
result_data.resize(size);
Int64 scale_on_time = DecimalUtils::scaleMultiplier<DateTime64>(scale); // scale that depends on type of arguments
Int64 scale_on_interval = decideScaleOnPrecision<unit>(); // scale that depends on the Interval
const Int64 scale_time = DecimalUtils::scaleMultiplier<DateTime64>(scale);
const Int64 scale_interval = scaleFromInterval<unit>();
/// In case if we have a difference between time arguments and Interval, we need to calculate the difference between them
/// to get the right precision for the result.
Int64 scale_diff = scale_on_interval > scale_on_time ? scale_on_interval / scale_on_time : scale_on_time / scale_on_interval;
const Int64 scale_diff = (scale_interval > scale_time) ? (scale_interval / scale_time) : (scale_time / scale_interval);
if (origin_column.column == nullptr)
{
for (size_t i = 0; i != size; ++i)
if (scale_time > scale_interval && scale_interval != 1)
{
result_data[i] = 0;
if (scale_on_interval < scale_on_time && scale_on_interval != 1)
for (size_t i = 0; i != size; ++i)
{
/// If we have a time argument that has bigger scale than the interval can contain and interval is not default, we need
/// to return a value with bigger precision and thus we should multiply result on the scale difference.
result_data[i] += static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_on_time)) * scale_diff;
else
result_data[i] = static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_on_time));
result_data[i] = 0;
result_data[i] += static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_time)) * scale_diff;
}
}
else
{
for (size_t i = 0; i != size; ++i)
result_data[i] = static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_time));
}
}
else
@ -386,55 +370,60 @@ private:
{
auto t = time_data[i];
if (origin > static_cast<size_t>(t))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The origin must be before the end date/datetime");
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The origin must be before the end date / date with time");
/// The trick to calculate the interval starting from an offset is to
/// 1. subtract the offset,
/// 2. perform the calculation, and
/// 3. add the offset to the result.
t -= origin;
auto res = static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(t, num_units, time_zone, scale_on_time));
auto res = static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(t, num_units, time_zone, scale_time));
static constexpr size_t SECONDS_PER_DAY = 86400;
static constexpr size_t SECONDS_PER_DAY = 86'400;
result_data[i] = 0;
if (unit == IntervalKind::Week || unit == IntervalKind::Month || unit == IntervalKind::Quarter || unit == IntervalKind::Year)
{
/// By default, when we use week, month, quarter or year interval, we get date return type. So, simply add values.
if (isDate(result_type) || isDate32(result_type))
/// For such intervals, ToStartOfInterval<unit>::execute() returns days
if (isDate(result_type))
result_data[i] += origin + res;
/// When we use DateTime arguments, we should keep in mind that we also have hours, minutes and seconds there,
/// so we need to multiply result by amount of seconds per day.
else if (isDateTime(result_type))
result_data[i] += origin + res * SECONDS_PER_DAY;
/// When we use DateTime64 arguments, we also should multiply it on right scale.
else
result_data[i] += origin + (res * SECONDS_PER_DAY * scale_on_time);
else if (isDateTime64(result_type))
result_data[i] += origin + (res * SECONDS_PER_DAY * scale_time);
}
else
{
/// In this case result will be calculated as datetime, so we need to get the amount of days if the arguments are Date.
if (isDate(result_type) || isDate32(result_type))
/// ToStartOfInterval<unit>::execute() returns seconds
if (isDate(result_type))
res = res / SECONDS_PER_DAY;
/// Case when Interval has default scale
if (scale_on_interval == 1)
if (scale_interval == 1)
{
/// Case when the arguments are DateTime64 with precision like 4,5,7,8. Here res has right precision and origin doesn't.
if (scale_on_time % 1000 != 0 && scale_on_time >= 1000)
result_data[i] += (origin + res / scale_on_time) * scale_on_time;
/// Special case when the arguments are DateTime64 with precision 2. Here origin has right precision and res doesn't
else if (scale_on_time == 100)
result_data[i] += (origin + res * scale_on_time);
/// Cases when precision of DateTime64 is 1, 3, 6, 9 e.g. has right precision in res and origin.
/// Interval has default scale, i.e. Year - Second
if (scale_time % 1000 != 0 && scale_time >= 1000)
/// The arguments are DateTime64 with precision like 4,5,7,8. Here res has right precision and origin doesn't.
result_data[i] += (origin + res / scale_time) * scale_time;
else if (scale_time == 100)
/// The arguments are DateTime64 with precision 2. Here origin has right precision and res doesn't
result_data[i] += (origin + res * scale_time);
else
/// Precision of DateTime64 is 1, 3, 6, 9, e.g. has right precision in res and origin.
result_data[i] += (origin + res);
}
/// Case when Interval has some specific scale (3,6,9)
else
{
/// If we have a time argument that has bigger scale than the interval can contain, we need
/// to return a value with bigger precision and thus we should multiply result on the scale difference.
if (scale_on_interval < scale_on_time)
/// Interval has some specific scale (3,6,9), i.e. Millisecond - Nanosecond
if (scale_interval < scale_time)
/// If we have a time argument that has bigger scale than the interval can contain, we need
/// to return a value with bigger precision and thus we should multiply result on the scale difference.
result_data[i] += origin + res * scale_diff;
/// The other case: interval has bigger scale than the interval or they have the same scale, so res has the right precision and origin doesn't
else
/// The other case: interval has bigger scale than the interval or they have the same scale, so res has the right precision and origin doesn't
result_data[i] += (origin + res / scale_diff) * scale_diff;
}
}