This commit is contained in:
Alexey Boykov 2021-08-26 21:17:44 +03:00
parent 657a8e13f9
commit 6fbdda6dd0
3 changed files with 142 additions and 3 deletions

View File

@ -14,7 +14,8 @@ namespace ErrorCodes
}
template <const char * func_name>
class TuplesToTupleFunction : public TupleIFunction {
class TuplesToTupleFunction : public TupleIFunction
{
public:
explicit TuplesToTupleFunction(ContextPtr context_) : TupleIFunction(context_) {}
@ -643,6 +644,131 @@ public:
}
};
class FunctionLpNorm : public TupleIFunction
{
public:
static constexpr auto name = "LpNorm";
explicit FunctionLpNorm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLpNorm>(context_); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override { return 2; }
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
const auto * cur_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
if (!cur_tuple)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument 0 of function {} should be tuples, got {}",
getName(), arguments[0].type->getName());
const auto & cur_types = cur_tuple->getElements();
Columns cur_elements;
if (arguments[0].column)
cur_elements = getTupleElements(*arguments[0].column);
size_t tuple_size = cur_types.size();
if (tuple_size == 0)
return std::make_shared<DataTypeUInt8>();
const auto & p_column = arguments[1];
auto abs = FunctionFactory::instance().get("abs", context);
auto pow = FunctionFactory::instance().get("pow", context);
auto plus = FunctionFactory::instance().get("plus", context);
DataTypePtr res_type;
for (size_t i = 0; i < tuple_size; ++i)
{
try
{
ColumnWithTypeAndName cur{cur_elements.empty() ? nullptr : cur_elements[i], cur_types[i], {}};
auto elem_abs = abs->build(ColumnsWithTypeAndName{cur});
cur.column = elem_abs->execute({cur}, elem_abs->getResultType(), 1);
cur.type = elem_abs->getResultType();
auto elem_pow = pow->build(ColumnsWithTypeAndName{cur, p_column});
if (i == 0)
{
res_type = elem_pow->getResultType();
continue;
}
ColumnWithTypeAndName left_type{res_type, {}};
ColumnWithTypeAndName right_type{elem_pow->getResultType(), {}};
auto plus_elem = plus->build({left_type, right_type});
res_type = plus_elem->getResultType();
}
catch (DB::Exception & e)
{
e.addMessage("While executing function {} for tuple element {}", getName(), i);
throw;
}
}
auto divide = FunctionFactory::instance().get("divide", context);
ColumnWithTypeAndName one{DataTypeFloat64().createColumnConst(1, 1.), std::make_shared<DataTypeFloat64>(), {}};
auto div_elem = divide->build({one, p_column});
ColumnWithTypeAndName inv_p_column;
inv_p_column.type = div_elem->getResultType();
inv_p_column.column = div_elem->execute({one, p_column}, inv_p_column.type, 1);
return pow->build({ColumnWithTypeAndName{res_type, {}}, inv_p_column})->getResultType();
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
const auto * cur_tuple = checkAndGetDataType<DataTypeTuple>(arguments[0].type.get());
const auto & cur_types = cur_tuple->getElements();
auto cur_elements = getTupleElements(*arguments[0].column);
size_t tuple_size = cur_elements.size();
if (tuple_size == 0)
return DataTypeUInt8().createColumnConstWithDefaultValue(input_rows_count);
const auto & p_column = arguments[1];
auto abs = FunctionFactory::instance().get("abs", context);
auto pow = FunctionFactory::instance().get("pow", context);
auto plus = FunctionFactory::instance().get("plus", context);
ColumnWithTypeAndName res;
for (size_t i = 0; i < tuple_size; ++i)
{
ColumnWithTypeAndName cur{cur_elements[i], cur_types[i], {}};
auto elem_abs = abs->build(ColumnsWithTypeAndName{cur});
cur.column = elem_abs->execute({cur}, elem_abs->getResultType(), input_rows_count);
cur.type = elem_abs->getResultType();
auto elem_pow = pow->build(ColumnsWithTypeAndName{cur, p_column});
ColumnWithTypeAndName column;
column.type = elem_pow->getResultType();
column.column = elem_pow->execute({cur, p_column}, column.type, input_rows_count);
if (i == 0)
{
res = std::move(column);
}
else
{
auto plus_elem = plus->build({res, column});
auto res_type = plus_elem->getResultType();
res.column = plus_elem->execute({res, column}, res_type, input_rows_count);
res.type = res_type;
}
}
auto divide = FunctionFactory::instance().get("divide", context);
ColumnWithTypeAndName one{DataTypeFloat64().createColumnConst(input_rows_count, 1.)->convertToFullColumnIfConst(), std::make_shared<DataTypeFloat64>(), {}};
auto div_elem = divide->build({one, p_column});
ColumnWithTypeAndName inv_p_column;
inv_p_column.type = div_elem->getResultType();
inv_p_column.column = div_elem->execute({one, p_column}, inv_p_column.type, input_rows_count);
auto pow_elem = pow->build({res, inv_p_column});
return pow_elem->execute({res, inv_p_column}, pow_elem->getResultType(), input_rows_count);
}
};
void registerVectorFunctions(FunctionFactory & factory)
{
factory.registerFunction<FunctionTuplePlus>();
@ -662,6 +788,7 @@ void registerVectorFunctions(FunctionFactory & factory)
factory.registerFunction<FunctionL1Norm>();
factory.registerFunction<FunctionL2Norm>();
factory.registerFunction<FunctionLinfNorm>();
factory.registerFunction<FunctionLpNorm>();
/*factory.registerFunction<FunctionL1Distance>();
factory.registerFunction<FunctionL1Normalize>();
@ -671,7 +798,6 @@ void registerVectorFunctions(FunctionFactory & factory)
factory.registerFunction<FunctionLinfDistance>();
factory.registerFunction<FunctionLinfNormalize>();
factory.registerFunction<FunctionLpNorm>();
factory.registerFunction<FunctionLpDistance>();
factory.registerFunction<FunctionLpNormalize>();

View File

@ -7,9 +7,15 @@
(-1,0,-3.5)
20
16.808
6
7.1
1.4142135623730951
5
1.5
-3
2.3
1
1.1
1
1
2.0000887587111964

View File

@ -11,10 +11,17 @@ SELECT tupleNegate(tuple(1, 0, 3.5));
SELECT dotProduct(tuple(1, 2, 3), tuple(2, 3, 4));
SELECT scalarProduct(tuple(-1, 2, 3.002), tuple(2, 3.4, 4));
SELECT L1Norm(tuple(-1,2.5,-3.6));
SELECT L1Norm(tuple(-1, 2, -3));
SELECT L1Norm(tuple(-1, 2.5, -3.6));
SELECT L2Norm(tuple(1, 1));
SELECT L2Norm(tuple(3, 4));
SELECT max2(1, 1.5);
SELECT min2(-1, -3);
SELECT LinfNorm(tuple(1, -2.3, 1.7));
SELECT LpNorm(tuple(-1), 3);
SELECT LpNorm(tuple(-1.1), 3);
SELECT LpNorm(tuple(13, -84.4, 91, 63.1), 2) = L2Norm(tuple(13, -84.4, 91, 63.1));
SELECT LpNorm(tuple(13, -84.4, 91, 63.1), 1) = L1Norm(tuple(13, -84.4, 91, 63.1));
SELECT LpNorm(tuple(-1, -2), 11);