Merge pull request #53444 from rschu1ze/factorize-constants

Minor: Factorize constants in Annoy index
This commit is contained in:
Robert Schulze 2023-08-17 11:05:56 +02:00 committed by GitHub
commit 067623a4c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,6 +25,11 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}
static constexpr auto DISTANCE_FUNCTION_L2 = "L2Distance";
static constexpr auto DISTANCE_FUNCTION_COSINE = "cosineDistance";
static constexpr auto DEFAULT_TREES = 100uz;
static constexpr auto DEFAULT_DISTANCE_FUNCTION = DISTANCE_FUNCTION_L2;
template <typename Distance>
AnnoyIndexWithSerialization<Distance>::AnnoyIndexWithSerialization(size_t dimensions)
@ -224,9 +229,9 @@ bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
{
if (distance_function == "L2Distance")
if (distance_function == DISTANCE_FUNCTION_L2)
return getUsefulRangesImpl<Annoy::Euclidean>(idx_granule);
else if (distance_function == "cosineDistance")
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return getUsefulRangesImpl<Annoy::Angular>(idx_granule);
std::unreachable();
}
@ -289,9 +294,9 @@ MergeTreeIndexAnnoy::MergeTreeIndexAnnoy(const IndexDescription & index_, UInt64
MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
{
if (distance_function == "L2Distance")
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Euclidean>>(index.name, index.sample_block);
else if (distance_function == "cosineDistance")
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Angular>>(index.name, index.sample_block);
std::unreachable();
}
@ -299,9 +304,9 @@ MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const
{
/// TODO: Support more metrics. Available metrics: https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171
if (distance_function == "L2Distance")
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Euclidean>>(index.name, index.sample_block, trees);
else if (distance_function == "cosineDistance")
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Angular>>(index.name, index.sample_block, trees);
std::unreachable();
}
@ -313,14 +318,11 @@ MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(const Selec
MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index)
{
static constexpr auto default_trees = 100uz;
static constexpr auto default_distance_function = "L2Distance";
String distance_function = default_distance_function;
String distance_function = DEFAULT_DISTANCE_FUNCTION;
if (!index.arguments.empty())
distance_function = index.arguments[0].get<String>();
UInt64 trees = default_trees;
UInt64 trees = DEFAULT_TREES;
if (index.arguments.size() > 1)
trees = index.arguments[1].get<UInt64>();
@ -350,8 +352,8 @@ void annoyIndexValidator(const IndexDescription & index, bool /* attach */)
if (!index.arguments.empty())
{
String distance_name = index.arguments[0].get<String>();
if (distance_name != "L2Distance" && distance_name != "cosineDistance")
throw Exception(ErrorCodes::INCORRECT_DATA, "Annoy index only supports distance functions 'L2Distance' and 'cosineDistance'");
if (distance_name != DISTANCE_FUNCTION_L2 && distance_name != DISTANCE_FUNCTION_COSINE)
throw Exception(ErrorCodes::INCORRECT_DATA, "Annoy index only supports distance functions '{}' and '{}'", DISTANCE_FUNCTION_L2, DISTANCE_FUNCTION_COSINE);
}
/// Check data type of indexed column: