From 8f8c622e7cee4379e6cae16de6a20be481c2bc78 Mon Sep 17 00:00:00 2001 From: Vladimir Makarov Date: Sat, 14 May 2022 08:24:54 +0000 Subject: [PATCH] fix --- contrib/CMakeLists.txt | 1 - src/CMakeLists.txt | 1 + .../MergeTree/MergeTreeIndexAnnoy.cpp | 112 ++++++++++++++---- src/Storages/MergeTree/MergeTreeIndexAnnoy.h | 4 +- src/Storages/MergeTree/MergeTreeIndices.cpp | 4 +- src/Storages/MergeTree/MergeTreeIndices.h | 2 + 6 files changed, 98 insertions(+), 26 deletions(-) diff --git a/contrib/CMakeLists.txt b/contrib/CMakeLists.txt index 864063d106d..7e5a75b9646 100644 --- a/contrib/CMakeLists.txt +++ b/contrib/CMakeLists.txt @@ -151,7 +151,6 @@ add_contrib (s2geometry-cmake s2geometry) set (ENABLE_ANNOY_DEFAULT ${ENABLE_LIBRARIES}) if (ENABLE_ANNOY) - add_compile_definitions(ENABLE_ANNOY) add_contrib (spotify-annoy-cmake spotify-annoy) endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f7e4b326faf..fcd29032e7a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -523,6 +523,7 @@ endif() dbms_target_link_libraries(PUBLIC ch_contrib::consistent_hashing) if (ENABLE_ANNOY) + add_compile_definitions(ENABLE_ANNOY) dbms_target_link_libraries(PUBLIC ch_contrib::spotify-annoy) endif() diff --git a/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp b/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp index bd6540e2c34..cd7ae7e099f 100644 --- a/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp +++ b/src/Storages/MergeTree/MergeTreeIndexAnnoy.cpp @@ -30,7 +30,8 @@ namespace Annoy template void AnnoyIndexSerialize::serialize(WriteBuffer& ostr) const { - if (!Base::_built) { + if (!Base::_built) + { throw Exception("Annoy Index should be built before serialization", ErrorCodes::LOGICAL_ERROR); } writeIntBinary(Base::_s, ostr); @@ -65,7 +66,8 @@ void AnnoyIndexSerialize::deserialize(ReadBuffer& istr) } template -float AnnoyIndexSerialize::getSpaceDim() const { +float AnnoyIndexSerialize::getSpaceDim() const +{ return Base::get_f(); } @@ -93,7 +95,8 @@ MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy( , index_base(std::move(index_base_)) {} -bool MergeTreeIndexGranuleAnnoy::empty() const { +bool MergeTreeIndexGranuleAnnoy::empty() const +{ return !static_cast(index_base); } @@ -127,7 +130,8 @@ bool MergeTreeIndexAggregatorAnnoy::empty() const MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy::getGranuleAndReset() { - if (empty()) { + if (empty()) + { return std::make_shared(index_name, index_sample_block); } index_base->build(index_param); @@ -145,7 +149,8 @@ void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, si size_t rows_read = std::min(limit, block.rows() - *pos); - if (index_sample_block.columns() > 1) { + if (index_sample_block.columns() > 1) + { throw Exception("Only one column is supported", ErrorCodes::LOGICAL_ERROR); } @@ -155,17 +160,21 @@ void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, si const auto & columns = column_tuple->getColumns(); std::vector> data{column_tuple->size(), std::vector()}; - for (size_t j = 0; j < columns.size(); ++j) { + for (size_t j = 0; j < columns.size(); ++j) + { const auto& pod_array = typeid_cast(columns[j].get())->getData(); - for (size_t i = 0; i < pod_array.size(); ++i) { + for (size_t i = 0; i < pod_array.size(); ++i) + { data[i].push_back(pod_array[i]); } } assert(!data.empty()); - if (!index_base) { + if (!index_base) + { index_base = std::make_shared(data[0].size()); } - for (const auto& item : data) { + for (const auto& item : data) + { index_base->add_item(index_base->get_n_items(), &item[0]); } @@ -178,24 +187,21 @@ MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy( const SelectQueryInfo & query, ContextPtr context) : condition(query, context) -{ -} +{} bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const { auto granule = std::dynamic_pointer_cast(idx_granule); - if (granule == nullptr) { + if (granule == nullptr) + { throw Exception("Granule has the wrong type", ErrorCodes::LOGICAL_ERROR); } auto annoy = granule->index_base; - if (condition.getMetric() != "L2Distance") { - throw Exception("The metric in the request (" + condition.getMetric() + ")" - + "does not match with the metric in the index (L2Distance)", ErrorCodes::INCORRECT_QUERY); - } - if (condition.getSpaceDim() == annoy->getSpaceDim()) { - throw Exception("The dimension of the space in the request (" + toString(condition.getSpaceDim()) + ")" + if (condition.getSpaceDim() != annoy->getSpaceDim()) + { + throw Exception("The dimension of the space in the request (" + toString(condition.getSpaceDim()) + ") " + "does not match with the dimension in the index (" + toString(annoy->getSpaceDim()) + ")", ErrorCodes::INCORRECT_QUERY); } std::vector target_vec = condition.getTargetVector(); @@ -206,14 +212,15 @@ bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr i int k_search = -1; auto settings_str = condition.getSettingsStr(); - if (!settings_str.empty()) { + if (!settings_str.empty()) + { try { k_search = std::stoi(settings_str); } catch (...) { - throw Exception("Setting of the annoy index should be int"); + throw Exception("Setting of the annoy index should be int", ErrorCodes::INCORRECT_QUERY); } } annoy->get_nns_by_vector(&target_vec[0], 1, k_search, &items, &dist); @@ -225,6 +232,65 @@ bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const return condition.alwaysUnknownOrTrue("L2Distance"); } +std::vector MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const +{ + UInt64 limit = condition.getLimitCount(); + std::optional comp_dist + = condition.queryHasWhereClause() ? std::optional(condition.getComparisonDistance()) : std::nullopt; + std::vector target_vec = condition.getTargetVector(); + + auto granule = std::dynamic_pointer_cast(idx_granule); + if (granule == nullptr) + { + throw Exception("Granule has the wrong type", ErrorCodes::LOGICAL_ERROR); + } + auto annoy = granule->index_base; + + if (condition.getSpaceDim() != annoy->getSpaceDim()) + { + throw Exception("The dimension of the space in the request (" + toString(condition.getSpaceDim()) + ") " + + "does not match with the dimension in the index (" + toString(annoy->getSpaceDim()) + ")", ErrorCodes::INCORRECT_QUERY); + } + + std::vector items; + std::vector dist; + items.reserve(limit); + dist.reserve(limit); + + int k_search = -1; + auto settings_str = condition.getSettingsStr(); + if (!settings_str.empty()) + { + try + { + k_search = std::stoi(settings_str); + } + catch (...) + { + throw Exception("Setting of the annoy index should be int", ErrorCodes::INCORRECT_QUERY); + } + } + annoy->get_nns_by_vector(&target_vec[0], 1, k_search, &items, &dist); + std::unordered_set result; + for (size_t i = 0; i < items.size(); ++i) + { + if (comp_dist && dist[i] > comp_dist) + { + continue; + } + result.insert(items[i] / 8192); + } + + std::vector result_vector; + result_vector.reserve(result.size()); + for (auto range : result) + { + result_vector.push_back(range); + } + + return result_vector; +} + MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const { @@ -260,10 +326,12 @@ MergeTreeIndexPtr AnnoyIndexCreator( void AnnoyIndexValidator(const IndexDescription & index, bool /* attach */) { - if (index.arguments.size() != 1) { + if (index.arguments.size() != 1) + { throw Exception("Annoy index must have exactly one argument.", ErrorCodes::INCORRECT_QUERY); } - if (index.arguments[0].getType() != Field::Types::UInt64) { + if (index.arguments[0].getType() != Field::Types::UInt64) + { throw Exception("Annoy index argument must be UInt64.", ErrorCodes::INCORRECT_QUERY); } } diff --git a/src/Storages/MergeTree/MergeTreeIndexAnnoy.h b/src/Storages/MergeTree/MergeTreeIndexAnnoy.h index e58fd553916..06cb4ecae71 100644 --- a/src/Storages/MergeTree/MergeTreeIndexAnnoy.h +++ b/src/Storages/MergeTree/MergeTreeIndexAnnoy.h @@ -74,7 +74,7 @@ struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator }; -class MergeTreeIndexConditionAnnoy final : public IMergeTreeIndexCondition +class MergeTreeIndexConditionAnnoy final : public ANNCondition::IMergeTreeIndexConditionAnn { public: MergeTreeIndexConditionAnnoy( @@ -86,6 +86,8 @@ public: bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override; + std::vector getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override; + ~MergeTreeIndexConditionAnnoy() override = default; private: diff --git a/src/Storages/MergeTree/MergeTreeIndices.cpp b/src/Storages/MergeTree/MergeTreeIndices.cpp index 10a5d121436..2374b2fbf8b 100644 --- a/src/Storages/MergeTree/MergeTreeIndices.cpp +++ b/src/Storages/MergeTree/MergeTreeIndices.cpp @@ -103,8 +103,8 @@ MergeTreeIndexFactory::MergeTreeIndexFactory() registerValidator("hypothesis", hypothesisIndexValidator); #ifdef ENABLE_ANNOY - registerCreator("annoy", AnnoyIndexCreator); - registerValidator("annoy", AnnoyIndexValidator); + registerCreator("annoy", AnnoyIndexCreator); + registerValidator("annoy", AnnoyIndexValidator); #endif } diff --git a/src/Storages/MergeTree/MergeTreeIndices.h b/src/Storages/MergeTree/MergeTreeIndices.h index 4c0a3fd679b..525677b635d 100644 --- a/src/Storages/MergeTree/MergeTreeIndices.h +++ b/src/Storages/MergeTree/MergeTreeIndices.h @@ -223,7 +223,9 @@ void bloomFilterIndexValidatorNew(const IndexDescription & index, bool attach); MergeTreeIndexPtr hypothesisIndexCreator(const IndexDescription & index); void hypothesisIndexValidator(const IndexDescription & index, bool attach); +#ifdef ENABLE_ANNOY MergeTreeIndexPtr AnnoyIndexCreator(const IndexDescription & index); void AnnoyIndexValidator(const IndexDescription & index, bool attach); +#endif }