Allow Array(Float64) as type of underlying column

This commit is contained in:
Robert Schulze 2024-08-15 10:42:06 +00:00
parent b9548504d9
commit 38a2b0dcc7
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
7 changed files with 76 additions and 37 deletions

View File

@ -22,10 +22,10 @@ ORDER BY Distance(vectors, Point)
LIMIT N LIMIT N
``` ```
`vectors` contains N-dimensional values of type [Array(Float32)](../../../sql-reference/data-types/array.md), for example embeddings. `vectors` contains N-dimensional values of type [Array(Float32)](../../../sql-reference/data-types/array.md) or Array(Float64), for example
Function `Distance` computes the distance between two vectors. Often, the Euclidean (L2) distance is chosen as distance function but [other embeddings. Function `Distance` computes the distance between two vectors. Often, the Euclidean (L2) distance is chosen as distance function
distance functions](/docs/en/sql-reference/functions/distance-functions.md) are also possible. `Point` is the reference point, e.g. `(0.17, but [other distance functions](/docs/en/sql-reference/functions/distance-functions.md) are also possible. `Point` is the reference point,
0.33, ...)`, and `N` limits the number of search results. e.g. `(0.17, 0.33, ...)`, and `N` limits the number of search results.
This query returns the top-`N` closest points to the reference point. Parameter `N` limits the number of returned values which is useful for This query returns the top-`N` closest points to the reference point. Parameter `N` limits the number of returned values which is useful for
situations where `MaxDistance` is difficult to determine in advance. situations where `MaxDistance` is difficult to determine in advance.

View File

