From fa635358d2352123f56616d918ad9f96d6c62ba1 Mon Sep 17 00:00:00 2001 From: Robert Schulze Date: Wed, 6 Nov 2024 20:34:05 +0000 Subject: [PATCH] Initial Analyzer support for vector similarity search MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This code is extremely shameful, don't look at it. Only goal of this PR is to demonstrate that the analyzer is able to see vector similarity indexes: ip-10-19-83-181.eu-central-1.compute.internal :) explain plan indexes=1 SELECT id, vec, L2Distance(vec, [0.0, 2.0]) FROM tab ORDER BY L2Distance(vec, [0.0, 2.0]) LIMIT 3 settings enable_analyzer = 1 EXPLAIN indexes = 1 SELECT id, vec, L2Distance(vec, [0., 2.]) FROM tab ORDER BY L2Distance(vec, [0., 2.]) ASC LIMIT 3 SETTINGS enable_analyzer = 1 Query id: 086df46e-2e27-4342-a47c-6900762a7c8c ┌─explain─────────────────────────────────────────────────────────────────────────────────────────┐ 1. │ Expression (Project names) │ 2. │ Limit (preliminary LIMIT (without OFFSET)) │ 3. │ Sorting (Sorting for ORDER BY) │ 4. │ Expression ((Before ORDER BY + (Projection + Change column names to column identifiers))) │ 5. │ ReadFromMergeTree (default.tab) │ 6. │ Indexes: │ 7. │ PrimaryKey │ 8. │ Condition: true │ 9. │ Parts: 1/1 │ 10. │ Granules: 1/1 │ 11. │ Skip │ 12. │ Name: idx │ 13. │ Description: vector_similarity GRANULARITY 100000000 │ 14. │ Parts: 1/1 │ 15. │ Granules: 1/1 │ └─────────────────────────────────────────────────────────────────────────────────────────────────┘ 15 rows in set. Elapsed: 0.007 sec. --- .../mergetree-family/annindexes.md | 2 - .../QueryPlan/Optimizations/Optimizations.h | 7 +- .../QueryPlanOptimizationSettings.h | 3 + .../Optimizations/optimizeReadInOrder.cpp | 1 + .../QueryPlan/Optimizations/optimizeTree.cpp | 1 - .../Optimizations/optimizeVectorSearch.cpp | 53 +++ .../Optimizations/useVectorSearch.cpp | 76 ++++ .../QueryPlan/ReadFromMergeTree.cpp | 9 +- src/Processors/QueryPlan/ReadFromMergeTree.h | 18 +- .../MergeTree/MergeTreeDataSelectExecutor.cpp | 17 +- .../MergeTree/MergeTreeDataSelectExecutor.h | 6 +- .../MergeTreeIndexLegacyVectorSimilarity.h | 2 +- .../MergeTreeIndexVectorSimilarity.cpp | 32 +- .../MergeTreeIndexVectorSimilarity.h | 9 +- src/Storages/MergeTree/MergeTreeIndices.h | 5 +- .../MergeTree/VectorSimilarityCondition.cpp | 354 ------------------ .../MergeTree/VectorSimilarityCondition.h | 147 -------- .../02354_vector_search_queries.sql | 15 + 18 files changed, 212 insertions(+), 545 deletions(-) create mode 100644 src/Processors/QueryPlan/Optimizations/optimizeVectorSearch.cpp create mode 100644 src/Processors/QueryPlan/Optimizations/useVectorSearch.cpp delete mode 100644 src/Storages/MergeTree/VectorSimilarityCondition.cpp delete mode 100644 src/Storages/MergeTree/VectorSimilarityCondition.h diff --git a/docs/en/engines/table-engines/mergetree-family/annindexes.md b/docs/en/engines/table-engines/mergetree-family/annindexes.md index dc12a60e8ef..3cf8a2fe5c2 100644 --- a/docs/en/engines/table-engines/mergetree-family/annindexes.md +++ b/docs/en/engines/table-engines/mergetree-family/annindexes.md @@ -84,8 +84,6 @@ to load and compare. The library also has several hardware-specific SIMD optimiz Arm (NEON and SVE) and x86 (AVX2 and AVX-512) CPUs and OS-specific optimizations to allow efficient navigation around immutable persistent files, without loading them into RAM. -USearch indexes are currently experimental, to use them you first need to `SET allow_experimental_vector_similarity_index = 1`. - Vector similarity indexes currently support two distance functions: - `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space ([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)). diff --git a/src/Processors/QueryPlan/Optimizations/Optimizations.h b/src/Processors/QueryPlan/Optimizations/Optimizations.h index c1c4d1e1635..7e0220d1d10 100644 --- a/src/Processors/QueryPlan/Optimizations/Optimizations.h +++ b/src/Processors/QueryPlan/Optimizations/Optimizations.h @@ -71,6 +71,8 @@ void tryRemoveRedundantSorting(QueryPlan::Node * root); /// Remove redundant distinct steps size_t tryRemoveRedundantDistinct(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes); +size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes); + /// Put some steps under union, so that plan optimization could be applied to union parts separately. /// For example, the plan can be rewritten like: /// - Something - - Expression - Something - @@ -82,7 +84,7 @@ size_t tryAggregatePartitionsIndependently(QueryPlan::Node * node, QueryPlan::No inline const auto & getOptimizations() { - static const std::array optimizations = {{ + static const std::array optimizations = {{ {tryLiftUpArrayJoin, "liftUpArrayJoin", &QueryPlanOptimizationSettings::lift_up_array_join}, {tryPushDownLimit, "pushDownLimit", &QueryPlanOptimizationSettings::push_down_limit}, {trySplitFilter, "splitFilter", &QueryPlanOptimizationSettings::split_filter}, @@ -95,6 +97,7 @@ inline const auto & getOptimizations() {tryLiftUpUnion, "liftUpUnion", &QueryPlanOptimizationSettings::lift_up_union}, {tryAggregatePartitionsIndependently, "aggregatePartitionsIndependently", &QueryPlanOptimizationSettings::aggregate_partitions_independently}, {tryRemoveRedundantDistinct, "removeRedundantDistinct", &QueryPlanOptimizationSettings::remove_redundant_distinct}, + {tryUseVectorSearch, "useVectorSearch", &QueryPlanOptimizationSettings::use_vector_search}, }}; return optimizations; @@ -110,7 +113,9 @@ using Stack = std::vector; /// Second pass optimizations void optimizePrimaryKeyConditionAndLimit(const Stack & stack); +void optimizeVectorSearch(const Stack & stack); void optimizePrewhere(Stack & stack, QueryPlan::Nodes & nodes); +void optimizeVectorSearch(QueryPlan::Node & node, QueryPlan::Nodes & nodes); void optimizeReadInOrder(QueryPlan::Node & node, QueryPlan::Nodes & nodes); void optimizeAggregationInOrder(QueryPlan::Node & node, QueryPlan::Nodes &); void optimizeJoin(QueryPlan::Node & node, QueryPlan::Nodes &); diff --git a/src/Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h b/src/Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h index 6232fc7f54f..c8dbd8f65ac 100644 --- a/src/Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h +++ b/src/Processors/QueryPlan/Optimizations/QueryPlanOptimizationSettings.h @@ -70,6 +70,9 @@ struct QueryPlanOptimizationSettings /// If remove-redundant-distinct-steps optimization is enabled. bool remove_redundant_distinct = true; + /// If use vector search is enabled, the query will use the vector similarity index + bool use_vector_search = true; + bool optimize_prewhere = true; /// If reading from projection can be applied diff --git a/src/Processors/QueryPlan/Optimizations/optimizeReadInOrder.cpp b/src/Processors/QueryPlan/Optimizations/optimizeReadInOrder.cpp index e64a88de62e..b2868194977 100644 --- a/src/Processors/QueryPlan/Optimizations/optimizeReadInOrder.cpp +++ b/src/Processors/QueryPlan/Optimizations/optimizeReadInOrder.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include diff --git a/src/Processors/QueryPlan/Optimizations/optimizeTree.cpp b/src/Processors/QueryPlan/Optimizations/optimizeTree.cpp index c034ca79181..96eb0f98647 100644 --- a/src/Processors/QueryPlan/Optimizations/optimizeTree.cpp +++ b/src/Processors/QueryPlan/Optimizations/optimizeTree.cpp @@ -126,7 +126,6 @@ void optimizeTreeSecondPass(const QueryPlanOptimizationSettings & optimization_s if (frame.next_child == 0) { - if (optimization_settings.read_in_order) optimizeReadInOrder(*frame.node, nodes); diff --git a/src/Processors/QueryPlan/Optimizations/optimizeVectorSearch.cpp b/src/Processors/QueryPlan/Optimizations/optimizeVectorSearch.cpp new file mode 100644 index 00000000000..a352027f7b6 --- /dev/null +++ b/src/Processors/QueryPlan/Optimizations/optimizeVectorSearch.cpp @@ -0,0 +1,53 @@ +#include +#include +#include +#include +#include + +namespace DB::QueryPlanOptimizations +{ + +void xx([[maybe_unused]] const QueryPlan::Node & node) +{ + /// const auto & frame = stack.back(); + /// + /// auto * source_step_with_filter = dynamic_cast(frame.node->step.get()); + /// if (!source_step_with_filter) + /// return; + /// + /// const auto & storage_prewhere_info = source_step_with_filter->getPrewhereInfo(); + /// if (storage_prewhere_info) + /// { + /// source_step_with_filter->addFilter(storage_prewhere_info->prewhere_actions.clone(), storage_prewhere_info->prewhere_column_name); + /// if (storage_prewhere_info->row_level_filter) + /// source_step_with_filter->addFilter(storage_prewhere_info->row_level_filter->clone(), storage_prewhere_info->row_level_column_name); + /// } + /// + /// for (auto iter = stack.rbegin() + 1; iter != stack.rend(); ++iter) + /// { + /// if (auto * filter_step = typeid_cast(iter->node->step.get())) + /// { + /// source_step_with_filter->addFilter(filter_step->getExpression().clone(), filter_step->getFilterColumnName()); + /// } + /// else if (auto * limit_step = typeid_cast(iter->node->step.get())) + /// { + /// source_step_with_filter->setLimit(limit_step->getLimitForSorting()); + /// break; + /// } + /// else if (typeid_cast(iter->node->step.get())) + /// { + /// /// Note: actually, plan optimizations merge Filter and Expression steps. + /// /// Ideally, chain should look like (Expression -> ...) -> (Filter -> ...) -> ReadFromStorage, + /// /// So this is likely not needed. + /// continue; + /// } + /// else + /// { + /// break; + /// } + /// } + /// + /// source_step_with_filter->applyFilters(); +} + +} diff --git a/src/Processors/QueryPlan/Optimizations/useVectorSearch.cpp b/src/Processors/QueryPlan/Optimizations/useVectorSearch.cpp new file mode 100644 index 00000000000..4ae9b706791 --- /dev/null +++ b/src/Processors/QueryPlan/Optimizations/useVectorSearch.cpp @@ -0,0 +1,76 @@ +#include +#include +#include +#include +#include +#include + +namespace DB::QueryPlanOptimizations +{ + +size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & /* nodes*/) +{ + auto * limit_step = typeid_cast(parent_node->step.get()); + if (!limit_step) + return 0; + + if (parent_node->children.size() != 1) + return 0; + QueryPlan::Node * child_node = parent_node->children.front(); + auto * sorting_step = typeid_cast(child_node->step.get()); + if (!sorting_step) + return 0; + + if (child_node->children.size() != 1) + return 0; + child_node = child_node->children.front(); + auto * expression_step = typeid_cast(child_node->step.get()); + if (!expression_step) + return 0; + + if (child_node->children.size() != 1) + return 0; + child_node = child_node->children.front(); + auto * read_from_mergetree_step = typeid_cast(child_node->step.get()); + if (!read_from_mergetree_step) + return 0; + + const auto & sort_description = sorting_step->getSortDescription(); + + if (sort_description.size() != 1) + return 0; + + [[maybe_unused]] ReadFromMergeTree::DistanceFunction distance_function; + + /// lol + if (sort_description[0].column_name.starts_with("L2Distance")) + distance_function = ReadFromMergeTree::DistanceFunction::L2Distance; + else if (sort_description[0].column_name.starts_with("cosineDistance")) + distance_function = ReadFromMergeTree::DistanceFunction::cosineDistance; + else + return 0; + + [[maybe_unused]] size_t limit = sorting_step->getLimit(); + [[maybe_unused]] std::vector reference_vector = {0.0, 0.2}; /// TODO + + /// TODO check that ReadFromMergeTree has a vector similarity index + + read_from_mergetree_step->vec_sim_idx_input = std::make_optional( + distance_function, limit, reference_vector); + + /// --- --- --- + /// alwaysUnknownOrTrue: + /// + /// String index_distance_function; + /// switch (metric_kind) + /// { + /// case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break; + /// case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break; + /// default: std::unreachable(); + /// } + /// return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function); + + return 0; +} + +} diff --git a/src/Processors/QueryPlan/ReadFromMergeTree.cpp b/src/Processors/QueryPlan/ReadFromMergeTree.cpp index 3186df6a6b3..e0bd003127b 100644 --- a/src/Processors/QueryPlan/ReadFromMergeTree.cpp +++ b/src/Processors/QueryPlan/ReadFromMergeTree.cpp @@ -1457,7 +1457,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(Merge all_column_names, log, indexes, - find_exact_ranges); + find_exact_ranges, + vec_sim_idx_input); } static void buildIndexes( @@ -1650,7 +1651,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead( const Names & all_column_names, LoggerPtr log, std::optional & indexes, - bool find_exact_ranges) + bool find_exact_ranges, + const std::optional & vec_sim_idx_input) { AnalysisResult result; const auto & settings = context_->getSettingsRef(); @@ -1743,7 +1745,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead( num_streams, result.index_stats, indexes->use_skip_indexes, - find_exact_ranges); + find_exact_ranges, + vec_sim_idx_input); } size_t sum_marks_pk = total_marks_pk; diff --git a/src/Processors/QueryPlan/ReadFromMergeTree.h b/src/Processors/QueryPlan/ReadFromMergeTree.h index 46a02f5643b..47f2d117cfa 100644 --- a/src/Processors/QueryPlan/ReadFromMergeTree.h +++ b/src/Processors/QueryPlan/ReadFromMergeTree.h @@ -166,6 +166,19 @@ public: std::optional> part_values; }; + enum class DistanceFunction + { + L2Distance, + cosineDistance + }; + + struct VectorSimilarityIndexInput + { + DistanceFunction distance_function; + size_t limit; + std::vector reference_vector; + }; + static AnalysisResultPtr selectRangesToRead( MergeTreeData::DataPartsVector parts, MergeTreeData::MutationsSnapshotPtr mutations_snapshot, @@ -178,7 +191,8 @@ public: const Names & all_column_names, LoggerPtr log, std::optional & indexes, - bool find_exact_ranges); + bool find_exact_ranges, + const std::optional & vec_sim_idx_input); AnalysisResultPtr selectRangesToRead(MergeTreeData::DataPartsVector parts, bool find_exact_ranges = false) const; @@ -212,6 +226,8 @@ public: void applyFilters(ActionDAGNodes added_filter_nodes) override; + std::optional vec_sim_idx_input; + private: MergeTreeReaderSettings reader_settings; diff --git a/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp b/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp index 1b3c58000e7..33369f85576 100644 --- a/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp +++ b/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp @@ -11,7 +11,6 @@ #include #include #include -#include #include #include #include @@ -629,7 +628,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd size_t num_streams, ReadFromMergeTree::IndexStats & index_stats, bool use_skip_indexes, - bool find_exact_ranges) + bool find_exact_ranges, + const std::optional & vec_sim_idx_input) { RangesInDataParts parts_with_ranges; parts_with_ranges.resize(parts.size()); @@ -734,7 +734,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd reader_settings, mark_cache.get(), uncompressed_cache.get(), - log); + log, + vec_sim_idx_input); stat.granules_dropped.fetch_add(total_granules - ranges.ranges.getNumberOfMarks(), std::memory_order_relaxed); if (ranges.ranges.empty()) @@ -950,7 +951,8 @@ ReadFromMergeTree::AnalysisResultPtr MergeTreeDataSelectExecutor::estimateNumMar column_names_to_return, log, indexes, - /*find_exact_ranges*/false); + /*find_exact_ranges*/false, + std::nullopt); } QueryPlanStepPtr MergeTreeDataSelectExecutor::readFromParts( @@ -1360,7 +1362,8 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex( const MergeTreeReaderSettings & reader_settings, MarkCache * mark_cache, UncompressedCache * uncompressed_cache, - LoggerPtr log) + LoggerPtr log, + const std::optional & vec_sim_idx_input) { if (!index_helper->getDeserializedFormat(part->getDataPartStorage(), index_helper->getFileName())) { @@ -1421,9 +1424,9 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex( if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin) reader.read(granule); - if (index_helper->isVectorSimilarityIndex()) + if (index_helper->isVectorSimilarityIndex() && vec_sim_idx_input) { - auto rows = condition->calculateApproximateNearestNeighbors(granule); + auto rows = condition->calculateApproximateNearestNeighbors(granule, vec_sim_idx_input->limit, vec_sim_idx_input->reference_vector); /// TODO for (auto row : rows) { diff --git a/src/Storages/MergeTree/MergeTreeDataSelectExecutor.h b/src/Storages/MergeTree/MergeTreeDataSelectExecutor.h index d16d9243c14..12b9780c10a 100644 --- a/src/Storages/MergeTree/MergeTreeDataSelectExecutor.h +++ b/src/Storages/MergeTree/MergeTreeDataSelectExecutor.h @@ -94,7 +94,8 @@ private: const MergeTreeReaderSettings & reader_settings, MarkCache * mark_cache, UncompressedCache * uncompressed_cache, - LoggerPtr log); + LoggerPtr log, + const std::optional & vec_sim_idx_input); static MarkRanges filterMarksUsingMergedIndex( MergeTreeIndices indices, @@ -197,7 +198,8 @@ public: size_t num_streams, ReadFromMergeTree::IndexStats & index_stats, bool use_skip_indexes, - bool find_exact_ranges); + bool find_exact_ranges, + const std::optional & vec_sim_idx_input); /// Create expression for sampling. /// Also, calculate _sample_factor if needed. diff --git a/src/Storages/MergeTree/MergeTreeIndexLegacyVectorSimilarity.h b/src/Storages/MergeTree/MergeTreeIndexLegacyVectorSimilarity.h index 1015401823d..d9ba24dc13f 100644 --- a/src/Storages/MergeTree/MergeTreeIndexLegacyVectorSimilarity.h +++ b/src/Storages/MergeTree/MergeTreeIndexLegacyVectorSimilarity.h @@ -1,6 +1,6 @@ #pragma once -#include +#include /// Walking corpse implementation for removed skipping index of type "annoy" and "usearch". /// Its only purpose is to allow loading old tables with indexes of these types. diff --git a/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.cpp b/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.cpp index 5a725922e14..6f26473ae8f 100644 --- a/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.cpp +++ b/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.cpp @@ -401,11 +401,9 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_ MergeTreeIndexConditionVectorSimilarity::MergeTreeIndexConditionVectorSimilarity( const IndexDescription & /*index_description*/, - const SelectQueryInfo & query, unum::usearch::metric_kind_t metric_kind_, ContextPtr context) - : vector_similarity_condition(query, context) - , metric_kind(metric_kind_) + : metric_kind(metric_kind_) , expansion_search(context->getSettingsRef()[Setting::hnsw_candidate_list_size_for_search]) { if (expansion_search == 0) @@ -420,31 +418,23 @@ bool MergeTreeIndexConditionVectorSimilarity::mayBeTrueOnGranule(MergeTreeIndexG bool MergeTreeIndexConditionVectorSimilarity::alwaysUnknownOrTrue() const { - String index_distance_function; - switch (metric_kind) - { - case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break; - case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break; - default: std::unreachable(); - } - return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function); + return false; } -std::vector MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule_) const +std::vector MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors( + MergeTreeIndexGranulePtr granule_, + size_t limit, + const std::vector & reference_vector) const { - const UInt64 limit = vector_similarity_condition.getLimit(); - const auto granule = std::dynamic_pointer_cast(granule_); if (granule == nullptr) throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type"); const USearchIndexWithSerializationPtr index = granule->index; - if (vector_similarity_condition.getDimensions() != index->dimensions()) - throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) does not match the dimension in the index ({})", - vector_similarity_condition.getDimensions(), index->dimensions()); - - const std::vector reference_vector = vector_similarity_condition.getReferenceVector(); + if (reference_vector.size() != index->dimensions()) + throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the reference vector in the query ({}) does not match the dimension in the index ({})", + reference_vector.size(), index->dimensions()); /// We want to run the search with the user-provided value for setting hnsw_candidate_list_size_for_search (aka. expansion_search). /// The way to do this in USearch is to call index_dense_gt::change_expansion_search. Unfortunately, this introduces a need to @@ -498,9 +488,9 @@ MergeTreeIndexAggregatorPtr MergeTreeIndexVectorSimilarity::createIndexAggregato return std::make_shared(index.name, index.sample_block, metric_kind, scalar_kind, usearch_hnsw_params); } -MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const +MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & /*query*/, ContextPtr context) const { - return std::make_shared(index, query, metric_kind, context); + return std::make_shared(index, metric_kind, context); }; MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const ActionsDAG *, ContextPtr) const diff --git a/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.h b/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.h index 9a81e168393..75aae459aa4 100644 --- a/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.h +++ b/src/Storages/MergeTree/MergeTreeIndexVectorSimilarity.h @@ -4,7 +4,7 @@ #if USE_USEARCH -#include +#include #include #include @@ -136,7 +136,6 @@ class MergeTreeIndexConditionVectorSimilarity final : public IMergeTreeIndexCond public: MergeTreeIndexConditionVectorSimilarity( const IndexDescription & index_description, - const SelectQueryInfo & query, unum::usearch::metric_kind_t metric_kind_, ContextPtr context); @@ -144,10 +143,12 @@ public: bool alwaysUnknownOrTrue() const override; bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr granule) const override; - std::vector calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule) const override; + std::vector calculateApproximateNearestNeighbors( + MergeTreeIndexGranulePtr granule, + size_t limit, + const std::vector & reference_vector) const override; private: - const VectorSimilarityCondition vector_similarity_condition; const unum::usearch::metric_kind_t metric_kind; const size_t expansion_search; }; diff --git a/src/Storages/MergeTree/MergeTreeIndices.h b/src/Storages/MergeTree/MergeTreeIndices.h index 9a358cb4b58..c472aa97007 100644 --- a/src/Storages/MergeTree/MergeTreeIndices.h +++ b/src/Storages/MergeTree/MergeTreeIndices.h @@ -99,7 +99,10 @@ public: /// Special method for vector similarity indexes: /// Returns the row positions of the N nearest neighbors in the index granule /// The returned row numbers are guaranteed to be sorted and unique. - virtual std::vector calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr) const + virtual std::vector calculateApproximateNearestNeighbors( + MergeTreeIndexGranulePtr /*granule*/, + size_t /*limit*/, + const std::vector & /*reference_vector*/) const { throw Exception(ErrorCodes::LOGICAL_ERROR, "calculateApproximateNearestNeighbors is not implemented for non-vector-similarity indexes"); } diff --git a/src/Storages/MergeTree/VectorSimilarityCondition.cpp b/src/Storages/MergeTree/VectorSimilarityCondition.cpp deleted file mode 100644 index ea3b1fbad8d..00000000000 --- a/src/Storages/MergeTree/VectorSimilarityCondition.cpp +++ /dev/null @@ -1,354 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace DB -{ -namespace Setting -{ - extern const SettingsUInt64 max_limit_for_ann_queries; -} - -namespace ErrorCodes -{ - extern const int LOGICAL_ERROR; - extern const int INCORRECT_QUERY; -} - -namespace -{ - -template -void extractReferenceVectorFromLiteral(std::vector & reference_vector, Literal literal) -{ - Float64 float_element_of_reference_vector; - Int64 int_element_of_reference_vector; - - for (const auto & value : literal.value()) - { - if (value.tryGet(float_element_of_reference_vector)) - reference_vector.emplace_back(float_element_of_reference_vector); - else if (value.tryGet(int_element_of_reference_vector)) - reference_vector.emplace_back(static_cast(int_element_of_reference_vector)); - else - throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported."); - } -} - -VectorSimilarityCondition::Info::DistanceFunction stringToDistanceFunction(const String & distance_function) -{ - if (distance_function == "L2Distance") - return VectorSimilarityCondition::Info::DistanceFunction::L2; - if (distance_function == "cosineDistance") - return VectorSimilarityCondition::Info::DistanceFunction::Cosine; - return VectorSimilarityCondition::Info::DistanceFunction::Unknown; -} - -} - -VectorSimilarityCondition::VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context) - : block_with_constants(KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context)) - , max_limit_for_ann_queries(context->getSettingsRef()[Setting::max_limit_for_ann_queries]) - , index_is_useful(checkQueryStructure(query_info)) -{} - -bool VectorSimilarityCondition::alwaysUnknownOrTrue(const String & distance_function) const -{ - if (!index_is_useful) - return true; /// query isn't supported - /// If query is supported, check if distance function of index is the same as distance function in query - return !(stringToDistanceFunction(distance_function) == query_information->distance_function); -} - -UInt64 VectorSimilarityCondition::getLimit() const -{ - if (index_is_useful && query_information.has_value()) - return query_information->limit; - throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported"); -} - -std::vector VectorSimilarityCondition::getReferenceVector() const -{ - if (index_is_useful && query_information.has_value()) - return query_information->reference_vector; - throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index."); -} - -size_t VectorSimilarityCondition::getDimensions() const -{ - if (index_is_useful && query_information.has_value()) - return query_information->reference_vector.size(); - throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index."); -} - -String VectorSimilarityCondition::getColumnName() const -{ - if (index_is_useful && query_information.has_value()) - return query_information->column_name; - throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index."); -} - -VectorSimilarityCondition::Info::DistanceFunction VectorSimilarityCondition::getDistanceFunction() const -{ - if (index_is_useful && query_information.has_value()) - return query_information->distance_function; - throw Exception(ErrorCodes::LOGICAL_ERROR, "Distance function was requested for useless or uninitialized index."); -} - -bool VectorSimilarityCondition::checkQueryStructure(const SelectQueryInfo & query) -{ - Info order_by_info; - - /// Build rpns for query sections - const auto & select = query.query->as(); - - RPN rpn_order_by; - RPNElement rpn_limit; - UInt64 limit; - - if (select.limitLength()) - traverseAtomAST(select.limitLength(), rpn_limit); - - if (select.orderBy()) - traverseOrderByAST(select.orderBy(), rpn_order_by); - - /// Reverse RPNs for conveniences during parsing - std::reverse(rpn_order_by.begin(), rpn_order_by.end()); - - const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by, order_by_info); - const bool limit_is_valid = matchRPNLimit(rpn_limit, limit); - - if (!limit_is_valid || limit > max_limit_for_ann_queries) - return false; - - if (order_by_is_valid) - { - query_information = std::move(order_by_info); - query_information->limit = limit; - return true; - } - - return false; -} - -void VectorSimilarityCondition::traverseAST(const ASTPtr & node, RPN & rpn) -{ - /// If the node is ASTFunction, it may have children nodes - if (const auto * func = node->as()) - { - const ASTs & children = func->arguments->children; - /// Traverse children nodes - for (const auto& child : children) - traverseAST(child, rpn); - } - - RPNElement element; - /// Get the data behind node - if (!traverseAtomAST(node, element)) - element.function = RPNElement::FUNCTION_UNKNOWN; - - rpn.emplace_back(std::move(element)); -} - -bool VectorSimilarityCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out) -{ - /// Match Functions - if (const auto * function = node->as()) - { - /// Set the name - out.func_name = function->name; - - if (function->name == "L1Distance" || - function->name == "L2Distance" || - function->name == "LinfDistance" || - function->name == "cosineDistance" || - function->name == "dotProduct") - out.function = RPNElement::FUNCTION_DISTANCE; - else if (function->name == "array") - out.function = RPNElement::FUNCTION_ARRAY; - else if (function->name == "_CAST") - out.function = RPNElement::FUNCTION_CAST; - else - return false; - - return true; - } - /// Match identifier - if (const auto * identifier = node->as()) - { - out.function = RPNElement::FUNCTION_IDENTIFIER; - out.identifier.emplace(identifier->name()); - out.func_name = "column identifier"; - - return true; - } - - /// Check if we have constants behind the node - return tryCastToConstType(node, out); -} - -bool VectorSimilarityCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out) -{ - Field const_value; - DataTypePtr const_type; - - if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type)) - { - /// Check for constant types - if (const_value.getType() == Field::Types::Float64) - { - out.function = RPNElement::FUNCTION_FLOAT_LITERAL; - out.float_literal.emplace(const_value.safeGet()); - out.func_name = "Float literal"; - return true; - } - - if (const_value.getType() == Field::Types::UInt64) - { - out.function = RPNElement::FUNCTION_INT_LITERAL; - out.int_literal.emplace(const_value.safeGet()); - out.func_name = "Int literal"; - return true; - } - - if (const_value.getType() == Field::Types::Int64) - { - out.function = RPNElement::FUNCTION_INT_LITERAL; - out.int_literal.emplace(const_value.safeGet()); - out.func_name = "Int literal"; - return true; - } - - if (const_value.getType() == Field::Types::Array) - { - out.function = RPNElement::FUNCTION_LITERAL_ARRAY; - out.array_literal = const_value.safeGet(); - out.func_name = "Array literal"; - return true; - } - - if (const_value.getType() == Field::Types::String) - { - out.function = RPNElement::FUNCTION_STRING_LITERAL; - out.func_name = const_value.safeGet(); - return true; - } - } - - return false; -} - -void VectorSimilarityCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn) -{ - if (const auto * expr_list = node->as()) - if (const auto * order_by_element = expr_list->children.front()->as()) - traverseAST(order_by_element->children.front(), rpn); -} - -/// Returns true and stores ANNExpr if the query has valid ORDERBY clause -bool VectorSimilarityCondition::matchRPNOrderBy(RPN & rpn, Info & info) -{ - /// ORDER BY clause must have at least 3 expressions - if (rpn.size() < 3) - return false; - - auto iter = rpn.begin(); - auto end = rpn.end(); - - bool identifier_found = false; - - /// Matches DistanceFunc->[Column]->[ArrayFunc]->ReferenceVector(floats)->[Column] - if (iter->function != RPNElement::FUNCTION_DISTANCE) - return false; - - info.distance_function = stringToDistanceFunction(iter->func_name); - ++iter; - - if (iter->function == RPNElement::FUNCTION_IDENTIFIER) - { - identifier_found = true; - info.column_name = std::move(iter->identifier.value()); - ++iter; - } - - if (iter->function == RPNElement::FUNCTION_ARRAY) - ++iter; - - if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY) - { - extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal); - ++iter; - } - - /// further conditions are possible if there is no array, or no identifier is found - /// the array can be inside a cast function. For other cases, see the loop after this condition - if (iter != end && iter->function == RPNElement::FUNCTION_CAST) - { - ++iter; - /// Cast should be made to array - if (!iter->func_name.starts_with("Array")) - return false; - ++iter; - if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY) - { - extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal); - ++iter; - } - else - return false; - } - - while (iter != end) - { - if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL || - iter->function == RPNElement::FUNCTION_INT_LITERAL) - info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter)); - else if (iter->function == RPNElement::FUNCTION_IDENTIFIER) - { - if (identifier_found) - return false; - info.column_name = std::move(iter->identifier.value()); - identifier_found = true; - } - else - return false; - - ++iter; - } - - /// Final checks of correctness - return identifier_found && !info.reference_vector.empty(); -} - -/// Returns true and stores Length if we have valid LIMIT clause in query -bool VectorSimilarityCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit) -{ - if (rpn.function == RPNElement::FUNCTION_INT_LITERAL) - { - limit = rpn.int_literal.value(); - return true; - } - - return false; -} - -/// Gets float or int from AST node -float VectorSimilarityCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter) -{ - if (iter->float_literal.has_value()) - return iter->float_literal.value(); - if (iter->int_literal.has_value()) - return static_cast(iter->int_literal.value()); - throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong parsed AST in buildRPN\n"); -} - -} diff --git a/src/Storages/MergeTree/VectorSimilarityCondition.h b/src/Storages/MergeTree/VectorSimilarityCondition.h deleted file mode 100644 index 86e77e88d33..00000000000 --- a/src/Storages/MergeTree/VectorSimilarityCondition.h +++ /dev/null @@ -1,147 +0,0 @@ -#pragma once - -#include -#include "base/types.h" - -#include -#include - -namespace DB -{ - -/// Class VectorSimilarityCondition is responsible for recognizing if the query -/// can utilize vector similarity indexes. -/// -/// Method alwaysUnknownOrTrue returns false if we can speed up the query, and true otherwise. -/// It has only one argument, the name of the distance function with which index was built. -/// Only queries with ORDER BY DistanceFunc and LIMIT, i.e.: -/// -/// SELECT * FROM * ... ORDER BY DistanceFunc(column, reference_vector) LIMIT count -/// -/// reference_vector should have float coordinates, e.g. [0.2, 0.1, .., 0.5] -class VectorSimilarityCondition -{ -public: - VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context); - - /// vector similarity queries have a similar structure: - /// - reference vector from which all distances are calculated - /// - distance function, e.g L2Distance - /// - name of column with embeddings - /// - type of query - /// - maximum number of returned elements (LIMIT) - /// - /// And one optional parameter: - /// - distance to compare with (only for where queries) - /// - /// This struct holds all these components. - struct Info - { - enum class DistanceFunction : uint8_t - { - Unknown, - L2, - Cosine - }; - - std::vector reference_vector; - DistanceFunction distance_function; - String column_name; - UInt64 limit; - float distance = -1.0; - }; - - /// Returns false if query can be speeded up by an ANN index, true otherwise. - bool alwaysUnknownOrTrue(const String & distance_function) const; - - std::vector getReferenceVector() const; - size_t getDimensions() const; - String getColumnName() const; - Info::DistanceFunction getDistanceFunction() const; - UInt64 getLimit() const; - -private: - struct RPNElement - { - enum Function - { - /// DistanceFunctions - FUNCTION_DISTANCE, - - /// array(0.1, ..., 0.1) - FUNCTION_ARRAY, - - /// Operators <, >, <=, >= - FUNCTION_COMPARISON, - - /// Numeric float value - FUNCTION_FLOAT_LITERAL, - - /// Numeric int value - FUNCTION_INT_LITERAL, - - /// Column identifier - FUNCTION_IDENTIFIER, - - /// Unknown, can be any value - FUNCTION_UNKNOWN, - - /// [0.1, ...., 0.1] vector without word 'array' - FUNCTION_LITERAL_ARRAY, - - /// if client parameters are used, cast will always be in the query - FUNCTION_CAST, - - /// name of type in cast function - FUNCTION_STRING_LITERAL, - }; - - explicit RPNElement(Function function_ = FUNCTION_UNKNOWN) - : function(function_) - {} - - Function function; - String func_name = "Unknown"; - - std::optional float_literal; - std::optional identifier; - std::optional int_literal; - std::optional array_literal; - - UInt32 dim = 0; - }; - - using RPN = std::vector; - - bool checkQueryStructure(const SelectQueryInfo & query); - - /// Util functions for the traversal of AST, parses AST and builds rpn - void traverseAST(const ASTPtr & node, RPN & rpn); - /// Return true if we can identify our node type - bool traverseAtomAST(const ASTPtr & node, RPNElement & out); - /// Checks if the AST stores ConstType expression - bool tryCastToConstType(const ASTPtr & node, RPNElement & out); - /// Traverses the AST of ORDERBY section - void traverseOrderByAST(const ASTPtr & node, RPN & rpn); - - /// Returns true and stores ANNExpr if the query has valid ORDERBY section - static bool matchRPNOrderBy(RPN & rpn, Info & info); - - /// Returns true and stores Length if we have valid LIMIT clause in query - static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit); - - /// Gets float or int from AST node - static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter); - - Block block_with_constants; - - /// true if we have one of two supported query types - std::optional query_information; - - /// only queries with a lower limit can be considered to avoid memory overflow - const UInt64 max_limit_for_ann_queries; - - bool index_is_useful = false; -}; - -} diff --git a/tests/queries/0_stateless/02354_vector_search_queries.sql b/tests/queries/0_stateless/02354_vector_search_queries.sql index 0941f9a43d6..930fc9a3c60 100644 --- a/tests/queries/0_stateless/02354_vector_search_queries.sql +++ b/tests/queries/0_stateless/02354_vector_search_queries.sql @@ -2,6 +2,21 @@ -- Tests various simple approximate nearest neighborhood (ANN) queries that utilize vector search indexes. +SET allow_experimental_vector_similarity_index = 1; +SET enable_analyzer = 0; + +CREATE OR REPLACE TABLE tab(id Int32, vec Array(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 8192; +INSERT INTO tab VALUES (0, [1.0, 0.0]), (1, [1.1, 0.0]), (2, [1.2, 0.0]), (3, [1.3, 0.0]), (4, [1.4, 0.0]), (5, [0.0, 2.0]), (6, [0.0, 2.1]), (7, [0.0, 2.2]), (8, [0.0, 2.3]), (9, [0.0, 2.4]); + +SELECT id, vec, L2Distance(vec, [0.0, 2.0]) +FROM tab +ORDER BY L2Distance(vec, [0.0, 2.0]) +LIMIT 3; + + + + + SET allow_experimental_vector_similarity_index = 1; SET enable_analyzer = 0;