unary operator

This commit is contained in:
Alexey Boykov 2021-08-30 20:13:41 +03:00
parent 65804040e4
commit 75604817b2
3 changed files with 49 additions and 2 deletions

View File

@ -83,6 +83,8 @@ class FunctionUnaryArithmetic : public IFunction
static constexpr bool allow_fixed_string = Op<UInt8>::allow_fixed_string; static constexpr bool allow_fixed_string = Op<UInt8>::allow_fixed_string;
static constexpr bool is_sign_function = IsUnaryOperation<Op>::sign; static constexpr bool is_sign_function = IsUnaryOperation<Op>::sign;
ContextPtr context;
template <typename F> template <typename F>
static bool castType(const IDataType * type, F && f) static bool castType(const IDataType * type, F && f)
{ {
@ -109,10 +111,29 @@ class FunctionUnaryArithmetic : public IFunction
>(type, std::forward<F>(f)); >(type, std::forward<F>(f));
} }
static FunctionOverloadResolverPtr
getFunctionForTupleArithmetic(const DataTypePtr & type, ContextPtr context)
{
if (!isTuple(type))
return {};
/// Special case when the function is negate, argument is tuple.
/// We construct another function (example: tupleNegate) and call it.
if constexpr (!IsUnaryOperation<Op>::negate)
return {};
return FunctionFactory::instance().get("tupleNegate", context);
}
public: public:
static constexpr auto name = Name::name; static constexpr auto name = Name::name;
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionUnaryArithmetic>(); } static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionUnaryArithmetic>(); }
FunctionUnaryArithmetic() = default;
explicit FunctionUnaryArithmetic(ContextPtr context_) : context(context_) {}
String getName() const override String getName() const override
{ {
return name; return name;
@ -126,6 +147,22 @@ public:
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
{ {
return getReturnTypeImplStatic(arguments, context);
}
static DataTypePtr getReturnTypeImplStatic(const DataTypes & arguments, ContextPtr context)
{
/// Special case when the function is negate, argument is tuple.
if (auto function_builder = getFunctionForTupleArithmetic(arguments[0], context))
{
ColumnsWithTypeAndName new_arguments(1);
new_arguments[0].type = arguments[0];
auto function = function_builder->build(new_arguments);
return function->getResultType();
}
DataTypePtr result; DataTypePtr result;
bool valid = castType(arguments[0].get(), [&](const auto & type) bool valid = castType(arguments[0].get(), [&](const auto & type)
{ {
@ -152,13 +189,19 @@ public:
return true; return true;
}); });
if (!valid) if (!valid)
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(), throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + String(name),
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return result; return result;
} }
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /*input_rows_count*/) const override ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override
{ {
/// Special case when the function is negate, argument is tuple.
if (auto function_builder = getFunctionForTupleArithmetic(arguments[0].type, context))
{
return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count);
}
ColumnPtr result_column; ColumnPtr result_column;
bool valid = castType(arguments[0].type.get(), [&](const auto & type) bool valid = castType(arguments[0].type.get(), [&](const auto & type)
{ {

View File

@ -6,6 +6,8 @@
(-2.5,4,2.75) (-2.5,4,2.75)
(3) (3)
(-1,0,-3.5) (-1,0,-3.5)
(-1,-2,-3)
(-1)
(0.5,1,1.5) (0.5,1,1.5)
(2,5,6) (2,5,6)
(1) (1)

View File

@ -8,6 +8,8 @@ SELECT tupleDivide((5, 8, 11), (-2, 2, 4));
SELECT tuple(1) + tuple(2); SELECT tuple(1) + tuple(2);
SELECT tupleNegate((1, 0, 3.5)); SELECT tupleNegate((1, 0, 3.5));
SELECT -(1, 2, 3);
SELECT -tuple(1);
SELECT tupleMultiplyByNumber((1, 2, 3), 0.5); SELECT tupleMultiplyByNumber((1, 2, 3), 0.5);
SELECT tupleDivideByNumber((1, 2.5, 3), 0.5); SELECT tupleDivideByNumber((1, 2.5, 3), 0.5);