Remove useless templatization

Makes the code cleaner, compile faster, and the binary smaller.
This commit is contained in:
Robert Schulze 2024-08-09 11:18:01 +00:00
parent 4f23f7754b
commit 8853b3359b
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
3 changed files with 75 additions and 94 deletions

View File

@ -9,9 +9,6 @@
namespace DB namespace DB
{ {
static constexpr auto DISTANCE_FUNCTION_L2 = "L2Distance";
static constexpr auto DISTANCE_FUNCTION_COSINE = "cosineDistance";
/// Approximate Nearest Neighbour queries have a similar structure: /// Approximate Nearest Neighbour queries have a similar structure:
/// - reference vector from which all distances are calculated /// - reference vector from which all distances are calculated
/// - metric name, e.g L2Distance /// - metric name, e.g L2Distance

View File

@ -41,22 +41,37 @@ namespace ErrorCodes
namespace namespace
{ {
std::unordered_map<String, unum::usearch::metric_kind_t> nameToMetricKind = {
{"L2Distance", unum::usearch::metric_kind_t::l2sq_k},
{"cosineDistance", unum::usearch::metric_kind_t::cos_k}};
std::unordered_map<String, unum::usearch::scalar_kind_t> nameToScalarKind = { std::unordered_map<String, unum::usearch::scalar_kind_t> nameToScalarKind = {
{"f64", unum::usearch::scalar_kind_t::f64_k}, {"f64", unum::usearch::scalar_kind_t::f64_k},
{"f32", unum::usearch::scalar_kind_t::f32_k}, {"f32", unum::usearch::scalar_kind_t::f32_k},
{"f16", unum::usearch::scalar_kind_t::f16_k}, {"f16", unum::usearch::scalar_kind_t::f16_k},
{"i8", unum::usearch::scalar_kind_t::i8_k}}; {"i8", unum::usearch::scalar_kind_t::i8_k}};
template <typename T>
String keysAsString(const T & t)
{
String result;
for (const auto & [k, _] : t)
{
if (!result.empty())
result += ", ";
result += k;
}
return result;
} }
template <unum::usearch::metric_kind_t Metric> }
USearchIndexWithSerialization<Metric>::USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind)
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, Metric, scalar_kind))) USearchIndexWithSerialization::USearchIndexWithSerialization(size_t dimensions, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind)
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, metric_kind, scalar_kind)))
{ {
} }
template <unum::usearch::metric_kind_t Metric> void USearchIndexWithSerialization::serialize(WriteBuffer & ostr) const
void USearchIndexWithSerialization<Metric>::serialize(WriteBuffer & ostr) const
{ {
auto callback = [&ostr](void * from, size_t n) auto callback = [&ostr](void * from, size_t n)
{ {
@ -69,8 +84,7 @@ void USearchIndexWithSerialization<Metric>::serialize(WriteBuffer & ostr) const
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not save USearch index, error: " + String(result.error.release())); throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not save USearch index, error: " + String(result.error.release()));
} }
template <unum::usearch::metric_kind_t Metric> void USearchIndexWithSerialization::deserialize(ReadBuffer & istr)
void USearchIndexWithSerialization<Metric>::deserialize(ReadBuffer & istr)
{ {
auto callback = [&istr](void * from, size_t n) auto callback = [&istr](void * from, size_t n)
{ {
@ -81,33 +95,34 @@ void USearchIndexWithSerialization<Metric>::deserialize(ReadBuffer & istr)
Base::load_from_stream(callback); Base::load_from_stream(callback);
} }
template <unum::usearch::metric_kind_t Metric> MergeTreeIndexGranuleUSearch::MergeTreeIndexGranuleUSearch(
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
const String & index_name_, const String & index_name_,
const Block & index_sample_block_, const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_) unum::usearch::scalar_kind_t scalar_kind_)
: index_name(index_name_) : index_name(index_name_)
, index_sample_block(index_sample_block_) , index_sample_block(index_sample_block_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_) , scalar_kind(scalar_kind_)
, index(nullptr) , index(nullptr)
{ {
} }
template <unum::usearch::metric_kind_t Metric> MergeTreeIndexGranuleUSearch::MergeTreeIndexGranuleUSearch(
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
const String & index_name_, const String & index_name_,
const Block & index_sample_block_, const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_, unum::usearch::scalar_kind_t scalar_kind_,
USearchIndexWithSerializationPtr<Metric> index_) USearchIndexWithSerializationPtr index_)
: index_name(index_name_) : index_name(index_name_)
, index_sample_block(index_sample_block_) , index_sample_block(index_sample_block_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_) , scalar_kind(scalar_kind_)
, index(std::move(index_)) , index(std::move(index_))
{ {
} }
template <unum::usearch::metric_kind_t Metric> void MergeTreeIndexGranuleUSearch::serializeBinary(WriteBuffer & ostr) const
void MergeTreeIndexGranuleUSearch<Metric>::serializeBinary(WriteBuffer & ostr) const
{ {
if (empty()) if (empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to write empty minmax index {}", backQuote(index_name)); throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to write empty minmax index {}", backQuote(index_name));
@ -118,36 +133,34 @@ void MergeTreeIndexGranuleUSearch<Metric>::serializeBinary(WriteBuffer & ostr) c
index->serialize(ostr); index->serialize(ostr);
} }
template <unum::usearch::metric_kind_t Metric> void MergeTreeIndexGranuleUSearch::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
void MergeTreeIndexGranuleUSearch<Metric>::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
{ {
UInt64 dimension; UInt64 dimension;
readIntBinary(dimension, istr); readIntBinary(dimension, istr);
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimension, scalar_kind); index = std::make_shared<USearchIndexWithSerialization>(dimension, metric_kind, scalar_kind);
index->deserialize(istr); index->deserialize(istr);
} }
template <unum::usearch::metric_kind_t Metric> MergeTreeIndexAggregatorUSearch::MergeTreeIndexAggregatorUSearch(
MergeTreeIndexAggregatorUSearch<Metric>::MergeTreeIndexAggregatorUSearch(
const String & index_name_, const String & index_name_,
const Block & index_sample_block_, const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_) unum::usearch::scalar_kind_t scalar_kind_)
: index_name(index_name_) : index_name(index_name_)
, index_sample_block(index_sample_block_) , index_sample_block(index_sample_block_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_) , scalar_kind(scalar_kind_)
{ {
} }
template <unum::usearch::metric_kind_t Metric> MergeTreeIndexGranulePtr MergeTreeIndexAggregatorUSearch::getGranuleAndReset()
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorUSearch<Metric>::getGranuleAndReset()
{ {
auto granule = std::make_shared<MergeTreeIndexGranuleUSearch<Metric>>(index_name, index_sample_block, scalar_kind, index); auto granule = std::make_shared<MergeTreeIndexGranuleUSearch>(index_name, index_sample_block, metric_kind, scalar_kind, index);
index = nullptr; index = nullptr;
return granule; return granule;
} }
template <unum::usearch::metric_kind_t Metric> void MergeTreeIndexAggregatorUSearch::update(const Block & block, size_t * pos, size_t limit)
void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t * pos, size_t limit)
{ {
if (*pos >= block.rows()) if (*pos >= block.rows())
throw Exception( throw Exception(
@ -201,7 +214,7 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name); throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
if (!index) if (!index)
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimensions, scalar_kind); index = std::make_shared<USearchIndexWithSerialization>(dimensions, metric_kind, scalar_kind);
/// Add all rows of block /// Add all rows of block
if (!index->reserve(unum::usearch::ceil2(index->size() + num_rows))) if (!index->reserve(unum::usearch::ceil2(index->size() + num_rows)))
@ -227,10 +240,10 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
MergeTreeIndexConditionUSearch::MergeTreeIndexConditionUSearch( MergeTreeIndexConditionUSearch::MergeTreeIndexConditionUSearch(
const IndexDescription & /*index_description*/, const IndexDescription & /*index_description*/,
const SelectQueryInfo & query, const SelectQueryInfo & query,
const String & distance_function_, unum::usearch::metric_kind_t metric_kind_,
ContextPtr context) ContextPtr context)
: ann_condition(query, context) : ann_condition(query, context)
, distance_function(distance_function_) , metric_kind(metric_kind_)
{ {
} }
@ -241,31 +254,28 @@ bool MergeTreeIndexConditionUSearch::mayBeTrueOnGranule(MergeTreeIndexGranulePtr
bool MergeTreeIndexConditionUSearch::alwaysUnknownOrTrue() const bool MergeTreeIndexConditionUSearch::alwaysUnknownOrTrue() const
{ {
return ann_condition.alwaysUnknownOrTrue(distance_function); String index_distance_function;
switch (metric_kind)
{
case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
default: std::unreachable();
}
return ann_condition.alwaysUnknownOrTrue(index_distance_function);
} }
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return getUsefulRangesImpl<unum::usearch::metric_kind_t::l2sq_k>(idx_granule);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return getUsefulRangesImpl<unum::usearch::metric_kind_t::cos_k>(idx_granule);
std::unreachable();
}
template <unum::usearch::metric_kind_t Metric>
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
{ {
const UInt64 limit = ann_condition.getLimit(); const UInt64 limit = ann_condition.getLimit();
const UInt64 index_granularity = ann_condition.getIndexGranularity(); const UInt64 index_granularity = ann_condition.getIndexGranularity();
const std::vector<float> reference_vector = ann_condition.getReferenceVector(); const std::vector<float> reference_vector = ann_condition.getReferenceVector();
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleUSearch<Metric>>(idx_granule); const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleUSearch>(idx_granule);
if (granule == nullptr) if (granule == nullptr)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type"); throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
const USearchIndexWithSerializationPtr<Metric> index = granule->index; const USearchIndexWithSerializationPtr index = granule->index;
if (ann_condition.getDimensions() != index->dimensions()) if (ann_condition.getDimensions() != index->dimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) " throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
@ -296,34 +306,26 @@ std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTre
return granules; return granules;
} }
MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_) MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_)
: IMergeTreeIndex(index_) : IMergeTreeIndex(index_)
, distance_function(distance_function_) , metric_kind(metric_kind_)
, scalar_kind(scalar_kind_) , scalar_kind(scalar_kind_)
{ {
} }
MergeTreeIndexGranulePtr MergeTreeIndexUSearch::createIndexGranule() const MergeTreeIndexGranulePtr MergeTreeIndexUSearch::createIndexGranule() const
{ {
if (distance_function == DISTANCE_FUNCTION_L2) return std::make_shared<MergeTreeIndexGranuleUSearch>(index.name, index.sample_block, metric_kind, scalar_kind);
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
std::unreachable();
} }
MergeTreeIndexAggregatorPtr MergeTreeIndexUSearch::createIndexAggregator(const MergeTreeWriterSettings & /*settings*/) const MergeTreeIndexAggregatorPtr MergeTreeIndexUSearch::createIndexAggregator(const MergeTreeWriterSettings & /*settings*/) const
{ {
if (distance_function == DISTANCE_FUNCTION_L2) return std::make_shared<MergeTreeIndexAggregatorUSearch>(index.name, index.sample_block, metric_kind, scalar_kind);
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
std::unreachable();
} }
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
{ {
return std::make_shared<MergeTreeIndexConditionUSearch>(index, query, distance_function, context); return std::make_shared<MergeTreeIndexConditionUSearch>(index, query, metric_kind, context);
}; };
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const ActionsDAG *, ContextPtr) const MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const ActionsDAG *, ContextPtr) const
@ -333,17 +335,17 @@ MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const Act
MergeTreeIndexPtr usearchIndexCreator(const IndexDescription & index) MergeTreeIndexPtr usearchIndexCreator(const IndexDescription & index)
{ {
static constexpr auto default_distance_function = DISTANCE_FUNCTION_L2; static constexpr auto default_metric_kind = unum::usearch::metric_kind_t::l2sq_k;
String distance_function = default_distance_function; auto metric_kind = default_metric_kind;
if (!index.arguments.empty()) if (!index.arguments.empty())
distance_function = index.arguments[0].safeGet<String>(); metric_kind = nameToMetricKind.at(index.arguments[0].safeGet<String>());
static constexpr auto default_scalar_kind = unum::usearch::scalar_kind_t::f16_k; static constexpr auto default_scalar_kind = unum::usearch::scalar_kind_t::f16_k;
auto scalar_kind = default_scalar_kind; auto scalar_kind = default_scalar_kind;
if (index.arguments.size() > 1) if (index.arguments.size() > 1)
scalar_kind = nameToScalarKind.at(index.arguments[1].safeGet<String>()); scalar_kind = nameToScalarKind.at(index.arguments[1].safeGet<String>());
return std::make_shared<MergeTreeIndexUSearch>(index, distance_function, scalar_kind); return std::make_shared<MergeTreeIndexUSearch>(index, metric_kind, scalar_kind);
} }
void usearchIndexValidator(const IndexDescription & index, bool /* attach */) void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
@ -365,26 +367,13 @@ void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
/// Check that a supported metric was passed as first argument /// Check that a supported metric was passed as first argument
if (!index.arguments.empty()) if (!index.arguments.empty() && !nameToMetricKind.contains(index.arguments[0].safeGet<String>()))
{ throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized metric kind (first argument) for vector index. Supported kinds are: {}", keysAsString(nameToMetricKind));
String distance_name = index.arguments[0].safeGet<String>();
if (distance_name != DISTANCE_FUNCTION_L2 && distance_name != DISTANCE_FUNCTION_COSINE)
throw Exception(ErrorCodes::INCORRECT_DATA, "USearch index only supports distance functions '{}' and '{}'", DISTANCE_FUNCTION_L2, DISTANCE_FUNCTION_COSINE);
}
/// Check that a supported kind was passed as a second argument /// Check that a supported kind was passed as a second argument
if (index.arguments.size() > 1 && !nameToScalarKind.contains(index.arguments[1].safeGet<String>())) if (index.arguments.size() > 1 && !nameToScalarKind.contains(index.arguments[1].safeGet<String>()))
{ throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized scalar kind (second argument) for vector index. Supported kinds are: {}", keysAsString(nameToScalarKind));
String supported_kinds;
for (const auto & [name, kind] : nameToScalarKind)
{
if (!supported_kinds.empty())
supported_kinds += ", ";
supported_kinds += name;
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized scalar kind (second argument) for USearch index. Supported kinds are: {}", supported_kinds);
}
/// Check data type of indexed column: /// Check data type of indexed column:

View File

@ -15,26 +15,23 @@ namespace DB
using USearchIndex = unum::usearch::index_dense_gt</*key_at*/ uint32_t, /*compressed_slot_at*/ uint32_t>; using USearchIndex = unum::usearch::index_dense_gt</*key_at*/ uint32_t, /*compressed_slot_at*/ uint32_t>;
template <unum::usearch::metric_kind_t Metric>
class USearchIndexWithSerialization : public USearchIndex class USearchIndexWithSerialization : public USearchIndex
{ {
using Base = USearchIndex; using Base = USearchIndex;
public: public:
USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind); USearchIndexWithSerialization(size_t dimensions, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind);
void serialize(WriteBuffer & ostr) const; void serialize(WriteBuffer & ostr) const;
void deserialize(ReadBuffer & istr); void deserialize(ReadBuffer & istr);
}; };
template <unum::usearch::metric_kind_t Metric> using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization>;
using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization<Metric>>;
template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
{ {
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_); MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind_);
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_, USearchIndexWithSerializationPtr<Metric> index_); MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind_, USearchIndexWithSerializationPtr index_);
~MergeTreeIndexGranuleUSearch() override = default; ~MergeTreeIndexGranuleUSearch() override = default;
@ -45,15 +42,15 @@ struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
const String index_name; const String index_name;
const Block index_sample_block; const Block index_sample_block;
const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind; const unum::usearch::scalar_kind_t scalar_kind;
USearchIndexWithSerializationPtr<Metric> index; USearchIndexWithSerializationPtr index;
}; };
template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
{ {
MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block, unum::usearch::scalar_kind_t scalar_kind_); MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_);
~MergeTreeIndexAggregatorUSearch() override = default; ~MergeTreeIndexAggregatorUSearch() override = default;
bool empty() const override { return !index || index->size() == 0; } bool empty() const override { return !index || index->size() == 0; }
@ -62,8 +59,9 @@ struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
const String index_name; const String index_name;
const Block index_sample_block; const Block index_sample_block;
const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind; const unum::usearch::scalar_kind_t scalar_kind;
USearchIndexWithSerializationPtr<Metric> index; USearchIndexWithSerializationPtr index;
}; };
@ -73,7 +71,7 @@ public:
MergeTreeIndexConditionUSearch( MergeTreeIndexConditionUSearch(
const IndexDescription & index_description, const IndexDescription & index_description,
const SelectQueryInfo & query, const SelectQueryInfo & query,
const String & distance_function, unum::usearch::metric_kind_t metric_kind_,
ContextPtr context); ContextPtr context);
~MergeTreeIndexConditionUSearch() override = default; ~MergeTreeIndexConditionUSearch() override = default;
@ -83,18 +81,15 @@ public:
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override; std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;
private: private:
template <unum::usearch::metric_kind_t Metric>
std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;
const ApproximateNearestNeighborCondition ann_condition; const ApproximateNearestNeighborCondition ann_condition;
const String distance_function; const unum::usearch::metric_kind_t metric_kind;
}; };
class MergeTreeIndexUSearch : public IMergeTreeIndex class MergeTreeIndexUSearch : public IMergeTreeIndex
{ {
public: public:
MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_); MergeTreeIndexUSearch(const IndexDescription & index_, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_);
~MergeTreeIndexUSearch() override = default; ~MergeTreeIndexUSearch() override = default;
@ -105,7 +100,7 @@ public:
bool isVectorSearch() const override { return true; } bool isVectorSearch() const override { return true; }
private: private:
const String distance_function; const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind; const unum::usearch::scalar_kind_t scalar_kind;
}; };