From 38a2b0dcc7c0d84136f00a3d5a7a3afbf86cccbb Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Thu, 15 Aug 2024 10:42:06 +0000 Subject: [PATCH] Allow Array(Float64) as type of underlying column --- .../mergetree-family/annindexes.md | 8 +- .../MergeTreeIndexVectorSimilarity.cpp | 79 ++++++++++++------- .../MergeTree/VectorSimilarityCondition.cpp | 4 +- .../MergeTree/VectorSimilarityCondition.h | 4 +- ..._vector_search_index_creation_negative.sql | 2 +- .../02354_vector_search_queries.reference | 4 + .../02354_vector_search_queries.sql | 12 +++ 7 files changed, 76 insertions(+), 37 deletions(-) diff --git a/docs/en/engines/table-engines/mergetree-family/annindexes.md b/docs/en/engines/table-engines/mergetree-family/annindexes.md index 097b0f5850a..1057ccb5fee 100644 --- a/docs/en/engines/table-engines/mergetree-family/annindexes.md +++ b/docs/en/engines/table-engines/mergetree-family/annindexes.md @@ -22,10 +22,10 @@ ORDER BY Distance(vectors, Point) LIMIT N ``` -`vectors` contains N-dimensional values of type [Array(Float32)](../../../sql-reference/data-types/array.md), for example embeddings. -Function `Distance` computes the distance between two vectors. Often, the Euclidean (L2) distance is chosen as distance function but [other -distance functions](/docs/en/sql-reference/functions/distance-functions.md) are also possible. `Point` is the reference point, e.g. `(0.17, -0.33, ...)`, and `N` limits the number of search results. +`vectors` contains N-dimensional values of type [Array(Float32)](../../../sql-reference/data-types/array.md) or Array(Float64), for example +embeddings. Function `Distance` computes the distance between two vectors. Often, the Euclidean (L2) distance is chosen as distance function +but [other distance functions](/docs/en/sql-reference/functions/distance-functions.md) are also possible. `Point` is the reference point, +e.g. `(0.17, 0.33, ...)`, and `N` limits the number of search results. This query returns the top-`N` closest points to the reference point. Parameter `N` limits the number of returned values which is useful for situations where `MaxDistance` is difficult to determine in advance. diff --git a/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.cpp b/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.cpp index b215aeae495..9376cdf7562 100644 --- a/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.cpp +++ b/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.cpp @@ -248,6 +248,40 @@ MergeTreeIndexGranulePtr MergeTreeIndexAggregatorVectorSimilarity::getGranuleAnd return granule; } +namespace +{ + +template +void updateImpl(const ColumnArray * column_array, const ColumnArray::Offsets & column_array_offsets, USearchIndexWithSerializationPtr & index, size_t dimensions, size_t rows) +{ + const auto & column_array_data = column_array->getData(); + const auto & column_array_data_float = typeid_cast(column_array_data); + const auto & column_array_data_float_data = column_array_data_float.getData(); + + /// Check all sizes are the same + for (size_t row = 0; row < rows - 1; ++row) + if (column_array_offsets[row + 1] - column_array_offsets[row] != dimensions) + throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column with vector similarity index must have equal length"); + + /// Reserving space is mandatory + if (!index->try_reserve(roundUpToPowerOfTwoOrZero(index->size() + rows))) + throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for vector similarity index"); + + for (size_t row = 0; row < rows; ++row) + { + if (auto result = index->add(static_cast(index->size()), &column_array_data_float_data[column_array_offsets[row - 1]]); !result) + throw Exception(ErrorCodes::INCORRECT_DATA, "Could not add data to vector similarity index. Error: {}", String(result.error.release())); + else + { + ProfileEvents::increment(ProfileEvents::USearchAddCount); + ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, result.visited_members); + ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, result.computed_distances); + } + } +} + +} + void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_t * pos, size_t limit) { if (*pos >= block.rows()) @@ -273,7 +307,7 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_ const auto * column_array = typeid_cast(column_cut.get()); if (!column_array) - throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array(Float32) column"); + throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array(Float*) column"); if (column_array->empty()) throw Exception(ErrorCodes::LOGICAL_ERROR, "Array is unexpectedly empty"); @@ -302,30 +336,19 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_ if (index->size() + rows > std::numeric_limits::max()) throw Exception(ErrorCodes::INCORRECT_DATA, "Size of vector similarity index would exceed 4 billion entries"); - const auto & column_array_data = column_array->getData(); - const auto & column_array_data_float = typeid_cast(column_array_data); - const auto & column_array_data_float_data = column_array_data_float.getData(); + DataTypePtr data_type = block.getDataTypes()[0]; + const auto * data_type_array = typeid_cast(data_type.get()); + if (!data_type_array) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected data type Array(Float*)"); + const TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId(); - /// Check all sizes are the same - for (size_t row = 0; row < rows - 1; ++row) - if (column_array_offsets[row + 1] - column_array_offsets[row] != dimensions) - throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column with vector similarity index must have equal length"); + if (WhichDataType(nested_type_index).isFloat32()) + updateImpl(column_array, column_array_offsets, index, dimensions, rows); + else if (WhichDataType(nested_type_index).isFloat64()) + updateImpl(column_array, column_array_offsets, index, dimensions, rows); + else + throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected data type Array(Float*)"); - /// Reserving space is mandatory - if (!index->try_reserve(roundUpToPowerOfTwoOrZero(index->size() + rows))) - throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for vector similarity index"); - - for (size_t row = 0; row < rows; ++row) - { - if (auto result = index->add(static_cast(index->size()), &column_array_data_float_data[column_array_offsets[row - 1]]); !result) - throw Exception(ErrorCodes::INCORRECT_DATA, "Could not add data to vector similarity index. Error: {}", String(result.error.release())); - else - { - ProfileEvents::increment(ProfileEvents::USearchAddCount); - ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, result.visited_members); - ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, result.computed_distances); - } - } *pos += rows_read; } @@ -373,7 +396,7 @@ std::vector MergeTreeIndexConditionVectorSimilarity::getUsefulRanges(Mer "does not match the dimension in the index ({})", vector_similarity_condition.getDimensions(), index->dimensions()); - const std::vector reference_vector = vector_similarity_condition.getReferenceVector(); + const std::vector reference_vector = vector_similarity_condition.getReferenceVector(); auto search_result = index->search(reference_vector.data(), limit); if (!search_result) @@ -499,14 +522,14 @@ void vectorSimilarityIndexValidator(const IndexDescription & index, bool /* atta if (index.column_names.size() != 1 || index.data_types.size() != 1) throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Vector similarity indexes must be created on a single column"); - /// Check that the data type is Array(Float32) + /// Check that the data type is Array(Float*) DataTypePtr data_type = index.sample_block.getDataTypes()[0]; const auto * data_type_array = typeid_cast(data_type.get()); if (!data_type_array) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float32)"); + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float*)"); TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId(); - if (!WhichDataType(nested_type_index).isFloat32()) - throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float32)"); + if (!WhichDataType(nested_type_index).isFloat()) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float*)"); } } diff --git a/src/Storages/MergeTree/VectorSimilarityCondition.cpp b/src/Storages/MergeTree/VectorSimilarityCondition.cpp index 2e53b4ecb3a..c8f33857640 100644 --- a/src/Storages/MergeTree/VectorSimilarityCondition.cpp +++ b/src/Storages/MergeTree/VectorSimilarityCondition.cpp @@ -24,7 +24,7 @@ namespace { template -void extractReferenceVectorFromLiteral(std::vector & reference_vector, Literal literal) +void extractReferenceVectorFromLiteral(std::vector & reference_vector, Literal literal) { Float64 float_element_of_reference_vector; Int64 int_element_of_reference_vector; @@ -72,7 +72,7 @@ UInt64 VectorSimilarityCondition::getLimit() const throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported"); } -std::vector VectorSimilarityCondition::getReferenceVector() const +std::vector VectorSimilarityCondition::getReferenceVector() const { if (index_is_useful && query_information.has_value()) return query_information->reference_vector; diff --git a/src/Storages/MergeTree/VectorSimilarityCondition.h b/src/Storages/MergeTree/VectorSimilarityCondition.h index fd339ed715d..2380f8f46b0 100644 --- a/src/Storages/MergeTree/VectorSimilarityCondition.h +++ b/src/Storages/MergeTree/VectorSimilarityCondition.h @@ -60,7 +60,7 @@ public: L2 }; - std::vector reference_vector; + std::vector reference_vector; DistanceFunction distance_function; String column_name; UInt64 limit; @@ -70,7 +70,7 @@ public: /// Returns false if query can be speeded up by an ANN index, true otherwise. bool alwaysUnknownOrTrue(String distance_function) const; - std::vector getReferenceVector() const; + std::vector getReferenceVector() const; size_t getDimensions() const; String getColumnName() const; Info::DistanceFunction getDistanceFunction() const; diff --git a/tests/queries/0_stateless/02354_vector_search_index_creation_negative.sql b/tests/queries/0_stateless/02354_vector_search_index_creation_negative.sql index b39b8d3e754..e8e6aaee1b2 100644 --- a/tests/queries/0_stateless/02354_vector_search_index_creation_negative.sql +++ b/tests/queries/0_stateless/02354_vector_search_index_creation_negative.sql @@ -35,6 +35,6 @@ SELECT 'Must be created on Array(Float32) columns'; SET allow_suspicious_low_cardinality_types = 1; CREATE TABLE tab(id Int32, vec UInt64, INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, vec Float32, INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } -CREATE TABLE tab(id Int32, vec Array(Float64), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } +CREATE TABLE tab(id Int32, vec Array(UInt64), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, vec LowCardinality(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, vec Nullable(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } diff --git a/tests/queries/0_stateless/02354_vector_search_queries.reference b/tests/queries/0_stateless/02354_vector_search_queries.reference index 374ed1f8abd..cb3a8c801b1 100644 --- a/tests/queries/0_stateless/02354_vector_search_queries.reference +++ b/tests/queries/0_stateless/02354_vector_search_queries.reference @@ -107,3 +107,7 @@ Expression (Projection) Description: vector_similarity GRANULARITY 2 Parts: 1/1 Granules: 2/4 +-- Index on Array(Float64) column +6 [0,2] 0 +7 [0,2.1] 0.10000000000000009 +8 [0,2.2] 0.20000000000000018 diff --git a/tests/queries/0_stateless/02354_vector_search_queries.sql b/tests/queries/0_stateless/02354_vector_search_queries.sql index 2c6a7f10776..fbf8427d8fe 100644 --- a/tests/queries/0_stateless/02354_vector_search_queries.sql +++ b/tests/queries/0_stateless/02354_vector_search_queries.sql @@ -124,3 +124,15 @@ LIMIT 3; DROP TABLE tab_f32; DROP TABLE tab_f16; DROP TABLE tab_i8; + +SELECT '-- Index on Array(Float64) column'; +CREATE TABLE tab(id Int32, vec Array(Float64), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance') GRANULARITY 2) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 3; +INSERT INTO tab VALUES (0, [1.0, 0.0]), (1, [1.1, 0.0]), (2, [1.2, 0.0]), (3, [1.3, 0.0]), (4, [1.4, 0.0]), (5, [1.5, 0.0]), (6, [0.0, 2.0]), (7, [0.0, 2.1]), (8, [0.0, 2.2]), (9, [0.0, 2.3]), (10, [0.0, 2.4]), (11, [0.0, 2.5]); + +WITH [0.0, 2.0] AS reference_vec +SELECT id, vec, L2Distance(vec, reference_vec) +FROM tab +ORDER BY L2Distance(vec, reference_vec) +LIMIT 3; + +DROP TABLE tab;