#ifdef ENABLE_ANNOY #include #include #include #include #include #include #include #include #include #include namespace DB { namespace ErrorCodes { extern const int ILLEGAL_COLUMN; extern const int INCORRECT_DATA; extern const int INCORRECT_NUMBER_OF_COLUMNS; extern const int INCORRECT_QUERY; extern const int LOGICAL_ERROR; } template AnnoyIndexWithSerialization::AnnoyIndexWithSerialization(size_t dimensions) : Base::AnnoyIndex(dimensions) { } template void AnnoyIndexWithSerialization::serialize(WriteBuffer & ostr) const { chassert(Base::_built); writeIntBinary(Base::_s, ostr); writeIntBinary(Base::_n_items, ostr); writeIntBinary(Base::_n_nodes, ostr); writeIntBinary(Base::_nodes_size, ostr); writeIntBinary(Base::_K, ostr); writeIntBinary(Base::_seed, ostr); writeVectorBinary(Base::_roots, ostr); ostr.write(reinterpret_cast(Base::_nodes), Base::_s * Base::_n_nodes); } template void AnnoyIndexWithSerialization::deserialize(ReadBuffer & istr) { chassert(!Base::_built); readIntBinary(Base::_s, istr); readIntBinary(Base::_n_items, istr); readIntBinary(Base::_n_nodes, istr); readIntBinary(Base::_nodes_size, istr); readIntBinary(Base::_K, istr); readIntBinary(Base::_seed, istr); readVectorBinary(Base::_roots, istr); Base::_nodes = realloc(Base::_nodes, Base::_s * Base::_n_nodes); istr.readStrict(reinterpret_cast(Base::_nodes), Base::_s * Base::_n_nodes); Base::_fd = 0; // set flags Base::_loaded = false; Base::_verbose = false; Base::_on_disk = false; Base::_built = true; } template size_t AnnoyIndexWithSerialization::getDimensions() const { return Base::get_f(); } template MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_) : index_name(index_name_) , index_sample_block(index_sample_block_) , index(nullptr) {} template MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy( const String & index_name_, const Block & index_sample_block_, AnnoyIndexWithSerializationPtr index_) : index_name(index_name_) , index_sample_block(index_sample_block_) , index(std::move(index_)) {} template void MergeTreeIndexGranuleAnnoy::serializeBinary(WriteBuffer & ostr) const { /// Number of dimensions is required in the index constructor, /// so it must be written and read separately from the other part writeIntBinary(static_cast(index->getDimensions()), ostr); // write dimension index->serialize(ostr); } template void MergeTreeIndexGranuleAnnoy::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/) { UInt64 dimension; readIntBinary(dimension, istr); index = std::make_shared>(dimension); index->deserialize(istr); } template MergeTreeIndexAggregatorAnnoy::MergeTreeIndexAggregatorAnnoy( const String & index_name_, const Block & index_sample_block_, UInt64 trees_) : index_name(index_name_) , index_sample_block(index_sample_block_) , trees(trees_) {} template MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy::getGranuleAndReset() { // NOLINTNEXTLINE(*) index->build(static_cast(trees), /*number_of_threads=*/1); auto granule = std::make_shared>(index_name, index_sample_block, index); index = nullptr; return granule; } template void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, size_t limit) { if (*pos >= block.rows()) throw Exception( ErrorCodes::LOGICAL_ERROR, "The provided position is not less than the number of block rows. Position: {}, Block rows: {}.", *pos, block.rows()); size_t rows_read = std::min(limit, block.rows() - *pos); if (rows_read == 0) return; if (index_sample_block.columns() > 1) throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected block with single column"); const String & index_column_name = index_sample_block.getByPosition(0).name; ColumnPtr column_cut = block.getByName(index_column_name).column->cut(*pos, rows_read); if (const auto & column_array = typeid_cast(column_cut.get())) { const auto & data = column_array->getData(); const auto & array = typeid_cast(data).getData(); if (array.empty()) throw Exception(ErrorCodes::LOGICAL_ERROR, "Array has 0 rows, {} rows expected", rows_read); const auto & offsets = column_array->getOffsets(); const size_t num_rows = offsets.size(); /// Check all sizes are the same size_t size = offsets[0]; for (size_t i = 0; i < num_rows - 1; ++i) if (offsets[i + 1] - offsets[i] != size) throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column {} must have equal length", index_column_name); index = std::make_shared>(size); /// Add all rows of block index->add_item(index->get_n_items(), array.data()); for (size_t current_row = 1; current_row < num_rows; ++current_row) index->add_item(index->get_n_items(), &array[offsets[current_row - 1]]); } else if (const auto & column_tuple = typeid_cast(column_cut.get())) { const auto & columns = column_tuple->getColumns(); /// TODO check if calling index->add_item() directly on the block's tuples is faster than materializing everything std::vector> data{column_tuple->size(), std::vector()}; for (const auto & column : columns) { const auto & pod_array = typeid_cast(column.get())->getData(); for (size_t i = 0; i < pod_array.size(); ++i) data[i].push_back(pod_array[i]); } if (data.empty()) throw Exception(ErrorCodes::LOGICAL_ERROR, "Tuple has 0 rows, {} rows expected", rows_read); index = std::make_shared>(data[0].size()); for (const auto & item : data) index->add_item(index->get_n_items(), item.data()); } else throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array or Tuple column"); *pos += rows_read; } MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy( const IndexDescription & /*index_description*/, const SelectQueryInfo & query, const String & distance_function_, ContextPtr context) : ann_condition(query, context) , distance_function(distance_function_) , search_k(context->getSettings().annoy_index_search_k_nodes) {} bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /*idx_granule*/) const { throw Exception(ErrorCodes::LOGICAL_ERROR, "mayBeTrueOnGranule is not supported for ANN skip indexes"); } bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const { return ann_condition.alwaysUnknownOrTrue(distance_function); } std::vector MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const { if (distance_function == "L2Distance") return getUsefulRangesImpl(idx_granule); else if (distance_function == "cosineDistance") return getUsefulRangesImpl(idx_granule); std::unreachable(); } template std::vector MergeTreeIndexConditionAnnoy::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const { const UInt64 limit = ann_condition.getLimit(); const UInt64 index_granularity = ann_condition.getIndexGranularity(); const std::optional comparison_distance = ann_condition.getQueryType() == ApproximateNearestNeighborInformation::Type::Where ? std::optional(ann_condition.getComparisonDistanceForWhereQuery()) : std::nullopt; if (comparison_distance && comparison_distance.value() < 0) throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to optimize query with where without distance"); const std::vector reference_vector = ann_condition.getReferenceVector(); const auto granule = std::dynamic_pointer_cast>(idx_granule); if (granule == nullptr) throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type"); const AnnoyIndexWithSerializationPtr annoy = granule->index; if (ann_condition.getDimensions() != annoy->getDimensions()) throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) " "does not match the dimension in the index ({})", ann_condition.getDimensions(), annoy->getDimensions()); std::vector neighbors; /// indexes of dots which were closest to the reference vector std::vector distances; neighbors.reserve(limit); distances.reserve(limit); annoy->get_nns_by_vector(reference_vector.data(), limit, static_cast(search_k), &neighbors, &distances); chassert(neighbors.size() == distances.size()); std::vector granule_numbers; granule_numbers.reserve(neighbors.size()); for (size_t i = 0; i < neighbors.size(); ++i) { if (comparison_distance && distances[i] > comparison_distance) continue; granule_numbers.push_back(neighbors[i] / index_granularity); } /// make unique std::sort(granule_numbers.begin(), granule_numbers.end()); granule_numbers.erase(std::unique(granule_numbers.begin(), granule_numbers.end()), granule_numbers.end()); return granule_numbers; } MergeTreeIndexAnnoy::MergeTreeIndexAnnoy(const IndexDescription & index_, UInt64 trees_, const String & distance_function_) : IMergeTreeIndex(index_) , trees(trees_) , distance_function(distance_function_) {} MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const { if (distance_function == "L2Distance") return std::make_shared>(index.name, index.sample_block); else if (distance_function == "cosineDistance") return std::make_shared>(index.name, index.sample_block); std::unreachable(); } MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const { /// TODO: Support more metrics. Available metrics: https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171 if (distance_function == "L2Distance") return std::make_shared>(index.name, index.sample_block, trees); else if (distance_function == "cosineDistance") return std::make_shared>(index.name, index.sample_block, trees); std::unreachable(); } MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const { return std::make_shared(index, query, distance_function, context); }; MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index) { static constexpr auto default_trees = 100uz; static constexpr auto default_distance_function = "L2Distance"; String distance_function = default_distance_function; if (!index.arguments.empty()) distance_function = index.arguments[0].get(); UInt64 trees = default_trees; if (index.arguments.size() > 1) trees = index.arguments[1].get(); return std::make_shared(index, trees, distance_function); } void annoyIndexValidator(const IndexDescription & index, bool /* attach */) { /// Check number and type of Annoy index arguments: if (index.arguments.size() > 2) throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index must not have more than two parameters"); if (!index.arguments.empty() && index.arguments[0].getType() != Field::Types::String) throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance function argument of Annoy index must be of type String"); if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::UInt64) throw Exception(ErrorCodes::INCORRECT_QUERY, "Number of trees argument of Annoy index must be of type UInt64"); /// Check that the index is created on a single column if (index.column_names.size() != 1 || index.data_types.size() != 1) throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Annoy indexes must be created on a single column"); /// Check that a supported metric was passed as first argument if (!index.arguments.empty()) { String distance_name = index.arguments[0].get(); if (distance_name != "L2Distance" && distance_name != "cosineDistance") throw Exception(ErrorCodes::INCORRECT_DATA, "Annoy index only supports distance functions 'L2Distance' and 'cosineDistance'"); } /// Check data type of indexed column: auto throw_unsupported_underlying_column_exception = []() { throw Exception( ErrorCodes::ILLEGAL_COLUMN, "Annoy indexes can only be created on columns of type Array(Float32) and Tuple(Float32)"); }; DataTypePtr data_type = index.sample_block.getDataTypes()[0]; if (const auto * data_type_array = typeid_cast(data_type.get())) { TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId(); if (!WhichDataType(nested_type_index).isFloat32()) throw_unsupported_underlying_column_exception(); } else if (const auto * data_type_tuple = typeid_cast(data_type.get())) { const DataTypes & inner_types = data_type_tuple->getElements(); for (const auto & inner_type : inner_types) { TypeIndex nested_type_index = inner_type->getTypeId(); if (!WhichDataType(nested_type_index).isFloat32()) throw_unsupported_underlying_column_exception(); } } else throw_unsupported_underlying_column_exception(); } } #endif