From 2a0b98b19c3f514d4734f365b86102b0ff2dbcf0 Mon Sep 17 00:00:00 2001 From: Artem Zuikov Date: Thu, 27 Aug 2020 16:17:13 +0300 Subject: [PATCH] add countDigits() function --- src/Functions/countDigits.cpp | 155 ++++++++++++++++++ src/Functions/isDecimalOverflow.cpp | 48 ++---- .../registerFunctionsMiscellaneous.cpp | 2 + src/Functions/ya.make | 2 + .../0_stateless/01458_count_digits.reference | 6 + .../0_stateless/01458_count_digits.sql | 26 +++ 6 files changed, 204 insertions(+), 35 deletions(-) create mode 100644 src/Functions/countDigits.cpp create mode 100644 tests/queries/0_stateless/01458_count_digits.reference create mode 100644 tests/queries/0_stateless/01458_count_digits.sql diff --git a/src/Functions/countDigits.cpp b/src/Functions/countDigits.cpp new file mode 100644 index 00000000000..e0376d9a568 --- /dev/null +++ b/src/Functions/countDigits.cpp @@ -0,0 +1,155 @@ +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ILLEGAL_COLUMN; +} + +/// Returns 1 if and Decimal value has more digits then it's Precision allow, 0 otherwise. +/// Precision could be set as second argument or omitted. If ommited function uses Decimal presicion of the first argument. +class FunctionCountDigits : public IFunction +{ +public: + static constexpr auto name = "countDigits"; + + static FunctionPtr create(const Context &) + { + return std::make_shared(); + } + + String getName() const override { return name; } + bool useDefaultImplementationForNulls() const override { return false; } + size_t getNumberOfArguments() const override { return 1; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + WhichDataType which_first(arguments[0]->getTypeId()); + + if (!which_first.isInt() && !which_first.isUInt() && !which_first.isDecimal()) + throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + return std::make_shared(); /// Up to 255 decimal digits. + } + + void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result_pos, size_t input_rows_count) const override + { + const auto & src_column = block.getByPosition(arguments[0]); + if (!src_column.column) + throw Exception("Illegal column while execute function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + auto result_column = ColumnUInt8::create(); + + auto call = [&](const auto & types) -> bool + { + using Types = std::decay_t; + using Type = typename Types::RightType; + using ColVecType = std::conditional_t, ColumnDecimal, ColumnVector>; + + if (const ColumnConst * const_column = checkAndGetColumnConst(src_column.column.get())) + { + Type const_value = checkAndGetColumn(const_column->getDataColumnPtr().get())->getData()[0]; + UInt32 num_digits = 0; + if constexpr (IsDecimalNumber) + num_digits = digits(const_value.value); + else + num_digits = digits(const_value); + result_column->getData().resize_fill(input_rows_count, num_digits); + return true; + } + else if (const ColVecType * col_vec = checkAndGetColumn(src_column.column.get())) + { + execute(*col_vec, *result_column, input_rows_count); + return true; + } + + throw Exception("Illegal column while execute function " + getName(), ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + }; + + TypeIndex dec_type_idx = src_column.type->getTypeId(); + if (!callOnBasicType(dec_type_idx, call)) + throw Exception("Wrong call for " + getName() + " with " + src_column.type->getName(), + ErrorCodes::ILLEGAL_COLUMN); + + block.getByPosition(result_pos).column = std::move(result_column); + } + +private: + template + static void execute(const ColVecType & col, ColumnUInt8 & result_column, size_t rows_count) + { + using NativeT = typename NativeType::Type; + + const auto & src_data = col.getData(); + auto & dst_data = result_column.getData(); + dst_data.resize(rows_count); + + for (size_t i = 0; i < rows_count; ++i) + { + if constexpr (IsDecimalNumber) + dst_data[i] = digits(src_data[i].value); + else + dst_data[i] = digits(src_data[i]); + } + } + + template + static UInt32 digits(T value) + { + static_assert(!IsDecimalNumber); + using DivT = std::conditional_t, Int32, UInt32>; + + UInt32 res = 0; + T tmp; + + if constexpr (sizeof(T) > sizeof(Int32)) + { + static constexpr const DivT e9 = 1000000000; + + tmp = value / e9; + while (tmp != 0) + { + value = tmp; + tmp /= e9; + res += 9; + } + } + + static constexpr const DivT e3 = 1000; + + tmp = value / e3; + while (tmp != 0) + { + value = tmp; + tmp /= e3; + res += 3; + } + + while (value != 0) + { + value /= 10; + ++res; + } + return res; + } +}; + + +void registerFunctionCountDigits(FunctionFactory & factory) +{ + factory.registerFunction(); +} + +} diff --git a/src/Functions/isDecimalOverflow.cpp b/src/Functions/isDecimalOverflow.cpp index de35689c9c5..f1b63fd0844 100644 --- a/src/Functions/isDecimalOverflow.cpp +++ b/src/Functions/isDecimalOverflow.cpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace DB @@ -37,7 +38,7 @@ public: DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { - if (arguments.size() < 1 || arguments.size() > 2) + if (arguments.empty() || arguments.size() > 2) throw Exception("Number of arguments for function " + getName() + " doesn't match: passed " + toString(arguments.size()) + ", should be 1 or 2.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); @@ -87,13 +88,12 @@ public: { using Types = std::decay_t; using Type = typename Types::RightType; - using NativeT = typename Type::NativeType; using ColVecType = ColumnDecimal; if (const ColumnConst * const_column = checkAndGetColumnConst(src_column.column.get())) { Type const_decimal = checkAndGetColumn(const_column->getDataColumnPtr().get())->getData()[0]; - UInt8 res_value = (digits(const_decimal.value) > precision); + UInt8 res_value = outOfDigits(const_decimal, precision); result_column->getData().resize_fill(input_rows_count, res_value); return true; } @@ -118,50 +118,28 @@ private: template static void execute(const ColumnDecimal & col, ColumnUInt8 & result_column, size_t rows_count, UInt32 precision) { - using NativeT = typename T::NativeType; - const auto & src_data = col.getData(); auto & dst_data = result_column.getData(); dst_data.resize(rows_count); for (size_t i = 0; i < rows_count; ++i) - dst_data[i] = (digits(src_data[i].value) > precision); + dst_data[i] = outOfDigits(src_data[i], precision); } template - static UInt32 digits(T value) + static bool outOfDigits(T dec, UInt32 precision) { - UInt32 res = 0; - T tmp; + static_assert(IsDecimalNumber); + using NativeT = typename T::NativeType; - static constexpr const Int32 e3 = 1000; - static constexpr const Int32 e9 = 1000000000; + if (precision > DecimalUtils::maxPrecision()) + return false; - if constexpr (sizeof(T) > sizeof(Int32)) - { - tmp = value / e9; - while (tmp) - { - value = tmp; - tmp /= e9; - res += 9; - } - } + NativeT pow10 = intExp10OfSize(precision); - tmp = value / e3; - while (tmp) - { - value = tmp; - tmp /= e3; - res += 3; - } - - while (value) - { - value /= 10; - ++res; - } - return res; + if (dec.value < 0) + return dec.value <= -pow10; + return dec.value >= pow10; } }; diff --git a/src/Functions/registerFunctionsMiscellaneous.cpp b/src/Functions/registerFunctionsMiscellaneous.cpp index 0a9180240a8..414f6ec5f8e 100644 --- a/src/Functions/registerFunctionsMiscellaneous.cpp +++ b/src/Functions/registerFunctionsMiscellaneous.cpp @@ -60,6 +60,7 @@ void registerFunctionGetScalar(FunctionFactory &); void registerFunctionGetSetting(FunctionFactory &); void registerFunctionIsConstant(FunctionFactory &); void registerFunctionIsDecimalOverflow(FunctionFactory &); +void registerFunctionCountDigits(FunctionFactory &); void registerFunctionGlobalVariable(FunctionFactory &); void registerFunctionHasThreadFuzzer(FunctionFactory &); void registerFunctionInitializeAggregation(FunctionFactory &); @@ -123,6 +124,7 @@ void registerFunctionsMiscellaneous(FunctionFactory & factory) registerFunctionGetSetting(factory); registerFunctionIsConstant(factory); registerFunctionIsDecimalOverflow(factory); + registerFunctionCountDigits(factory); registerFunctionGlobalVariable(factory); registerFunctionHasThreadFuzzer(factory); registerFunctionInitializeAggregation(factory); diff --git a/src/Functions/ya.make b/src/Functions/ya.make index d0c261e1265..31d5dfa9fd3 100644 --- a/src/Functions/ya.make +++ b/src/Functions/ya.make @@ -132,6 +132,7 @@ SRCS( concat.cpp convertCharset.cpp cos.cpp + countDigits.cpp CRC.cpp currentDatabase.cpp currentUser.cpp @@ -243,6 +244,7 @@ SRCS( intExp10.cpp intExp2.cpp isConstant.cpp + isDecimalOverflow.cpp isFinite.cpp isInfinite.cpp isNaN.cpp diff --git a/tests/queries/0_stateless/01458_count_digits.reference b/tests/queries/0_stateless/01458_count_digits.reference new file mode 100644 index 00000000000..46b87cd22b7 --- /dev/null +++ b/tests/queries/0_stateless/01458_count_digits.reference @@ -0,0 +1,6 @@ +0 2 2 0 2 3 0 2 4 +2 3 4 +10 10 19 19 39 39 +2 2 2 2 2 2 2 2 2 2 2 2 +0 0 0 0 0 0 0 0 0 0 0 0 +3 3 3 5 5 5 10 10 10 19 19 20 diff --git a/tests/queries/0_stateless/01458_count_digits.sql b/tests/queries/0_stateless/01458_count_digits.sql new file mode 100644 index 00000000000..91ca07469e5 --- /dev/null +++ b/tests/queries/0_stateless/01458_count_digits.sql @@ -0,0 +1,26 @@ +SELECT countDigits(toDecimal32(0, 0)), countDigits(toDecimal32(42, 0)), countDigits(toDecimal32(4.2, 1)), + countDigits(toDecimal64(0, 0)), countDigits(toDecimal64(42, 0)), countDigits(toDecimal64(4.2, 2)), + countDigits(toDecimal128(0, 0)), countDigits(toDecimal128(42, 0)), countDigits(toDecimal128(4.2, 3)); + +SELECT countDigits(materialize(toDecimal32(4.2, 1))), + countDigits(materialize(toDecimal64(4.2, 2))), + countDigits(materialize(toDecimal128(4.2, 3))); + +SELECT countDigits(toDecimal32(1, 9)), countDigits(toDecimal32(-1, 9)), + countDigits(toDecimal64(1, 18)), countDigits(toDecimal64(-1, 18)), + countDigits(toDecimal128(1, 38)), countDigits(toDecimal128(-1, 38)); + +SELECT countDigits(toInt8(42)), countDigits(toInt8(-42)), countDigits(toUInt8(42)), + countDigits(toInt16(42)), countDigits(toInt16(-42)), countDigits(toUInt16(42)), + countDigits(toInt32(42)), countDigits(toInt32(-42)), countDigits(toUInt32(42)), + countDigits(toInt64(42)), countDigits(toInt64(-42)), countDigits(toUInt64(42)); + +SELECT countDigits(toInt8(0)), countDigits(toInt8(0)), countDigits(toUInt8(0)), + countDigits(toInt16(0)), countDigits(toInt16(0)), countDigits(toUInt16(0)), + countDigits(toInt32(0)), countDigits(toInt32(0)), countDigits(toUInt32(0)), + countDigits(toInt64(0)), countDigits(toInt64(0)), countDigits(toUInt64(0)); + +SELECT countDigits(toInt8(127)), countDigits(toInt8(-128)), countDigits(toUInt8(255)), + countDigits(toInt16(32767)), countDigits(toInt16(-32768)), countDigits(toUInt16(65535)), + countDigits(toInt32(2147483647)), countDigits(toInt32(-2147483648)), countDigits(toUInt32(4294967295)), + countDigits(toInt64(9223372036854775807)), countDigits(toInt64(-9223372036854775808)), countDigits(toUInt64(18446744073709551615));