diff --git a/dbms/include/DB/Functions/FunctionsRound.h b/dbms/include/DB/Functions/FunctionsRound.h index be73593f149..6caa02bc28b 100644 --- a/dbms/include/DB/Functions/FunctionsRound.h +++ b/dbms/include/DB/Functions/FunctionsRound.h @@ -1,8 +1,9 @@ #pragma once #include -#include // log2() -#include +#include +#include +#include namespace { @@ -28,11 +29,11 @@ struct PowerOf10<0> template struct TablePowersOf10 { - static const size_t value[sizeof...(TArgs)]; + static const std::array value; }; template -const size_t TablePowersOf10::value[sizeof...(TArgs)] = { TArgs... }; +const std::array TablePowersOf10::value = { TArgs... }; /// Сгенерить первые N степеней. @@ -54,8 +55,7 @@ struct FillArray using result = typename FillArrayImpl::result; }; -static const size_t powers_count = 16; -using powers_of_10 = FillArray::result; +using powers_of_10 = FillArray<16>::result; } @@ -145,44 +145,110 @@ namespace DB } }; + template struct RoundImpl { - static inline double apply(double x) + static inline T apply(T val) { - return round(x); + throw Exception("Invalid invokation", ErrorCodes::LOGICAL_ERROR); } }; + template<> + struct RoundImpl + { + static inline Float32 apply(Float32 val) + { + return roundf(val); + } + }; + + template<> + struct RoundImpl + { + static inline Float64 apply(Float64 val) + { + return round(val); + } + }; + + template struct CeilImpl { - static inline double apply(double x) + static inline T apply(T val) { - return ceil(x); + throw Exception("Invalid invokation", ErrorCodes::LOGICAL_ERROR); } }; + template<> + struct CeilImpl + { + static inline Float32 apply(Float32 val) + { + return ceilf(val); + } + }; + + template<> + struct CeilImpl + { + static inline Float64 apply(Float64 val) + { + return ceil(val); + } + }; + + template struct FloorImpl { - static inline double apply(double x) + static inline T apply(T val) { - return floor(x); + throw Exception("Invalid invokation", ErrorCodes::LOGICAL_ERROR); } }; - template + template<> + struct FloorImpl + { + static inline Float32 apply(Float32 val) + { + return floorf(val); + } + }; + + template<> + struct FloorImpl + { + static inline Float64 apply(Float64 val) + { + return floor(val); + } + }; + + template class Op, typename PowersTable> struct FunctionApproximatingImpl { - static inline A apply(A a, Int8 scale) + template + static inline A2 apply(A2 a, UInt8 scale, typename std::enable_if::value>::type * = nullptr) { - if (scale < 0) - scale = 0; + if (a == 0) + return a; + else + { + size_t power = PowersTable::value[scale]; + return Op::apply(a * power) / power; + } + } - size_t power = (scale < static_cast(powers_count)) ? powers_of_10::value[scale] : pow(10, scale); - return static_cast(Op::apply(a * power) / power); + template + static inline A2 apply(A2 a, UInt8 scale, typename std::enable_if::value>::type * = nullptr) + { + return a; } }; - template + template class Op, typename PowersTable, typename Name> class FunctionApproximating : public IFunction { public: @@ -190,7 +256,13 @@ namespace DB static IFunction * create(const Context & context) { return new FunctionApproximating; } private: - template + template + bool checkType(const IDataType * type) const + { + return typeid_cast(type) != nullptr; + } + + template bool executeType(Block & block, const ColumnNumbers & arguments, Int8 scale, size_t result) { if (ColumnVector * col = typeid_cast *>(&*block.getByPosition(arguments[0]).column)) @@ -204,13 +276,13 @@ namespace DB const PODArray & a = col->getData(); size_t size = a.size(); for (size_t i = 0; i < size; ++i) - vec_res[i] = FunctionApproximatingImpl::apply(a[i], scale); + vec_res[i] = FunctionApproximatingImpl::apply(a[i], scale); return true; } else if (ColumnConst * col = typeid_cast *>(&*block.getByPosition(arguments[0]).column)) { - T0 res = FunctionApproximatingImpl::apply(col->getData(), scale); + T0 res = FunctionApproximatingImpl::apply(col->getData(), scale); ColumnConst * col_res = new ColumnConst(col->size(), res); block.getByPosition(result).column = col_res; @@ -221,6 +293,48 @@ namespace DB return false; } + template + bool getScaleForType(const ColumnPtr & column, UInt8 & scale) + { + using ColumnType = ColumnConst; + + const ColumnType * scale_col = typeid_cast(&*column); + if (scale_col == nullptr) + return false; + + T val = scale_col->getData(); + if (std::is_signed::value && (val < 0)) + val = 0; + else if (val >= static_cast(PowersTable::value.size())) + val = static_cast(PowersTable::value.size()) - 1; + + scale = static_cast(val); + + return true; + } + + UInt8 getScale(const ColumnPtr & column) + { + UInt8 scale = 0; + + if (!( getScaleForType(column, scale) + || getScaleForType(column, scale) + || getScaleForType(column, scale) + || getScaleForType(column, scale) + || getScaleForType(column, scale) + || getScaleForType(column, scale) + || getScaleForType(column, scale) + || getScaleForType(column, scale) + || getScaleForType(column, scale))) + { + throw Exception("Illegal column " + column->getName() + + " of second ('scale') argument of function " + getName(), + ErrorCodes::ILLEGAL_COLUMN); + } + + return scale; + } + public: /// Получить имя функции. String getName() const override @@ -236,8 +350,21 @@ namespace DB + toString(arguments.size()) + ", should be 1.", ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); - if ((arguments.size() == 2) && (arguments[1]->getName() != TypeName::get())) - throw Exception("Illegal type in second argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + if (arguments.size() == 2) + { + const IDataType * type = &*arguments[1]; + if (!( checkType(type) + || checkType(type) + || checkType(type) + || checkType(type) + || checkType(type) + || checkType(type) + || checkType(type) + || checkType(type))) + { + throw Exception("Illegal type in second argument", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + } const IDataType * type = &*arguments[0]; if (!type->behavesAsNumber()) @@ -252,19 +379,9 @@ namespace DB /// Выполнить функцию над блоком. void execute(Block & block, const ColumnNumbers & arguments, size_t result) override { - Int8 scale = 0; + UInt8 scale = 0; if (arguments.size() == 2) - { - const ColumnConst * scale_col = typeid_cast(&*block.getByPosition(arguments[1]).column); - - if (scale_col == nullptr) - throw Exception("Illegal column " + block.getByPosition(arguments[1]).column->getName() - + " of second ('scale') argument of function " + name - + ". Must be constant int8.", - ErrorCodes::ILLEGAL_COLUMN); - - scale = scale_col->getData(); - } + scale = getScale(block.getByPosition(arguments[1]).column); if (!( executeType(block, arguments, scale, result) || executeType(block, arguments, scale, result) @@ -292,7 +409,7 @@ namespace DB typedef FunctionUnaryArithmetic FunctionRoundToExp2; typedef FunctionUnaryArithmetic FunctionRoundDuration; typedef FunctionUnaryArithmetic FunctionRoundAge; - typedef FunctionApproximating FunctionRound; - typedef FunctionApproximating FunctionCeil; - typedef FunctionApproximating FunctionFloor; + typedef FunctionApproximating FunctionRound; + typedef FunctionApproximating FunctionCeil; + typedef FunctionApproximating FunctionFloor; }