Initial Analyzer support for vector similarity search

This code is extremely shameful, don't look at it.

Only goal of this PR is to demonstrate that the analyzer is able to see
vector similarity indexes:

ip-10-19-83-181.eu-central-1.compute.internal :) explain plan indexes=1 SELECT id, vec, L2Distance(vec, [0.0, 2.0])
FROM tab
ORDER BY L2Distance(vec, [0.0, 2.0])
LIMIT 3 settings enable_analyzer = 1

EXPLAIN indexes = 1
SELECT
    id,
    vec,
    L2Distance(vec, [0., 2.])
FROM tab
ORDER BY L2Distance(vec, [0., 2.]) ASC
LIMIT 3
SETTINGS enable_analyzer = 1

Query id: 086df46e-2e27-4342-a47c-6900762a7c8c

    ┌─explain─────────────────────────────────────────────────────────────────────────────────────────┐
 1. │ Expression (Project names)                                                                      │
 2. │   Limit (preliminary LIMIT (without OFFSET))                                                    │
 3. │     Sorting (Sorting for ORDER BY)                                                              │
 4. │       Expression ((Before ORDER BY + (Projection + Change column names to column identifiers))) │
 5. │         ReadFromMergeTree (default.tab)                                                         │
 6. │         Indexes:                                                                                │
 7. │           PrimaryKey                                                                            │
 8. │             Condition: true                                                                     │
 9. │             Parts: 1/1                                                                          │
10. │             Granules: 1/1                                                                       │
11. │           Skip                                                                                  │
12. │             Name: idx                                                                           │
13. │             Description: vector_similarity GRANULARITY 100000000                                │
14. │             Parts: 1/1                                                                          │
15. │             Granules: 1/1                                                                       │
    └─────────────────────────────────────────────────────────────────────────────────────────────────┘

15 rows in set. Elapsed: 0.007 sec.
This commit is contained in:
Robert Schulze 2024-11-06 20:34:05 +00:00
parent bac948ec0e
commit fa635358d2
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
18 changed files with 212 additions and 545 deletions

View File

