Cosmetics: target vector --> reference vector

This commit is contained in:
Robert Schulze 2023-05-25 21:42:46 +00:00
parent 567d54a268
commit 6fe208832d
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
3 changed files with 38 additions and 38 deletions

View File

@ -25,19 +25,19 @@ namespace
{
template <typename Literal>
void extractTargetVectorFromLiteral(ApproximateNearestNeighborInformation::Embedding & target, Literal literal)
void extraceReferenceVectorFromLiteral(ApproximateNearestNeighborInformation::Embedding & reference_vector, Literal literal)
{
Float64 float_element_of_target_vector;
Int64 int_element_of_target_vector;
Float64 float_element_of_reference_vector;
Int64 int_element_of_reference_vector;
for (const auto & value : literal.value())
{
if (value.tryGet(float_element_of_target_vector))
target.emplace_back(float_element_of_target_vector);
else if (value.tryGet(int_element_of_target_vector))
target.emplace_back(static_cast<float>(int_element_of_target_vector));
if (value.tryGet(float_element_of_reference_vector))
reference_vector.emplace_back(float_element_of_reference_vector);
else if (value.tryGet(int_element_of_reference_vector))
reference_vector.emplace_back(static_cast<float>(int_element_of_reference_vector));
else
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in target vector. Only float or int are supported.");
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported.");
}
}
@ -82,17 +82,17 @@ UInt64 ApproximateNearestNeighborCondition::getLimit() const
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
}
std::vector<float> ApproximateNearestNeighborCondition::getTargetVector() const
std::vector<float> ApproximateNearestNeighborCondition::getReferenceVector() const
{
if (index_is_useful && query_information.has_value())
return query_information->target;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Target vector was requested for useless or uninitialized index.");
return query_information->reference_vector;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index.");
}
size_t ApproximateNearestNeighborCondition::getNumOfDimensions() const
{
if (index_is_useful && query_information.has_value())
return query_information->target.size();
return query_information->reference_vector.size();
throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index.");
}
@ -327,7 +327,7 @@ bool ApproximateNearestNeighborCondition::matchRPNWhere(RPN & rpn, ApproximateNe
ann_info.query_type = ApproximateNearestNeighborInformation::Type::Where;
// WHERE section must have at least 5 expressions
// Operator->Distance(float)->DistanceFunc->Column->Tuple(Array)Func(TargetVector(floats))
// Operator->Distance(float)->DistanceFunc->Column->Tuple(Array)Func(ReferenceVector(floats))
if (rpn.size() < 5)
return false;
@ -363,12 +363,12 @@ bool ApproximateNearestNeighborCondition::matchRPNWhere(RPN & rpn, ApproximateNe
if (greater_case)
{
if (ann_info.target.size() < 2)
if (ann_info.reference_vector.size() < 2)
return false;
ann_info.distance = ann_info.target.back();
ann_info.distance = ann_info.reference_vector.back();
if (ann_info.distance < 0)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance can't be negative. Got {}", ann_info.distance);
ann_info.target.pop_back();
ann_info.reference_vector.pop_back();
}
// query is ok
@ -403,12 +403,12 @@ bool ApproximateNearestNeighborCondition::matchRPNLimit(RPNElement & rpn, UInt64
return false;
}
/* Matches dist function, target vector, column name */
/* Matches dist function, referencer vector, column name */
bool ApproximateNearestNeighborCondition::matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ApproximateNearestNeighborInformation & ann_info)
{
bool identifier_found = false;
// Matches DistanceFunc->[Column]->[Tuple(array)Func]->TargetVector(floats)->[Column]
// Matches DistanceFunc->[Column]->[Tuple(array)Func]->ReferenceVector(floats)->[Column]
if (iter->function != RPNElement::FUNCTION_DISTANCE)
return false;
@ -436,13 +436,13 @@ bool ApproximateNearestNeighborCondition::matchMainParts(RPN::iterator & iter, c
if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE)
{
extractTargetVectorFromLiteral(ann_info.target, iter->tuple_literal);
extraceReferenceVectorFromLiteral(ann_info.reference_vector, iter->tuple_literal);
++iter;
}
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractTargetVectorFromLiteral(ann_info.target, iter->array_literal);
extraceReferenceVectorFromLiteral(ann_info.reference_vector, iter->array_literal);
++iter;
}
@ -457,12 +457,12 @@ bool ApproximateNearestNeighborCondition::matchMainParts(RPN::iterator & iter, c
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE)
{
extractTargetVectorFromLiteral(ann_info.target, iter->tuple_literal);
extraceReferenceVectorFromLiteral(ann_info.reference_vector, iter->tuple_literal);
++iter;
}
else if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractTargetVectorFromLiteral(ann_info.target, iter->array_literal);
extraceReferenceVectorFromLiteral(ann_info.reference_vector, iter->array_literal);
++iter;
}
else
@ -473,7 +473,7 @@ bool ApproximateNearestNeighborCondition::matchMainParts(RPN::iterator & iter, c
{
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
ann_info.target.emplace_back(getFloatOrIntLiteralOrPanic(iter));
ann_info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter));
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
if (identifier_found)
@ -488,7 +488,7 @@ bool ApproximateNearestNeighborCondition::matchMainParts(RPN::iterator & iter, c
}
// Final checks of correctness
return identifier_found && !ann_info.target.empty();
return identifier_found && !ann_info.reference_vector.empty();
}
// Gets float or int from AST node

