mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-10 01:25:21 +00:00
Merge pull request #37601 from ClickHouse/array_norm_dist_fixes
Added LpNorm and LpDistance functions for arrays
This commit is contained in:
commit
6a57e1a970
@ -7,21 +7,25 @@
|
||||
#include <DataTypes/getLeastSupertype.h>
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include "base/range.h"
|
||||
#include <base/range.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int SIZES_OF_ARRAYS_DOESNT_MATCH;
|
||||
extern const int ARGUMENT_OUT_OF_BOUND;
|
||||
}
|
||||
|
||||
struct L1Distance
|
||||
{
|
||||
static inline String name = "L1";
|
||||
|
||||
struct ConstParams {};
|
||||
|
||||
template <typename FloatType>
|
||||
struct State
|
||||
{
|
||||
@ -29,13 +33,13 @@ struct L1Distance
|
||||
};
|
||||
|
||||
template <typename ResultType>
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y)
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y, const ConstParams &)
|
||||
{
|
||||
state.sum += fabs(x - y);
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state)
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
|
||||
{
|
||||
return state.sum;
|
||||
}
|
||||
@ -45,6 +49,8 @@ struct L2Distance
|
||||
{
|
||||
static inline String name = "L2";
|
||||
|
||||
struct ConstParams {};
|
||||
|
||||
template <typename FloatType>
|
||||
struct State
|
||||
{
|
||||
@ -52,22 +58,53 @@ struct L2Distance
|
||||
};
|
||||
|
||||
template <typename ResultType>
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y)
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y, const ConstParams &)
|
||||
{
|
||||
state.sum += (x - y) * (x - y);
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state)
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
|
||||
{
|
||||
return sqrt(state.sum);
|
||||
}
|
||||
};
|
||||
|
||||
struct LpDistance
|
||||
{
|
||||
static inline String name = "Lp";
|
||||
|
||||
struct ConstParams
|
||||
{
|
||||
Float64 power;
|
||||
Float64 inverted_power;
|
||||
};
|
||||
|
||||
template <typename FloatType>
|
||||
struct State
|
||||
{
|
||||
FloatType sum = 0;
|
||||
};
|
||||
|
||||
template <typename ResultType>
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y, const ConstParams & params)
|
||||
{
|
||||
state.sum += std::pow(fabs(x - y), params.power);
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams & params)
|
||||
{
|
||||
return std::pow(state.sum, params.inverted_power);
|
||||
}
|
||||
};
|
||||
|
||||
struct LinfDistance
|
||||
{
|
||||
static inline String name = "Linf";
|
||||
|
||||
struct ConstParams {};
|
||||
|
||||
template <typename FloatType>
|
||||
struct State
|
||||
{
|
||||
@ -75,21 +112,24 @@ struct LinfDistance
|
||||
};
|
||||
|
||||
template <typename ResultType>
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y)
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y, const ConstParams &)
|
||||
{
|
||||
state.dist = fmax(state.dist, fabs(x - y));
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state)
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
|
||||
{
|
||||
return state.dist;
|
||||
}
|
||||
};
|
||||
|
||||
struct CosineDistance
|
||||
{
|
||||
static inline String name = "Cosine";
|
||||
|
||||
struct ConstParams {};
|
||||
|
||||
template <typename FloatType>
|
||||
struct State
|
||||
{
|
||||
@ -99,7 +139,7 @@ struct CosineDistance
|
||||
};
|
||||
|
||||
template <typename ResultType>
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y)
|
||||
static void accumulate(State<ResultType> & state, ResultType x, ResultType y, const ConstParams &)
|
||||
{
|
||||
state.dot_prod += x * y;
|
||||
state.x_squared += x * x;
|
||||
@ -107,7 +147,7 @@ struct CosineDistance
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
static ResultType finalize(const State<ResultType> & state)
|
||||
static ResultType finalize(const State<ResultType> & state, const ConstParams &)
|
||||
{
|
||||
return 1 - state.dot_prod / sqrt(state.x_squared * state.y_squared);
|
||||
}
|
||||
@ -121,17 +161,18 @@ public:
|
||||
String getName() const override { return name; }
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayDistance<Kernel>>(); }
|
||||
size_t getNumberOfArguments() const override { return 2; }
|
||||
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {}; }
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
DataTypes types;
|
||||
for (const auto & argument : arguments)
|
||||
for (size_t i = 0; i < 2; ++i)
|
||||
{
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(argument.type.get());
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(arguments[i].type.get());
|
||||
if (!array_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument of function {} must be array.", getName());
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument {} of function {} must be array.", i, getName());
|
||||
|
||||
types.push_back(array_type->getNestedType());
|
||||
}
|
||||
@ -221,7 +262,7 @@ private:
|
||||
{
|
||||
#define ON_TYPE(type) \
|
||||
case TypeIndex::type: \
|
||||
return executeWithTypes<ResultType, FirstArgType, type>(arguments[0].column, arguments[1].column, input_rows_count); \
|
||||
return executeWithTypes<ResultType, FirstArgType, type>(arguments[0].column, arguments[1].column, input_rows_count, arguments); \
|
||||
break;
|
||||
|
||||
SUPPORTED_TYPES(ON_TYPE)
|
||||
@ -237,15 +278,15 @@ private:
|
||||
}
|
||||
|
||||
template <typename ResultType, typename FirstArgType, typename SecondArgType>
|
||||
ColumnPtr executeWithTypes(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count) const
|
||||
ColumnPtr executeWithTypes(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
if (typeid_cast<const ColumnConst *>(col_x.get()))
|
||||
{
|
||||
return executeWithTypesFirstArgConst<ResultType, FirstArgType, SecondArgType>(col_x, col_y, input_rows_count);
|
||||
return executeWithTypesFirstArgConst<ResultType, FirstArgType, SecondArgType>(col_x, col_y, input_rows_count, arguments);
|
||||
}
|
||||
else if (typeid_cast<const ColumnConst *>(col_y.get()))
|
||||
{
|
||||
return executeWithTypesFirstArgConst<ResultType, SecondArgType, FirstArgType>(col_y, col_x, input_rows_count);
|
||||
return executeWithTypesFirstArgConst<ResultType, SecondArgType, FirstArgType>(col_y, col_x, input_rows_count, arguments);
|
||||
}
|
||||
|
||||
col_x = col_x->convertToFullColumnIfConst();
|
||||
@ -273,6 +314,8 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
const typename Kernel::ConstParams kernel_params = initConstParams(arguments);
|
||||
|
||||
auto result = ColumnVector<ResultType>::create(input_rows_count);
|
||||
auto & result_data = result->getData();
|
||||
|
||||
@ -284,9 +327,9 @@ private:
|
||||
typename Kernel::template State<Float64> state;
|
||||
for (; prev < off; ++prev)
|
||||
{
|
||||
Kernel::template accumulate<Float64>(state, data_x[prev], data_y[prev]);
|
||||
Kernel::template accumulate<Float64>(state, data_x[prev], data_y[prev], kernel_params);
|
||||
}
|
||||
result_data[row] = Kernel::finalize(state);
|
||||
result_data[row] = Kernel::finalize(state, kernel_params);
|
||||
row++;
|
||||
}
|
||||
return result;
|
||||
@ -294,7 +337,7 @@ private:
|
||||
|
||||
/// Special case when the 1st parameter is Const
|
||||
template <typename ResultType, typename FirstArgType, typename SecondArgType>
|
||||
ColumnPtr executeWithTypesFirstArgConst(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count) const
|
||||
ColumnPtr executeWithTypesFirstArgConst(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
col_x = assert_cast<const ColumnConst *>(col_x.get())->getDataColumnPtr();
|
||||
col_y = col_y->convertToFullColumnIfConst();
|
||||
@ -322,6 +365,8 @@ private:
|
||||
prev_offset = offsets_y[row];
|
||||
}
|
||||
|
||||
const typename Kernel::ConstParams kernel_params = initConstParams(arguments);
|
||||
|
||||
auto result = ColumnVector<ResultType>::create(input_rows_count);
|
||||
auto & result_data = result->getData();
|
||||
|
||||
@ -333,19 +378,59 @@ private:
|
||||
typename Kernel::template State<Float64> state;
|
||||
for (size_t i = 0; prev < off; ++i, ++prev)
|
||||
{
|
||||
Kernel::template accumulate<Float64>(state, data_x[i], data_y[prev]);
|
||||
Kernel::template accumulate<Float64>(state, data_x[i], data_y[prev], kernel_params);
|
||||
}
|
||||
result_data[row] = Kernel::finalize(state);
|
||||
result_data[row] = Kernel::finalize(state, kernel_params);
|
||||
row++;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
typename Kernel::ConstParams initConstParams(const ColumnsWithTypeAndName &) const { return {}; }
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
size_t FunctionArrayDistance<LpDistance>::getNumberOfArguments() const { return 3; }
|
||||
|
||||
template <>
|
||||
ColumnNumbers FunctionArrayDistance<LpDistance>::getArgumentsThatAreAlwaysConstant() const { return {2}; }
|
||||
|
||||
template <>
|
||||
LpDistance::ConstParams FunctionArrayDistance<LpDistance>::initConstParams(const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
if (arguments.size() < 3)
|
||||
throw Exception(
|
||||
ErrorCodes::LOGICAL_ERROR,
|
||||
"Argument p of function {} was not provided",
|
||||
getName());
|
||||
|
||||
if (!arguments[2].column->isNumeric())
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Argument p of function {} must be numeric constant",
|
||||
getName());
|
||||
|
||||
if (!isColumnConst(*arguments[2].column) && arguments[2].column->size() != 1)
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_COLUMN,
|
||||
"Second argument for function {} must be either constant Float64 or constant UInt",
|
||||
getName());
|
||||
|
||||
Float64 p = arguments[2].column->getFloat64(0);
|
||||
if (p < 1 || p >= HUGE_VAL)
|
||||
throw Exception(
|
||||
ErrorCodes::ARGUMENT_OUT_OF_BOUND,
|
||||
"Second argument for function {} must be not less than one and not be an infinity",
|
||||
getName());
|
||||
|
||||
return LpDistance::ConstParams{p, 1 / p};
|
||||
}
|
||||
|
||||
/// These functions are used by TupleOrArrayFunction
|
||||
FunctionPtr createFunctionArrayL1Distance(ContextPtr context_) { return FunctionArrayDistance<L1Distance>::create(context_); }
|
||||
FunctionPtr createFunctionArrayL2Distance(ContextPtr context_) { return FunctionArrayDistance<L2Distance>::create(context_); }
|
||||
FunctionPtr createFunctionArrayLpDistance(ContextPtr context_) { return FunctionArrayDistance<LpDistance>::create(context_); }
|
||||
FunctionPtr createFunctionArrayLinfDistance(ContextPtr context_) { return FunctionArrayDistance<LinfDistance>::create(context_); }
|
||||
FunctionPtr createFunctionArrayCosineDistance(ContextPtr context_) { return FunctionArrayDistance<CosineDistance>::create(context_); }
|
||||
|
||||
|
@ -13,22 +13,26 @@ namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int ILLEGAL_COLUMN;
|
||||
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int ARGUMENT_OUT_OF_BOUND;
|
||||
}
|
||||
|
||||
struct L1Norm
|
||||
{
|
||||
static inline String name = "L1";
|
||||
|
||||
struct ConstParams {};
|
||||
|
||||
template <typename ResultType>
|
||||
inline static ResultType accumulate(ResultType result, ResultType value)
|
||||
inline static ResultType accumulate(ResultType result, ResultType value, const ConstParams &)
|
||||
{
|
||||
return result + fabs(value);
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
inline static ResultType finalize(ResultType result)
|
||||
inline static ResultType finalize(ResultType result, const ConstParams &)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
@ -38,32 +42,59 @@ struct L2Norm
|
||||
{
|
||||
static inline String name = "L2";
|
||||
|
||||
struct ConstParams {};
|
||||
|
||||
template <typename ResultType>
|
||||
inline static ResultType accumulate(ResultType result, ResultType value)
|
||||
inline static ResultType accumulate(ResultType result, ResultType value, const ConstParams &)
|
||||
{
|
||||
return result + value * value;
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
inline static ResultType finalize(ResultType result)
|
||||
inline static ResultType finalize(ResultType result, const ConstParams &)
|
||||
{
|
||||
return sqrt(result);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct LpNorm
|
||||
{
|
||||
static inline String name = "Lp";
|
||||
|
||||
struct ConstParams
|
||||
{
|
||||
Float64 power;
|
||||
Float64 inverted_power = 1 / power;
|
||||
};
|
||||
|
||||
template <typename ResultType>
|
||||
inline static ResultType accumulate(ResultType result, ResultType value, const ConstParams & params)
|
||||
{
|
||||
return result + std::pow(fabs(value), params.power);
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
inline static ResultType finalize(ResultType result, const ConstParams & params)
|
||||
{
|
||||
return std::pow(result, params.inverted_power);
|
||||
}
|
||||
};
|
||||
|
||||
struct LinfNorm
|
||||
{
|
||||
static inline String name = "Linf";
|
||||
|
||||
struct ConstParams {};
|
||||
|
||||
template <typename ResultType>
|
||||
inline static ResultType accumulate(ResultType result, ResultType value)
|
||||
inline static ResultType accumulate(ResultType result, ResultType value, const ConstParams &)
|
||||
{
|
||||
return fmax(result, fabs(value));
|
||||
}
|
||||
|
||||
template <typename ResultType>
|
||||
inline static ResultType finalize(ResultType result)
|
||||
inline static ResultType finalize(ResultType result, const ConstParams &)
|
||||
{
|
||||
return result;
|
||||
}
|
||||
@ -78,22 +109,17 @@ public:
|
||||
String getName() const override { return name; }
|
||||
static FunctionPtr create(ContextPtr) { return std::make_shared<FunctionArrayNorm<Kernel>>(); }
|
||||
size_t getNumberOfArguments() const override { return 1; }
|
||||
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override { return {}; }
|
||||
bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; }
|
||||
bool useDefaultImplementationForConstants() const override { return true; }
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
|
||||
{
|
||||
DataTypes types;
|
||||
for (const auto & argument : arguments)
|
||||
{
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(argument.type.get());
|
||||
if (!array_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument of function {} must be array.", getName());
|
||||
const auto * array_type = checkAndGetDataType<DataTypeArray>(arguments[0].type.get());
|
||||
if (!array_type)
|
||||
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Argument of function {} must be array.", getName());
|
||||
|
||||
types.push_back(array_type->getNestedType());
|
||||
}
|
||||
const auto & common_type = getLeastSupertype(types);
|
||||
switch (common_type->getTypeId())
|
||||
switch (array_type->getNestedType()->getTypeId())
|
||||
{
|
||||
case TypeIndex::UInt8:
|
||||
case TypeIndex::UInt16:
|
||||
@ -111,7 +137,7 @@ public:
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Arguments of function {} has nested type {}. "
|
||||
"Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.",
|
||||
getName(), common_type->getName());
|
||||
getName(), array_type->getNestedType()->getName());
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,7 +151,7 @@ public:
|
||||
switch (result_type->getTypeId())
|
||||
{
|
||||
case TypeIndex::Float64:
|
||||
return executeWithResultType<Float64>(*arr, type, input_rows_count);
|
||||
return executeWithResultType<Float64>(*arr, type, input_rows_count, arguments);
|
||||
break;
|
||||
default:
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName());
|
||||
@ -148,13 +174,13 @@ private:
|
||||
|
||||
|
||||
template <typename ResultType>
|
||||
ColumnPtr executeWithResultType(const ColumnArray & array, const DataTypePtr & nested_type, size_t input_rows_count) const
|
||||
ColumnPtr executeWithResultType(const ColumnArray & array, const DataTypePtr & nested_type, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
switch (nested_type->getTypeId())
|
||||
{
|
||||
#define ON_TYPE(type) \
|
||||
case TypeIndex::type: \
|
||||
return executeWithTypes<ResultType, type>(array, input_rows_count); \
|
||||
return executeWithTypes<ResultType, type>(array, input_rows_count, arguments); \
|
||||
break;
|
||||
|
||||
SUPPORTED_TYPES(ON_TYPE)
|
||||
@ -170,7 +196,7 @@ private:
|
||||
}
|
||||
|
||||
template <typename ResultType, typename ArgumentType>
|
||||
static ColumnPtr executeWithTypes(const ColumnArray & array, size_t input_rows_count)
|
||||
ColumnPtr executeWithTypes(const ColumnArray & array, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
const auto & data = typeid_cast<const ColumnVector<ArgumentType> &>(array.getData()).getData();
|
||||
const auto & offsets = array.getOffsets();
|
||||
@ -178,6 +204,8 @@ private:
|
||||
auto result_col = ColumnVector<ResultType>::create(input_rows_count);
|
||||
auto & result_data = result_col->getData();
|
||||
|
||||
const typename Kernel::ConstParams kernel_params = initConstParams(arguments);
|
||||
|
||||
ColumnArray::Offset prev = 0;
|
||||
size_t row = 0;
|
||||
for (auto off : offsets)
|
||||
@ -185,18 +213,59 @@ private:
|
||||
Float64 result = 0;
|
||||
for (; prev < off; ++prev)
|
||||
{
|
||||
result = Kernel::template accumulate<Float64>(result, data[prev]);
|
||||
result = Kernel::template accumulate<Float64>(result, data[prev], kernel_params);
|
||||
}
|
||||
result_data[row] = Kernel::finalize(result);
|
||||
result_data[row] = Kernel::finalize(result, kernel_params);
|
||||
row++;
|
||||
}
|
||||
return result_col;
|
||||
}
|
||||
|
||||
typename Kernel::ConstParams initConstParams(const ColumnsWithTypeAndName &) const { return {}; }
|
||||
};
|
||||
|
||||
template <>
|
||||
size_t FunctionArrayNorm<LpNorm>::getNumberOfArguments() const { return 2; }
|
||||
|
||||
template <>
|
||||
ColumnNumbers FunctionArrayNorm<LpNorm>::getArgumentsThatAreAlwaysConstant() const { return {1}; }
|
||||
|
||||
template <>
|
||||
LpNorm::ConstParams FunctionArrayNorm<LpNorm>::initConstParams(const ColumnsWithTypeAndName & arguments) const
|
||||
{
|
||||
if (arguments.size() < 2)
|
||||
throw Exception(
|
||||
ErrorCodes::LOGICAL_ERROR,
|
||||
"Argument p of function {} was not provided",
|
||||
getName());
|
||||
|
||||
if (!arguments[1].column->isNumeric())
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
|
||||
"Argument p of function {} must be numeric constant",
|
||||
getName());
|
||||
|
||||
if (!isColumnConst(*arguments[1].column) && arguments[1].column->size() != 1)
|
||||
throw Exception(
|
||||
ErrorCodes::ILLEGAL_COLUMN,
|
||||
"Second argument for function {} must be either constant Float64 or constant UInt",
|
||||
getName());
|
||||
|
||||
Float64 p = arguments[1].column->getFloat64(0);
|
||||
if (p < 1 || p >= HUGE_VAL)
|
||||
throw Exception(
|
||||
ErrorCodes::ARGUMENT_OUT_OF_BOUND,
|
||||
"Second argument for function {} must be not less than one and not be an infinity",
|
||||
getName());
|
||||
|
||||
return LpNorm::ConstParams{p, 1 / p};
|
||||
}
|
||||
|
||||
|
||||
/// These functions are used by TupleOrArrayFunction
|
||||
FunctionPtr createFunctionArrayL1Norm(ContextPtr context_) { return FunctionArrayNorm<L1Norm>::create(context_); }
|
||||
FunctionPtr createFunctionArrayL2Norm(ContextPtr context_) { return FunctionArrayNorm<L2Norm>::create(context_); }
|
||||
FunctionPtr createFunctionArrayLpNorm(ContextPtr context_) { return FunctionArrayNorm<LpNorm>::create(context_); }
|
||||
FunctionPtr createFunctionArrayLinfNorm(ContextPtr context_) { return FunctionArrayNorm<LinfNorm>::create(context_); }
|
||||
|
||||
}
|
||||
|
@ -810,16 +810,18 @@ public:
|
||||
|
||||
const auto & p_column = arguments[1];
|
||||
|
||||
const auto * p_column_const = assert_cast<const ColumnConst *>(p_column.column.get());
|
||||
if (!isColumnConst(*p_column.column) && p_column.column->size() != 1)
|
||||
throw Exception{"Second argument for function " + getName() + " must be either constant Float64 or constant UInt", ErrorCodes::ILLEGAL_COLUMN};
|
||||
|
||||
double p;
|
||||
if (isFloat(p_column_const->getDataType()))
|
||||
p = p_column_const->getFloat64(0);
|
||||
else if (isUnsignedInteger(p_column_const->getDataType()))
|
||||
p = p_column_const->getUInt(0);
|
||||
if (isFloat(p_column.column->getDataType()))
|
||||
p = p_column.column->getFloat64(0);
|
||||
else if (isUnsignedInteger(p_column.column->getDataType()))
|
||||
p = p_column.column->getUInt(0);
|
||||
else
|
||||
throw Exception{"Second argument for function " + getName() + " must be either constant Float64 or constant UInt", ErrorCodes::ILLEGAL_COLUMN};
|
||||
|
||||
if (p < 1 || p == HUGE_VAL)
|
||||
if (p < 1 || p >= HUGE_VAL)
|
||||
throw Exception{"Second argument for function " + getName() + " must be not less than one and not be an infinity", ErrorCodes::ARGUMENT_OUT_OF_BOUND};
|
||||
|
||||
auto abs = FunctionFactory::instance().get("abs", context);
|
||||
@ -1109,10 +1111,12 @@ private:
|
||||
|
||||
extern FunctionPtr createFunctionArrayL1Norm(ContextPtr context_);
|
||||
extern FunctionPtr createFunctionArrayL2Norm(ContextPtr context_);
|
||||
extern FunctionPtr createFunctionArrayLpNorm(ContextPtr context_);
|
||||
extern FunctionPtr createFunctionArrayLinfNorm(ContextPtr context_);
|
||||
|
||||
extern FunctionPtr createFunctionArrayL1Distance(ContextPtr context_);
|
||||
extern FunctionPtr createFunctionArrayL2Distance(ContextPtr context_);
|
||||
extern FunctionPtr createFunctionArrayLpDistance(ContextPtr context_);
|
||||
extern FunctionPtr createFunctionArrayLinfDistance(ContextPtr context_);
|
||||
extern FunctionPtr createFunctionArrayCosineDistance(ContextPtr context_);
|
||||
|
||||
@ -1132,6 +1136,14 @@ struct L2NormTraits
|
||||
static constexpr auto CreateArrayFunction = createFunctionArrayL2Norm;
|
||||
};
|
||||
|
||||
struct LpNormTraits
|
||||
{
|
||||
static inline String name = "LpNorm";
|
||||
|
||||
static constexpr auto CreateTupleFunction = FunctionLpNorm::create;
|
||||
static constexpr auto CreateArrayFunction = createFunctionArrayLpNorm;
|
||||
};
|
||||
|
||||
struct LinfNormTraits
|
||||
{
|
||||
static inline String name = "LinfNorm";
|
||||
@ -1156,6 +1168,14 @@ struct L2DistanceTraits
|
||||
static constexpr auto CreateArrayFunction = createFunctionArrayL2Distance;
|
||||
};
|
||||
|
||||
struct LpDistanceTraits
|
||||
{
|
||||
static inline String name = "LpDistance";
|
||||
|
||||
static constexpr auto CreateTupleFunction = FunctionLpDistance::create;
|
||||
static constexpr auto CreateArrayFunction = createFunctionArrayLpDistance;
|
||||
};
|
||||
|
||||
struct LinfDistanceTraits
|
||||
{
|
||||
static inline String name = "LinfDistance";
|
||||
@ -1174,10 +1194,12 @@ struct CosineDistanceTraits
|
||||
|
||||
using TupleOrArrayFunctionL1Norm = TupleOrArrayFunction<L1NormTraits>;
|
||||
using TupleOrArrayFunctionL2Norm = TupleOrArrayFunction<L2NormTraits>;
|
||||
using TupleOrArrayFunctionLpNorm = TupleOrArrayFunction<LpNormTraits>;
|
||||
using TupleOrArrayFunctionLinfNorm = TupleOrArrayFunction<LinfNormTraits>;
|
||||
|
||||
using TupleOrArrayFunctionL1Distance = TupleOrArrayFunction<L1DistanceTraits>;
|
||||
using TupleOrArrayFunctionL2Distance = TupleOrArrayFunction<L2DistanceTraits>;
|
||||
using TupleOrArrayFunctionLpDistance = TupleOrArrayFunction<LpDistanceTraits>;
|
||||
using TupleOrArrayFunctionLinfDistance = TupleOrArrayFunction<LinfDistanceTraits>;
|
||||
using TupleOrArrayFunctionCosineDistance = TupleOrArrayFunction<CosineDistanceTraits>;
|
||||
|
||||
@ -1200,7 +1222,7 @@ void registerVectorFunctions(FunctionFactory & factory)
|
||||
factory.registerFunction<TupleOrArrayFunctionL1Norm>();
|
||||
factory.registerFunction<TupleOrArrayFunctionL2Norm>();
|
||||
factory.registerFunction<TupleOrArrayFunctionLinfNorm>();
|
||||
factory.registerFunction<FunctionLpNorm>();
|
||||
factory.registerFunction<TupleOrArrayFunctionLpNorm>();
|
||||
|
||||
factory.registerAlias("normL1", TupleOrArrayFunctionL1Norm::name, FunctionFactory::CaseInsensitive);
|
||||
factory.registerAlias("normL2", TupleOrArrayFunctionL2Norm::name, FunctionFactory::CaseInsensitive);
|
||||
@ -1210,7 +1232,7 @@ void registerVectorFunctions(FunctionFactory & factory)
|
||||
factory.registerFunction<TupleOrArrayFunctionL1Distance>();
|
||||
factory.registerFunction<TupleOrArrayFunctionL2Distance>();
|
||||
factory.registerFunction<TupleOrArrayFunctionLinfDistance>();
|
||||
factory.registerFunction<FunctionLpDistance>();
|
||||
factory.registerFunction<TupleOrArrayFunctionLpDistance>();
|
||||
|
||||
factory.registerAlias("distanceL1", FunctionL1Distance::name, FunctionFactory::CaseInsensitive);
|
||||
factory.registerAlias("distanceL2", FunctionL2Distance::name, FunctionFactory::CaseInsensitive);
|
||||
|
@ -1,5 +1,6 @@
|
||||
6
|
||||
3.7416573867739413
|
||||
3.2071843327373397
|
||||
3
|
||||
0.00258509695694209
|
||||
\N
|
||||
@ -11,6 +12,9 @@ nan
|
||||
7.0710678118654755
|
||||
9.16515138991168
|
||||
12.12435565298214
|
||||
5.917593844525055
|
||||
8.308858759453505
|
||||
9.932246380845738
|
||||
2
|
||||
5
|
||||
4
|
||||
|
@ -1,5 +1,6 @@
|
||||
SELECT L1Distance([0, 0, 0], [1, 2, 3]);
|
||||
SELECT L2Distance([1, 2, 3], [0, 0, 0]);
|
||||
SELECT LpDistance([1, 2, 3], [0, 0, 0], 3.5);
|
||||
SELECT LinfDistance([1, 2, 3], [0, 0, 0]);
|
||||
SELECT cosineDistance([1, 2, 3], [3, 5, 7]);
|
||||
|
||||
@ -26,6 +27,7 @@ CREATE TABLE vec2d (id UInt64, v Array(Float64)) ENGINE = Memory;
|
||||
INSERT INTO vec1 VALUES (1, [3, 4, 5]), (2, [2, 4, 8]), (3, [7, 7, 7]);
|
||||
SELECT L1Distance(v, [0, 0, 0]) FROM vec1;
|
||||
SELECT L2Distance(v, [0, 0, 0]) FROM vec1;
|
||||
SELECT LpDistance(v, [0, 0, 0], 3.14) FROM vec1;
|
||||
SELECT LinfDistance([5, 4, 3], v) FROM vec1;
|
||||
SELECT cosineDistance([3, 2, 1], v) FROM vec1;
|
||||
SELECT LinfDistance(v, materialize([0, -2, 0])) FROM vec1;
|
||||
@ -42,6 +44,10 @@ SELECT v1.id, v2.id, L2Distance(v1.v, v2.v) as dist FROM vec1 v1, vec2d v2;
|
||||
|
||||
SELECT L1Distance([0, 0], [1]); -- { serverError 190 }
|
||||
SELECT L2Distance([1, 2], (3,4)); -- { serverError 43 }
|
||||
SELECT LpDistance([1, 2], [3,4]); -- { serverError 42 }
|
||||
SELECT LpDistance([1, 2], [3,4], -1.); -- { serverError 69 }
|
||||
SELECT LpDistance([1, 2], [3,4], 'aaa'); -- { serverError 43 }
|
||||
SELECT LpDistance([1, 2], [3,4], materialize(2.7)); -- { serverError 44 }
|
||||
|
||||
DROP TABLE vec1;
|
||||
DROP TABLE vec2;
|
||||
|
@ -1,27 +1,28 @@
|
||||
6
|
||||
7.0710678118654755
|
||||
10.882246697870885
|
||||
2
|
||||
10803059573 4234902446.7343364 2096941042
|
||||
1 5
|
||||
2 2
|
||||
3 5.196152422706632
|
||||
4 0
|
||||
10803059573 4234902446.7343364 10803059573 4234902446.7343364 3122003357.3280888 2096941042
|
||||
1 7 5 4.601724723020627 4
|
||||
2 2 2 2 2
|
||||
3 9 5.196152422706632 4.506432087111623 3
|
||||
4 0 0 0 0
|
||||
1 11
|
||||
2 11
|
||||
3 11
|
||||
4 11
|
||||
1 5
|
||||
2 2
|
||||
3 5.196152422706632
|
||||
4 0
|
||||
1 7 5 4.601724723020627 4
|
||||
2 2 2 2 2
|
||||
3 9 5.196152422706632 4.506432087111623 3
|
||||
4 0 0 0 0
|
||||
1 11
|
||||
2 11
|
||||
3 11
|
||||
4 11
|
||||
1 5
|
||||
2 2
|
||||
3 5.196152422706632
|
||||
4 0
|
||||
1 7 5 4.601724723020627 4
|
||||
2 2 2 2 2
|
||||
3 9 5.196152422706632 4.506432087111623 3
|
||||
4 0 0 0 0
|
||||
1 11
|
||||
2 11
|
||||
3 11
|
||||
|
@ -1,5 +1,6 @@
|
||||
SELECT L1Norm([1, 2, 3]);
|
||||
SELECT L2Norm([3., 4., 5.]);
|
||||
SELECT LpNorm([3., 4., 5.], 1.1);
|
||||
SELECT LinfNorm([0, 0, 2]);
|
||||
|
||||
-- Overflows
|
||||
@ -7,6 +8,9 @@ WITH CAST([-547274980, 1790553898, 1981517754, 1908431500, 1352428565, -57341255
|
||||
SELECT
|
||||
L1Norm(a),
|
||||
L2Norm(a),
|
||||
LpNorm(a,1),
|
||||
LpNorm(a,2),
|
||||
LpNorm(a,3.14),
|
||||
LinfNorm(a);
|
||||
|
||||
DROP TABLE IF EXISTS vec1;
|
||||
@ -19,17 +23,23 @@ INSERT INTO vec1 VALUES (1, [3, 4]), (2, [2]), (3, [3, 3, 3]), (4, NULL);
|
||||
INSERT INTO vec1f VALUES (1, [3, 4]), (2, [2]), (3, [3, 3, 3]), (4, NULL);
|
||||
INSERT INTO vec1d VALUES (1, [3, 4]), (2, [2]), (3, [3, 3, 3]), (4, NULL);
|
||||
|
||||
SELECT id, L2Norm(v) FROM vec1;
|
||||
SELECT id, L1Norm(v), L2Norm(v), LpNorm(v, 2.7), LinfNorm(v) FROM vec1;
|
||||
SELECT id, L1Norm(materialize([5., 6.])) FROM vec1;
|
||||
|
||||
SELECT id, L2Norm(v) FROM vec1f;
|
||||
SELECT id, L1Norm(v), L2Norm(v), LpNorm(v, 2.7), LinfNorm(v) FROM vec1f;
|
||||
SELECT id, L1Norm(materialize([5., 6.])) FROM vec1f;
|
||||
|
||||
SELECT id, L2Norm(v) FROM vec1d;
|
||||
SELECT id, L1Norm(v), L2Norm(v), LpNorm(v, 2.7), LinfNorm(v) FROM vec1d;
|
||||
SELECT id, L1Norm(materialize([5., 6.])) FROM vec1d;
|
||||
|
||||
SELECT L1Norm(1, 2); -- { serverError 42 }
|
||||
|
||||
SELECT LpNorm([1,2]); -- { serverError 42 }
|
||||
SELECT LpNorm([1,2], -3.4); -- { serverError 69 }
|
||||
SELECT LpNorm([1,2], 'aa'); -- { serverError 43 }
|
||||
SELECT LpNorm([1,2], [1]); -- { serverError 43 }
|
||||
SELECT LpNorm([1,2], materialize(3.14)); -- { serverError 44 }
|
||||
|
||||
DROP TABLE vec1;
|
||||
DROP TABLE vec1f;
|
||||
DROP TABLE vec1d;
|
||||
|
Loading…
Reference in New Issue
Block a user