From 307511aab4f618d013671d468c2f23db5660e953 Mon Sep 17 00:00:00 2001 From: qieqieplus Date: Mon, 9 May 2022 14:41:12 +0800 Subject: [PATCH] impl norm functions for array --- src/Functions/array/arrayNorm.cpp | 195 ++++++++++++++++++ .../array/registerFunctionsArray.cpp | 2 + 2 files changed, 197 insertions(+) create mode 100644 src/Functions/array/arrayNorm.cpp diff --git a/src/Functions/array/arrayNorm.cpp b/src/Functions/array/arrayNorm.cpp new file mode 100644 index 00000000000..cf8ae7524dc --- /dev/null +++ b/src/Functions/array/arrayNorm.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int LOGICAL_ERROR; +} + +template +struct LpNorm +{ + static inline String name = "L" + std::to_string(N); + template + static void compute(const std::vector> & vec, PaddedPODArray & array) + { + for (const auto & v : vec) + { + array.push_back(v.template lpNorm()); + } + } +}; + +struct LinfNorm : LpNorm +{ + static inline String name = "Linf"; +}; + +template +class FunctionArrayNorm : public IFunction +{ +public: + static inline auto name = "array" + Kernel::name + "Norm"; + String getName() const override { return name; } + static FunctionPtr create(ContextPtr) { return std::make_shared>(); } + size_t getNumberOfArguments() const override { return 1; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } + + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override + { + DataTypes types; + for (const auto & argument : arguments) + { + const auto * array_type = checkAndGetDataType(argument.type.get()); + if (!array_type) + throw Exception("Argument of function " + getName() + " must be array. ", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + types.push_back(array_type->getNestedType()); + } + const auto & common_type = getLeastSupertype(types); + switch (common_type->getTypeId()) + { + case TypeIndex::UInt8: + case TypeIndex::UInt16: + case TypeIndex::UInt32: + case TypeIndex::Int8: + case TypeIndex::Int16: + case TypeIndex::Int32: + case TypeIndex::Float32: + return std::make_shared(); + case TypeIndex::UInt64: + case TypeIndex::Int64: + case TypeIndex::Float64: + return std::make_shared(); + default: + throw Exception( + "Arguments of function " + getName() + " has nested type " + common_type->getName() + + ". Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + } + + ColumnPtr + executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override + { + const auto & type = typeid_cast(arguments[0].type.get())->getNestedType(); + const auto column = arguments[0].column->convertToFullColumnIfConst(); + const auto * arr = assert_cast(column.get()); + + auto result = result_type->createColumn(); + switch (result_type->getTypeId()) + { + case TypeIndex::Float32: + executeWithType(*arr, type, result); + break; + case TypeIndex::Float64: + executeWithType(*arr, type, result); + break; + default: + throw Exception("Unexpected result type.", ErrorCodes::LOGICAL_ERROR); + } + return result; + } + +private: + template + void executeWithType(const ColumnArray & array, const DataTypePtr & type, MutableColumnPtr & column) const + { + std::vector> v; + columnToVectors(array, type, v); + auto & data = assert_cast &>(*column).getData(); + Kernel::compute(v, data); + } + + template + void columnToVectors(const ColumnArray & array, const DataTypePtr & nested_type, std::vector> & vec) const + { + switch (nested_type->getTypeId()) + { + case TypeIndex::UInt8: + fillVectors(vec, array); + break; + case TypeIndex::UInt16: + fillVectors(vec, array); + break; + case TypeIndex::UInt32: + fillVectors(vec, array); + break; + case TypeIndex::UInt64: + fillVectors(vec, array); + break; + case TypeIndex::Int8: + fillVectors(vec, array); + break; + case TypeIndex::Int16: + fillVectors(vec, array); + break; + case TypeIndex::Int32: + fillVectors(vec, array); + break; + case TypeIndex::Int64: + fillVectors(vec, array); + break; + case TypeIndex::Float32: + fillVectors(vec, array); + break; + case TypeIndex::Float64: + fillVectors(vec, array); + break; + default: + throw Exception( + "Arguments of function " + getName() + " has nested type " + nested_type->getName() + + ". Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + } + } + template ::value>::type> + void fillVectors(std::vector> & vec, const ColumnArray & array) const + { + const auto & data = typeid_cast &>(array.getData()).getData(); + const auto & offsets = array.getOffsets(); + ColumnArray::Offset prev = 0; + for (auto off : offsets) + { + vec.emplace_back(Eigen::Map>(data.data() + prev, off - prev)); + prev = off; + } + } + + template + void fillVectors(std::vector> & vec, const ColumnArray & array) const + { + const auto & data = typeid_cast &>(array.getData()).getData(); + const auto & offsets = array.getOffsets(); + + ColumnArray::Offset prev = 0; + for (auto off : offsets) + { + Eigen::VectorX mat(off - prev); + for (ColumnArray::Offset row = 0; row + prev < off; ++row) + { + mat[row] = static_cast(data[prev + row]); + } + prev = off; + vec.emplace_back(mat); + } + } +}; + +void registerFunctionArrayNorm(FunctionFactory & factory) +{ + factory.registerFunction>>(); + factory.registerFunction>>(); + factory.registerFunction>(); +} + +} diff --git a/src/Functions/array/registerFunctionsArray.cpp b/src/Functions/array/registerFunctionsArray.cpp index f24a2023d40..e2e8b08fbf2 100644 --- a/src/Functions/array/registerFunctionsArray.cpp +++ b/src/Functions/array/registerFunctionsArray.cpp @@ -38,6 +38,7 @@ void registerFunctionArrayReduceInRanges(FunctionFactory &); void registerFunctionMapOp(FunctionFactory &); void registerFunctionMapPopulateSeries(FunctionFactory &); void registerFunctionArrayDistance(FunctionFactory &); +void registerFunctionArrayNorm(FunctionFactory &); void registerFunctionsArray(FunctionFactory & factory) { @@ -77,6 +78,7 @@ void registerFunctionsArray(FunctionFactory & factory) registerFunctionMapOp(factory); registerFunctionMapPopulateSeries(factory); registerFunctionArrayDistance(factory); + registerFunctionArrayNorm(factory); } }