Check distance function in CREATE TABLE instead of first INSERT

This commit is contained in:
Robert Schulze 2023-05-25 20:51:21 +00:00
parent 03b6856556
commit 18304f5aef
No known key found for this signature in database
GPG Key ID: 26703B55FB13728A
2 changed files with 14 additions and 2 deletions

View File

@ -307,7 +307,7 @@ MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Euclidean>>(index.name, index.sample_block); return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Euclidean>>(index.name, index.sample_block);
else if (distance_function == "cosineDistance") else if (distance_function == "cosineDistance")
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Angular>>(index.name, index.sample_block); return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Angular>>(index.name, index.sample_block);
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown distance name. Must be 'L2Distance' or 'cosineDistance'. Got {}", distance_function); std::unreachable();
} }
MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const
@ -317,7 +317,7 @@ MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Euclidean>>(index.name, index.sample_block, trees); return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Euclidean>>(index.name, index.sample_block, trees);
if (distance_function == "cosineDistance") if (distance_function == "cosineDistance")
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Angular>>(index.name, index.sample_block, trees); return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Angular>>(index.name, index.sample_block, trees);
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown distance name. Must be 'L2Distance' or 'cosineDistance'. Got {}", distance_function); std::unreachable();
} }
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
@ -357,6 +357,15 @@ void annoyIndexValidator(const IndexDescription & index, bool /* attach */)
if (index.column_names.size() != 1 || index.data_types.size() != 1) if (index.column_names.size() != 1 || index.data_types.size() != 1)
throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Annoy indexes must be created on a single column"); throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Annoy indexes must be created on a single column");
/// Check that a supported metric was passed as first argument
if (!index.arguments.empty())
{
String distance_name = index.arguments[0].get<String>();
if (distance_name != "L2Distance" && distance_name != "cosineDistance")
throw Exception(ErrorCodes::INCORRECT_DATA, "Annoy index supports only distance functions 'L2Distance' and 'cosineDistance'. Given distance function: {}", distance_name);
}
/// Check data type of indexed column: /// Check data type of indexed column:
auto throw_unsupported_underlying_column_exception = [](DataTypePtr data_type) auto throw_unsupported_underlying_column_exception = [](DataTypePtr data_type)

View File

@ -24,3 +24,6 @@ CREATE TABLE tab(id Int32, embedding Float32, INDEX annoy_index embedding TYPE a
CREATE TABLE tab(id Int32, embedding Array(Float64), INDEX annoy_index embedding TYPE annoy()) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, embedding Array(Float64), INDEX annoy_index embedding TYPE annoy()) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN }
CREATE TABLE tab(id Int32, embedding LowCardinality(Float32), INDEX annoy_index embedding TYPE annoy()) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, embedding LowCardinality(Float32), INDEX annoy_index embedding TYPE annoy()) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN }
CREATE TABLE tab(id Int32, embedding Nullable(Float32), INDEX annoy_index embedding TYPE annoy()) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN } CREATE TABLE tab(id Int32, embedding Nullable(Float32), INDEX annoy_index embedding TYPE annoy()) ENGINE = MergeTree ORDER BY id; -- { serverError ILLEGAL_COLUMN }
-- reject unsupported distance functions
CREATE TABLE tab(id Int32, embedding Array(Float32), INDEX annoy_index embedding TYPE annoy('wormholeDistance')) ENGINE = MergeTree ORDER BY id; -- { serverError INCORRECT_DATA }