rewrite the main logic

This commit is contained in:
Yarik Briukhovetskyi 2024-09-03 19:53:28 +02:00 committed by GitHub
parent a2f2d88cec
commit 4fad12ecb0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 79 additions and 159 deletions

View File

@ -22,7 +22,6 @@
namespace DB
{
static Int64 Int64_max_value = std::numeric_limits<Int64>::max();
static constexpr auto millisecond_multiplier = 1'000;
static constexpr auto microsecond_multiplier = 1'000'000;
static constexpr auto nanosecond_multiplier = 1'000'000'000;
@ -701,12 +700,8 @@ struct ToStartOfInterval<IntervalKind::Kind::Week>
if (origin == 0)
return time_zone.toStartOfWeekInterval(time_zone.toDayNum(t / scale_multiplier), weeks);
else
{
if (weeks < Int64_max_value / 7) // Check if multiplication doesn't overflow Int64 value
return ToStartOfInterval<IntervalKind::Kind::Day>::execute(t, weeks * 7, time_zone, scale_multiplier, origin);
else
throw Exception(ErrorCodes::VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE, "Value {} * 7 is out of bounds for type Int64", weeks);
}
return ToStartOfInterval<IntervalKind::Kind::Day>::execute(t, weeks * 7, time_zone, scale_multiplier, origin);
}
};
@ -727,18 +722,20 @@ struct ToStartOfInterval<IntervalKind::Kind::Month>
}
static Int64 execute(Int64 t, Int64 months, const DateLUTImpl & time_zone, Int64 scale_multiplier, Int64 origin = 0)
{
const Int64 scaled_time = t / scale_multiplier;
if (origin == 0)
return time_zone.toStartOfMonthInterval(time_zone.toDayNum(t / scale_multiplier), months);
return time_zone.toStartOfMonthInterval(time_zone.toDayNum(scaled_time), months);
else
{
Int64 days = time_zone.toDayOfMonth(t / scale_multiplier + origin) - time_zone.toDayOfMonth(origin);
Int64 months_to_add = time_zone.toMonth(t / scale_multiplier + origin) - time_zone.toMonth(origin);
Int64 years = time_zone.toYear(t / scale_multiplier + origin) - time_zone.toYear(origin);
const Int64 scaled_origin = origin / scale_multiplier;
const Int64 days = time_zone.toDayOfMonth(scaled_time + scaled_origin) - time_zone.toDayOfMonth(scaled_origin);
Int64 months_to_add = time_zone.toMonth(scaled_time + scaled_origin) - time_zone.toMonth(scaled_origin);
const Int64 years = time_zone.toYear(scaled_time + scaled_origin) - time_zone.toYear(scaled_origin);
months_to_add = days < 0 ? months_to_add - 1 : months_to_add;
months_to_add += years * 12;
Int64 month_multiplier = (months_to_add / months) * months;
return time_zone.addMonths(time_zone.toDate(origin), month_multiplier) - time_zone.toDate(origin);
return (time_zone.addMonths(time_zone.toDate(scaled_origin), month_multiplier) - time_zone.toDate(scaled_origin));
}
}
};
@ -763,12 +760,7 @@ struct ToStartOfInterval<IntervalKind::Kind::Quarter>
if (origin == 0)
return time_zone.toStartOfQuarterInterval(time_zone.toDayNum(t / scale_multiplier), quarters);
else
{
if (quarters < Int64_max_value / 3) // Check if multiplication doesn't overflow Int64 value
return ToStartOfInterval<IntervalKind::Kind::Month>::execute(t, quarters * 3, time_zone, scale_multiplier, origin);
else
throw Exception(ErrorCodes::VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE, "Value {} * 3 is out of bounds for type Int64", quarters);
}
return ToStartOfInterval<IntervalKind::Kind::Month>::execute(t, quarters * 3, time_zone, scale_multiplier, origin);
}
};
@ -792,12 +784,7 @@ struct ToStartOfInterval<IntervalKind::Kind::Year>
if (origin == 0)
return time_zone.toStartOfYearInterval(time_zone.toDayNum(t / scale_multiplier), years);
else
{
if (years < Int64_max_value / 12) // Check if multiplication doesn't overflow Int64 value
return ToStartOfInterval<IntervalKind::Kind::Month>::execute(t, years * 12, time_zone, scale_multiplier, origin);
else
throw Exception(ErrorCodes::VALUE_IS_OUT_OF_RANGE_OF_DATA_TYPE, "Value {} * 12 is out of bounds for type Int64", years);
}
return ToStartOfInterval<IntervalKind::Kind::Month>::execute(t, years * 12, time_zone, scale_multiplier, origin);z
}
};

