From 34d82625b58e157d937d48a9660079d63deab785 Mon Sep 17 00:00:00 2001 From: FFFFFFFHHHHHHH <916677625@qq.com> Date: Sun, 23 Apr 2023 02:58:00 +0800 Subject: [PATCH 1/8] feat: add dotProduct for array --- src/Functions/array/arrayAUC.cpp | 2 +- src/Functions/array/arrayDotProduct.cpp | 75 ++++++++++++++++++++++++ src/Functions/array/arrayScalarProduct.h | 71 +++++++++++++--------- src/Functions/vectorFunctions.cpp | 16 ++++- 4 files changed, 133 insertions(+), 31 deletions(-) create mode 100644 src/Functions/array/arrayDotProduct.cpp diff --git a/src/Functions/array/arrayAUC.cpp b/src/Functions/array/arrayAUC.cpp index 2890ae55886..297394822d9 100644 --- a/src/Functions/array/arrayAUC.cpp +++ b/src/Functions/array/arrayAUC.cpp @@ -91,7 +91,7 @@ public: return std::make_shared>(); } - template + template static ResultType apply( const T * scores, const U * labels, diff --git a/src/Functions/array/arrayDotProduct.cpp b/src/Functions/array/arrayDotProduct.cpp new file mode 100644 index 00000000000..4b9433f683d --- /dev/null +++ b/src/Functions/array/arrayDotProduct.cpp @@ -0,0 +1,75 @@ +#include +#include +#include +#include "arrayScalarProduct.h" + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; +} + +struct NameArrayDotProduct +{ + static constexpr auto name = "arrayDotProduct"; +}; + + +class ArrayDotProductImpl +{ +public: + + static DataTypePtr getReturnType(const DataTypePtr left_type, const DataTypePtr & right_type) + { + const auto & common_type = getLeastSupertype(DataTypes{left_type, right_type}); + switch (common_type->getTypeId()) + { + case TypeIndex::UInt8: + case TypeIndex::UInt16: + case TypeIndex::UInt32: + case TypeIndex::Int8: + case TypeIndex::Int16: + case TypeIndex::Int32: + case TypeIndex::UInt64: + case TypeIndex::Int64: + case TypeIndex::Float64: + return std::make_shared(); + case TypeIndex::Float32: + return std::make_shared(); + default: + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Arguments of function {} has nested type {}. " + "Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + std::string(NameArrayDotProduct::name), + common_type->getName()); + } + } + + template + static ResultType apply( + const T * left, + const U * right, + size_t size) + { + ResultType result = 0; + for (size_t i = 0; i < size; ++i) + result += static_cast(left[i]) * static_cast(right[i]); + return result; + } + +}; + +using FunctionArrayDotProduct = FunctionArrayScalarProduct; + +REGISTER_FUNCTION(ArrayDotProduct) +{ + factory.registerFunction(); +} + +/// These functions are used by TupleOrArrayFunction in Function/vectorFunctions.cpp +FunctionPtr createFunctionArrayDotProduct(ContextPtr context_) { return FunctionArrayDotProduct::create(context_); } +} diff --git a/src/Functions/array/arrayScalarProduct.h b/src/Functions/array/arrayScalarProduct.h index 94ce1bc533c..ded6ec8ae29 100644 --- a/src/Functions/array/arrayScalarProduct.h +++ b/src/Functions/array/arrayScalarProduct.h @@ -29,29 +29,28 @@ public: static FunctionPtr create(ContextPtr) { return std::make_shared(); } private: - using ResultColumnType = ColumnVector; - template + template ColumnPtr executeNumber(const ColumnsWithTypeAndName & arguments) const { ColumnPtr res; - if ( (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments)) - || (res = executeNumberNumber(arguments))) + if ( (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments)) + || (res = executeNumberNumber(arguments))) return res; return nullptr; } - template + template ColumnPtr executeNumberNumber(const ColumnsWithTypeAndName & arguments) const { ColumnPtr col1 = arguments[0].column->convertToFullColumnIfConst(); @@ -72,7 +71,7 @@ private: if (!col_nested1 || !col_nested2) return nullptr; - auto col_res = ResultColumnType::create(); + auto col_res = ColumnVector::create(); vector( col_nested1->getData(), @@ -83,12 +82,12 @@ private: return col_res; } - template + template static NO_INLINE void vector( const PaddedPODArray & data1, const PaddedPODArray & data2, const ColumnArray::Offsets & offsets, - PaddedPODArray & result) + PaddedPODArray & result) { size_t size = offsets.size(); result.resize(size); @@ -97,7 +96,7 @@ private: for (size_t i = 0; i < size; ++i) { size_t array_size = offsets[i] - current_offset; - result[i] = Method::apply(&data1[current_offset], &data2[current_offset], array_size); + result[i] = Method::template apply(&data1[current_offset], &data2[current_offset], array_size); current_offset = offsets[i]; } } @@ -130,24 +129,40 @@ public: return Method::getReturnType(nested_types[0], nested_types[1]); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t /* input_rows_count */) const override + template + ColumnPtr executeWithResultType(const ColumnsWithTypeAndName & arguments) const { ColumnPtr res; - if (!((res = executeNumber(arguments)) - || (res = executeNumber(arguments)) - || (res = executeNumber(arguments)) - || (res = executeNumber(arguments)) - || (res = executeNumber(arguments)) - || (res = executeNumber(arguments)) - || (res = executeNumber(arguments)) - || (res = executeNumber(arguments)) - || (res = executeNumber(arguments)) - || (res = executeNumber(arguments)))) + if ( !((res = executeNumber(arguments)) + || (res = executeNumber(arguments)) + || (res = executeNumber(arguments)) + || (res = executeNumber(arguments)) + || (res = executeNumber(arguments)) + || (res = executeNumber(arguments)) + || (res = executeNumber(arguments)) + || (res = executeNumber(arguments)) + || (res = executeNumber(arguments)) + || (res = executeNumber(arguments)))) throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}", arguments[0].column->getName(), getName()); return res; } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /* input_rows_count */) const override + { + switch (result_type->getTypeId()) + { + case TypeIndex::Float32: + return executeWithResultType(arguments); + break; + case TypeIndex::Float64: + return executeWithResultType(arguments); + break; + default: + throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName()); + } + } }; } diff --git a/src/Functions/vectorFunctions.cpp b/src/Functions/vectorFunctions.cpp index a7ab09612cf..b52def28755 100644 --- a/src/Functions/vectorFunctions.cpp +++ b/src/Functions/vectorFunctions.cpp @@ -1429,6 +1429,8 @@ private: FunctionPtr array_function; }; +extern FunctionPtr createFunctionArrayDotProduct(ContextPtr context_); + extern FunctionPtr createFunctionArrayL1Norm(ContextPtr context_); extern FunctionPtr createFunctionArrayL2Norm(ContextPtr context_); extern FunctionPtr createFunctionArrayL2SquaredNorm(ContextPtr context_); @@ -1442,6 +1444,14 @@ extern FunctionPtr createFunctionArrayLpDistance(ContextPtr context_); extern FunctionPtr createFunctionArrayLinfDistance(ContextPtr context_); extern FunctionPtr createFunctionArrayCosineDistance(ContextPtr context_); +struct DotProduct +{ + static constexpr auto name = "dotProduct"; + + static constexpr auto CreateTupleFunction = FunctionDotProduct::create; + static constexpr auto CreateArrayFunction = createFunctionArrayDotProduct; +}; + struct L1NormTraits { static constexpr auto name = "L1Norm"; @@ -1530,6 +1540,8 @@ struct CosineDistanceTraits static constexpr auto CreateArrayFunction = createFunctionArrayCosineDistance; }; +using TupleOrArrayFunctionDotProduct = TupleOrArrayFunction; + using TupleOrArrayFunctionL1Norm = TupleOrArrayFunction; using TupleOrArrayFunctionL2Norm = TupleOrArrayFunction; using TupleOrArrayFunctionL2SquaredNorm = TupleOrArrayFunction; @@ -1615,8 +1627,8 @@ If the types of the first interval (or the interval in the tuple) and the second factory.registerFunction(); factory.registerFunction(); - factory.registerFunction(); - factory.registerAlias("scalarProduct", FunctionDotProduct::name, FunctionFactory::CaseInsensitive); + factory.registerFunction(); + factory.registerAlias("scalarProduct", TupleOrArrayFunctionDotProduct::name, FunctionFactory::CaseInsensitive); factory.registerFunction(); factory.registerFunction(); From 81fa4701aaa00974aeb8173c9847dfbe220c2a27 Mon Sep 17 00:00:00 2001 From: FFFFFFFHHHHHHH <916677625@qq.com> Date: Sun, 23 Apr 2023 03:23:14 +0800 Subject: [PATCH 2/8] feat: add dotProduct for array --- .../0_stateless/02708_dot_product.reference | 10 ++++++ .../queries/0_stateless/02708_dot_product.sql | 34 +++++++++++++++++++ 2 files changed, 44 insertions(+) create mode 100644 tests/queries/0_stateless/02708_dot_product.reference create mode 100644 tests/queries/0_stateless/02708_dot_product.sql diff --git a/tests/queries/0_stateless/02708_dot_product.reference b/tests/queries/0_stateless/02708_dot_product.reference new file mode 100644 index 00000000000..7106b870fab --- /dev/null +++ b/tests/queries/0_stateless/02708_dot_product.reference @@ -0,0 +1,10 @@ +3881.304 +3881.304 +3881.304 +376.5 +230 +0 +Float64 +Float32 +Float64 +Float64 diff --git a/tests/queries/0_stateless/02708_dot_product.sql b/tests/queries/0_stateless/02708_dot_product.sql new file mode 100644 index 00000000000..46450ae6394 --- /dev/null +++ b/tests/queries/0_stateless/02708_dot_product.sql @@ -0,0 +1,34 @@ +SELECT dotProduct([12, 2.22, 302], [1.32, 231.2, 11.1]); +SELECT scalarProduct([12, 2.22, 302], [1.32, 231.2, 11.1]); +SELECT arrayDotProduct([12, 2.22, 302], [1.32, 231.2, 11.1]); + +SELECT dotProduct([1.3, 2, 3, 4, 5], [222, 12, 5.3, 2, 8]); + +SELECT dotProduct([1, 1, 1, 1, 1], [222, 12, 0, -12, 8]); + +SELECT round(dotProduct([-1, 2, 3.002], [2, 3.4, 4]) - dotProduct((-1, 2, 3.002), (2, 3.4, 4)), 2); + + +DROP TABLE IF EXISTS product_fp64_fp64; +CREATE TABLE product_fp64_fp64 (x Array(Float64), y Array(Float64)) engine = MergeTree() order by x; +INSERT INTO TABLE product_fp64_fp64 (x, y) values ([1, 2], [3, 4]); +SELECT toTypeName(dotProduct(x, y)) from product_fp64_fp64; +DROP TABLE product_fp64_fp64; + +DROP TABLE IF EXISTS product_fp32_fp32; +CREATE TABLE product_fp32_fp32 (x Array(Float32), y Array(Float32)) engine = MergeTree() order by x; +INSERT INTO TABLE product_fp32_fp32 (x, y) values ([1, 2], [3, 4]); +SELECT toTypeName(dotProduct(x, y)) from product_fp32_fp32; +DROP TABLE product_fp32_fp32; + +DROP TABLE IF EXISTS product_fp32_fp64; +CREATE TABLE product_fp32_fp64 (x Array(Float32), y Array(Float64)) engine = MergeTree() order by x; +INSERT INTO TABLE product_fp32_fp64 (x, y) values ([1, 2], [3, 4]); +SELECT toTypeName(dotProduct(x, y)) from product_fp32_fp64; +DROP TABLE product_fp32_fp64; + +DROP TABLE IF EXISTS product_uint8_fp64; +CREATE TABLE product_uint8_fp64 (x Array(UInt8), y Array(Float64)) engine = MergeTree() order by x; +INSERT INTO TABLE product_uint8_fp64 (x, y) values ([1, 2], [3, 4]); +SELECT toTypeName(dotProduct(x, y)) from product_uint8_fp64; +DROP TABLE product_uint8_fp64; From 79ae949b566e5023500d037db6327a050dc7a11f Mon Sep 17 00:00:00 2001 From: FFFFFFFHHHHHHH <916677625@qq.com> Date: Sun, 23 Apr 2023 03:35:09 +0800 Subject: [PATCH 3/8] fix style --- src/Functions/array/arrayDotProduct.cpp | 7 ++----- src/Functions/array/arrayScalarProduct.h | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/Functions/array/arrayDotProduct.cpp b/src/Functions/array/arrayDotProduct.cpp index 4b9433f683d..0d21bae90e3 100644 --- a/src/Functions/array/arrayDotProduct.cpp +++ b/src/Functions/array/arrayDotProduct.cpp @@ -17,12 +17,10 @@ struct NameArrayDotProduct static constexpr auto name = "arrayDotProduct"; }; - class ArrayDotProductImpl { public: - - static DataTypePtr getReturnType(const DataTypePtr left_type, const DataTypePtr & right_type) + static DataTypePtr getReturnType(const DataTypePtr & left_type, const DataTypePtr & right_type) { const auto & common_type = getLeastSupertype(DataTypes{left_type, right_type}); switch (common_type->getTypeId()) @@ -56,11 +54,10 @@ public: size_t size) { ResultType result = 0; - for (size_t i = 0; i < size; ++i) + for (size_t i = 0; i < size; ++i) result += static_cast(left[i]) * static_cast(right[i]); return result; } - }; using FunctionArrayDotProduct = FunctionArrayScalarProduct; diff --git a/src/Functions/array/arrayScalarProduct.h b/src/Functions/array/arrayScalarProduct.h index ded6ec8ae29..0d1bf44a3e7 100644 --- a/src/Functions/array/arrayScalarProduct.h +++ b/src/Functions/array/arrayScalarProduct.h @@ -133,7 +133,7 @@ public: ColumnPtr executeWithResultType(const ColumnsWithTypeAndName & arguments) const { ColumnPtr res; - if ( !((res = executeNumber(arguments)) + if (!((res = executeNumber(arguments)) || (res = executeNumber(arguments)) || (res = executeNumber(arguments)) || (res = executeNumber(arguments)) From 4145abf547c3f09b849aa8e3429a318a340ef668 Mon Sep 17 00:00:00 2001 From: FFFFFFFHHHHHHH <916677625@qq.com> Date: Sun, 23 Apr 2023 04:04:05 +0800 Subject: [PATCH 4/8] fix style --- src/Functions/array/arrayScalarProduct.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Functions/array/arrayScalarProduct.h b/src/Functions/array/arrayScalarProduct.h index 0d1bf44a3e7..5c36f2492c6 100644 --- a/src/Functions/array/arrayScalarProduct.h +++ b/src/Functions/array/arrayScalarProduct.h @@ -18,6 +18,7 @@ namespace ErrorCodes extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int BAD_ARGUMENTS; + extern const int LOGICAL_ERROR; } From 69e39aba80b4fadea79bfdf0f83a23848520948e Mon Sep 17 00:00:00 2001 From: FFFFFFFHHHHHHH <916677625@qq.com> Date: Mon, 24 Apr 2023 22:13:13 +0800 Subject: [PATCH 5/8] fix test --- .../02415_all_new_functions_must_be_documented.reference | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference index 4ff8c2d3af1..98cae995b47 100644 --- a/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference +++ b/tests/queries/0_stateless/02415_all_new_functions_must_be_documented.reference @@ -97,6 +97,7 @@ arrayCumSum arrayCumSumNonNegative arrayDifference arrayDistinct +arrayDotProduct arrayElement arrayEnumerate arrayEnumerateDense From a7e04b7576d5dc5cc2b7feb4a043eb4361d9eeb9 Mon Sep 17 00:00:00 2001 From: fhbai Date: Tue, 9 May 2023 11:36:15 +0800 Subject: [PATCH 6/8] fix return type --- src/Functions/array/arrayDotProduct.cpp | 59 +++++++++++-------- src/Functions/array/arrayScalarProduct.h | 22 +++++-- .../0_stateless/02708_dot_product.reference | 4 ++ .../queries/0_stateless/02708_dot_product.sql | 23 +++++++- 4 files changed, 76 insertions(+), 32 deletions(-) diff --git a/src/Functions/array/arrayDotProduct.cpp b/src/Functions/array/arrayDotProduct.cpp index 0d21bae90e3..7aa9f1d49c7 100644 --- a/src/Functions/array/arrayDotProduct.cpp +++ b/src/Functions/array/arrayDotProduct.cpp @@ -1,7 +1,12 @@ #include #include #include -#include "arrayScalarProduct.h" +#include +#include +#include +#include +#include +#include namespace DB @@ -20,31 +25,33 @@ struct NameArrayDotProduct class ArrayDotProductImpl { public: - static DataTypePtr getReturnType(const DataTypePtr & left_type, const DataTypePtr & right_type) + static DataTypePtr getReturnType(const DataTypePtr & left, const DataTypePtr & right) { - const auto & common_type = getLeastSupertype(DataTypes{left_type, right_type}); - switch (common_type->getTypeId()) - { - case TypeIndex::UInt8: - case TypeIndex::UInt16: - case TypeIndex::UInt32: - case TypeIndex::Int8: - case TypeIndex::Int16: - case TypeIndex::Int32: - case TypeIndex::UInt64: - case TypeIndex::Int64: - case TypeIndex::Float64: - return std::make_shared(); - case TypeIndex::Float32: - return std::make_shared(); - default: - throw Exception( - ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, - "Arguments of function {} has nested type {}. " - "Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", - std::string(NameArrayDotProduct::name), - common_type->getName()); - } + using Types = TypeList; + + DataTypePtr result_type; + bool valid = castTypeToEither(Types{}, left.get(), [&](const auto & left_) { + return castTypeToEither(Types{}, right.get(), [&](const auto & right_) { + using LeftDataType = typename std::decay_t::FieldType; + using RightDataType = typename std::decay_t::FieldType; + using ResultType = typename NumberTraits::ResultOfAdditionMultiplication::Type; + if (std::is_same_v && std::is_same_v) + result_type = std::make_shared(); + else + result_type = std::make_shared>(); + return true; + }); + }); + + if (!valid) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Arguments of function {} " + "only support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + std::string(NameArrayDotProduct::name)); + return result_type; } template @@ -67,6 +74,6 @@ REGISTER_FUNCTION(ArrayDotProduct) factory.registerFunction(); } -/// These functions are used by TupleOrArrayFunction in Function/vectorFunctions.cpp +// These functions are used by TupleOrArrayFunction in Function/vectorFunctions.cpp FunctionPtr createFunctionArrayDotProduct(ContextPtr context_) { return FunctionArrayDotProduct::create(context_); } } diff --git a/src/Functions/array/arrayScalarProduct.h b/src/Functions/array/arrayScalarProduct.h index 5c36f2492c6..374a2d8a194 100644 --- a/src/Functions/array/arrayScalarProduct.h +++ b/src/Functions/array/arrayScalarProduct.h @@ -6,6 +6,7 @@ #include #include #include +#include namespace DB @@ -154,12 +155,23 @@ public: { switch (result_type->getTypeId()) { - case TypeIndex::Float32: - return executeWithResultType(arguments); - break; - case TypeIndex::Float64: - return executeWithResultType(arguments); + #define SUPPORTED_TYPE(type) \ + case TypeIndex::type: \ + return executeWithResultType(arguments); \ break; + + SUPPORTED_TYPE(UInt8) + SUPPORTED_TYPE(UInt16) + SUPPORTED_TYPE(UInt32) + SUPPORTED_TYPE(UInt64) + SUPPORTED_TYPE(Int8) + SUPPORTED_TYPE(Int16) + SUPPORTED_TYPE(Int32) + SUPPORTED_TYPE(Int64) + SUPPORTED_TYPE(Float32) + SUPPORTED_TYPE(Float64) + #undef SUPPORTED_TYPE + default: throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName()); } diff --git a/tests/queries/0_stateless/02708_dot_product.reference b/tests/queries/0_stateless/02708_dot_product.reference index 7106b870fab..45e53871aa2 100644 --- a/tests/queries/0_stateless/02708_dot_product.reference +++ b/tests/queries/0_stateless/02708_dot_product.reference @@ -4,7 +4,11 @@ 376.5 230 0 +0 Float64 Float32 Float64 Float64 +UInt16 +UInt64 +Int64 diff --git a/tests/queries/0_stateless/02708_dot_product.sql b/tests/queries/0_stateless/02708_dot_product.sql index 46450ae6394..e94cb577bf4 100644 --- a/tests/queries/0_stateless/02708_dot_product.sql +++ b/tests/queries/0_stateless/02708_dot_product.sql @@ -1,13 +1,16 @@ SELECT dotProduct([12, 2.22, 302], [1.32, 231.2, 11.1]); + SELECT scalarProduct([12, 2.22, 302], [1.32, 231.2, 11.1]); + SELECT arrayDotProduct([12, 2.22, 302], [1.32, 231.2, 11.1]); SELECT dotProduct([1.3, 2, 3, 4, 5], [222, 12, 5.3, 2, 8]); SELECT dotProduct([1, 1, 1, 1, 1], [222, 12, 0, -12, 8]); -SELECT round(dotProduct([-1, 2, 3.002], [2, 3.4, 4]) - dotProduct((-1, 2, 3.002), (2, 3.4, 4)), 2); +SELECT round(dotProduct([12345678901234567], [1]) - dotProduct(tuple(12345678901234567), tuple(1)), 2); +SELECT round(dotProduct([-1, 2, 3.002], [2, 3.4, 4]) - dotProduct((-1, 2, 3.002), (2, 3.4, 4)), 2); DROP TABLE IF EXISTS product_fp64_fp64; CREATE TABLE product_fp64_fp64 (x Array(Float64), y Array(Float64)) engine = MergeTree() order by x; @@ -32,3 +35,21 @@ CREATE TABLE product_uint8_fp64 (x Array(UInt8), y Array(Float64)) engine = Merg INSERT INTO TABLE product_uint8_fp64 (x, y) values ([1, 2], [3, 4]); SELECT toTypeName(dotProduct(x, y)) from product_uint8_fp64; DROP TABLE product_uint8_fp64; + +DROP TABLE IF EXISTS product_uint8_uint8; +CREATE TABLE product_uint8_uint8 (x Array(UInt8), y Array(UInt8)) engine = MergeTree() order by x; +INSERT INTO TABLE product_uint8_uint8 (x, y) values ([1, 2], [3, 4]); +SELECT toTypeName(dotProduct(x, y)) from product_uint8_uint8; +DROP TABLE product_uint8_uint8; + +DROP TABLE IF EXISTS product_uint64_uint64; +CREATE TABLE product_uint64_uint64 (x Array(UInt64), y Array(UInt64)) engine = MergeTree() order by x; +INSERT INTO TABLE product_uint64_uint64 (x, y) values ([1, 2], [3, 4]); +SELECT toTypeName(dotProduct(x, y)) from product_uint64_uint64; +DROP TABLE product_uint64_uint64; + +DROP TABLE IF EXISTS product_int32_uint64; +CREATE TABLE product_int32_uint64 (x Array(Int32), y Array(UInt64)) engine = MergeTree() order by x; +INSERT INTO TABLE product_int32_uint64 (x, y) values ([1, 2], [3, 4]); +SELECT toTypeName(dotProduct(x, y)) from product_int32_uint64; +DROP TABLE product_int32_uint64; From 79398f612f4a94cee996d6256155c5e229ee090b Mon Sep 17 00:00:00 2001 From: FFFFFFFHHHHHHH <916677625@qq.com> Date: Tue, 9 May 2023 11:50:38 +0800 Subject: [PATCH 7/8] fix style --- src/Functions/array/arrayDotProduct.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Functions/array/arrayDotProduct.cpp b/src/Functions/array/arrayDotProduct.cpp index 7aa9f1d49c7..e3c80775f1b 100644 --- a/src/Functions/array/arrayDotProduct.cpp +++ b/src/Functions/array/arrayDotProduct.cpp @@ -32,8 +32,10 @@ public: DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64>; DataTypePtr result_type; - bool valid = castTypeToEither(Types{}, left.get(), [&](const auto & left_) { - return castTypeToEither(Types{}, right.get(), [&](const auto & right_) { + bool valid = castTypeToEither(Types{}, left.get(), [&](const auto & left_) + { + return castTypeToEither(Types{}, right.get(), [&](const auto & right_) + { using LeftDataType = typename std::decay_t::FieldType; using RightDataType = typename std::decay_t::FieldType; using ResultType = typename NumberTraits::ResultOfAdditionMultiplication::Type; From c10435489487626448b8e861074ef9b41ecb9f4e Mon Sep 17 00:00:00 2001 From: fhbai Date: Wed, 17 May 2023 14:39:30 +0800 Subject: [PATCH 8/8] fix --- src/Functions/array/arrayDotProduct.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Functions/array/arrayDotProduct.cpp b/src/Functions/array/arrayDotProduct.cpp index e3c80775f1b..d17c223cc2f 100644 --- a/src/Functions/array/arrayDotProduct.cpp +++ b/src/Functions/array/arrayDotProduct.cpp @@ -57,7 +57,7 @@ public: } template - static ResultType apply( + static inline NO_SANITIZE_UNDEFINED ResultType apply( const T * left, const U * right, size_t size)