Operators, refactoring

Added operators, tupleHammingDistance has been refactored
This commit is contained in:
Alexey Boykov 2021-08-20 17:06:57 +03:00
parent 974acb615a
commit b3bd6b5b29
8 changed files with 453 additions and 43 deletions

View File

@ -645,6 +645,36 @@ class FunctionBinaryArithmetic : public IFunction
return FunctionFactory::instance().get(function_name, context);
}
static FunctionOverloadResolverPtr
getFunctionForTupleArithmetic(const DataTypePtr & type0, const DataTypePtr & type1, ContextPtr context)
{
std::cout << "Into getFunctionForTupleArithmetic" << std::endl;
if (!isTuple(type0) || !isTuple(type1))
return {};
/// Special case when the function is plus, minus or multiply, both arguments are tuples.
/// We construct another function (example: tuplePlus) and call it.
if constexpr (!is_plus && !is_minus && !is_multiply)
return {};
std::string function_name;
if (is_plus)
{
function_name = "tuplePlus";
}
else if (is_minus)
{
function_name = "tupleMinus";
}
else
{
function_name = "tupleMultiply";
}
return FunctionFactory::instance().get(function_name, context);
}
static bool isAggregateMultiply(const DataTypePtr & type0, const DataTypePtr & type1)
{
if constexpr (!is_multiply)
@ -999,6 +1029,20 @@ public:
return function->getResultType();
}
/// Special case when the function is plus, minus or multiply, both arguments are tuples.
if (auto function_builder = getFunctionForTupleArithmetic(arguments[0], arguments[1], context))
{
std::cerr << "Tuple op" << std::endl;
ColumnsWithTypeAndName new_arguments(2);
for (size_t i = 0; i < 2; ++i)
new_arguments[i].type = arguments[i];
auto function = function_builder->build(new_arguments);
return function->getResultType();
}
std::cerr << "Wow, it's here!" << std::endl;
DataTypePtr type_res;
const bool valid = castBothTypes(arguments[0].get(), arguments[1].get(), [&](const auto & left, const auto & right)
@ -1270,6 +1314,13 @@ public:
return executeDateTimeIntervalPlusMinus(arguments, result_type, input_rows_count, function_builder);
}
/// Special case when the function is plus, minus or multiply, both arguments are tuples.
if (auto function_builder
= getFunctionForTupleArithmetic(arguments[0].type, arguments[1].type, context))
{
return function_builder->build(arguments)->execute(arguments, result_type, input_rows_count);
}
const auto & left_argument = arguments[0];
const auto & right_argument = arguments[1];
const auto * const left_generic = left_argument.type.get();

View File

@ -3,7 +3,7 @@
namespace DB
{
/// These classes should be present in DB namespace (cannot place them into nemelesspace)
/// These classes should be present in DB namespace (cannot place them into namelesspace)
template <typename> struct AbsImpl;
template <typename> struct NegateImpl;
template <typename, typename> struct PlusImpl;

View File

@ -0,0 +1,25 @@
#include <Functions/IFunction.h>
namespace DB
{
class TupleIFunction : public IFunction
{
public:
Columns getTupleElements(const IColumn & column) const
{
if (const auto * const_column = typeid_cast<const ColumnConst *>(&column))
return convertConstTupleToConstantElements(*const_column);
if (const auto * column_tuple = typeid_cast<const ColumnTuple *>(&column))
{
Columns columns(column_tuple->tupleSize());
for (size_t i = 0; i < columns.size(); ++i)
columns[i] = column_tuple->getColumnPtr(i);
return columns;
}
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument of function {} should be tuples, got {}",
getName(), column.getName());
}
};
}

View File