@ -248,6 +248,40 @@ MergeTreeIndexGranulePtr MergeTreeIndexAggregatorVectorSimilarity::getGranuleAnd
return granule; return granule;
} }
namespace
{
template <typename Column>
void updateImpl(const ColumnArray * column_array, const ColumnArray::Offsets & column_array_offsets, USearchIndexWithSerializationPtr & index, size_t dimensions, size_t rows)
{
const auto & column_array_data = column_array->getData();
const auto & column_array_data_float = typeid_cast<const Column &>(column_array_data);
const auto & column_array_data_float_data = column_array_data_float.getData();
/// Check all sizes are the same
for (size_t row = 0; row < rows - 1; ++row)
if (column_array_offsets[row + 1] - column_array_offsets[row] != dimensions)
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column with vector similarity index must have equal length");
/// Reserving space is mandatory
if (!index->try_reserve(roundUpToPowerOfTwoOrZero(index->size() + rows)))
throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for vector similarity index");
for (size_t row = 0; row < rows; ++row)
{
if (auto result = index->add(static_cast<USearchIndex::vector_key_t>(index->size()), &column_array_data_float_data[column_array_offsets[row - 1]]); !result)
throw Exception(ErrorCodes::INCORRECT_DATA, "Could not add data to vector similarity index. Error: {}", String(result.error.release()));
else
{
ProfileEvents::increment(ProfileEvents::USearchAddCount);
ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, result.visited_members);
ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, result.computed_distances);
}
}
}
}
void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_t * pos, size_t limit) void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_t * pos, size_t limit)
{ {
if (*pos >= block.rows()) if (*pos >= block.rows())
@ -273,7 +307,7 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_
const auto * column_array = typeid_cast<const ColumnArray *>(column_cut.get()); const auto * column_array = typeid_cast<const ColumnArray *>(column_cut.get());
if (!column_array) if (!column_array)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array(Float32) column"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array(Float*) column");
if (column_array->empty()) if (column_array->empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Array is unexpectedly empty"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Array is unexpectedly empty");
@ -302,30 +336,19 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_
if (index->size() + rows > std::numeric_limits<UInt32>::max()) if (index->size() + rows > std::numeric_limits<UInt32>::max())
throw Exception(ErrorCodes::INCORRECT_DATA, "Size of vector similarity index would exceed 4 billion entries"); throw Exception(ErrorCodes::INCORRECT_DATA, "Size of vector similarity index would exceed 4 billion entries");
const auto & column_array_data = column_array->getData(); DataTypePtr data_type = block.getDataTypes()[0];
const auto & column_array_data_float = typeid_cast<const ColumnFloat32 &>(column_array_data); const auto * data_type_array = typeid_cast<const DataTypeArray *>(data_type.get());
const auto & column_array_data_float_data = column_array_data_float.getData(); if (!data_type_array)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected data type Array(Float*)");
const TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId();
/// Check all sizes are the same if (WhichDataType(nested_type_index).isFloat32())
for (size_t row = 0; row < rows - 1; ++row) updateImpl<ColumnFloat32>(column_array, column_array_offsets, index, dimensions, rows);
if (column_array_offsets[row + 1] - column_array_offsets[row] != dimensions) else if (WhichDataType(nested_type_index).isFloat64())
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column with vector similarity index must have equal length"); updateImpl<ColumnFloat64>(column_array, column_array_offsets, index, dimensions, rows);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected data type Array(Float*)");
/// Reserving space is mandatory
if (!index->try_reserve(roundUpToPowerOfTwoOrZero(index->size() + rows)))
throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for vector similarity index");
for (size_t row = 0; row < rows; ++row)
{
if (auto result = index->add(static_cast<USearchIndex::vector_key_t>(index->size()), &column_array_data_float_data[column_array_offsets[row - 1]]); !result)
throw Exception(ErrorCodes::INCORRECT_DATA, "Could not add data to vector similarity index. Error: {}", String(result.error.release()));
else
{
ProfileEvents::increment(ProfileEvents::USearchAddCount);
ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, result.visited_members);
ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, result.computed_distances);
}
}
*pos += rows_read; *pos += rows_read;
} }
@ -373,7 +396,7 @@ std::vector<size_t> MergeTreeIndexConditionVectorSimilarity::getUsefulRanges(Mer
"does not match the dimension in the index ({})", "does not match the dimension in the index ({})",
vector_similarity_condition.getDimensions(), index->dimensions()); vector_similarity_condition.getDimensions(), index->dimensions());
const std::vector<float> reference_vector = vector_similarity_condition.getReferenceVector(); const std::vector<Float64> reference_vector = vector_similarity_condition.getReferenceVector();
auto search_result = index->search(reference_vector.data(), limit); auto search_result = index->search(reference_vector.data(), limit);
if (!search_result) if (!search_result)
@ -499,14 +522,14 @@ void vectorSimilarityIndexValidator(const IndexDescription & index, bool /* atta
if (index.column_names.size() != 1 || index.data_types.size() != 1) if (index.column_names.size() != 1 || index.data_types.size() != 1)
throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Vector similarity indexes must be created on a single column"); throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Vector similarity indexes must be created on a single column");
/// Check that the data type is Array(Float32) /// Check that the data type is Array(Float*)
DataTypePtr data_type = index.sample_block.getDataTypes()[0]; DataTypePtr data_type = index.sample_block.getDataTypes()[0];
const auto * data_type_array = typeid_cast<const DataTypeArray *>(data_type.get()); const auto * data_type_array = typeid_cast<const DataTypeArray *>(data_type.get());
if (!data_type_array) if (!data_type_array)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float32)"); throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float*)");
TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId(); TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32()) if (!WhichDataType(nested_type_index).isFloat())
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float32)"); throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float*)");
} }
} }

View File

@ -24,7 +24,7 @@ namespace
{ {
template <typename Literal> template <typename Literal>
void extractReferenceVectorFromLiteral(std::vector<Float32> & reference_vector, Literal literal) void extractReferenceVectorFromLiteral(std::vector<Float64> & reference_vector, Literal literal)
{ {
Float64 float_element_of_reference_vector; Float64 float_element_of_reference_vector;
Int64 int_element_of_reference_vector; Int64 int_element_of_reference_vector;
@ -72,7 +72,7 @@ UInt64 VectorSimilarityCondition::getLimit() const
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported"); throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
} }
std::vector<float> VectorSimilarityCondition::getReferenceVector() const std::vector<Float64> VectorSimilarityCondition::getReferenceVector() const
{ {
if (index_is_useful && query_information.has_value()) if (index_is_useful && query_information.has_value())
return query_information->reference_vector; return query_information->reference_vector;

View File

@ -60,7 +60,7 @@ public:
L2 L2
}; };
std::vector<Float32> reference_vector; std::vector<Float64> reference_vector;
DistanceFunction distance_function; DistanceFunction distance_function;
String column_name; String column_name;
UInt64 limit; UInt64 limit;
@ -70,7 +70,7 @@ public:
/// Returns false if query can be speeded up by an ANN index, true otherwise. /// Returns false if query can be speeded up by an ANN index, true otherwise.
bool alwaysUnknownOrTrue(String distance_function) const; bool alwaysUnknownOrTrue(String distance_function) const;
std::vector<float> getReferenceVector() const; std::vector<Float64> getReferenceVector() const;
size_t getDimensions() const; size_t getDimensions() const;
String getColumnName() const; String getColumnName() const;
Info::DistanceFunction getDistanceFunction() const; Info::DistanceFunction getDistanceFunction() const;

View File

@ -35,6 +35,6 @@ SELECT 'Must be created on Array(Float32) columns';
SET allow_suspicious_low_cardinality_types = 1; SET allow_suspicious_low_cardinality_types = 1;
CREATE TABLE tab(id Int32, vec UInt64, INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, vec UInt64, INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN }
CREATE TABLE tab(id Int32, vec Float32, INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, vec Float32, INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN }
CREATE TABLE tab(id Int32, vec Array(Float64), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, vec Array(UInt64), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN }
CREATE TABLE tab(id Int32, vec LowCardinality(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, vec LowCardinality(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN }
CREATE TABLE tab(id Int32, vec Nullable(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, vec Nullable(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN }

View File

@ -107,3 +107,7 @@ Expression (Projection)
Description: vector_similarity GRANULARITY 2 Description: vector_similarity GRANULARITY 2
Parts: 1/1 Parts: 1/1
Granules: 2/4 Granules: 2/4
-- Index on Array(Float64) column
6 [0,2] 0
7 [0,2.1] 0.10000000000000009
8 [0,2.2] 0.20000000000000018

View File

@ -124,3 +124,15 @@ LIMIT 3;
DROP TABLE tab_f32; DROP TABLE tab_f32;
DROP TABLE tab_f16; DROP TABLE tab_f16;
DROP TABLE tab_i8; DROP TABLE tab_i8;
SELECT '-- Index on Array(Float64) column';
CREATE TABLE tab(id Int32, vec Array(Float64), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance') GRANULARITY 2) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 3;
INSERT INTO tab VALUES (0, [1.0, 0.0]), (1, [1.1, 0.0]), (2, [1.2, 0.0]), (3, [1.3, 0.0]), (4, [1.4, 0.0]), (5, [1.5, 0.0]), (6, [0.0, 2.0]), (7, [0.0, 2.1]), (8, [0.0, 2.2]), (9, [0.0, 2.3]), (10, [0.0, 2.4]), (11, [0.0, 2.5]);
WITH [0.0, 2.0] AS reference_vec
SELECT id, vec, L2Distance(vec, reference_vec)
FROM tab
ORDER BY L2Distance(vec, reference_vec)
LIMIT 3;
DROP TABLE tab;