From d1c49cc9bcb869030361821e60349e8054a51c4b Mon Sep 17 00:00:00 2001 From: Yarik Briukhovetskyi <114298166+yariks5s@users.noreply.github.com> Date: Tue, 12 Dec 2023 17:55:53 +0100 Subject: [PATCH] Added comments, simplified and fixed review --- src/Functions/toStartOfInterval.cpp | 138 ++++++++++++++++------------ 1 file changed, 81 insertions(+), 57 deletions(-) diff --git a/src/Functions/toStartOfInterval.cpp b/src/Functions/toStartOfInterval.cpp index b6a3a9389d6..6c71b357590 100644 --- a/src/Functions/toStartOfInterval.cpp +++ b/src/Functions/toStartOfInterval.cpp @@ -1,9 +1,7 @@ -#include -#include #include #include #include -#include +#include #include "DataTypes/IDataType.h" #include #include @@ -16,7 +14,7 @@ #include #include #include -#include +#include namespace DB @@ -31,21 +29,23 @@ namespace ErrorCodes } -namespace -{ - class FunctionToStartOfInterval : public IFunction { public: + enum class Overload + { + Default, /// toStartOfInterval(time, interval) or toStartOfInterval(time, interval, timezone) + Origin /// toStartOfInterval(time, interval, origin) or toStartOfInterval(time, interval, origin, timezone) + }; + mutable Overload overload; + static FunctionPtr create(ContextPtr) { return std::make_shared(); } static constexpr auto name = "toStartOfInterval"; String getName() const override { return name; } - bool isVariadic() const override { return true; } size_t getNumberOfArguments() const override { return 0; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; } - bool useDefaultImplementationForConstants() const override { return true; } ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2, 3}; } @@ -59,8 +59,9 @@ public: { const DataTypePtr & type_arg1 = arguments[0].type; if (!isDate(type_arg1) && !isDateTime(type_arg1) && !isDateTime64(type_arg1)) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of 1st argument of function {}. " - "Should be a date or a date with time", type_arg1->getName(), getName()); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of 1st argument of function {}, expected a Date, DateTime or DateTime64", + type_arg1->getName(), getName()); value_is_date = isDate(type_arg1); }; @@ -75,10 +76,14 @@ public: auto check_second_argument = [&] { const DataTypePtr & type_arg2 = arguments[1].type; + interval_type = checkAndGetDataType(type_arg2.get()); if (!interval_type) - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of 2nd argument of function {}. " - "Should be an interval of time", type_arg2->getName(), getName()); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Illegal type {} of 2nd argument of function {}, expected a time interval", + type_arg2->getName(), getName()); + + /// Result here is determined for default overload (without origin) switch (interval_type->getKind()) // NOLINT(bugprone-switch-missing-default-case) { case IntervalKind::Nanosecond: @@ -89,7 +94,7 @@ public: case IntervalKind::Second: case IntervalKind::Minute: case IntervalKind::Hour: - case IntervalKind::Day: + case IntervalKind::Day: /// weird why Day leads to DateTime but too afraid to change it result_type = ResultType::DateTime; break; case IntervalKind::Week: @@ -101,31 +106,26 @@ public: } }; - enum class ThirdArgument - { - IsTimezone, - IsOrigin - }; - ThirdArgument third_argument; /// valid only if 3rd argument is given auto check_third_argument = [&] { const DataTypePtr & type_arg3 = arguments[2].type; if (isString(type_arg3)) { - third_argument = ThirdArgument::IsTimezone; + overload = Overload::Default; + if (value_is_date && result_type == ResultType::Date) throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "The timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64", getName(), interval_type->getKind().toString()); } - else if (isDateOrDate32OrDateTimeOrDateTime64(type_arg3)) + else if (isDateTimeOrDateTime64(type_arg3) || isDate(type_arg3)) { - third_argument = ThirdArgument::IsOrigin; + overload = Overload::Origin; if (isDateTime64(arguments[0].type) && isDateTime64(arguments[2].type)) result_type = ResultType::DateTime64; else if (isDateTime(arguments[0].type) && isDateTime(arguments[2].type)) result_type = ResultType::DateTime; - else if ((isDate(arguments[0].type) || isDate32(arguments[0].type)) && (isDate(arguments[2].type) || isDate32(arguments[2].type))) + else if (isDate(arguments[0].type) && isDate(arguments[2].type)) result_type = ResultType::Date; else throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName()); @@ -138,7 +138,7 @@ public: auto check_fourth_argument = [&] { - if (third_argument != ThirdArgument::IsOrigin) /// sanity check + if (overload != Overload::Origin) /// sanity check throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of 3rd argument of function {}. " "The third argument must a Date/Date32/DateTime/DateTime64 with a constant origin", arguments[2].type->getName(), getName()); @@ -185,7 +185,7 @@ public: return std::make_shared(); case ResultType::DateTime: { - const size_t time_zone_arg_num = (arguments.size() == 2 || (arguments.size() == 3 && third_argument == ThirdArgument::IsTimezone)) ? 2 : 3; + const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3; return std::make_shared(extractTimeZoneNameFromFunctionArguments(arguments, time_zone_arg_num, 0, false)); } case ResultType::DateTime64: @@ -200,7 +200,7 @@ public: else if (interval_type->getKind() == IntervalKind::Millisecond) scale = 3 > scale ? 3 : scale; - const size_t time_zone_arg_num = (arguments.size() == 2 || (arguments.size() == 3 && third_argument == ThirdArgument::IsTimezone)) ? 2 : 3; + const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3; return std::make_shared(scale, extractTimeZoneNameFromFunctionArguments(arguments, time_zone_arg_num, 0, false)); } } @@ -278,25 +278,25 @@ private: if (time_column_vec) return dispatchForIntervalColumn(assert_cast(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone); } - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column for first argument of function {}. Must contain dates or dates with time", getName()); + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column for 1st argument of function {}, expected a Date, DateTime or DateTime64", getName()); } template ColumnPtr dispatchForIntervalColumn( const TimeDataType & time_data_type, const TimeColumnType & time_column, const ColumnWithTypeAndName & interval_column, const ColumnWithTypeAndName & origin_column, - const DataTypePtr & result_type, const DateLUTImpl & time_zone, const UInt16 scale = 1) const + const DataTypePtr & result_type, const DateLUTImpl & time_zone, UInt16 scale = 1) const { const auto * interval_type = checkAndGetDataType(interval_column.type.get()); if (!interval_type) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for second argument of function {}, must be an interval of time.", getName()); + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for 2nd argument of function {}, must be a time interval", getName()); const auto * interval_column_const_int64 = checkAndGetColumnConst(interval_column.column.get()); if (!interval_column_const_int64) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for second argument of function {}, must be a const interval of time.", getName()); + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for 2nd argument of function {}, must be a const time interval", getName()); - Int64 num_units = interval_column_const_int64->getValue(); + const Int64 num_units = interval_column_const_int64->getValue(); if (num_units <= 0) - throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Value for second argument of function {} must be positive.", getName()); + throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Value for 2nd argument of function {} must be positive", getName()); switch (interval_type->getKind()) // NOLINT(bugprone-switch-missing-default-case) { @@ -328,12 +328,11 @@ private: } template - Int64 decideScaleOnPrecision(const UInt16 scale) const + Int64 decideScaleOnPrecision() const { static constexpr Int64 MILLISECOND_SCALE = 1000; static constexpr Int64 MICROSECOND_SCALE = 1000000; static constexpr Int64 NANOSECOND_SCALE = 1000000000; - Int64 scale_multiplier = DecimalUtils::scaleMultiplier(scale); switch (unit) { case IntervalKind::Millisecond: @@ -343,37 +342,41 @@ private: case IntervalKind::Nanosecond: return NANOSECOND_SCALE; default: - return scale_multiplier; + return 1; } } - template + template ColumnPtr execute(const TimeDataType &, const ColumnType & time_column_type, Int64 num_units, const ColumnWithTypeAndName & origin_column, const DataTypePtr & result_type, const DateLUTImpl & time_zone, const UInt16 scale) const { - using ToColumnType = typename ToDataType::ColumnType; - using ToFieldType = typename ToDataType::FieldType; + using ResultColumnType = typename ResultDataType::ColumnType; + using ResultFieldType = typename ResultDataType::FieldType; const auto & time_data = time_column_type.getData(); size_t size = time_data.size(); auto result_col = result_type->createColumn(); - auto * col_to = assert_cast(result_col.get()); + auto * col_to = assert_cast(result_col.get()); auto & result_data = col_to->getData(); result_data.resize(size); - Int64 scale_multiplier = DecimalUtils::scaleMultiplier(scale); - Int64 scale_on_interval = decideScaleOnPrecision(scale); - Int64 scale_diff = scale_on_interval > scale_multiplier ? scale_on_interval / scale_multiplier : scale_multiplier / scale_on_interval; + Int64 scale_on_time = DecimalUtils::scaleMultiplier(scale); // scale that depends on type of arguments + Int64 scale_on_interval = decideScaleOnPrecision(); // scale that depends on the Interval + /// In case if we have a difference between time arguments and Interval, we need to calculate the difference between them + /// to get the right precision for the result. + Int64 scale_diff = scale_on_interval > scale_on_time ? scale_on_interval / scale_on_time : scale_on_time / scale_on_interval; if (origin_column.column == nullptr) { for (size_t i = 0; i != size; ++i) { result_data[i] = 0; - if (scale_on_interval < scale_multiplier) - result_data[i] += static_cast(ToStartOfInterval::execute(time_data[i], num_units, time_zone, scale_multiplier)) * scale_diff; + if (scale_on_interval < scale_on_time) + /// if we have a time argument that has bigger scale than the interval can contain, we need + /// to return a value with bigger precision and thus we should multiply result on the scale difference. + result_data[i] += static_cast(ToStartOfInterval::execute(time_data[i], num_units, time_zone, scale_on_interval)) * scale_diff; else - result_data[i] = static_cast(ToStartOfInterval::execute(time_data[i], num_units, time_zone, scale_multiplier)); + result_data[i] = static_cast(ToStartOfInterval::execute(time_data[i], num_units, time_zone, scale_on_time)); } } else @@ -387,31 +390,54 @@ private: throw Exception(ErrorCodes::BAD_ARGUMENTS, "The origin must be before the end date/datetime"); t -= origin; - auto res = static_cast(ToStartOfInterval::execute(t, num_units, time_zone, scale_multiplier)); + auto res = static_cast(ToStartOfInterval::execute(t, num_units, time_zone, scale_on_time)); static constexpr size_t SECONDS_PER_DAY = 86400; result_data[i] = 0; if (unit == IntervalKind::Week || unit == IntervalKind::Month || unit == IntervalKind::Quarter || unit == IntervalKind::Year) { + /// By default, when we use week, month, quarter or year interval, we get date return type. So, simply add values. if (isDate(result_type) || isDate32(result_type)) result_data[i] += origin + res; - else if (isDateTime64(result_type)) - result_data[i] += origin + (res * SECONDS_PER_DAY * scale_multiplier); - else + /// When we use DateTime arguments, we should keep in mind that we also have hours, minutes and seconds there, + /// so we need to multiply result by amount of seconds per day. + else if (isDateTime(result_type)) result_data[i] += origin + res * SECONDS_PER_DAY; + /// When we use DateTime64 arguments, we also should multiply it on right scale. + else + result_data[i] += origin + (res * SECONDS_PER_DAY * scale_on_time); } else { + /// In this case result will be calculated as datetime, so we need to get the amount of days if the arguments are Date. if (isDate(result_type) || isDate32(result_type)) res = res / SECONDS_PER_DAY; - if (scale_on_interval > scale_multiplier) - result_data[i] += (origin + res / scale_diff) * scale_diff; - else if (scale_on_interval == scale_multiplier && scale_on_interval % 1000 != 0 && scale_multiplier != 10) - result_data[i] += origin + (res * scale_on_interval); + /// Case when Interval has default scale + if (scale_on_interval == 1) + { + /// Case when the arguments are DateTime64 with precision like 4,5,7,8. Here res has right precision and origin doesn't. + if (scale_on_time % 1000 != 0 && scale_on_time >= 1000) + result_data[i] += (origin + res / scale_on_time) * scale_on_time; + /// Special case when the arguments are DateTime64 with precision 2. Here origin has right precision and res doesn't + else if (scale_on_time == 100) + result_data[i] += (origin + res * scale_on_time); + /// Cases when precision of DateTime64 is 1, 3, 6, 9 e.g. has right precision in res and origin. + else + result_data[i] += (origin + res); + } + /// Case when Interval has some specific scale (3,6,9) else - result_data[i] += origin + res * scale_diff; + { + /// If we have a time argument that has bigger scale than the interval can contain, we need + /// to return a value with bigger precision and thus we should multiply result on the scale difference. + if (scale_on_interval < scale_on_time) + result_data[i] += origin + res * scale_diff; + /// The other case: interval has bigger scale than the interval or they have the same scale, so res has the right precision and origin doesn't + else + result_data[i] += (origin + res / scale_diff) * scale_diff; + } } } } @@ -419,8 +445,6 @@ private: } }; -} - REGISTER_FUNCTION(ToStartOfInterval) { factory.registerFunction();