@ -34,6 +34,7 @@ void registerFunctionAcosh(FunctionFactory & factory);
void registerFunctionAtanh(FunctionFactory & factory);
void registerFunctionPow(FunctionFactory & factory);
void registerFunctionSign(FunctionFactory & factory);
void registerVectorFunctions(FunctionFactory &);
void registerFunctionsMath(FunctionFactory & factory)
@ -70,6 +71,7 @@ void registerFunctionsMath(FunctionFactory & factory)
registerFunctionAtanh(factory);
registerFunctionPow(factory);
registerFunctionSign(factory);
registerVectorFunctions(factory);
}
}

View File

@ -4,7 +4,7 @@
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/IFunction.h>
#include <Functions/TupleIFunction.h>
#include <Functions/castTypeToEither.h>
namespace DB
@ -16,14 +16,13 @@ namespace ErrorCodes
/// tupleHammingDistance function: (Tuple(...), Tuple(...))-> N
/// Return the number of non-equal tuple elements
class FunctionTupleHammingDistance : public IFunction
class FunctionTupleHammingDistance : public TupleIFunction
{
private:
ContextPtr context;
public:
static constexpr auto name = "tupleHammingDistance";
using ResultType = UInt8;
explicit FunctionTupleHammingDistance(ContextPtr context_) : context(context_) {}
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionTupleHammingDistance>(context); }
@ -34,23 +33,6 @@ public:
bool useDefaultImplementationForConstants() const override { return true; }
Columns getTupleElements(const IColumn & column) const
{
if (const auto * const_column = typeid_cast<const ColumnConst *>(&column))
return convertConstTupleToConstantElements(*const_column);
if (const auto * column_tuple = typeid_cast<const ColumnTuple *>(&column))
{
Columns columns(column_tuple->tupleSize());
for (size_t i = 0; i < columns.size(); ++i)
columns[i] = column_tuple->getColumnPtr(i);
return columns;
}
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument of function {} should be tuples, got {}",
getName(), column.getName());
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
@ -85,7 +67,7 @@ public:
auto compare = FunctionFactory::instance().get("notEquals", context);
auto plus = FunctionFactory::instance().get("plus", context);
DataTypes types(tuple_size);
DataTypePtr res_type;
for (size_t i = 0; i < tuple_size; ++i)
{
try
@ -93,7 +75,17 @@ public:
ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}};
auto elem_compare = compare->build(ColumnsWithTypeAndName{left, right});
types[i] = elem_compare->getResultType();
if (i == 0)
{
res_type = elem_compare->getResultType();
continue;
}
ColumnWithTypeAndName left_type{res_type, {}};
ColumnWithTypeAndName right_type{elem_compare->getResultType(), {}};
auto plus_elem = plus->build({left_type, right_type});
res_type = plus_elem->getResultType();
}
catch (DB::Exception & e)
{
@ -102,15 +94,6 @@ public:
}
}
auto res_type = types[0];
for (size_t i = 1; i < tuple_size; ++i)
{
ColumnWithTypeAndName left{res_type, {}};
ColumnWithTypeAndName right{types[i], {}};
auto plus_elem = plus->build({left, right});
res_type = plus_elem->getResultType();
}
return res_type;
}
@ -129,24 +112,29 @@ public:
auto compare = FunctionFactory::instance().get("notEquals", context);
auto plus = FunctionFactory::instance().get("plus", context);
ColumnsWithTypeAndName columns(tuple_size);
ColumnWithTypeAndName res;
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName left{left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements[i], right_types[i], {}};
auto elem_compare = compare->build(ColumnsWithTypeAndName{left, right});
columns[i].type = elem_compare->getResultType();
columns[i].column = elem_compare->execute({left, right}, columns[i].type, input_rows_count);
}
auto res = columns[0];
for (size_t i = 1; i < tuple_size; ++i)
ColumnWithTypeAndName column;
column.type = elem_compare->getResultType();
column.column = elem_compare->execute({left, right}, column.type, input_rows_count);
if (i == 0)
{
auto plus_elem = plus->build({res, columns[i]});
res = std::move(column);
}
else
{
auto plus_elem = plus->build({res, column});
auto res_type = plus_elem->getResultType();
res.column = plus_elem->execute({res, columns[i]}, res_type, input_rows_count);
res.column = plus_elem->execute({res, column}, res_type, input_rows_count);
res.type = res_type;
}
}
return res.column;
}

