mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-21 23:21:59 +00:00
Merge fa635358d2
into 44b4bd38b9
This commit is contained in:
commit
f659541db2
@ -84,8 +84,6 @@ to load and compare. The library also has several hardware-specific SIMD optimiz
|
|||||||
Arm (NEON and SVE) and x86 (AVX2 and AVX-512) CPUs and OS-specific optimizations to allow efficient navigation around immutable persistent
|
Arm (NEON and SVE) and x86 (AVX2 and AVX-512) CPUs and OS-specific optimizations to allow efficient navigation around immutable persistent
|
||||||
files, without loading them into RAM.
|
files, without loading them into RAM.
|
||||||
|
|
||||||
USearch indexes are currently experimental, to use them you first need to `SET allow_experimental_vector_similarity_index = 1`.
|
|
||||||
|
|
||||||
Vector similarity indexes currently support two distance functions:
|
Vector similarity indexes currently support two distance functions:
|
||||||
- `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space
|
- `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space
|
||||||
([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)).
|
([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)).
|
||||||
|
@ -71,6 +71,8 @@ void tryRemoveRedundantSorting(QueryPlan::Node * root);
|
|||||||
/// Remove redundant distinct steps
|
/// Remove redundant distinct steps
|
||||||
size_t tryRemoveRedundantDistinct(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
|
size_t tryRemoveRedundantDistinct(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
|
||||||
|
|
||||||
|
size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & nodes);
|
||||||
|
|
||||||
/// Put some steps under union, so that plan optimization could be applied to union parts separately.
|
/// Put some steps under union, so that plan optimization could be applied to union parts separately.
|
||||||
/// For example, the plan can be rewritten like:
|
/// For example, the plan can be rewritten like:
|
||||||
/// - Something - - Expression - Something -
|
/// - Something - - Expression - Something -
|
||||||
@ -82,7 +84,7 @@ size_t tryAggregatePartitionsIndependently(QueryPlan::Node * node, QueryPlan::No
|
|||||||
|
|
||||||
inline const auto & getOptimizations()
|
inline const auto & getOptimizations()
|
||||||
{
|
{
|
||||||
static const std::array<Optimization, 12> optimizations = {{
|
static const std::array<Optimization, 13> optimizations = {{
|
||||||
{tryLiftUpArrayJoin, "liftUpArrayJoin", &QueryPlanOptimizationSettings::lift_up_array_join},
|
{tryLiftUpArrayJoin, "liftUpArrayJoin", &QueryPlanOptimizationSettings::lift_up_array_join},
|
||||||
{tryPushDownLimit, "pushDownLimit", &QueryPlanOptimizationSettings::push_down_limit},
|
{tryPushDownLimit, "pushDownLimit", &QueryPlanOptimizationSettings::push_down_limit},
|
||||||
{trySplitFilter, "splitFilter", &QueryPlanOptimizationSettings::split_filter},
|
{trySplitFilter, "splitFilter", &QueryPlanOptimizationSettings::split_filter},
|
||||||
@ -95,6 +97,7 @@ inline const auto & getOptimizations()
|
|||||||
{tryLiftUpUnion, "liftUpUnion", &QueryPlanOptimizationSettings::lift_up_union},
|
{tryLiftUpUnion, "liftUpUnion", &QueryPlanOptimizationSettings::lift_up_union},
|
||||||
{tryAggregatePartitionsIndependently, "aggregatePartitionsIndependently", &QueryPlanOptimizationSettings::aggregate_partitions_independently},
|
{tryAggregatePartitionsIndependently, "aggregatePartitionsIndependently", &QueryPlanOptimizationSettings::aggregate_partitions_independently},
|
||||||
{tryRemoveRedundantDistinct, "removeRedundantDistinct", &QueryPlanOptimizationSettings::remove_redundant_distinct},
|
{tryRemoveRedundantDistinct, "removeRedundantDistinct", &QueryPlanOptimizationSettings::remove_redundant_distinct},
|
||||||
|
{tryUseVectorSearch, "useVectorSearch", &QueryPlanOptimizationSettings::use_vector_search},
|
||||||
}};
|
}};
|
||||||
|
|
||||||
return optimizations;
|
return optimizations;
|
||||||
@ -110,7 +113,9 @@ using Stack = std::vector<Frame>;
|
|||||||
|
|
||||||
/// Second pass optimizations
|
/// Second pass optimizations
|
||||||
void optimizePrimaryKeyConditionAndLimit(const Stack & stack);
|
void optimizePrimaryKeyConditionAndLimit(const Stack & stack);
|
||||||
|
void optimizeVectorSearch(const Stack & stack);
|
||||||
void optimizePrewhere(Stack & stack, QueryPlan::Nodes & nodes);
|
void optimizePrewhere(Stack & stack, QueryPlan::Nodes & nodes);
|
||||||
|
void optimizeVectorSearch(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
|
||||||
void optimizeReadInOrder(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
|
void optimizeReadInOrder(QueryPlan::Node & node, QueryPlan::Nodes & nodes);
|
||||||
void optimizeAggregationInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);
|
void optimizeAggregationInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);
|
||||||
void optimizeDistinctInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);
|
void optimizeDistinctInOrder(QueryPlan::Node & node, QueryPlan::Nodes &);
|
||||||
|
@ -70,6 +70,9 @@ struct QueryPlanOptimizationSettings
|
|||||||
/// If remove-redundant-distinct-steps optimization is enabled.
|
/// If remove-redundant-distinct-steps optimization is enabled.
|
||||||
bool remove_redundant_distinct = true;
|
bool remove_redundant_distinct = true;
|
||||||
|
|
||||||
|
/// If use vector search is enabled, the query will use the vector similarity index
|
||||||
|
bool use_vector_search = true;
|
||||||
|
|
||||||
bool optimize_prewhere = true;
|
bool optimize_prewhere = true;
|
||||||
|
|
||||||
/// If reading from projection can be applied
|
/// If reading from projection can be applied
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
#include <Processors/QueryPlan/ExpressionStep.h>
|
#include <Processors/QueryPlan/ExpressionStep.h>
|
||||||
#include <Processors/QueryPlan/FilterStep.h>
|
#include <Processors/QueryPlan/FilterStep.h>
|
||||||
#include <Processors/QueryPlan/ITransformingStep.h>
|
#include <Processors/QueryPlan/ITransformingStep.h>
|
||||||
|
#include <Processors/QueryPlan/LimitStep.h>
|
||||||
#include <Processors/QueryPlan/JoinStep.h>
|
#include <Processors/QueryPlan/JoinStep.h>
|
||||||
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
|
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
|
||||||
#include <Processors/QueryPlan/Optimizations/actionsDAGUtils.h>
|
#include <Processors/QueryPlan/Optimizations/actionsDAGUtils.h>
|
||||||
|
@ -126,7 +126,6 @@ void optimizeTreeSecondPass(const QueryPlanOptimizationSettings & optimization_s
|
|||||||
|
|
||||||
if (frame.next_child == 0)
|
if (frame.next_child == 0)
|
||||||
{
|
{
|
||||||
|
|
||||||
if (optimization_settings.read_in_order)
|
if (optimization_settings.read_in_order)
|
||||||
optimizeReadInOrder(*frame.node, nodes);
|
optimizeReadInOrder(*frame.node, nodes);
|
||||||
|
|
||||||
|
@ -0,0 +1,53 @@
|
|||||||
|
#include <Processors/QueryPlan/Optimizations/Optimizations.h>
|
||||||
|
#include <Processors/QueryPlan/ExpressionStep.h>
|
||||||
|
#include <Processors/QueryPlan/FilterStep.h>
|
||||||
|
#include <Processors/QueryPlan/LimitStep.h>
|
||||||
|
#include <Processors/QueryPlan/SourceStepWithFilter.h>
|
||||||
|
|
||||||
|
namespace DB::QueryPlanOptimizations
|
||||||
|
{
|
||||||
|
|
||||||
|
void xx([[maybe_unused]] const QueryPlan::Node & node)
|
||||||
|
{
|
||||||
|
/// const auto & frame = stack.back();
|
||||||
|
///
|
||||||
|
/// auto * source_step_with_filter = dynamic_cast<SourceStepWithFilter *>(frame.node->step.get());
|
||||||
|
/// if (!source_step_with_filter)
|
||||||
|
/// return;
|
||||||
|
///
|
||||||
|
/// const auto & storage_prewhere_info = source_step_with_filter->getPrewhereInfo();
|
||||||
|
/// if (storage_prewhere_info)
|
||||||
|
/// {
|
||||||
|
/// source_step_with_filter->addFilter(storage_prewhere_info->prewhere_actions.clone(), storage_prewhere_info->prewhere_column_name);
|
||||||
|
/// if (storage_prewhere_info->row_level_filter)
|
||||||
|
/// source_step_with_filter->addFilter(storage_prewhere_info->row_level_filter->clone(), storage_prewhere_info->row_level_column_name);
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// for (auto iter = stack.rbegin() + 1; iter != stack.rend(); ++iter)
|
||||||
|
/// {
|
||||||
|
/// if (auto * filter_step = typeid_cast<FilterStep *>(iter->node->step.get()))
|
||||||
|
/// {
|
||||||
|
/// source_step_with_filter->addFilter(filter_step->getExpression().clone(), filter_step->getFilterColumnName());
|
||||||
|
/// }
|
||||||
|
/// else if (auto * limit_step = typeid_cast<LimitStep *>(iter->node->step.get()))
|
||||||
|
/// {
|
||||||
|
/// source_step_with_filter->setLimit(limit_step->getLimitForSorting());
|
||||||
|
/// break;
|
||||||
|
/// }
|
||||||
|
/// else if (typeid_cast<ExpressionStep *>(iter->node->step.get()))
|
||||||
|
/// {
|
||||||
|
/// /// Note: actually, plan optimizations merge Filter and Expression steps.
|
||||||
|
/// /// Ideally, chain should look like (Expression -> ...) -> (Filter -> ...) -> ReadFromStorage,
|
||||||
|
/// /// So this is likely not needed.
|
||||||
|
/// continue;
|
||||||
|
/// }
|
||||||
|
/// else
|
||||||
|
/// {
|
||||||
|
/// break;
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// source_step_with_filter->applyFilters();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
76
src/Processors/QueryPlan/Optimizations/useVectorSearch.cpp
Normal file
76
src/Processors/QueryPlan/Optimizations/useVectorSearch.cpp
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#include <memory>
|
||||||
|
#include <Processors/QueryPlan/QueryPlan.h>
|
||||||
|
#include <Processors/QueryPlan/ExpressionStep.h>
|
||||||
|
#include <Processors/QueryPlan/ReadFromMergeTree.h>
|
||||||
|
#include <Processors/QueryPlan/LimitStep.h>
|
||||||
|
#include <Processors/QueryPlan/SortingStep.h>
|
||||||
|
|
||||||
|
namespace DB::QueryPlanOptimizations
|
||||||
|
{
|
||||||
|
|
||||||
|
size_t tryUseVectorSearch(QueryPlan::Node * parent_node, QueryPlan::Nodes & /* nodes*/)
|
||||||
|
{
|
||||||
|
auto * limit_step = typeid_cast<LimitStep *>(parent_node->step.get());
|
||||||
|
if (!limit_step)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
if (parent_node->children.size() != 1)
|
||||||
|
return 0;
|
||||||
|
QueryPlan::Node * child_node = parent_node->children.front();
|
||||||
|
auto * sorting_step = typeid_cast<SortingStep *>(child_node->step.get());
|
||||||
|
if (!sorting_step)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
if (child_node->children.size() != 1)
|
||||||
|
return 0;
|
||||||
|
child_node = child_node->children.front();
|
||||||
|
auto * expression_step = typeid_cast<ExpressionStep *>(child_node->step.get());
|
||||||
|
if (!expression_step)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
if (child_node->children.size() != 1)
|
||||||
|
return 0;
|
||||||
|
child_node = child_node->children.front();
|
||||||
|
auto * read_from_mergetree_step = typeid_cast<ReadFromMergeTree *>(child_node->step.get());
|
||||||
|
if (!read_from_mergetree_step)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
const auto & sort_description = sorting_step->getSortDescription();
|
||||||
|
|
||||||
|
if (sort_description.size() != 1)
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
[[maybe_unused]] ReadFromMergeTree::DistanceFunction distance_function;
|
||||||
|
|
||||||
|
/// lol
|
||||||
|
if (sort_description[0].column_name.starts_with("L2Distance"))
|
||||||
|
distance_function = ReadFromMergeTree::DistanceFunction::L2Distance;
|
||||||
|
else if (sort_description[0].column_name.starts_with("cosineDistance"))
|
||||||
|
distance_function = ReadFromMergeTree::DistanceFunction::cosineDistance;
|
||||||
|
else
|
||||||
|
return 0;
|
||||||
|
|
||||||
|
[[maybe_unused]] size_t limit = sorting_step->getLimit();
|
||||||
|
[[maybe_unused]] std::vector<Float64> reference_vector = {0.0, 0.2}; /// TODO
|
||||||
|
|
||||||
|
/// TODO check that ReadFromMergeTree has a vector similarity index
|
||||||
|
|
||||||
|
read_from_mergetree_step->vec_sim_idx_input = std::make_optional<ReadFromMergeTree::VectorSimilarityIndexInput>(
|
||||||
|
distance_function, limit, reference_vector);
|
||||||
|
|
||||||
|
/// --- --- ---
|
||||||
|
/// alwaysUnknownOrTrue:
|
||||||
|
///
|
||||||
|
/// String index_distance_function;
|
||||||
|
/// switch (metric_kind)
|
||||||
|
/// {
|
||||||
|
/// case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
|
||||||
|
/// case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
|
||||||
|
/// default: std::unreachable();
|
||||||
|
/// }
|
||||||
|
/// return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
@ -1489,7 +1489,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(Merge
|
|||||||
all_column_names,
|
all_column_names,
|
||||||
log,
|
log,
|
||||||
indexes,
|
indexes,
|
||||||
find_exact_ranges);
|
find_exact_ranges,
|
||||||
|
vec_sim_idx_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void buildIndexes(
|
static void buildIndexes(
|
||||||
@ -1682,7 +1683,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(
|
|||||||
const Names & all_column_names,
|
const Names & all_column_names,
|
||||||
LoggerPtr log,
|
LoggerPtr log,
|
||||||
std::optional<Indexes> & indexes,
|
std::optional<Indexes> & indexes,
|
||||||
bool find_exact_ranges)
|
bool find_exact_ranges,
|
||||||
|
const std::optional<VectorSimilarityIndexInput> & vec_sim_idx_input)
|
||||||
{
|
{
|
||||||
AnalysisResult result;
|
AnalysisResult result;
|
||||||
const auto & settings = context_->getSettingsRef();
|
const auto & settings = context_->getSettingsRef();
|
||||||
@ -1775,7 +1777,8 @@ ReadFromMergeTree::AnalysisResultPtr ReadFromMergeTree::selectRangesToRead(
|
|||||||
num_streams,
|
num_streams,
|
||||||
result.index_stats,
|
result.index_stats,
|
||||||
indexes->use_skip_indexes,
|
indexes->use_skip_indexes,
|
||||||
find_exact_ranges);
|
find_exact_ranges,
|
||||||
|
vec_sim_idx_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t sum_marks_pk = total_marks_pk;
|
size_t sum_marks_pk = total_marks_pk;
|
||||||
|
@ -166,6 +166,19 @@ public:
|
|||||||
std::optional<std::unordered_set<String>> part_values;
|
std::optional<std::unordered_set<String>> part_values;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum class DistanceFunction
|
||||||
|
{
|
||||||
|
L2Distance,
|
||||||
|
cosineDistance
|
||||||
|
};
|
||||||
|
|
||||||
|
struct VectorSimilarityIndexInput
|
||||||
|
{
|
||||||
|
DistanceFunction distance_function;
|
||||||
|
size_t limit;
|
||||||
|
std::vector<Float64> reference_vector;
|
||||||
|
};
|
||||||
|
|
||||||
static AnalysisResultPtr selectRangesToRead(
|
static AnalysisResultPtr selectRangesToRead(
|
||||||
MergeTreeData::DataPartsVector parts,
|
MergeTreeData::DataPartsVector parts,
|
||||||
MergeTreeData::MutationsSnapshotPtr mutations_snapshot,
|
MergeTreeData::MutationsSnapshotPtr mutations_snapshot,
|
||||||
@ -178,7 +191,8 @@ public:
|
|||||||
const Names & all_column_names,
|
const Names & all_column_names,
|
||||||
LoggerPtr log,
|
LoggerPtr log,
|
||||||
std::optional<Indexes> & indexes,
|
std::optional<Indexes> & indexes,
|
||||||
bool find_exact_ranges);
|
bool find_exact_ranges,
|
||||||
|
const std::optional<VectorSimilarityIndexInput> & vec_sim_idx_input);
|
||||||
|
|
||||||
AnalysisResultPtr selectRangesToRead(MergeTreeData::DataPartsVector parts, bool find_exact_ranges = false) const;
|
AnalysisResultPtr selectRangesToRead(MergeTreeData::DataPartsVector parts, bool find_exact_ranges = false) const;
|
||||||
|
|
||||||
@ -212,6 +226,8 @@ public:
|
|||||||
|
|
||||||
void applyFilters(ActionDAGNodes added_filter_nodes) override;
|
void applyFilters(ActionDAGNodes added_filter_nodes) override;
|
||||||
|
|
||||||
|
std::optional<VectorSimilarityIndexInput> vec_sim_idx_input;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MergeTreeReaderSettings reader_settings;
|
MergeTreeReaderSettings reader_settings;
|
||||||
|
|
||||||
|
@ -11,7 +11,6 @@
|
|||||||
#include <Storages/MergeTree/MergeTreeDataPartUUID.h>
|
#include <Storages/MergeTree/MergeTreeDataPartUUID.h>
|
||||||
#include <Storages/MergeTree/StorageFromMergeTreeDataPart.h>
|
#include <Storages/MergeTree/StorageFromMergeTreeDataPart.h>
|
||||||
#include <Storages/MergeTree/MergeTreeIndexFullText.h>
|
#include <Storages/MergeTree/MergeTreeIndexFullText.h>
|
||||||
#include <Storages/MergeTree/VectorSimilarityCondition.h>
|
|
||||||
#include <Storages/ReadInOrderOptimizer.h>
|
#include <Storages/ReadInOrderOptimizer.h>
|
||||||
#include <Storages/VirtualColumnUtils.h>
|
#include <Storages/VirtualColumnUtils.h>
|
||||||
#include <Parsers/ASTIdentifier.h>
|
#include <Parsers/ASTIdentifier.h>
|
||||||
@ -629,7 +628,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd
|
|||||||
size_t num_streams,
|
size_t num_streams,
|
||||||
ReadFromMergeTree::IndexStats & index_stats,
|
ReadFromMergeTree::IndexStats & index_stats,
|
||||||
bool use_skip_indexes,
|
bool use_skip_indexes,
|
||||||
bool find_exact_ranges)
|
bool find_exact_ranges,
|
||||||
|
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input)
|
||||||
{
|
{
|
||||||
RangesInDataParts parts_with_ranges;
|
RangesInDataParts parts_with_ranges;
|
||||||
parts_with_ranges.resize(parts.size());
|
parts_with_ranges.resize(parts.size());
|
||||||
@ -734,7 +734,8 @@ RangesInDataParts MergeTreeDataSelectExecutor::filterPartsByPrimaryKeyAndSkipInd
|
|||||||
reader_settings,
|
reader_settings,
|
||||||
mark_cache.get(),
|
mark_cache.get(),
|
||||||
uncompressed_cache.get(),
|
uncompressed_cache.get(),
|
||||||
log);
|
log,
|
||||||
|
vec_sim_idx_input);
|
||||||
|
|
||||||
stat.granules_dropped.fetch_add(total_granules - ranges.ranges.getNumberOfMarks(), std::memory_order_relaxed);
|
stat.granules_dropped.fetch_add(total_granules - ranges.ranges.getNumberOfMarks(), std::memory_order_relaxed);
|
||||||
if (ranges.ranges.empty())
|
if (ranges.ranges.empty())
|
||||||
@ -950,7 +951,8 @@ ReadFromMergeTree::AnalysisResultPtr MergeTreeDataSelectExecutor::estimateNumMar
|
|||||||
column_names_to_return,
|
column_names_to_return,
|
||||||
log,
|
log,
|
||||||
indexes,
|
indexes,
|
||||||
/*find_exact_ranges*/false);
|
/*find_exact_ranges*/false,
|
||||||
|
std::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
QueryPlanStepPtr MergeTreeDataSelectExecutor::readFromParts(
|
QueryPlanStepPtr MergeTreeDataSelectExecutor::readFromParts(
|
||||||
@ -1364,7 +1366,8 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
|
|||||||
const MergeTreeReaderSettings & reader_settings,
|
const MergeTreeReaderSettings & reader_settings,
|
||||||
MarkCache * mark_cache,
|
MarkCache * mark_cache,
|
||||||
UncompressedCache * uncompressed_cache,
|
UncompressedCache * uncompressed_cache,
|
||||||
LoggerPtr log)
|
LoggerPtr log,
|
||||||
|
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input)
|
||||||
{
|
{
|
||||||
if (!index_helper->getDeserializedFormat(part->getDataPartStorage(), index_helper->getFileName()))
|
if (!index_helper->getDeserializedFormat(part->getDataPartStorage(), index_helper->getFileName()))
|
||||||
{
|
{
|
||||||
@ -1424,9 +1427,9 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
|
|||||||
if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin)
|
if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin)
|
||||||
reader.read(granule);
|
reader.read(granule);
|
||||||
|
|
||||||
if (index_helper->isVectorSimilarityIndex())
|
if (index_helper->isVectorSimilarityIndex() && vec_sim_idx_input)
|
||||||
{
|
{
|
||||||
auto rows = condition->calculateApproximateNearestNeighbors(granule);
|
auto rows = condition->calculateApproximateNearestNeighbors(granule, vec_sim_idx_input->limit, vec_sim_idx_input->reference_vector); /// TODO
|
||||||
|
|
||||||
for (auto row : rows)
|
for (auto row : rows)
|
||||||
{
|
{
|
||||||
|
@ -94,7 +94,8 @@ private:
|
|||||||
const MergeTreeReaderSettings & reader_settings,
|
const MergeTreeReaderSettings & reader_settings,
|
||||||
MarkCache * mark_cache,
|
MarkCache * mark_cache,
|
||||||
UncompressedCache * uncompressed_cache,
|
UncompressedCache * uncompressed_cache,
|
||||||
LoggerPtr log);
|
LoggerPtr log,
|
||||||
|
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input);
|
||||||
|
|
||||||
static MarkRanges filterMarksUsingMergedIndex(
|
static MarkRanges filterMarksUsingMergedIndex(
|
||||||
MergeTreeIndices indices,
|
MergeTreeIndices indices,
|
||||||
@ -197,7 +198,8 @@ public:
|
|||||||
size_t num_streams,
|
size_t num_streams,
|
||||||
ReadFromMergeTree::IndexStats & index_stats,
|
ReadFromMergeTree::IndexStats & index_stats,
|
||||||
bool use_skip_indexes,
|
bool use_skip_indexes,
|
||||||
bool find_exact_ranges);
|
bool find_exact_ranges,
|
||||||
|
const std::optional<ReadFromMergeTree::VectorSimilarityIndexInput> & vec_sim_idx_input);
|
||||||
|
|
||||||
/// Create expression for sampling.
|
/// Create expression for sampling.
|
||||||
/// Also, calculate _sample_factor if needed.
|
/// Also, calculate _sample_factor if needed.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <Storages/MergeTree/VectorSimilarityCondition.h>
|
#include <Storages/MergeTree/MergeTreeIndices.h>
|
||||||
|
|
||||||
/// Walking corpse implementation for removed skipping index of type "annoy" and "usearch".
|
/// Walking corpse implementation for removed skipping index of type "annoy" and "usearch".
|
||||||
/// Its only purpose is to allow loading old tables with indexes of these types.
|
/// Its only purpose is to allow loading old tables with indexes of these types.
|
||||||
|
@ -398,11 +398,9 @@ void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_
|
|||||||
|
|
||||||
MergeTreeIndexConditionVectorSimilarity::MergeTreeIndexConditionVectorSimilarity(
|
MergeTreeIndexConditionVectorSimilarity::MergeTreeIndexConditionVectorSimilarity(
|
||||||
const IndexDescription & /*index_description*/,
|
const IndexDescription & /*index_description*/,
|
||||||
const SelectQueryInfo & query,
|
|
||||||
unum::usearch::metric_kind_t metric_kind_,
|
unum::usearch::metric_kind_t metric_kind_,
|
||||||
ContextPtr context)
|
ContextPtr context)
|
||||||
: vector_similarity_condition(query, context)
|
: metric_kind(metric_kind_)
|
||||||
, metric_kind(metric_kind_)
|
|
||||||
, expansion_search(context->getSettingsRef()[Setting::hnsw_candidate_list_size_for_search])
|
, expansion_search(context->getSettingsRef()[Setting::hnsw_candidate_list_size_for_search])
|
||||||
{
|
{
|
||||||
if (expansion_search == 0)
|
if (expansion_search == 0)
|
||||||
@ -417,31 +415,23 @@ bool MergeTreeIndexConditionVectorSimilarity::mayBeTrueOnGranule(MergeTreeIndexG
|
|||||||
|
|
||||||
bool MergeTreeIndexConditionVectorSimilarity::alwaysUnknownOrTrue() const
|
bool MergeTreeIndexConditionVectorSimilarity::alwaysUnknownOrTrue() const
|
||||||
{
|
{
|
||||||
String index_distance_function;
|
return false;
|
||||||
switch (metric_kind)
|
|
||||||
{
|
|
||||||
case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
|
|
||||||
case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
|
|
||||||
default: std::unreachable();
|
|
||||||
}
|
|
||||||
return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<UInt64> MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule_) const
|
std::vector<UInt64> MergeTreeIndexConditionVectorSimilarity::calculateApproximateNearestNeighbors(
|
||||||
|
MergeTreeIndexGranulePtr granule_,
|
||||||
|
size_t limit,
|
||||||
|
const std::vector<Float64> & reference_vector) const
|
||||||
{
|
{
|
||||||
const UInt64 limit = vector_similarity_condition.getLimit();
|
|
||||||
|
|
||||||
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleVectorSimilarity>(granule_);
|
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleVectorSimilarity>(granule_);
|
||||||
if (granule == nullptr)
|
if (granule == nullptr)
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
|
||||||
|
|
||||||
const USearchIndexWithSerializationPtr index = granule->index;
|
const USearchIndexWithSerializationPtr index = granule->index;
|
||||||
|
|
||||||
if (vector_similarity_condition.getDimensions() != index->dimensions())
|
if (reference_vector.size() != index->dimensions())
|
||||||
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) does not match the dimension in the index ({})",
|
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the reference vector in the query ({}) does not match the dimension in the index ({})",
|
||||||
vector_similarity_condition.getDimensions(), index->dimensions());
|
reference_vector.size(), index->dimensions());
|
||||||
|
|
||||||
const std::vector<Float64> reference_vector = vector_similarity_condition.getReferenceVector();
|
|
||||||
|
|
||||||
/// We want to run the search with the user-provided value for setting hnsw_candidate_list_size_for_search (aka. expansion_search).
|
/// We want to run the search with the user-provided value for setting hnsw_candidate_list_size_for_search (aka. expansion_search).
|
||||||
/// The way to do this in USearch is to call index_dense_gt::change_expansion_search. Unfortunately, this introduces a need to
|
/// The way to do this in USearch is to call index_dense_gt::change_expansion_search. Unfortunately, this introduces a need to
|
||||||
@ -495,9 +485,9 @@ MergeTreeIndexAggregatorPtr MergeTreeIndexVectorSimilarity::createIndexAggregato
|
|||||||
return std::make_shared<MergeTreeIndexAggregatorVectorSimilarity>(index.name, index.sample_block, metric_kind, scalar_kind, usearch_hnsw_params);
|
return std::make_shared<MergeTreeIndexAggregatorVectorSimilarity>(index.name, index.sample_block, metric_kind, scalar_kind, usearch_hnsw_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
|
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & /*query*/, ContextPtr context) const
|
||||||
{
|
{
|
||||||
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, query, metric_kind, context);
|
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, metric_kind, context);
|
||||||
};
|
};
|
||||||
|
|
||||||
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const ActionsDAG *, ContextPtr) const
|
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const ActionsDAG *, ContextPtr) const
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
#if USE_USEARCH
|
#if USE_USEARCH
|
||||||
|
|
||||||
#include <Storages/MergeTree/VectorSimilarityCondition.h>
|
#include <Storages/MergeTree/MergeTreeIndices.h>
|
||||||
#include <Common/Logger.h>
|
#include <Common/Logger.h>
|
||||||
#include <usearch/index_dense.hpp>
|
#include <usearch/index_dense.hpp>
|
||||||
|
|
||||||
@ -133,7 +133,6 @@ class MergeTreeIndexConditionVectorSimilarity final : public IMergeTreeIndexCond
|
|||||||
public:
|
public:
|
||||||
MergeTreeIndexConditionVectorSimilarity(
|
MergeTreeIndexConditionVectorSimilarity(
|
||||||
const IndexDescription & index_description,
|
const IndexDescription & index_description,
|
||||||
const SelectQueryInfo & query,
|
|
||||||
unum::usearch::metric_kind_t metric_kind_,
|
unum::usearch::metric_kind_t metric_kind_,
|
||||||
ContextPtr context);
|
ContextPtr context);
|
||||||
|
|
||||||
@ -141,10 +140,12 @@ public:
|
|||||||
|
|
||||||
bool alwaysUnknownOrTrue() const override;
|
bool alwaysUnknownOrTrue() const override;
|
||||||
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr granule) const override;
|
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr granule) const override;
|
||||||
std::vector<UInt64> calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr granule) const override;
|
std::vector<UInt64> calculateApproximateNearestNeighbors(
|
||||||
|
MergeTreeIndexGranulePtr granule,
|
||||||
|
size_t limit,
|
||||||
|
const std::vector<Float64> & reference_vector) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const VectorSimilarityCondition vector_similarity_condition;
|
|
||||||
const unum::usearch::metric_kind_t metric_kind;
|
const unum::usearch::metric_kind_t metric_kind;
|
||||||
const size_t expansion_search;
|
const size_t expansion_search;
|
||||||
};
|
};
|
||||||
|
@ -99,7 +99,10 @@ public:
|
|||||||
/// Special method for vector similarity indexes:
|
/// Special method for vector similarity indexes:
|
||||||
/// Returns the row positions of the N nearest neighbors in the index granule
|
/// Returns the row positions of the N nearest neighbors in the index granule
|
||||||
/// The returned row numbers are guaranteed to be sorted and unique.
|
/// The returned row numbers are guaranteed to be sorted and unique.
|
||||||
virtual std::vector<UInt64> calculateApproximateNearestNeighbors(MergeTreeIndexGranulePtr) const
|
virtual std::vector<UInt64> calculateApproximateNearestNeighbors(
|
||||||
|
MergeTreeIndexGranulePtr /*granule*/,
|
||||||
|
size_t /*limit*/,
|
||||||
|
const std::vector<Float64> & /*reference_vector*/) const
|
||||||
{
|
{
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "calculateApproximateNearestNeighbors is not implemented for non-vector-similarity indexes");
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "calculateApproximateNearestNeighbors is not implemented for non-vector-similarity indexes");
|
||||||
}
|
}
|
||||||
|
@ -1,354 +0,0 @@
|
|||||||
#include <Storages/MergeTree/VectorSimilarityCondition.h>
|
|
||||||
|
|
||||||
#include <Core/Settings.h>
|
|
||||||
#include <Interpreters/Context.h>
|
|
||||||
#include <Parsers/ASTFunction.h>
|
|
||||||
#include <Parsers/ASTIdentifier.h>
|
|
||||||
#include <Parsers/ASTLiteral.h>
|
|
||||||
#include <Parsers/ASTOrderByElement.h>
|
|
||||||
#include <Parsers/ASTSelectQuery.h>
|
|
||||||
#include <Parsers/ASTSetQuery.h>
|
|
||||||
#include <Storages/MergeTree/KeyCondition.h>
|
|
||||||
#include <Storages/MergeTree/MergeTreeSettings.h>
|
|
||||||
|
|
||||||
namespace DB
|
|
||||||
{
|
|
||||||
namespace Setting
|
|
||||||
{
|
|
||||||
extern const SettingsUInt64 max_limit_for_ann_queries;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace ErrorCodes
|
|
||||||
{
|
|
||||||
extern const int LOGICAL_ERROR;
|
|
||||||
extern const int INCORRECT_QUERY;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace
|
|
||||||
{
|
|
||||||
|
|
||||||
template <typename Literal>
|
|
||||||
void extractReferenceVectorFromLiteral(std::vector<Float64> & reference_vector, Literal literal)
|
|
||||||
{
|
|
||||||
Float64 float_element_of_reference_vector;
|
|
||||||
Int64 int_element_of_reference_vector;
|
|
||||||
|
|
||||||
for (const auto & value : literal.value())
|
|
||||||
{
|
|
||||||
if (value.tryGet(float_element_of_reference_vector))
|
|
||||||
reference_vector.emplace_back(float_element_of_reference_vector);
|
|
||||||
else if (value.tryGet(int_element_of_reference_vector))
|
|
||||||
reference_vector.emplace_back(static_cast<float>(int_element_of_reference_vector));
|
|
||||||
else
|
|
||||||
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
VectorSimilarityCondition::Info::DistanceFunction stringToDistanceFunction(const String & distance_function)
|
|
||||||
{
|
|
||||||
if (distance_function == "L2Distance")
|
|
||||||
return VectorSimilarityCondition::Info::DistanceFunction::L2;
|
|
||||||
if (distance_function == "cosineDistance")
|
|
||||||
return VectorSimilarityCondition::Info::DistanceFunction::Cosine;
|
|
||||||
return VectorSimilarityCondition::Info::DistanceFunction::Unknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
VectorSimilarityCondition::VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context)
|
|
||||||
: block_with_constants(KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context))
|
|
||||||
, max_limit_for_ann_queries(context->getSettingsRef()[Setting::max_limit_for_ann_queries])
|
|
||||||
, index_is_useful(checkQueryStructure(query_info))
|
|
||||||
{}
|
|
||||||
|
|
||||||
bool VectorSimilarityCondition::alwaysUnknownOrTrue(const String & distance_function) const
|
|
||||||
{
|
|
||||||
if (!index_is_useful)
|
|
||||||
return true; /// query isn't supported
|
|
||||||
/// If query is supported, check if distance function of index is the same as distance function in query
|
|
||||||
return !(stringToDistanceFunction(distance_function) == query_information->distance_function);
|
|
||||||
}
|
|
||||||
|
|
||||||
UInt64 VectorSimilarityCondition::getLimit() const
|
|
||||||
{
|
|
||||||
if (index_is_useful && query_information.has_value())
|
|
||||||
return query_information->limit;
|
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<Float64> VectorSimilarityCondition::getReferenceVector() const
|
|
||||||
{
|
|
||||||
if (index_is_useful && query_information.has_value())
|
|
||||||
return query_information->reference_vector;
|
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index.");
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t VectorSimilarityCondition::getDimensions() const
|
|
||||||
{
|
|
||||||
if (index_is_useful && query_information.has_value())
|
|
||||||
return query_information->reference_vector.size();
|
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index.");
|
|
||||||
}
|
|
||||||
|
|
||||||
String VectorSimilarityCondition::getColumnName() const
|
|
||||||
{
|
|
||||||
if (index_is_useful && query_information.has_value())
|
|
||||||
return query_information->column_name;
|
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index.");
|
|
||||||
}
|
|
||||||
|
|
||||||
VectorSimilarityCondition::Info::DistanceFunction VectorSimilarityCondition::getDistanceFunction() const
|
|
||||||
{
|
|
||||||
if (index_is_useful && query_information.has_value())
|
|
||||||
return query_information->distance_function;
|
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Distance function was requested for useless or uninitialized index.");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool VectorSimilarityCondition::checkQueryStructure(const SelectQueryInfo & query)
|
|
||||||
{
|
|
||||||
Info order_by_info;
|
|
||||||
|
|
||||||
/// Build rpns for query sections
|
|
||||||
const auto & select = query.query->as<ASTSelectQuery &>();
|
|
||||||
|
|
||||||
RPN rpn_order_by;
|
|
||||||
RPNElement rpn_limit;
|
|
||||||
UInt64 limit;
|
|
||||||
|
|
||||||
if (select.limitLength())
|
|
||||||
traverseAtomAST(select.limitLength(), rpn_limit);
|
|
||||||
|
|
||||||
if (select.orderBy())
|
|
||||||
traverseOrderByAST(select.orderBy(), rpn_order_by);
|
|
||||||
|
|
||||||
/// Reverse RPNs for conveniences during parsing
|
|
||||||
std::reverse(rpn_order_by.begin(), rpn_order_by.end());
|
|
||||||
|
|
||||||
const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by, order_by_info);
|
|
||||||
const bool limit_is_valid = matchRPNLimit(rpn_limit, limit);
|
|
||||||
|
|
||||||
if (!limit_is_valid || limit > max_limit_for_ann_queries)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
if (order_by_is_valid)
|
|
||||||
{
|
|
||||||
query_information = std::move(order_by_info);
|
|
||||||
query_information->limit = limit;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void VectorSimilarityCondition::traverseAST(const ASTPtr & node, RPN & rpn)
|
|
||||||
{
|
|
||||||
/// If the node is ASTFunction, it may have children nodes
|
|
||||||
if (const auto * func = node->as<ASTFunction>())
|
|
||||||
{
|
|
||||||
const ASTs & children = func->arguments->children;
|
|
||||||
/// Traverse children nodes
|
|
||||||
for (const auto& child : children)
|
|
||||||
traverseAST(child, rpn);
|
|
||||||
}
|
|
||||||
|
|
||||||
RPNElement element;
|
|
||||||
/// Get the data behind node
|
|
||||||
if (!traverseAtomAST(node, element))
|
|
||||||
element.function = RPNElement::FUNCTION_UNKNOWN;
|
|
||||||
|
|
||||||
rpn.emplace_back(std::move(element));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool VectorSimilarityCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
|
|
||||||
{
|
|
||||||
/// Match Functions
|
|
||||||
if (const auto * function = node->as<ASTFunction>())
|
|
||||||
{
|
|
||||||
/// Set the name
|
|
||||||
out.func_name = function->name;
|
|
||||||
|
|
||||||
if (function->name == "L1Distance" ||
|
|
||||||
function->name == "L2Distance" ||
|
|
||||||
function->name == "LinfDistance" ||
|
|
||||||
function->name == "cosineDistance" ||
|
|
||||||
function->name == "dotProduct")
|
|
||||||
out.function = RPNElement::FUNCTION_DISTANCE;
|
|
||||||
else if (function->name == "array")
|
|
||||||
out.function = RPNElement::FUNCTION_ARRAY;
|
|
||||||
else if (function->name == "_CAST")
|
|
||||||
out.function = RPNElement::FUNCTION_CAST;
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
/// Match identifier
|
|
||||||
if (const auto * identifier = node->as<ASTIdentifier>())
|
|
||||||
{
|
|
||||||
out.function = RPNElement::FUNCTION_IDENTIFIER;
|
|
||||||
out.identifier.emplace(identifier->name());
|
|
||||||
out.func_name = "column identifier";
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if we have constants behind the node
|
|
||||||
return tryCastToConstType(node, out);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool VectorSimilarityCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
|
|
||||||
{
|
|
||||||
Field const_value;
|
|
||||||
DataTypePtr const_type;
|
|
||||||
|
|
||||||
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
|
|
||||||
{
|
|
||||||
/// Check for constant types
|
|
||||||
if (const_value.getType() == Field::Types::Float64)
|
|
||||||
{
|
|
||||||
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
|
|
||||||
out.float_literal.emplace(const_value.safeGet<Float32>());
|
|
||||||
out.func_name = "Float literal";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (const_value.getType() == Field::Types::UInt64)
|
|
||||||
{
|
|
||||||
out.function = RPNElement::FUNCTION_INT_LITERAL;
|
|
||||||
out.int_literal.emplace(const_value.safeGet<UInt64>());
|
|
||||||
out.func_name = "Int literal";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (const_value.getType() == Field::Types::Int64)
|
|
||||||
{
|
|
||||||
out.function = RPNElement::FUNCTION_INT_LITERAL;
|
|
||||||
out.int_literal.emplace(const_value.safeGet<Int64>());
|
|
||||||
out.func_name = "Int literal";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (const_value.getType() == Field::Types::Array)
|
|
||||||
{
|
|
||||||
out.function = RPNElement::FUNCTION_LITERAL_ARRAY;
|
|
||||||
out.array_literal = const_value.safeGet<Array>();
|
|
||||||
out.func_name = "Array literal";
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (const_value.getType() == Field::Types::String)
|
|
||||||
{
|
|
||||||
out.function = RPNElement::FUNCTION_STRING_LITERAL;
|
|
||||||
out.func_name = const_value.safeGet<String>();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void VectorSimilarityCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn)
|
|
||||||
{
|
|
||||||
if (const auto * expr_list = node->as<ASTExpressionList>())
|
|
||||||
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
|
|
||||||
traverseAST(order_by_element->children.front(), rpn);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true and stores ANNExpr if the query has valid ORDERBY clause
|
|
||||||
bool VectorSimilarityCondition::matchRPNOrderBy(RPN & rpn, Info & info)
|
|
||||||
{
|
|
||||||
/// ORDER BY clause must have at least 3 expressions
|
|
||||||
if (rpn.size() < 3)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto iter = rpn.begin();
|
|
||||||
auto end = rpn.end();
|
|
||||||
|
|
||||||
bool identifier_found = false;
|
|
||||||
|
|
||||||
/// Matches DistanceFunc->[Column]->[ArrayFunc]->ReferenceVector(floats)->[Column]
|
|
||||||
if (iter->function != RPNElement::FUNCTION_DISTANCE)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
info.distance_function = stringToDistanceFunction(iter->func_name);
|
|
||||||
++iter;
|
|
||||||
|
|
||||||
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
|
|
||||||
{
|
|
||||||
identifier_found = true;
|
|
||||||
info.column_name = std::move(iter->identifier.value());
|
|
||||||
++iter;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (iter->function == RPNElement::FUNCTION_ARRAY)
|
|
||||||
++iter;
|
|
||||||
|
|
||||||
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
|
|
||||||
{
|
|
||||||
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
|
|
||||||
++iter;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// further conditions are possible if there is no array, or no identifier is found
|
|
||||||
/// the array can be inside a cast function. For other cases, see the loop after this condition
|
|
||||||
if (iter != end && iter->function == RPNElement::FUNCTION_CAST)
|
|
||||||
{
|
|
||||||
++iter;
|
|
||||||
/// Cast should be made to array
|
|
||||||
if (!iter->func_name.starts_with("Array"))
|
|
||||||
return false;
|
|
||||||
++iter;
|
|
||||||
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
|
|
||||||
{
|
|
||||||
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
|
|
||||||
++iter;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
while (iter != end)
|
|
||||||
{
|
|
||||||
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
|
|
||||||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
|
|
||||||
info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter));
|
|
||||||
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
|
|
||||||
{
|
|
||||||
if (identifier_found)
|
|
||||||
return false;
|
|
||||||
info.column_name = std::move(iter->identifier.value());
|
|
||||||
identifier_found = true;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
return false;
|
|
||||||
|
|
||||||
++iter;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Final checks of correctness
|
|
||||||
return identifier_found && !info.reference_vector.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns true and stores Length if we have valid LIMIT clause in query
|
|
||||||
bool VectorSimilarityCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
|
|
||||||
{
|
|
||||||
if (rpn.function == RPNElement::FUNCTION_INT_LITERAL)
|
|
||||||
{
|
|
||||||
limit = rpn.int_literal.value();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Gets float or int from AST node
|
|
||||||
float VectorSimilarityCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter)
|
|
||||||
{
|
|
||||||
if (iter->float_literal.has_value())
|
|
||||||
return iter->float_literal.value();
|
|
||||||
if (iter->int_literal.has_value())
|
|
||||||
return static_cast<float>(iter->int_literal.value());
|
|
||||||
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong parsed AST in buildRPN\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
@ -1,147 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include <Storages/MergeTree/MergeTreeIndices.h>
|
|
||||||
#include "base/types.h"
|
|
||||||
|
|
||||||
#include <optional>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace DB
|
|
||||||
{
|
|
||||||
|
|
||||||
/// Class VectorSimilarityCondition is responsible for recognizing if the query
|
|
||||||
/// can utilize vector similarity indexes.
|
|
||||||
///
|
|
||||||
/// Method alwaysUnknownOrTrue returns false if we can speed up the query, and true otherwise.
|
|
||||||
/// It has only one argument, the name of the distance function with which index was built.
|
|
||||||
/// Only queries with ORDER BY DistanceFunc and LIMIT, i.e.:
|
|
||||||
///
|
|
||||||
/// SELECT * FROM * ... ORDER BY DistanceFunc(column, reference_vector) LIMIT count
|
|
||||||
///
|
|
||||||
/// reference_vector should have float coordinates, e.g. [0.2, 0.1, .., 0.5]
|
|
||||||
class VectorSimilarityCondition
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context);
|
|
||||||
|
|
||||||
/// vector similarity queries have a similar structure:
|
|
||||||
/// - reference vector from which all distances are calculated
|
|
||||||
/// - distance function, e.g L2Distance
|
|
||||||
/// - name of column with embeddings
|
|
||||||
/// - type of query
|
|
||||||
/// - maximum number of returned elements (LIMIT)
|
|
||||||
///
|
|
||||||
/// And one optional parameter:
|
|
||||||
/// - distance to compare with (only for where queries)
|
|
||||||
///
|
|
||||||
/// This struct holds all these components.
|
|
||||||
struct Info
|
|
||||||
{
|
|
||||||
enum class DistanceFunction : uint8_t
|
|
||||||
{
|
|
||||||
Unknown,
|
|
||||||
L2,
|
|
||||||
Cosine
|
|
||||||
};
|
|
||||||
|
|
||||||
std::vector<Float64> reference_vector;
|
|
||||||
DistanceFunction distance_function;
|
|
||||||
String column_name;
|
|
||||||
UInt64 limit;
|
|
||||||
float distance = -1.0;
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Returns false if query can be speeded up by an ANN index, true otherwise.
|
|
||||||
bool alwaysUnknownOrTrue(const String & distance_function) const;
|
|
||||||
|
|
||||||
std::vector<Float64> getReferenceVector() const;
|
|
||||||
size_t getDimensions() const;
|
|
||||||
String getColumnName() const;
|
|
||||||
Info::DistanceFunction getDistanceFunction() const;
|
|
||||||
UInt64 getLimit() const;
|
|
||||||
|
|
||||||
private:
|
|
||||||
struct RPNElement
|
|
||||||
{
|
|
||||||
enum Function
|
|
||||||
{
|
|
||||||
/// DistanceFunctions
|
|
||||||
FUNCTION_DISTANCE,
|
|
||||||
|
|
||||||
/// array(0.1, ..., 0.1)
|
|
||||||
FUNCTION_ARRAY,
|
|
||||||
|
|
||||||
/// Operators <, >, <=, >=
|
|
||||||
FUNCTION_COMPARISON,
|
|
||||||
|
|
||||||
/// Numeric float value
|
|
||||||
FUNCTION_FLOAT_LITERAL,
|
|
||||||
|
|
||||||
/// Numeric int value
|
|
||||||
FUNCTION_INT_LITERAL,
|
|
||||||
|
|
||||||
/// Column identifier
|
|
||||||
FUNCTION_IDENTIFIER,
|
|
||||||
|
|
||||||
/// Unknown, can be any value
|
|
||||||
FUNCTION_UNKNOWN,
|
|
||||||
|
|
||||||
/// [0.1, ...., 0.1] vector without word 'array'
|
|
||||||
FUNCTION_LITERAL_ARRAY,
|
|
||||||
|
|
||||||
/// if client parameters are used, cast will always be in the query
|
|
||||||
FUNCTION_CAST,
|
|
||||||
|
|
||||||
/// name of type in cast function
|
|
||||||
FUNCTION_STRING_LITERAL,
|
|
||||||
};
|
|
||||||
|
|
||||||
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
|
|
||||||
: function(function_)
|
|
||||||
{}
|
|
||||||
|
|
||||||
Function function;
|
|
||||||
String func_name = "Unknown";
|
|
||||||
|
|
||||||
std::optional<float> float_literal;
|
|
||||||
std::optional<String> identifier;
|
|
||||||
std::optional<int64_t> int_literal;
|
|
||||||
std::optional<Array> array_literal;
|
|
||||||
|
|
||||||
UInt32 dim = 0;
|
|
||||||
};
|
|
||||||
|
|
||||||
using RPN = std::vector<RPNElement>;
|
|
||||||
|
|
||||||
bool checkQueryStructure(const SelectQueryInfo & query);
|
|
||||||
|
|
||||||
/// Util functions for the traversal of AST, parses AST and builds rpn
|
|
||||||
void traverseAST(const ASTPtr & node, RPN & rpn);
|
|
||||||
/// Return true if we can identify our node type
|
|
||||||
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
|
|
||||||
/// Checks if the AST stores ConstType expression
|
|
||||||
bool tryCastToConstType(const ASTPtr & node, RPNElement & out);
|
|
||||||
/// Traverses the AST of ORDERBY section
|
|
||||||
void traverseOrderByAST(const ASTPtr & node, RPN & rpn);
|
|
||||||
|
|
||||||
/// Returns true and stores ANNExpr if the query has valid ORDERBY section
|
|
||||||
static bool matchRPNOrderBy(RPN & rpn, Info & info);
|
|
||||||
|
|
||||||
/// Returns true and stores Length if we have valid LIMIT clause in query
|
|
||||||
static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit);
|
|
||||||
|
|
||||||
/// Gets float or int from AST node
|
|
||||||
static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter);
|
|
||||||
|
|
||||||
Block block_with_constants;
|
|
||||||
|
|
||||||
/// true if we have one of two supported query types
|
|
||||||
std::optional<Info> query_information;
|
|
||||||
|
|
||||||
/// only queries with a lower limit can be considered to avoid memory overflow
|
|
||||||
const UInt64 max_limit_for_ann_queries;
|
|
||||||
|
|
||||||
bool index_is_useful = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
}
|
|
@ -2,6 +2,21 @@
|
|||||||
|
|
||||||
-- Tests various simple approximate nearest neighborhood (ANN) queries that utilize vector search indexes.
|
-- Tests various simple approximate nearest neighborhood (ANN) queries that utilize vector search indexes.
|
||||||
|
|
||||||
|
SET allow_experimental_vector_similarity_index = 1;
|
||||||
|
SET enable_analyzer = 0;
|
||||||
|
|
||||||
|
CREATE OR REPLACE TABLE tab(id Int32, vec Array(Float32), INDEX idx vec TYPE vector_similarity('hnsw', 'L2Distance')) ENGINE = MergeTree ORDER BY id SETTINGS index_granularity = 8192;
|
||||||
|
INSERT INTO tab VALUES (0, [1.0, 0.0]), (1, [1.1, 0.0]), (2, [1.2, 0.0]), (3, [1.3, 0.0]), (4, [1.4, 0.0]), (5, [0.0, 2.0]), (6, [0.0, 2.1]), (7, [0.0, 2.2]), (8, [0.0, 2.3]), (9, [0.0, 2.4]);
|
||||||
|
|
||||||
|
SELECT id, vec, L2Distance(vec, [0.0, 2.0])
|
||||||
|
FROM tab
|
||||||
|
ORDER BY L2Distance(vec, [0.0, 2.0])
|
||||||
|
LIMIT 3;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
SET allow_experimental_vector_similarity_index = 1;
|
SET allow_experimental_vector_similarity_index = 1;
|
||||||
|
|
||||||
SET enable_analyzer = 0;
|
SET enable_analyzer = 0;
|
||||||
|
Loading…
Reference in New Issue
Block a user