This commit is contained in:
Robert Schulze 2024-11-20 15:16:37 -08:00 committed by GitHub
commit f659541db2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 212 additions and 545 deletions

View File

@ -84,8 +84,6 @@ to load and compare. The library also has several hardware-specific SIMD optimiz
Arm (NEON and SVE) and x86 (AVX2 and AVX-512) CPUs and OS-specific optimizations to allow efficient navigation around immutable persistent
files, without loading them into RAM.
USearch indexes are currently experimental, to use them you first need to `SET allow_experimental_vector_similarity_index = 1`.
Vector similarity indexes currently support two distance functions:
- `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space
([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)).

View File

@ -71,6 +71,8 @@ void tryRemoveRedundantSorting(QueryPlan::Node * root);
/// Remove redundant distinct steps
size_t tryRemoveRedundantDistinct(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
/// Put some steps under union, so that plan optimization could be applied to union parts separately.
/// For example, the plan can be rewritten like:
/// - Something - - Expression - Something -
@ -82,7 +84,7 @@ size_t tryAggregatePartitionsIndependently(QueryPlan::Node * node, QueryPlan::No
inline const auto & getOptimizations()
{
static const std::array<Optimization, 12> optimizations = {{
static const std::array<Optimization, 13> optimizations = {{
{tryLiftUpArrayJoin, "liftUpArrayJoin", &QueryPlanOptimizationSettings::lift_up_array_join},
{tryPushDownLimit, "pushDownLimit", &QueryPlanOptimizationSettings::push_down_limit},
{trySplitFilter, "splitFilter", &QueryPlanOptimizationSettings::split_filter},
@ -95,6 +97,7 @@ inline const auto & getOptimizations()
{tryLiftUpUnion, "liftUpUnion", &QueryPlanOptimizationSettings::lift_up_union},
{tryAggregatePartitionsIndependently, "aggregatePartitionsIndependently", &QueryPlanOptimizationSettings::aggregate_partitions_independently},
{tryRemoveRedundantDistinct, "removeRedundantDistinct", &QueryPlanOptimizationSettings::remove_redundant_distinct},
{tryUseVectorSearch, "useVectorSearch", &QueryPlanOptimizationSettings::use_vector_search},
}};
return optimizations;
@ -110,7 +113,9 @@ using Stack = std::vector<Frame>;
/// Second pass optimizations
void optimizePrimaryKeyConditionAndLimit(const Stack & stack);
void optimizeVectorSearch(const Stack & stack);
void optimizePrewhere(Stack & stack, QueryPlan::Nodes & nodes);
void optimizeVectorSearch(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
void optimizeReadInOrder(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
void optimizeAggregationInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);
void optimizeDistinctInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);

View File

@ -70,6 +70,9 @@ struct QueryPlanOptimizationSettings
/// If remove-redundant-distinct-steps optimization is enabled.
bool remove_redundant_distinct = true;
/// If use vector search is enabled, the query will use the vector similarity index
bool use_vector_search = true;
bool optimize_prewhere = true;
/// If reading from projection can be applied

View File

@ -14,6 +14,7 @@
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/FilterStep.h>
#include <Processors/QueryPlan/ITransformingStep.h>
#include <Processors/QueryPlan/LimitStep.h>
#include <Processors/QueryPlan/JoinStep.h>
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
#include <Processors/QueryPlan/Optimizations/actionsDAGUtils.h>

View File

@ -126,7 +126,6 @@ void optimizeTreeSecondPass(const QueryPlanOptimizationSettings & optimization_s
if (frame.next_child == 0)
{
if (optimization_settings.read_in_order)
optimizeReadInOrder(*frame.node, nodes);

View File

@ -0,0 +1,53 @@
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/FilterStep.h>
#include <Processors/QueryPlan/LimitStep.h>
#include <Processors/QueryPlan/SourceStepWithFilter.h>
namespace DB::QueryPlanOptimizations
{
void xx([[maybe_unused]] const QueryPlan::Node & node)
{
/// const auto & frame = stack.back();
///
/// auto * source_step_with_filter = dynamic_cast<SourceStepWithFilter *>(frame.node->step.get());
/// if (!source_step_with_filter)
/// return;
///
/// const auto & storage_prewhere_info = source_step_with_filter->getPrewhereInfo();
/// if (storage_prewhere_info)
/// {
/// source_step_with_filter->addFilter(storage_prewhere_info->prewhere_actions.clone(), storage_prewhere_info->prewhere_column_name);
/// if (storage_prewhere_info->row_level_filter)
/// source_step_with_filter->addFilter(storage_prewhere_info->row_level_filter->clone(), storage_prewhere_info->row_level_column_name);
/// }
///
/// for (auto iter = stack.rbegin() + 1; iter != stack.rend(); ++iter)
/// {
/// if (auto * filter_step = typeid_cast<FilterStep *>(iter->node->step.get()))
/// {
/// source_step_with_filter->addFilter(filter_step->getExpression().clone(), filter_step->getFilterColumnName());
/// }
/// else if (auto * limit_step = typeid_cast<LimitStep *>(iter->node->step.get()))
/// {
/// source_step_with_filter->setLimit(limit_step->getLimitForSorting());
/// break;
/// }
/// else if (typeid_cast<ExpressionStep *>(iter->node->step.get()))
/// {
/// /// Note: actually, plan optimizations merge Filter and Expression steps.
/// /// Ideally, chain should look like (Expression -> ...) -> (Filter -> ...) -> ReadFromStorage,
/// /// So this is likely not needed.
/// continue;
/// }
/// else
/// {
/// break;
/// }
/// }
///
/// source_step_with_filter->applyFilters();
}
}

View File

@ -0,0 +1,76 @@
#include <memory>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/ReadFromMergeTree.h>
#include <Processors/QueryPlan/LimitStep.h>
#include <Processors/QueryPlan/SortingStep.h>
namespace DB::QueryPlanOptimizations
{
size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & /* nodes*/)
{
auto * limit_step = typeid_cast<LimitStep *>(parent_node->step.get());
if (!limit_step)
return 0;
if (parent_node->children.size() != 1)
return 0;
QueryPlan::Node * child_node = parent_node->children.front();
auto * sorting_step = typeid_cast<SortingStep *>(child_node->step.get());
if (!sorting_step)
return 0;
if (child_node->children.size() != 1)
return 0;
child_node = child_node->children.front();
auto * expression_step = typeid_cast<ExpressionStep *>(child_node->step.get());
if (!expression_step)
return 0;
if (child_node->children.size() != 1)
return 0;
child_node = child_node->children.front();
auto * read_from_mergetree_step = typeid_cast<ReadFromMergeTree *>(child_node->step.get());
if (!read_from_mergetree_step)
return 0;
const auto & sort_description = sorting_step->getSortDescription();
if (sort_description.size() != 1)
return 0;
[[maybe_unused]] ReadFromMergeTree::DistanceFunction distance_function;
/// lol
if (sort_description[0].column_name.starts_with("L2Distance"))
distance_function = ReadFromMergeTree::DistanceFunction::L2Distance;
else if (sort_description[0].column_name.starts_with("cosineDistance"))
distance_function = ReadFromMergeTree::DistanceFunction::cosineDistance;
else
return 0;
[[maybe_unused]] size_t limit = sorting_step->getLimit();
[[maybe_unused]] std::vector<Float64> reference_vector = {0.0, 0.2}; /// TODO
/// TODO check that ReadFromMergeTree has a vector similarity index
read_from_mergetree_step->vec_sim_idx_input = std::make_optional<ReadFromMergeTree::VectorSimilarityIndexInput>(
distance_function, limit, reference_vector);
/// --- --- ---
/// alwaysUnknownOrTrue:
///
/// String index_distance_function;
/// switch (metric_kind)
/// {
/// case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
/// case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
/// default: std::unreachable();
/// }
/// return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
return 0;
}
}

View File

@ -1489,7 +1489,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(Merge
all_column_names,
log,
indexes,
find_exact_ranges);
find_exact_ranges,
vec_sim_idx_input);
}
static void buildIndexes(
@ -1682,7 +1683,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(
const Names & all_column_names,
LoggerPtr log,
std::optional<Indexes> & indexes,
bool find_exact_ranges)
bool find_exact_ranges,
const std::optional<VectorSimilarityIndexInput> & vec_sim_idx_input)
{
AnalysisResult result;
const auto & settings = context_->getSettingsRef();
@ -1775,7 +1777,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(
num_streams,
result.index_stats,
indexes->use_skip_indexes,
find_exact_ranges);
find_exact_ranges,
vec_sim_idx_input);
}
size_t sum_marks_pk = total_marks_pk;

View File

@ -166,6 +166,19 @@ public:
std::optional<std::unordered_set<String>> part_values;
};
enum class DistanceFunction
{
L2Distance,
cosineDistance
};
struct VectorSimilarityIndexInput
{
DistanceFunction distance_function;
size_t limit;
std::vector<Float64> reference_vector;
};
static AnalysisResultPtr selectRangesToRead(
MergeTreeData::DataPartsVector parts,
MergeTreeData::MutationsSnapshotPtr mutations_snapshot,
@ -178,7 +191,8 @@ public:
const Names & all_column_names,
LoggerPtr log,
std::optional<Indexes> & indexes,
bool find_exact_ranges);
bool find_exact_ranges,
const std::optional<VectorSimilarityIndexInput> & vec_sim_idx_input);
AnalysisResultPtr selectRangesToRead(MergeTreeData::DataPartsVector parts, bool find_exact_ranges = false) const;
@ -212,6 +226,8 @@ public:
void applyFilters(ActionDAGNodes added_filter_nodes) override;
std::optional<VectorSimilarityIndexInput> vec_sim_idx_input;
private:
MergeTreeReaderSettings reader_settings;

View File

@ -11,7 +11,6 @@
#include <Storages/MergeTree/MergeTreeDataPartUUID.h>
#include <Storages/MergeTree/StorageFromMergeTreeDataPart.h>
#include <Storages/MergeTree/MergeTreeIndexFullText.h>
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Storages/ReadInOrderOptimizer.h>
#include <Storages/VirtualColumnUtils.h>
#include <Parsers/ASTIdentifier.h>
@ -629,7 +628,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd
size_t num_streams,
ReadFromMergeTree::IndexStats & index_stats,
bool use_skip_indexes,
bool find_exact_ranges)
bool find_exact_ranges,
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input)
{
RangesInDataParts parts_with_ranges;
parts_with_ranges.resize(parts.size());
@ -734,7 +734,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd
reader_settings,
mark_cache.get(),
uncompressed_cache.get(),
log);
log,
vec_sim_idx_input);
stat.granules_dropped.fetch_add(total_granules - ranges.ranges.getNumberOfMarks(), std::memory_order_relaxed);
if (ranges.ranges.empty())
@ -950,7 +951,8 @@ ReadFromMergeTree::AnalysisResultPtr MergeTreeDataSelectExecutor::estimateNumMar
column_names_to_return,
log,
indexes,
/*find_exact_ranges*/false);
/*find_exact_ranges*/false,
std::nullopt);
}
QueryPlanStepPtr MergeTreeDataSelectExecutor::readFromParts(
@ -1364,7 +1366,8 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
const MergeTreeReaderSettings & reader_settings,
MarkCache * mark_cache,
UncompressedCache * uncompressed_cache,
LoggerPtr log)
LoggerPtr log,
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input)
{
if (!index_helper->getDeserializedFormat(part->getDataPartStorage(), index_helper->getFileName()))
{
@ -1424,9 +1427,9 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin)
reader.read(granule);
if (index_helper->isVectorSimilarityIndex())
if (index_helper->isVectorSimilarityIndex() && vec_sim_idx_input)
{
auto rows = condition->calculateApproximateNearestNeighbors(granule);
auto rows = condition->calculateApproximateNearestNeighbors(granule, vec_sim_idx_input->limit, vec_sim_idx_input->reference_vector); /// TODO
for (auto row : rows)
{

View File

@ -94,7 +94,8 @@ private:
const MergeTreeReaderSettings & reader_settings,
MarkCache * mark_cache,
UncompressedCache * uncompressed_cache,
LoggerPtr log);
LoggerPtr log,
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input);
static MarkRanges filterMarksUsingMergedIndex(
MergeTreeIndices indices,
@ -197,7 +198,8 @@ public:
size_t num_streams,
ReadFromMergeTree::IndexStats & index_stats,
bool use_skip_indexes,
bool find_exact_ranges);
bool find_exact_ranges,
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input);
/// Create expression for sampling.
/// Also, calculate _sample_factor if needed.

View File

@ -1,6 +1,6 @@
#pragma once
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Storages/MergeTree/MergeTreeIndices.h>
/// Walking corpse implementation for removed skipping index of type "annoy" and "usearch".
/// Its only purpose is to allow loading old tables with indexes of these types.

View File

@ -398,11 +398,9 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_
MergeTreeIndexConditionVectorSimilarity::MergeTreeIndexConditionVectorSimilarity(
const IndexDescription & /*index_description*/,
const SelectQueryInfo & query,
unum::usearch::metric_kind_t metric_kind_,
ContextPtr context)
: vector_similarity_condition(query, context)
, metric_kind(metric_kind_)
: metric_kind(metric_kind_)
, expansion_search(context->getSettingsRef()[Setting::hnsw_candidate_list_size_for_search])
{
if (expansion_search == 0)
@ -417,31 +415,23 @@ bool MergeTreeIndexConditionVectorSimilarity::mayBeTrueOnGranule(MergeTreeIndexG
bool MergeTreeIndexConditionVectorSimilarity::alwaysUnknownOrTrue() const
{
String index_distance_function;
switch (metric_kind)
{
case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
default: std::unreachable();
}
return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
return false;
}
std::vector<UInt64> MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule_) const
std::vector<UInt64> MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(
MergeTreeIndexGranulePtr granule_,
size_t limit,
const std::vector<Float64> & reference_vector) const
{
const UInt64 limit = vector_similarity_condition.getLimit();
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleVectorSimilarity>(granule_);
if (granule == nullptr)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
const USearchIndexWithSerializationPtr index = granule->index;
if (vector_similarity_condition.getDimensions() != index->dimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) does not match the dimension in the index ({})",
vector_similarity_condition.getDimensions(), index->dimensions());
const std::vector<Float64> reference_vector = vector_similarity_condition.getReferenceVector();
if (reference_vector.size() != index->dimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the reference vector in the query ({}) does not match the dimension in the index ({})",
reference_vector.size(), index->dimensions());
/// We want to run the search with the user-provided value for setting hnsw_candidate_list_size_for_search (aka. expansion_search).
/// The way to do this in USearch is to call index_dense_gt::change_expansion_search. Unfortunately, this introduces a need to
@ -495,9 +485,9 @@ MergeTreeIndexAggregatorPtr MergeTreeIndexVectorSimilarity::createIndexAggregato
return std::make_shared<MergeTreeIndexAggregatorVectorSimilarity>(index.name, index.sample_block, metric_kind, scalar_kind, usearch_hnsw_params);
}
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & /*query*/, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, query, metric_kind, context);
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, metric_kind, context);
};
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const ActionsDAG *, ContextPtr) const

View File

@ -4,7 +4,7 @@
#if USE_USEARCH
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Storages/MergeTree/MergeTreeIndices.h>
#include <Common/Logger.h>
#include <usearch/index_dense.hpp>
@ -133,7 +133,6 @@ class MergeTreeIndexConditionVectorSimilarity final : public IMergeTreeIndexCond
public:
MergeTreeIndexConditionVectorSimilarity(
const IndexDescription & index_description,
const SelectQueryInfo & query,
unum::usearch::metric_kind_t metric_kind_,
ContextPtr context);
@ -141,10 +140,12 @@ public:
bool alwaysUnknownOrTrue() const override;
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr granule) const override;
std::vector<UInt64> calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule) const override;
std::vector<UInt64> calculateApproximateNearestNeighbors(
MergeTreeIndexGranulePtr granule,
size_t limit,
const std::vector<Float64> & reference_vector) const override;
private:
const VectorSimilarityCondition vector_similarity_condition;
const unum::usearch::metric_kind_t metric_kind;
const size_t expansion_search;
};

View File

@ -99,7 +99,10 @@ public:
/// Special method for vector similarity indexes:
/// Returns the row positions of the N nearest neighbors in the index granule
/// The returned row numbers are guaranteed to be sorted and unique.
virtual std::vector<UInt64> calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr) const
virtual std::vector<UInt64> calculateApproximateNearestNeighbors(
MergeTreeIndexGranulePtr /*granule*/,
size_t /*limit*/,
const std::vector<Float64> & /*reference_vector*/) const
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "calculateApproximateNearestNeighbors is not implemented for non-vector-similarity indexes");
}

View File

@ -1,354 +0,0 @@
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Core/Settings.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Storages/MergeTree/KeyCondition.h>
#include <Storages/MergeTree/MergeTreeSettings.h>
namespace DB
{
namespace Setting
{
extern const SettingsUInt64 max_limit_for_ann_queries;
}
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int INCORRECT_QUERY;
}
namespace
{
template <typename Literal>
void extractReferenceVectorFromLiteral(std::vector<Float64> & reference_vector, Literal literal)
{
Float64 float_element_of_reference_vector;
Int64 int_element_of_reference_vector;
for (const auto & value : literal.value())
{
if (value.tryGet(float_element_of_reference_vector))
reference_vector.emplace_back(float_element_of_reference_vector);
else if (value.tryGet(int_element_of_reference_vector))
reference_vector.emplace_back(static_cast<float>(int_element_of_reference_vector));
else
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported.");
}
}
VectorSimilarityCondition::Info::DistanceFunction stringToDistanceFunction(const String & distance_function)
{
if (distance_function == "L2Distance")
return VectorSimilarityCondition::Info::DistanceFunction::L2;
if (distance_function == "cosineDistance")
return VectorSimilarityCondition::Info::DistanceFunction::Cosine;
return VectorSimilarityCondition::Info::DistanceFunction::Unknown;
}
}
VectorSimilarityCondition::VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context)
: block_with_constants(KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context))
, max_limit_for_ann_queries(context->getSettingsRef()[Setting::max_limit_for_ann_queries])
, index_is_useful(checkQueryStructure(query_info))
{}
bool VectorSimilarityCondition::alwaysUnknownOrTrue(const String & distance_function) const
{
if (!index_is_useful)
return true; /// query isn't supported
/// If query is supported, check if distance function of index is the same as distance function in query
return !(stringToDistanceFunction(distance_function) == query_information->distance_function);
}
UInt64 VectorSimilarityCondition::getLimit() const
{
if (index_is_useful && query_information.has_value())
return query_information->limit;
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
}
std::vector<Float64> VectorSimilarityCondition::getReferenceVector() const
{
if (index_is_useful && query_information.has_value())
return query_information->reference_vector;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index.");
}
size_t VectorSimilarityCondition::getDimensions() const
{
if (index_is_useful && query_information.has_value())
return query_information->reference_vector.size();
throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index.");
}
String VectorSimilarityCondition::getColumnName() const
{
if (index_is_useful && query_information.has_value())
return query_information->column_name;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index.");
}
VectorSimilarityCondition::Info::DistanceFunction VectorSimilarityCondition::getDistanceFunction() const
{
if (index_is_useful && query_information.has_value())
return query_information->distance_function;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Distance function was requested for useless or uninitialized index.");
}
bool VectorSimilarityCondition::checkQueryStructure(const SelectQueryInfo & query)
{
Info order_by_info;
/// Build rpns for query sections
const auto & select = query.query->as<ASTSelectQuery &>();
RPN rpn_order_by;
RPNElement rpn_limit;
UInt64 limit;
if (select.limitLength())
traverseAtomAST(select.limitLength(), rpn_limit);
if (select.orderBy())
traverseOrderByAST(select.orderBy(), rpn_order_by);
/// Reverse RPNs for conveniences during parsing
std::reverse(rpn_order_by.begin(), rpn_order_by.end());
const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by, order_by_info);
const bool limit_is_valid = matchRPNLimit(rpn_limit, limit);
if (!limit_is_valid || limit > max_limit_for_ann_queries)
return false;
if (order_by_is_valid)
{
query_information = std::move(order_by_info);
query_information->limit = limit;
return true;
}
return false;
}
void VectorSimilarityCondition::traverseAST(const ASTPtr & node, RPN & rpn)
{
/// If the node is ASTFunction, it may have children nodes
if (const auto * func = node->as<ASTFunction>())
{
const ASTs & children = func->arguments->children;
/// Traverse children nodes
for (const auto& child : children)
traverseAST(child, rpn);
}
RPNElement element;
/// Get the data behind node
if (!traverseAtomAST(node, element))
element.function = RPNElement::FUNCTION_UNKNOWN;
rpn.emplace_back(std::move(element));
}
bool VectorSimilarityCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
{
/// Match Functions
if (const auto * function = node->as<ASTFunction>())
{
/// Set the name
out.func_name = function->name;
if (function->name == "L1Distance" ||
function->name == "L2Distance" ||
function->name == "LinfDistance" ||
function->name == "cosineDistance" ||
function->name == "dotProduct")
out.function = RPNElement::FUNCTION_DISTANCE;
else if (function->name == "array")
out.function = RPNElement::FUNCTION_ARRAY;
else if (function->name == "_CAST")
out.function = RPNElement::FUNCTION_CAST;
else
return false;
return true;
}
/// Match identifier
if (const auto * identifier = node->as<ASTIdentifier>())
{
out.function = RPNElement::FUNCTION_IDENTIFIER;
out.identifier.emplace(identifier->name());
out.func_name = "column identifier";
return true;
}
/// Check if we have constants behind the node
return tryCastToConstType(node, out);
}
bool VectorSimilarityCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
{
Field const_value;
DataTypePtr const_type;
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
{
/// Check for constant types
if (const_value.getType() == Field::Types::Float64)
{
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
out.float_literal.emplace(const_value.safeGet<Float32>());
out.func_name = "Float literal";
return true;
}
if (const_value.getType() == Field::Types::UInt64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.safeGet<UInt64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Int64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.safeGet<Int64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Array)
{
out.function = RPNElement::FUNCTION_LITERAL_ARRAY;
out.array_literal = const_value.safeGet<Array>();
out.func_name = "Array literal";
return true;
}
if (const_value.getType() == Field::Types::String)
{
out.function = RPNElement::FUNCTION_STRING_LITERAL;
out.func_name = const_value.safeGet<String>();
return true;
}
}
return false;
}
void VectorSimilarityCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn)
{
if (const auto * expr_list = node->as<ASTExpressionList>())
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
traverseAST(order_by_element->children.front(), rpn);
}
/// Returns true and stores ANNExpr if the query has valid ORDERBY clause
bool VectorSimilarityCondition::matchRPNOrderBy(RPN & rpn, Info & info)
{
/// ORDER BY clause must have at least 3 expressions
if (rpn.size() < 3)
return false;
auto iter = rpn.begin();
auto end = rpn.end();
bool identifier_found = false;
/// Matches DistanceFunc->[Column]->[ArrayFunc]->ReferenceVector(floats)->[Column]
if (iter->function != RPNElement::FUNCTION_DISTANCE)
return false;
info.distance_function = stringToDistanceFunction(iter->func_name);
++iter;
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
identifier_found = true;
info.column_name = std::move(iter->identifier.value());
++iter;
}
if (iter->function == RPNElement::FUNCTION_ARRAY)
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
++iter;
}
/// further conditions are possible if there is no array, or no identifier is found
/// the array can be inside a cast function. For other cases, see the loop after this condition
if (iter != end && iter->function == RPNElement::FUNCTION_CAST)
{
++iter;
/// Cast should be made to array
if (!iter->func_name.starts_with("Array"))
return false;
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
++iter;
}
else
return false;
}
while (iter != end)
{
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter));
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
if (identifier_found)
return false;
info.column_name = std::move(iter->identifier.value());
identifier_found = true;
}
else
return false;
++iter;
}
/// Final checks of correctness
return identifier_found && !info.reference_vector.empty();
}
/// Returns true and stores Length if we have valid LIMIT clause in query
bool VectorSimilarityCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
{
if (rpn.function == RPNElement::FUNCTION_INT_LITERAL)
{
limit = rpn.int_literal.value();
return true;
}
return false;
}
/// Gets float or int from AST node
float VectorSimilarityCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter)
{
if (iter->float_literal.has_value())
return iter->float_literal.value();
if (iter->int_literal.has_value())
return static_cast<float>(iter->int_literal.value());
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong parsed AST in buildRPN\n");
}
}

View File

@ -1,147 +0,0 @@
#pragma once
#include <Storages/MergeTree/MergeTreeIndices.h>
#include "base/types.h"
#include <optional>
#include <vector>
namespace DB
{
/// Class VectorSimilarityCondition is responsible for recognizing if the query
/// can utilize vector similarity indexes.
///
/// Method alwaysUnknownOrTrue returns false if we can speed up the query, and true otherwise.
/// It has only one argument, the name of the distance function with which index was built.
/// Only queries with ORDER BY DistanceFunc and LIMIT, i.e.:
///
/// SELECT * FROM * ... ORDER BY DistanceFunc(column, reference_vector) LIMIT count
///
/// reference_vector should have float coordinates, e.g. [0.2, 0.1, .., 0.5]
class VectorSimilarityCondition
{
public:
VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context);
/// vector similarity queries have a similar structure:
/// - reference vector from which all distances are calculated
/// - distance function, e.g L2Distance
/// - name of column with embeddings
/// - type of query
/// - maximum number of returned elements (LIMIT)
///
/// And one optional parameter:
/// - distance to compare with (only for where queries)
///
/// This struct holds all these components.
struct Info
{
enum class DistanceFunction : uint8_t
{
Unknown,
L2,
Cosine
};
std::vector<Float64> reference_vector;
DistanceFunction distance_function;
String column_name;
UInt64 limit;
float distance = -1.0;
};
/// Returns false if query can be speeded up by an ANN index, true otherwise.
bool alwaysUnknownOrTrue(const String & distance_function) const;
std::vector<Float64> getReferenceVector() const;
size_t getDimensions() const;
String getColumnName() const;
Info::DistanceFunction getDistanceFunction() const;
UInt64 getLimit() const;
private:
struct RPNElement
{
enum Function
{
/// DistanceFunctions
FUNCTION_DISTANCE,
/// array(0.1, ..., 0.1)
FUNCTION_ARRAY,
/// Operators <, >, <=, >=
FUNCTION_COMPARISON,
/// Numeric float value
FUNCTION_FLOAT_LITERAL,
/// Numeric int value
FUNCTION_INT_LITERAL,
/// Column identifier
FUNCTION_IDENTIFIER,
/// Unknown, can be any value
FUNCTION_UNKNOWN,
/// [0.1, ...., 0.1] vector without word 'array'
FUNCTION_LITERAL_ARRAY,
/// if client parameters are used, cast will always be in the query
FUNCTION_CAST,
/// name of type in cast function
FUNCTION_STRING_LITERAL,
};
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
: function(function_)
{}
Function function;
String func_name = "Unknown";
std::optional<float> float_literal;
std::optional<String> identifier;
std::optional<int64_t> int_literal;
std::optional<Array> array_literal;
UInt32 dim = 0;
};
using RPN = std::vector<RPNElement>;
bool checkQueryStructure(const SelectQueryInfo & query);
/// Util functions for the traversal of AST, parses AST and builds rpn
void traverseAST(const ASTPtr & node, RPN & rpn);
/// Return true if we can identify our node type
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
/// Checks if the AST stores ConstType expression
bool tryCastToConstType(const ASTPtr & node, RPNElement & out);
/// Traverses the AST of ORDERBY section
void traverseOrderByAST(const ASTPtr & node, RPN & rpn);
/// Returns true and stores ANNExpr if the query has valid ORDERBY section
static bool matchRPNOrderBy(RPN & rpn, Info & info);
/// Returns true and stores Length if we have valid LIMIT clause in query
static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit);
/// Gets float or int from AST node
static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter);
Block block_with_constants;
/// true if we have one of two supported query types
std::optional<Info> query_information;
/// only queries with a lower limit can be considered to avoid memory overflow
const UInt64 max_limit_for_ann_queries;
bool index_is_useful = false;
};
}

View File

@ -2,6 +2,21 @@
-- Tests various simple approximate nearest neighborhood (ANN) queries that utilize vector search indexes.
SET allow_experimental_vector_similarity_index = 1;
SET enable_analyzer = 0;
CREATE OR REPLACE TABLE tab(id Int32, vec Array(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 8192;
INSERT INTO tab VALUES (0, [1.0, 0.0]), (1, [1.1, 0.0]), (2, [1.2, 0.0]), (3, [1.3, 0.0]), (4, [1.4, 0.0]), (5, [0.0, 2.0]), (6, [0.0, 2.1]), (7, [0.0, 2.2]), (8, [0.0, 2.3]), (9, [0.0, 2.4]);
SELECT id, vec, L2Distance(vec, [0.0, 2.0])
FROM tab
ORDER BY L2Distance(vec, [0.0, 2.0])
LIMIT 3;
SET allow_experimental_vector_similarity_index = 1;
SET enable_analyzer = 0;