View File

@ -0,0 +1,333 @@
#include <Columns/ColumnTuple.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/TupleIFunction.h>
#include <Functions/castTypeToEither.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
class FunctionTuplePlus : public TupleIFunction
{
private:
ContextPtr context;
public:
static constexpr auto name = "tuplePlus";
explicit FunctionTuplePlus(ContextPtr context_) : context(context_) {}
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionTuplePlus>(context); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
if (!left_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}",
getName(), arguments[0].type->getName());
if (!right_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}",
getName(), arguments[1].type->getName());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
Columns left_elements;
Columns right_elements;
if (arguments[0].column)
left_elements = getTupleElements(*arguments[0].column);
if (arguments[1].column)
right_elements = getTupleElements(*arguments[1].column);
if (left_types.size() != right_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expected tuples of the same size as arguments of function {}. Got {} and {}",
getName(), arguments[0].type->getName(), arguments[1].type->getName());
size_t tuple_size = left_types.size();
if (tuple_size == 0)
return std::make_shared<DataTypeUInt8>();
auto plus = FunctionFactory::instance().get("plus", context);
DataTypes types(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
try
{
ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}};
auto elem_plus = plus->build(ColumnsWithTypeAndName{left, right});
types[i] = elem_plus->getResultType();
}
catch (DB::Exception & e)
{
e.addMessage("While executing function {} for tuple element {}", getName(), i);
throw;
}
}
return std::make_shared<DataTypeTuple>(types);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
auto left_elements = getTupleElements(*arguments[0].column);
auto right_elements = getTupleElements(*arguments[1].column);
size_t tuple_size = left_elements.size();
if (tuple_size == 0)
return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count);
auto plus = FunctionFactory::instance().get("plus", context);
Columns columns(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName left{left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements[i], right_types[i], {}};
auto elem_plus = plus->build(ColumnsWithTypeAndName{left, right});
columns[i] = elem_plus->execute({left, right}, elem_plus->getResultType(), input_rows_count);
}
return ColumnTuple::create(columns);
}
};
class FunctionTupleMinus : public TupleIFunction
{
private:
ContextPtr context;
public:
static constexpr auto name = "tupleMinus";
explicit FunctionTupleMinus(ContextPtr context_) : context(context_) {}
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionTupleMinus>(context); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
if (!left_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}",
getName(), arguments[0].type->getName());
if (!right_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}",
getName(), arguments[1].type->getName());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
Columns left_elements;
Columns right_elements;
if (arguments[0].column)
left_elements = getTupleElements(*arguments[0].column);
if (arguments[1].column)
right_elements = getTupleElements(*arguments[1].column);
if (left_types.size() != right_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expected tuples of the same size as arguments of function {}. Got {} and {}",
getName(), arguments[0].type->getName(), arguments[1].type->getName());
size_t tuple_size = left_types.size();
if (tuple_size == 0)
return std::make_shared<DataTypeUInt8>();
auto minus = FunctionFactory::instance().get("minus", context);
DataTypes types(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
try
{
ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}};
auto elem_minus = minus->build(ColumnsWithTypeAndName{left, right});
types[i] = elem_minus->getResultType();
}
catch (DB::Exception & e)
{
e.addMessage("While executing function {} for tuple element {}", getName(), i);
throw;
}
}
return std::make_shared<DataTypeTuple>(types);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
auto left_elements = getTupleElements(*arguments[0].column);
auto right_elements = getTupleElements(*arguments[1].column);
size_t tuple_size = left_elements.size();
if (tuple_size == 0)
return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count);
auto minus = FunctionFactory::instance().get("minus", context);
Columns columns(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName left{left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements[i], right_types[i], {}};
auto elem_minus = minus->build(ColumnsWithTypeAndName{left, right});
columns[i] = elem_minus->execute({left, right}, elem_minus->getResultType(), input_rows_count);
}
return ColumnTuple::create(columns);
}
};
class FunctionTupleMultiply : public TupleIFunction
{
private:
ContextPtr context;
public:
static constexpr auto name = "tupleMultiply";
explicit FunctionTupleMultiply(ContextPtr context_) : context(context_) {}
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionTupleMultiply>(context); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
if (!left_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}",
getName(), arguments[0].type->getName());
if (!right_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}",
getName(), arguments[1].type->getName());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
Columns left_elements;
Columns right_elements;
if (arguments[0].column)
left_elements = getTupleElements(*arguments[0].column);
if (arguments[1].column)
right_elements = getTupleElements(*arguments[1].column);
if (left_types.size() != right_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expected tuples of the same size as arguments of function {}. Got {} and {}",
getName(), arguments[0].type->getName(), arguments[1].type->getName());
size_t tuple_size = left_types.size();
if (tuple_size == 0)
return std::make_shared<DataTypeUInt8>();
auto multiply = FunctionFactory::instance().get("multiply", context);
DataTypes types(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
try
{
ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}};
auto elem_multiply = multiply->build(ColumnsWithTypeAndName{left, right});
types[i] = elem_multiply->getResultType();
}
catch (DB::Exception & e)
{
e.addMessage("While executing function {} for tuple element {}", getName(), i);
throw;
}
}
return std::make_shared<DataTypeTuple>(types);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
auto left_elements = getTupleElements(*arguments[0].column);
auto right_elements = getTupleElements(*arguments[1].column);
size_t tuple_size = left_elements.size();
if (tuple_size == 0)
return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count);
auto multiply = FunctionFactory::instance().get("multiply", context);
Columns columns(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName left{left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements[i], right_types[i], {}};
auto elem_multiply = multiply->build(ColumnsWithTypeAndName{left, right});
columns[i] = elem_multiply->execute({left, right}, elem_multiply->getResultType(), input_rows_count);
}
return ColumnTuple::create(columns);
}
};
void registerVectorFunctions(FunctionFactory & factory)
{
factory.registerFunction<FunctionTuplePlus>();
factory.registerAlias("vectorSum", FunctionTuplePlus::name, FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionTupleMinus>();
factory.registerAlias("vectorDifference", FunctionTupleMinus::name, FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionTupleMultiply>();
/*factory.registerFunction<FunctionL1Length>();
factory.registerFunction<FunctionL1Distance>();
factory.registerFunction<FunctionL1Norm>();
factory.registerFunction<FunctionL2Length>();
factory.registerFunction<FunctionL2Distance>();
factory.registerFunction<FunctionL2Norm>();
factory.registerFunction<FunctionLinfLength>();
factory.registerFunction<FunctionLinfDistance>();
factory.registerFunction<FunctionLinfNorm>();
factory.registerFunction<FunctionLpLength>();
factory.registerFunction<FunctionLpDistance>();
factory.registerFunction<FunctionLpNorm>();
factory.registerFunction<FunctionCosineDistance>();*/
}
}

View File

@ -0,0 +1,5 @@
0
1
(10,3)
(-1,0)
(-23,-27)

View File

@ -0,0 +1,6 @@
SELECT tupleHammingDistance(tuple(1), tuple(1));
SELECT tupleHammingDistance(tuple(1, 3), tuple(1, 2));
SELECT tuple(1, 2) + tuple(3, 4) * tuple(5, 1) - tuple(6, 3);
SELECT vectorDifference(tuplePlus(tuple(1, 2), tuple(3, 4)), tuple(5, 6));
SELECT tupleMinus(vectorSum(tupleMultiply(tuple(1, 2), tuple(3, 4)), tuple(5, 6)), tuple(31, 41));