diff --git a/src/Functions/array/arrayDistance.cpp b/src/Functions/array/arrayDistance.cpp index 71564f6fa93..6ed4bf24f99 100644 --- a/src/Functions/array/arrayDistance.cpp +++ b/src/Functions/array/arrayDistance.cpp @@ -18,11 +18,11 @@ namespace DB { namespace ErrorCodes { + extern const int ARGUMENT_OUT_OF_BOUND; extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int LOGICAL_ERROR; extern const int SIZES_OF_ARRAYS_DONT_MATCH; - extern const int ARGUMENT_OUT_OF_BOUND; } struct L1Distance @@ -357,7 +357,7 @@ public: throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} has nested type {}. " - "Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", getName(), common_type->getName()); } @@ -379,17 +379,17 @@ public: } -#define SUPPORTED_TYPES(action) \ - action(UInt8) \ - action(UInt16) \ - action(UInt32) \ - action(UInt64) \ - action(Int8) \ - action(Int16) \ - action(Int32) \ - action(Int64) \ - action(Float32) \ - action(Float64) +#define SUPPORTED_TYPES(ACTION) \ + ACTION(UInt8) \ + ACTION(UInt16) \ + ACTION(UInt32) \ + ACTION(UInt64) \ + ACTION(Int8) \ + ACTION(Int16) \ + ACTION(Int32) \ + ACTION(Int64) \ + ACTION(Float32) \ + ACTION(Float64) private: @@ -398,12 +398,11 @@ private: { DataTypePtr type_x = typeid_cast(arguments[0].type.get())->getNestedType(); - /// Dynamic disaptch based on the 1st argument type switch (type_x->getTypeId()) { #define ON_TYPE(type) \ case TypeIndex::type: \ - return executeWithFirstType(arguments, input_rows_count); \ + return executeWithResultTypeAndLeftType(arguments, input_rows_count); \ break; SUPPORTED_TYPES(ON_TYPE) @@ -413,23 +412,22 @@ private: throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} has nested type {}. " - "Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", getName(), type_x->getName()); } } - template - ColumnPtr executeWithFirstType(const ColumnsWithTypeAndName & arguments, size_t input_rows_count) const + template + ColumnPtr executeWithResultTypeAndLeftType(const ColumnsWithTypeAndName & arguments, size_t input_rows_count) const { DataTypePtr type_y = typeid_cast(arguments[1].type.get())->getNestedType(); - /// Dynamic disaptch based on the 2nd argument type switch (type_y->getTypeId()) { #define ON_TYPE(type) \ case TypeIndex::type: \ - return executeWithTypes(arguments[0].column, arguments[1].column, input_rows_count, arguments); \ + return executeWithResultTypeAndLeftTypeAndRightType(arguments[0].column, arguments[1].column, input_rows_count, arguments); \ break; SUPPORTED_TYPES(ON_TYPE) @@ -439,59 +437,43 @@ private: throw Exception( ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Arguments of function {} has nested type {}. " - "Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", getName(), type_y->getName()); } } - template - ColumnPtr executeWithTypes(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const + template + ColumnPtr executeWithResultTypeAndLeftTypeAndRightType(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const { if (typeid_cast(col_x.get())) { - return executeWithTypesFirstArgConst(col_x, col_y, input_rows_count, arguments); + return executeWithLeftArgConst(col_x, col_y, input_rows_count, arguments); } else if (typeid_cast(col_y.get())) { - return executeWithTypesFirstArgConst(col_y, col_x, input_rows_count, arguments); + return executeWithLeftArgConst(col_y, col_x, input_rows_count, arguments); } - col_x = col_x->convertToFullColumnIfConst(); - col_y = col_y->convertToFullColumnIfConst(); - const auto & array_x = *assert_cast(col_x.get()); const auto & array_y = *assert_cast(col_y.get()); - const auto & data_x = typeid_cast &>(array_x.getData()).getData(); - const auto & data_y = typeid_cast &>(array_y.getData()).getData(); + const auto & data_x = typeid_cast &>(array_x.getData()).getData(); + const auto & data_y = typeid_cast &>(array_y.getData()).getData(); const auto & offsets_x = array_x.getOffsets(); - const auto & offsets_y = array_y.getOffsets(); - /// Check that arrays in both columns are the sames size - for (size_t row = 0; row < offsets_x.size(); ++row) - { - if (offsets_x[row] != offsets_y[row]) [[unlikely]] - { - ColumnArray::Offset prev_offset = row > 0 ? offsets_x[row] : 0; - throw Exception( - ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, - "Arguments of function {} have different array sizes: {} and {}", - getName(), - offsets_x[row] - prev_offset, - offsets_y[row] - prev_offset); - } - } + if (!array_x.hasEqualOffsets(array_y)) + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Array arguments for function {} must have equal sizes", getName()); const typename Kernel::ConstParams kernel_params = initConstParams(arguments); - auto result = ColumnVector::create(input_rows_count); - auto & result_data = result->getData(); + auto col_res = ColumnVector::create(input_rows_count); + auto & result_data = col_res->getData(); - /// Do the actual computation ColumnArray::Offset prev = 0; size_t row = 0; + for (auto off : offsets_x) { /// Process chunks in vectorized manner @@ -517,12 +499,12 @@ private: result_data[row] = Kernel::finalize(state, kernel_params); row++; } - return result; + return col_res; } /// Special case when the 1st parameter is Const - template - ColumnPtr executeWithTypesFirstArgConst(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const + template + ColumnPtr executeWithLeftArgConst(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count, const ColumnsWithTypeAndName & arguments) const { col_x = assert_cast(col_x.get())->getDataColumnPtr(); col_y = col_y->convertToFullColumnIfConst(); @@ -530,26 +512,25 @@ private: const auto & array_x = *assert_cast(col_x.get()); const auto & array_y = *assert_cast(col_y.get()); - const auto & data_x = typeid_cast &>(array_x.getData()).getData(); - const auto & data_y = typeid_cast &>(array_y.getData()).getData(); + const auto & data_x = typeid_cast &>(array_x.getData()).getData(); + const auto & data_y = typeid_cast &>(array_y.getData()).getData(); const auto & offsets_x = array_x.getOffsets(); const auto & offsets_y = array_y.getOffsets(); - /// Check that arrays in both columns are the sames size ColumnArray::Offset prev_offset = 0; - for (size_t row : collections::range(0, offsets_y.size())) + for (auto offset_y : offsets_y) { - if (offsets_x[0] != offsets_y[row] - prev_offset) [[unlikely]] + if (offsets_x[0] != offset_y - prev_offset) [[unlikely]] { throw Exception( ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Arguments of function {} have different array sizes: {} and {}", getName(), offsets_x[0], - offsets_y[row] - prev_offset); + offset_y - prev_offset); } - prev_offset = offsets_y[row]; + prev_offset = offset_y; } const typename Kernel::ConstParams kernel_params = initConstParams(arguments); @@ -557,7 +538,6 @@ private: auto result = ColumnVector::create(input_rows_count); auto & result_data = result->getData(); - /// Do the actual computation size_t prev = 0; size_t row = 0; @@ -574,7 +554,7 @@ private: /// - the two most common metrics L2 and cosine distance, /// - the most powerful SIMD instruction set (AVX-512F). #if USE_MULTITARGET_CODE - if constexpr (std::is_same_v && std::is_same_v) /// ResultType is Float32 or Float64 + if constexpr (std::is_same_v && std::is_same_v) /// ResultType is Float32 or Float64 { if constexpr (std::is_same_v || std::is_same_v) diff --git a/src/Functions/array/arrayDotProduct.cpp b/src/Functions/array/arrayDotProduct.cpp index 6c615a058c3..783843a89d5 100644 --- a/src/Functions/array/arrayDotProduct.cpp +++ b/src/Functions/array/arrayDotProduct.cpp @@ -18,10 +18,9 @@ namespace DB namespace ErrorCodes { - extern const int BAD_ARGUMENTS; - extern const int ILLEGAL_COLUMN; extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int LOGICAL_ERROR; + extern const int SIZES_OF_ARRAYS_DONT_MATCH; } @@ -141,6 +140,7 @@ public: static FunctionPtr create(ContextPtr) { return std::make_shared(); } size_t getNumberOfArguments() const override { return 2; } bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + bool useDefaultImplementationForConstants() const override { return true; } DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override { @@ -163,26 +163,29 @@ public: return Kernel::getReturnType(nested_types[0], nested_types[1]); } - ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /* input_rows_count */) const override +#define SUPPORTED_TYPES(ACTION) \ + ACTION(UInt8) \ + ACTION(UInt16) \ + ACTION(UInt32) \ + ACTION(UInt64) \ + ACTION(Int8) \ + ACTION(Int16) \ + ACTION(Int32) \ + ACTION(Int64) \ + ACTION(Float32) \ + ACTION(Float64) + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override { switch (result_type->getTypeId()) { - #define SUPPORTED_TYPE(type) \ + #define ON_TYPE(type) \ case TypeIndex::type: \ - return executeWithResultType(arguments); \ + return executeWithResultType(arguments, input_rows_count); \ break; - SUPPORTED_TYPE(UInt8) - SUPPORTED_TYPE(UInt16) - SUPPORTED_TYPE(UInt32) - SUPPORTED_TYPE(UInt64) - SUPPORTED_TYPE(Int8) - SUPPORTED_TYPE(Int16) - SUPPORTED_TYPE(Int32) - SUPPORTED_TYPE(Int64) - SUPPORTED_TYPE(Float32) - SUPPORTED_TYPE(Float64) - #undef SUPPORTED_TYPE + SUPPORTED_TYPES(ON_TYPE) + #undef ON_TYPE default: throw Exception(ErrorCodes::LOGICAL_ERROR, "Unexpected result type {}", result_type->getName()); @@ -191,90 +194,150 @@ public: private: template - ColumnPtr executeWithResultType(const ColumnsWithTypeAndName & arguments) const + ColumnPtr executeWithResultType(const ColumnsWithTypeAndName & arguments, size_t input_rows_count) const { - ColumnPtr res; - if (!((res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)) - || (res = executeWithResultTypeAndLeft(arguments)))) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, - "Illegal column {} of first argument of function {}", arguments[0].column->getName(), getName()); + DataTypePtr type_x = typeid_cast(arguments[0].type.get())->getNestedType(); - return res; + switch (type_x->getTypeId()) + { +#define ON_TYPE(type) \ + case TypeIndex::type: \ + return executeWithResultTypeAndLeftType(arguments, input_rows_count); \ + break; + + SUPPORTED_TYPES(ON_TYPE) +#undef ON_TYPE + + default: + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Arguments of function {} has nested type {}. " + "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + getName(), + type_x->getName()); + } } template - ColumnPtr executeWithResultTypeAndLeft(const ColumnsWithTypeAndName & arguments) const + ColumnPtr executeWithResultTypeAndLeftType(const ColumnsWithTypeAndName & arguments, size_t input_rows_count) const { - ColumnPtr res; - if ( (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments)) - || (res = executeWithResultTypeAndLeftAndRight(arguments))) - return res; + DataTypePtr type_y = typeid_cast(arguments[1].type.get())->getNestedType(); - return nullptr; + switch (type_y->getTypeId()) + { + #define ON_TYPE(type) \ + case TypeIndex::type: \ + return executeWithResultTypeAndLeftTypeAndRightType(arguments[0].column, arguments[1].column, input_rows_count); \ + break; + + SUPPORTED_TYPES(ON_TYPE) + #undef ON_TYPE + + default: + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, + "Arguments of function {} has nested type {}. " + "Supported types: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", + getName(), + type_y->getName()); + } } template - ColumnPtr executeWithResultTypeAndLeftAndRight(const ColumnsWithTypeAndName & arguments) const + ColumnPtr executeWithResultTypeAndLeftTypeAndRightType(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count) const { - ColumnPtr col_left = arguments[0].column->convertToFullColumnIfConst(); - ColumnPtr col_right = arguments[1].column->convertToFullColumnIfConst(); - if (!col_left || !col_right) - return nullptr; + if (typeid_cast(col_x.get())) + { + return executeWithLeftArgConst(col_x, col_y, input_rows_count); + } + else if (typeid_cast(col_y.get())) + { + return executeWithLeftArgConst(col_y, col_x, input_rows_count); + } - const ColumnArray * col_arr_left = checkAndGetColumn(col_left.get()); - const ColumnArray * cokl_arr_right = checkAndGetColumn(col_right.get()); - if (!col_arr_left || !cokl_arr_right) - return nullptr; + const auto & array_x = *assert_cast(col_x.get()); + const auto & array_y = *assert_cast(col_y.get()); - const ColumnVector * col_arr_nested_left = checkAndGetColumn>(col_arr_left->getData()); - const ColumnVector * col_arr_nested_right = checkAndGetColumn>(cokl_arr_right->getData()); - if (!col_arr_nested_left || !col_arr_nested_right) - return nullptr; + const auto & data_x = typeid_cast &>(array_x.getData()).getData(); + const auto & data_y = typeid_cast &>(array_y.getData()).getData(); - if (!col_arr_left->hasEqualOffsets(*cokl_arr_right)) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "Array arguments for function {} must have equal sizes", getName()); + const auto & offsets_x = array_x.getOffsets(); - auto col_res = ColumnVector::create(); + if (!array_x.hasEqualOffsets(array_y)) + throw Exception(ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, "Array arguments for function {} must have equal sizes", getName()); - vector( - col_arr_nested_left->getData(), - col_arr_nested_right->getData(), - col_arr_left->getOffsets(), - col_res->getData()); + auto col_res = ColumnVector::create(input_rows_count); + auto & result_data = col_res->getData(); + + ColumnArray::Offset current_offset = 0; + for (size_t row = 0; row < input_rows_count; ++row) + { + const size_t array_size = offsets_x[row] - current_offset; + + size_t i = 0; + + /// Process chunks in vectorized manner + static constexpr size_t VEC_SIZE = 4; + typename Kernel::template State states[VEC_SIZE]; + for (; i + VEC_SIZE < array_size; i += VEC_SIZE) + { + for (size_t j = 0; j < VEC_SIZE; ++j) + Kernel::template accumulate(states[j], static_cast(data_x[current_offset + i + j]), static_cast(data_y[current_offset + i + j])); + } + + typename Kernel::template State state; + for (const auto & other_state : states) + Kernel::template combine(state, other_state); + + /// Process the tail + for (; i < array_size; ++i) + Kernel::template accumulate(state, static_cast(data_x[current_offset + i]), static_cast(data_y[current_offset + i])); + + result_data[row] = Kernel::template finalize(state); + + current_offset = offsets_x[row]; + } return col_res; } template - static void vector( - const PaddedPODArray & left, - const PaddedPODArray & right, - const ColumnArray::Offsets & offsets, - PaddedPODArray & result) + ColumnPtr executeWithLeftArgConst(ColumnPtr col_x, ColumnPtr col_y, size_t input_rows_count) const { - size_t size = offsets.size(); - result.resize(size); + col_x = assert_cast(col_x.get())->getDataColumnPtr(); + col_y = col_y->convertToFullColumnIfConst(); + + const auto & array_x = *assert_cast(col_x.get()); + const auto & array_y = *assert_cast(col_y.get()); + + const auto & data_x = typeid_cast &>(array_x.getData()).getData(); + const auto & data_y = typeid_cast &>(array_y.getData()).getData(); + + const auto & offsets_x = array_x.getOffsets(); + const auto & offsets_y = array_y.getOffsets(); + + ColumnArray::Offset prev_offset = 0; + for (auto offset_y : offsets_y) + { + if (offsets_x[0] != offset_y - prev_offset) [[unlikely]] + { + throw Exception( + ErrorCodes::SIZES_OF_ARRAYS_DONT_MATCH, + "Arguments of function {} have different array sizes: {} and {}", + getName(), + offsets_x[0], + offset_y - prev_offset); + } + prev_offset = offset_y; + } + + auto col_res = ColumnVector::create(input_rows_count); + auto & result = col_res->getData(); ColumnArray::Offset current_offset = 0; - for (size_t row = 0; row < size; ++row) + for (size_t row = 0; row < input_rows_count; ++row) { - size_t array_size = offsets[row] - current_offset; + const size_t array_size = offsets_x[0]; typename Kernel::template State state; size_t i = 0; @@ -283,13 +346,14 @@ private: /// To avoid combinatorial explosion of SIMD kernels, focus on /// - the two most common input/output types (Float32 x Float32) --> Float32 and (Float64 x Float64) --> Float64 instead of 10 x /// 10 input types x 8 output types, + /// - const/non-const inputs instead of non-const/non-const inputs /// - the most powerful SIMD instruction set (AVX-512F). #if USE_MULTITARGET_CODE if constexpr ((std::is_same_v || std::is_same_v) && std::is_same_v && std::is_same_v) { if (isArchSupported(TargetArch::AVX512F)) - Kernel::template accumulateCombine(&left[current_offset], &right[current_offset], array_size, i, state); + Kernel::template accumulateCombine(&data_x[0], &data_y[current_offset], array_size, i, state); } #else /// Process chunks in vectorized manner @@ -298,7 +362,7 @@ private: for (; i + VEC_SIZE < array_size; i += VEC_SIZE) { for (size_t j = 0; j < VEC_SIZE; ++j) - Kernel::template accumulate(states[j], static_cast(left[i + j]), static_cast(right[i + j])); + Kernel::template accumulate(states[j], static_cast(data_x[i + j]), static_cast(data_y[current_offset + i + j])); } for (const auto & other_state : states) @@ -307,13 +371,14 @@ private: /// Process the tail for (; i < array_size; ++i) - Kernel::template accumulate(state, static_cast(left[i]), static_cast(right[i])); + Kernel::template accumulate(state, static_cast(data_x[i]), static_cast(data_y[current_offset + i])); - /// ResultType res = Kernel::template finalize(state); result[row] = Kernel::template finalize(state); - current_offset = offsets[row]; + current_offset = offsets_y[row]; } + + return col_res; } }; diff --git a/src/Functions/array/arrayNorm.cpp b/src/Functions/array/arrayNorm.cpp index 027a33d094c..e87eff6add1 100644 --- a/src/Functions/array/arrayNorm.cpp +++ b/src/Functions/array/arrayNorm.cpp @@ -175,8 +175,7 @@ public: } } - ColumnPtr - executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t input_rows_count) const override { DataTypePtr type = typeid_cast(arguments[0].type.get())->getNestedType(); ColumnPtr column = arguments[0].column->convertToFullColumnIfConst(); diff --git a/tests/performance/dotProduct.xml b/tests/performance/dotProduct.xml index 6e056964ebb..a0ab7beea9c 100644 --- a/tests/performance/dotProduct.xml +++ b/tests/performance/dotProduct.xml @@ -56,6 +56,7 @@ SELECT sum(dp) FROM (SELECT dotProduct(v, v) AS dp FROM vecs_{element_type}) + WITH (SELECT v FROM vecs_{element_type} limit 1) AS a SELECT sum(dp) FROM (SELECT dotProduct(a, v) AS dp FROM vecs_{element_type}) DROP TABLE vecs_{element_type} diff --git a/tests/queries/0_stateless/02708_dotProduct.reference b/tests/queries/0_stateless/02708_dotProduct.reference index 5cc9a9f0502..93a67e4c0be 100644 --- a/tests/queries/0_stateless/02708_dotProduct.reference +++ b/tests/queries/0_stateless/02708_dotProduct.reference @@ -11,6 +11,8 @@ [-1,-2,-3] [4,5,6] -32 Int64 [1,2,3] [4,5,6] 32 Float32 [1,2,3] [4,5,6] 32 Float64 +[] [] 0 Float32 +[] [] 0 UInt16 -- Tuple (1,2,3) (4,5,6) 32 UInt64 (1,2,3) (4,5,6) 32 UInt64 @@ -24,6 +26,8 @@ (1,2,3) (4,5,6) 32 Float64 -- Non-const argument [1,2,3] [4,5,6] 32 UInt16 +[] [] 0 Float32 +[] [] 0 UInt16 -- Array with mixed element arguments types (result type is the supertype) [1,2,3] [4,5,6] 32 Float32 -- Tuple with mixed element arguments types @@ -32,3 +36,18 @@ 32 32 32 +-- Tests that trigger special paths + -- non-const / non-const +0 61 +1 186 +0 61 +1 186 +0 61 +1 186 + -- const / non-const +0 62 +1 187 +0 62 +1 187 +0 62 +1 187 diff --git a/tests/queries/0_stateless/02708_dotProduct.sql b/tests/queries/0_stateless/02708_dotProduct.sql index 6ad615664e8..05c66777dff 100644 --- a/tests/queries/0_stateless/02708_dotProduct.sql +++ b/tests/queries/0_stateless/02708_dotProduct.sql @@ -4,7 +4,7 @@ SELECT arrayDotProduct([1, 2]); -- { serverError NUMBER_OF_ARGUMENTS_DOESNT_MATC SELECT arrayDotProduct([1, 2], 'abc'); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } SELECT arrayDotProduct('abc', [1, 2]); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } SELECT arrayDotProduct([1, 2], ['abc', 'def']); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } -SELECT arrayDotProduct([1, 2], [3, 4, 5]); -- { serverError BAD_ARGUMENTS } +SELECT arrayDotProduct([1, 2], [3, 4, 5]); -- { serverError SIZES_OF_ARRAYS_DONT_MATCH } SELECT dotProduct([1, 2], (3, 4, 5)); -- { serverError ILLEGAL_TYPE_OF_ARGUMENT } SELECT '-- Tests'; @@ -19,6 +19,9 @@ SELECT [-1, -2, -3]::Array(Int32) AS x, [4, 5, 6]::Array(Int32) AS y, dotProduct SELECT [-1, -2, -3]::Array(Int64) AS x, [4, 5, 6]::Array(Int64) AS y, dotProduct(x, y) AS res, toTypeName(res); SELECT [1, 2, 3]::Array(Float32) AS x, [4, 5, 6]::Array(Float32) AS y, dotProduct(x, y) AS res, toTypeName(res); SELECT [1, 2, 3]::Array(Float64) AS x, [4, 5, 6]::Array(Float64) AS y, dotProduct(x, y) AS res, toTypeName(res); +-- empty arrays +SELECT []::Array(Float32) AS x, []::Array(Float32) AS y, dotProduct(x, y) AS res, toTypeName(res); +SELECT []::Array(UInt8) AS x, []::Array(UInt8) AS y, dotProduct(x, y) AS res, toTypeName(res); SELECT ' -- Tuple'; SELECT (1::UInt8, 2::UInt8, 3::UInt8) AS x, (4::UInt8, 5::UInt8, 6::UInt8) AS y, dotProduct(x, y) AS res, toTypeName(res); @@ -34,6 +37,8 @@ SELECT (1::Float64, 2::Float64, 3::Float64) AS x, (4::Float64, 5::Float64, 6::Fl SELECT '-- Non-const argument'; SELECT materialize([1::UInt8, 2::UInt8, 3::UInt8]) AS x, [4::UInt8, 5::UInt8, 6::UInt8] AS y, dotProduct(x, y) AS res, toTypeName(res); +SELECT materialize([]::Array(Float32)) AS x, []::Array(Float32) AS y, dotProduct(x, y) AS res, toTypeName(res); +SELECT materialize([]::Array(UInt8)) AS x, []::Array(UInt8) AS y, dotProduct(x, y) AS res, toTypeName(res); SELECT ' -- Array with mixed element arguments types (result type is the supertype)'; SELECT [1::UInt16, 2::UInt8, 3::Float32] AS x, [4::Int16, 5::Float32, 6::UInt8] AS y, dotProduct(x, y) AS res, toTypeName(res); @@ -45,3 +50,17 @@ SELECT '-- Aliases'; SELECT scalarProduct([1, 2, 3], [4, 5, 6]); SELECT scalarProduct((1, 2, 3), (4, 5, 6)); SELECT arrayDotProduct([1, 2, 3], [4, 5, 6]); -- actually no alias but the internal function for arrays + +SELECT '-- Tests that trigger special paths'; +DROP TABLE IF EXISTS tab; +CREATE TABLE tab(id UInt64, vec Array(Float32)) ENGINE = MergeTree ORDER BY id; +INSERT INTO tab VALUES (0, [0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0, 3.0, 0.0, 1.0, 2.0]) (1, [5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0]); +SELECT ' -- non-const / non-const'; +SELECT id, arrayDotProduct(vec, vec) FROM tab ORDER BY id; +SELECT id, arrayDotProduct(vec::Array(Float64), vec::Array(Float64)) FROM tab ORDER BY id; +SELECT id, arrayDotProduct(vec::Array(UInt32), vec::Array(UInt32)) FROM tab ORDER BY id; +SELECT ' -- const / non-const'; +SELECT id, arrayDotProduct([5.0, 2.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0]::Array(Float32), vec) FROM tab ORDER BY id; +SELECT id, arrayDotProduct([5.0, 2.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0, 3.0, 5.0, 1.0, 2.0]::Array(Float64), vec) FROM tab ORDER BY id; +SELECT id, arrayDotProduct([5, 2, 2, 3, 5, 1, 2, 3, 5, 1, 2, 3, 5, 1, 2, 3, 5, 1, 2]::Array(UInt32), vec) FROM tab ORDER BY id; +DROP TABLE tab;