diff --git a/src/Functions/array/arrayDistance.cpp b/src/Functions/array/arrayDistance.cpp index 8004552a840..d405b9a4a94 100644 --- a/src/Functions/array/arrayDistance.cpp +++ b/src/Functions/array/arrayDistance.cpp @@ -1,30 +1,31 @@ -#include +#include #include #include #include #include #include - #include namespace DB { namespace ErrorCodes { - extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int LOGICAL_ERROR; + extern const int SIZES_OF_ARRAYS_DOESNT_MATCH; } template struct LpDistance { static inline String name = "L" + std::to_string(N); - template - static void compute(const Eigen::MatrixBase & left, const Eigen::MatrixBase & right, PaddedPODArray & array) + template + static void compute(const Eigen::MatrixX & left, const Eigen::MatrixX & right, PaddedPODArray & array) { auto & norms = (left - right).colwise().template lpNorm(); + // array.insert() failed to work with Eigen iterators for (auto n : norms) - array.emplace_back(n); + array.push_back(n); } }; @@ -36,19 +37,16 @@ struct LinfDistance : LpDistance struct CosineDistance { static inline String name = "Cosine"; - template - static void compute(const Eigen::MatrixBase & left, const Eigen::MatrixBase & right, PaddedPODArray & array) + template + static void compute(const Eigen::MatrixX & left, const Eigen::MatrixX & right, PaddedPODArray & array) { - // auto & nx = left.colwise().normalized().eval(); - // auto & ny = right.colwise().normalized().eval(); - // auto & dist = 1.0 - x.cwiseProduct(y).colwise().sum().array(); auto & prod = left.cwiseProduct(right).colwise().sum(); auto & nx = left.colwise().norm(); auto & ny = right.colwise().norm(); auto & nm = nx.cwiseProduct(ny).cwiseInverse(); auto & dist = 1.0 - prod.cwiseProduct(nm).array(); for (auto d : dist) - array.emplace_back(d); + array.push_back(d); } }; @@ -63,30 +61,18 @@ public: bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } bool useDefaultImplementationForConstants() const override { return true; } - DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & /*arguments*/) const override + DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override { - return std::make_shared(); - } - - ColumnPtr - executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override - { - const auto & left = arguments[0]; - const auto & right = arguments[1]; - - const auto & col_x = left.column->convertToFullColumnIfConst(); - const auto & col_y = right.column->convertToFullColumnIfConst(); - const auto * arr_x = checkAndGetColumn(col_x.get()); - const auto * arr_y = checkAndGetColumn(col_y.get()); - if (!arr_x || !arr_y) + DataTypes types; + for (const auto & argument : arguments) { - throw Exception("Argument of function " + String(name) + " must be array. ", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + const auto * array_type = checkAndGetDataType(argument.type.get()); + if (!array_type) + throw Exception("Argument of function " + getName() + " must be array. ", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); + + types.push_back(array_type->getNestedType()); } - - const auto & type_x = checkAndGetDataType(left.type.get())->getNestedType(); - const auto & type_y = checkAndGetDataType(right.type.get())->getNestedType(); - - const auto & common_type = getLeastSupertype(DataTypes{type_x, type_y}); + const auto & common_type = getLeastSupertype(types); switch (common_type->getTypeId()) { case TypeIndex::UInt8: @@ -96,31 +82,55 @@ public: case TypeIndex::Int16: case TypeIndex::Int32: case TypeIndex::Float32: - return executeWithType(*arr_x, *arr_y, type_x, type_y, result_type); + return std::make_shared(); case TypeIndex::UInt64: case TypeIndex::Int64: case TypeIndex::Float64: - return executeWithType(*arr_x, *arr_y, type_x, type_y, result_type); + return std::make_shared(); default: throw Exception( - "Array nested type " + common_type->getName() + "Arguments of function " + getName() + " has nested type " + common_type->getName() + ". Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } } + ColumnPtr + executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override + { + const auto & type_x = typeid_cast(arguments[0].type.get())->getNestedType(); + const auto & type_y = typeid_cast(arguments[1].type.get())->getNestedType(); + + const auto & col_x = arguments[0].column->convertToFullColumnIfConst(); + const auto & col_y = arguments[1].column->convertToFullColumnIfConst(); + + const auto * arr_x = assert_cast(col_x.get()); + const auto * arr_y = assert_cast(col_y.get()); + + auto result = result_type->createColumn(); + switch (result_type->getTypeId()) + { + case TypeIndex::Float32: + executeWithType(*arr_x, *arr_y, type_x, type_y, result); + break; + case TypeIndex::Float64: + executeWithType(*arr_x, *arr_y, type_x, type_y, result); + break; + default: + throw Exception("Unexpected result type.", ErrorCodes::LOGICAL_ERROR); + } + return result; + } + private: template - static ColumnPtr executeWithType( + void executeWithType( const ColumnArray & array_x, const ColumnArray & array_y, const DataTypePtr & type_x, const DataTypePtr & type_y, - const DataTypePtr & result_type) + MutableColumnPtr & column) const { - auto result = result_type->createColumn(); - auto & array = typeid_cast(*result).getData(); - Eigen::MatrixX mx, my; columnToMatrix(array_x, type_x, mx); columnToMatrix(array_y, type_y, my); @@ -128,15 +138,14 @@ private: if (mx.rows() && my.rows() && mx.rows() != my.rows()) { throw Exception( - "Arguments of function " + String(name) + " have different array sizes.", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); + "Arguments of function " + getName() + " have different array sizes.", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); } - - Kernel::compute(mx, my, array); - return result; + auto & data = assert_cast &>(*column).getData(); + Kernel::compute(mx, my, data); } template - static void columnToMatrix(const ColumnArray & array, const DataTypePtr & nested_type, Eigen::MatrixX & mat) + void columnToMatrix(const ColumnArray & array, const DataTypePtr & nested_type, Eigen::MatrixX & mat) const { switch (nested_type->getTypeId()) { @@ -172,7 +181,7 @@ private: break; default: throw Exception( - "Array nested type " + nested_type->getName() + "Arguments of function " + getName() + " has nested type " + nested_type->getName() + ". Support: UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); } @@ -180,25 +189,28 @@ private: // optimize for float/ double template ::value>::type> - static void fillMatrix(Eigen::MatrixX & mat, const ColumnArray & array) + void fillMatrix(Eigen::MatrixX & mat, const ColumnArray & array) const { - const auto & vec = typeid_cast &>(array.getData()); - const auto & data = vec.getData(); + const auto & data = typeid_cast &>(array.getData()).getData(); const auto & offsets = array.getOffsets(); mat = Eigen::Map>(data.data(), offsets.front(), offsets.size()); } template - static void fillMatrix(Eigen::MatrixX & mat, const ColumnArray & array) + void fillMatrix(Eigen::MatrixX & mat, const ColumnArray & array) const { - const auto & vec = typeid_cast &>(array.getData()); - const auto & data = vec.getData(); + const auto & data = typeid_cast &>(array.getData()).getData(); const auto & offsets = array.getOffsets(); - mat.resize(offsets.front(), offsets.size()); + size_t rows = offsets.front(), cols = offsets.size(); + mat.resize(rows, cols); ColumnArray::Offset prev = 0, col = 0; - for (auto off : offsets) + for (ColumnArray::Offset off : offsets) { + if (off - prev != rows) + throw Exception( + "Arguments of function " + getName() + " have different array sizes.", ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH); + for (ColumnArray::Offset row = 0; row < off - prev; ++row) { mat(row, col) = static_cast(data[prev + row]); diff --git a/tests/queries/0_stateless/02280_array_distance.reference b/tests/queries/0_stateless/02280_array_distance.reference index 9de9521befe..75e24833e1e 100644 --- a/tests/queries/0_stateless/02280_array_distance.reference +++ b/tests/queries/0_stateless/02280_array_distance.reference @@ -19,12 +19,12 @@ nan 6 8 9 -0.020204103 -0.118082896 +0.020204102886728692 +0.11808289631180302 0 -1 1 218.74643 -1 2 1348.2118 -2 1 219.28064 -2 2 1347.4009 -3 1 214.35251 -3 2 1342.8857 +1 1 218.74642854227358 +1 2 1348.2117786164013 +2 1 219.28064210048274 +2 2 1347.4008312302617 +3 1 214.35251339790725 +3 2 1342.8856987845243 diff --git a/tests/queries/0_stateless/02280_array_distance.sql b/tests/queries/0_stateless/02280_array_distance.sql index 0f269b11179..835b2cf15ae 100644 --- a/tests/queries/0_stateless/02280_array_distance.sql +++ b/tests/queries/0_stateless/02280_array_distance.sql @@ -22,6 +22,9 @@ SELECT arrayCosineDistance(v, materialize([1., 1., 1.])) FROM vec1; INSERT INTO vec2 VALUES (1, [100, 200, 0]), (2, [888, 777, 666]); SELECT v1.id, v2.id, arrayL2Distance(v1.v, v2.v) as dist FROM vec1 v1, vec2 v2; +INSERT INTO vec2 VALUES (3, [123]); +SELECT v1.id, v2.id, arrayL2Distance(v1.v, v2.v) as dist FROM vec1 v1, vec2 v2; -- { serverError 190 } + SELECT arrayL1Distance([0, 0], [1]); -- { serverError 190 } SELECT arrayL2Distance((1, 2), (3,4)); -- { serverError 43 }