@ -84,8 +84,6 @@ to load and compare. The library also has several hardware-specific SIMD optimiz
Arm (NEON and SVE) and x86 (AVX2 and AVX-512) CPUs and OS-specific optimizations to allow efficient navigation around immutable persistent
files, without loading them into RAM.
USearch indexes are currently experimental, to use them you first need to `SET allow_experimental_vector_similarity_index = 1`.
Vector similarity indexes currently support two distance functions:
- `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space
([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)).

View File

@ -71,6 +71,8 @@ void tryRemoveRedundantSorting(QueryPlan::Node * root);
/// Remove redundant distinct steps
size_t tryRemoveRedundantDistinct(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
/// Put some steps under union, so that plan optimization could be applied to union parts separately.
/// For example, the plan can be rewritten like:
/// - Something - - Expression - Something -
@ -82,7 +84,7 @@ size_t tryAggregatePartitionsIndependently(QueryPlan::Node * node, QueryPlan::No
inline const auto & getOptimizations()
{
static const std::array<Optimization, 12> optimizations = {{
static const std::array<Optimization, 13> optimizations = {{
{tryLiftUpArrayJoin, "liftUpArrayJoin", &QueryPlanOptimizationSettings::lift_up_array_join},
{tryPushDownLimit, "pushDownLimit", &QueryPlanOptimizationSettings::push_down_limit},
{trySplitFilter, "splitFilter", &QueryPlanOptimizationSettings::split_filter},
@ -95,6 +97,7 @@ inline const auto & getOptimizations()
{tryLiftUpUnion, "liftUpUnion", &QueryPlanOptimizationSettings::lift_up_union},
{tryAggregatePartitionsIndependently, "aggregatePartitionsIndependently", &QueryPlanOptimizationSettings::aggregate_partitions_independently},
{tryRemoveRedundantDistinct, "removeRedundantDistinct", &QueryPlanOptimizationSettings::remove_redundant_distinct},
{tryUseVectorSearch, "useVectorSearch", &QueryPlanOptimizationSettings::use_vector_search},
}};
return optimizations;
@ -110,7 +113,9 @@ using Stack = std::vector<Frame>;
/// Second pass optimizations
void optimizePrimaryKeyConditionAndLimit(const Stack & stack);
void optimizeVectorSearch(const Stack & stack);
void optimizePrewhere(Stack & stack, QueryPlan::Nodes & nodes);
void optimizeVectorSearch(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
void optimizeReadInOrder(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
void optimizeAggregationInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);
void optimizeJoin(QueryPlan::Node & node, QueryPlan::Nodes &);

View File

@ -70,6 +70,9 @@ struct QueryPlanOptimizationSettings
/// If remove-redundant-distinct-steps optimization is enabled.
bool remove_redundant_distinct = true;
/// If use vector search is enabled, the query will use the vector similarity index
bool use_vector_search = true;
bool optimize_prewhere = true;
/// If reading from projection can be applied

View File

@ -14,6 +14,7 @@
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/FilterStep.h>
#include <Processors/QueryPlan/ITransformingStep.h>
#include <Processors/QueryPlan/LimitStep.h>
#include <Processors/QueryPlan/JoinStep.h>
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
#include <Processors/QueryPlan/Optimizations/actionsDAGUtils.h>

View File

@ -126,7 +126,6 @@ void optimizeTreeSecondPass(const QueryPlanOptimizationSettings & optimization_s
if (frame.next_child == 0)
{
if (optimization_settings.read_in_order)
optimizeReadInOrder(*frame.node, nodes);

View File

@ -0,0 +1,53 @@
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/FilterStep.h>
#include <Processors/QueryPlan/LimitStep.h>
#include <Processors/QueryPlan/SourceStepWithFilter.h>
namespace DB::QueryPlanOptimizations
{
void xx([[maybe_unused]] const QueryPlan::Node & node)
{
/// const auto & frame = stack.back();
///
/// auto * source_step_with_filter = dynamic_cast<SourceStepWithFilter *>(frame.node->step.get());
/// if (!source_step_with_filter)
/// return;
///
/// const auto & storage_prewhere_info = source_step_with_filter->getPrewhereInfo();
/// if (storage_prewhere_info)
/// {
/// source_step_with_filter->addFilter(storage_prewhere_info->prewhere_actions.clone(), storage_prewhere_info->prewhere_column_name);
/// if (storage_prewhere_info->row_level_filter)
/// source_step_with_filter->addFilter(storage_prewhere_info->row_level_filter->clone(), storage_prewhere_info->row_level_column_name);
/// }
///
/// for (auto iter = stack.rbegin() + 1; iter != stack.rend(); ++iter)
/// {
/// if (auto * filter_step = typeid_cast<FilterStep *>(iter->node->step.get()))
/// {
/// source_step_with_filter->addFilter(filter_step->getExpression().clone(), filter_step->getFilterColumnName());
/// }
/// else if (auto * limit_step = typeid_cast<LimitStep *>(iter->node->step.get()))
/// {
/// source_step_with_filter->setLimit(limit_step->getLimitForSorting());
/// break;
/// }
/// else if (typeid_cast<ExpressionStep *>(iter->node->step.get()))
/// {
/// /// Note: actually, plan optimizations merge Filter and Expression steps.
/// /// Ideally, chain should look like (Expression -> ...) -> (Filter -> ...) -> ReadFromStorage,
/// /// So this is likely not needed.
/// continue;
/// }
/// else
/// {
/// break;
/// }
/// }
///
/// source_step_with_filter->applyFilters();
}
}

View File

@ -0,0 +1,76 @@
#include <memory>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/ReadFromMergeTree.h>
#include <Processors/QueryPlan/LimitStep.h>
#include <Processors/QueryPlan/SortingStep.h>
namespace DB::QueryPlanOptimizations
{
size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & /* nodes*/)
{
auto * limit_step = typeid_cast<LimitStep *>(parent_node->step.get());
if (!limit_step)
return 0;
if (parent_node->children.size() != 1)
return 0;
QueryPlan::Node * child_node = parent_node->children.front();
auto * sorting_step = typeid_cast<SortingStep *>(child_node->step.get());
if (!sorting_step)
return 0;
if (child_node->children.size() != 1)
return 0;
child_node = child_node->children.front();
auto * expression_step = typeid_cast<ExpressionStep *>(child_node->step.get());
if (!expression_step)
return 0;
if (child_node->children.size() != 1)
return 0;
child_node = child_node->children.front();
auto * read_from_mergetree_step = typeid_cast<ReadFromMergeTree *>(child_node->step.get());
if (!read_from_mergetree_step)
return 0;
const auto & sort_description = sorting_step->getSortDescription();
if (sort_description.size() != 1)
return 0;
[[maybe_unused]] ReadFromMergeTree::DistanceFunction distance_function;
/// lol
if (sort_description[0].column_name.starts_with("L2Distance"))
distance_function = ReadFromMergeTree::DistanceFunction::L2Distance;
else if (sort_description[0].column_name.starts_with("cosineDistance"))
distance_function = ReadFromMergeTree::DistanceFunction::cosineDistance;
else
return 0;
[[maybe_unused]] size_t limit = sorting_step->getLimit();
[[maybe_unused]] std::vector<Float64> reference_vector = {0.0, 0.2}; /// TODO
/// TODO check that ReadFromMergeTree has a vector similarity index
read_from_mergetree_step->vec_sim_idx_input = std::make_optional<ReadFromMergeTree::VectorSimilarityIndexInput>(
distance_function, limit, reference_vector);
/// --- --- ---
/// alwaysUnknownOrTrue:
///
/// String index_distance_function;
/// switch (metric_kind)
/// {
/// case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
/// case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
/// default: std::unreachable();
/// }
/// return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
return 0;
}
}

View File

@ -1457,7 +1457,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(Merge
all_column_names,
log,
indexes,
find_exact_ranges);
find_exact_ranges,
vec_sim_idx_input);
}
static void buildIndexes(
@ -1650,7 +1651,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(
const Names & all_column_names,
LoggerPtr log,
std::optional<Indexes> & indexes,
bool find_exact_ranges)
bool find_exact_ranges,
const std::optional<VectorSimilarityIndexInput> & vec_sim_idx_input)
{
AnalysisResult result;
const auto & settings = context_->getSettingsRef();
@ -1743,7 +1745,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(
num_streams,
result.index_stats,
indexes->use_skip_indexes,
find_exact_ranges);
find_exact_ranges,
vec_sim_idx_input);
}
size_t sum_marks_pk = total_marks_pk;

View File

@ -166,6 +166,19 @@ public:
std::optional<std::unordered_set<String>> part_values;
};
enum class DistanceFunction
{
L2Distance,
cosineDistance
};
struct VectorSimilarityIndexInput
{
DistanceFunction distance_function;
size_t limit;
std::vector<Float64> reference_vector;
};
static AnalysisResultPtr selectRangesToRead(
MergeTreeData::DataPartsVector parts,
MergeTreeData::MutationsSnapshotPtr mutations_snapshot,
@ -178,7 +191,8 @@ public:
const Names & all_column_names,
LoggerPtr log,
std::optional<Indexes> & indexes,
bool find_exact_ranges);
bool find_exact_ranges,
const std::optional<VectorSimilarityIndexInput> & vec_sim_idx_input);
AnalysisResultPtr selectRangesToRead(MergeTreeData::DataPartsVector parts, bool find_exact_ranges = false) const;
@ -212,6 +226,8 @@ public:
void applyFilters(ActionDAGNodes added_filter_nodes) override;
std::optional<VectorSimilarityIndexInput> vec_sim_idx_input;
private:
MergeTreeReaderSettings reader_settings;

View File

@ -11,7 +11,6 @@
#include <Storages/MergeTree/MergeTreeDataPartUUID.h>
#include <Storages/MergeTree/StorageFromMergeTreeDataPart.h>
#include <Storages/MergeTree/MergeTreeIndexFullText.h>
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Storages/ReadInOrderOptimizer.h>
#include <Storages/VirtualColumnUtils.h>
#include <Parsers/ASTIdentifier.h>
@ -629,7 +628,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd
size_t num_streams,
ReadFromMergeTree::IndexStats & index_stats,
bool use_skip_indexes,
bool find_exact_ranges)
bool find_exact_ranges,
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input)
{
RangesInDataParts parts_with_ranges;
parts_with_ranges.resize(parts.size());
@ -734,7 +734,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd
reader_settings,
mark_cache.get(),
uncompressed_cache.get(),
log);
log,
vec_sim_idx_input);
stat.granules_dropped.fetch_add(total_granules - ranges.ranges.getNumberOfMarks(), std::memory_order_relaxed);
if (ranges.ranges.empty())
@ -950,7 +951,8 @@ ReadFromMergeTree::AnalysisResultPtr MergeTreeDataSelectExecutor::estimateNumMar
column_names_to_return,
log,
indexes,
/*find_exact_ranges*/false);
/*find_exact_ranges*/false,
std::nullopt);
}
QueryPlanStepPtr MergeTreeDataSelectExecutor::readFromParts(
@ -1360,7 +1362,8 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
const MergeTreeReaderSettings & reader_settings,
MarkCache * mark_cache,
UncompressedCache * uncompressed_cache,
LoggerPtr log)
LoggerPtr log,
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input)
{
if (!index_helper->getDeserializedFormat(part->getDataPartStorage(), index_helper->getFileName()))
{
@ -1421,9 +1424,9 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin)
reader.read(granule);
if (index_helper->isVectorSimilarityIndex())
if (index_helper->isVectorSimilarityIndex() && vec_sim_idx_input)
{
auto rows = condition->calculateApproximateNearestNeighbors(granule);
auto rows = condition->calculateApproximateNearestNeighbors(granule, vec_sim_idx_input->limit, vec_sim_idx_input->reference_vector); /// TODO
for (auto row : rows)
{

View File

@ -94,7 +94,8 @@ private:
const MergeTreeReaderSettings & reader_settings,
MarkCache * mark_cache,
UncompressedCache * uncompressed_cache,
LoggerPtr log);
LoggerPtr log,
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input);
static MarkRanges filterMarksUsingMergedIndex(
MergeTreeIndices indices,
@ -197,7 +198,8 @@ public:
size_t num_streams,
ReadFromMergeTree::IndexStats & index_stats,
bool use_skip_indexes,
bool find_exact_ranges);
bool find_exact_ranges,
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input);
/// Create expression for sampling.
/// Also, calculate _sample_factor if needed.

View File

@ -1,6 +1,6 @@
#pragma once
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Storages/MergeTree/MergeTreeIndices.h>
/// Walking corpse implementation for removed skipping index of type "annoy" and "usearch".
/// Its only purpose is to allow loading old tables with indexes of these types.

View File

@ -401,11 +401,9 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_
MergeTreeIndexConditionVectorSimilarity::MergeTreeIndexConditionVectorSimilarity(
const IndexDescription & /*index_description*/,
const SelectQueryInfo & query,
unum::usearch::metric_kind_t metric_kind_,
ContextPtr context)
: vector_similarity_condition(query, context)
, metric_kind(metric_kind_)
: metric_kind(metric_kind_)
, expansion_search(context->getSettingsRef()[Setting::hnsw_candidate_list_size_for_search])
{
if (expansion_search == 0)
@ -420,31 +418,23 @@ bool MergeTreeIndexConditionVectorSimilarity::mayBeTrueOnGranule(MergeTreeIndexG
bool MergeTreeIndexConditionVectorSimilarity::alwaysUnknownOrTrue() const
{
String index_distance_function;
switch (metric_kind)
{
case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
default: std::unreachable();
}
return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
return false;
}
std::vector<UInt64> MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule_) const
std::vector<UInt64> MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(
MergeTreeIndexGranulePtr granule_,
size_t limit,
const std::vector<Float64> & reference_vector) const
{
const UInt64 limit = vector_similarity_condition.getLimit();
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleVectorSimilarity>(granule_);
if (granule == nullptr)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
const USearchIndexWithSerializationPtr index = granule->index;
if (vector_similarity_condition.getDimensions() != index->dimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) does not match the dimension in the index ({})",
vector_similarity_condition.getDimensions(), index->dimensions());
const std::vector<Float64> reference_vector = vector_similarity_condition.getReferenceVector();
if (reference_vector.size() != index->dimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the reference vector in the query ({}) does not match the dimension in the index ({})",
reference_vector.size(), index->dimensions());
/// We want to run the search with the user-provided value for setting hnsw_candidate_list_size_for_search (aka. expansion_search).
/// The way to do this in USearch is to call index_dense_gt::change_expansion_search. Unfortunately, this introduces a need to
@ -498,9 +488,9 @@ MergeTreeIndexAggregatorPtr MergeTreeIndexVectorSimilarity::createIndexAggregato
return std::make_shared<MergeTreeIndexAggregatorVectorSimilarity>(index.name, index.sample_block, metric_kind, scalar_kind, usearch_hnsw_params);
}
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & /*query*/, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, query, metric_kind, context);
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, metric_kind, context);
};
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const ActionsDAG *, ContextPtr) const

View File

@ -4,7 +4,7 @@
#if USE_USEARCH
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Storages/MergeTree/MergeTreeIndices.h>
#include <Common/Logger.h>
#include <usearch/index_dense.hpp>
@ -136,7 +136,6 @@ class MergeTreeIndexConditionVectorSimilarity final : public IMergeTreeIndexCond
public:
MergeTreeIndexConditionVectorSimilarity(
const IndexDescription & index_description,
const SelectQueryInfo & query,
unum::usearch::metric_kind_t metric_kind_,
ContextPtr context);
@ -144,10 +143,12 @@ public:
bool alwaysUnknownOrTrue() const override;
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr granule) const override;
std::vector<UInt64> calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule) const override;
std::vector<UInt64> calculateApproximateNearestNeighbors(
MergeTreeIndexGranulePtr granule,
size_t limit,
const std::vector<Float64> & reference_vector) const override;
private:
const VectorSimilarityCondition vector_similarity_condition;
const unum::usearch::metric_kind_t metric_kind;
const size_t expansion_search;
};

View File

@ -99,7 +99,10 @@ public:
/// Special method for vector similarity indexes:
/// Returns the row positions of the N nearest neighbors in the index granule
/// The returned row numbers are guaranteed to be sorted and unique.
virtual std::vector<UInt64> calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr) const
virtual std::vector<UInt64> calculateApproximateNearestNeighbors(
MergeTreeIndexGranulePtr /*granule*/,
size_t /*limit*/,
const std::vector<Float64> & /*reference_vector*/) const
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "calculateApproximateNearestNeighbors is not implemented for non-vector-similarity indexes");
}

View File

@ -1,354 +0,0 @@
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Core/Settings.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Storages/MergeTree/KeyCondition.h>
#include <Storages/MergeTree/MergeTreeSettings.h>
namespace DB
{
namespace Setting
{
extern const SettingsUInt64 max_limit_for_ann_queries;
}
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int INCORRECT_QUERY;
}
namespace
{
template <typename Literal>
void extractReferenceVectorFromLiteral(std::vector<Float64> & reference_vector, Literal literal)
{
Float64 float_element_of_reference_vector;
Int64 int_element_of_reference_vector;
for (const auto & value : literal.value())
{
if (value.tryGet(float_element_of_reference_vector))
reference_vector.emplace_back(float_element_of_reference_vector);
else if (value.tryGet(int_element_of_reference_vector))
reference_vector.emplace_back(static_cast<float>(int_element_of_reference_vector));
else
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported.");
}
}
VectorSimilarityCondition::Info::DistanceFunction stringToDistanceFunction(const String & distance_function)
{
if (distance_function == "L2Distance")
return VectorSimilarityCondition::Info::DistanceFunction::L2;
if (distance_function == "cosineDistance")
return VectorSimilarityCondition::Info::DistanceFunction::Cosine;
return VectorSimilarityCondition::Info::DistanceFunction::Unknown;
}
}
VectorSimilarityCondition::VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context)
: block_with_constants(KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context))
, max_limit_for_ann_queries(context->getSettingsRef()[Setting::max_limit_for_ann_queries])
, index_is_useful(checkQueryStructure(query_info))
{}
bool VectorSimilarityCondition::alwaysUnknownOrTrue(const String & distance_function) const
{
if (!index_is_useful)
return true; /// query isn't supported
/// If query is supported, check if distance function of index is the same as distance function in query
return !(stringToDistanceFunction(distance_function) == query_information->distance_function);
}
UInt64 VectorSimilarityCondition::getLimit() const
{
if (index_is_useful && query_information.has_value())
return query_information->limit;
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
}
std::vector<Float64> VectorSimilarityCondition::getReferenceVector() const
{
if (index_is_useful && query_information.has_value())
return query_information->reference_vector;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index.");
}
size_t VectorSimilarityCondition::getDimensions() const
{
if (index_is_useful && query_information.has_value())
return query_information->reference_vector.size();
throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index.");
}
String VectorSimilarityCondition::getColumnName() const
{
if (index_is_useful && query_information.has_value())
return query_information->column_name;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index.");
}
VectorSimilarityCondition::Info::DistanceFunction VectorSimilarityCondition::getDistanceFunction() const
{
if (index_is_useful && query_information.has_value())
return query_information->distance_function;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Distance function was requested for useless or uninitialized index.");
}
bool VectorSimilarityCondition::checkQueryStructure(const SelectQueryInfo & query)
{
Info order_by_info;
/// Build rpns for query sections
const auto & select = query.query->as<ASTSelectQuery &>();
RPN rpn_order_by;
RPNElement rpn_limit;
UInt64 limit;
if (select.limitLength())
traverseAtomAST(select.limitLength(), rpn_limit);
if (select.orderBy())
traverseOrderByAST(select.orderBy(), rpn_order_by);
/// Reverse RPNs for conveniences during parsing
std::reverse(rpn_order_by.begin(), rpn_order_by.end());
const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by, order_by_info);
const bool limit_is_valid = matchRPNLimit(rpn_limit, limit);
if (!limit_is_valid || limit > max_limit_for_ann_queries)
return false;
if (order_by_is_valid)
{
query_information = std::move(order_by_info);
query_information->limit = limit;
return true;
}
return false;
}
void VectorSimilarityCondition::traverseAST(const ASTPtr & node, RPN & rpn)
{
/// If the node is ASTFunction, it may have children nodes
if (const auto * func = node->as<ASTFunction>())
{
const ASTs & children = func->arguments->children;
/// Traverse children nodes
for (const auto& child : children)
traverseAST(child, rpn);
}
RPNElement element;
/// Get the data behind node
if (!traverseAtomAST(node, element))
element.function = RPNElement::FUNCTION_UNKNOWN;
rpn.emplace_back(std::move(element));
}
bool VectorSimilarityCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
{
/// Match Functions
if (const auto * function = node->as<ASTFunction>())
{
/// Set the name
out.func_name = function->name;
if (function->name == "L1Distance" ||
function->name == "L2Distance" ||
function->name == "LinfDistance" ||
function->name == "cosineDistance" ||
function->name == "dotProduct")
out.function = RPNElement::FUNCTION_DISTANCE;
else if (function->name == "array")
out.function = RPNElement::FUNCTION_ARRAY;
else if (function->name == "_CAST")
out.function = RPNElement::FUNCTION_CAST;
else
return false;
return true;
}
/// Match identifier
if (const auto * identifier = node->as<ASTIdentifier>())
{
out.function = RPNElement::FUNCTION_IDENTIFIER;
out.identifier.emplace(identifier->name());
out.func_name = "column identifier";
return true;
}
/// Check if we have constants behind the node
return tryCastToConstType(node, out);
}
bool VectorSimilarityCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
{
Field const_value;
DataTypePtr const_type;
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
{
/// Check for constant types
if (const_value.getType() == Field::Types::Float64)
{
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
out.float_literal.emplace(const_value.safeGet<Float32>());
out.func_name = "Float literal";
return true;
}
if (const_value.getType() == Field::Types::UInt64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.safeGet<UInt64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Int64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.safeGet<Int64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Array)
{
out.function = RPNElement::FUNCTION_LITERAL_ARRAY;
out.array_literal = const_value.safeGet<Array>();
out.func_name = "Array literal";
return true;
}
if (const_value.getType() == Field::Types::String)
{
out.function = RPNElement::FUNCTION_STRING_LITERAL;
out.func_name = const_value.safeGet<String>();
return true;
}
}
return false;
}
void VectorSimilarityCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn)
{
if (const auto * expr_list = node->as<ASTExpressionList>())
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
traverseAST(order_by_element->children.front(), rpn);
}
/// Returns true and stores ANNExpr if the query has valid ORDERBY clause
bool VectorSimilarityCondition::matchRPNOrderBy(RPN & rpn, Info & info)
{
/// ORDER BY clause must have at least 3 expressions
if (rpn.size() < 3)
return false;
auto iter = rpn.begin();
auto end = rpn.end();
bool identifier_found = false;
/// Matches DistanceFunc->[Column]->[ArrayFunc]->ReferenceVector(floats)->[Column]
if (iter->function != RPNElement::FUNCTION_DISTANCE)
return false;
info.distance_function = stringToDistanceFunction(iter->func_name);
++iter;
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
identifier_found = true;
info.column_name = std::move(iter->identifier.value());
++iter;
}
if (iter->function == RPNElement::FUNCTION_ARRAY)
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
++iter;
}
/// further conditions are possible if there is no array, or no identifier is found
/// the array can be inside a cast function. For other cases, see the loop after this condition
if (iter != end && iter->function == RPNElement::FUNCTION_CAST)
{
++iter;
/// Cast should be made to array
if (!iter->func_name.starts_with("Array"))
return false;
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
++iter;
}
else
return false;
}
while (iter != end)
{
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter));
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
if (identifier_found)
return false;
info.column_name = std::move(iter->identifier.value());
identifier_found = true;
}
else
return false;
++iter;
}
/// Final checks of correctness
return identifier_found && !info.reference_vector.empty();
}
/// Returns true and stores Length if we have valid LIMIT clause in query
bool VectorSimilarityCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
{
if (rpn.function == RPNElement::FUNCTION_INT_LITERAL)
{
limit = rpn.int_literal.value();
return true;
}
return false;
}
/// Gets float or int from AST node
float VectorSimilarityCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter)
{
if (iter->float_literal.has_value())
return iter->float_literal.value();
if (iter->int_literal.has_value())
return static_cast<float>(iter->int_literal.value());
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong parsed AST in buildRPN\n");
}
}

View File

@ -1,147 +0,0 @@
#pragma once
#include <Storages/MergeTree/MergeTreeIndices.h>
#include "base/types.h"
#include <optional>
#include <vector>
namespace DB
{
/// Class VectorSimilarityCondition is responsible for recognizing if the query
/// can utilize vector similarity indexes.
///
/// Method alwaysUnknownOrTrue returns false if we can speed up the query, and true otherwise.
/// It has only one argument, the name of the distance function with which index was built.
/// Only queries with ORDER BY DistanceFunc and LIMIT, i.e.:
///
/// SELECT * FROM * ... ORDER BY DistanceFunc(column, reference_vector) LIMIT count
///
/// reference_vector should have float coordinates, e.g. [0.2, 0.1, .., 0.5]
class VectorSimilarityCondition
{
public:
VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context);
/// vector similarity queries have a similar structure:
/// - reference vector from which all distances are calculated
/// - distance function, e.g L2Distance
/// - name of column with embeddings
/// - type of query
/// - maximum number of returned elements (LIMIT)
///
/// And one optional parameter:
/// - distance to compare with (only for where queries)
///
/// This struct holds all these components.
struct Info
{
enum class DistanceFunction : uint8_t
{
Unknown,
L2,
Cosine
};
std::vector<Float64> reference_vector;
DistanceFunction distance_function;
String column_name;
UInt64 limit;
float distance = -1.0;
};
/// Returns false if query can be speeded up by an ANN index, true otherwise.
bool alwaysUnknownOrTrue(const String & distance_function) const;
std::vector<Float64> getReferenceVector() const;
size_t getDimensions() const;
String getColumnName() const;
Info::DistanceFunction getDistanceFunction() const;
UInt64 getLimit() const;
private:
struct RPNElement
{
enum Function
{
/// DistanceFunctions
FUNCTION_DISTANCE,
/// array(0.1, ..., 0.1)
FUNCTION_ARRAY,
/// Operators <, >, <=, >=
FUNCTION_COMPARISON,
/// Numeric float value
FUNCTION_FLOAT_LITERAL,
/// Numeric int value
FUNCTION_INT_LITERAL,
/// Column identifier
FUNCTION_IDENTIFIER,
/// Unknown, can be any value
FUNCTION_UNKNOWN,
/// [0.1, ...., 0.1] vector without word 'array'
FUNCTION_LITERAL_ARRAY,
/// if client parameters are used, cast will always be in the query
FUNCTION_CAST,
/// name of type in cast function
FUNCTION_STRING_LITERAL,
};
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
: function(function_)
{}
Function function;
String func_name = "Unknown";
std::optional<float> float_literal;
std::optional<String> identifier;
std::optional<int64_t> int_literal;
std::optional<Array> array_literal;
UInt32 dim = 0;
};
using RPN = std::vector<RPNElement>;
bool checkQueryStructure(const SelectQueryInfo & query);
/// Util functions for the traversal of AST, parses AST and builds rpn
void traverseAST(const ASTPtr & node, RPN & rpn);
/// Return true if we can identify our node type
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
/// Checks if the AST stores ConstType expression
bool tryCastToConstType(const ASTPtr & node, RPNElement & out);
/// Traverses the AST of ORDERBY section
void traverseOrderByAST(const ASTPtr & node, RPN & rpn);
/// Returns true and stores ANNExpr if the query has valid ORDERBY section
static bool matchRPNOrderBy(RPN & rpn, Info & info);
/// Returns true and stores Length if we have valid LIMIT clause in query
static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit);
/// Gets float or int from AST node
static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter);
Block block_with_constants;
/// true if we have one of two supported query types
std::optional<Info> query_information;
/// only queries with a lower limit can be considered to avoid memory overflow
const UInt64 max_limit_for_ann_queries;
bool index_is_useful = false;
};
}

View File

@ -2,6 +2,21 @@
-- Tests various simple approximate nearest neighborhood (ANN) queries that utilize vector search indexes.
SET allow_experimental_vector_similarity_index = 1;
SET enable_analyzer = 0;
CREATE OR REPLACE TABLE tab(id Int32, vec Array(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 8192;
INSERT INTO tab VALUES (0, [1.0, 0.0]), (1, [1.1, 0.0]), (2, [1.2, 0.0]), (3, [1.3, 0.0]), (4, [1.4, 0.0]), (5, [0.0, 2.0]), (6, [0.0, 2.1]), (7, [0.0, 2.2]), (8, [0.0, 2.3]), (9, [0.0, 2.4]);
SELECT id, vec, L2Distance(vec, [0.0, 2.0])
FROM tab
ORDER BY L2Distance(vec, [0.0, 2.0])
LIMIT 3;
SET allow_experimental_vector_similarity_index = 1;
SET enable_analyzer = 0;