mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
Merge fa635358d2
into 44b4bd38b9
This commit is contained in:
commit
f659541db2
@ -84,8 +84,6 @@ to load and compare. The library also has several hardware-specific SIMD optimiz
|
||||
Arm (NEON and SVE) and x86 (AVX2 and AVX-512) CPUs and OS-specific optimizations to allow efficient navigation around immutable persistent
|
||||
files, without loading them into RAM.
|
||||
|
||||
USearch indexes are currently experimental, to use them you first need to `SET allow_experimental_vector_similarity_index = 1`.
|
||||
|
||||
Vector similarity indexes currently support two distance functions:
|
||||
- `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space
|
||||
([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)).
|
||||
|
@ -71,6 +71,8 @@ void tryRemoveRedundantSorting(QueryPlan::Node * root);
|
||||
/// Remove redundant distinct steps
|
||||
size_t tryRemoveRedundantDistinct(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
|
||||
|
||||
size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
|
||||
|
||||
/// Put some steps under union, so that plan optimization could be applied to union parts separately.
|
||||
/// For example, the plan can be rewritten like:
|
||||
/// - Something - - Expression - Something -
|
||||
@ -82,7 +84,7 @@ size_t tryAggregatePartitionsIndependently(QueryPlan::Node * node, QueryPlan::No
|
||||
|
||||
inline const auto & getOptimizations()
|
||||
{
|
||||
static const std::array<Optimization, 12> optimizations = {{
|
||||
static const std::array<Optimization, 13> optimizations = {{
|
||||
{tryLiftUpArrayJoin, "liftUpArrayJoin", &QueryPlanOptimizationSettings::lift_up_array_join},
|
||||
{tryPushDownLimit, "pushDownLimit", &QueryPlanOptimizationSettings::push_down_limit},
|
||||
{trySplitFilter, "splitFilter", &QueryPlanOptimizationSettings::split_filter},
|
||||
@ -95,6 +97,7 @@ inline const auto & getOptimizations()
|
||||
{tryLiftUpUnion, "liftUpUnion", &QueryPlanOptimizationSettings::lift_up_union},
|
||||
{tryAggregatePartitionsIndependently, "aggregatePartitionsIndependently", &QueryPlanOptimizationSettings::aggregate_partitions_independently},
|
||||
{tryRemoveRedundantDistinct, "removeRedundantDistinct", &QueryPlanOptimizationSettings::remove_redundant_distinct},
|
||||
{tryUseVectorSearch, "useVectorSearch", &QueryPlanOptimizationSettings::use_vector_search},
|
||||
}};
|
||||
|
||||
return optimizations;
|
||||
@ -110,7 +113,9 @@ using Stack = std::vector<Frame>;
|
||||
|
||||
/// Second pass optimizations
|
||||
void optimizePrimaryKeyConditionAndLimit(const Stack & stack);
|
||||
void optimizeVectorSearch(const Stack & stack);
|
||||
void optimizePrewhere(Stack & stack, QueryPlan::Nodes & nodes);
|
||||
void optimizeVectorSearch(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
|
||||
void optimizeReadInOrder(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
|
||||
void optimizeAggregationInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);
|
||||
void optimizeDistinctInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);
|
||||
|
@ -70,6 +70,9 @@ struct QueryPlanOptimizationSettings
|
||||
/// If remove-redundant-distinct-steps optimization is enabled.
|
||||
bool remove_redundant_distinct = true;
|
||||
|
||||
/// If use vector search is enabled, the query will use the vector similarity index
|
||||
bool use_vector_search = true;
|
||||
|
||||
bool optimize_prewhere = true;
|
||||
|
||||
/// If reading from projection can be applied
|
||||
|
@ -14,6 +14,7 @@
|
||||
#include <Processors/QueryPlan/ExpressionStep.h>
|
||||
#include <Processors/QueryPlan/FilterStep.h>
|
||||
#include <Processors/QueryPlan/ITransformingStep.h>
|
||||
#include <Processors/QueryPlan/LimitStep.h>
|
||||
#include <Processors/QueryPlan/JoinStep.h>
|
||||
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
|
||||
#include <Processors/QueryPlan/Optimizations/actionsDAGUtils.h>
|
||||
|
@ -126,7 +126,6 @@ void optimizeTreeSecondPass(const QueryPlanOptimizationSettings & optimization_s
|
||||
|
||||
if (frame.next_child == 0)
|
||||
{
|
||||
|
||||
if (optimization_settings.read_in_order)
|
||||
optimizeReadInOrder(*frame.node, nodes);
|
||||
|
||||
|
@ -0,0 +1,53 @@
|
||||
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
|
||||
#include <Processors/QueryPlan/ExpressionStep.h>
|
||||
#include <Processors/QueryPlan/FilterStep.h>
|
||||
#include <Processors/QueryPlan/LimitStep.h>
|
||||
#include <Processors/QueryPlan/SourceStepWithFilter.h>
|
||||
|
||||
namespace DB::QueryPlanOptimizations
|
||||
{
|
||||
|
||||
void xx([[maybe_unused]] const QueryPlan::Node & node)
|
||||
{
|
||||
/// const auto & frame = stack.back();
|
||||
///
|
||||
/// auto * source_step_with_filter = dynamic_cast<SourceStepWithFilter *>(frame.node->step.get());
|
||||
/// if (!source_step_with_filter)
|
||||
/// return;
|
||||
///
|
||||
/// const auto & storage_prewhere_info = source_step_with_filter->getPrewhereInfo();
|
||||
/// if (storage_prewhere_info)
|
||||
/// {
|
||||
/// source_step_with_filter->addFilter(storage_prewhere_info->prewhere_actions.clone(), storage_prewhere_info->prewhere_column_name);
|
||||
/// if (storage_prewhere_info->row_level_filter)
|
||||
/// source_step_with_filter->addFilter(storage_prewhere_info->row_level_filter->clone(), storage_prewhere_info->row_level_column_name);
|
||||
/// }
|
||||
///
|
||||
/// for (auto iter = stack.rbegin() + 1; iter != stack.rend(); ++iter)
|
||||
/// {
|
||||
/// if (auto * filter_step = typeid_cast<FilterStep *>(iter->node->step.get()))
|
||||
/// {
|
||||
/// source_step_with_filter->addFilter(filter_step->getExpression().clone(), filter_step->getFilterColumnName());
|
||||
/// }
|
||||
/// else if (auto * limit_step = typeid_cast<LimitStep *>(iter->node->step.get()))
|
||||
/// {
|
||||
/// source_step_with_filter->setLimit(limit_step->getLimitForSorting());
|
||||
/// break;
|
||||
/// }
|
||||
/// else if (typeid_cast<ExpressionStep *>(iter->node->step.get()))
|
||||
/// {
|
||||
/// /// Note: actually, plan optimizations merge Filter and Expression steps.
|
||||
/// /// Ideally, chain should look like (Expression -> ...) -> (Filter -> ...) -> ReadFromStorage,
|
||||
/// /// So this is likely not needed.
|
||||
/// continue;
|
||||
/// }
|
||||
/// else
|
||||
/// {
|
||||
/// break;
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// source_step_with_filter->applyFilters();
|
||||
}
|
||||
|
||||
}
|
76
src/Processors/QueryPlan/Optimizations/useVectorSearch.cpp
Normal file
76
src/Processors/QueryPlan/Optimizations/useVectorSearch.cpp
Normal file
@ -0,0 +1,76 @@
|
||||
#include <memory>
|
||||
#include <Processors/QueryPlan/QueryPlan.h>
|
||||
#include <Processors/QueryPlan/ExpressionStep.h>
|
||||
#include <Processors/QueryPlan/ReadFromMergeTree.h>
|
||||
#include <Processors/QueryPlan/LimitStep.h>
|
||||
#include <Processors/QueryPlan/SortingStep.h>
|
||||
|
||||
namespace DB::QueryPlanOptimizations
|
||||
{
|
||||
|
||||
size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & /* nodes*/)
|
||||
{
|
||||
auto * limit_step = typeid_cast<LimitStep *>(parent_node->step.get());
|
||||
if (!limit_step)
|
||||
return 0;
|
||||
|
||||
if (parent_node->children.size() != 1)
|
||||
return 0;
|
||||
QueryPlan::Node * child_node = parent_node->children.front();
|
||||
auto * sorting_step = typeid_cast<SortingStep *>(child_node->step.get());
|
||||
if (!sorting_step)
|
||||
return 0;
|
||||
|
||||
if (child_node->children.size() != 1)
|
||||
return 0;
|
||||
child_node = child_node->children.front();
|
||||
auto * expression_step = typeid_cast<ExpressionStep *>(child_node->step.get());
|
||||
if (!expression_step)
|
||||
return 0;
|
||||
|
||||
if (child_node->children.size() != 1)
|
||||
return 0;
|
||||
child_node = child_node->children.front();
|
||||
auto * read_from_mergetree_step = typeid_cast<ReadFromMergeTree *>(child_node->step.get());
|
||||
if (!read_from_mergetree_step)
|
||||
return 0;
|
||||
|
||||
const auto & sort_description = sorting_step->getSortDescription();
|
||||
|
||||
if (sort_description.size() != 1)
|
||||
return 0;
|
||||
|
||||
[[maybe_unused]] ReadFromMergeTree::DistanceFunction distance_function;
|
||||
|
||||
/// lol
|
||||
if (sort_description[0].column_name.starts_with("L2Distance"))
|
||||
distance_function = ReadFromMergeTree::DistanceFunction::L2Distance;
|
||||
else if (sort_description[0].column_name.starts_with("cosineDistance"))
|
||||
distance_function = ReadFromMergeTree::DistanceFunction::cosineDistance;
|
||||
else
|
||||
return 0;
|
||||
|
||||
[[maybe_unused]] size_t limit = sorting_step->getLimit();
|
||||
[[maybe_unused]] std::vector<Float64> reference_vector = {0.0, 0.2}; /// TODO
|
||||
|
||||
/// TODO check that ReadFromMergeTree has a vector similarity index
|
||||
|
||||
read_from_mergetree_step->vec_sim_idx_input = std::make_optional<ReadFromMergeTree::VectorSimilarityIndexInput>(
|
||||
distance_function, limit, reference_vector);
|
||||
|
||||
/// --- --- ---
|
||||
/// alwaysUnknownOrTrue:
|
||||
///
|
||||
/// String index_distance_function;
|
||||
/// switch (metric_kind)
|
||||
/// {
|
||||
/// case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
|
||||
/// case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
|
||||
/// default: std::unreachable();
|
||||
/// }
|
||||
/// return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
}
|
@ -1489,7 +1489,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(Merge
|
||||
all_column_names,
|
||||
log,
|
||||
indexes,
|
||||
find_exact_ranges);
|
||||
find_exact_ranges,
|
||||
vec_sim_idx_input);
|
||||
}
|
||||
|
||||
static void buildIndexes(
|
||||
@ -1682,7 +1683,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(
|
||||
const Names & all_column_names,
|
||||
LoggerPtr log,
|
||||
std::optional<Indexes> & indexes,
|
||||
bool find_exact_ranges)
|
||||
bool find_exact_ranges,
|
||||
const std::optional<VectorSimilarityIndexInput> & vec_sim_idx_input)
|
||||
{
|
||||
AnalysisResult result;
|
||||
const auto & settings = context_->getSettingsRef();
|
||||
@ -1775,7 +1777,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(
|
||||
num_streams,
|
||||
result.index_stats,
|
||||
indexes->use_skip_indexes,
|
||||
find_exact_ranges);
|
||||
find_exact_ranges,
|
||||
vec_sim_idx_input);
|
||||
}
|
||||
|
||||
size_t sum_marks_pk = total_marks_pk;
|
||||
|
@ -166,6 +166,19 @@ public:
|
||||
std::optional<std::unordered_set<String>> part_values;
|
||||
};
|
||||
|
||||
enum class DistanceFunction
|
||||
{
|
||||
L2Distance,
|
||||
cosineDistance
|
||||
};
|
||||
|
||||
struct VectorSimilarityIndexInput
|
||||
{
|
||||
DistanceFunction distance_function;
|
||||
size_t limit;
|
||||
std::vector<Float64> reference_vector;
|
||||
};
|
||||
|
||||
static AnalysisResultPtr selectRangesToRead(
|
||||
MergeTreeData::DataPartsVector parts,
|
||||
MergeTreeData::MutationsSnapshotPtr mutations_snapshot,
|
||||
@ -178,7 +191,8 @@ public:
|
||||
const Names & all_column_names,
|
||||
LoggerPtr log,
|
||||
std::optional<Indexes> & indexes,
|
||||
bool find_exact_ranges);
|
||||
bool find_exact_ranges,
|
||||
const std::optional<VectorSimilarityIndexInput> & vec_sim_idx_input);
|
||||
|
||||
AnalysisResultPtr selectRangesToRead(MergeTreeData::DataPartsVector parts, bool find_exact_ranges = false) const;
|
||||
|
||||
@ -212,6 +226,8 @@ public:
|
||||
|
||||
void applyFilters(ActionDAGNodes added_filter_nodes) override;
|
||||
|
||||
std::optional<VectorSimilarityIndexInput> vec_sim_idx_input;
|
||||
|
||||
private:
|
||||
MergeTreeReaderSettings reader_settings;
|
||||
|
||||
|
@ -11,7 +11,6 @@
|
||||
#include <Storages/MergeTree/MergeTreeDataPartUUID.h>
|
||||
#include <Storages/MergeTree/StorageFromMergeTreeDataPart.h>
|
||||
#include <Storages/MergeTree/MergeTreeIndexFullText.h>
|
||||
#include <Storages/MergeTree/VectorSimilarityCondition.h>
|
||||
#include <Storages/ReadInOrderOptimizer.h>
|
||||
#include <Storages/VirtualColumnUtils.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
@ -629,7 +628,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd
|
||||
size_t num_streams,
|
||||
ReadFromMergeTree::IndexStats & index_stats,
|
||||
bool use_skip_indexes,
|
||||
bool find_exact_ranges)
|
||||
bool find_exact_ranges,
|
||||
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input)
|
||||
{
|
||||
RangesInDataParts parts_with_ranges;
|
||||
parts_with_ranges.resize(parts.size());
|
||||
@ -734,7 +734,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd
|
||||
reader_settings,
|
||||
mark_cache.get(),
|
||||
uncompressed_cache.get(),
|
||||
log);
|
||||
log,
|
||||
vec_sim_idx_input);
|
||||
|
||||
stat.granules_dropped.fetch_add(total_granules - ranges.ranges.getNumberOfMarks(), std::memory_order_relaxed);
|
||||
if (ranges.ranges.empty())
|
||||
@ -950,7 +951,8 @@ ReadFromMergeTree::AnalysisResultPtr MergeTreeDataSelectExecutor::estimateNumMar
|
||||
column_names_to_return,
|
||||
log,
|
||||
indexes,
|
||||
/*find_exact_ranges*/false);
|
||||
/*find_exact_ranges*/false,
|
||||
std::nullopt);
|
||||
}
|
||||
|
||||
QueryPlanStepPtr MergeTreeDataSelectExecutor::readFromParts(
|
||||
@ -1364,7 +1366,8 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
|
||||
const MergeTreeReaderSettings & reader_settings,
|
||||
MarkCache * mark_cache,
|
||||
UncompressedCache * uncompressed_cache,
|
||||
LoggerPtr log)
|
||||
LoggerPtr log,
|
||||
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input)
|
||||
{
|
||||
if (!index_helper->getDeserializedFormat(part->getDataPartStorage(), index_helper->getFileName()))
|
||||
{
|
||||
@ -1424,9 +1427,9 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
|
||||
if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin)
|
||||
reader.read(granule);
|
||||
|
||||
if (index_helper->isVectorSimilarityIndex())
|
||||
if (index_helper->isVectorSimilarityIndex() && vec_sim_idx_input)
|
||||
{
|
||||
auto rows = condition->calculateApproximateNearestNeighbors(granule);
|
||||
auto rows = condition->calculateApproximateNearestNeighbors(granule, vec_sim_idx_input->limit, vec_sim_idx_input->reference_vector); /// TODO
|
||||
|
||||
for (auto row : rows)
|
||||
{
|
||||
|
@ -94,7 +94,8 @@ private:
|
||||
const MergeTreeReaderSettings & reader_settings,
|
||||
MarkCache * mark_cache,
|
||||
UncompressedCache * uncompressed_cache,
|
||||
LoggerPtr log);
|
||||
LoggerPtr log,
|
||||
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input);
|
||||
|
||||
static MarkRanges filterMarksUsingMergedIndex(
|
||||
MergeTreeIndices indices,
|
||||
@ -197,7 +198,8 @@ public:
|
||||
size_t num_streams,
|
||||
ReadFromMergeTree::IndexStats & index_stats,
|
||||
bool use_skip_indexes,
|
||||
bool find_exact_ranges);
|
||||
bool find_exact_ranges,
|
||||
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input);
|
||||
|
||||
/// Create expression for sampling.
|
||||
/// Also, calculate _sample_factor if needed.
|
||||
|
@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <Storages/MergeTree/VectorSimilarityCondition.h>
|
||||
#include <Storages/MergeTree/MergeTreeIndices.h>
|
||||
|
||||
/// Walking corpse implementation for removed skipping index of type "annoy" and "usearch".
|
||||
/// Its only purpose is to allow loading old tables with indexes of these types.
|
||||
|
@ -398,11 +398,9 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_
|
||||
|
||||
MergeTreeIndexConditionVectorSimilarity::MergeTreeIndexConditionVectorSimilarity(
|
||||
const IndexDescription & /*index_description*/,
|
||||
const SelectQueryInfo & query,
|
||||
unum::usearch::metric_kind_t metric_kind_,
|
||||
ContextPtr context)
|
||||
: vector_similarity_condition(query, context)
|
||||
, metric_kind(metric_kind_)
|
||||
: metric_kind(metric_kind_)
|
||||
, expansion_search(context->getSettingsRef()[Setting::hnsw_candidate_list_size_for_search])
|
||||
{
|
||||
if (expansion_search == 0)
|
||||
@ -417,31 +415,23 @@ bool MergeTreeIndexConditionVectorSimilarity::mayBeTrueOnGranule(MergeTreeIndexG
|
||||
|
||||
bool MergeTreeIndexConditionVectorSimilarity::alwaysUnknownOrTrue() const
|
||||
{
|
||||
String index_distance_function;
|
||||
switch (metric_kind)
|
||||
{
|
||||
case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
|
||||
case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
|
||||
default: std::unreachable();
|
||||
}
|
||||
return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<UInt64> MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule_) const
|
||||
std::vector<UInt64> MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(
|
||||
MergeTreeIndexGranulePtr granule_,
|
||||
size_t limit,
|
||||
const std::vector<Float64> & reference_vector) const
|
||||
{
|
||||
const UInt64 limit = vector_similarity_condition.getLimit();
|
||||
|
||||
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleVectorSimilarity>(granule_);
|
||||
if (granule == nullptr)
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
|
||||
|
||||
const USearchIndexWithSerializationPtr index = granule->index;
|
||||
|
||||
if (vector_similarity_condition.getDimensions() != index->dimensions())
|
||||
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) does not match the dimension in the index ({})",
|
||||
vector_similarity_condition.getDimensions(), index->dimensions());
|
||||
|
||||
const std::vector<Float64> reference_vector = vector_similarity_condition.getReferenceVector();
|
||||
if (reference_vector.size() != index->dimensions())
|
||||
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the reference vector in the query ({}) does not match the dimension in the index ({})",
|
||||
reference_vector.size(), index->dimensions());
|
||||
|
||||
/// We want to run the search with the user-provided value for setting hnsw_candidate_list_size_for_search (aka. expansion_search).
|
||||
/// The way to do this in USearch is to call index_dense_gt::change_expansion_search. Unfortunately, this introduces a need to
|
||||
@ -495,9 +485,9 @@ MergeTreeIndexAggregatorPtr MergeTreeIndexVectorSimilarity::createIndexAggregato
|
||||
return std::make_shared<MergeTreeIndexAggregatorVectorSimilarity>(index.name, index.sample_block, metric_kind, scalar_kind, usearch_hnsw_params);
|
||||
}
|
||||
|
||||
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
|
||||
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & /*query*/, ContextPtr context) const
|
||||
{
|
||||
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, query, metric_kind, context);
|
||||
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, metric_kind, context);
|
||||
};
|
||||
|
||||
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const ActionsDAG *, ContextPtr) const
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#if USE_USEARCH
|
||||
|
||||
#include <Storages/MergeTree/VectorSimilarityCondition.h>
|
||||
#include <Storages/MergeTree/MergeTreeIndices.h>
|
||||
#include <Common/Logger.h>
|
||||
#include <usearch/index_dense.hpp>
|
||||
|
||||
@ -133,7 +133,6 @@ class MergeTreeIndexConditionVectorSimilarity final : public IMergeTreeIndexCond
|
||||
public:
|
||||
MergeTreeIndexConditionVectorSimilarity(
|
||||
const IndexDescription & index_description,
|
||||
const SelectQueryInfo & query,
|
||||
unum::usearch::metric_kind_t metric_kind_,
|
||||
ContextPtr context);
|
||||
|
||||
@ -141,10 +140,12 @@ public:
|
||||
|
||||
bool alwaysUnknownOrTrue() const override;
|
||||
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr granule) const override;
|
||||
std::vector<UInt64> calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule) const override;
|
||||
std::vector<UInt64> calculateApproximateNearestNeighbors(
|
||||
MergeTreeIndexGranulePtr granule,
|
||||
size_t limit,
|
||||
const std::vector<Float64> & reference_vector) const override;
|
||||
|
||||
private:
|
||||
const VectorSimilarityCondition vector_similarity_condition;
|
||||
const unum::usearch::metric_kind_t metric_kind;
|
||||
const size_t expansion_search;
|
||||
};
|
||||
|
@ -99,7 +99,10 @@ public:
|
||||
/// Special method for vector similarity indexes:
|
||||
/// Returns the row positions of the N nearest neighbors in the index granule
|
||||
/// The returned row numbers are guaranteed to be sorted and unique.
|
||||
virtual std::vector<UInt64> calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr) const
|
||||
virtual std::vector<UInt64> calculateApproximateNearestNeighbors(
|
||||
MergeTreeIndexGranulePtr /*granule*/,
|
||||
size_t /*limit*/,
|
||||
const std::vector<Float64> & /*reference_vector*/) const
|
||||
{
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "calculateApproximateNearestNeighbors is not implemented for non-vector-similarity indexes");
|
||||
}
|
||||
|
@ -1,354 +0,0 @@
|
||||
#include <Storages/MergeTree/VectorSimilarityCondition.h>
|
||||
|
||||
#include <Core/Settings.h>
|
||||
#include <Interpreters/Context.h>
|
||||
#include <Parsers/ASTFunction.h>
|
||||
#include <Parsers/ASTIdentifier.h>
|
||||
#include <Parsers/ASTLiteral.h>
|
||||
#include <Parsers/ASTOrderByElement.h>
|
||||
#include <Parsers/ASTSelectQuery.h>
|
||||
#include <Parsers/ASTSetQuery.h>
|
||||
#include <Storages/MergeTree/KeyCondition.h>
|
||||
#include <Storages/MergeTree/MergeTreeSettings.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace Setting
|
||||
{
|
||||
extern const SettingsUInt64 max_limit_for_ann_queries;
|
||||
}
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int INCORRECT_QUERY;
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
template <typename Literal>
|
||||
void extractReferenceVectorFromLiteral(std::vector<Float64> & reference_vector, Literal literal)
|
||||
{
|
||||
Float64 float_element_of_reference_vector;
|
||||
Int64 int_element_of_reference_vector;
|
||||
|
||||
for (const auto & value : literal.value())
|
||||
{
|
||||
if (value.tryGet(float_element_of_reference_vector))
|
||||
reference_vector.emplace_back(float_element_of_reference_vector);
|
||||
else if (value.tryGet(int_element_of_reference_vector))
|
||||
reference_vector.emplace_back(static_cast<float>(int_element_of_reference_vector));
|
||||
else
|
||||
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported.");
|
||||
}
|
||||
}
|
||||
|
||||
VectorSimilarityCondition::Info::DistanceFunction stringToDistanceFunction(const String & distance_function)
|
||||
{
|
||||
if (distance_function == "L2Distance")
|
||||
return VectorSimilarityCondition::Info::DistanceFunction::L2;
|
||||
if (distance_function == "cosineDistance")
|
||||
return VectorSimilarityCondition::Info::DistanceFunction::Cosine;
|
||||
return VectorSimilarityCondition::Info::DistanceFunction::Unknown;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
VectorSimilarityCondition::VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context)
|
||||
: block_with_constants(KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context))
|
||||
, max_limit_for_ann_queries(context->getSettingsRef()[Setting::max_limit_for_ann_queries])
|
||||
, index_is_useful(checkQueryStructure(query_info))
|
||||
{}
|
||||
|
||||
bool VectorSimilarityCondition::alwaysUnknownOrTrue(const String & distance_function) const
|
||||
{
|
||||
if (!index_is_useful)
|
||||
return true; /// query isn't supported
|
||||
/// If query is supported, check if distance function of index is the same as distance function in query
|
||||
return !(stringToDistanceFunction(distance_function) == query_information->distance_function);
|
||||
}
|
||||
|
||||
UInt64 VectorSimilarityCondition::getLimit() const
|
||||
{
|
||||
if (index_is_useful && query_information.has_value())
|
||||
return query_information->limit;
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
|
||||
}
|
||||
|
||||
std::vector<Float64> VectorSimilarityCondition::getReferenceVector() const
|
||||
{
|
||||
if (index_is_useful && query_information.has_value())
|
||||
return query_information->reference_vector;
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index.");
|
||||
}
|
||||
|
||||
size_t VectorSimilarityCondition::getDimensions() const
|
||||
{
|
||||
if (index_is_useful && query_information.has_value())
|
||||
return query_information->reference_vector.size();
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index.");
|
||||
}
|
||||
|
||||
String VectorSimilarityCondition::getColumnName() const
|
||||
{
|
||||
if (index_is_useful && query_information.has_value())
|
||||
return query_information->column_name;
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index.");
|
||||
}
|
||||
|
||||
VectorSimilarityCondition::Info::DistanceFunction VectorSimilarityCondition::getDistanceFunction() const
|
||||
{
|
||||
if (index_is_useful && query_information.has_value())
|
||||
return query_information->distance_function;
|
||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Distance function was requested for useless or uninitialized index.");
|
||||
}
|
||||
|
||||
bool VectorSimilarityCondition::checkQueryStructure(const SelectQueryInfo & query)
|
||||
{
|
||||
Info order_by_info;
|
||||
|
||||
/// Build rpns for query sections
|
||||
const auto & select = query.query->as<ASTSelectQuery &>();
|
||||
|
||||
RPN rpn_order_by;
|
||||
RPNElement rpn_limit;
|
||||
UInt64 limit;
|
||||
|
||||
if (select.limitLength())
|
||||
traverseAtomAST(select.limitLength(), rpn_limit);
|
||||
|
||||
if (select.orderBy())
|
||||
traverseOrderByAST(select.orderBy(), rpn_order_by);
|
||||
|
||||
/// Reverse RPNs for conveniences during parsing
|
||||
std::reverse(rpn_order_by.begin(), rpn_order_by.end());
|
||||
|
||||
const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by, order_by_info);
|
||||
const bool limit_is_valid = matchRPNLimit(rpn_limit, limit);
|
||||
|
||||
if (!limit_is_valid || limit > max_limit_for_ann_queries)
|
||||
return false;
|
||||
|
||||
if (order_by_is_valid)
|
||||
{
|
||||
query_information = std::move(order_by_info);
|
||||
query_information->limit = limit;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void VectorSimilarityCondition::traverseAST(const ASTPtr & node, RPN & rpn)
|
||||
{
|
||||
/// If the node is ASTFunction, it may have children nodes
|
||||
if (const auto * func = node->as<ASTFunction>())
|
||||
{
|
||||
const ASTs & children = func->arguments->children;
|
||||
/// Traverse children nodes
|
||||
for (const auto& child : children)
|
||||
traverseAST(child, rpn);
|
||||
}
|
||||
|
||||
RPNElement element;
|
||||
/// Get the data behind node
|
||||
if (!traverseAtomAST(node, element))
|
||||
element.function = RPNElement::FUNCTION_UNKNOWN;
|
||||
|
||||
rpn.emplace_back(std::move(element));
|
||||
}
|
||||
|
||||
bool VectorSimilarityCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
|
||||
{
|
||||
/// Match Functions
|
||||
if (const auto * function = node->as<ASTFunction>())
|
||||
{
|
||||
/// Set the name
|
||||
out.func_name = function->name;
|
||||
|
||||
if (function->name == "L1Distance" ||
|
||||
function->name == "L2Distance" ||
|
||||
function->name == "LinfDistance" ||
|
||||
function->name == "cosineDistance" ||
|
||||
function->name == "dotProduct")
|
||||
out.function = RPNElement::FUNCTION_DISTANCE;
|
||||
else if (function->name == "array")
|
||||
out.function = RPNElement::FUNCTION_ARRAY;
|
||||
else if (function->name == "_CAST")
|
||||
out.function = RPNElement::FUNCTION_CAST;
|
||||
else
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
/// Match identifier
|
||||
if (const auto * identifier = node->as<ASTIdentifier>())
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_IDENTIFIER;
|
||||
out.identifier.emplace(identifier->name());
|
||||
out.func_name = "column identifier";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Check if we have constants behind the node
|
||||
return tryCastToConstType(node, out);
|
||||
}
|
||||
|
||||
bool VectorSimilarityCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
|
||||
{
|
||||
Field const_value;
|
||||
DataTypePtr const_type;
|
||||
|
||||
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
|
||||
{
|
||||
/// Check for constant types
|
||||
if (const_value.getType() == Field::Types::Float64)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
|
||||
out.float_literal.emplace(const_value.safeGet<Float32>());
|
||||
out.func_name = "Float literal";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (const_value.getType() == Field::Types::UInt64)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_INT_LITERAL;
|
||||
out.int_literal.emplace(const_value.safeGet<UInt64>());
|
||||
out.func_name = "Int literal";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (const_value.getType() == Field::Types::Int64)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_INT_LITERAL;
|
||||
out.int_literal.emplace(const_value.safeGet<Int64>());
|
||||
out.func_name = "Int literal";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (const_value.getType() == Field::Types::Array)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_LITERAL_ARRAY;
|
||||
out.array_literal = const_value.safeGet<Array>();
|
||||
out.func_name = "Array literal";
|
||||
return true;
|
||||
}
|
||||
|
||||
if (const_value.getType() == Field::Types::String)
|
||||
{
|
||||
out.function = RPNElement::FUNCTION_STRING_LITERAL;
|
||||
out.func_name = const_value.safeGet<String>();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void VectorSimilarityCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn)
|
||||
{
|
||||
if (const auto * expr_list = node->as<ASTExpressionList>())
|
||||
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
|
||||
traverseAST(order_by_element->children.front(), rpn);
|
||||
}
|
||||
|
||||
/// Returns true and stores ANNExpr if the query has valid ORDERBY clause
|
||||
bool VectorSimilarityCondition::matchRPNOrderBy(RPN & rpn, Info & info)
|
||||
{
|
||||
/// ORDER BY clause must have at least 3 expressions
|
||||
if (rpn.size() < 3)
|
||||
return false;
|
||||
|
||||
auto iter = rpn.begin();
|
||||
auto end = rpn.end();
|
||||
|
||||
bool identifier_found = false;
|
||||
|
||||
/// Matches DistanceFunc->[Column]->[ArrayFunc]->ReferenceVector(floats)->[Column]
|
||||
if (iter->function != RPNElement::FUNCTION_DISTANCE)
|
||||
return false;
|
||||
|
||||
info.distance_function = stringToDistanceFunction(iter->func_name);
|
||||
++iter;
|
||||
|
||||
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
|
||||
{
|
||||
identifier_found = true;
|
||||
info.column_name = std::move(iter->identifier.value());
|
||||
++iter;
|
||||
}
|
||||
|
||||
if (iter->function == RPNElement::FUNCTION_ARRAY)
|
||||
++iter;
|
||||
|
||||
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
|
||||
{
|
||||
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
|
||||
++iter;
|
||||
}
|
||||
|
||||
/// further conditions are possible if there is no array, or no identifier is found
|
||||
/// the array can be inside a cast function. For other cases, see the loop after this condition
|
||||
if (iter != end && iter->function == RPNElement::FUNCTION_CAST)
|
||||
{
|
||||
++iter;
|
||||
/// Cast should be made to array
|
||||
if (!iter->func_name.starts_with("Array"))
|
||||
return false;
|
||||
++iter;
|
||||
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
|
||||
{
|
||||
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
|
||||
++iter;
|
||||
}
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
while (iter != end)
|
||||
{
|
||||
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
|
||||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
|
||||
info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter));
|
||||
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
|
||||
{
|
||||
if (identifier_found)
|
||||
return false;
|
||||
info.column_name = std::move(iter->identifier.value());
|
||||
identifier_found = true;
|
||||
}
|
||||
else
|
||||
return false;
|
||||
|
||||
++iter;
|
||||
}
|
||||
|
||||
/// Final checks of correctness
|
||||
return identifier_found && !info.reference_vector.empty();
|
||||
}
|
||||
|
||||
/// Returns true and stores Length if we have valid LIMIT clause in query
|
||||
bool VectorSimilarityCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
|
||||
{
|
||||
if (rpn.function == RPNElement::FUNCTION_INT_LITERAL)
|
||||
{
|
||||
limit = rpn.int_literal.value();
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Gets float or int from AST node
|
||||
float VectorSimilarityCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter)
|
||||
{
|
||||
if (iter->float_literal.has_value())
|
||||
return iter->float_literal.value();
|
||||
if (iter->int_literal.has_value())
|
||||
return static_cast<float>(iter->int_literal.value());
|
||||
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong parsed AST in buildRPN\n");
|
||||
}
|
||||
|
||||
}
|
@ -1,147 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <Storages/MergeTree/MergeTreeIndices.h>
|
||||
#include "base/types.h"
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Class VectorSimilarityCondition is responsible for recognizing if the query
|
||||
/// can utilize vector similarity indexes.
|
||||
///
|
||||
/// Method alwaysUnknownOrTrue returns false if we can speed up the query, and true otherwise.
|
||||
/// It has only one argument, the name of the distance function with which index was built.
|
||||
/// Only queries with ORDER BY DistanceFunc and LIMIT, i.e.:
|
||||
///
|
||||
/// SELECT * FROM * ... ORDER BY DistanceFunc(column, reference_vector) LIMIT count
|
||||
///
|
||||
/// reference_vector should have float coordinates, e.g. [0.2, 0.1, .., 0.5]
|
||||
class VectorSimilarityCondition
|
||||
{
|
||||
public:
|
||||
VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context);
|
||||
|
||||
/// vector similarity queries have a similar structure:
|
||||
/// - reference vector from which all distances are calculated
|
||||
/// - distance function, e.g L2Distance
|
||||
/// - name of column with embeddings
|
||||
/// - type of query
|
||||
/// - maximum number of returned elements (LIMIT)
|
||||
///
|
||||
/// And one optional parameter:
|
||||
/// - distance to compare with (only for where queries)
|
||||
///
|
||||
/// This struct holds all these components.
|
||||
struct Info
|
||||
{
|
||||
enum class DistanceFunction : uint8_t
|
||||
{
|
||||
Unknown,
|
||||
L2,
|
||||
Cosine
|
||||
};
|
||||
|
||||
std::vector<Float64> reference_vector;
|
||||
DistanceFunction distance_function;
|
||||
String column_name;
|
||||
UInt64 limit;
|
||||
float distance = -1.0;
|
||||
};
|
||||
|
||||
/// Returns false if query can be speeded up by an ANN index, true otherwise.
|
||||
bool alwaysUnknownOrTrue(const String & distance_function) const;
|
||||
|
||||
std::vector<Float64> getReferenceVector() const;
|
||||
size_t getDimensions() const;
|
||||
String getColumnName() const;
|
||||
Info::DistanceFunction getDistanceFunction() const;
|
||||
UInt64 getLimit() const;
|
||||
|
||||
private:
|
||||
struct RPNElement
|
||||
{
|
||||
enum Function
|
||||
{
|
||||
/// DistanceFunctions
|
||||
FUNCTION_DISTANCE,
|
||||
|
||||
/// array(0.1, ..., 0.1)
|
||||
FUNCTION_ARRAY,
|
||||
|
||||
/// Operators <, >, <=, >=
|
||||
FUNCTION_COMPARISON,
|
||||
|
||||
/// Numeric float value
|
||||
FUNCTION_FLOAT_LITERAL,
|
||||
|
||||
/// Numeric int value
|
||||
FUNCTION_INT_LITERAL,
|
||||
|
||||
/// Column identifier
|
||||
FUNCTION_IDENTIFIER,
|
||||
|
||||
/// Unknown, can be any value
|
||||
FUNCTION_UNKNOWN,
|
||||
|
||||
/// [0.1, ...., 0.1] vector without word 'array'
|
||||
FUNCTION_LITERAL_ARRAY,
|
||||
|
||||
/// if client parameters are used, cast will always be in the query
|
||||
FUNCTION_CAST,
|
||||
|
||||
/// name of type in cast function
|
||||
FUNCTION_STRING_LITERAL,
|
||||
};
|
||||
|
||||
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
|
||||
: function(function_)
|
||||
{}
|
||||
|
||||
Function function;
|
||||
String func_name = "Unknown";
|
||||
|
||||
std::optional<float> float_literal;
|
||||
std::optional<String> identifier;
|
||||
std::optional<int64_t> int_literal;
|
||||
std::optional<Array> array_literal;
|
||||
|
||||
UInt32 dim = 0;
|
||||
};
|
||||
|
||||
using RPN = std::vector<RPNElement>;
|
||||
|
||||
bool checkQueryStructure(const SelectQueryInfo & query);
|
||||
|
||||
/// Util functions for the traversal of AST, parses AST and builds rpn
|
||||
void traverseAST(const ASTPtr & node, RPN & rpn);
|
||||
/// Return true if we can identify our node type
|
||||
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
|
||||
/// Checks if the AST stores ConstType expression
|
||||
bool tryCastToConstType(const ASTPtr & node, RPNElement & out);
|
||||
/// Traverses the AST of ORDERBY section
|
||||
void traverseOrderByAST(const ASTPtr & node, RPN & rpn);
|
||||
|
||||
/// Returns true and stores ANNExpr if the query has valid ORDERBY section
|
||||
static bool matchRPNOrderBy(RPN & rpn, Info & info);
|
||||
|
||||
/// Returns true and stores Length if we have valid LIMIT clause in query
|
||||
static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit);
|
||||
|
||||
/// Gets float or int from AST node
|
||||
static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter);
|
||||
|
||||
Block block_with_constants;
|
||||
|
||||
/// true if we have one of two supported query types
|
||||
std::optional<Info> query_information;
|
||||
|
||||
/// only queries with a lower limit can be considered to avoid memory overflow
|
||||
const UInt64 max_limit_for_ann_queries;
|
||||
|
||||
bool index_is_useful = false;
|
||||
};
|
||||
|
||||
}
|
@ -2,6 +2,21 @@
|
||||
|
||||
-- Tests various simple approximate nearest neighborhood (ANN) queries that utilize vector search indexes.
|
||||
|
||||
SET allow_experimental_vector_similarity_index = 1;
|
||||
SET enable_analyzer = 0;
|
||||
|
||||
CREATE OR REPLACE TABLE tab(id Int32, vec Array(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 8192;
|
||||
INSERT INTO tab VALUES (0, [1.0, 0.0]), (1, [1.1, 0.0]), (2, [1.2, 0.0]), (3, [1.3, 0.0]), (4, [1.4, 0.0]), (5, [0.0, 2.0]), (6, [0.0, 2.1]), (7, [0.0, 2.2]), (8, [0.0, 2.3]), (9, [0.0, 2.4]);
|
||||
|
||||
SELECT id, vec, L2Distance(vec, [0.0, 2.0])
|
||||
FROM tab
|
||||
ORDER BY L2Distance(vec, [0.0, 2.0])
|
||||
LIMIT 3;
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
SET allow_experimental_vector_similarity_index = 1;
|
||||
|
||||
SET enable_analyzer = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user