diff --git a/dbms/include/DB/Functions/FunctionsArray.h b/dbms/include/DB/Functions/FunctionsArray.h index e8e0f71508d..b6262298db8 100644 --- a/dbms/include/DB/Functions/FunctionsArray.h +++ b/dbms/include/DB/Functions/FunctionsArray.h @@ -17,6 +17,8 @@ #include #include +#include + #include @@ -1226,6 +1228,7 @@ struct FunctionEmptyArray : public IFunction static const String name; static IFunction * create(const Context & context) { return new FunctionEmptyArray; } +private: String getName() const { return name; @@ -1255,6 +1258,118 @@ struct FunctionEmptyArray : public IFunction template const String FunctionEmptyArray::name = FunctionEmptyArray::base_name + DataTypeToName::get(); +class FunctionRange : public IFunction +{ +public: + static constexpr auto name = "range"; + static IFunction * create(const Context &) { return new FunctionRange; } + +private: + String getName() const override + { + return name; + } + + DataTypePtr getReturnType(const DataTypes & arguments) const override + { + if (arguments.size() != 1) + throw Exception{ + "Number of arguments for function " + getName() + " doesn't match: passed " + + toString(arguments.size()) + ", should be 1.", + ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH + }; + + const auto arg = arguments.front().get(); + + if (!typeid_cast(arg) && + !typeid_cast(arg) && + !typeid_cast(arg) & + !typeid_cast(arg)) + { + throw Exception{ + "Illegal type " + arg->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT + }; + } + + return new DataTypeArray{arg->clone()}; + } + + template + bool execute(Block & block, const IColumn * const arg, const size_t result) override + { + if (const auto in = typeid_cast *>(arg)) + { + const auto & in_data = in->getData(); + const auto total_values = std::accumulate(std::begin(in_data), std::end(in_data), std::size_t{}, + std::plus{}); + + const auto data_col = new ColumnVector{total_values}; + const auto out = new ColumnArray{ + data_col, + new ColumnArray::ColumnOffsets_t{in->size()} + }; + block.getByPosition(result).column = out; + + auto & out_data = data_col->getData(); + auto & out_offsets = out->getOffsets(); + + IColumn::Offset_t offset{}; + for (const auto i : ext::range(0, in->size())) + { + std::copy(ext::make_range_iterator(T{}), ext::make_range_iterator(in_data[i]), &out_data[offset]); + offset += in_data[i]; + out_offsets[i] = offset; + } + + return true; + } + else if (const auto in = typeid_cast *>(arg)) + { + const auto & in_data = in->getData(); + const std::size_t total_values = in->size() * in_data; + + const auto data_col = new ColumnVector{total_values}; + const auto out = new ColumnArray{ + data_col, + new ColumnArray::ColumnOffsets_t{in->size()} + }; + block.getByPosition(result).column = out; + + auto & out_data = data_col->getData(); + auto & out_offsets = out->getOffsets(); + + IColumn::Offset_t offset{}; + for (const auto i : ext::range(0, in->size())) + { + std::copy(ext::make_range_iterator(T{}), ext::make_range_iterator(in_data), &out_data[offset]); + offset += in_data; + out_offsets[i] = offset; + } + + return true; + } + + return false; + } + + void execute(Block & block, const ColumnNumbers & arguments, const size_t result) override + { + const auto col = block.getByPosition(arguments[0]).column.get(); + + if (!execute(block, col, result) && + !execute(block, col, result) && + !execute(block, col, result) && + !execute(block, col, result)) + { + throw Exception{ + "Illegal column " + col->getName() + " of argument of function " + getName(), + ErrorCodes::ILLEGAL_COLUMN + }; + } + } +}; + struct NameHas { static constexpr auto name = "has"; }; struct NameIndexOf { static constexpr auto name = "indexOf"; }; diff --git a/dbms/include/DB/Functions/FunctionsMath.h b/dbms/include/DB/Functions/FunctionsMath.h index 9a4a8e8ba78..e74d73f2d28 100644 --- a/dbms/include/DB/Functions/FunctionsMath.h +++ b/dbms/include/DB/Functions/FunctionsMath.h @@ -230,19 +230,17 @@ private: } template - bool executeRight(Block & block, const ColumnNumbers & arguments, const size_t result, - const ColumnConst * const left_arg) + bool executeRight(Block & block, const size_t result, const ColumnConst * const left_arg, + const IColumn * const right_arg) { - const auto arg = block.getByPosition(arguments[1]).column.get(); - - if (const auto right_arg = typeid_cast *>(arg)) + if (const auto right_arg_typed = typeid_cast *>(right_arg)) { const auto dst = new ColumnVector; block.getByPosition(result).column = dst; LeftType left_src_data[Impl::rows_per_iteration]; std::fill(std::begin(left_src_data), std::end(left_src_data), left_arg->getData()); - const auto & right_src_data = right_arg->getData(); + const auto & right_src_data = right_arg_typed->getData(); const auto src_size = right_src_data.size(); auto & dst_data = dst->getData(); dst_data.resize(src_size); @@ -267,10 +265,10 @@ private: return true; } - else if (const auto right_arg = typeid_cast *>(arg)) + else if (const auto right_arg_typed = typeid_cast *>(right_arg)) { const LeftType left_src[Impl::rows_per_iteration] { left_arg->getData() }; - const RightType right_src[Impl::rows_per_iteration] { right_arg->getData() }; + const RightType right_src[Impl::rows_per_iteration] { right_arg_typed->getData() }; Float64 dst[Impl::rows_per_iteration]; Impl::execute(left_src, right_src, dst); @@ -284,18 +282,16 @@ private: } template - bool executeRight(Block & block, const ColumnNumbers & arguments, const size_t result, - const ColumnVector * const left_arg) + bool executeRight(Block & block, const size_t result, const ColumnVector * const left_arg, + const IColumn * const right_arg) { - const auto arg = block.getByPosition(arguments[1]).column.get(); - - if (const auto right_arg = typeid_cast *>(arg)) + if (const auto right_arg_typed = typeid_cast *>(right_arg)) { const auto dst = new ColumnVector; block.getByPosition(result).column = dst; const auto & left_src_data = left_arg->getData(); - const auto & right_src_data = right_arg->getData(); + const auto & right_src_data = right_arg_typed->getData(); const auto src_size = left_src_data.size(); auto & dst_data = dst->getData(); dst_data.resize(src_size); @@ -323,14 +319,14 @@ private: return true; } - else if (const auto right_arg = typeid_cast *>(arg)) + else if (const auto right_arg_typed = typeid_cast *>(right_arg)) { const auto dst = new ColumnVector; block.getByPosition(result).column = dst; const auto & left_src_data = left_arg->getData(); RightType right_src_data[Impl::rows_per_iteration]; - std::fill(std::begin(right_src_data), std::end(right_src_data), right_arg->getData()); + std::fill(std::begin(right_src_data), std::end(right_src_data), right_arg_typed->getData()); const auto src_size = left_src_data.size(); auto & dst_data = dst->getData(); dst_data.resize(src_size); @@ -360,21 +356,23 @@ private: } template class LeftColumnType> - bool executeLeftImpl(Block & block, const ColumnNumbers & arguments, const size_t result) + bool executeLeftImpl(Block & block, const ColumnNumbers & arguments, const size_t result, + const IColumn * const left_arg) { - if (const auto arg = typeid_cast *>( - block.getByPosition(arguments[0]).column.get())) + if (const auto left_arg_typed = typeid_cast *>(left_arg)) { - if (executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg) || - executeRight(block, arguments, result, arg)) + const auto right_arg = block.getByPosition(arguments[1]).column.get(); + + if (executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg) || + executeRight(block, result, left_arg_typed, right_arg)) { return true; } @@ -392,10 +390,11 @@ private: } template - bool executeLeft(Block & block, const ColumnNumbers & arguments, const size_t result) + bool executeLeft(Block & block, const ColumnNumbers & arguments, const size_t result, + const IColumn * const left_arg) { - if (executeLeftImpl(block, arguments, result) || - executeLeftImpl(block, arguments, result)) + if (executeLeftImpl(block, arguments, result, left_arg) || + executeLeftImpl(block, arguments, result, left_arg)) return true; return false; @@ -403,19 +402,21 @@ private: void execute(Block & block, const ColumnNumbers & arguments, const size_t result) override { - if (!executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result) && - !executeLeft(block, arguments, result)) + const auto left_arg = block.getByPosition(arguments[0]).column.get(); + + if (!executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg) && + !executeLeft(block, arguments, result, left_arg)) { throw Exception{ - "Illegal column " + block.getByPosition(arguments[0]).column->getName() + " of argument of function " + getName(), + "Illegal column " + left_arg->getName() + " of argument of function " + getName(), ErrorCodes::ILLEGAL_COLUMN }; } diff --git a/dbms/src/Functions/FunctionsArray.cpp b/dbms/src/Functions/FunctionsArray.cpp index 15c79eeb454..80aacd534a6 100644 --- a/dbms/src/Functions/FunctionsArray.cpp +++ b/dbms/src/Functions/FunctionsArray.cpp @@ -26,6 +26,7 @@ void registerFunctionsArray(FunctionFactory & factory) factory.registerFunction(); factory.registerFunction(); factory.registerFunction(); + factory.registerFunction(); } } diff --git a/dbms/tests/queries/0_stateless/00087_math_functions.reference b/dbms/tests/queries/0_stateless/00087_math_functions.reference new file mode 100644 index 00000000000..f018a0a276a --- /dev/null +++ b/dbms/tests/queries/0_stateless/00087_math_functions.reference @@ -0,0 +1,78 @@ +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/dbms/tests/queries/0_stateless/00087_math_functions.sql b/dbms/tests/queries/0_stateless/00087_math_functions.sql new file mode 100644 index 00000000000..ad02b5502d3 --- /dev/null +++ b/dbms/tests/queries/0_stateless/00087_math_functions.sql @@ -0,0 +1,96 @@ +select abs(0) = 0; +select abs(1) = 1; +select abs(1) = 1; +select abs(0.0) = 0; +select abs(1.0) = 1.0; +select abs(-1.0) = 1.0; +select abs(-128) = 128; +select abs(127) = 127; +select sum(abs(number - 10 as x) = (x < 0 ? -x : x)) / count() array join range(1000000) as number; + +select square(0) = 0; +select square(1) = 1; +select square(2) = 4; +select sum(square(x) = x * x) / count() array join range(1000000) as x; + +select sqrt(0) = 0; +select sqrt(1) = 1; +select sqrt(4) = 2; +select sum(sqrt(square(x)) = x) / count() array join range(1000000) as x; + +select cbrt(0) = 0; +select cbrt(1) = 1; +select cbrt(8) = 2; +select sum(abs(cbrt(x * square(x)) - x) < 1.0e-9) / count() array join range(1000000) as x; + +select pow(1, 0) = 1; +select pow(2, 0) = 1; +select sum(pow(x, 0) = 1) / count() array join range(1000000) as x; +select pow(1, 1) = 1; +select pow(2, 1) = 2; +select sum(abs(pow(x, 1) - x) < 1.0e-9) / count() array join range(1000000) as x; +select sum(pow(x, 2) = square(x)) / count() array join range(10000) as x; + +select tgamma(0) = inf; +select tgamma(1) = 1; +select tgamma(2) = 1; +select tgamma(3) = 2; +select tgamma(4) = 6; + +select sum(abs(lgamma(x + 1) - log(tgamma(x + 1))) < 1.0e-9) / count() array join range(10) as x; + +select abs(e() - arraySum(arrayMap(x -> 1 / tgamma(x + 1), range(13)))) < 1.0e-9; + +select log(0) = -inf; +select log(1) = 0; +select log(e()) = 1; +select log(exp(1)) = 1; +select log(exp(2)) = 2; +select sum(abs(log(exp(x)) - x) < 1.0e-9) / count() array join range(100) as x; + +select exp2(-1) = 1/2; +select exp2(0) = 1; +select exp2(1) = 2; +select exp2(2) = 4; +select exp2(3) = 8; +select sum(exp2(x) = pow(2, x)) / count() array join range(1000) as x; + +select log2(0) = -inf; +select log2(1) = 0; +select log2(2) = 1; +select log2(4) = 2; +select sum(abs(log2(exp2(x)) - x) < 1.0e-9) / count() array join range(1000) as x; + +select sin(0) = 0; +select sin(pi() / 4) = 1 / sqrt(2); +select sin(pi() / 2) = 1; +select sin(3 * pi() / 2) = -1; +select sum(sin(pi() / 2 + 2 * pi() * x) = 1) / count() array join range(1000000) as x; + +select cos(0) = 1; +select abs(cos(pi() / 4) - 1 / sqrt(2)) < 1.0e-9; +select cos(pi() / 2) < 1.0e-9; +select sum(abs(cos(2 * pi() * x)) - 1 < 1.0e-9) / count() array join range(1000000) as x; + +select tan(0) = 0; +select abs(tan(pi() / 4) - 1) < 1.0e-9; +select sum(abs(tan(pi() / 4 + 2 * pi() * x) - 1) < 1.0e-8) / count() array join range(1000000) as x; + +select asin(0) = 0; +select asin(1) = pi() / 2; +select asin(-1) = -pi() / 2; + +select acos(0) = pi() / 2; +select acos(1) = 0; +select acos(-1) = pi(); + +select atan(0) = 0; +select atan(1) = pi() / 4; + +select erf(0) = 0; +select erf(-10) = -1; +select erf(10) = 1; + +select erfc(0) = 1; +select erfc(-10) = 2; +select erfc(28) = 0;