From b3bd6b5b29fb59b9b19da280a54987c18f8abd7f Mon Sep 17 00:00:00 2001 From: Alexey Boykov Date: Fri, 20 Aug 2021 17:06:57 +0300 Subject: [PATCH] Operators, refactoring Added operators, tupleHammingDistance has been refactored --- src/Functions/FunctionBinaryArithmetic.h | 51 +++ src/Functions/IsOperation.h | 2 +- src/Functions/TupleIFunction.h | 25 ++ src/Functions/registerFunctionsMath.cpp | 2 + src/Functions/tupleHammingDistance.cpp | 72 ++-- src/Functions/vectorFunctions.cpp | 333 ++++++++++++++++++ .../02011_tuple_vector_functions.reference | 5 + .../02011_tuple_vector_functions.sql | 6 + 8 files changed, 453 insertions(+), 43 deletions(-) create mode 100644 src/Functions/TupleIFunction.h create mode 100644 src/Functions/vectorFunctions.cpp create mode 100644 tests/queries/0_stateless/02011_tuple_vector_functions.reference create mode 100644 tests/queries/0_stateless/02011_tuple_vector_functions.sql diff --git a/src/Functions/FunctionBinaryArithmetic.h b/src/Functions/FunctionBinaryArithmetic.h index fc36541a078..8952e2a4b2a 100644 --- a/src/Functions/FunctionBinaryArithmetic.h +++ b/src/Functions/FunctionBinaryArithmetic.h @@ -645,6 +645,36 @@ class FunctionBinaryArithmetic : public IFunction return FunctionFactory::instance().get(function_name, context); } + static FunctionOverloadResolverPtr + getFunctionForTupleArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, ContextPtr context) + { + std::cout << "Into getFunctionForTupleArithmetic" << std::endl; + if (!isTuple(type0) || !isTuple(type1)) + return {}; + + /// Special case when the function is plus, minus or multiply, both arguments are tuples. + /// We construct another function (example: tuplePlus) and call it. + + if constexpr (!is_plus && !is_minus && !is_multiply) + return {}; + + std::string function_name; + if (is_plus) + { + function_name = "tuplePlus"; + } + else if (is_minus) + { + function_name = "tupleMinus"; + } + else + { + function_name = "tupleMultiply"; + } + + return FunctionFactory::instance().get(function_name, context); + } + static bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1) { if constexpr (!is_multiply) @@ -999,6 +1029,20 @@ public: return function->getResultType(); } + /// Special case when the function is plus, minus or multiply, both arguments are tuples. + if (auto function_builder = getFunctionForTupleArithmetic(arguments[0], arguments[1], context)) + { + std::cerr << "Tuple op" << std::endl; + ColumnsWithTypeAndName new_arguments(2); + + for (size_t i = 0; i < 2; ++i) + new_arguments[i].type = arguments[i]; + + auto function = function_builder->build(new_arguments); + return function->getResultType(); + } + + std::cerr << "Wow, it's here!" << std::endl; DataTypePtr type_res; const bool valid = castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right) @@ -1270,6 +1314,13 @@ public: return executeDateTimeIntervalPlusMinus(arguments, result_type, input_rows_count, function_builder); } + /// Special case when the function is plus, minus or multiply, both arguments are tuples. + if (auto function_builder + = getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context)) + { + return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count); + } + const auto & left_argument = arguments[0]; const auto & right_argument = arguments[1]; const auto * const left_generic = left_argument.type.get(); diff --git a/src/Functions/IsOperation.h b/src/Functions/IsOperation.h index 5b03ae3d189..369978fe271 100644 --- a/src/Functions/IsOperation.h +++ b/src/Functions/IsOperation.h @@ -3,7 +3,7 @@ namespace DB { -/// These classes should be present in DB namespace (cannot place them into nemelesspace) +/// These classes should be present in DB namespace (cannot place them into namelesspace) template struct AbsImpl; template struct NegateImpl; template struct PlusImpl; diff --git a/src/Functions/TupleIFunction.h b/src/Functions/TupleIFunction.h new file mode 100644 index 00000000000..adc2e4aa387 --- /dev/null +++ b/src/Functions/TupleIFunction.h @@ -0,0 +1,25 @@ +#include + +namespace DB +{ +class TupleIFunction : public IFunction +{ +public: + Columns getTupleElements(const IColumn & column) const + { + if (const auto * const_column = typeid_cast(&column)) + return convertConstTupleToConstantElements(*const_column); + + if (const auto * column_tuple = typeid_cast(&column)) + { + Columns columns(column_tuple->tupleSize()); + for (size_t i = 0; i < columns.size(); ++i) + columns[i] = column_tuple->getColumnPtr(i); + return columns; + } + + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument of function {} should be tuples, got {}", + getName(), column.getName()); + } +}; +} diff --git a/src/Functions/registerFunctionsMath.cpp b/src/Functions/registerFunctionsMath.cpp index bc851de9b93..8444dafe166 100644 --- a/src/Functions/registerFunctionsMath.cpp +++ b/src/Functions/registerFunctionsMath.cpp @@ -34,6 +34,7 @@ void registerFunctionAcosh(FunctionFactory & factory); void registerFunctionAtanh(FunctionFactory & factory); void registerFunctionPow(FunctionFactory & factory); void registerFunctionSign(FunctionFactory & factory); +void registerVectorFunctions(FunctionFactory &); void registerFunctionsMath(FunctionFactory & factory) @@ -70,6 +71,7 @@ void registerFunctionsMath(FunctionFactory & factory) registerFunctionAtanh(factory); registerFunctionPow(factory); registerFunctionSign(factory); + registerVectorFunctions(factory); } } diff --git a/src/Functions/tupleHammingDistance.cpp b/src/Functions/tupleHammingDistance.cpp index 9d660e388cb..badf5720110 100644 --- a/src/Functions/tupleHammingDistance.cpp +++ b/src/Functions/tupleHammingDistance.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include namespace DB @@ -16,14 +16,13 @@ namespace ErrorCodes /// tupleHammingDistance function: (Tuple(...), Tuple(...))-> N /// Return the number of non-equal tuple elements -class FunctionTupleHammingDistance : public IFunction +class FunctionTupleHammingDistance : public TupleIFunction { private: ContextPtr context; public: static constexpr auto name = "tupleHammingDistance"; - using ResultType = UInt8; explicit FunctionTupleHammingDistance(ContextPtr context_) : context(context_) {} static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } @@ -34,23 +33,6 @@ public: bool useDefaultImplementationForConstants() const override { return true; } - Columns getTupleElements(const IColumn & column) const - { - if (const auto * const_column = typeid_cast(&column)) - return convertConstTupleToConstantElements(*const_column); - - if (const auto * column_tuple = typeid_cast(&column)) - { - Columns columns(column_tuple->tupleSize()); - for (size_t i = 0; i < columns.size(); ++i) - columns[i] = column_tuple->getColumnPtr(i); - return columns; - } - - throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument of function {} should be tuples, got {}", - getName(), column.getName()); - } - DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { const auto * left_tuple = checkAndGetDataType(arguments[0].type.get()); @@ -85,7 +67,7 @@ public: auto compare = FunctionFactory::instance().get("notEquals", context); auto plus = FunctionFactory::instance().get("plus", context); - DataTypes types(tuple_size); + DataTypePtr res_type; for (size_t i = 0; i < tuple_size; ++i) { try @@ -93,7 +75,17 @@ public: ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}}; ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}}; auto elem_compare = compare->build(ColumnsWithTypeAndName{left, right}); - types[i] = elem_compare->getResultType(); + + if (i == 0) + { + res_type = elem_compare->getResultType(); + continue; + } + + ColumnWithTypeAndName left_type{res_type, {}}; + ColumnWithTypeAndName right_type{elem_compare->getResultType(), {}}; + auto plus_elem = plus->build({left_type, right_type}); + res_type = plus_elem->getResultType(); } catch (DB::Exception & e) { @@ -102,15 +94,6 @@ public: } } - auto res_type = types[0]; - for (size_t i = 1; i < tuple_size; ++i) - { - ColumnWithTypeAndName left{res_type, {}}; - ColumnWithTypeAndName right{types[i], {}}; - auto plus_elem = plus->build({left, right}); - res_type = plus_elem->getResultType(); - } - return res_type; } @@ -129,23 +112,28 @@ public: auto compare = FunctionFactory::instance().get("notEquals", context); auto plus = FunctionFactory::instance().get("plus", context); - ColumnsWithTypeAndName columns(tuple_size); + ColumnWithTypeAndName res; for (size_t i = 0; i < tuple_size; ++i) { ColumnWithTypeAndName left{left_elements[i], left_types[i], {}}; ColumnWithTypeAndName right{right_elements[i], right_types[i], {}}; auto elem_compare = compare->build(ColumnsWithTypeAndName{left, right}); - columns[i].type = elem_compare->getResultType(); - columns[i].column = elem_compare->execute({left, right}, columns[i].type, input_rows_count); - } - auto res = columns[0]; - for (size_t i = 1; i < tuple_size; ++i) - { - auto plus_elem = plus->build({res, columns[i]}); - auto res_type = plus_elem->getResultType(); - res.column = plus_elem->execute({res, columns[i]}, res_type, input_rows_count); - res.type = res_type; + ColumnWithTypeAndName column; + column.type = elem_compare->getResultType(); + column.column = elem_compare->execute({left, right}, column.type, input_rows_count); + + if (i == 0) + { + res = std::move(column); + } + else + { + auto plus_elem = plus->build({res, column}); + auto res_type = plus_elem->getResultType(); + res.column = plus_elem->execute({res, column}, res_type, input_rows_count); + res.type = res_type; + } } return res.column; diff --git a/src/Functions/vectorFunctions.cpp b/src/Functions/vectorFunctions.cpp new file mode 100644 index 00000000000..7f77d2165df --- /dev/null +++ b/src/Functions/vectorFunctions.cpp @@ -0,0 +1,333 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +} + +class FunctionTuplePlus : public TupleIFunction +{ +private: + ContextPtr context; + +public: + static constexpr auto name = "tuplePlus"; + + explicit FunctionTuplePlus(ContextPtr context_) : context(context_) {} + static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + const auto * left_tuple = checkAndGetDataType(arguments[0].type.get()); + const auto * right_tuple = checkAndGetDataType(arguments[1].type.get()); + + if (!left_tuple) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}", + getName(), arguments[0].type->getName()); + + if (!right_tuple) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}", + getName(), arguments[1].type->getName()); + + const auto & left_types = left_tuple->getElements(); + const auto & right_types = right_tuple->getElements(); + + Columns left_elements; + Columns right_elements; + if (arguments[0].column) + left_elements = getTupleElements(*arguments[0].column); + if (arguments[1].column) + right_elements = getTupleElements(*arguments[1].column); + + if (left_types.size() != right_types.size()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Expected tuples of the same size as arguments of function {}. Got {} and {}", + getName(), arguments[0].type->getName(), arguments[1].type->getName()); + + size_t tuple_size = left_types.size(); + if (tuple_size == 0) + return std::make_shared(); + + auto plus = FunctionFactory::instance().get("plus", context); + DataTypes types(tuple_size); + for (size_t i = 0; i < tuple_size; ++i) + { + try + { + ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}}; + ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}}; + auto elem_plus = plus->build(ColumnsWithTypeAndName{left, right}); + types[i] = elem_plus->getResultType(); + } + catch (DB::Exception & e) + { + e.addMessage("While executing function {} for tuple element {}", getName(), i); + throw; + } + } + + return std::make_shared(types); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + const auto * left_tuple = checkAndGetDataType(arguments[0].type.get()); + const auto * right_tuple = checkAndGetDataType(arguments[1].type.get()); + const auto & left_types = left_tuple->getElements(); + const auto & right_types = right_tuple->getElements(); + auto left_elements = getTupleElements(*arguments[0].column); + auto right_elements = getTupleElements(*arguments[1].column); + + size_t tuple_size = left_elements.size(); + if (tuple_size == 0) + return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count); + + auto plus = FunctionFactory::instance().get("plus", context); + Columns columns(tuple_size); + for (size_t i = 0; i < tuple_size; ++i) + { + ColumnWithTypeAndName left{left_elements[i], left_types[i], {}}; + ColumnWithTypeAndName right{right_elements[i], right_types[i], {}}; + auto elem_plus = plus->build(ColumnsWithTypeAndName{left, right}); + columns[i] = elem_plus->execute({left, right}, elem_plus->getResultType(), input_rows_count); + } + + return ColumnTuple::create(columns); + } +}; + +class FunctionTupleMinus : public TupleIFunction +{ +private: + ContextPtr context; + +public: + static constexpr auto name = "tupleMinus"; + + explicit FunctionTupleMinus(ContextPtr context_) : context(context_) {} + static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + const auto * left_tuple = checkAndGetDataType(arguments[0].type.get()); + const auto * right_tuple = checkAndGetDataType(arguments[1].type.get()); + + if (!left_tuple) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}", + getName(), arguments[0].type->getName()); + + if (!right_tuple) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}", + getName(), arguments[1].type->getName()); + + const auto & left_types = left_tuple->getElements(); + const auto & right_types = right_tuple->getElements(); + + Columns left_elements; + Columns right_elements; + if (arguments[0].column) + left_elements = getTupleElements(*arguments[0].column); + if (arguments[1].column) + right_elements = getTupleElements(*arguments[1].column); + + if (left_types.size() != right_types.size()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Expected tuples of the same size as arguments of function {}. Got {} and {}", + getName(), arguments[0].type->getName(), arguments[1].type->getName()); + + size_t tuple_size = left_types.size(); + if (tuple_size == 0) + return std::make_shared(); + + auto minus = FunctionFactory::instance().get("minus", context); + DataTypes types(tuple_size); + for (size_t i = 0; i < tuple_size; ++i) + { + try + { + ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}}; + ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}}; + auto elem_minus = minus->build(ColumnsWithTypeAndName{left, right}); + types[i] = elem_minus->getResultType(); + } + catch (DB::Exception & e) + { + e.addMessage("While executing function {} for tuple element {}", getName(), i); + throw; + } + } + + return std::make_shared(types); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + const auto * left_tuple = checkAndGetDataType(arguments[0].type.get()); + const auto * right_tuple = checkAndGetDataType(arguments[1].type.get()); + const auto & left_types = left_tuple->getElements(); + const auto & right_types = right_tuple->getElements(); + auto left_elements = getTupleElements(*arguments[0].column); + auto right_elements = getTupleElements(*arguments[1].column); + + size_t tuple_size = left_elements.size(); + if (tuple_size == 0) + return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count); + + auto minus = FunctionFactory::instance().get("minus", context); + Columns columns(tuple_size); + for (size_t i = 0; i < tuple_size; ++i) + { + ColumnWithTypeAndName left{left_elements[i], left_types[i], {}}; + ColumnWithTypeAndName right{right_elements[i], right_types[i], {}}; + auto elem_minus = minus->build(ColumnsWithTypeAndName{left, right}); + columns[i] = elem_minus->execute({left, right}, elem_minus->getResultType(), input_rows_count); + } + + return ColumnTuple::create(columns); + } +}; + +class FunctionTupleMultiply : public TupleIFunction +{ +private: + ContextPtr context; + +public: + static constexpr auto name = "tupleMultiply"; + + explicit FunctionTupleMultiply(ContextPtr context_) : context(context_) {} + static FunctionPtr create(ContextPtr context) { return std::make_shared(context); } + + String getName() const override { return name; } + + size_t getNumberOfArguments() const override { return 2; } + + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + const auto * left_tuple = checkAndGetDataType(arguments[0].type.get()); + const auto * right_tuple = checkAndGetDataType(arguments[1].type.get()); + + if (!left_tuple) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}", + getName(), arguments[0].type->getName()); + + if (!right_tuple) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}", + getName(), arguments[1].type->getName()); + + const auto & left_types = left_tuple->getElements(); + const auto & right_types = right_tuple->getElements(); + + Columns left_elements; + Columns right_elements; + if (arguments[0].column) + left_elements = getTupleElements(*arguments[0].column); + if (arguments[1].column) + right_elements = getTupleElements(*arguments[1].column); + + if (left_types.size() != right_types.size()) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Expected tuples of the same size as arguments of function {}. Got {} and {}", + getName(), arguments[0].type->getName(), arguments[1].type->getName()); + + size_t tuple_size = left_types.size(); + if (tuple_size == 0) + return std::make_shared(); + + auto multiply = FunctionFactory::instance().get("multiply", context); + DataTypes types(tuple_size); + for (size_t i = 0; i < tuple_size; ++i) + { + try + { + ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}}; + ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}}; + auto elem_multiply = multiply->build(ColumnsWithTypeAndName{left, right}); + types[i] = elem_multiply->getResultType(); + } + catch (DB::Exception & e) + { + e.addMessage("While executing function {} for tuple element {}", getName(), i); + throw; + } + } + + return std::make_shared(types); + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + const auto * left_tuple = checkAndGetDataType(arguments[0].type.get()); + const auto * right_tuple = checkAndGetDataType(arguments[1].type.get()); + const auto & left_types = left_tuple->getElements(); + const auto & right_types = right_tuple->getElements(); + auto left_elements = getTupleElements(*arguments[0].column); + auto right_elements = getTupleElements(*arguments[1].column); + + size_t tuple_size = left_elements.size(); + if (tuple_size == 0) + return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count); + + auto multiply = FunctionFactory::instance().get("multiply", context); + Columns columns(tuple_size); + for (size_t i = 0; i < tuple_size; ++i) + { + ColumnWithTypeAndName left{left_elements[i], left_types[i], {}}; + ColumnWithTypeAndName right{right_elements[i], right_types[i], {}}; + auto elem_multiply = multiply->build(ColumnsWithTypeAndName{left, right}); + columns[i] = elem_multiply->execute({left, right}, elem_multiply->getResultType(), input_rows_count); + } + + return ColumnTuple::create(columns); + } +}; + +void registerVectorFunctions(FunctionFactory & factory) +{ + factory.registerFunction(); + factory.registerAlias("vectorSum", FunctionTuplePlus::name, FunctionFactory::CaseInsensitive); + factory.registerFunction(); + factory.registerAlias("vectorDifference", FunctionTupleMinus::name, FunctionFactory::CaseInsensitive); + factory.registerFunction(); + + /*factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + + factory.registerFunction(); + factory.registerFunction(); + factory.registerFunction(); + + factory.registerFunction();*/ +} +} diff --git a/tests/queries/0_stateless/02011_tuple_vector_functions.reference b/tests/queries/0_stateless/02011_tuple_vector_functions.reference new file mode 100644 index 00000000000..598728930e0 --- /dev/null +++ b/tests/queries/0_stateless/02011_tuple_vector_functions.reference @@ -0,0 +1,5 @@ +0 +1 +(10,3) +(-1,0) +(-23,-27) diff --git a/tests/queries/0_stateless/02011_tuple_vector_functions.sql b/tests/queries/0_stateless/02011_tuple_vector_functions.sql new file mode 100644 index 00000000000..d2c8779183a --- /dev/null +++ b/tests/queries/0_stateless/02011_tuple_vector_functions.sql @@ -0,0 +1,6 @@ +SELECT tupleHammingDistance(tuple(1), tuple(1)); +SELECT tupleHammingDistance(tuple(1, 3), tuple(1, 2)); + +SELECT tuple(1, 2) + tuple(3, 4) * tuple(5, 1) - tuple(6, 3); +SELECT vectorDifference(tuplePlus(tuple(1, 2), tuple(3, 4)), tuple(5, 6)); +SELECT tupleMinus(vectorSum(tupleMultiply(tuple(1, 2), tuple(3, 4)), tuple(5, 6)), tuple(31, 41));