Added comments, simplified and fixed review

This commit is contained in:
Yarik Briukhovetskyi 2023-12-12 17:55:53 +01:00 committed by GitHub
parent 87bda03da1
commit d1c49cc9bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,7 @@
#include <memory>
#include <Common/IntervalKind.h>
#include <Columns/ColumnsDateTime.h>
#include <Columns/ColumnsNumber.h>
#include <Common/DateLUTImpl.h>
#include <Common/Exception.h>
#include <Common/IntervalKind.h>
#include "DataTypes/IDataType.h"
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDate32.h>
@ -16,7 +14,7 @@
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <IO/WriteHelpers.h>
#include <base/types.h>
#include <base/arithmeticOverflow.h>
namespace DB
@ -31,21 +29,23 @@ namespace ErrorCodes
}
namespace
{
class FunctionToStartOfInterval : public IFunction
{
public:
enum class Overload
{
Default, /// toStartOfInterval(time, interval) or toStartOfInterval(time, interval, timezone)
Origin /// toStartOfInterval(time, interval, origin) or toStartOfInterval(time, interval, origin, timezone)
};
mutable Overload overload;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionToStartOfInterval>(); }
static constexpr auto name = "toStartOfInterval";
String getName() const override { return name; }
bool isVariadic() const override { return true; }
size_t getNumberOfArguments() const override { return 0; }
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return false; }
bool useDefaultImplementationForConstants() const override { return true; }
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {1, 2, 3}; }
@ -59,8 +59,9 @@ public:
{
const DataTypePtr & type_arg1 = arguments[0].type;
if (!isDate(type_arg1) && !isDateTime(type_arg1) && !isDateTime64(type_arg1))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of 1st argument of function {}. "
"Should be a date or a date with time", type_arg1->getName(), getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of 1st argument of function {}, expected a Date, DateTime or DateTime64",
type_arg1->getName(), getName());
value_is_date = isDate(type_arg1);
};
@ -75,10 +76,14 @@ public:
auto check_second_argument = [&]
{
const DataTypePtr & type_arg2 = arguments[1].type;
interval_type = checkAndGetDataType<DataTypeInterval>(type_arg2.get());
if (!interval_type)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of 2nd argument of function {}. "
"Should be an interval of time", type_arg2->getName(), getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of 2nd argument of function {}, expected a time interval",
type_arg2->getName(), getName());
/// Result here is determined for default overload (without origin)
switch (interval_type->getKind()) // NOLINT(bugprone-switch-missing-default-case)
{
case IntervalKind::Nanosecond:
@ -89,7 +94,7 @@ public:
case IntervalKind::Second:
case IntervalKind::Minute:
case IntervalKind::Hour:
case IntervalKind::Day:
case IntervalKind::Day: /// weird why Day leads to DateTime but too afraid to change it
result_type = ResultType::DateTime;
break;
case IntervalKind::Week:
@ -101,31 +106,26 @@ public:
}
};
enum class ThirdArgument
{
IsTimezone,
IsOrigin
};
ThirdArgument third_argument; /// valid only if 3rd argument is given
auto check_third_argument = [&]
{
const DataTypePtr & type_arg3 = arguments[2].type;
if (isString(type_arg3))
{
third_argument = ThirdArgument::IsTimezone;
overload = Overload::Default;
if (value_is_date && result_type == ResultType::Date)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"The timezone argument of function {} with interval type {} is allowed only when the 1st argument has the type DateTime or DateTime64",
getName(), interval_type->getKind().toString());
}
else if (isDateOrDate32OrDateTimeOrDateTime64(type_arg3))
else if (isDateTimeOrDateTime64(type_arg3) || isDate(type_arg3))
{
third_argument = ThirdArgument::IsOrigin;
overload = Overload::Origin;
if (isDateTime64(arguments[0].type) && isDateTime64(arguments[2].type))
result_type = ResultType::DateTime64;
else if (isDateTime(arguments[0].type) && isDateTime(arguments[2].type))
result_type = ResultType::DateTime;
else if ((isDate(arguments[0].type) || isDate32(arguments[0].type)) && (isDate(arguments[2].type) || isDate32(arguments[2].type)))
else if (isDate(arguments[0].type) && isDate(arguments[2].type))
result_type = ResultType::Date;
else
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Datetime argument and origin argument for function {} must have the same type", getName());
@ -138,7 +138,7 @@ public:
auto check_fourth_argument = [&]
{
if (third_argument != ThirdArgument::IsOrigin) /// sanity check
if (overload != Overload::Origin) /// sanity check
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of 3rd argument of function {}. "
"The third argument must a Date/Date32/DateTime/DateTime64 with a constant origin",
arguments[2].type->getName(), getName());
@ -185,7 +185,7 @@ public:
return std::make_shared<DataTypeDate>();
case ResultType::DateTime:
{
const size_t time_zone_arg_num = (arguments.size() == 2 || (arguments.size() == 3 && third_argument == ThirdArgument::IsTimezone)) ? 2 : 3;
const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3;
return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, time_zone_arg_num, 0, false));
}
case ResultType::DateTime64:
@ -200,7 +200,7 @@ public:
else if (interval_type->getKind() == IntervalKind::Millisecond)
scale = 3 > scale ? 3 : scale;
const size_t time_zone_arg_num = (arguments.size() == 2 || (arguments.size() == 3 && third_argument == ThirdArgument::IsTimezone)) ? 2 : 3;
const size_t time_zone_arg_num = (overload == Overload::Default) ? 2 : 3;
return std::make_shared<DataTypeDateTime64>(scale, extractTimeZoneNameFromFunctionArguments(arguments, time_zone_arg_num, 0, false));
}
}
@ -278,25 +278,25 @@ private:
if (time_column_vec)
return dispatchForIntervalColumn<ReturnType>(assert_cast<const DataTypeDate32 &>(time_column_type), *time_column_vec, interval_column, origin_column, result_type, time_zone);
}
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column for first argument of function {}. Must contain dates or dates with time", getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal column for 1st argument of function {}, expected a Date, DateTime or DateTime64", getName());
}
template <typename ReturnType, typename TimeColumnType, typename TimeDataType>
ColumnPtr dispatchForIntervalColumn(
const TimeDataType & time_data_type, const TimeColumnType & time_column, const ColumnWithTypeAndName & interval_column, const ColumnWithTypeAndName & origin_column,
const DataTypePtr & result_type, const DateLUTImpl & time_zone, const UInt16 scale = 1) const
const DataTypePtr & result_type, const DateLUTImpl & time_zone, UInt16 scale = 1) const
{
const auto * interval_type = checkAndGetDataType<DataTypeInterval>(interval_column.type.get());
if (!interval_type)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for second argument of function {}, must be an interval of time.", getName());
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for 2nd argument of function {}, must be a time interval", getName());
const auto * interval_column_const_int64 = checkAndGetColumnConst<ColumnInt64>(interval_column.column.get());
if (!interval_column_const_int64)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for second argument of function {}, must be a const interval of time.", getName());
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column for 2nd argument of function {}, must be a const time interval", getName());
Int64 num_units = interval_column_const_int64->getValue<Int64>();
const Int64 num_units = interval_column_const_int64->getValue<Int64>();
if (num_units <= 0)
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Value for second argument of function {} must be positive.", getName());
throw Exception(ErrorCodes::ARGUMENT_OUT_OF_BOUND, "Value for 2nd argument of function {} must be positive", getName());
switch (interval_type->getKind()) // NOLINT(bugprone-switch-missing-default-case)
{
@ -328,12 +328,11 @@ private:
}
template <IntervalKind::Kind unit>
Int64 decideScaleOnPrecision(const UInt16 scale) const
Int64 decideScaleOnPrecision() const
{
static constexpr Int64 MILLISECOND_SCALE = 1000;
static constexpr Int64 MICROSECOND_SCALE = 1000000;
static constexpr Int64 NANOSECOND_SCALE = 1000000000;
Int64 scale_multiplier = DecimalUtils::scaleMultiplier<DateTime64>(scale);
switch (unit)
{
case IntervalKind::Millisecond:
@ -343,37 +342,41 @@ private:
case IntervalKind::Nanosecond:
return NANOSECOND_SCALE;
default:
return scale_multiplier;
return 1;
}
}
template <typename TimeDataType, typename ToDataType, IntervalKind::Kind unit, typename ColumnType>
template <typename TimeDataType, typename ResultDataType, IntervalKind::Kind unit, typename ColumnType>
ColumnPtr execute(const TimeDataType &, const ColumnType & time_column_type, Int64 num_units, const ColumnWithTypeAndName & origin_column, const DataTypePtr & result_type, const DateLUTImpl & time_zone, const UInt16 scale) const
{
using ToColumnType = typename ToDataType::ColumnType;
using ToFieldType = typename ToDataType::FieldType;
using ResultColumnType = typename ResultDataType::ColumnType;
using ResultFieldType = typename ResultDataType::FieldType;
const auto & time_data = time_column_type.getData();
size_t size = time_data.size();
auto result_col = result_type->createColumn();
auto * col_to = assert_cast<ToColumnType *>(result_col.get());
auto * col_to = assert_cast<ResultColumnType *>(result_col.get());
auto & result_data = col_to->getData();
result_data.resize(size);
Int64 scale_multiplier = DecimalUtils::scaleMultiplier<DateTime64>(scale);
Int64 scale_on_interval = decideScaleOnPrecision<unit>(scale);
Int64 scale_diff = scale_on_interval > scale_multiplier ? scale_on_interval / scale_multiplier : scale_multiplier / scale_on_interval;
Int64 scale_on_time = DecimalUtils::scaleMultiplier<DateTime64>(scale); // scale that depends on type of arguments
Int64 scale_on_interval = decideScaleOnPrecision<unit>(); // scale that depends on the Interval
/// In case if we have a difference between time arguments and Interval, we need to calculate the difference between them
/// to get the right precision for the result.
Int64 scale_diff = scale_on_interval > scale_on_time ? scale_on_interval / scale_on_time : scale_on_time / scale_on_interval;
if (origin_column.column == nullptr)
{
for (size_t i = 0; i != size; ++i)
{
result_data[i] = 0;
if (scale_on_interval < scale_multiplier)
result_data[i] += static_cast<ToFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_multiplier)) * scale_diff;
if (scale_on_interval < scale_on_time)
/// if we have a time argument that has bigger scale than the interval can contain, we need
/// to return a value with bigger precision and thus we should multiply result on the scale difference.
result_data[i] += static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_on_interval)) * scale_diff;
else
result_data[i] = static_cast<ToFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_multiplier));
result_data[i] = static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(time_data[i], num_units, time_zone, scale_on_time));
}
}
else
@ -387,31 +390,54 @@ private:
throw Exception(ErrorCodes::BAD_ARGUMENTS, "The origin must be before the end date/datetime");
t -= origin;
auto res = static_cast<ToFieldType>(ToStartOfInterval<unit>::execute(t, num_units, time_zone, scale_multiplier));
auto res = static_cast<ResultFieldType>(ToStartOfInterval<unit>::execute(t, num_units, time_zone, scale_on_time));
static constexpr size_t SECONDS_PER_DAY = 86400;
result_data[i] = 0;
if (unit == IntervalKind::Week || unit == IntervalKind::Month || unit == IntervalKind::Quarter || unit == IntervalKind::Year)
{
/// By default, when we use week, month, quarter or year interval, we get date return type. So, simply add values.
if (isDate(result_type) || isDate32(result_type))
result_data[i] += origin + res;
else if (isDateTime64(result_type))
result_data[i] += origin + (res * SECONDS_PER_DAY * scale_multiplier);
else
/// When we use DateTime arguments, we should keep in mind that we also have hours, minutes and seconds there,
/// so we need to multiply result by amount of seconds per day.
else if (isDateTime(result_type))
result_data[i] += origin + res * SECONDS_PER_DAY;
/// When we use DateTime64 arguments, we also should multiply it on right scale.
else
result_data[i] += origin + (res * SECONDS_PER_DAY * scale_on_time);
}
else
{
/// In this case result will be calculated as datetime, so we need to get the amount of days if the arguments are Date.
if (isDate(result_type) || isDate32(result_type))
res = res / SECONDS_PER_DAY;
if (scale_on_interval > scale_multiplier)
result_data[i] += (origin + res / scale_diff) * scale_diff;
else if (scale_on_interval == scale_multiplier && scale_on_interval % 1000 != 0 && scale_multiplier != 10)
result_data[i] += origin + (res * scale_on_interval);
/// Case when Interval has default scale
if (scale_on_interval == 1)
{
/// Case when the arguments are DateTime64 with precision like 4,5,7,8. Here res has right precision and origin doesn't.
if (scale_on_time % 1000 != 0 && scale_on_time >= 1000)
result_data[i] += (origin + res / scale_on_time) * scale_on_time;
/// Special case when the arguments are DateTime64 with precision 2. Here origin has right precision and res doesn't
else if (scale_on_time == 100)
result_data[i] += (origin + res * scale_on_time);
/// Cases when precision of DateTime64 is 1, 3, 6, 9 e.g. has right precision in res and origin.
else
result_data[i] += (origin + res);
}
/// Case when Interval has some specific scale (3,6,9)
else
{
/// If we have a time argument that has bigger scale than the interval can contain, we need
/// to return a value with bigger precision and thus we should multiply result on the scale difference.
if (scale_on_interval < scale_on_time)
result_data[i] += origin + res * scale_diff;
/// The other case: interval has bigger scale than the interval or they have the same scale, so res has the right precision and origin doesn't
else
result_data[i] += (origin + res / scale_diff) * scale_diff;
}
}
}
}
@ -419,8 +445,6 @@ private:
}
};
}
REGISTER_FUNCTION(ToStartOfInterval)
{
factory.registerFunction<FunctionToStartOfInterval>();