diff --git a/.gitmodules b/.gitmodules index 5847e7456a7..dd64c9c05f6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -259,8 +259,8 @@ [submodule "contrib/minizip-ng"] path = contrib/minizip-ng url = https://github.com/zlib-ng/minizip-ng -[submodule "contrib/spotify-annoy"] - path = contrib/spotify-annoy +[submodule "contrib/annoy"] + path = contrib/annoy url = https://github.com/Vector-Similarity-Search-for-ClickHouse/annoy.git [submodule "contrib/wyhash"] path = contrib/wyhash diff --git a/contrib/CMakeLists.txt b/contrib/CMakeLists.txt index 3274a98377b..f285fb337fd 100644 --- a/contrib/CMakeLists.txt +++ b/contrib/CMakeLists.txt @@ -159,14 +159,9 @@ add_contrib (base-x-cmake base-x) set (ENABLE_ANNOY_DEFAULT ${ENABLE_LIBRARIES}) option(ENABLE_ANNOY "Enable Annoy index support" ${ENABLE_ANNOY_DEFAULT}) -if (CMAKE_SYSTEM_NAME MATCHES "Darwin") - message (WARNING "Annoy disabled. Doesn't support Darwin.") - set (ENABLE_ANNOY OFF PARENT_SCOPE) - set (ENABLE_ANNOY OFF) -endif () if (ENABLE_ANNOY) - add_contrib (spotify-annoy-cmake spotify-annoy) - target_compile_definitions(_spotify_annoy PUBLIC ENABLE_ANNOY) + add_contrib(annoy-cmake annoy) + target_compile_definitions(_annoy PUBLIC ENABLE_ANNOY) endif() # Put all targets defined here and in subdirectories under "contrib/" folders in GUI-based IDEs. diff --git a/contrib/annoy b/contrib/annoy new file mode 160000 index 00000000000..301ff04e221 --- /dev/null +++ b/contrib/annoy @@ -0,0 +1 @@ +Subproject commit 301ff04e2213abaa7cbe30041b9b576c968bd994 diff --git a/contrib/annoy-cmake/CMakeLists.txt b/contrib/annoy-cmake/CMakeLists.txt new file mode 100644 index 00000000000..967659fd2b8 --- /dev/null +++ b/contrib/annoy-cmake/CMakeLists.txt @@ -0,0 +1,9 @@ +set(ANNOY_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/annoy") +set(ANNOY_SOURCE_DIR "${ANNOY_PROJECT_DIR}/src") +set(ANNOY_INCLUDE_DIR "${ANNOY_PROJECT_DIR}/src") + +add_library(_annoy ${ANNOY_SOURCE_DIR}/mman.h) +target_include_directories(_annoy SYSTEM PUBLIC ${ANNOY_INCLUDE_DIR}) +set_target_properties(_annoy PROPERTIES LINKER_LANGUAGE CXX) + +add_library(ch_contrib::annoy ALIAS _annoy) diff --git a/contrib/spotify-annoy-cmake/CMakeLists.txt b/contrib/spotify-annoy-cmake/CMakeLists.txt deleted file mode 100644 index 4172c9a13a2..00000000000 --- a/contrib/spotify-annoy-cmake/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -set(SPOTIFY_ANNOY_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/spotify-annoy") -set(SPOTIFY_ANNOY_SOURCE_DIR "${SPOTIFY_ANNOY_PROJECT_DIR}/src") -set(SPOTIFY_ANNOY_INCLUDE_DIR "${SPOTIFY_ANNOY_PROJECT_DIR}/src") - -set(SPOTIFY_ANNOY_SRC - ${SPOTIFY_ANNOY_SOURCE_DIR}/mman.h -) - -add_library(_spotify_annoy ${SPOTIFY_ANNOY_SRC}) -target_include_directories(_spotify_annoy SYSTEM PUBLIC ${SPOTIFY_ANNOY_INCLUDE_DIR}) -set_target_properties(_spotify_annoy PROPERTIES LINKER_LANGUAGE CXX) - -add_library(ch_contrib::spotify-annoy ALIAS _spotify_annoy) diff --git a/docs/en/engines/table-engines/mergetree-family/ann_indexes.md b/docs/en/engines/table-engines/mergetree-family/ann_indexes.md new file mode 100644 index 00000000000..a14cd49064f --- /dev/null +++ b/docs/en/engines/table-engines/mergetree-family/ann_indexes.md @@ -0,0 +1,76 @@ +# Approximate Nearest Neighbor Search Indexes [experimental] {#table_engines-ANNIndex} + +# TODO Embedings + +Approximate Nearest Neighbor Search Indexes (`ANNIndexes`) are simmilar to skip indexes. They are constructed by some granules and determine which of them should be skipped. Compared to skip indices, ANN indices use their results not only to skip some group of granules, but also to select particular granules from a set of granules. + +`ANNIndexes` are designed to speed up two types of queries: + +- ###### Type 1: Where + ``` sql + SELECT * + FROM table_name + WHERE DistanceFunction(Column, TargetVector) < Value + LIMIT N + ``` +- ###### Type 2: Order by + ``` sql + SELECT * + FROM table_name [WHERE ...] + ORDER BY DistanceFunction(Column, TargetVector) + LIMIT N + ``` + +In these queries, `DistanceFunction` is selected from tuples of distance functions. `TargetVector` is a known embedding (something like `(0.1, 0.1, ... )`). `Value` - a float value that will bound the neighbourhood. + +!!! note "Note" + ANNIndex can't speed up query that satisfies both types and they work only for Tuples. All queries must have the limit, as algorithms are used to find nearest neighbors and need a specific number of them. + +Both types of queries are handled the same way. The indexes get `n` neighbors (where `n` is taken from the `LIMIT` section) and work with them. In `ORDER BY` query they remember the numbers of all parts of the granule that have at least one of neighbor. In `WHERE` query they remember only those parts that satisfy the requirements. + +###### Create table with ANNIndex +``` +CREATE TABLE t +( + `id` Int64, + `number` Tuple(Float32, Float32, Float32), + INDEX x number TYPE annoy GRANULARITY N +) +ENGINE = MergeTree +ORDER BY id; +``` + +!!! note "Note" + ANNIndexes work only when setting `index_granularity=8192`. + +Number of granules in granularity should be large. With greater `GRANULARITY` indexes remember the data structure better. But some indexes can't be built if they don't have enough data, so this granule will always participate in the query. For more information, see the description of indexes. + +As the indexes are built only during insertions into table, `INSERT` and `OPTIMIZE` queries are slower than for ordinary table. At this stage indexes remember all the information about the given data. ANNIndexes should be used if you have immutable or rarely changed data and many read requests. + +You can create your table with index which uses certain algorithm. Now only indices based on the following algorithms are supported: + +##### Index list +- Annoy + +# Annoy {#annoy} +Implementation of the algorithm was taken from [this repository](https://github.com/spotify/annoy). + +Short description of the algorithm: +The algorithm recursively divides in half all space by random linear surfaces (lines in 2D, planes in 3D e.t.c.). Thus it makes tree of polyhedrons and points that they contains. Repeating the operation several times for greater accuracy it creates a forest. +To find K Nearest Neighbours it goes down through the trees and fills the buffer of closest points using the priority queue of polyhedrons. Next, it sorts buffer and return the nearest K points. + +__Example__: +```sql +CREATE TABLE t +( + id Int64, + number Tuple(Float32, Float32, Float32), + INDEX x number TYPE annoy(T) GRANULARITY N +) +ENGINE = MergeTree +ORDER BY id; +``` +Parameter `T` is the number of trees which algorithm will create. The bigger it is, the slower (approximately linear) it works (in both `CREATE` and `SELECT` requests), but the better accuracy you get (adjusted for randomness). + +In the `SELECT` in the settings (`ann_index_params`) you can specify the size of the internal buffer (more details in the description above or in the [original repository](https://github.com/spotify/annoy)). +This parameter may help you to adjust the trade-off between query speed and accuracy. diff --git a/docs/en/engines/table-engines/mergetree-family/mergetree.md b/docs/en/engines/table-engines/mergetree-family/mergetree.md index 3793526ba64..6ad6bc877a2 100644 --- a/docs/en/engines/table-engines/mergetree-family/mergetree.md +++ b/docs/en/engines/table-engines/mergetree-family/mergetree.md @@ -480,6 +480,8 @@ For example: - `NOT startsWith(s, 'test')` ::: +In addition to skip indices, there are also [Approximate Nearest Neighbor Search Indexes](../../../engines/table-engines/mergetree-family/replication.md). + ## Projections {#projections} Projections are like [materialized views](../../../sql-reference/statements/create/view.md#materialized) but defined in part-level. It provides consistency guarantees along with automatic usage in queries. @@ -1032,73 +1034,3 @@ Examples of working configurations can be found in integration tests directory ( - `_partition_value` — Values (a tuple) of a `partition by` expression. - `_sample_factor` — Sample factor (from the query). -# ANN Skip Index [experimental] {#table_engines-ANNIndex} - -`ANNIndexes` are designed to speed up two types of queries: - -- ###### Type 1: Where - ``` sql - SELECT * FROM table_name WHERE - DistanceFunction(Column, TargetVector) < Value - LIMIT N - ``` -- ###### Type 2: OrderBy - ``` sql - SELECT * FROM table_name [WHERE ...] OrderBy - DistanceFunction(Column, TargetVector) - LIMIT N - ``` - -In these queries, `DistanceFunction` is selected from tuples of distance functions. `TargetVector` is a known embedding (something like `(0.1, 0.1, ... )`). `Value` - a float value that will bound the neighbourhood. - -!!! note "Note" - ANNIndex can't speed up query that satisfies both types and they work only for Tuples. All queries must have the limit, as algorithms are used to find nearest neighbors and need a specific number of them. - -Both types of queries are handled the same way. The indexes get `n` neighbors (where `n` is taken from the `LIMIT` section) and work with them. In `ORDER BY` query they remember the numbers of all parts of the granule that have at least one of neighbor. In `WHERE` query they remember only those parts that satisfy the requirements. - -###### Create table with ANNIndex -``` -CREATE TABLE t -( - `id` Int64, - `number` Tuple(Float32, Float32, Float32), - INDEX x number TYPE annoy GRANULARITY N -) -ENGINE = MergeTree -ORDER BY id; -``` - -!!! note "Note" - ANNIndexes work only when setting `index_granularity=8192`. - -Number of granules in granularity should be large. With greater `GRANULARITY` indexes remember the data structure better. But some indexes can't be built if they don't have enough data, so this granule will always participate in the query. For more information, see the description of indexes. - -As the indexes are built only during insertions into table, `INSERT` and `OPTIMIZE` queries are slower than for ordinary table. OAt this stage indexes remember all the information about the given data. ANNIndexes should be used if you have immutable or rarely changed data and many read requests. - -You can create your table with index which uses certain algorithm. Now only indices based on the following algorithms are supported: - -##### Index list -- Annoy - -# Annoy {#annoy} -Implementation of the algorithm was taken from [this repository](https://github.com/spotify/annoy). - -Short description of the algorithm: -The algorithm recursively divides in half all space by random linear surfaces (lines in 2D, planes in 3D e.t.c.). Thus it makes tree of polyhedrons and points that they contains. Repeating the operation several times for greater accuracy it creates a forest. -To find K Nearest Neighbours it goes down through the trees and fills the buffer of closest points using the priority queue of polyhedrons. Next, it sorts buffer and return the nearest K points. - -__Example__: -```sql -CREATE TABLE t -( - id Int64, - number Tuple(Float32, Float32, Float32), - INDEX x number TYPE annoy(T) GRANULARITY N -) -ENGINE = MergeTree -ORDER BY id; -``` -Parameter `T` is the number of trees which algorithm will create. The bigger it is, the slower (approximately linear) it works (in both `CREATE` and `SELECT` requests), but the better accuracy you get (adjusted for randomness). - -In the `SELECT` in the settings (`ann_index_params`) you can specify the size of the internal buffer (more details in the description above or in the [original repository](https://github.com/spotify/annoy)). -This parameter may help you to adjust the trade-off between query speed and accuracy. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1865ce66d09..b4ec86e2991 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -545,8 +545,8 @@ if (TARGET ch_contrib::rapidjson) endif() dbms_target_link_libraries(PUBLIC ch_contrib::consistent_hashing) -if (ENABLE_ANNOY) - dbms_target_link_libraries(PUBLIC ch_contrib::spotify-annoy) +if (TARGET ch_contrib::annoy AND ENABLE_ANNOY) + dbms_target_link_libraries(PUBLIC ch_contrib::annoy) endif() include ("${ClickHouse_SOURCE_DIR}/cmake/add_check.cmake") diff --git a/src/Core/Settings.h b/src/Core/Settings.h index 2b7e43f2c81..c45fdd0b290 100644 --- a/src/Core/Settings.h +++ b/src/Core/Settings.h @@ -599,7 +599,7 @@ static constexpr UInt64 operator""_GiB(unsigned long long value) M(Bool, allow_experimental_hash_functions, false, "Enable experimental hash functions (hashid, etc)", 0) \ M(Bool, allow_experimental_object_type, false, "Allow Object and JSON data types", 0) \ M(String, insert_deduplication_token, "", "If not empty, used for duplicate detection instead of data digest", 0) \ - M(String, ann_index_params, "", "Parameters for ANNIndexes in select queries", 0) \ + M(String, ann_index_params, "", "Parameters for ANNIndexes in select queries. String of parameters like `param1=x,param2=y...`. See ANNIndexes documentation for each index", 0) \ M(Bool, count_distinct_optimization, false, "Rewrite count distinct to subquery of group by", 0) \ M(Bool, throw_on_unsupported_query_inside_transaction, true, "Throw exception if unsupported query is used inside transaction", 0) \ M(TransactionsWaitCSNMode, wait_changes_become_visible_after_commit_mode, TransactionsWaitCSNMode::WAIT_UNKNOWN, "Wait for committed changes to become actually visible in the latest snapshot", 0) \ diff --git a/src/Storages/MergeTree/CommonANNIndexes.cpp b/src/Storages/MergeTree/CommonANNIndexes.cpp index d6518676d5b..a3b6e018d7a 100644 --- a/src/Storages/MergeTree/CommonANNIndexes.cpp +++ b/src/Storages/MergeTree/CommonANNIndexes.cpp @@ -1,3 +1,6 @@ +#include +#include + #include #include #include @@ -7,8 +10,6 @@ #include -#include - namespace DB { @@ -18,23 +19,14 @@ namespace ErrorCodes extern const int INCORRECT_QUERY; } -namespace ANNCondition +namespace ApproximateNearestNeighbour { ANNCondition::ANNCondition(const SelectQueryInfo & query_info, - ContextPtr context) -{ - // Initialize - block_with_constants = KeyCondition:: - getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context); - // Build rpns for query sections - buildRPN(query_info); - - // Match rpns with supported types - index_is_useful = matchAllRPNS(); - // Get from settings ANNIndex parameters - ann_index_params = context->getSettings().get("ann_index_params").get(); -} + ContextPtr context) : + block_with_constants{KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context)}, + ann_index_params{context->getSettings().get("ann_index_params").get()}, + index_is_useful{checkQueryStructure(query_info)} {} bool ANNCondition::alwaysUnknownOrTrue(String metric_name) const { @@ -43,87 +35,60 @@ bool ANNCondition::alwaysUnknownOrTrue(String metric_name) const return true; // Query isn't supported } // If query is supported, check metrics for match - return !(metric_name == ann_expr->metric_name); + return !(metric_name == query_information->metric_name); } -float ANNCondition::getComparisonDistance() const +///TODO: check for all getters? +float ANNCondition::getComparisonDistanceForWhereQuery() const { - if (where_query_type) + ///TODO: query_information->??? + if (query_information->query_type == ANNQueryInformation::Type::WhereQuery) { - return ann_expr->distance; + return query_information->distance; } throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported method for this query type"); } -std::vector ANNCondition::getTargetVector() const -{ - return ann_expr->target; -} - -String ANNCondition::getColumnName() const -{ - return ann_expr->column_name; -} - -String ANNCondition::getMetric() const -{ - return ann_expr->metric_name; -} - -size_t ANNCondition::getSpaceDim() const -{ - return ann_expr->target.size(); -} - -float ANNCondition::getPForLpDistance() const -{ - return ann_expr->p_for_lp_dist; -} - -bool ANNCondition::queryHasWhereClause() const -{ - return where_query_type; -} - -bool ANNCondition::queryHasOrderByClause() const -{ - return order_by_query_type; -} - UInt64 ANNCondition::getLimitCount() const { if (index_is_useful) { - return limit_count; + return query_information->limit; } throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported"); } -String ANNCondition::getSettingsStr() const +bool ANNCondition::checkQueryStructure(const SelectQueryInfo & query) { - return ann_index_params; -} + // RPN-s for different sections of the query + RPN rpn_prewhere_clause; + RPN rpn_where_clause; + RPN rpn_order_by_clause; + RPNElement rpn_limit; -void ANNCondition::buildRPN(const SelectQueryInfo & query) -{ + ANNQueryInformation prewhere_info; + ANNQueryInformation where_info; + ANNQueryInformation order_by_info; + + // Build rpns for query sections const auto & select = query.query->as(); - if (select.prewhere()) // If query has PREWHERE section + if (select.prewhere()) // If query has PREWHERE clause { traverseAST(select.prewhere(), rpn_prewhere_clause); } - if (select.where()) // If query has WHERE section + if (select.where()) // If query has WHERE clause { traverseAST(select.where(), rpn_where_clause); } - if (select.limitLength()) // If query has LIMIT section + if (select.limitLength()) // If query has LIMIT clause { - traverseAST(select.limitLength(), rpn_limit_clause); + traverseAtomAST(select.limitLength(), rpn_limit); } - if (select.orderBy()) // If query has ORDERBY section + if (select.orderBy()) // If query has ORDERBY clause { traverseOrderByAST(select.orderBy(), rpn_order_by_clause); } @@ -132,18 +97,55 @@ void ANNCondition::buildRPN(const SelectQueryInfo & query) std::reverse(rpn_prewhere_clause.begin(), rpn_prewhere_clause.end()); std::reverse(rpn_where_clause.begin(), rpn_where_clause.end()); std::reverse(rpn_order_by_clause.begin(), rpn_order_by_clause.end()); + + // Match rpns with supported types and extract information + const bool prewhere_is_valid = matchRPNWhere(rpn_prewhere_clause, prewhere_info); + const bool where_is_valid = matchRPNWhere(rpn_where_clause, where_info); + const bool limit_is_valid = matchRPNLimit(rpn_limit, query_information->limit); + const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by_clause, order_by_info); + + // Query without LIMIT clause is not supported + if (!limit_is_valid) + { + return false; + } + + // Search type query in both sections isn't supported + if (prewhere_is_valid && where_is_valid) + { + return false; + } + + // Search type should be in WHERE or PREWHERE clause + if (prewhere_is_valid || where_is_valid) + { + query_information = std::move(where_is_valid ? where_info : prewhere_info); + } + + if (order_by_is_valid) + { + // Query with valid where and order by type is not supported + if (query_information.has_value()) + { + return false; + } + + query_information = std::move(order_by_info); + } + + return query_information.has_value(); } void ANNCondition::traverseAST(const ASTPtr & node, RPN & rpn) { - // If the node is ASTFUunction, it may have children nodes + // If the node is ASTFunction, it may have children nodes if (const auto * func = node->as()) { - const ASTs & args = func->arguments->children; + const ASTs & children = func->arguments->children; // Traverse children nodes - for (const auto& arg : args) + for (const auto& child : children) { - traverseAST(arg, rpn); + traverseAST(child, rpn); } } @@ -221,6 +223,8 @@ bool ANNCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out) out.func_name = "Float literal"; return true; } + + /// TODO: Uint? if (const_value.getType() == Field::Types::UInt64) { out.function = RPNElement::FUNCTION_INT_LITERAL; @@ -228,6 +232,7 @@ bool ANNCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out) out.func_name = "Int literal"; return true; } + if (const_value.getType() == Field::Types::Int64) { out.function = RPNElement::FUNCTION_INT_LITERAL; @@ -235,6 +240,7 @@ bool ANNCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out) out.func_name = "Int literal"; return true; } + if (const_value.getType() == Field::Types::Tuple) { out.function = RPNElement::FUNCTION_LITERAL_TUPLE; @@ -258,153 +264,8 @@ void ANNCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn) } } -bool ANNCondition::matchAllRPNS() -{ - ANNExpression expr_prewhere; - ANNExpression expr_where; - ANNExpression expr_order_by; - LimitExpression expr_limit; - bool prewhere_is_valid = matchRPNWhere(rpn_prewhere_clause, expr_prewhere); - bool where_is_valid = matchRPNWhere(rpn_where_clause, expr_where); - bool limit_is_valid = matchRPNLimit(rpn_limit_clause, expr_limit); - bool order_by_is_valid = matchRPNOrderBy(rpn_order_by_clause, expr_order_by); - - // Query without LIMIT section is not supported - if (!limit_is_valid) - { - return false; - } - // Set LIMIT count - limit_count = expr_limit.length; - // Search type query in both sections isn't supported - if (prewhere_is_valid && where_is_valid) - { - return false; - } - // Search type should be in WHERE or PREWHERE section - if (prewhere_is_valid || where_is_valid) - { - ann_expr = std::move(where_is_valid ? expr_where : expr_prewhere); - where_query_type = true; - } - if (order_by_is_valid) - { - ann_expr = std::move(expr_order_by); - order_by_query_type = true; - } - // Query with valid search and orderby type is not supported - if (where_query_type && order_by_query_type) - { - return false; - } - - return where_query_type || order_by_query_type; -} - -bool ANNCondition::matchRPNLimit(RPN & rpn, LimitExpression & expr) -{ - // LIMIT section must have least 1 expression - if (rpn.size() != 1) - { - return false; - } - if (rpn.front().function == RPNElement::FUNCTION_INT_LITERAL) - { - expr.length = rpn.front().int_literal.value(); - return true; - } - return false; -} - -bool ANNCondition::matchRPNOrderBy(RPN & rpn, ANNExpression & expr) -{ - // ORDERBY section must have at least 3 expressions - if (rpn.size() < 3) - { - return false; - } - - auto iter = rpn.begin(); - auto end = rpn.end(); - bool identifier_found = false; - - return ANNCondition::matchMainParts(iter, end, expr, identifier_found); -} - -bool ANNCondition::matchMainParts(RPN::iterator & iter, RPN::iterator & end, - ANNExpression & expr, bool & identifier_found) - { - // Matches DistanceFunc->[Column]->[TupleFunc]->TargetVector(floats)->[Column] - if (iter->function != RPNElement::FUNCTION_DISTANCE) - { - return false; - } - - expr.metric_name = iter->func_name; - ++iter; - - if (expr.metric_name == "LpDistance") - { - if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL && - iter->function != RPNElement::FUNCTION_INT_LITERAL) - { - return false; - } - expr.p_for_lp_dist = getFloatOrIntLiteralOrPanic(iter); - ++iter; - } - - - if (iter->function == RPNElement::FUNCTION_IDENTIFIER) - { - identifier_found = true; - expr.column_name = getIdentifierOrPanic(iter); - ++iter; - } - - if (iter->function == RPNElement::FUNCTION_TUPLE) - { - ++iter; - } - - if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE) - { - for (const auto & value : iter->tuple_literal.value()) - { - expr.target.emplace_back(value.get()); - } - ++iter; - } - - - while (iter != end) - { - if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL || - iter->function == RPNElement::FUNCTION_INT_LITERAL) - { - expr.target.emplace_back(getFloatOrIntLiteralOrPanic(iter)); - } - else if (iter->function == RPNElement::FUNCTION_IDENTIFIER) - { - if (identifier_found) - { - return false; - } - expr.column_name = getIdentifierOrPanic(iter); - identifier_found = true; - } - else - { - return false; - } - - ++iter; - } - - return true; -} - -bool ANNCondition::matchRPNWhere(RPN & rpn, ANNExpression & expr) +// Returns true and stores ANNQueryInformation if the query has valid WHERE clause +bool ANNCondition::matchRPNWhere(RPN & rpn, ANNQueryInformation & expr) { // WHERE section must have at least 5 expressions // Operator->Distance(float)->DistanceFunc->Column->TupleFunc(TargetVector(floats)) @@ -465,24 +326,112 @@ bool ANNCondition::matchRPNWhere(RPN & rpn, ANNExpression & expr) expr.target.pop_back(); } - // Querry is ok + // query is ok return true; } -String ANNCondition::getIdentifierOrPanic(RPN::iterator& iter) +// Returns true and stores ANNExpr if the query has valid ORDERBY clause +bool ANNCondition::matchRPNOrderBy(RPN & rpn, ANNQueryInformation & expr) { - String identifier; - try + // ORDER BY clause must have at least 3 expressions + if (rpn.size() < 3) { - identifier = std::move(iter->identifier.value()); + return false; } - catch (...) - { - ANNCondition::panicIfWrongBuiltRPN(); - } - return identifier; + + auto iter = rpn.begin(); + auto end = rpn.end(); + bool identifier_found = false; + + return ANNCondition::matchMainParts(iter, end, expr, identifier_found); } +// Returns true and stores Length if we have valid LIMIT clause in query +bool ANNCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit) +{ + if (rpn.function == RPNElement::FUNCTION_INT_LITERAL) + { + limit = rpn.int_literal.value(); + return true; + } + + return false; +} + +/* Matches dist function, target vector, column name */ +bool ANNCondition::matchMainParts(RPN::iterator & iter, RPN::iterator & end, ANNQueryInformation & expr, bool & identifier_found) +{ + // Matches DistanceFunc->[Column]->[TupleFunc]->TargetVector(floats)->[Column] + if (iter->function != RPNElement::FUNCTION_DISTANCE) + { + return false; + } + + expr.metric_name = iter->func_name; + ++iter; + + if (expr.metric_name == "LpDistance") + { + if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL && + iter->function != RPNElement::FUNCTION_INT_LITERAL) + { + return false; + } + expr.p_for_lp_dist = getFloatOrIntLiteralOrPanic(iter); + ++iter; + } + + + if (iter->function == RPNElement::FUNCTION_IDENTIFIER) + { + identifier_found = true; + expr.column_name = std::move(iter->identifier.value()); + ++iter; + } + + if (iter->function == RPNElement::FUNCTION_TUPLE) + { + ++iter; + } + + if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE) + { + for (const auto & value : iter->tuple_literal.value()) + { + expr.target.emplace_back(value.get()); + } + ++iter; + } + + + while (iter != end) + { + if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL || + iter->function == RPNElement::FUNCTION_INT_LITERAL) + { + expr.target.emplace_back(getFloatOrIntLiteralOrPanic(iter)); + } + else if (iter->function == RPNElement::FUNCTION_IDENTIFIER) + { + if (identifier_found) + { + return false; + } + expr.column_name = std::move(iter->identifier.value()); + identifier_found = true; + } + else + { + return false; + } + + ++iter; + } + + return true; +} + +// Gets float or int from AST node float ANNCondition::getFloatOrIntLiteralOrPanic(RPN::iterator& iter) { if (iter->float_literal.has_value()) @@ -493,14 +442,7 @@ float ANNCondition::getFloatOrIntLiteralOrPanic(RPN::iterator& iter) { return static_cast(iter->int_literal.value()); } - ANNCondition::panicIfWrongBuiltRPN(); -} - -void ANNCondition::panicIfWrongBuiltRPN() -{ - LOG_DEBUG(&Poco::Logger::get("ANNCondition"), "Wrong parsing of AST"); - throw Exception( - "Wrong parsed AST in buildRPN\n", DB::ErrorCodes::INCORRECT_QUERY); + throw Exception("Wrong parsed AST in buildRPN\n", ErrorCodes::INCORRECT_QUERY); } } diff --git a/src/Storages/MergeTree/CommonANNIndexes.h b/src/Storages/MergeTree/CommonANNIndexes.h index d18a90c2f11..533461559be 100644 --- a/src/Storages/MergeTree/CommonANNIndexes.h +++ b/src/Storages/MergeTree/CommonANNIndexes.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include "base/types.h" @@ -10,9 +9,43 @@ namespace DB { -namespace ANNCondition +namespace ApproximateNearestNeighbour { +/** + * Queries for Approximate Nearest Neighbour Search + * have similar structure: + * 1) target vector from which all distances are calculated + * 2) metric name (e.g L2Distance, LpDistance, etc.) + * 3) name of column with embeddings + * 4) type of query + * 5) Number of elements, that should be taken (limit) + * + * And two optional parameters: + * 1) p for LpDistance function + * 2) distance to compare with (only for where queries) + */ +struct ANNQueryInformation +{ + using Embedding = std::vector; + + // Extracted data from valid query + Embedding target; + String metric_name; + String column_name; + UInt64 limit; + + enum class Type + { + Undefined, + OrderByQuery, + WhereQuery + } query_type; + + float p_for_lp_dist = -1.0; + float distance = -1.0; +}; + /** Class ANNCondition, is responsible for recognizing special query types which can be speeded up by ANN Indexes. It parses the SQL query and checks @@ -44,12 +77,12 @@ namespace ANNCondition * distance to compare(ONLY for search types, otherwise you get exception) * spaceDimension(which is targetVector's components count) * column - * objects count from LIMIT section(for both queries) + * objects count from LIMIT clause(for both queries) * settings str, if query has settings section with new 'ann_index_params' value, than you can get the new value(empty by default) calling method getSettingsStr * queryHasOrderByClause and queryHasWhereClause return true if query matches the type - Search query type is also recognized for PREWHERE section + Search query type is also recognized for PREWHERE clause */ class ANNCondition @@ -58,60 +91,37 @@ public: ANNCondition(const SelectQueryInfo & query_info, ContextPtr context); - // flase if query can be speeded up, true otherwise + // false if query can be speeded up, true otherwise bool alwaysUnknownOrTrue(String metric_name) const; // returns the distance to compare with for search query - float getComparisonDistance() const; + float getComparisonDistanceForWhereQuery() const; // distance should be calculated regarding to targetVector - std::vector getTargetVector() const; + std::vector getTargetVector() const { return query_information->target; } // targetVector dimension size - size_t getSpaceDim() const; + size_t getNumOfDimensions() const { return query_information->target.size(); } - // data Column Name in DB - String getColumnName() const; + ///TODO: nullptr + String getColumnName() const { return query_information->column_name; } - // Distance function name - String getMetric() const; + String getMetricName() const { return query_information->metric_name; } // the P- value if the metric is 'LpDistance' - float getPForLpDistance() const; + float getPValueForLpDistance() const { return query_information->p_for_lp_dist; } - // true if query match ORDERBY type - bool queryHasOrderByClause() const; + bool queryHasOrderByClause() const { return query_information->query_type == ANNQueryInformation::Type::OrderByQuery; } - // true if query match Search type - bool queryHasWhereClause() const; + bool queryHasWhereClause() const { return query_information->query_type == ANNQueryInformation::Type::WhereQuery; } - // length's value from LIMIT section, nullopt if not any + // length's value from LIMIT clause, nullopt if not any UInt64 getLimitCount() const; - // value of 'ann_index_params' if have in SETTINGS section, empty string otherwise - String getSettingsStr() const; + // value of 'ann_index_params' if have in SETTINGS clause, empty string otherwise + String getParamsStr() const { return ann_index_params; } private: - // Type of the vector to use as a target in the distance function - using Target = std::vector; - - // Extracted data from valid query - struct ANNExpression - { - Target target; - float distance = -1.0; - String metric_name; - String column_name; - float p_for_lp_dist = -1.0; // The P parameter for LpDistance - }; - - struct LimitExpression - { - Int64 length; - }; - - using ANNExprOpt = std::optional; - using LimitExprOpt = std::optional; struct RPNElement { @@ -150,15 +160,15 @@ private: std::optional float_literal; std::optional identifier; - std::optional int_literal{std::nullopt}; - std::optional tuple_literal{std::nullopt}; + std::optional int_literal; + std::optional tuple_literal; - UInt32 dim{0}; + UInt32 dim = 0; }; using RPN = std::vector; - void buildRPN(const SelectQueryInfo & query); + bool checkQueryStructure(const SelectQueryInfo & query); // Util functions for the traversal of AST, parses AST and builds rpn void traverseAST(const ASTPtr & node, RPN & rpn); @@ -171,45 +181,34 @@ private: // Checks that at least one rpn is matching for index // New RPNs for other query types can be added here - bool matchAllRPNS(); + bool matchAllRPNS(); // Returns true and stores ANNExpr if the query has valid WHERE section - static bool matchRPNWhere(RPN & rpn, ANNExpression & expr); + static bool matchRPNWhere(RPN & rpn, ANNQueryInformation & expr); // Returns true and stores ANNExpr if the query has valid ORDERBY section - static bool matchRPNOrderBy(RPN & rpn, ANNExpression & expr); + static bool matchRPNOrderBy(RPN & rpn, ANNQueryInformation & expr); // Returns true and stores Length if we have valid LIMIT clause in query - static bool matchRPNLimit(RPN & rpn, LimitExpression & expr); + static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit); /* Matches dist function, target vector, column name */ - static bool matchMainParts(RPN::iterator & iter, RPN::iterator & end, ANNExpression & expr, bool & identifier_found); + static bool matchMainParts(RPN::iterator & iter, RPN::iterator & end, ANNQueryInformation & expr, bool & identifier_found); // Util methods static void panicIfWrongBuiltRPN [[noreturn]] (); - static String getIdentifierOrPanic(RPN::iterator& iter); + // Gets float or int from AST node static float getFloatOrIntLiteralOrPanic(RPN::iterator& iter); - - // RPN-s for different sections of the query - RPN rpn_prewhere_clause; - RPN rpn_where_clause; - RPN rpn_limit_clause; - RPN rpn_order_by_clause; - Block block_with_constants; - // Data extracted from query, in case query has supported type - ANNExprOpt ann_expr{std::nullopt}; - UInt64 limit_count{0}; - String ann_index_params; // Empty string if no params - - bool order_by_query_type{false}; - bool where_query_type{false}; - // true if we have one of two supported query types - bool index_is_useful{false}; + std::optional query_information; + + // Get from settings ANNIndex parameters + String ann_index_params; + bool index_is_useful = false; }; // condition interface for Ann indexes. Returns vector of indexes of ranges in granule which are useful for query. @@ -219,6 +218,8 @@ public: virtual std::vector getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const = 0; }; -} +} // namespace ApproximateNearestNeighbour -} +namespace ANN = ApproximateNearestNeighbour; + +} // namespace DB diff --git a/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp b/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp index f350e9039cd..e90821195f1 100644 --- a/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp +++ b/src/Storages/MergeTree/MergeTreeDataSelectExecutor.cpp @@ -1641,7 +1641,7 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex( if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin) granule = reader.read(); // Cast to Ann condition - auto ann_condition = std::dynamic_pointer_cast(condition); + auto ann_condition = std::dynamic_pointer_cast(condition); if (ann_condition != nullptr) { // vector of indexes of useful ranges diff --git a/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp b/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp index 6132563ba5c..6bc711ab2ca 100644 --- a/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp +++ b/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp @@ -17,10 +17,7 @@ namespace Annoy template void AnnoyIndexSerialize::serialize(WriteBuffer& ostr) const { - if (!Base::_built) - { - throw Exception("Annoy Index should be built before serialization", ErrorCodes::LOGICAL_ERROR); - } + assert(Base::_built); writeIntBinary(Base::_s, ostr); writeIntBinary(Base::_n_items, ostr); writeIntBinary(Base::_n_nodes, ostr); @@ -53,7 +50,7 @@ void AnnoyIndexSerialize::deserialize(ReadBuffer& istr) } template -float AnnoyIndexSerialize::getSpaceDim() const +float AnnoyIndexSerialize::getNumOfDimensions() const { return Base::get_f(); } @@ -89,7 +86,7 @@ bool MergeTreeIndexGranuleAnnoy::empty() const void MergeTreeIndexGranuleAnnoy::serializeBinary(WriteBuffer & ostr) const { - writeIntBinary(index_base->getSpaceDim(), ostr); // write dimension + writeIntBinary(index_base->getNumOfDimensions(), ostr); // write dimension index_base->serialize(ostr); } @@ -102,9 +99,10 @@ void MergeTreeIndexGranuleAnnoy::deserializeBinary(ReadBuffer & istr, MergeTreeI } -MergeTreeIndexAggregatorAnnoy::MergeTreeIndexAggregatorAnnoy(const String & index_name_, - const Block & index_sample_block_, - int index_param_) +MergeTreeIndexAggregatorAnnoy::MergeTreeIndexAggregatorAnnoy( + const String & index_name_, + const Block & index_sample_block_, + int index_param_) : index_name(index_name_) , index_sample_block(index_sample_block_) , index_param(index_param_) @@ -127,8 +125,9 @@ void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, si { if (*pos >= block.rows()) throw Exception( - "The provided position is not less than the number of block rows. Position: " - + toString(*pos) + ", Block rows: " + toString(block.rows()) + ".", ErrorCodes::LOGICAL_ERROR); + ErrorCodes::LOGICAL_ERROR, + "The provided position is not less than the number of block rows. Position: {}, Block rows: {}.", + toString(*pos), toString(block.rows())); size_t rows_read = std::min(limit, block.rows() - *pos); @@ -187,7 +186,11 @@ std::vector MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndex { UInt64 limit = condition.getLimitCount(); std::optional comp_dist - = condition.queryHasWhereClause() ? std::optional(condition.getComparisonDistance()) : std::nullopt; + = condition.queryHasWhereClause() ? std::optional(condition.getComparisonDistanceForWhereQuery()) : std::nullopt; + + if (comp_dist && comp_dist.value() < 0) + throw Exception(ErrorCodes::LOGICAL_ERROR, "Attemp to optimize query with where without distance"); + std::vector target_vec = condition.getTargetVector(); auto granule = std::dynamic_pointer_cast(idx_granule); @@ -197,10 +200,10 @@ std::vector MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndex } auto annoy = granule->index_base; - if (condition.getSpaceDim() != annoy->getSpaceDim()) + if (condition.getNumOfDimensions() != annoy->getNumOfDimensions()) { - throw Exception("The dimension of the space in the request (" + toString(condition.getSpaceDim()) + ") " - + "does not match with the dimension in the index (" + toString(annoy->getSpaceDim()) + ")", ErrorCodes::INCORRECT_QUERY); + throw Exception("The dimension of the space in the request (" + toString(condition.getNumOfDimensions()) + ") " + + "does not match with the dimension in the index (" + toString(annoy->getNumOfDimensions()) + ")", ErrorCodes::INCORRECT_QUERY); } std::vector items; @@ -209,12 +212,12 @@ std::vector MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndex dist.reserve(limit); int k_search = -1; - auto settings_str = condition.getSettingsStr(); - if (!settings_str.empty()) + auto params_str = condition.getParamsStr(); + if (!params_str.empty()) { try { - k_search = std::stoi(settings_str); + k_search = std::stoi(params_str); } catch (...) { @@ -259,11 +262,11 @@ MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition( return std::make_shared(index, query, context); }; -MergeTreeIndexFormat MergeTreeIndexAnnoy::getDeserializedFormat(const DiskPtr disk, const std::string & relative_path_prefix) const +MergeTreeIndexFormat MergeTreeIndexAnnoy::getDeserializedFormat(const DataPartStoragePtr & data_part_storage, const std::string & relative_path_prefix) const { - if (disk->exists(relative_path_prefix + ".idx2")) + if (data_part_storage->exists(relative_path_prefix + ".idx2")) return {2, ".idx2"}; - else if (disk->exists(relative_path_prefix + ".idx")) + else if (data_part_storage->exists(relative_path_prefix + ".idx")) return {1, ".idx"}; return {0 /* unknown */, ""}; } diff --git a/src/Storages/MergeTree/MergeTreeIndexAnnoy.h b/src/Storages/MergeTree/MergeTreeIndexAnnoy.h index 9d65bfc7d35..228bdd00052 100644 --- a/src/Storages/MergeTree/MergeTreeIndexAnnoy.h +++ b/src/Storages/MergeTree/MergeTreeIndexAnnoy.h @@ -7,7 +7,7 @@ #include #include - +///TODO:Arrays namespace DB { @@ -27,7 +27,7 @@ namespace Annoy explicit AnnoyIndexSerialize(const int dim) : Base::AnnoyIndex(dim) {} void serialize(WriteBuffer& ostr) const; void deserialize(ReadBuffer& istr); - float getSpaceDim() const; + float getNumOfDimensions() const; }; } @@ -74,7 +74,7 @@ struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator }; -class MergeTreeIndexConditionAnnoy final : public ANNCondition::IMergeTreeIndexConditionAnn +class MergeTreeIndexConditionAnnoy final : public ANN::IMergeTreeIndexConditionAnn { public: MergeTreeIndexConditionAnnoy( @@ -91,14 +91,14 @@ public: ~MergeTreeIndexConditionAnnoy() override = default; private: - ANNCondition::ANNCondition condition; + ANN::ANNCondition condition; }; class MergeTreeIndexAnnoy : public IMergeTreeIndex { public: - explicit MergeTreeIndexAnnoy(const IndexDescription & index_, int index_param_) + MergeTreeIndexAnnoy(const IndexDescription & index_, int index_param_) : IMergeTreeIndex(index_) , index_param(index_param_) {} @@ -114,10 +114,10 @@ public: bool mayBenefitFromIndexForIn(const ASTPtr & /*node*/) const override { return true; } const char* getSerializedFileExtension() const override { return ".idx2"; } - MergeTreeIndexFormat getDeserializedFormat(const DiskPtr disk, const std::string & path_prefix) const override; + MergeTreeIndexFormat getDeserializedFormat(const DataPartStoragePtr & data_part_storage, const std::string & path_prefix) const override; private: - int index_param; + const int index_param; }; diff --git a/src/Storages/MergeTree/MergeTreeIndices.cpp b/src/Storages/MergeTree/MergeTreeIndices.cpp index 2374b2fbf8b..faf2c8a14e9 100644 --- a/src/Storages/MergeTree/MergeTreeIndices.cpp +++ b/src/Storages/MergeTree/MergeTreeIndices.cpp @@ -102,10 +102,10 @@ MergeTreeIndexFactory::MergeTreeIndexFactory() registerCreator("hypothesis", hypothesisIndexCreator); registerValidator("hypothesis", hypothesisIndexValidator); - #ifdef ENABLE_ANNOY +#ifdef ENABLE_ANNOY registerCreator("annoy", AnnoyIndexCreator); registerValidator("annoy", AnnoyIndexValidator); - #endif +#endif } MergeTreeIndexFactory & MergeTreeIndexFactory::instance()