2022-08-30 15:26:56 +00:00
|
|
|
#ifdef ENABLE_ANNOY
|
|
|
|
|
|
|
|
#include <Storages/MergeTree/MergeTreeIndexAnnoy.h>
|
|
|
|
|
|
|
|
#include <Common/typeid_cast.h>
|
|
|
|
#include <Core/Field.h>
|
|
|
|
#include <IO/ReadHelpers.h>
|
|
|
|
#include <IO/WriteHelpers.h>
|
|
|
|
#include <Interpreters/castColumn.h>
|
|
|
|
#include <Columns/ColumnArray.h>
|
|
|
|
#include <DataTypes/DataTypeArray.h>
|
2022-10-19 12:35:47 +00:00
|
|
|
#include <DataTypes/DataTypeTuple.h>
|
2022-08-30 15:26:56 +00:00
|
|
|
|
|
|
|
|
|
|
|
namespace DB
|
|
|
|
{
|
|
|
|
|
2023-05-25 19:47:05 +00:00
|
|
|
namespace ErrorCodes
|
|
|
|
{
|
|
|
|
extern const int ILLEGAL_COLUMN;
|
|
|
|
extern const int INCORRECT_DATA;
|
|
|
|
extern const int INCORRECT_NUMBER_OF_COLUMNS;
|
|
|
|
extern const int INCORRECT_QUERY;
|
|
|
|
extern const int LOGICAL_ERROR;
|
|
|
|
extern const int BAD_ARGUMENTS;
|
|
|
|
}
|
|
|
|
|
2022-08-30 15:26:56 +00:00
|
|
|
|
2023-05-25 19:39:04 +00:00
|
|
|
template <typename Distance>
|
2023-05-25 19:55:01 +00:00
|
|
|
AnnoyIndexWithSerialization<Distance>::AnnoyIndexWithSerialization(uint64_t dim)
|
2023-05-25 19:39:04 +00:00
|
|
|
: Base::AnnoyIndex(dim)
|
|
|
|
{
|
|
|
|
}
|
|
|
|
|
2023-05-25 19:36:27 +00:00
|
|
|
template<typename Distance>
|
2023-05-25 19:55:01 +00:00
|
|
|
void AnnoyIndexWithSerialization<Distance>::serialize(WriteBuffer& ostr) const
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
2023-05-25 19:36:27 +00:00
|
|
|
chassert(Base::_built);
|
2022-08-30 15:26:56 +00:00
|
|
|
writeIntBinary(Base::_s, ostr);
|
|
|
|
writeIntBinary(Base::_n_items, ostr);
|
|
|
|
writeIntBinary(Base::_n_nodes, ostr);
|
|
|
|
writeIntBinary(Base::_nodes_size, ostr);
|
|
|
|
writeIntBinary(Base::_K, ostr);
|
|
|
|
writeIntBinary(Base::_seed, ostr);
|
|
|
|
writeVectorBinary(Base::_roots, ostr);
|
|
|
|
ostr.write(reinterpret_cast<const char*>(Base::_nodes), Base::_s * Base::_n_nodes);
|
|
|
|
}
|
|
|
|
|
2023-05-25 19:36:27 +00:00
|
|
|
template<typename Distance>
|
2023-05-25 19:55:01 +00:00
|
|
|
void AnnoyIndexWithSerialization<Distance>::deserialize(ReadBuffer& istr)
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
2023-05-25 19:36:27 +00:00
|
|
|
chassert(!Base::_built);
|
2022-08-30 15:26:56 +00:00
|
|
|
readIntBinary(Base::_s, istr);
|
|
|
|
readIntBinary(Base::_n_items, istr);
|
|
|
|
readIntBinary(Base::_n_nodes, istr);
|
|
|
|
readIntBinary(Base::_nodes_size, istr);
|
|
|
|
readIntBinary(Base::_K, istr);
|
|
|
|
readIntBinary(Base::_seed, istr);
|
|
|
|
readVectorBinary(Base::_roots, istr);
|
|
|
|
Base::_nodes = realloc(Base::_nodes, Base::_s * Base::_n_nodes);
|
2022-11-11 09:56:18 +00:00
|
|
|
istr.readStrict(reinterpret_cast<char *>(Base::_nodes), Base::_s * Base::_n_nodes);
|
2022-08-30 15:26:56 +00:00
|
|
|
|
|
|
|
Base::_fd = 0;
|
|
|
|
// set flags
|
|
|
|
Base::_loaded = false;
|
|
|
|
Base::_verbose = false;
|
|
|
|
Base::_on_disk = false;
|
|
|
|
Base::_built = true;
|
|
|
|
}
|
|
|
|
|
2023-05-25 19:36:27 +00:00
|
|
|
template<typename Distance>
|
2023-05-25 19:55:01 +00:00
|
|
|
uint64_t AnnoyIndexWithSerialization<Distance>::getNumOfDimensions() const
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
|
|
|
return Base::get_f();
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2022-10-28 17:03:35 +00:00
|
|
|
template <typename Distance>
|
|
|
|
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_)
|
2022-08-30 15:26:56 +00:00
|
|
|
: index_name(index_name_)
|
|
|
|
, index_sample_block(index_sample_block_)
|
|
|
|
, index(nullptr)
|
|
|
|
{}
|
|
|
|
|
2022-10-28 17:03:35 +00:00
|
|
|
template <typename Distance>
|
|
|
|
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(
|
2022-08-30 15:26:56 +00:00
|
|
|
const String & index_name_,
|
|
|
|
const Block & index_sample_block_,
|
2023-05-25 19:55:01 +00:00
|
|
|
AnnoyIndexWithSerializationPtr<Distance> index_)
|
2022-08-30 15:26:56 +00:00
|
|
|
: index_name(index_name_)
|
|
|
|
, index_sample_block(index_sample_block_)
|
2023-05-25 19:36:27 +00:00
|
|
|
, index(std::move(index_))
|
2022-08-30 15:26:56 +00:00
|
|
|
{}
|
|
|
|
|
2022-10-28 17:03:35 +00:00
|
|
|
template <typename Distance>
|
|
|
|
void MergeTreeIndexGranuleAnnoy<Distance>::serializeBinary(WriteBuffer & ostr) const
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
2023-05-25 19:36:27 +00:00
|
|
|
/// Number of dimensions is required in the index constructor,
|
2022-08-30 15:26:56 +00:00
|
|
|
/// so it must be written and read separately from the other part
|
|
|
|
writeIntBinary(index->getNumOfDimensions(), ostr); // write dimension
|
|
|
|
index->serialize(ostr);
|
|
|
|
}
|
|
|
|
|
2022-10-28 17:03:35 +00:00
|
|
|
template <typename Distance>
|
|
|
|
void MergeTreeIndexGranuleAnnoy<Distance>::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
|
|
|
uint64_t dimension;
|
|
|
|
readIntBinary(dimension, istr);
|
2023-05-25 19:55:01 +00:00
|
|
|
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(dimension);
|
2022-08-30 15:26:56 +00:00
|
|
|
index->deserialize(istr);
|
|
|
|
}
|
|
|
|
|
2022-10-28 17:03:35 +00:00
|
|
|
template <typename Distance>
|
|
|
|
MergeTreeIndexAggregatorAnnoy<Distance>::MergeTreeIndexAggregatorAnnoy(
|
2022-08-30 15:26:56 +00:00
|
|
|
const String & index_name_,
|
|
|
|
const Block & index_sample_block_,
|
2023-05-25 19:44:20 +00:00
|
|
|
uint64_t trees_)
|
2022-08-30 15:26:56 +00:00
|
|
|
: index_name(index_name_)
|
|
|
|
, index_sample_block(index_sample_block_)
|
2023-05-25 19:44:20 +00:00
|
|
|
, trees(trees_)
|
2022-08-30 15:26:56 +00:00
|
|
|
{}
|
|
|
|
|
2022-10-28 17:03:35 +00:00
|
|
|
template <typename Distance>
|
|
|
|
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy<Distance>::getGranuleAndReset()
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
|
|
|
// NOLINTNEXTLINE(*)
|
2023-05-25 19:44:20 +00:00
|
|
|
index->build(static_cast<int>(trees), /*number_of_threads=*/1);
|
2023-05-25 19:36:27 +00:00
|
|
|
auto granule = std::make_shared<MergeTreeIndexGranuleAnnoy<Distance>>(index_name, index_sample_block, index);
|
2022-08-30 15:26:56 +00:00
|
|
|
index = nullptr;
|
|
|
|
return granule;
|
|
|
|
}
|
|
|
|
|
2022-10-28 17:03:35 +00:00
|
|
|
template <typename Distance>
|
|
|
|
void MergeTreeIndexAggregatorAnnoy<Distance>::update(const Block & block, size_t * pos, size_t limit)
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
|
|
|
if (*pos >= block.rows())
|
|
|
|
throw Exception(
|
|
|
|
ErrorCodes::LOGICAL_ERROR,
|
|
|
|
"The provided position is not less than the number of block rows. Position: {}, Block rows: {}.",
|
2023-05-25 20:07:43 +00:00
|
|
|
*pos, block.rows());
|
2022-08-30 15:26:56 +00:00
|
|
|
|
|
|
|
size_t rows_read = std::min(limit, block.rows() - *pos);
|
2022-08-30 15:32:05 +00:00
|
|
|
if (rows_read == 0)
|
|
|
|
return;
|
2022-08-30 15:26:56 +00:00
|
|
|
|
|
|
|
if (index_sample_block.columns() > 1)
|
2023-05-25 20:11:10 +00:00
|
|
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected block with single column");
|
2022-08-30 15:26:56 +00:00
|
|
|
|
|
|
|
auto index_column_name = index_sample_block.getByPosition(0).name;
|
|
|
|
const auto & column_cut = block.getByName(index_column_name).column->cut(*pos, rows_read);
|
2023-05-25 20:12:29 +00:00
|
|
|
|
2023-05-25 20:13:22 +00:00
|
|
|
if (const auto & column_array = typeid_cast<const ColumnArray *>(column_cut.get()))
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
|
|
|
const auto & data = column_array->getData();
|
2023-05-25 20:13:22 +00:00
|
|
|
const auto & array = typeid_cast<const ColumnFloat32 &>(data).getData();
|
2022-08-30 23:09:22 +00:00
|
|
|
if (array.empty())
|
2022-10-19 12:35:47 +00:00
|
|
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "Array has 0 rows, {} rows expected", rows_read);
|
2022-08-30 15:26:56 +00:00
|
|
|
const auto & offsets = column_array->getOffsets();
|
2022-08-30 15:46:55 +00:00
|
|
|
size_t num_rows = offsets.size();
|
2022-08-30 15:26:56 +00:00
|
|
|
|
2022-10-19 12:35:47 +00:00
|
|
|
/// Check all sizes are the same
|
2022-08-30 15:26:56 +00:00
|
|
|
size_t size = offsets[0];
|
|
|
|
for (size_t i = 0; i < num_rows - 1; ++i)
|
|
|
|
if (offsets[i + 1] - offsets[i] != size)
|
2023-05-25 20:11:10 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column {} must have equal length", index_column_name);
|
2022-10-19 12:35:47 +00:00
|
|
|
|
2023-05-25 19:55:01 +00:00
|
|
|
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(size);
|
2022-08-30 15:26:56 +00:00
|
|
|
|
2022-08-30 23:09:22 +00:00
|
|
|
index->add_item(index->get_n_items(), array.data());
|
2022-08-30 15:32:05 +00:00
|
|
|
/// add all rows from 1 to num_rows - 1 (this is the same as the beginning of the last element)
|
2022-08-31 13:28:10 +00:00
|
|
|
for (size_t current_row = 1; current_row < num_rows; ++current_row)
|
|
|
|
index->add_item(index->get_n_items(), &array[offsets[current_row - 1]]);
|
2022-08-30 15:26:56 +00:00
|
|
|
}
|
2023-05-25 20:13:22 +00:00
|
|
|
else if (const auto & column_tuple = typeid_cast<const ColumnTuple *>(column_cut.get()))
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
|
|
|
const auto & columns = column_tuple->getColumns();
|
|
|
|
|
|
|
|
std::vector<std::vector<Float32>> data{column_tuple->size(), std::vector<Float32>()};
|
2023-05-25 20:13:22 +00:00
|
|
|
for (const auto & column : columns)
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
2023-05-25 20:13:22 +00:00
|
|
|
const auto & pod_array = typeid_cast<const ColumnFloat32 *>(column.get())->getData();
|
2022-08-30 15:26:56 +00:00
|
|
|
for (size_t i = 0; i < pod_array.size(); ++i)
|
|
|
|
data[i].push_back(pod_array[i]);
|
|
|
|
}
|
|
|
|
assert(!data.empty());
|
|
|
|
if (!index)
|
2023-05-25 19:55:01 +00:00
|
|
|
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(data[0].size());
|
2023-05-25 20:13:22 +00:00
|
|
|
for (const auto & item : data)
|
2022-08-30 15:26:56 +00:00
|
|
|
index->add_item(index->get_n_items(), item.data());
|
|
|
|
}
|
2023-05-25 20:12:29 +00:00
|
|
|
else
|
|
|
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array or Tuple column");
|
2022-08-30 15:26:56 +00:00
|
|
|
|
|
|
|
*pos += rows_read;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy(
|
|
|
|
const IndexDescription & /*index*/,
|
|
|
|
const SelectQueryInfo & query,
|
2023-05-25 20:03:51 +00:00
|
|
|
const String& distance_name_,
|
|
|
|
ContextPtr context)
|
2023-05-25 19:36:27 +00:00
|
|
|
: condition(query, context)
|
|
|
|
, distance_name(distance_name_)
|
2022-08-30 15:26:56 +00:00
|
|
|
{}
|
|
|
|
|
|
|
|
|
|
|
|
bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /* idx_granule */) const
|
|
|
|
{
|
2023-01-23 21:13:58 +00:00
|
|
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "mayBeTrueOnGranule is not supported for ANN skip indexes");
|
2022-08-30 15:26:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const
|
|
|
|
{
|
2022-11-12 09:23:49 +00:00
|
|
|
return condition.alwaysUnknownOrTrue(distance_name);
|
2022-08-30 15:26:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
|
2022-10-28 17:03:35 +00:00
|
|
|
{
|
2022-11-12 09:23:49 +00:00
|
|
|
if (distance_name == "L2Distance")
|
2023-05-25 19:48:12 +00:00
|
|
|
return getUsefulRangesImpl<Annoy::Euclidean>(idx_granule);
|
2022-11-12 09:23:49 +00:00
|
|
|
else if (distance_name == "cosineDistance")
|
2023-05-25 19:48:12 +00:00
|
|
|
return getUsefulRangesImpl<Annoy::Angular>(idx_granule);
|
2022-10-28 17:03:35 +00:00
|
|
|
else
|
2022-11-12 09:23:49 +00:00
|
|
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown distance name. Must be 'L2Distance' or 'cosineDistance'. Got {}", distance_name);
|
2022-10-28 17:03:35 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename Distance>
|
|
|
|
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
|
2022-08-30 15:26:56 +00:00
|
|
|
{
|
|
|
|
UInt64 limit = condition.getLimit();
|
|
|
|
UInt64 index_granularity = condition.getIndexGranularity();
|
2023-05-25 19:36:27 +00:00
|
|
|
std::optional<float> comp_dist = condition.getQueryType() == ApproximateNearestNeighbour::ANNQueryInformation::Type::Where
|
|
|
|
? std::optional<float>(condition.getComparisonDistanceForWhereQuery())
|
|
|
|
: std::nullopt;
|
2022-08-30 15:26:56 +00:00
|
|
|
|
|
|
|
if (comp_dist && comp_dist.value() < 0)
|
|
|
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to optimize query with where without distance");
|
|
|
|
|
|
|
|
std::vector<float> target_vec = condition.getTargetVector();
|
|
|
|
|
2023-05-25 19:36:27 +00:00
|
|
|
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy<Distance>>(idx_granule);
|
2022-08-30 15:26:56 +00:00
|
|
|
if (granule == nullptr)
|
2023-01-23 21:13:58 +00:00
|
|
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
|
2022-10-19 12:35:47 +00:00
|
|
|
|
2022-08-30 15:26:56 +00:00
|
|
|
auto annoy = granule->index;
|
|
|
|
|
|
|
|
if (condition.getNumOfDimensions() != annoy->getNumOfDimensions())
|
2023-01-23 21:13:58 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
|
|
|
|
"does not match with the dimension in the index ({})",
|
|
|
|
toString(condition.getNumOfDimensions()), toString(annoy->getNumOfDimensions()));
|
2022-08-30 15:26:56 +00:00
|
|
|
|
|
|
|
/// neighbors contain indexes of dots which were closest to target vector
|
|
|
|
std::vector<UInt64> neighbors;
|
|
|
|
std::vector<Float32> distances;
|
|
|
|
neighbors.reserve(limit);
|
|
|
|
distances.reserve(limit);
|
|
|
|
|
|
|
|
int k_search = -1;
|
|
|
|
String params_str = condition.getParamsStr();
|
|
|
|
if (!params_str.empty())
|
|
|
|
{
|
|
|
|
try
|
|
|
|
{
|
|
|
|
/// k_search=... (algorithm will inspect up to search_k nodes which defaults to n_trees * n if not provided)
|
|
|
|
k_search = std::stoi(params_str.data() + 9);
|
|
|
|
}
|
|
|
|
catch (...)
|
|
|
|
{
|
2023-01-23 21:13:58 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_QUERY, "Setting of the annoy index should be int");
|
2022-08-30 15:26:56 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
annoy->get_nns_by_vector(target_vec.data(), limit, k_search, &neighbors, &distances);
|
|
|
|
std::unordered_set<size_t> granule_numbers;
|
|
|
|
for (size_t i = 0; i < neighbors.size(); ++i)
|
|
|
|
{
|
|
|
|
if (comp_dist && distances[i] > comp_dist)
|
|
|
|
continue;
|
|
|
|
granule_numbers.insert(neighbors[i] / index_granularity);
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<size_t> result_vector;
|
|
|
|
result_vector.reserve(granule_numbers.size());
|
|
|
|
for (auto granule_number : granule_numbers)
|
|
|
|
result_vector.push_back(granule_number);
|
|
|
|
|
|
|
|
return result_vector;
|
|
|
|
}
|
|
|
|
|
2023-05-25 19:44:20 +00:00
|
|
|
MergeTreeIndexAnnoy::MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t trees_, const String& distance_name_)
|
2023-05-25 19:39:04 +00:00
|
|
|
: IMergeTreeIndex(index_)
|
2023-05-25 19:44:20 +00:00
|
|
|
, trees(trees_)
|
2023-05-25 19:39:04 +00:00
|
|
|
, distance_name(distance_name_)
|
|
|
|
{}
|
|
|
|
|
2022-08-30 15:26:56 +00:00
|
|
|
MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
|
|
|
|
{
|
2022-11-12 09:23:49 +00:00
|
|
|
if (distance_name == "L2Distance")
|
2023-05-25 19:48:12 +00:00
|
|
|
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Euclidean>>(index.name, index.sample_block);
|
2023-05-25 19:31:34 +00:00
|
|
|
else if (distance_name == "cosineDistance")
|
2023-05-25 19:48:12 +00:00
|
|
|
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Angular>>(index.name, index.sample_block);
|
2022-11-12 09:23:49 +00:00
|
|
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown distance name. Must be 'L2Distance' or 'cosineDistance'. Got {}", distance_name);
|
2022-08-30 15:26:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const
|
|
|
|
{
|
2023-05-25 19:36:27 +00:00
|
|
|
/// TODO: Support more metrics. Available metrics: https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171
|
2022-11-12 09:23:49 +00:00
|
|
|
if (distance_name == "L2Distance")
|
2023-05-25 19:48:12 +00:00
|
|
|
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Euclidean>>(index.name, index.sample_block, trees);
|
2022-11-12 09:23:49 +00:00
|
|
|
if (distance_name == "cosineDistance")
|
2023-05-25 19:48:12 +00:00
|
|
|
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Angular>>(index.name, index.sample_block, trees);
|
2022-11-12 09:23:49 +00:00
|
|
|
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown distance name. Must be 'L2Distance' or 'cosineDistance'. Got {}", distance_name);
|
2022-08-30 15:26:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(
|
|
|
|
const SelectQueryInfo & query, ContextPtr context) const
|
|
|
|
{
|
2023-05-25 20:03:51 +00:00
|
|
|
return std::make_shared<MergeTreeIndexConditionAnnoy>(index, query, distance_name, context);
|
2022-08-30 15:26:56 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index)
|
|
|
|
{
|
2022-10-28 17:03:35 +00:00
|
|
|
uint64_t param = 100;
|
2022-11-12 09:23:49 +00:00
|
|
|
String distance_name = "L2Distance";
|
2022-10-30 14:57:10 +00:00
|
|
|
if (!index.arguments.empty() && !index.arguments[0].tryGet<uint64_t>(param))
|
2022-11-12 09:23:49 +00:00
|
|
|
if (!index.arguments[0].tryGet<String>(distance_name))
|
2023-01-23 21:13:58 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_DATA, "Can't parse first argument");
|
2022-11-12 09:23:49 +00:00
|
|
|
if (index.arguments.size() > 1 && !index.arguments[1].tryGet<String>(distance_name))
|
2023-01-23 21:13:58 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_DATA, "Can't parse second argument");
|
2022-11-12 09:23:49 +00:00
|
|
|
return std::make_shared<MergeTreeIndexAnnoy>(index, param, distance_name);
|
2022-08-30 15:26:56 +00:00
|
|
|
}
|
|
|
|
|
2022-10-19 12:35:47 +00:00
|
|
|
static void assertIndexColumnsType(const Block & header)
|
|
|
|
{
|
|
|
|
DataTypePtr column_data_type_ptr = header.getDataTypes()[0];
|
|
|
|
|
|
|
|
if (const auto * array_type = typeid_cast<const DataTypeArray *>(column_data_type_ptr.get()))
|
|
|
|
{
|
|
|
|
TypeIndex nested_type_index = array_type->getNestedType()->getTypeId();
|
|
|
|
if (!WhichDataType(nested_type_index).isFloat32())
|
|
|
|
throw Exception(
|
|
|
|
ErrorCodes::ILLEGAL_COLUMN,
|
2023-05-25 20:05:25 +00:00
|
|
|
"Unexpected type {} of Annoy index. Only Array(Float32) and Tuple(Float32) are supported",
|
2022-10-19 12:35:47 +00:00
|
|
|
column_data_type_ptr->getName());
|
|
|
|
}
|
|
|
|
else if (const auto * tuple_type = typeid_cast<const DataTypeTuple *>(column_data_type_ptr.get()))
|
|
|
|
{
|
|
|
|
const DataTypes & nested_types = tuple_type->getElements();
|
|
|
|
for (const auto & type : nested_types)
|
|
|
|
{
|
|
|
|
TypeIndex nested_type_index = type->getTypeId();
|
|
|
|
if (!WhichDataType(nested_type_index).isFloat32())
|
|
|
|
throw Exception(
|
|
|
|
ErrorCodes::ILLEGAL_COLUMN,
|
2023-05-25 20:05:25 +00:00
|
|
|
"Unexpected type {} of Annoy index. Only Array(Float32) and Tuple(Float32) are supported",
|
2022-10-19 12:35:47 +00:00
|
|
|
column_data_type_ptr->getName());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
else
|
|
|
|
throw Exception(
|
|
|
|
ErrorCodes::ILLEGAL_COLUMN,
|
2023-05-25 20:05:25 +00:00
|
|
|
"Unexpected type {} of Annoy index. Only Array(Float32) and Tuple(Float32) are supported",
|
2022-10-19 12:35:47 +00:00
|
|
|
column_data_type_ptr->getName());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
2022-08-30 15:26:56 +00:00
|
|
|
void annoyIndexValidator(const IndexDescription & index, bool /* attach */)
|
|
|
|
{
|
2022-10-28 17:03:35 +00:00
|
|
|
if (index.arguments.size() > 2)
|
2023-01-23 21:13:58 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index must not have more than two parameters");
|
2023-05-25 19:31:34 +00:00
|
|
|
|
2022-10-30 14:57:10 +00:00
|
|
|
if (!index.arguments.empty() && index.arguments[0].getType() != Field::Types::UInt64
|
2022-10-28 17:03:35 +00:00
|
|
|
&& index.arguments[0].getType() != Field::Types::String)
|
2023-05-25 20:05:25 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index first argument must be UInt64 or String");
|
2023-05-25 19:31:34 +00:00
|
|
|
|
2022-10-28 17:03:35 +00:00
|
|
|
if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::String)
|
2023-05-25 20:05:25 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index second argument must be String");
|
2022-10-19 12:35:47 +00:00
|
|
|
|
|
|
|
if (index.column_names.size() != 1 || index.data_types.size() != 1)
|
2023-01-23 21:13:58 +00:00
|
|
|
throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Annoy indexes must be created on a single column");
|
2022-10-19 12:35:47 +00:00
|
|
|
|
|
|
|
assertIndexColumnsType(index.sample_block);
|
2022-08-30 15:26:56 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
2023-05-25 19:31:34 +00:00
|
|
|
|
2023-05-25 19:36:27 +00:00
|
|
|
#endif
|