View File

@ -12,7 +12,7 @@ namespace DB
/**
* Queries for Approximate Nearest Neighbour Search
* have similar structure:
* 1) target vector from which all distances are calculated
* 1) reference vector from which all distances are calculated
* 2) metric name (e.g L2Distance, LpDistance, etc.)
* 3) name of column with embeddings
* 4) type of query
@ -27,7 +27,7 @@ struct ApproximateNearestNeighborInformation
using Embedding = std::vector<float>;
// Extracted data from valid query
Embedding target;
Embedding reference_vector;
enum class Metric
{
Unknown,
@ -56,14 +56,14 @@ struct ApproximateNearestNeighborInformation
There are two main patterns of queries being supported
1) Search query type
SELECT * FROM * WHERE DistanceFunc(column, target_vector) < floatLiteral LIMIT count
SELECT * FROM * WHERE DistanceFunc(column, reference) < floatLiteral LIMIT count
2) OrderBy query type
SELECT * FROM * WHERE * ORDERBY DistanceFunc(column, target_vector) LIMIT count
SELECT * FROM * WHERE * ORDERBY DistanceFunc(column, reference) LIMIT count
*Query without LIMIT count is not supported*
target_vector(should have float coordinates) examples:
reference(should have float coordinates) examples:
tuple(0.1, 0.1, ...., 0.1) or (0.1, 0.1, ...., 0.1)
[the word tuple is not needed]
@ -72,11 +72,11 @@ struct ApproximateNearestNeighborInformation
returns true.
From matching query it extracts
* targetVector
* referenceVector
* metricName(DistanceFunction)
* dimension size if query uses LpDistance
* distance to compare(ONLY for search types, otherwise you get exception)
* spaceDimension(which is targetVector's components count)
* spaceDimension(which is reference vector's components count)
* column
* objects count from LIMIT clause(for both queries)
* queryHasOrderByClause and queryHasWhereClause return true if query matches the type
@ -96,10 +96,10 @@ public:
// returns the distance to compare with for search query
float getComparisonDistanceForWhereQuery() const;
// distance should be calculated regarding to targetVector
std::vector<float> getTargetVector() const;
// distance should be calculated regarding to reference vector
std::vector<float> getReferenceVector() const;
// targetVector dimension size
// reference vector's dimension size
size_t getNumOfDimensions() const;
String getColumnName() const;
@ -196,7 +196,7 @@ private:
// Returns true and stores Length if we have valid LIMIT clause in query
static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit);
/* Matches dist function, target vector, column name */
/* Matches dist function, reference vector, column name */
static bool matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ApproximateNearestNeighborInformation & ann_info);
// Gets float or int from AST node

View File

@ -247,7 +247,7 @@ std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRangesImpl(MergeTreeI
if (comp_dist && comp_dist.value() < 0)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to optimize query with where without distance");
std::vector<float> target_vec = condition.getTargetVector();
std::vector<float> reference_vector = condition.getReferenceVector();
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy<Distance>>(idx_granule);
if (granule == nullptr)
@ -260,13 +260,13 @@ std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRangesImpl(MergeTreeI
"does not match with the dimension in the index ({})",
toString(condition.getNumOfDimensions()), toString(annoy->getNumOfDimensions()));
/// neighbors contain indexes of dots which were closest to target vector
/// neighbors contain indexes of dots which were closest to the reference vector
std::vector<UInt64> neighbors;
std::vector<Float32> distances;
neighbors.reserve(limit);
distances.reserve(limit);
annoy->get_nns_by_vector(target_vec.data(), limit, static_cast<int>(search_k), &neighbors, &distances);
annoy->get_nns_by_vector(reference_vector.data(), limit, static_cast<int>(search_k), &neighbors, &distances);
std::unordered_set<size_t> granule_numbers;
for (size_t i = 0; i < neighbors.size(); ++i)