diff --git a/src/Functions/toStartOfInterval.cpp b/src/Functions/toStartOfInterval.cpp index ec0deee8abd..da4eba9a594 100644 --- a/src/Functions/toStartOfInterval.cpp +++ b/src/Functions/toStartOfInterval.cpp @@ -1,8 +1,10 @@ -#include "Common/IntervalKind.h" +#include +#include #include #include #include #include +#include "DataTypes/IDataType.h" #include #include #include @@ -14,7 +16,6 @@ #include #include #include -#include #include @@ -116,9 +117,22 @@ public: throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64", getName(), interval_type->getKind().toString()); + + if (arguments[0].type.get() != arguments[2].type.get()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName()); } else if (isDateOrDate32OrDateTimeOrDateTime64(type_arg3)) + { third_argument = ThirdArgument::IsOrigin; + if (isDateTime64(arguments[0].type) && isDateTime64(arguments[2].type)) + result_type = ResultType::DateTime64; + else if (isDateTime(arguments[0].type) && isDateTime(arguments[2].type)) + result_type = ResultType::DateTime; + else if ((isDate(arguments[0].type) || isDate32(arguments[0].type)) && (isDate(arguments[2].type) || isDate32(arguments[2].type))) + result_type = ResultType::Date; + else + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName()); + } else throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of 3rd argument of function {}. " "This argument is optional and must be a constant String with timezone name or a Date/Date32/DateTime/DateTime64 with a constant origin", @@ -180,13 +194,14 @@ public: } case ResultType::DateTime64: { - UInt32 scale = 0; + auto scale_date_time = assert_cast(*arguments[0].type.get()).getScale(); + UInt32 scale = scale_date_time; if (interval_type->getKind() == IntervalKind::Nanosecond) - scale = 9; + scale = 9 > scale_date_time ? 9 : scale_date_time; else if (interval_type->getKind() == IntervalKind::Microsecond) - scale = 6; + scale = 6 > scale_date_time ? 6 : scale_date_time; else if (interval_type->getKind() == IntervalKind::Millisecond) - scale = 3; + scale = 3 > scale_date_time ? 3 : scale_date_time; const size_t time_zone_arg_num = (arguments.size() == 2 || (arguments.size() == 3 && third_argument == ThirdArgument::IsTimezone)) ? 2 : 3; return std::make_shared(scale, extractTimeZoneNameFromFunctionArguments(arguments, time_zone_arg_num, 0, false)); @@ -209,11 +224,19 @@ public: const size_t time_zone_arg_num = (arguments.size() == 2 || (arguments.size() == 3 && isString(arguments[2].type))) ? 2 : 3; const auto & time_zone = extractTimeZoneFromFunctionArguments(arguments, time_zone_arg_num, 0); - auto result_column = dispatchForTimeColumn(time_column, interval_column, origin_column, result_type, time_zone); + ColumnPtr result_column = nullptr; + if (isDateTime64(result_type)) + result_column = dispatchForTimeColumn(time_column, interval_column, origin_column, result_type, time_zone); + else if (isDateTime(result_type)) + result_column = dispatchForTimeColumn(time_column, interval_column, origin_column, result_type, time_zone); + else + result_column = dispatchForTimeColumn(time_column, interval_column, origin_column, result_type, time_zone); + return result_column; } private: + template ColumnPtr dispatchForTimeColumn( const ColumnWithTypeAndName & time_column, const ColumnWithTypeAndName & interval_column, const ColumnWithTypeAndName & origin_column, const DataTypePtr & result_type, const DateLUTImpl & time_zone) const { @@ -229,7 +252,7 @@ private: auto scale = assert_cast(time_column_type).getScale(); if (time_column_vec) - return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone, scale); + return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone, scale); } else if (isDateTime(time_column_type)) { @@ -238,7 +261,7 @@ private: const auto * time_column_vec = checkAndGetColumn(time_column_col); if (time_column_vec) - return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone); + return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone); } else if (isDate(time_column_type)) { @@ -247,7 +270,7 @@ private: const auto * time_column_vec = checkAndGetColumn(time_column_col); if (time_column_vec) - return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone); + return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone); } else if (isDate32(time_column_type)) { @@ -257,12 +280,12 @@ private: const auto * time_column_vec = checkAndGetColumn(time_column_col); if (time_column_vec) - return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone); + return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone); } throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column for first argument of function {}. Must contain dates or dates with time", getName()); } - template + template ColumnPtr dispatchForIntervalColumn( const TimeDataType & time_data_type, const TimeColumnType & time_column, const ColumnWithTypeAndName & interval_column, const ColumnWithTypeAndName & origin_column, const DataTypePtr & result_type, const DateLUTImpl & time_zone, const UInt16 scale = 1) const @@ -282,32 +305,52 @@ private: switch (interval_type->getKind()) // NOLINT(bugprone-switch-missing-default-case) { case IntervalKind::Nanosecond: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Microsecond: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Millisecond: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Second: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Minute: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Hour: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Day: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Week: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Month: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Quarter: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); case IntervalKind::Year: - return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); + return execute(time_data_type, time_column, num_units, origin_column, result_type, time_zone, scale); } std::unreachable(); } + template + Int64 decideScaleOnPrecision(const UInt16 scale) const + { + static constexpr Int64 MILLISECOND_SCALE = 1000; + static constexpr Int64 MICROSECOND_SCALE = 1000000; + static constexpr Int64 NANOSECOND_SCALE = 1000000000; + Int64 scale_multiplier = DecimalUtils::scaleMultiplier(scale); + switch (unit) + { + case IntervalKind::Millisecond: + return MILLISECOND_SCALE; + case IntervalKind::Microsecond: + return MICROSECOND_SCALE; + case IntervalKind::Nanosecond: + return NANOSECOND_SCALE; + default: + return scale_multiplier; + } + } + template ColumnPtr execute(const TimeDataType &, const ColumnType & time_column_type, Int64 num_units, const ColumnWithTypeAndName & origin_column, const DataTypePtr & result_type, const DateLUTImpl & time_zone, const UInt16 scale) const { @@ -323,6 +366,8 @@ private: result_data.resize(size); Int64 scale_multiplier = DecimalUtils::scaleMultiplier(scale); + Int64 scale_on_precision = decideScaleOnPrecision(scale); + Int64 scale_diff = scale_on_precision > scale_multiplier ? scale_on_precision / scale_multiplier : scale_multiplier / scale_on_precision; if (origin_column.column == nullptr) { @@ -342,19 +387,40 @@ private: t -= origin; auto res = static_cast(ToStartOfInterval::execute(t, num_units, time_zone, scale_multiplier)); - if (!(unit == IntervalKind::Millisecond || unit == IntervalKind::Microsecond || unit == IntervalKind::Nanosecond) && scale_multiplier != 10) - origin = origin / scale_multiplier; - static constexpr size_t SECONDS_PER_DAY = 86400; result_data[i] = 0; if (unit == IntervalKind::Week || unit == IntervalKind::Month || unit == IntervalKind::Quarter || unit == IntervalKind::Year) - result_data[i] = static_cast(origin/SECONDS_PER_DAY + res); + { + if (isDate(result_type) || isDate32(result_type)) + { + result_data[i] += origin + res; + } + else if (isDateTime64(result_type)) + { + result_data[i] += origin + (res * SECONDS_PER_DAY * scale_multiplier); + } + else + { + result_data[i] += origin + res * SECONDS_PER_DAY; + } + } else - result_data[i] += origin + res; + { + if (isDate(result_type) || isDate32(result_type)) + res = res / SECONDS_PER_DAY; + + if (scale_on_precision > scale_multiplier) + { + result_data[i] += (origin + res / scale_diff) * scale_diff; + } + else + { + result_data[i] += origin + res * scale_diff; + } + } } } - return result_col; } };