ClickHouse/src/Functions/vectorFunctions.cpp

334 lines
14 KiB
C++
Raw Normal View History

#include <Columns/ColumnTuple.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypesNumber.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/TupleIFunction.h>
#include <Functions/castTypeToEither.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
class FunctionTuplePlus : public TupleIFunction
{
private:
ContextPtr context;
public:
static constexpr auto name = "tuplePlus";
explicit FunctionTuplePlus(ContextPtr context_) : context(context_) {}
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionTuplePlus>(context); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
if (!left_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}",
getName(), arguments[0].type->getName());
if (!right_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}",
getName(), arguments[1].type->getName());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
Columns left_elements;
Columns right_elements;
if (arguments[0].column)
left_elements = getTupleElements(*arguments[0].column);
if (arguments[1].column)
right_elements = getTupleElements(*arguments[1].column);
if (left_types.size() != right_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expected tuples of the same size as arguments of function {}. Got {} and {}",
getName(), arguments[0].type->getName(), arguments[1].type->getName());
size_t tuple_size = left_types.size();
if (tuple_size == 0)
return std::make_shared<DataTypeUInt8>();
auto plus = FunctionFactory::instance().get("plus", context);
DataTypes types(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
try
{
ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}};
auto elem_plus = plus->build(ColumnsWithTypeAndName{left, right});
types[i] = elem_plus->getResultType();
}
catch (DB::Exception & e)
{
e.addMessage("While executing function {} for tuple element {}", getName(), i);
throw;
}
}
return std::make_shared<DataTypeTuple>(types);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
auto left_elements = getTupleElements(*arguments[0].column);
auto right_elements = getTupleElements(*arguments[1].column);
size_t tuple_size = left_elements.size();
if (tuple_size == 0)
return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count);
auto plus = FunctionFactory::instance().get("plus", context);
Columns columns(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName left{left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements[i], right_types[i], {}};
auto elem_plus = plus->build(ColumnsWithTypeAndName{left, right});
columns[i] = elem_plus->execute({left, right}, elem_plus->getResultType(), input_rows_count);
}
return ColumnTuple::create(columns);
}
};
class FunctionTupleMinus : public TupleIFunction
{
private:
ContextPtr context;
public:
static constexpr auto name = "tupleMinus";
explicit FunctionTupleMinus(ContextPtr context_) : context(context_) {}
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionTupleMinus>(context); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
if (!left_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}",
getName(), arguments[0].type->getName());
if (!right_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}",
getName(), arguments[1].type->getName());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
Columns left_elements;
Columns right_elements;
if (arguments[0].column)
left_elements = getTupleElements(*arguments[0].column);
if (arguments[1].column)
right_elements = getTupleElements(*arguments[1].column);
if (left_types.size() != right_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expected tuples of the same size as arguments of function {}. Got {} and {}",
getName(), arguments[0].type->getName(), arguments[1].type->getName());
size_t tuple_size = left_types.size();
if (tuple_size == 0)
return std::make_shared<DataTypeUInt8>();
auto minus = FunctionFactory::instance().get("minus", context);
DataTypes types(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
try
{
ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}};
auto elem_minus = minus->build(ColumnsWithTypeAndName{left, right});
types[i] = elem_minus->getResultType();
}
catch (DB::Exception & e)
{
e.addMessage("While executing function {} for tuple element {}", getName(), i);
throw;
}
}
return std::make_shared<DataTypeTuple>(types);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
auto left_elements = getTupleElements(*arguments[0].column);
auto right_elements = getTupleElements(*arguments[1].column);
size_t tuple_size = left_elements.size();
if (tuple_size == 0)
return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count);
auto minus = FunctionFactory::instance().get("minus", context);
Columns columns(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName left{left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements[i], right_types[i], {}};
auto elem_minus = minus->build(ColumnsWithTypeAndName{left, right});
columns[i] = elem_minus->execute({left, right}, elem_minus->getResultType(), input_rows_count);
}
return ColumnTuple::create(columns);
}
};
class FunctionTupleMultiply : public TupleIFunction
{
private:
ContextPtr context;
public:
static constexpr auto name = "tupleMultiply";
explicit FunctionTupleMultiply(ContextPtr context_) : context(context_) {}
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionTupleMultiply>(context); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
bool useDefaultImplementationForConstants() const override { return true; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
if (!left_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}",
getName(), arguments[0].type->getName());
if (!right_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 1 of function {} should be tuples, got {}",
getName(), arguments[1].type->getName());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
Columns left_elements;
Columns right_elements;
if (arguments[0].column)
left_elements = getTupleElements(*arguments[0].column);
if (arguments[1].column)
right_elements = getTupleElements(*arguments[1].column);
if (left_types.size() != right_types.size())
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Expected tuples of the same size as arguments of function {}. Got {} and {}",
getName(), arguments[0].type->getName(), arguments[1].type->getName());
size_t tuple_size = left_types.size();
if (tuple_size == 0)
return std::make_shared<DataTypeUInt8>();
auto multiply = FunctionFactory::instance().get("multiply", context);
DataTypes types(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
try
{
ColumnWithTypeAndName left{left_elements.empty() ? nullptr : left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements.empty() ? nullptr : right_elements[i], right_types[i], {}};
auto elem_multiply = multiply->build(ColumnsWithTypeAndName{left, right});
types[i] = elem_multiply->getResultType();
}
catch (DB::Exception & e)
{
e.addMessage("While executing function {} for tuple element {}", getName(), i);
throw;
}
}
return std::make_shared<DataTypeTuple>(types);
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto * left_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto * right_tuple = checkAndGetDataType<DataTypeTuple>(arguments[1].type.get());
const auto & left_types = left_tuple->getElements();
const auto & right_types = right_tuple->getElements();
auto left_elements = getTupleElements(*arguments[0].column);
auto right_elements = getTupleElements(*arguments[1].column);
size_t tuple_size = left_elements.size();
if (tuple_size == 0)
return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count);
auto multiply = FunctionFactory::instance().get("multiply", context);
Columns columns(tuple_size);
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName left{left_elements[i], left_types[i], {}};
ColumnWithTypeAndName right{right_elements[i], right_types[i], {}};
auto elem_multiply = multiply->build(ColumnsWithTypeAndName{left, right});
columns[i] = elem_multiply->execute({left, right}, elem_multiply->getResultType(), input_rows_count);
}
return ColumnTuple::create(columns);
}
};
void registerVectorFunctions(FunctionFactory & factory)
{
factory.registerFunction<FunctionTuplePlus>();
factory.registerAlias("vectorSum", FunctionTuplePlus::name, FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionTupleMinus>();
factory.registerAlias("vectorDifference", FunctionTupleMinus::name, FunctionFactory::CaseInsensitive);
factory.registerFunction<FunctionTupleMultiply>();
/*factory.registerFunction<FunctionL1Length>();
factory.registerFunction<FunctionL1Distance>();
factory.registerFunction<FunctionL1Norm>();
factory.registerFunction<FunctionL2Length>();
factory.registerFunction<FunctionL2Distance>();
factory.registerFunction<FunctionL2Norm>();
factory.registerFunction<FunctionLinfLength>();
factory.registerFunction<FunctionLinfDistance>();
factory.registerFunction<FunctionLinfNorm>();
factory.registerFunction<FunctionLpLength>();
factory.registerFunction<FunctionLpDistance>();
factory.registerFunction<FunctionLpNorm>();
factory.registerFunction<FunctionCosineDistance>();*/
}
}