mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-27 10:02:01 +00:00
Remove useless templatization
Makes the code cleaner, compile faster, and the binary smaller.
This commit is contained in:
parent
4f23f7754b
commit
8853b3359b
@ -9,9 +9,6 @@
|
|||||||
namespace DB
|
namespace DB
|
||||||
{
|
{
|
||||||
|
|
||||||
static constexpr auto DISTANCE_FUNCTION_L2 = "L2Distance";
|
|
||||||
static constexpr auto DISTANCE_FUNCTION_COSINE = "cosineDistance";
|
|
||||||
|
|
||||||
/// Approximate Nearest Neighbour queries have a similar structure:
|
/// Approximate Nearest Neighbour queries have a similar structure:
|
||||||
/// - reference vector from which all distances are calculated
|
/// - reference vector from which all distances are calculated
|
||||||
/// - metric name, e.g L2Distance
|
/// - metric name, e.g L2Distance
|
||||||
|
@ -41,22 +41,37 @@ namespace ErrorCodes
|
|||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
|
std::unordered_map<String, unum::usearch::metric_kind_t> nameToMetricKind = {
|
||||||
|
{"L2Distance", unum::usearch::metric_kind_t::l2sq_k},
|
||||||
|
{"cosineDistance", unum::usearch::metric_kind_t::cos_k}};
|
||||||
|
|
||||||
std::unordered_map<String, unum::usearch::scalar_kind_t> nameToScalarKind = {
|
std::unordered_map<String, unum::usearch::scalar_kind_t> nameToScalarKind = {
|
||||||
{"f64", unum::usearch::scalar_kind_t::f64_k},
|
{"f64", unum::usearch::scalar_kind_t::f64_k},
|
||||||
{"f32", unum::usearch::scalar_kind_t::f32_k},
|
{"f32", unum::usearch::scalar_kind_t::f32_k},
|
||||||
{"f16", unum::usearch::scalar_kind_t::f16_k},
|
{"f16", unum::usearch::scalar_kind_t::f16_k},
|
||||||
{"i8", unum::usearch::scalar_kind_t::i8_k}};
|
{"i8", unum::usearch::scalar_kind_t::i8_k}};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
String keysAsString(const T & t)
|
||||||
|
{
|
||||||
|
String result;
|
||||||
|
for (const auto & [k, _] : t)
|
||||||
|
{
|
||||||
|
if (!result.empty())
|
||||||
|
result += ", ";
|
||||||
|
result += k;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
}
|
||||||
USearchIndexWithSerialization<Metric>::USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind)
|
|
||||||
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, Metric, scalar_kind)))
|
USearchIndexWithSerialization::USearchIndexWithSerialization(size_t dimensions, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind)
|
||||||
|
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, metric_kind, scalar_kind)))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
void USearchIndexWithSerialization::serialize(WriteBuffer & ostr) const
|
||||||
void USearchIndexWithSerialization<Metric>::serialize(WriteBuffer & ostr) const
|
|
||||||
{
|
{
|
||||||
auto callback = [&ostr](void * from, size_t n)
|
auto callback = [&ostr](void * from, size_t n)
|
||||||
{
|
{
|
||||||
@ -69,8 +84,7 @@ void USearchIndexWithSerialization<Metric>::serialize(WriteBuffer & ostr) const
|
|||||||
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not save USearch index, error: " + String(result.error.release()));
|
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not save USearch index, error: " + String(result.error.release()));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
void USearchIndexWithSerialization::deserialize(ReadBuffer & istr)
|
||||||
void USearchIndexWithSerialization<Metric>::deserialize(ReadBuffer & istr)
|
|
||||||
{
|
{
|
||||||
auto callback = [&istr](void * from, size_t n)
|
auto callback = [&istr](void * from, size_t n)
|
||||||
{
|
{
|
||||||
@ -81,33 +95,34 @@ void USearchIndexWithSerialization<Metric>::deserialize(ReadBuffer & istr)
|
|||||||
Base::load_from_stream(callback);
|
Base::load_from_stream(callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
MergeTreeIndexGranuleUSearch::MergeTreeIndexGranuleUSearch(
|
||||||
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
|
|
||||||
const String & index_name_,
|
const String & index_name_,
|
||||||
const Block & index_sample_block_,
|
const Block & index_sample_block_,
|
||||||
|
unum::usearch::metric_kind_t metric_kind_,
|
||||||
unum::usearch::scalar_kind_t scalar_kind_)
|
unum::usearch::scalar_kind_t scalar_kind_)
|
||||||
: index_name(index_name_)
|
: index_name(index_name_)
|
||||||
, index_sample_block(index_sample_block_)
|
, index_sample_block(index_sample_block_)
|
||||||
|
, metric_kind(metric_kind_)
|
||||||
, scalar_kind(scalar_kind_)
|
, scalar_kind(scalar_kind_)
|
||||||
, index(nullptr)
|
, index(nullptr)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
MergeTreeIndexGranuleUSearch::MergeTreeIndexGranuleUSearch(
|
||||||
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
|
|
||||||
const String & index_name_,
|
const String & index_name_,
|
||||||
const Block & index_sample_block_,
|
const Block & index_sample_block_,
|
||||||
|
unum::usearch::metric_kind_t metric_kind_,
|
||||||
unum::usearch::scalar_kind_t scalar_kind_,
|
unum::usearch::scalar_kind_t scalar_kind_,
|
||||||
USearchIndexWithSerializationPtr<Metric> index_)
|
USearchIndexWithSerializationPtr index_)
|
||||||
: index_name(index_name_)
|
: index_name(index_name_)
|
||||||
, index_sample_block(index_sample_block_)
|
, index_sample_block(index_sample_block_)
|
||||||
|
, metric_kind(metric_kind_)
|
||||||
, scalar_kind(scalar_kind_)
|
, scalar_kind(scalar_kind_)
|
||||||
, index(std::move(index_))
|
, index(std::move(index_))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
void MergeTreeIndexGranuleUSearch::serializeBinary(WriteBuffer & ostr) const
|
||||||
void MergeTreeIndexGranuleUSearch<Metric>::serializeBinary(WriteBuffer & ostr) const
|
|
||||||
{
|
{
|
||||||
if (empty())
|
if (empty())
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to write empty minmax index {}", backQuote(index_name));
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to write empty minmax index {}", backQuote(index_name));
|
||||||
@ -118,36 +133,34 @@ void MergeTreeIndexGranuleUSearch<Metric>::serializeBinary(WriteBuffer & ostr) c
|
|||||||
index->serialize(ostr);
|
index->serialize(ostr);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
void MergeTreeIndexGranuleUSearch::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
|
||||||
void MergeTreeIndexGranuleUSearch<Metric>::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
|
|
||||||
{
|
{
|
||||||
UInt64 dimension;
|
UInt64 dimension;
|
||||||
readIntBinary(dimension, istr);
|
readIntBinary(dimension, istr);
|
||||||
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimension, scalar_kind);
|
index = std::make_shared<USearchIndexWithSerialization>(dimension, metric_kind, scalar_kind);
|
||||||
index->deserialize(istr);
|
index->deserialize(istr);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
MergeTreeIndexAggregatorUSearch::MergeTreeIndexAggregatorUSearch(
|
||||||
MergeTreeIndexAggregatorUSearch<Metric>::MergeTreeIndexAggregatorUSearch(
|
|
||||||
const String & index_name_,
|
const String & index_name_,
|
||||||
const Block & index_sample_block_,
|
const Block & index_sample_block_,
|
||||||
|
unum::usearch::metric_kind_t metric_kind_,
|
||||||
unum::usearch::scalar_kind_t scalar_kind_)
|
unum::usearch::scalar_kind_t scalar_kind_)
|
||||||
: index_name(index_name_)
|
: index_name(index_name_)
|
||||||
, index_sample_block(index_sample_block_)
|
, index_sample_block(index_sample_block_)
|
||||||
|
, metric_kind(metric_kind_)
|
||||||
, scalar_kind(scalar_kind_)
|
, scalar_kind(scalar_kind_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorUSearch::getGranuleAndReset()
|
||||||
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorUSearch<Metric>::getGranuleAndReset()
|
|
||||||
{
|
{
|
||||||
auto granule = std::make_shared<MergeTreeIndexGranuleUSearch<Metric>>(index_name, index_sample_block, scalar_kind, index);
|
auto granule = std::make_shared<MergeTreeIndexGranuleUSearch>(index_name, index_sample_block, metric_kind, scalar_kind, index);
|
||||||
index = nullptr;
|
index = nullptr;
|
||||||
return granule;
|
return granule;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
void MergeTreeIndexAggregatorUSearch::update(const Block & block, size_t * pos, size_t limit)
|
||||||
void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t * pos, size_t limit)
|
|
||||||
{
|
{
|
||||||
if (*pos >= block.rows())
|
if (*pos >= block.rows())
|
||||||
throw Exception(
|
throw Exception(
|
||||||
@ -201,7 +214,7 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
|
|||||||
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
|
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
|
||||||
|
|
||||||
if (!index)
|
if (!index)
|
||||||
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimensions, scalar_kind);
|
index = std::make_shared<USearchIndexWithSerialization>(dimensions, metric_kind, scalar_kind);
|
||||||
|
|
||||||
/// Add all rows of block
|
/// Add all rows of block
|
||||||
if (!index->reserve(unum::usearch::ceil2(index->size() + num_rows)))
|
if (!index->reserve(unum::usearch::ceil2(index->size() + num_rows)))
|
||||||
@ -227,10 +240,10 @@ void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t
|
|||||||
MergeTreeIndexConditionUSearch::MergeTreeIndexConditionUSearch(
|
MergeTreeIndexConditionUSearch::MergeTreeIndexConditionUSearch(
|
||||||
const IndexDescription & /*index_description*/,
|
const IndexDescription & /*index_description*/,
|
||||||
const SelectQueryInfo & query,
|
const SelectQueryInfo & query,
|
||||||
const String & distance_function_,
|
unum::usearch::metric_kind_t metric_kind_,
|
||||||
ContextPtr context)
|
ContextPtr context)
|
||||||
: ann_condition(query, context)
|
: ann_condition(query, context)
|
||||||
, distance_function(distance_function_)
|
, metric_kind(metric_kind_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -241,31 +254,28 @@ bool MergeTreeIndexConditionUSearch::mayBeTrueOnGranule(MergeTreeIndexGranulePtr
|
|||||||
|
|
||||||
bool MergeTreeIndexConditionUSearch::alwaysUnknownOrTrue() const
|
bool MergeTreeIndexConditionUSearch::alwaysUnknownOrTrue() const
|
||||||
{
|
{
|
||||||
return ann_condition.alwaysUnknownOrTrue(distance_function);
|
String index_distance_function;
|
||||||
|
switch (metric_kind)
|
||||||
|
{
|
||||||
|
case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
|
||||||
|
case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
|
||||||
|
default: std::unreachable();
|
||||||
|
}
|
||||||
|
return ann_condition.alwaysUnknownOrTrue(index_distance_function);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
|
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
|
||||||
{
|
|
||||||
if (distance_function == DISTANCE_FUNCTION_L2)
|
|
||||||
return getUsefulRangesImpl<unum::usearch::metric_kind_t::l2sq_k>(idx_granule);
|
|
||||||
else if (distance_function == DISTANCE_FUNCTION_COSINE)
|
|
||||||
return getUsefulRangesImpl<unum::usearch::metric_kind_t::cos_k>(idx_granule);
|
|
||||||
std::unreachable();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
|
||||||
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
|
|
||||||
{
|
{
|
||||||
const UInt64 limit = ann_condition.getLimit();
|
const UInt64 limit = ann_condition.getLimit();
|
||||||
const UInt64 index_granularity = ann_condition.getIndexGranularity();
|
const UInt64 index_granularity = ann_condition.getIndexGranularity();
|
||||||
|
|
||||||
const std::vector<float> reference_vector = ann_condition.getReferenceVector();
|
const std::vector<float> reference_vector = ann_condition.getReferenceVector();
|
||||||
|
|
||||||
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleUSearch<Metric>>(idx_granule);
|
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleUSearch>(idx_granule);
|
||||||
if (granule == nullptr)
|
if (granule == nullptr)
|
||||||
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
|
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
|
||||||
|
|
||||||
const USearchIndexWithSerializationPtr<Metric> index = granule->index;
|
const USearchIndexWithSerializationPtr index = granule->index;
|
||||||
|
|
||||||
if (ann_condition.getDimensions() != index->dimensions())
|
if (ann_condition.getDimensions() != index->dimensions())
|
||||||
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
|
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
|
||||||
@ -296,34 +306,26 @@ std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTre
|
|||||||
return granules;
|
return granules;
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_)
|
MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_)
|
||||||
: IMergeTreeIndex(index_)
|
: IMergeTreeIndex(index_)
|
||||||
, distance_function(distance_function_)
|
, metric_kind(metric_kind_)
|
||||||
, scalar_kind(scalar_kind_)
|
, scalar_kind(scalar_kind_)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeTreeIndexGranulePtr MergeTreeIndexUSearch::createIndexGranule() const
|
MergeTreeIndexGranulePtr MergeTreeIndexUSearch::createIndexGranule() const
|
||||||
{
|
{
|
||||||
if (distance_function == DISTANCE_FUNCTION_L2)
|
return std::make_shared<MergeTreeIndexGranuleUSearch>(index.name, index.sample_block, metric_kind, scalar_kind);
|
||||||
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
|
|
||||||
else if (distance_function == DISTANCE_FUNCTION_COSINE)
|
|
||||||
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
|
|
||||||
std::unreachable();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeTreeIndexAggregatorPtr MergeTreeIndexUSearch::createIndexAggregator(const MergeTreeWriterSettings & /*settings*/) const
|
MergeTreeIndexAggregatorPtr MergeTreeIndexUSearch::createIndexAggregator(const MergeTreeWriterSettings & /*settings*/) const
|
||||||
{
|
{
|
||||||
if (distance_function == DISTANCE_FUNCTION_L2)
|
return std::make_shared<MergeTreeIndexAggregatorUSearch>(index.name, index.sample_block, metric_kind, scalar_kind);
|
||||||
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
|
|
||||||
else if (distance_function == DISTANCE_FUNCTION_COSINE)
|
|
||||||
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
|
|
||||||
std::unreachable();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
|
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
|
||||||
{
|
{
|
||||||
return std::make_shared<MergeTreeIndexConditionUSearch>(index, query, distance_function, context);
|
return std::make_shared<MergeTreeIndexConditionUSearch>(index, query, metric_kind, context);
|
||||||
};
|
};
|
||||||
|
|
||||||
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const ActionsDAG *, ContextPtr) const
|
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const ActionsDAG *, ContextPtr) const
|
||||||
@ -333,17 +335,17 @@ MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const Act
|
|||||||
|
|
||||||
MergeTreeIndexPtr usearchIndexCreator(const IndexDescription & index)
|
MergeTreeIndexPtr usearchIndexCreator(const IndexDescription & index)
|
||||||
{
|
{
|
||||||
static constexpr auto default_distance_function = DISTANCE_FUNCTION_L2;
|
static constexpr auto default_metric_kind = unum::usearch::metric_kind_t::l2sq_k;
|
||||||
String distance_function = default_distance_function;
|
auto metric_kind = default_metric_kind;
|
||||||
if (!index.arguments.empty())
|
if (!index.arguments.empty())
|
||||||
distance_function = index.arguments[0].safeGet<String>();
|
metric_kind = nameToMetricKind.at(index.arguments[0].safeGet<String>());
|
||||||
|
|
||||||
static constexpr auto default_scalar_kind = unum::usearch::scalar_kind_t::f16_k;
|
static constexpr auto default_scalar_kind = unum::usearch::scalar_kind_t::f16_k;
|
||||||
auto scalar_kind = default_scalar_kind;
|
auto scalar_kind = default_scalar_kind;
|
||||||
if (index.arguments.size() > 1)
|
if (index.arguments.size() > 1)
|
||||||
scalar_kind = nameToScalarKind.at(index.arguments[1].safeGet<String>());
|
scalar_kind = nameToScalarKind.at(index.arguments[1].safeGet<String>());
|
||||||
|
|
||||||
return std::make_shared<MergeTreeIndexUSearch>(index, distance_function, scalar_kind);
|
return std::make_shared<MergeTreeIndexUSearch>(index, metric_kind, scalar_kind);
|
||||||
}
|
}
|
||||||
|
|
||||||
void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
|
void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
|
||||||
@ -365,26 +367,13 @@ void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
|
|||||||
|
|
||||||
/// Check that a supported metric was passed as first argument
|
/// Check that a supported metric was passed as first argument
|
||||||
|
|
||||||
if (!index.arguments.empty())
|
if (!index.arguments.empty() && !nameToMetricKind.contains(index.arguments[0].safeGet<String>()))
|
||||||
{
|
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized metric kind (first argument) for vector index. Supported kinds are: {}", keysAsString(nameToMetricKind));
|
||||||
String distance_name = index.arguments[0].safeGet<String>();
|
|
||||||
if (distance_name != DISTANCE_FUNCTION_L2 && distance_name != DISTANCE_FUNCTION_COSINE)
|
|
||||||
throw Exception(ErrorCodes::INCORRECT_DATA, "USearch index only supports distance functions '{}' and '{}'", DISTANCE_FUNCTION_L2, DISTANCE_FUNCTION_COSINE);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check that a supported kind was passed as a second argument
|
/// Check that a supported kind was passed as a second argument
|
||||||
|
|
||||||
if (index.arguments.size() > 1 && !nameToScalarKind.contains(index.arguments[1].safeGet<String>()))
|
if (index.arguments.size() > 1 && !nameToScalarKind.contains(index.arguments[1].safeGet<String>()))
|
||||||
{
|
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized scalar kind (second argument) for vector index. Supported kinds are: {}", keysAsString(nameToScalarKind));
|
||||||
String supported_kinds;
|
|
||||||
for (const auto & [name, kind] : nameToScalarKind)
|
|
||||||
{
|
|
||||||
if (!supported_kinds.empty())
|
|
||||||
supported_kinds += ", ";
|
|
||||||
supported_kinds += name;
|
|
||||||
}
|
|
||||||
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized scalar kind (second argument) for USearch index. Supported kinds are: {}", supported_kinds);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check data type of indexed column:
|
/// Check data type of indexed column:
|
||||||
|
|
||||||
|
@ -15,26 +15,23 @@ namespace DB
|
|||||||
|
|
||||||
using USearchIndex = unum::usearch::index_dense_gt</*key_at*/ uint32_t, /*compressed_slot_at*/ uint32_t>;
|
using USearchIndex = unum::usearch::index_dense_gt</*key_at*/ uint32_t, /*compressed_slot_at*/ uint32_t>;
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
|
||||||
class USearchIndexWithSerialization : public USearchIndex
|
class USearchIndexWithSerialization : public USearchIndex
|
||||||
{
|
{
|
||||||
using Base = USearchIndex;
|
using Base = USearchIndex;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind);
|
USearchIndexWithSerialization(size_t dimensions, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind);
|
||||||
void serialize(WriteBuffer & ostr) const;
|
void serialize(WriteBuffer & ostr) const;
|
||||||
void deserialize(ReadBuffer & istr);
|
void deserialize(ReadBuffer & istr);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization>;
|
||||||
using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization<Metric>>;
|
|
||||||
|
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
|
||||||
struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
|
struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
|
||||||
{
|
{
|
||||||
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_);
|
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind_);
|
||||||
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_, USearchIndexWithSerializationPtr<Metric> index_);
|
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::metric_kind_t metric_kind, unum::usearch::scalar_kind_t scalar_kind_, USearchIndexWithSerializationPtr index_);
|
||||||
|
|
||||||
~MergeTreeIndexGranuleUSearch() override = default;
|
~MergeTreeIndexGranuleUSearch() override = default;
|
||||||
|
|
||||||
@ -45,15 +42,15 @@ struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
|
|||||||
|
|
||||||
const String index_name;
|
const String index_name;
|
||||||
const Block index_sample_block;
|
const Block index_sample_block;
|
||||||
|
const unum::usearch::metric_kind_t metric_kind;
|
||||||
const unum::usearch::scalar_kind_t scalar_kind;
|
const unum::usearch::scalar_kind_t scalar_kind;
|
||||||
USearchIndexWithSerializationPtr<Metric> index;
|
USearchIndexWithSerializationPtr index;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
|
||||||
struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
|
struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
|
||||||
{
|
{
|
||||||
MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block, unum::usearch::scalar_kind_t scalar_kind_);
|
MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_);
|
||||||
~MergeTreeIndexAggregatorUSearch() override = default;
|
~MergeTreeIndexAggregatorUSearch() override = default;
|
||||||
|
|
||||||
bool empty() const override { return !index || index->size() == 0; }
|
bool empty() const override { return !index || index->size() == 0; }
|
||||||
@ -62,8 +59,9 @@ struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
|
|||||||
|
|
||||||
const String index_name;
|
const String index_name;
|
||||||
const Block index_sample_block;
|
const Block index_sample_block;
|
||||||
|
const unum::usearch::metric_kind_t metric_kind;
|
||||||
const unum::usearch::scalar_kind_t scalar_kind;
|
const unum::usearch::scalar_kind_t scalar_kind;
|
||||||
USearchIndexWithSerializationPtr<Metric> index;
|
USearchIndexWithSerializationPtr index;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -73,7 +71,7 @@ public:
|
|||||||
MergeTreeIndexConditionUSearch(
|
MergeTreeIndexConditionUSearch(
|
||||||
const IndexDescription & index_description,
|
const IndexDescription & index_description,
|
||||||
const SelectQueryInfo & query,
|
const SelectQueryInfo & query,
|
||||||
const String & distance_function,
|
unum::usearch::metric_kind_t metric_kind_,
|
||||||
ContextPtr context);
|
ContextPtr context);
|
||||||
|
|
||||||
~MergeTreeIndexConditionUSearch() override = default;
|
~MergeTreeIndexConditionUSearch() override = default;
|
||||||
@ -83,18 +81,15 @@ public:
|
|||||||
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;
|
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <unum::usearch::metric_kind_t Metric>
|
|
||||||
std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;
|
|
||||||
|
|
||||||
const ApproximateNearestNeighborCondition ann_condition;
|
const ApproximateNearestNeighborCondition ann_condition;
|
||||||
const String distance_function;
|
const unum::usearch::metric_kind_t metric_kind;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class MergeTreeIndexUSearch : public IMergeTreeIndex
|
class MergeTreeIndexUSearch : public IMergeTreeIndex
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_);
|
MergeTreeIndexUSearch(const IndexDescription & index_, unum::usearch::metric_kind_t metric_kind_, unum::usearch::scalar_kind_t scalar_kind_);
|
||||||
|
|
||||||
~MergeTreeIndexUSearch() override = default;
|
~MergeTreeIndexUSearch() override = default;
|
||||||
|
|
||||||
@ -105,7 +100,7 @@ public:
|
|||||||
bool isVectorSearch() const override { return true; }
|
bool isVectorSearch() const override { return true; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const String distance_function;
|
const unum::usearch::metric_kind_t metric_kind;
|
||||||
const unum::usearch::scalar_kind_t scalar_kind;
|
const unum::usearch::scalar_kind_t scalar_kind;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user