From c781908e6dbeaaa553c72a31a3e03630494f07f2 Mon Sep 17 00:00:00 2001 From: Vasily Nemkov Date: Wed, 11 Dec 2019 11:56:32 +0300 Subject: [PATCH] Post-PR fixes * More precise overflow check in readIntTextImpl * writeDateTimeText now always writes sub-second part for DateTime64 * comment for validateFunctionArgumentTypes * DateTime64-related fixes for FunctionConvertFromString * other minoe fixes: comments, removed commented out code, variable renamings, etc. --- dbms/src/Functions/FunctionHelpers.cpp | 2 +- dbms/src/Functions/FunctionHelpers.h | 17 ++--- dbms/src/Functions/FunctionsConversion.h | 16 +++-- dbms/src/IO/ReadHelpers.h | 63 ++++++++++--------- dbms/src/IO/WriteHelpers.h | 4 +- .../00921_datetime64_basic.reference | 2 +- 6 files changed, 55 insertions(+), 49 deletions(-) diff --git a/dbms/src/Functions/FunctionHelpers.cpp b/dbms/src/Functions/FunctionHelpers.cpp index 50ef38ddd4a..a24ab0fa17e 100644 --- a/dbms/src/Functions/FunctionHelpers.cpp +++ b/dbms/src/Functions/FunctionHelpers.cpp @@ -136,7 +136,7 @@ void validateArgumentsImpl(const IFunction & func, const auto & arg = arguments[i + argument_offset]; const auto validator = validators[i]; - if (validator.validator_func(*arg.type) == false) + if (!validator.validator_func(*arg.type)) throw Exception("Illegal type " + arg.type->getName() + " of " + std::to_string(i) + " argument of function " + func.getName() + diff --git a/dbms/src/Functions/FunctionHelpers.h b/dbms/src/Functions/FunctionHelpers.h index 29c33e95703..933b3c83ff3 100644 --- a/dbms/src/Functions/FunctionHelpers.h +++ b/dbms/src/Functions/FunctionHelpers.h @@ -24,17 +24,6 @@ const Type * checkAndGetDataType(const IDataType * data_type) return typeid_cast(data_type); } -template -std::shared_ptr checkAndGetDataTypePtr(const DataTypePtr & data_type) -{ - if (typeid_cast(data_type.get())) - { - return std::static_pointer_cast(data_type); - } - - return std::shared_ptr(); -} - template const ColumnConst * checkAndGetColumnConst(const IColumn * column) { @@ -109,6 +98,12 @@ struct FunctionArgumentTypeValidator using FunctionArgumentTypeValidators = std::vector; +/** Validate that function arguments match specification. + * first, check that mandatory args present and have valid type. + * second, check optional arguents types, skipping ones that are missing. + * + * If any mandatory arg is missing, throw an exception, with explicit description of expected arguments. + */ void validateFunctionArgumentTypes(const IFunction & func, const ColumnsWithTypeAndName & arguments, const FunctionArgumentTypeValidators & mandatory_args, const FunctionArgumentTypeValidators & optional_args = {}); /// Checks if a list of array columns have equal offsets. Return a pair of nested columns and offsets if true, otherwise throw. diff --git a/dbms/src/Functions/FunctionsConversion.h b/dbms/src/Functions/FunctionsConversion.h index c7a2ad30b65..d1d0b33edfa 100644 --- a/dbms/src/Functions/FunctionsConversion.h +++ b/dbms/src/Functions/FunctionsConversion.h @@ -254,7 +254,7 @@ template struct ConvertImpl struct FromDateTime64Transform { - static constexpr auto name = "toDateTime64"; + static constexpr auto name = Transform::name; const DateTime64::NativeType scale_multiplier = 1; @@ -934,6 +934,7 @@ public: } else { + // Optional second argument with time zone for DateTime. UInt8 timezone_arg_position = 1; UInt32 scale [[maybe_unused]] = DataTypeDateTime64::default_scale; @@ -1079,8 +1080,7 @@ public: static constexpr bool to_decimal = std::is_same_v> || std::is_same_v> || - std::is_same_v> || - std::is_same_v; + std::is_same_v>; static FunctionPtr create(const Context &) { return std::make_shared(); } @@ -1144,7 +1144,7 @@ public: res = std::make_shared(extractTimeZoneNameFromFunctionArguments(arguments, 1, 0)); else if constexpr (to_decimal) { - UInt64 scale [[maybe_unused]] = extractToDecimalScale(arguments[1]); + UInt64 scale = extractToDecimalScale(arguments[1]); if constexpr (std::is_same_v>) res = createDecimal(9, scale); @@ -1156,6 +1156,12 @@ public: if (!res) throw Exception("Someting wrong with toDecimalNNOrZero() or toDecimalNNOrNull()", ErrorCodes::LOGICAL_ERROR); } + else if constexpr (std::is_same_v) + { + UInt64 scale = extractToDecimalScale(arguments[1]); + const auto timezone = extractTimeZoneNameFromFunctionArguments(arguments, 2, 0); + res = std::make_shared(scale, timezone); + } else res = std::make_shared(); @@ -1170,7 +1176,7 @@ public: const IDataType * from_type = block.getByPosition(arguments[0]).type.get(); bool ok = true; - if constexpr (to_decimal) + if constexpr (to_decimal || std::is_same_v) { if (arguments.size() != 2) throw Exception{"Function " + getName() + " expects 2 arguments for Decimal.", ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION}; diff --git a/dbms/src/IO/ReadHelpers.h b/dbms/src/IO/ReadHelpers.h index 9bd2a81bac6..714e3f28f52 100644 --- a/dbms/src/IO/ReadHelpers.h +++ b/dbms/src/IO/ReadHelpers.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -256,14 +257,14 @@ inline void readBoolTextWord(bool & x, ReadBuffer & buf) } } -enum ReadIntTextCheckOverflow +enum class ReadIntTextCheckOverflow { - READ_INT_DO_NOT_CHECK_OVERFLOW, - READ_INT_CHECK_OVERFLOW, + DO_NOT_CHECK_OVERFLOW, + CHECK_OVERFLOW, }; -template -ReturnType readIntTextImpl(T & x, ReadBuffer & buf, ReadIntTextCheckOverflow check_overflow = READ_INT_DO_NOT_CHECK_OVERFLOW) +template +ReturnType readIntTextImpl(T & x, ReadBuffer & buf) { static constexpr bool throw_exception = std::is_same_v; @@ -277,7 +278,7 @@ ReturnType readIntTextImpl(T & x, ReadBuffer & buf, ReadIntTextCheckOverflow che return ReturnType(false); } - size_t initial = buf.count(); + const size_t initial_pos = buf.count(); while (!buf.eof()) { switch (*buf.position()) @@ -285,7 +286,7 @@ ReturnType readIntTextImpl(T & x, ReadBuffer & buf, ReadIntTextCheckOverflow che case '+': break; case '-': - if (is_signed_v) + if constexpr (is_signed_v) negative = true; else { @@ -305,44 +306,48 @@ ReturnType readIntTextImpl(T & x, ReadBuffer & buf, ReadIntTextCheckOverflow che case '7': [[fallthrough]]; case '8': [[fallthrough]]; case '9': + if constexpr (check_overflow == ReadIntTextCheckOverflow::CHECK_OVERFLOW) + { + // perform relativelly slow overflow check only when number of decimal digits so far is close to the max for given type. + if (buf.count() - initial_pos >= std::numeric_limits::max_digits10) + { + if (common::mulOverflow(res, static_cast(10), res) + || common::addOverflow(res, static_cast(*buf.position() - '0'), res)) + return ReturnType(false); + break; + } + } res *= 10; res += *buf.position() - '0'; break; default: goto end; -// x = negative ? -res : res; -// return ReturnType(true); } ++buf.position(); } end: x = negative ? -res : res; - if (check_overflow && buf.count() - initial > std::numeric_limits::digits10) - { - // the int literal is too big and x overflowed - return ReturnType(false); - } return ReturnType(true); } -template -void readIntText(T & x, ReadBuffer & buf, ReadIntTextCheckOverflow check_overflow = READ_INT_DO_NOT_CHECK_OVERFLOW) +template +void readIntText(T & x, ReadBuffer & buf) { - readIntTextImpl(x, buf, check_overflow); + readIntTextImpl(x, buf); } -template -bool tryReadIntText(T & x, ReadBuffer & buf, ReadIntTextCheckOverflow check_overflow = READ_INT_DO_NOT_CHECK_OVERFLOW) +template +bool tryReadIntText(T & x, ReadBuffer & buf) { - return readIntTextImpl(x, buf, check_overflow); + return readIntTextImpl(x, buf); } -template -void readIntText(Decimal & x, ReadBuffer & buf, ReadIntTextCheckOverflow check_overflow = READ_INT_DO_NOT_CHECK_OVERFLOW) +template +void readIntText(Decimal & x, ReadBuffer & buf) { - readIntText(x.value, buf, check_overflow); + readIntText(x.value, buf); } /** More efficient variant (about 1.5 times on real dataset). @@ -642,7 +647,7 @@ inline ReturnType readDateTimeTextImpl(time_t & datetime, ReadBuffer & buf, cons } else /// Why not readIntTextUnsafe? Because for needs of AdFox, parsing of unix timestamp with leading zeros is supported: 000...NNNN. - return readIntTextImpl(datetime, buf, READ_INT_CHECK_OVERFLOW); + return readIntTextImpl(datetime, buf); } else return readDateTimeTextFallback(datetime, buf, date_lut); @@ -662,8 +667,8 @@ inline ReturnType readDateTimeTextImpl(DateTime64 & datetime64, UInt32 scale, Re if (!buf.eof() && *buf.position() == '.') { buf.ignore(1); // skip separator - const auto count1 = buf.count(); - if (!tryReadIntText(c.fractional, buf, READ_INT_CHECK_OVERFLOW)) + const auto pos_before_fractional = buf.count(); + if (!tryReadIntText(c.fractional, buf)) { return ReturnType(false); } @@ -674,7 +679,7 @@ inline ReturnType readDateTimeTextImpl(DateTime64 & datetime64, UInt32 scale, Re // If scale is 3, but we read '12', promote fractional part to '120'. // And vice versa: if we read '1234', denote it to '123'. - const auto fractional_length = static_cast(buf.count() - count1); + const auto fractional_length = static_cast(buf.count() - pos_before_fractional); if (const auto adjust_scale = static_cast(scale) - fractional_length; adjust_scale > 0) { c.fractional *= common::exp10_i64(adjust_scale); @@ -935,11 +940,11 @@ void readAndThrowException(ReadBuffer & buf, const String & additional_message = /** Helper function for implementation. */ -template +template static inline const char * tryReadIntText(T & x, const char * pos, const char * end) { ReadBufferFromMemory in(pos, end - pos); - tryReadIntText(x, in); + tryReadIntText(x, in); return pos + in.count(); } diff --git a/dbms/src/IO/WriteHelpers.h b/dbms/src/IO/WriteHelpers.h index d9449d0da10..097489d258e 100644 --- a/dbms/src/IO/WriteHelpers.h +++ b/dbms/src/IO/WriteHelpers.h @@ -732,7 +732,7 @@ inline void writeDateTimeText(DateTime64 datetime64, UInt32 scale, WriteBuffer & LocalDateTime(values.year, values.month, values.day_of_month, date_lut.toHour(c.whole), date_lut.toMinute(c.whole), date_lut.toSecond(c.whole)), buf); - if (scale > 0 && c.fractional) + if (scale > 0) { buf.write(fractional_time_delimiter); @@ -740,7 +740,7 @@ inline void writeDateTimeText(DateTime64 datetime64, UInt32 scale, WriteBuffer & static_assert(sizeof(data) >= MaxScale); auto fractional = c.fractional; - for (Int32 pos = scale - 1; pos >= 0; --pos, fractional /= DateTime64(10)) + for (Int32 pos = scale - 1; pos >= 0 && fractional; --pos, fractional /= DateTime64(10)) data[pos] += fractional % DateTime64(10); writeString(&data[0], static_cast(scale), buf); diff --git a/dbms/tests/queries/0_stateless/00921_datetime64_basic.reference b/dbms/tests/queries/0_stateless/00921_datetime64_basic.reference index cdcf961727c..6352bd34f98 100644 --- a/dbms/tests/queries/0_stateless/00921_datetime64_basic.reference +++ b/dbms/tests/queries/0_stateless/00921_datetime64_basic.reference @@ -1,3 +1,3 @@ -2019-09-16 19:20:11 +2019-09-16 19:20:11.000 2019-05-03 11:25:25.123 2019-05-03 2019-05-02 21:00:00 2019-04-01 1970-01-02 11:25:25 2019-05-03 11:25:00 2019-09-16 19:20:11.234