Normalize

This commit is contained in:
Alexey Boykov 2021-08-30 15:52:00 +03:00
parent 91199dc73a
commit 072104135a
3 changed files with 111 additions and 35 deletions

View File

@ -13,6 +13,16 @@ namespace ErrorCodes
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
}
static const char PLUS_NAME[] = "plus";
static const char MINUS_NAME[] = "minus";
static const char MULTIPLY_NAME[] = "multiply";
static const char DIVIDE_NAME[] = "divide";
static constexpr char L1_LABEL[] = "1";
static constexpr char L2_LABEL[] = "2";
static constexpr char Linf_LABEL[] = "inf";
static constexpr char Lp_LABEL[] = "p";
/// str starts from the lowercase letter; not constexpr due to the compiler version
/*constexpr*/ std::string makeFirstLetterUppercase(std::string && str)
{
@ -116,16 +126,12 @@ public:
}
};
static const char PLUS_NAME[] = "plus";
using FunctionTuplePlus = FunctionTupleOperator<PLUS_NAME>;
static const char MINUS_NAME[] = "minus";
using FunctionTupleMinus = FunctionTupleOperator<MINUS_NAME>;
static const char MULTIPLY_NAME[] = "multiply";
using FunctionTupleMultiply = FunctionTupleOperator<MULTIPLY_NAME>;
static const char DIVIDE_NAME[] = "divide";
using FunctionTupleDivide = FunctionTupleOperator<DIVIDE_NAME>;
class FunctionTupleNegate : public TupleIFunction
@ -401,13 +407,18 @@ public:
}
};
class FunctionL1Norm : public TupleIFunction
/// this is for convenient usage in LNormalize
template <const char * func_label>
class FunctionLNorm : public TupleIFunction {};
template <>
class FunctionLNorm<L1_LABEL> : public TupleIFunction
{
public:
static constexpr auto name = "L1Norm";
explicit FunctionL1Norm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionL1Norm>(context_); }
explicit FunctionLNorm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLNorm>(context_); }
String getName() const override { return name; }
@ -500,14 +511,16 @@ public:
return res.column;
}
};
using FunctionL1Norm = FunctionLNorm<L1_LABEL>;
class FunctionL2Norm : public TupleIFunction
template <>
class FunctionLNorm<L2_LABEL> : public TupleIFunction
{
public:
static constexpr auto name = "L2Norm";
explicit FunctionL2Norm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionL2Norm>(context_); }
explicit FunctionLNorm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLNorm>(context_); }
String getName() const override { return name; }
@ -603,14 +616,16 @@ public:
return sqrt_elem->execute({res}, sqrt_elem->getResultType(), input_rows_count);
}
};
using FunctionL2Norm = FunctionLNorm<L2_LABEL>;
class FunctionLinfNorm : public TupleIFunction
template <>
class FunctionLNorm<Linf_LABEL> : public TupleIFunction
{
public:
static constexpr auto name = "LinfNorm";
explicit FunctionLinfNorm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLinfNorm>(context_); }
explicit FunctionLNorm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLNorm>(context_); }
String getName() const override { return name; }
@ -703,14 +718,16 @@ public:
return res.column;
}
};
using FunctionLinfNorm = FunctionLNorm<Linf_LABEL>;
class FunctionLpNorm : public TupleIFunction
template <>
class FunctionLNorm<Lp_LABEL> : public TupleIFunction
{
public:
static constexpr auto name = "LpNorm";
explicit FunctionLpNorm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLpNorm>(context_); }
explicit FunctionLNorm(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLNorm>(context_); }
String getName() const override { return name; }
@ -828,16 +845,17 @@ public:
return pow_elem->execute({res, inv_p_column}, pow_elem->getResultType(), input_rows_count);
}
};
using FunctionLpNorm = FunctionLNorm<Lp_LABEL>;
template <const char * func_label>
class FunctionVectorDistance : public TupleIFunction
class FunctionLDistance : public TupleIFunction
{
public:
/// constexpr cannot be used due to std::string has not constexpr constructor in this compiler version
static inline auto name = "L" + std::string(func_label) + "Distance";
explicit FunctionVectorDistance(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionVectorDistance>(context_); }
explicit FunctionLDistance(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLDistance>(context_); }
String getName() const override { return name; }
@ -886,17 +904,66 @@ public:
}
};
static constexpr char L1DISTANCE_LABEL[] = "1";
using FunctionL1Distance = FunctionVectorDistance<L1DISTANCE_LABEL>;
using FunctionL1Distance = FunctionLDistance<L1_LABEL>;
static constexpr char L2DISTANCE_LABEL[] = "2";
using FunctionL2Distance = FunctionVectorDistance<L2DISTANCE_LABEL>;
using FunctionL2Distance = FunctionLDistance<L2_LABEL>;
static constexpr char LinfDISTANCE_LABEL[] = "inf";
using FunctionLinfDistance = FunctionVectorDistance<LinfDISTANCE_LABEL>;
using FunctionLinfDistance = FunctionLDistance<Linf_LABEL>;
static constexpr char LpDISTANCE_LABEL[] = "p";
using FunctionLpDistance = FunctionVectorDistance<LpDISTANCE_LABEL>;
using FunctionLpDistance = FunctionLDistance<Lp_LABEL>;
template <const char * func_label>
class FunctionLNormalize : public TupleIFunction
{
public:
/// constexpr cannot be used due to std::string has not constexpr constructor in this compiler version
static inline auto name = "L" + std::string(func_label) + "Normalize";
explicit FunctionLNormalize(ContextPtr context_) : TupleIFunction(context_) {}
static FunctionPtr create(ContextPtr context_) { return std::make_shared<FunctionLNormalize>(context_); }
String getName() const override { return name; }
size_t getNumberOfArguments() const override
{
if constexpr (func_label[0] == 'p')
return 2;
else
return 1;
}
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
FunctionLNorm<func_label> norm(context);
auto type = norm.getReturnTypeImpl(arguments);
auto column = norm.executeImpl(arguments, DataTypePtr(), 1);
ColumnWithTypeAndName norm_res{column, type, {}};
FunctionTupleDivideByNumber divide(context);
return divide.getReturnTypeImpl({arguments[0], norm_res});
}
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
{
FunctionLNorm<func_label> norm(context);
auto type = norm.getReturnTypeImpl(arguments);
auto column = norm.executeImpl(arguments, DataTypePtr(), input_rows_count);
ColumnWithTypeAndName norm_res{column, type, {}};
FunctionTupleDivideByNumber divide(context);
return divide.executeImpl({arguments[0], norm_res}, DataTypePtr(), input_rows_count);
}
};
using FunctionL1Normalize = FunctionLNormalize<L1_LABEL>;
using FunctionL2Normalize = FunctionLNormalize<L2_LABEL>;
using FunctionLinfNormalize = FunctionLNormalize<Linf_LABEL>;
using FunctionLpNormalize = FunctionLNormalize<Lp_LABEL>;
void registerVectorFunctions(FunctionFactory & factory)
{
@ -924,10 +991,10 @@ void registerVectorFunctions(FunctionFactory & factory)
factory.registerFunction<FunctionLinfDistance>();
factory.registerFunction<FunctionLpDistance>();
// factory.registerFunction<FunctionL1Normalize>();
// factory.registerFunction<FunctionL2Normalize>();
// factory.registerFunction<FunctionLinfNormalize>();
// factory.registerFunction<FunctionLpNormalize>();
factory.registerFunction<FunctionL1Normalize>();
factory.registerFunction<FunctionL2Normalize>();
factory.registerFunction<FunctionLinfNormalize>();
factory.registerFunction<FunctionLpNormalize>();
//
// factory.registerFunction<FunctionCosineDistance>();
}

