Remove useless templatization

Makes the code cleaner, compile faster, and the binary smaller.
This commit is contained in:
Robert Schulze 2024-08-09 11:18:01 +00:00
parent 4f23f7754b
commit 8853b3359b
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
3 changed files with 75 additions and 94 deletions

View File

@ -9,9 +9,6 @@
namespace DB
{
static constexpr auto DISTANCE_FUNCTION_L2 = "L2Distance";
static constexpr auto DISTANCE_FUNCTION_COSINE = "cosineDistance";
/// Approximate Nearest Neighbour queries have a similar structure:
/// - reference vector from which all distances are calculated
/// - metric name, e.g L2Distance

View File

@ -41,22 +41,37 @@ namespace ErrorCodes
namespace
{
std::unordered_map<String, unum::usearch::metric_kind_t> nameToMetricKind = {
{"L2Distance", unum::usearch::metric_kind_t::l2sq_k},
{"cosineDistance", unum::usearch::metric_kind_t::cos_k}};
std::unordered_map<String, unum::usearch::scalar_kind_t> nameToScalarKind = {
{"f64", unum::usearch::scalar_kind_t::f64_k},
{"f32", unum::usearch::scalar_kind_t::f32_k},
{"f16", unum::usearch::scalar_kind_t::f16_k},
{"i8", unum::usearch::scalar_kind_t::i8_k}};
template <typename T>
String keysAsString(const T & t)
{
String result;
for (const auto & [k, _] : t)
{
if (!result.empty())
result += ", ";
result += k;
}
return result;
}
template <unum::usearch::metric_kind_t Metric>
USearchIndexWithSerialization<Metric>::USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind)
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, Metric, scalar_kind)))
}
USearchIndexWithSerialization::USearchIndexWithSerialization(size_t dimensions, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind)
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, metric_kind, scalar_kind)))
{
}
template <unum::usearch::metric_kind_t Metric>
void USearchIndexWithSerialization<Metric>::serialize(WriteBuffer & ostr) const
void USearchIndexWithSerialization::serialize(WriteBuffer & ostr) const
{
auto callback = [&ostr](void * from, size_t n)
{
@ -69,8 +84,7 @@ void USearchIndexWithSerialization<Metric>::serialize(WriteBuffer & ostr) const
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not save USearch index, error: " + String(result.error.release()));
}
template <unum::usearch::metric_kind_t Metric>
void USearchIndexWithSerialization<Metric>::deserialize(ReadBuffer & istr)
void USearchIndexWithSerialization::deserialize(ReadBuffer & istr)
{
auto callback = [&istr](void * from, size_t n)
{
@ -81,33 +95,34 @@ void USearchIndexWithSerialization<Metric>::deserialize(ReadBuffer & istr)
Base::load_from_stream(callback);
}
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
MergeTreeIndexGranuleUSearch::MergeTreeIndexGranuleUSearch(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_)
, index(nullptr)
{
}
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
MergeTreeIndexGranuleUSearch::MergeTreeIndexGranuleUSearch(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
USearchIndexWithSerializationPtr<Metric> index_)
USearchIndexWithSerializationPtr index_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_)
, index(std::move(index_))
{
}
template <unum::usearch::metric_kind_t Metric>
void MergeTreeIndexGranuleUSearch<Metric>::serializeBinary(WriteBuffer & ostr) const
void MergeTreeIndexGranuleUSearch::serializeBinary(WriteBuffer & ostr) const
{
if (empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to write empty minmax index {}", backQuote(index_name));
@ -118,36 +133,34 @@ void MergeTreeIndexGranuleUSearch<Metric>::serializeBinary(WriteBuffer & ostr) c
index->serialize(ostr);
}
template <unum::usearch::metric_kind_t Metric>
void MergeTreeIndexGranuleUSearch<Metric>::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
void MergeTreeIndexGranuleUSearch::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
{
UInt64 dimension;
readIntBinary(dimension, istr);
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimension, scalar_kind);
index = std::make_shared<USearchIndexWithSerialization>(dimension, metric_kind, scalar_kind);
index->deserialize(istr);
}
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexAggregatorUSearch<Metric>::MergeTreeIndexAggregatorUSearch(
MergeTreeIndexAggregatorUSearch::MergeTreeIndexAggregatorUSearch(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_)
{
}
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorUSearch<Metric>::getGranuleAndReset()
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorUSearch::getGranuleAndReset()
{
auto granule = std::make_shared<MergeTreeIndexGranuleUSearch<Metric>>(index_name, index_sample_block, scalar_kind, index);
auto granule = std::make_shared<MergeTreeIndexGranuleUSearch>(index_name, index_sample_block, metric_kind, scalar_kind, index);
index = nullptr;
return granule;
}
template <unum::usearch::metric_kind_t Metric>
void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t * pos, size_t limit)
void MergeTreeIndexAggregatorUSearch::update(const Block & block, size_t * pos, size_t limit)
{
if (*pos >= block.rows())
throw Exception(
@ -201,7 +214,7 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
if (!index)
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimensions, scalar_kind);
index = std::make_shared<USearchIndexWithSerialization>(dimensions, metric_kind, scalar_kind);
/// Add all rows of block
if (!index->reserve(unum::usearch::ceil2(index->size() + num_rows)))
@ -227,10 +240,10 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
MergeTreeIndexConditionUSearch::MergeTreeIndexConditionUSearch(
const IndexDescription & /*index_description*/,
const SelectQueryInfo & query,
const String & distance_function_,
unum::usearch::metric_kind_t metric_kind_,
ContextPtr context)
: ann_condition(query, context)
, distance_function(distance_function_)
, metric_kind(metric_kind_)
{
}
@ -241,31 +254,28 @@ bool MergeTreeIndexConditionUSearch::mayBeTrueOnGranule(MergeTreeIndexGranulePtr
bool MergeTreeIndexConditionUSearch::alwaysUnknownOrTrue() const
{
return ann_condition.alwaysUnknownOrTrue(distance_function);
String index_distance_function;
switch (metric_kind)
{
case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
default: std::unreachable();
}
return ann_condition.alwaysUnknownOrTrue(index_distance_function);
}
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return getUsefulRangesImpl<unum::usearch::metric_kind_t::l2sq_k>(idx_granule);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return getUsefulRangesImpl<unum::usearch::metric_kind_t::cos_k>(idx_granule);
std::unreachable();
}
template <unum::usearch::metric_kind_t Metric>
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
{
const UInt64 limit = ann_condition.getLimit();
const UInt64 index_granularity = ann_condition.getIndexGranularity();
const std::vector<float> reference_vector = ann_condition.getReferenceVector();
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleUSearch<Metric>>(idx_granule);
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleUSearch>(idx_granule);
if (granule == nullptr)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
const USearchIndexWithSerializationPtr<Metric> index = granule->index;
const USearchIndexWithSerializationPtr index = granule->index;
if (ann_condition.getDimensions() != index->dimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
@ -296,34 +306,26 @@ std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTre
return granules;
}
MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_)
MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_)
: IMergeTreeIndex(index_)
, distance_function(distance_function_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_)
{
}
MergeTreeIndexGranulePtr MergeTreeIndexUSearch::createIndexGranule() const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
std::unreachable();
return std::make_shared<MergeTreeIndexGranuleUSearch>(index.name, index.sample_block, metric_kind, scalar_kind);
}
MergeTreeIndexAggregatorPtr MergeTreeIndexUSearch::createIndexAggregator(const MergeTreeWriterSettings & /*settings*/) const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
std::unreachable();
return std::make_shared<MergeTreeIndexAggregatorUSearch>(index.name, index.sample_block, metric_kind, scalar_kind);
}
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionUSearch>(index, query, distance_function, context);
return std::make_shared<MergeTreeIndexConditionUSearch>(index, query, metric_kind, context);
};
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const ActionsDAG *, ContextPtr) const
@ -333,17 +335,17 @@ MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const Act
MergeTreeIndexPtr usearchIndexCreator(const IndexDescription & index)
{
static constexpr auto default_distance_function = DISTANCE_FUNCTION_L2;
String distance_function = default_distance_function;
static constexpr auto default_metric_kind = unum::usearch::metric_kind_t::l2sq_k;
auto metric_kind = default_metric_kind;
if (!index.arguments.empty())
distance_function = index.arguments[0].safeGet<String>();
metric_kind = nameToMetricKind.at(index.arguments[0].safeGet<String>());
static constexpr auto default_scalar_kind = unum::usearch::scalar_kind_t::f16_k;
auto scalar_kind = default_scalar_kind;
if (index.arguments.size() > 1)
scalar_kind = nameToScalarKind.at(index.arguments[1].safeGet<String>());
return std::make_shared<MergeTreeIndexUSearch>(index, distance_function, scalar_kind);
return std::make_shared<MergeTreeIndexUSearch>(index, metric_kind, scalar_kind);
}
void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
@ -365,26 +367,13 @@ void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
/// Check that a supported metric was passed as first argument
if (!index.arguments.empty())
{
String distance_name = index.arguments[0].safeGet<String>();
if (distance_name != DISTANCE_FUNCTION_L2 && distance_name != DISTANCE_FUNCTION_COSINE)
throw Exception(ErrorCodes::INCORRECT_DATA, "USearch index only supports distance functions '{}' and '{}'", DISTANCE_FUNCTION_L2, DISTANCE_FUNCTION_COSINE);
}
if (!index.arguments.empty() && !nameToMetricKind.contains(index.arguments[0].safeGet<String>()))
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized metric kind (first argument) for vector index. Supported kinds are: {}", keysAsString(nameToMetricKind));
/// Check that a supported kind was passed as a second argument
if (index.arguments.size() > 1 && !nameToScalarKind.contains(index.arguments[1].safeGet<String>()))
{
String supported_kinds;
for (const auto & [name, kind] : nameToScalarKind)
{
if (!supported_kinds.empty())
supported_kinds += ", ";
supported_kinds += name;
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized scalar kind (second argument) for USearch index. Supported kinds are: {}", supported_kinds);
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized scalar kind (second argument) for vector index. Supported kinds are: {}", keysAsString(nameToScalarKind));
/// Check data type of indexed column:

View File

@ -15,26 +15,23 @@ namespace DB
using USearchIndex = unum::usearch::index_dense_gt</*key_at*/ uint32_t, /*compressed_slot_at*/ uint32_t>;
template <unum::usearch::metric_kind_t Metric>
class USearchIndexWithSerialization : public USearchIndex
{
using Base = USearchIndex;
public:
USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind);
USearchIndexWithSerialization(size_t dimensions, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind);
void serialize(WriteBuffer & ostr) const;
void deserialize(ReadBuffer & istr);
};
template <unum::usearch::metric_kind_t Metric>
using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization<Metric>>;
using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization>;
template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
{
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_);
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_, USearchIndexWithSerializationPtr<Metric> index_);
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind_);
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind_, USearchIndexWithSerializationPtr index_);
~MergeTreeIndexGranuleUSearch() override = default;
@ -45,15 +42,15 @@ struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
const String index_name;
const Block index_sample_block;
const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind;
USearchIndexWithSerializationPtr<Metric> index;
USearchIndexWithSerializationPtr index;
};
template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
{
MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block, unum::usearch::scalar_kind_t scalar_kind_);
MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_);
~MergeTreeIndexAggregatorUSearch() override = default;
bool empty() const override { return !index || index->size() == 0; }
@ -62,8 +59,9 @@ struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
const String index_name;
const Block index_sample_block;
const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind;
USearchIndexWithSerializationPtr<Metric> index;
USearchIndexWithSerializationPtr index;
};
@ -73,7 +71,7 @@ public:
MergeTreeIndexConditionUSearch(
const IndexDescription & index_description,
const SelectQueryInfo & query,
const String & distance_function,
unum::usearch::metric_kind_t metric_kind_,
ContextPtr context);
~MergeTreeIndexConditionUSearch() override = default;
@ -83,18 +81,15 @@ public:
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;
private:
template <unum::usearch::metric_kind_t Metric>
std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;
const ApproximateNearestNeighborCondition ann_condition;
const String distance_function;
const unum::usearch::metric_kind_t metric_kind;
};
class MergeTreeIndexUSearch : public IMergeTreeIndex
{
public:
MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_);
MergeTreeIndexUSearch(const IndexDescription & index_, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_);
~MergeTreeIndexUSearch() override = default;
@ -105,7 +100,7 @@ public:
bool isVectorSearch() const override { return true; }
private:
const String distance_function;
const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind;
};