View File

@ -1,6 +1,3 @@
#include <cmath>
#include <string>
#include <type_traits>
#include <Columns/ColumnsDateTime.h>
#include <Columns/ColumnsNumber.h>
#include <Common/DateLUTImpl.h>
@ -13,6 +10,8 @@
#include <Functions/FunctionFactory.h>
#include <Functions/IFunction.h>
#include <IO/WriteHelpers.h>
#include <cmath>
#include <algorithm>
namespace DB
@ -176,7 +175,7 @@ public:
else
{
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 2, 3 or 4",
"Number of arguments for function {} doesn't match: passed {}, must be 2, 3 or 4",
getName(), arguments.size());
}
@ -193,13 +192,17 @@ public:
{
UInt32 scale = 0;
if (isDateTime64(arguments[0].type) && overload == Overload::Origin)
{
scale = assert_cast<const DataTypeDateTime64 &>(*arguments[0].type.get()).getScale();
if (assert_cast<const DataTypeDateTime64 &>(*arguments[2].type.get()).getScale() != scale)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same scale", getName());
}
if (interval_type->getKind() == IntervalKind::Kind::Nanosecond)
scale = (9 > scale) ? 9 : scale;
scale = 9;
else if (interval_type->getKind() == IntervalKind::Kind::Microsecond)
scale = (6 > scale) ? 6 : scale;
scale = 6;
else if (interval_type->getKind() == IntervalKind::Kind::Millisecond)
scale = (3 > scale) ? 3 : scale;
scale = 3;
const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3;
return std::make_shared<DataTypeDateTime64>(scale, extractTimeZoneNameFromFunctionArguments(arguments, time_zone_arg_num, 0, false));
@ -218,7 +221,7 @@ public:
if (overload == Overload::Origin)
origin_column = arguments[2];
const size_t time_zone_arg_num = (overload == Overload::Origin) ? 3 : 2;
const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3;
const auto & time_zone = extractTimeZoneFromFunctionArguments(arguments, time_zone_arg_num, 0);
ColumnPtr result_column;
@ -272,25 +275,22 @@ private:
if (!interval_type)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for 2nd argument of function {}, must be a time interval", getName());
if (isDate(time_data_type) || isDateTime(time_data_type))
switch (interval_type->getKind()) // NOLINT(bugprone-switch-missing-default-case)
{
switch (interval_type->getKind()) // NOLINT(bugprone-switch-missing-default-case)
{
case IntervalKind::Kind::Nanosecond:
case IntervalKind::Kind::Microsecond:
case IntervalKind::Kind::Millisecond:
if (isDate(time_data_type) || isDateTime(time_data_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal interval kind for argument data type {}", isDate(time_data_type) ? "Date" : "DateTime");
break;
case IntervalKind::Kind::Second:
case IntervalKind::Kind::Minute:
case IntervalKind::Kind::Hour:
if (isDate(time_data_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal interval kind for argument data type Date");
break;
default:
break;
}
case IntervalKind::Kind::Nanosecond:
case IntervalKind::Kind::Microsecond:
case IntervalKind::Kind::Millisecond:
if (isDate(time_data_type) || isDateTime(time_data_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal interval kind for argument data type {}", isDate(time_data_type) ? "Date" : "DateTime");
break;
case IntervalKind::Kind::Second:
case IntervalKind::Kind::Minute:
case IntervalKind::Kind::Hour:
if (isDate(time_data_type))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal interval kind for argument data type Date");
break;
default:
break;
}
const auto * interval_column_const_int64 = checkAndGetColumnConst<ColumnInt64>(interval_column.column.get());
@ -330,27 +330,10 @@ private:
std::unreachable();
}
template <IntervalKind::Kind unit>
static Int64 scaleFromInterval()
{
switch (unit)
{
case IntervalKind::Kind::Millisecond:
return 1'000;
case IntervalKind::Kind::Microsecond:
return 1'000'000;
case IntervalKind::Kind::Nanosecond:
return 1'000'000'000;
default:
return 1;
}
}
template <typename TimeDataType, typename ResultDataType, IntervalKind::Kind unit, typename ColumnType>
ColumnPtr execute(const TimeDataType &, const ColumnType & time_column_type, Int64 num_units, const ColumnWithTypeAndName & origin_column, const DataTypePtr & result_type, const DateLUTImpl & time_zone, UInt16 scale) const
{
using ResultColumnType = typename ResultDataType::ColumnType;
using ResultFieldType = typename ResultDataType::FieldType;
const auto & time_data = time_column_type.getData();
size_t size = time_data.size();
@ -360,114 +343,64 @@ private:
auto & result_data = col_to->getData();
result_data.resize(size);
const Int64 scale_endtime = DecimalUtils::scaleMultiplier<DateTime64>(scale);
const Int64 scale_interval = scaleFromInterval<unit>();
Int64 scale_multiplier = DecimalUtils::scaleMultiplier<DateTime64>(scale);
/// In case if we have a difference between time arguments and Interval, we need to calculate the difference between them
/// to get the right precision for the result.
const Int64 scale_diff = (scale_interval > scale_endtime) ? (scale_interval / scale_endtime) : (scale_endtime / scale_interval);
if (origin_column.column) // Overload: Origin
{
const bool is_small_interval = (unit == IntervalKind::Kind::Nanosecond || unit == IntervalKind::Kind::Microsecond || unit == IntervalKind::Kind::Millisecond);
const bool is_result_date = isDate(result_type);
if (origin_column.column == nullptr)
{
if (scale_endtime > scale_interval && scale_interval != 1)
{
for (size_t i = 0; i != size; ++i)
{
/// If we have a time argument that has bigger scale than the interval can contain and interval is not default, we need
/// to return a value with bigger precision and thus we should multiply result on the scale difference.
result_data[i] = 0;
result_data[i] += static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_endtime));
}
}
else
{
for (size_t i = 0; i != size; ++i)
result_data[i] = static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_endtime));
}
}
else
{
UInt64 origin = origin_column.column->get64(0);
Int64 result_scale = scale_multiplier;
Int64 origin_scale = 1;
if (isDateTime64(origin_column.type.get()))
origin_scale = assert_cast<const DataTypeDateTime64 &>(*origin_column.type.get()).getScale();
if (isDateTime64(result_type)) /// We have origin scale only in case if arguments are DateTime64.
origin_scale = assert_cast<const DataTypeDateTime64 &>(*origin_column.type).getScaleMultiplier();
else if (!is_small_interval) /// In case of large interval and arguments are not DateTime64, we should not have scale in result.
result_scale = 1;
if (is_small_interval)
result_scale = assert_cast<const DataTypeDateTime64 &>(*result_type).getScaleMultiplier();
/// In case if we have a difference between time arguments and Interval, we need to calculate the difference between them
/// to get the right precision for the result. In case of large intervals, we should not have scale difference.
Int64 scale_diff = is_small_interval ? std::max(result_scale / origin_scale, origin_scale / result_scale) : 1;
static constexpr Int64 SECONDS_PER_DAY = 86'400;
UInt64 origin = origin_column.column->get64(0);
for (size_t i = 0; i != size; ++i)
{
UInt64 end_time = time_data[i];
if (origin > static_cast<size_t>(end_time) && origin_scale == scale)
UInt64 time_arg = time_data[i];
if (origin > static_cast<size_t>(time_arg))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The origin must be before the end date / date with time");
else if (origin_scale > scale)
origin /= static_cast<UInt64>(std::pow(10, origin_scale - scale)); /// If arguments have different scales, we make
else if (origin_scale < scale) /// origin argument to have the same scale as the first argument.
origin *= static_cast<UInt64>(std::pow(10, scale - origin_scale));
/// The trick to calculate the interval starting from an offset is to
/// 1. subtract the offset,
/// 2. perform the calculation, and
/// 3. add the offset to the result.
static constexpr size_t SECONDS_PER_DAY = 86'400;
result_data[i] = 0;
if (isDate(origin_column.type.get())) /// We need to perform calculations on dateTime (dateTime64) values only.
if (is_result_date) /// All internal calculations of ToStartOfInterval<...> expect arguments to be seconds or milli-, micro-, nanoseconds.
{
end_time *= SECONDS_PER_DAY;
time_arg *= SECONDS_PER_DAY;
origin *= SECONDS_PER_DAY;
}
Int64 delta = (end_time - origin) * (isDateTime64(origin_column.type.get()) ? 1 : scale_endtime); /// No need to multiply on scale endtime if we have dateTime64 argument.
Int64 offset = 0;
Int64 offset = ToStartOfInterval<unit>::execute(time_arg - origin, num_units, time_zone, result_scale, origin);
/// In case if arguments are DateTime64 with large interval, we should apply scale on it.
offset *= (!is_small_interval) ? result_scale : 1;
if (is_result_date) /// Convert back to date after calculations.
{
auto origin_data = isDateTime64(result_type) ? origin / scale_endtime : origin;
offset = static_cast<DataTypeDateTime::FieldType>(ToStartOfInterval<unit>::execute(delta, num_units, time_zone, scale_endtime, origin_data));
offset /= SECONDS_PER_DAY;
origin /= SECONDS_PER_DAY;
}
if (isDate(result_type)) /// The result should be a date and the calculations were as datetime.
result_data[i] += (origin + offset) / SECONDS_PER_DAY;
else if (unit == IntervalKind::Kind::Week || unit == IntervalKind::Kind::Month || unit == IntervalKind::Kind::Quarter || unit == IntervalKind::Kind::Year)
{
if (isDateTime64(result_type)) /// We need to have the right scale for offset, origin already has the right scale.
offset *= scale_endtime;
result_data[i] += origin + offset;
}
else
{
/// ToStartOfInterval<unit>::execute() returns seconds.
if (scale_interval == 1)
{
if (isDateTime64(result_type)) /// We need to have the right scale for offset, origin already has the correct scale.
offset *= scale_endtime;
/// Interval has default scale, i.e. Year - Second.
if (scale_endtime % 1000 != 0 && scale_endtime >= 1000)
/// The arguments are DateTime64 with precision like 4,5,7,8. Here offset has correct precision and origin doesn't.
result_data[i] += (origin + offset / scale_endtime) * scale_endtime;
else
/// Precision of DateTime64 is 1, 2, 3, 6, 9, e.g. has correct precision in offset and origin.
result_data[i] += (origin + offset);
}
else
{
/// Interval has some specific scale (3,6,9), i.e. Millisecond - Nanosecond.
if (scale_interval < scale_endtime)
/// If we have a time argument that has bigger scale than the interval can contain, we need
/// to return a value with bigger precision and thus we should multiply result on the scale difference.
result_data[i] += origin + offset * scale_diff;
else
/// The other case: interval has bigger scale than the interval or they have the same scale, so offset has the right precision and origin doesn't.
result_data[i] += (origin + offset / scale_diff) * scale_diff;
}
}
result_data[i] = 0;
result_data[i] += (result_scale < origin_scale) ? (origin + offset) / scale_diff : (origin + offset) * scale_diff;
}
}
else // Overload: Default
{
for (size_t i = 0; i != size; ++i)
result_data[i] = static_cast<typename ResultDataType::FieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_multiplier));
}
return result_col;
}
};

View File

@ -23,8 +23,8 @@ Time and origin as DateTime64(9)
2023-10-09 10:10:07.123456789
2023-10-09 10:11:11.123456789
2023-10-09 10:11:12.123456789
2023-10-09 10:11:12.987456789
2023-10-09 10:11:12.987654789
2023-10-09 10:11:12.987
2023-10-09 10:11:12.987654
2023-10-09 10:11:12.987654321
Time and origin as DateTime64(3)
2023-02-01 09:08:07.123