From e7abc06c89002d28ebfc10beac1e164a335f47e6 Mon Sep 17 00:00:00 2001 From: FArthur-cmd <613623@mail.ru> Date: Tue, 30 Aug 2022 15:26:56 +0000 Subject: [PATCH] Revert "Revert "Add Annoy index"" This reverts commit 6fdfb964d0a70e5d2a78cb727bdf30d0ba5a1a34. --- .gitmodules | 4 + contrib/CMakeLists.txt | 2 + contrib/annoy | 1 + contrib/annoy-cmake/CMakeLists.txt | 16 + .../mergetree-family/annindexes.md | 125 ++++ .../mergetree-family/mergetree.md | 4 + src/CMakeLists.txt | 4 + src/Core/Settings.h | 2 + src/Storages/MergeTree/CommonANNIndexes.cpp | 595 ++++++++++++++++++ src/Storages/MergeTree/CommonANNIndexes.h | 236 +++++++ .../MergeTree/MergeTreeDataSelectExecutor.cpp | 27 + .../MergeTree/MergeTreeIndexAnnoy.cpp | 317 ++++++++++ src/Storages/MergeTree/MergeTreeIndexAnnoy.h | 123 ++++ src/Storages/MergeTree/MergeTreeIndices.cpp | 5 + src/Storages/MergeTree/MergeTreeIndices.h | 5 + .../queries/0_stateless/02354_annoy.reference | 16 + tests/queries/0_stateless/02354_annoy.sql | 44 ++ 17 files changed, 1526 insertions(+) create mode 160000 contrib/annoy create mode 100644 contrib/annoy-cmake/CMakeLists.txt create mode 100644 docs/en/engines/table-engines/mergetree-family/annindexes.md create mode 100644 src/Storages/MergeTree/CommonANNIndexes.cpp create mode 100644 src/Storages/MergeTree/CommonANNIndexes.h create mode 100644 src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp create mode 100644 src/Storages/MergeTree/MergeTreeIndexAnnoy.h create mode 100644 tests/queries/0_stateless/02354_annoy.reference create mode 100644 tests/queries/0_stateless/02354_annoy.sql diff --git a/.gitmodules b/.gitmodules index 62b2f9d7766..f372a309cad 100644 --- a/.gitmodules +++ b/.gitmodules @@ -259,6 +259,10 @@ [submodule "contrib/minizip-ng"] path = contrib/minizip-ng url = https://github.com/zlib-ng/minizip-ng +[submodule "contrib/annoy"] + path = contrib/annoy + url = https://github.com/ClickHouse/annoy.git + branch = ClickHouse-master [submodule "contrib/qpl"] path = contrib/qpl url = https://github.com/intel/qpl.git diff --git a/contrib/CMakeLists.txt b/contrib/CMakeLists.txt index 08b91c1b81c..486fca60912 100644 --- a/contrib/CMakeLists.txt +++ b/contrib/CMakeLists.txt @@ -159,6 +159,8 @@ add_contrib (s2geometry-cmake s2geometry) add_contrib (c-ares-cmake c-ares) add_contrib (qpl-cmake qpl) +add_contrib(annoy-cmake annoy) + # Put all targets defined here and in subdirectories under "contrib/" folders in GUI-based IDEs. # Some of third-party projects may override CMAKE_FOLDER or FOLDER property of their targets, so they would not appear # in "contrib/..." as originally planned, so we workaround this by fixing FOLDER properties of all targets manually, diff --git a/contrib/annoy b/contrib/annoy new file mode 160000 index 00000000000..9d8a603a4cd --- /dev/null +++ b/contrib/annoy @@ -0,0 +1 @@ +Subproject commit 9d8a603a4cd252448589e84c9846f94368d5a289 diff --git a/contrib/annoy-cmake/CMakeLists.txt b/contrib/annoy-cmake/CMakeLists.txt new file mode 100644 index 00000000000..f2535ba7fde --- /dev/null +++ b/contrib/annoy-cmake/CMakeLists.txt @@ -0,0 +1,16 @@ +option(ENABLE_ANNOY "Enable Annoy index support" ${ENABLE_LIBRARIES}) + +if ((NOT ENABLE_ANNOY) OR (SANITIZE STREQUAL "undefined")) + message (STATUS "Not using annoy") + return() +endif() + +set(ANNOY_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/annoy") +set(ANNOY_SOURCE_DIR "${ANNOY_PROJECT_DIR}/src") + +add_library(_annoy INTERFACE) +target_include_directories(_annoy SYSTEM INTERFACE ${ANNOY_SOURCE_DIR}) + +add_library(ch_contrib::annoy ALIAS _annoy) +target_compile_definitions(_annoy INTERFACE ENABLE_ANNOY) +target_compile_definitions(_annoy INTERFACE ANNOYLIB_MULTITHREADED_BUILD) diff --git a/docs/en/engines/table-engines/mergetree-family/annindexes.md b/docs/en/engines/table-engines/mergetree-family/annindexes.md new file mode 100644 index 00000000000..6c669a4f7b6 --- /dev/null +++ b/docs/en/engines/table-engines/mergetree-family/annindexes.md @@ -0,0 +1,125 @@ +# Approximate Nearest Neighbor Search Indexes [experimental] {#table_engines-ANNIndex} + +The main task that indexes achieve is to quickly find nearest neighbors for multidimensional data. An example of such a problem can be finding similar pictures (texts) for a given picture (text). That problem can be reduced to finding the nearest [embeddings](https://cloud.google.com/architecture/overview-extracting-and-serving-feature-embeddings-for-machine-learning). They can be created from data using [UDF](../../../sql-reference/functions/index.md#executable-user-defined-functions). + +The next query finds the closest neighbors in N-dimensional space using the L2 (Euclidean) distance: +``` sql +SELECT * +FROM table_name +WHERE L2Distance(Column, Point) < MaxDistance +LIMIT N +``` +But it will take some time for execution because of the long calculation of the distance between `TargetEmbedding` and all other vectors. This is where ANN indexes can help. They store a compact approximation of the search space (e.g. using clustering, search trees, etc.) and are able to compute approximate neighbors quickly. + +## Indexes Structure + +Approximate Nearest Neighbor Search Indexes (`ANNIndexes`) are similar to skip indexes. They are constructed by some granules and determine which of them should be skipped. Compared to skip indices, ANN indices use their results not only to skip some group of granules, but also to select particular granules from a set of granules. + +`ANNIndexes` are designed to speed up two types of queries: + +- ###### Type 1: Where + ``` sql + SELECT * + FROM table_name + WHERE DistanceFunction(Column, Point) < MaxDistance + LIMIT N + ``` +- ###### Type 2: Order by + ``` sql + SELECT * + FROM table_name [WHERE ...] + ORDER BY DistanceFunction(Column, Point) + LIMIT N + ``` + +In these queries, `DistanceFunction` is selected from [distance functions](../../../sql-reference/functions/distance-functions). `Point` is a known vector (something like `(0.1, 0.1, ... )`). To avoid writing large vectors, use [client parameters](../../../interfaces/cli.md#queries-with-parameters-cli-queries-with-parameters). `Value` - a float value that will bound the neighbourhood. + +!!! note "Note" + ANN index can't speed up query that satisfies both types(`where + order by`, only one of them). All queries must have the limit, as algorithms are used to find nearest neighbors and need a specific number of them. + +!!! note "Note" + Indexes are applied only to queries with a limit less than the `max_limit_for_ann_queries` setting. This helps to avoid memory overflows in queries with a large limit. `max_limit_for_ann_queries` setting can be changed if you know you can provide enough memory. The default value is `1000000`. + +Both types of queries are handled the same way. The indexes get `n` neighbors (where `n` is taken from the `LIMIT` clause) and work with them. In `ORDER BY` query they remember the numbers of all parts of the granule that have at least one of neighbor. In `WHERE` query they remember only those parts that satisfy the requirements. + + + +## Create table with ANNIndex + +```sql +CREATE TABLE t +( + `id` Int64, + `number` Tuple(Float32, Float32, Float32), + INDEX x number TYPE annoy GRANULARITY N +) +ENGINE = MergeTree +ORDER BY id; +``` + +```sql +CREATE TABLE t +( + `id` Int64, + `number` Array(Float32), + INDEX x number TYPE annoy GRANULARITY N +) +ENGINE = MergeTree +ORDER BY id; +``` + +With greater `GRANULARITY` indexes remember the data structure better. The `GRANULARITY` indicates how many granules will be used to construct the index. The more data is provided for the index, the more of it can be handled by one index and the more chances that with the right hyperparameters the index will remember the data structure better. But some indexes can't be built if they don't have enough data, so this granule will always participate in the query. For more information, see the description of indexes. + +As the indexes are built only during insertions into table, `INSERT` and `OPTIMIZE` queries are slower than for ordinary table. At this stage indexes remember all the information about the given data. ANNIndexes should be used if you have immutable or rarely changed data and many read requests. + +You can create your table with index which uses certain algorithm. Now only indices based on the following algorithms are supported: + +# Index list +- [Annoy](../../../engines/table-engines/mergetree-family/annindexes.md#annoy-annoy) + +# Annoy {#annoy} +Implementation of the algorithm was taken from [this repository](https://github.com/spotify/annoy). + +Short description of the algorithm: +The algorithm recursively divides in half all space by random linear surfaces (lines in 2D, planes in 3D e.t.c.). Thus it makes tree of polyhedrons and points that they contains. Repeating the operation several times for greater accuracy it creates a forest. +To find K Nearest Neighbours it goes down through the trees and fills the buffer of closest points using the priority queue of polyhedrons. Next, it sorts buffer and return the nearest K points. + +__Examples__: +```sql +CREATE TABLE t +( + id Int64, + number Tuple(Float32, Float32, Float32), + INDEX x number TYPE annoy(T) GRANULARITY N +) +ENGINE = MergeTree +ORDER BY id; +``` + +```sql +CREATE TABLE t +( + id Int64, + number Array(Float32), + INDEX x number TYPE annoy(T) GRANULARITY N +) +ENGINE = MergeTree +ORDER BY id; +``` +!!! note "Note" + Table with array field will work faster, but all arrays **must** have same length. Use [CONSTRAINT](../../../sql-reference/statements/create/table.md#constraints) to avoid errors. For example, `CONSTRAINT constraint_name_1 CHECK length(number) = 256`. + +Parameter `T` is the number of trees which algorithm will create. The bigger it is, the slower (approximately linear) it works (in both `CREATE` and `SELECT` requests), but the better accuracy you get (adjusted for randomness). + +Annoy supports only `L2Distance`. + +In the `SELECT` in the settings (`ann_index_select_query_params`) you can specify the size of the internal buffer (more details in the description above or in the [original repository](https://github.com/spotify/annoy)). During the query it will inspect up to `search_k` nodes which defaults to `n_trees * n` if not provided. `search_k` gives you a run-time tradeoff between better accuracy and speed. + +__Example__: +``` sql +SELECT * +FROM table_name [WHERE ...] +ORDER BY L2Distance(Column, Point) +LIMIT N +SETTING ann_index_select_query_params=`k_search=100` +``` diff --git a/docs/en/engines/table-engines/mergetree-family/mergetree.md b/docs/en/engines/table-engines/mergetree-family/mergetree.md index 0ebe3c99f35..9dc7e300d45 100644 --- a/docs/en/engines/table-engines/mergetree-family/mergetree.md +++ b/docs/en/engines/table-engines/mergetree-family/mergetree.md @@ -481,6 +481,10 @@ For example: - `NOT startsWith(s, 'test')` ::: + +## Approximate Nearest Neighbor Search Indexes [experimental] {#table_engines-ANNIndex} +In addition to skip indices, there are also [Approximate Nearest Neighbor Search Indexes](../../../engines/table-engines/mergetree-family/annindexes.md). + ## Projections {#projections} Projections are like [materialized views](../../../sql-reference/statements/create/view.md#materialized) but defined in part-level. It provides consistency guarantees along with automatic usage in queries. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3ece5fd410b..c8aa0c84a24 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -570,6 +570,10 @@ endif() dbms_target_link_libraries(PUBLIC ch_contrib::consistent_hashing) +if (TARGET ch_contrib::annoy) + dbms_target_link_libraries(PUBLIC ch_contrib::annoy) +endif() + include ("${ClickHouse_SOURCE_DIR}/cmake/add_check.cmake") if (ENABLE_TESTS) diff --git a/src/Core/Settings.h b/src/Core/Settings.h index af32c15a867..026b603177c 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -629,6 +629,8 @@ static constexpr UInt64 operator""_GiB(unsigned long long value) M(Bool, allow_experimental_hash_functions, false, "Enable experimental hash functions (hashid, etc)", 0) \ M(Bool, allow_experimental_object_type, false, "Allow Object and JSON data types", 0) \ M(String, insert_deduplication_token, "", "If not empty, used for duplicate detection instead of data digest", 0) \ + M(String, ann_index_select_query_params, "", "Parameters passed to ANN indexes in SELECT queries, the format is 'param1=x, param2=y, ...'", 0) \ + M(UInt64, max_limit_for_ann_queries, 1000000, "Maximum limit value for using ANN indexes is used to prevent memory overflow in search queries for indexes", 0) \ M(Bool, count_distinct_optimization, false, "Rewrite count distinct to subquery of group by", 0) \ M(Bool, throw_on_unsupported_query_inside_transaction, true, "Throw exception if unsupported query is used inside transaction", 0) \ M(TransactionsWaitCSNMode, wait_changes_become_visible_after_commit_mode, TransactionsWaitCSNMode::WAIT_UNKNOWN, "Wait for committed changes to become actually visible in the latest snapshot", 0) \ diff --git a/src/Storages/MergeTree/CommonANNIndexes.cpp b/src/Storages/MergeTree/CommonANNIndexes.cpp new file mode 100644 index 00000000000..886f9ab1c0f --- /dev/null +++ b/src/Storages/MergeTree/CommonANNIndexes.cpp @@ -0,0 +1,595 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int INCORRECT_QUERY; +} + +namespace +{ + +namespace ANN = ApproximateNearestNeighbour; + +template +void extractTargetVectorFromLiteral(ANN::ANNQueryInformation::Embedding & target, Literal literal) +{ + Float64 float_element_of_target_vector; + Int64 int_element_of_target_vector; + + for (const auto & value : literal.value()) + { + if (value.tryGet(float_element_of_target_vector)) + { + target.emplace_back(float_element_of_target_vector); + } + else if (value.tryGet(int_element_of_target_vector)) + { + target.emplace_back(static_cast(int_element_of_target_vector)); + } + else + { + throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in target vector. Only float or int are supported."); + } + } +} + +ANN::ANNQueryInformation::Metric castMetricFromStringToType(String metric_name) +{ + if (metric_name == "L2Distance") + return ANN::ANNQueryInformation::Metric::L2; + if (metric_name == "LpDistance") + return ANN::ANNQueryInformation::Metric::Lp; + return ANN::ANNQueryInformation::Metric::Unknown; +} + +} + +namespace ApproximateNearestNeighbour +{ + +ANNCondition::ANNCondition(const SelectQueryInfo & query_info, + ContextPtr context) : + block_with_constants{KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context)}, + ann_index_select_query_params{context->getSettings().get("ann_index_select_query_params").get()}, + index_granularity{context->getMergeTreeSettings().get("index_granularity").get()}, + limit_restriction{context->getSettings().get("max_limit_for_ann_queries").get()}, + index_is_useful{checkQueryStructure(query_info)} {} + +bool ANNCondition::alwaysUnknownOrTrue(String metric_name) const +{ + if (!index_is_useful) + { + return true; // Query isn't supported + } + // If query is supported, check metrics for match + return !(castMetricFromStringToType(metric_name) == query_information->metric); +} + +float ANNCondition::getComparisonDistanceForWhereQuery() const +{ + if (index_is_useful && query_information.has_value() + && query_information->query_type == ANNQueryInformation::Type::Where) + { + return query_information->distance; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported method for this query type"); +} + +UInt64 ANNCondition::getLimit() const +{ + if (index_is_useful && query_information.has_value()) + { + return query_information->limit; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported"); +} + +std::vector ANNCondition::getTargetVector() const +{ + if (index_is_useful && query_information.has_value()) + { + return query_information->target; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Target vector was requested for useless or uninitialized index."); +} + +size_t ANNCondition::getNumOfDimensions() const +{ + if (index_is_useful && query_information.has_value()) + { + return query_information->target.size(); + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index."); +} + +String ANNCondition::getColumnName() const +{ + if (index_is_useful && query_information.has_value()) + { + return query_information->column_name; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index."); +} + +ANNQueryInformation::Metric ANNCondition::getMetricType() const +{ + if (index_is_useful && query_information.has_value()) + { + return query_information->metric; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Metric name was requested for useless or uninitialized index."); +} + +float ANNCondition::getPValueForLpDistance() const +{ + if (index_is_useful && query_information.has_value()) + { + return query_information->p_for_lp_dist; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "P from LPDistance was requested for useless or uninitialized index."); +} + +ANNQueryInformation::Type ANNCondition::getQueryType() const +{ + if (index_is_useful && query_information.has_value()) + { + return query_information->query_type; + } + throw Exception(ErrorCodes::LOGICAL_ERROR, "Query type was requested for useless or uninitialized index."); +} + +bool ANNCondition::checkQueryStructure(const SelectQueryInfo & query) +{ + // RPN-s for different sections of the query + RPN rpn_prewhere_clause; + RPN rpn_where_clause; + RPN rpn_order_by_clause; + RPNElement rpn_limit; + UInt64 limit; + + ANNQueryInformation prewhere_info; + ANNQueryInformation where_info; + ANNQueryInformation order_by_info; + + // Build rpns for query sections + const auto & select = query.query->as(); + + if (select.prewhere()) // If query has PREWHERE clause + { + traverseAST(select.prewhere(), rpn_prewhere_clause); + } + + if (select.where()) // If query has WHERE clause + { + traverseAST(select.where(), rpn_where_clause); + } + + if (select.limitLength()) // If query has LIMIT clause + { + traverseAtomAST(select.limitLength(), rpn_limit); + } + + if (select.orderBy()) // If query has ORDERBY clause + { + traverseOrderByAST(select.orderBy(), rpn_order_by_clause); + } + + // Reverse RPNs for conveniences during parsing + std::reverse(rpn_prewhere_clause.begin(), rpn_prewhere_clause.end()); + std::reverse(rpn_where_clause.begin(), rpn_where_clause.end()); + std::reverse(rpn_order_by_clause.begin(), rpn_order_by_clause.end()); + + // Match rpns with supported types and extract information + const bool prewhere_is_valid = matchRPNWhere(rpn_prewhere_clause, prewhere_info); + const bool where_is_valid = matchRPNWhere(rpn_where_clause, where_info); + const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by_clause, order_by_info); + const bool limit_is_valid = matchRPNLimit(rpn_limit, limit); + + // Query without a LIMIT clause or with a limit greater than a restriction is not supported + if (!limit_is_valid || limit_restriction < limit) + { + return false; + } + + // Search type query in both sections isn't supported + if (prewhere_is_valid && where_is_valid) + { + return false; + } + + // Search type should be in WHERE or PREWHERE clause + if (prewhere_is_valid || where_is_valid) + { + query_information = std::move(prewhere_is_valid ? prewhere_info : where_info); + } + + if (order_by_is_valid) + { + // Query with valid where and order by type is not supported + if (query_information.has_value()) + { + return false; + } + + query_information = std::move(order_by_info); + } + + if (query_information) + query_information->limit = limit; + + return query_information.has_value(); +} + +void ANNCondition::traverseAST(const ASTPtr & node, RPN & rpn) +{ + // If the node is ASTFunction, it may have children nodes + if (const auto * func = node->as()) + { + const ASTs & children = func->arguments->children; + // Traverse children nodes + for (const auto& child : children) + { + traverseAST(child, rpn); + } + } + + RPNElement element; + // Get the data behind node + if (!traverseAtomAST(node, element)) + { + element.function = RPNElement::FUNCTION_UNKNOWN; + } + + rpn.emplace_back(std::move(element)); +} + +bool ANNCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out) +{ + // Match Functions + if (const auto * function = node->as()) + { + // Set the name + out.func_name = function->name; + + if (function->name == "L1Distance" || + function->name == "L2Distance" || + function->name == "LinfDistance" || + function->name == "cosineDistance" || + function->name == "dotProduct" || + function->name == "LpDistance") + { + out.function = RPNElement::FUNCTION_DISTANCE; + } + else if (function->name == "tuple") + { + out.function = RPNElement::FUNCTION_TUPLE; + } + else if (function->name == "array") + { + out.function = RPNElement::FUNCTION_ARRAY; + } + else if (function->name == "less" || + function->name == "greater" || + function->name == "lessOrEquals" || + function->name == "greaterOrEquals") + { + out.function = RPNElement::FUNCTION_COMPARISON; + } + else if (function->name == "_CAST") + { + out.function = RPNElement::FUNCTION_CAST; + } + else + { + return false; + } + + return true; + } + // Match identifier + else if (const auto * identifier = node->as()) + { + out.function = RPNElement::FUNCTION_IDENTIFIER; + out.identifier.emplace(identifier->name()); + out.func_name = "column identifier"; + + return true; + } + + // Check if we have constants behind the node + return tryCastToConstType(node, out); +} + +bool ANNCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out) +{ + Field const_value; + DataTypePtr const_type; + + if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type)) + { + /// Check for constant types + if (const_value.getType() == Field::Types::Float64) + { + out.function = RPNElement::FUNCTION_FLOAT_LITERAL; + out.float_literal.emplace(const_value.get()); + out.func_name = "Float literal"; + return true; + } + + if (const_value.getType() == Field::Types::UInt64) + { + out.function = RPNElement::FUNCTION_INT_LITERAL; + out.int_literal.emplace(const_value.get()); + out.func_name = "Int literal"; + return true; + } + + if (const_value.getType() == Field::Types::Int64) + { + out.function = RPNElement::FUNCTION_INT_LITERAL; + out.int_literal.emplace(const_value.get()); + out.func_name = "Int literal"; + return true; + } + + if (const_value.getType() == Field::Types::Tuple) + { + out.function = RPNElement::FUNCTION_LITERAL_TUPLE; + out.tuple_literal = const_value.get(); + out.func_name = "Tuple literal"; + return true; + } + + if (const_value.getType() == Field::Types::Array) + { + out.function = RPNElement::FUNCTION_LITERAL_ARRAY; + out.array_literal = const_value.get(); + out.func_name = "Array literal"; + return true; + } + + if (const_value.getType() == Field::Types::String) + { + out.function = RPNElement::FUNCTION_STRING_LITERAL; + out.func_name = const_value.get(); + return true; + } + } + + return false; +} + +void ANNCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn) +{ + if (const auto * expr_list = node->as()) + { + if (const auto * order_by_element = expr_list->children.front()->as()) + { + traverseAST(order_by_element->children.front(), rpn); + } + } +} + +// Returns true and stores ANNQueryInformation if the query has valid WHERE clause +bool ANNCondition::matchRPNWhere(RPN & rpn, ANNQueryInformation & expr) +{ + // WHERE section must have at least 5 expressions + // Operator->Distance(float)->DistanceFunc->Column->Tuple(Array)Func(TargetVector(floats)) + if (rpn.size() < 5) + { + return false; + } + + auto iter = rpn.begin(); + + // Query starts from operator less + if (iter->function != RPNElement::FUNCTION_COMPARISON) + { + return false; + } + + const bool greater_case = iter->func_name == "greater" || iter->func_name == "greaterOrEquals"; + const bool less_case = iter->func_name == "less" || iter->func_name == "lessOrEquals"; + + ++iter; + + if (less_case) + { + if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL) + { + return false; + } + + expr.distance = getFloatOrIntLiteralOrPanic(iter); + ++iter; + + } + else if (!greater_case) + { + return false; + } + + auto end = rpn.end(); + if (!matchMainParts(iter, end, expr)) + { + return false; + } + + if (greater_case) + { + if (expr.target.size() < 2) + { + return false; + } + expr.distance = expr.target.back(); + expr.target.pop_back(); + } + + // query is ok + return true; +} + +// Returns true and stores ANNExpr if the query has valid ORDERBY clause +bool ANNCondition::matchRPNOrderBy(RPN & rpn, ANNQueryInformation & expr) +{ + // ORDER BY clause must have at least 3 expressions + if (rpn.size() < 3) + { + return false; + } + + auto iter = rpn.begin(); + auto end = rpn.end(); + + return ANNCondition::matchMainParts(iter, end, expr); +} + +// Returns true and stores Length if we have valid LIMIT clause in query +bool ANNCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit) +{ + if (rpn.function == RPNElement::FUNCTION_INT_LITERAL) + { + limit = rpn.int_literal.value(); + return true; + } + + return false; +} + +/* Matches dist function, target vector, column name */ +bool ANNCondition::matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ANNQueryInformation & expr) +{ + bool identifier_found = false; + + // Matches DistanceFunc->[Column]->[Tuple(array)Func]->TargetVector(floats)->[Column] + if (iter->function != RPNElement::FUNCTION_DISTANCE) + { + return false; + } + + expr.metric = castMetricFromStringToType(iter->func_name); + ++iter; + + if (expr.metric == ANN::ANNQueryInformation::Metric::Lp) + { + if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL && + iter->function != RPNElement::FUNCTION_INT_LITERAL) + { + return false; + } + expr.p_for_lp_dist = getFloatOrIntLiteralOrPanic(iter); + ++iter; + } + + if (iter->function == RPNElement::FUNCTION_IDENTIFIER) + { + identifier_found = true; + expr.column_name = std::move(iter->identifier.value()); + ++iter; + } + + if (iter->function == RPNElement::FUNCTION_TUPLE || iter->function == RPNElement::FUNCTION_ARRAY) + { + ++iter; + } + + if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE) + { + extractTargetVectorFromLiteral(expr.target, iter->tuple_literal); + ++iter; + } + + if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY) + { + extractTargetVectorFromLiteral(expr.target, iter->array_literal); + ++iter; + } + + /// further conditions are possible if there is no tuple or array, or no identifier is found + /// the tuple or array can be inside a cast function. For other cases, see the loop after this condition + if (iter != end && iter->function == RPNElement::FUNCTION_CAST) + { + ++iter; + /// Cast should be made to array or tuple + if (!iter->func_name.starts_with("Array") && !iter->func_name.starts_with("Tuple")) + { + return false; + } + ++iter; + if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE) + { + extractTargetVectorFromLiteral(expr.target, iter->tuple_literal); + ++iter; + } + else if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY) + { + extractTargetVectorFromLiteral(expr.target, iter->array_literal); + ++iter; + } + else + { + return false; + } + } + + while (iter != end) + { + if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL || + iter->function == RPNElement::FUNCTION_INT_LITERAL) + { + expr.target.emplace_back(getFloatOrIntLiteralOrPanic(iter)); + } + else if (iter->function == RPNElement::FUNCTION_IDENTIFIER) + { + if (identifier_found) + { + return false; + } + expr.column_name = std::move(iter->identifier.value()); + identifier_found = true; + } + else + { + return false; + } + + ++iter; + } + + // Final checks of correctness + return identifier_found && !expr.target.empty(); +} + +// Gets float or int from AST node +float ANNCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter) +{ + if (iter->float_literal.has_value()) + { + return iter->float_literal.value(); + } + if (iter->int_literal.has_value()) + { + return static_cast(iter->int_literal.value()); + } + throw Exception("Wrong parsed AST in buildRPN\n", ErrorCodes::INCORRECT_QUERY); +} + +} + +} diff --git a/src/Storages/MergeTree/CommonANNIndexes.h b/src/Storages/MergeTree/CommonANNIndexes.h new file mode 100644 index 00000000000..fefb9584863 --- /dev/null +++ b/src/Storages/MergeTree/CommonANNIndexes.h @@ -0,0 +1,236 @@ +#pragma once + +#include +#include "base/types.h" + +#include +#include + +namespace DB +{ + +namespace ApproximateNearestNeighbour +{ + +/** + * Queries for Approximate Nearest Neighbour Search + * have similar structure: + * 1) target vector from which all distances are calculated + * 2) metric name (e.g L2Distance, LpDistance, etc.) + * 3) name of column with embeddings + * 4) type of query + * 5) Number of elements, that should be taken (limit) + * + * And two optional parameters: + * 1) p for LpDistance function + * 2) distance to compare with (only for where queries) + */ +struct ANNQueryInformation +{ + using Embedding = std::vector; + + // Extracted data from valid query + Embedding target; + enum class Metric + { + Unknown, + L2, + Lp + } metric; + String column_name; + UInt64 limit; + + enum class Type + { + OrderBy, + Where + } query_type; + + float p_for_lp_dist = -1.0; + float distance = -1.0; +}; + +/** + Class ANNCondition, is responsible for recognizing special query types which + can be speeded up by ANN Indexes. It parses the SQL query and checks + if it matches ANNIndexes. The recognizing method - alwaysUnknownOrTrue + returns false if we can speed up the query, and true otherwise. + It has only one argument, name of the metric with which index was built. + There are two main patterns of queries being supported + + 1) Search query type + SELECT * FROM * WHERE DistanceFunc(column, target_vector) < floatLiteral LIMIT count + + 2) OrderBy query type + SELECT * FROM * WHERE * ORDERBY DistanceFunc(column, target_vector) LIMIT count + + *Query without LIMIT count is not supported* + + target_vector(should have float coordinates) examples: + tuple(0.1, 0.1, ...., 0.1) or (0.1, 0.1, ...., 0.1) + [the word tuple is not needed] + + If the query matches one of these two types, than the class extracts useful information + from the query. If the query has both 1 and 2 types, than we can't speed and alwaysUnknownOrTrue + returns true. + + From matching query it extracts + * targetVector + * metricName(DistanceFunction) + * dimension size if query uses LpDistance + * distance to compare(ONLY for search types, otherwise you get exception) + * spaceDimension(which is targetVector's components count) + * column + * objects count from LIMIT clause(for both queries) + * settings str, if query has settings section with new 'ann_index_select_query_params' value, + than you can get the new value(empty by default) calling method getSettingsStr + * queryHasOrderByClause and queryHasWhereClause return true if query matches the type + + Search query type is also recognized for PREWHERE clause +*/ + +class ANNCondition +{ +public: + ANNCondition(const SelectQueryInfo & query_info, + ContextPtr context); + + // false if query can be speeded up, true otherwise + bool alwaysUnknownOrTrue(String metric_name) const; + + // returns the distance to compare with for search query + float getComparisonDistanceForWhereQuery() const; + + // distance should be calculated regarding to targetVector + std::vector getTargetVector() const; + + // targetVector dimension size + size_t getNumOfDimensions() const; + + String getColumnName() const; + + ANNQueryInformation::Metric getMetricType() const; + + // the P- value if the metric is 'LpDistance' + float getPValueForLpDistance() const; + + ANNQueryInformation::Type getQueryType() const; + + UInt64 getIndexGranularity() const { return index_granularity; } + + // length's value from LIMIT clause + UInt64 getLimit() const; + + // value of 'ann_index_select_query_params' if have in SETTINGS clause, empty string otherwise + String getParamsStr() const { return ann_index_select_query_params; } + +private: + + struct RPNElement + { + enum Function + { + // DistanceFunctions + FUNCTION_DISTANCE, + + //tuple(0.1, ..., 0.1) + FUNCTION_TUPLE, + + //array(0.1, ..., 0.1) + FUNCTION_ARRAY, + + // Operators <, >, <=, >= + FUNCTION_COMPARISON, + + // Numeric float value + FUNCTION_FLOAT_LITERAL, + + // Numeric int value + FUNCTION_INT_LITERAL, + + // Column identifier + FUNCTION_IDENTIFIER, + + // Unknown, can be any value + FUNCTION_UNKNOWN, + + // (0.1, ...., 0.1) vector without word 'tuple' + FUNCTION_LITERAL_TUPLE, + + // [0.1, ...., 0.1] vector without word 'array' + FUNCTION_LITERAL_ARRAY, + + // if client parameters are used, cast will always be in the query + FUNCTION_CAST, + + // name of type in cast function + FUNCTION_STRING_LITERAL, + }; + + explicit RPNElement(Function function_ = FUNCTION_UNKNOWN) + : function(function_), func_name("Unknown"), float_literal(std::nullopt), identifier(std::nullopt) {} + + Function function; + String func_name; + + std::optional float_literal; + std::optional identifier; + std::optional int_literal; + + std::optional tuple_literal; + std::optional array_literal; + + UInt32 dim = 0; + }; + + using RPN = std::vector; + + bool checkQueryStructure(const SelectQueryInfo & query); + + // Util functions for the traversal of AST, parses AST and builds rpn + void traverseAST(const ASTPtr & node, RPN & rpn); + // Return true if we can identify our node type + bool traverseAtomAST(const ASTPtr & node, RPNElement & out); + // Checks if the AST stores ConstType expression + bool tryCastToConstType(const ASTPtr & node, RPNElement & out); + // Traverses the AST of ORDERBY section + void traverseOrderByAST(const ASTPtr & node, RPN & rpn); + + // Returns true and stores ANNExpr if the query has valid WHERE section + static bool matchRPNWhere(RPN & rpn, ANNQueryInformation & expr); + + // Returns true and stores ANNExpr if the query has valid ORDERBY section + static bool matchRPNOrderBy(RPN & rpn, ANNQueryInformation & expr); + + // Returns true and stores Length if we have valid LIMIT clause in query + static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit); + + /* Matches dist function, target vector, column name */ + static bool matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ANNQueryInformation & expr); + + // Gets float or int from AST node + static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter); + + Block block_with_constants; + + // true if we have one of two supported query types + std::optional query_information; + + // Get from settings ANNIndex parameters + String ann_index_select_query_params; + UInt64 index_granularity; + /// only queries with a lower limit can be considered to avoid memory overflow + UInt64 limit_restriction; + bool index_is_useful = false; +}; + +// condition interface for Ann indexes. Returns vector of indexes of ranges in granule which are useful for query. +class IMergeTreeIndexConditionAnn : public IMergeTreeIndexCondition +{ +public: + virtual std::vector getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const = 0; +}; + +} + +} diff --git a/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp b/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp index c5f546a9c36..12aec29eab6 100644 --- a/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp +++ b/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp @@ -43,6 +43,8 @@ #include +#include + namespace DB { @@ -1669,6 +1671,31 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex( { if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin) granule = reader.read(); + // Cast to Ann condition + auto ann_condition = std::dynamic_pointer_cast(condition); + if (ann_condition != nullptr) + { + // vector of indexes of useful ranges + auto result = ann_condition->getUsefulRanges(granule); + if (result.empty()) + { + ++granules_dropped; + } + + for (auto range : result) + { + // range for corresponding index + MarkRange data_range( + std::max(ranges[i].begin, index_mark * index_granularity + range), + std::min(ranges[i].end, index_mark * index_granularity + range + 1)); + + if (res.empty() || res.back().end - data_range.begin > min_marks_for_seek) + res.push_back(data_range); + else + res.back().end = data_range.end; + } + continue; + } if (!condition->mayBeTrueOnGranule(granule)) { diff --git a/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp b/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp new file mode 100644 index 00000000000..a8b825d832d --- /dev/null +++ b/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp @@ -0,0 +1,317 @@ +#ifdef ENABLE_ANNOY + +#include + +#include +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ApproximateNearestNeighbour +{ + +template +void AnnoyIndex::serialize(WriteBuffer& ostr) const +{ + assert(Base::_built); + writeIntBinary(Base::_s, ostr); + writeIntBinary(Base::_n_items, ostr); + writeIntBinary(Base::_n_nodes, ostr); + writeIntBinary(Base::_nodes_size, ostr); + writeIntBinary(Base::_K, ostr); + writeIntBinary(Base::_seed, ostr); + writeVectorBinary(Base::_roots, ostr); + ostr.write(reinterpret_cast(Base::_nodes), Base::_s * Base::_n_nodes); +} + +template +void AnnoyIndex::deserialize(ReadBuffer& istr) +{ + assert(!Base::_built); + readIntBinary(Base::_s, istr); + readIntBinary(Base::_n_items, istr); + readIntBinary(Base::_n_nodes, istr); + readIntBinary(Base::_nodes_size, istr); + readIntBinary(Base::_K, istr); + readIntBinary(Base::_seed, istr); + readVectorBinary(Base::_roots, istr); + Base::_nodes = realloc(Base::_nodes, Base::_s * Base::_n_nodes); + istr.read(reinterpret_cast(Base::_nodes), Base::_s * Base::_n_nodes); + + Base::_fd = 0; + // set flags + Base::_loaded = false; + Base::_verbose = false; + Base::_on_disk = false; + Base::_built = true; +} + +template +uint64_t AnnoyIndex::getNumOfDimensions() const +{ + return Base::get_f(); +} + +} + + +namespace ErrorCodes +{ + extern const int LOGICAL_ERROR; + extern const int INCORRECT_QUERY; + extern const int INCORRECT_DATA; +} + +MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_) + : index_name(index_name_) + , index_sample_block(index_sample_block_) + , index(nullptr) +{} + +MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy( + const String & index_name_, + const Block & index_sample_block_, + AnnoyIndexPtr index_base_) + : index_name(index_name_) + , index_sample_block(index_sample_block_) + , index(std::move(index_base_)) +{} + +void MergeTreeIndexGranuleAnnoy::serializeBinary(WriteBuffer & ostr) const +{ + /// number of dimensions is required in the constructor, + /// so it must be written and read separately from the other part + writeIntBinary(index->getNumOfDimensions(), ostr); // write dimension + index->serialize(ostr); +} + +void MergeTreeIndexGranuleAnnoy::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/) +{ + uint64_t dimension; + readIntBinary(dimension, istr); + index = std::make_shared(dimension); + index->deserialize(istr); +} + + +MergeTreeIndexAggregatorAnnoy::MergeTreeIndexAggregatorAnnoy( + const String & index_name_, + const Block & index_sample_block_, + uint64_t number_of_trees_) + : index_name(index_name_) + , index_sample_block(index_sample_block_) + , number_of_trees(number_of_trees_) +{} + +MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy::getGranuleAndReset() +{ + // NOLINTNEXTLINE(*) + index->build(number_of_trees, /*number_of_threads=*/1); + auto granule = std::make_shared(index_name, index_sample_block, index); + index = nullptr; + return granule; +} + +void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, size_t limit) +{ + if (*pos >= block.rows()) + throw Exception( + ErrorCodes::LOGICAL_ERROR, + "The provided position is not less than the number of block rows. Position: {}, Block rows: {}.", + toString(*pos), toString(block.rows())); + + size_t rows_read = std::min(limit, block.rows() - *pos); + + if (index_sample_block.columns() > 1) + { + throw Exception("Only one column is supported", ErrorCodes::LOGICAL_ERROR); + } + + auto index_column_name = index_sample_block.getByPosition(0).name; + const auto & column_cut = block.getByName(index_column_name).column->cut(*pos, rows_read); + const auto & column_array = typeid_cast(column_cut.get()); + if (column_array) + { + const auto & data = column_array->getData(); + const auto & array = typeid_cast(data).getData(); + const auto & offsets = column_array->getOffsets(); + size_t num_rows = column_array->size(); + + /// All sizes are the same + size_t size = offsets[0]; + for (size_t i = 0; i < num_rows - 1; ++i) + { + if (offsets[i + 1] - offsets[i] != size) + { + throw Exception(ErrorCodes::INCORRECT_DATA, "Arrays should have same length"); + } + } + index = std::make_shared(size); + + for (size_t current_row = 0; current_row < num_rows; ++current_row) + { + index->add_item(index->get_n_items(), &array[offsets[current_row]]); + } + } + else + { + /// Other possible type of column is Tuple + const auto & column_tuple = typeid_cast(column_cut.get()); + + if (!column_tuple) + throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type was given to index."); + + const auto & columns = column_tuple->getColumns(); + + std::vector> data{column_tuple->size(), std::vector()}; + for (const auto& column : columns) + { + const auto& pod_array = typeid_cast(column.get())->getData(); + for (size_t i = 0; i < pod_array.size(); ++i) + { + data[i].push_back(pod_array[i]); + } + } + assert(!data.empty()); + if (!index) + { + index = std::make_shared(data[0].size()); + } + for (const auto& item : data) + { + index->add_item(index->get_n_items(), item.data()); + } + } + + *pos += rows_read; +} + + +MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy( + const IndexDescription & /*index*/, + const SelectQueryInfo & query, + ContextPtr context) + : condition(query, context) +{} + + +bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /* idx_granule */) const +{ + throw Exception("mayBeTrueOnGranule is not supported for ANN skip indexes", ErrorCodes::LOGICAL_ERROR); +} + +bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const +{ + return condition.alwaysUnknownOrTrue("L2Distance"); +} + +std::vector MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const +{ + UInt64 limit = condition.getLimit(); + UInt64 index_granularity = condition.getIndexGranularity(); + std::optional comp_dist = condition.getQueryType() == ANN::ANNQueryInformation::Type::Where ? + std::optional(condition.getComparisonDistanceForWhereQuery()) : std::nullopt; + + if (comp_dist && comp_dist.value() < 0) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to optimize query with where without distance"); + + std::vector target_vec = condition.getTargetVector(); + + auto granule = std::dynamic_pointer_cast(idx_granule); + if (granule == nullptr) + { + throw Exception("Granule has the wrong type", ErrorCodes::LOGICAL_ERROR); + } + auto annoy = granule->index; + + if (condition.getNumOfDimensions() != annoy->getNumOfDimensions()) + { + throw Exception("The dimension of the space in the request (" + toString(condition.getNumOfDimensions()) + ") " + + "does not match with the dimension in the index (" + toString(annoy->getNumOfDimensions()) + ")", ErrorCodes::INCORRECT_QUERY); + } + + /// neighbors contain indexes of dots which were closest to target vector + std::vector neighbors; + std::vector distances; + neighbors.reserve(limit); + distances.reserve(limit); + + int k_search = -1; + String params_str = condition.getParamsStr(); + if (!params_str.empty()) + { + try + { + /// k_search=... (algorithm will inspect up to search_k nodes which defaults to n_trees * n if not provided) + k_search = std::stoi(params_str.data() + 9); + } + catch (...) + { + throw Exception("Setting of the annoy index should be int", ErrorCodes::INCORRECT_QUERY); + } + } + annoy->get_nns_by_vector(target_vec.data(), limit, k_search, &neighbors, &distances); + std::unordered_set granule_numbers; + for (size_t i = 0; i < neighbors.size(); ++i) + { + if (comp_dist && distances[i] > comp_dist) + { + continue; + } + granule_numbers.insert(neighbors[i] / index_granularity); + } + + std::vector result_vector; + result_vector.reserve(granule_numbers.size()); + for (auto granule_number : granule_numbers) + { + result_vector.push_back(granule_number); + } + + return result_vector; +} + + +MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const +{ + return std::make_shared(index.name, index.sample_block); +} + +MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const +{ + return std::make_shared(index.name, index.sample_block, number_of_trees); +} + +MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition( + const SelectQueryInfo & query, ContextPtr context) const +{ + return std::make_shared(index, query, context); +}; + +MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index) +{ + uint64_t param = index.arguments[0].get(); + return std::make_shared(index, param); +} + +void annoyIndexValidator(const IndexDescription & index, bool /* attach */) +{ + if (index.arguments.size() != 1) + { + throw Exception("Annoy index must have exactly one argument.", ErrorCodes::INCORRECT_QUERY); + } + if (index.arguments[0].getType() != Field::Types::UInt64) + { + throw Exception("Annoy index argument must be UInt64.", ErrorCodes::INCORRECT_QUERY); + } +} + +} +#endif // ENABLE_ANNOY diff --git a/src/Storages/MergeTree/MergeTreeIndexAnnoy.h b/src/Storages/MergeTree/MergeTreeIndexAnnoy.h new file mode 100644 index 00000000000..85bbb0a1bd2 --- /dev/null +++ b/src/Storages/MergeTree/MergeTreeIndexAnnoy.h @@ -0,0 +1,123 @@ +#pragma once + +#ifdef ENABLE_ANNOY + +#include + +#include +#include + +namespace DB +{ + +namespace ANN = ApproximateNearestNeighbour; + +// auxiliary namespace for working with spotify-annoy library +// mainly for serialization and deserialization of the index +namespace ApproximateNearestNeighbour +{ + using AnnoyIndexThreadedBuildPolicy = ::Annoy::AnnoyIndexMultiThreadedBuildPolicy; + // TODO: Support different metrics. List of available metrics can be taken from here: + // https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171 + template + class AnnoyIndex : public ::Annoy::AnnoyIndex + { + using Base = ::Annoy::AnnoyIndex; + public: + explicit AnnoyIndex(const uint64_t dim) : Base::AnnoyIndex(dim) {} + void serialize(WriteBuffer& ostr) const; + void deserialize(ReadBuffer& istr); + uint64_t getNumOfDimensions() const; + }; +} + +struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule +{ + using AnnoyIndex = ANN::AnnoyIndex<>; + using AnnoyIndexPtr = std::shared_ptr; + + MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_); + MergeTreeIndexGranuleAnnoy( + const String & index_name_, + const Block & index_sample_block_, + AnnoyIndexPtr index_base_); + + ~MergeTreeIndexGranuleAnnoy() override = default; + + void serializeBinary(WriteBuffer & ostr) const override; + void deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion version) override; + + bool empty() const override { return !index.get(); } + + String index_name; + Block index_sample_block; + AnnoyIndexPtr index; +}; + + +struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator +{ + using AnnoyIndex = ANN::AnnoyIndex<>; + using AnnoyIndexPtr = std::shared_ptr; + + MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block, uint64_t number_of_trees); + ~MergeTreeIndexAggregatorAnnoy() override = default; + + bool empty() const override { return !index || index->get_n_items() == 0; } + MergeTreeIndexGranulePtr getGranuleAndReset() override; + void update(const Block & block, size_t * pos, size_t limit) override; + + String index_name; + Block index_sample_block; + const uint64_t number_of_trees; + AnnoyIndexPtr index; +}; + + +class MergeTreeIndexConditionAnnoy final : public ANN::IMergeTreeIndexConditionAnn +{ +public: + MergeTreeIndexConditionAnnoy( + const IndexDescription & index, + const SelectQueryInfo & query, + ContextPtr context); + + bool alwaysUnknownOrTrue() const override; + + bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override; + + std::vector getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override; + + ~MergeTreeIndexConditionAnnoy() override = default; + +private: + ANN::ANNCondition condition; +}; + + +class MergeTreeIndexAnnoy : public IMergeTreeIndex +{ +public: + MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t number_of_trees_) + : IMergeTreeIndex(index_) + , number_of_trees(number_of_trees_) + {} + + ~MergeTreeIndexAnnoy() override = default; + + MergeTreeIndexGranulePtr createIndexGranule() const override; + MergeTreeIndexAggregatorPtr createIndexAggregator() const override; + + MergeTreeIndexConditionPtr createIndexCondition( + const SelectQueryInfo & query, ContextPtr context) const override; + + bool mayBenefitFromIndexForIn(const ASTPtr & /*node*/) const override { return false; } + +private: + const uint64_t number_of_trees; +}; + + +} + +#endif // ENABLE_ANNOY diff --git a/src/Storages/MergeTree/MergeTreeIndices.cpp b/src/Storages/MergeTree/MergeTreeIndices.cpp index 9d7e0cdfdbe..eeeef27699f 100644 --- a/src/Storages/MergeTree/MergeTreeIndices.cpp +++ b/src/Storages/MergeTree/MergeTreeIndices.cpp @@ -101,6 +101,11 @@ MergeTreeIndexFactory::MergeTreeIndexFactory() registerCreator("hypothesis", hypothesisIndexCreator); registerValidator("hypothesis", hypothesisIndexValidator); + +#ifdef ENABLE_ANNOY + registerCreator("annoy", annoyIndexCreator); + registerValidator("annoy", annoyIndexValidator); +#endif } MergeTreeIndexFactory & MergeTreeIndexFactory::instance() diff --git a/src/Storages/MergeTree/MergeTreeIndices.h b/src/Storages/MergeTree/MergeTreeIndices.h index 051edd630cb..14002534c94 100644 --- a/src/Storages/MergeTree/MergeTreeIndices.h +++ b/src/Storages/MergeTree/MergeTreeIndices.h @@ -224,4 +224,9 @@ void bloomFilterIndexValidatorNew(const IndexDescription & index, bool attach); MergeTreeIndexPtr hypothesisIndexCreator(const IndexDescription & index); void hypothesisIndexValidator(const IndexDescription & index, bool attach); +#ifdef ENABLE_ANNOY +MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index); +void annoyIndexValidator(const IndexDescription & index, bool attach); +#endif + } diff --git a/tests/queries/0_stateless/02354_annoy.reference b/tests/queries/0_stateless/02354_annoy.reference new file mode 100644 index 00000000000..2cc62ef4c86 --- /dev/null +++ b/tests/queries/0_stateless/02354_annoy.reference @@ -0,0 +1,16 @@ +1 [0,0,10] +2 [0,0,10.5] +3 [0,0,9.5] +4 [0,0,9.7] +5 [0,0,10.2] +1 [0,0,10] +5 [0,0,10.2] +4 [0,0,9.7] +1 [0,0,10] +2 [0,0,10.5] +3 [0,0,9.5] +4 [0,0,9.7] +5 [0,0,10.2] +1 [0,0,10] +5 [0,0,10.2] +4 [0,0,9.7] diff --git a/tests/queries/0_stateless/02354_annoy.sql b/tests/queries/0_stateless/02354_annoy.sql new file mode 100644 index 00000000000..da0799cecaa --- /dev/null +++ b/tests/queries/0_stateless/02354_annoy.sql @@ -0,0 +1,44 @@ +-- Tags: no-fasttest, no-ubsan + +DROP TABLE IF EXISTS 02354_annoy; + +CREATE TABLE 02354_annoy +( + id Int32, + embedding Array(Float32), + INDEX annoy_index embedding TYPE annoy(100) GRANULARITY 1 +) +ENGINE = MergeTree +ORDER BY id +SETTINGS index_granularity=5; + +INSERT INTO 02354_annoy VALUES (1, [0.0, 0.0, 10.0]), (2, [0.0, 0.0, 10.5]), (3, [0.0, 0.0, 9.5]), (4, [0.0, 0.0, 9.7]), (5, [0.0, 0.0, 10.2]), (6, [10.0, 0.0, 0.0]), (7, [9.5, 0.0, 0.0]), (8, [9.7, 0.0, 0.0]), (9, [10.2, 0.0, 0.0]), (10, [10.5, 0.0, 0.0]), (11, [0.0, 10.0, 0.0]), (12, [0.0, 9.5, 0.0]), (13, [0.0, 9.7, 0.0]), (14, [0.0, 10.2, 0.0]), (15, [0.0, 10.5, 0.0]); + +SELECT * +FROM 02354_annoy +WHERE L2Distance(embedding, [0.0, 0.0, 10.0]) < 1.0 +LIMIT 5; + +SELECT * +FROM 02354_annoy +ORDER BY L2Distance(embedding, [0.0, 0.0, 10.0]) +LIMIT 3; + +SET param_02354_target_vector='[0.0, 0.0, 10.0]'; + +SELECT * +FROM 02354_annoy +WHERE L2Distance(embedding, {02354_target_vector: Array(Float32)}) < 1.0 +LIMIT 5; + +SELECT * +FROM 02354_annoy +ORDER BY L2Distance(embedding, {02354_target_vector: Array(Float32)}) +LIMIT 3; + +SELECT * +FROM 02354_annoy +ORDER BY L2Distance(embedding, [0.0, 0.0]) +LIMIT 3; -- { serverError 80 } + +DROP TABLE IF EXISTS 02354_annoy;