View File

@ -16,7 +16,7 @@
6
7.1
1.4142135623730951
5
13
1.5
-3
2.3
@ -31,3 +31,7 @@
1
0
-4.413254828250501e-8
(0.2,-0.8)
(0.6,0.8)
(1,-1,1)
(0.5,0.993670377332229)

View File

@ -20,8 +20,8 @@ SELECT scalarProduct(tuple(1), tuple(0));
SELECT L1Norm((-1, 2, -3));
SELECT L1Norm((-1, 2.5, -3.6));
SELECT L2Norm((1, 1));
SELECT L2Norm((3, 4));
SELECT L2Norm((1, 1.0));
SELECT L2Norm((-12, 5));
SELECT max2(1, 1.5);
SELECT min2(-1, -3);
@ -38,4 +38,9 @@ SELECT L1Distance((1, 2, 3), (2, 3, 1));
SELECT L2Distance((1, 1), (3, -1));
SELECT LinfDistance((1, 1), (1, 2));
SELECT L2Distance((5, 5), (5, 5));
SELECT LpDistance((1800, 1900), (18, 59), 12) - LpDistance(tuple(-22), tuple(1900), 12);
SELECT LpDistance((1800, 1900), (18, 59), 12) - LpDistance(tuple(-22), tuple(1900), 12);
SELECT L1Normalize((1, -4));
SELECT L2Normalize((3, 4));
SELECT LinfNormalize((5, -5, 5.0));
SELECT LpNormalize((1, 1.98734075466445795857), 5);