add cosine distance for annoy and fix docs for cosineDistance

This commit is contained in:
FArthur-cmd 2022-10-28 17:03:35 +00:00
parent ba7c3c5eae
commit f187d4e1d4
6 changed files with 180 additions and 50 deletions

View File

@ -2,13 +2,20 @@
The main task that indexes achieve is to quickly find nearest neighbors for multidimensional data. An example of such a problem can be finding similar pictures (texts) for a given picture (text). That problem can be reduced to finding the nearest [embeddings](https://cloud.google.com/architecture/overview-extracting-and-serving-feature-embeddings-for-machine-learning). They can be created from data using [UDF](../../../sql-reference/functions/index.md#executable-user-defined-functions).
The next query finds the closest neighbors in N-dimensional space using the L2 (Euclidean) distance:
The next queries find the closest neighbors in N-dimensional space using the L2 (Euclidean) distance:
``` sql
SELECT *
FROM table_name
WHERE L2Distance(Column, Point) < MaxDistance
LIMIT N
```
``` sql
SELECT *
FROM table_name
ORDER BY L2Distance(Column, Point)
LIMIT N
```
But it will take some time for execution because of the long calculation of the distance between `TargetEmbedding` and all other vectors. This is where ANN indexes can help. They store a compact approximation of the search space (e.g. using clustering, search trees, etc.) and are able to compute approximate neighbors quickly.
## Indexes Structure
@ -53,7 +60,7 @@ CREATE TABLE t
(
`id` Int64,
`number` Tuple(Float32, Float32, Float32),
INDEX x number TYPE annoy GRANULARITY N
INDEX x number TYPE index_name(parameters) GRANULARITY N
)
ENGINE = MergeTree
ORDER BY id;
@ -64,7 +71,7 @@ CREATE TABLE t
(
`id` Int64,
`number` Array(Float32),
INDEX x number TYPE annoy GRANULARITY N
INDEX x number TYPE index_name(parameters) GRANULARITY N
)
ENGINE = MergeTree
ORDER BY id;
@ -92,7 +99,7 @@ CREATE TABLE t
(
id Int64,
number Tuple(Float32, Float32, Float32),
INDEX x number TYPE annoy(T) GRANULARITY N
INDEX x number TYPE annoy(Trees, DistanceName) GRANULARITY N
)
ENGINE = MergeTree
ORDER BY id;
@ -103,7 +110,7 @@ CREATE TABLE t
(
id Int64,
number Array(Float32),
INDEX x number TYPE annoy(T) GRANULARITY N
INDEX x number TYPE annoy(Trees, DistanceName) GRANULARITY N
)
ENGINE = MergeTree
ORDER BY id;
@ -111,9 +118,19 @@ ORDER BY id;
!!! note "Note"
Table with array field will work faster, but all arrays **must** have same length. Use [CONSTRAINT](../../../sql-reference/statements/create/table.md#constraints) to avoid errors. For example, `CONSTRAINT constraint_name_1 CHECK length(number) = 256`.
Parameter `T` is the number of trees which algorithm will create. The bigger it is, the slower (approximately linear) it works (in both `CREATE` and `SELECT` requests), but the better accuracy you get (adjusted for randomness).
Parameter `Trees` is the number of trees which algorithm will create. The bigger it is, the slower (approximately linear) it works (in both `CREATE` and `SELECT` requests), but the better accuracy you get (adjusted for randomness). By default it is set to `100`. Parameter `DistanceName` is name of distance function. By default it is set to `L2Distance`. It can be set without changing first parameter, for example
```sql
CREATE TABLE t
(
id Int64,
number Array(Float32),
INDEX x number TYPE annoy('cosineDistance') GRANULARITY N
)
ENGINE = MergeTree
ORDER BY id;
```
Annoy supports only `L2Distance`.
Annoy supports `L2Distance` and `cosineDistance`.
In the `SELECT` in the settings (`ann_index_select_query_params`) you can specify the size of the internal buffer (more details in the description above or in the [original repository](https://github.com/spotify/annoy)). During the query it will inspect up to `search_k` nodes which defaults to `n_trees * n` if not provided. `search_k` gives you a run-time tradeoff between better accuracy and speed.

View File

@ -474,13 +474,13 @@ Calculates the cosine distance between two vectors (the values of the tuples are
**Syntax**
```sql
cosineDistance(tuple1, tuple2)
cosineDistance(vector1, vector2)
```
**Arguments**
- `tuple1` — First tuple. [Tuple](../../sql-reference/data-types/tuple.md).
- `tuple2` — Second tuple. [Tuple](../../sql-reference/data-types/tuple.md).
- `vector1` — First tuple. [Tuple](../../sql-reference/data-types/tuple.md) or [Array](../../sql-reference/data-types/array.md).
- `vector2` — Second tuple. [Tuple](../../sql-reference/data-types/tuple.md) or [Array](../../sql-reference/data-types/array.md).
**Returned value**
@ -488,7 +488,7 @@ cosineDistance(tuple1, tuple2)
Type: [Float](../../sql-reference/data-types/float.md).
**Example**
**Examples**
Query:
@ -503,3 +503,17 @@ Result:
│ 0.007722123286332261 │
└────────────────────────────────┘
```
Query:
```sql
SELECT cosineDistance([1, 2], [2, 3]);
```
Result:
```text
┌─cosineDistance([1, 2], [2, 3])─┐
│ 0.007722123286332261 │
└────────────────────────────────┘
```

View File

@ -69,13 +69,15 @@ namespace ErrorCodes
extern const int INCORRECT_DATA;
}
MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_)
template <typename Distance>
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, index(nullptr)
{}
MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy(
template <typename Distance>
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(
const String & index_name_,
const Block & index_sample_block_,
AnnoyIndexPtr index_base_)
@ -84,7 +86,8 @@ MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy(
, index(std::move(index_base_))
{}
void MergeTreeIndexGranuleAnnoy::serializeBinary(WriteBuffer & ostr) const
template <typename Distance>
void MergeTreeIndexGranuleAnnoy<Distance>::serializeBinary(WriteBuffer & ostr) const
{
/// number of dimensions is required in the constructor,
/// so it must be written and read separately from the other part
@ -92,7 +95,8 @@ void MergeTreeIndexGranuleAnnoy::serializeBinary(WriteBuffer & ostr) const
index->serialize(ostr);
}
void MergeTreeIndexGranuleAnnoy::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
template <typename Distance>
void MergeTreeIndexGranuleAnnoy<Distance>::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
{
uint64_t dimension;
readIntBinary(dimension, istr);
@ -100,8 +104,8 @@ void MergeTreeIndexGranuleAnnoy::deserializeBinary(ReadBuffer & istr, MergeTreeI
index->deserialize(istr);
}
MergeTreeIndexAggregatorAnnoy::MergeTreeIndexAggregatorAnnoy(
template <typename Distance>
MergeTreeIndexAggregatorAnnoy<Distance>::MergeTreeIndexAggregatorAnnoy(
const String & index_name_,
const Block & index_sample_block_,
uint64_t number_of_trees_)
@ -110,16 +114,18 @@ MergeTreeIndexAggregatorAnnoy::MergeTreeIndexAggregatorAnnoy(
, number_of_trees(number_of_trees_)
{}
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy::getGranuleAndReset()
template <typename Distance>
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy<Distance>::getGranuleAndReset()
{
// NOLINTNEXTLINE(*)
index->build(number_of_trees, /*number_of_threads=*/1);
auto granule = std::make_shared<MergeTreeIndexGranuleAnnoy>(index_name, index_sample_block, index);
auto granule = std::make_shared<MergeTreeIndexGranuleAnnoy<Distance> >(index_name, index_sample_block, index);
index = nullptr;
return granule;
}
void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, size_t limit)
template <typename Distance>
void MergeTreeIndexAggregatorAnnoy<Distance>::update(const Block & block, size_t * pos, size_t limit)
{
if (*pos >= block.rows())
throw Exception(
@ -203,8 +209,9 @@ void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, si
MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy(
const IndexDescription & /*index*/,
const SelectQueryInfo & query,
ContextPtr context)
: condition(query, context)
ContextPtr context,
const String& metric_name_)
: condition(query, context), metric_name(metric_name_)
{}
@ -215,10 +222,28 @@ bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /
bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const
{
return condition.alwaysUnknownOrTrue("L2Distance");
return condition.alwaysUnknownOrTrue(metric_name);
}
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
{
if (metric_name == "L2Distance")
{
return getUsefulRangesImpl<::Annoy::Euclidean>(idx_granule);
}
else if (metric_name == "cosineDistance")
{
return getUsefulRangesImpl<::Annoy::Angular>(idx_granule);
}
else
{
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong metric type {}", metric_name);
}
}
template <typename Distance>
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
{
UInt64 limit = condition.getLimit();
UInt64 index_granularity = condition.getIndexGranularity();
@ -230,7 +255,7 @@ std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndex
std::vector<float> target_vec = condition.getTargetVector();
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy>(idx_granule);
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy<Distance> >(idx_granule);
if (granule == nullptr)
{
throw Exception("Granule has the wrong type", ErrorCodes::LOGICAL_ERROR);
@ -284,38 +309,70 @@ std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndex
return result_vector;
}
MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
{
return std::make_shared<MergeTreeIndexGranuleAnnoy>(index.name, index.sample_block);
if (metric_name == "L2Distance")
{
return std::make_shared<MergeTreeIndexGranuleAnnoy<::Annoy::Euclidean> >(index.name, index.sample_block);
}
if (metric_name == "cosineDistance")
{
return std::make_shared<MergeTreeIndexGranuleAnnoy<::Annoy::Angular> >(index.name, index.sample_block);
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Wrong metric name {}", metric_name);
}
MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const
{
return std::make_shared<MergeTreeIndexAggregatorAnnoy>(index.name, index.sample_block, number_of_trees);
if (metric_name == "L2Distance")
{
return std::make_shared<MergeTreeIndexAggregatorAnnoy<::Annoy::Euclidean> >(index.name, index.sample_block, number_of_trees);
}
if (metric_name == "cosineDistance")
{
return std::make_shared<MergeTreeIndexAggregatorAnnoy<::Annoy::Angular> >(index.name, index.sample_block, number_of_trees);
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Wrong metric name {}", metric_name);
}
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(
const SelectQueryInfo & query, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionAnnoy>(index, query, context);
return std::make_shared<MergeTreeIndexConditionAnnoy>(index, query, context, metric_name);
};
MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index)
{
uint64_t param = index.arguments[0].get<uint64_t>();
return std::make_shared<MergeTreeIndexAnnoy>(index, param);
uint64_t param = 100;
String metric_name = "L2Distance";
if (index.arguments.size() > 0 && !index.arguments[0].tryGet<uint64_t>(param))
{
if (!index.arguments[0].tryGet<String>(metric_name))
{
throw Exception("Can't parse first argument", ErrorCodes::INCORRECT_DATA);
}
}
if (index.arguments.size() > 1 && !index.arguments[1].tryGet<String>(metric_name))
{
throw Exception("Can't parse second argument", ErrorCodes::INCORRECT_DATA);
}
return std::make_shared<MergeTreeIndexAnnoy>(index, param, metric_name);
}
void annoyIndexValidator(const IndexDescription & index, bool /* attach */)
{
if (index.arguments.size() != 1)
if (index.arguments.size() > 2)
{
throw Exception("Annoy index must have exactly one argument.", ErrorCodes::INCORRECT_QUERY);
throw Exception("Annoy index must no more than two arguments.", ErrorCodes::INCORRECT_QUERY);
}
if (index.arguments[0].getType() != Field::Types::UInt64)
if (index.arguments.size() > 0 && index.arguments[0].getType() != Field::Types::UInt64
&& index.arguments[0].getType() != Field::Types::String)
{
throw Exception("Annoy index argument must be UInt64.", ErrorCodes::INCORRECT_QUERY);
throw Exception("Annoy index first argument must be UInt64 or String.", ErrorCodes::INCORRECT_QUERY);
}
if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::String)
{
throw Exception("Annoy index second argument must be String.", ErrorCodes::INCORRECT_QUERY);
}
}

View File

@ -19,7 +19,7 @@ namespace ApproximateNearestNeighbour
using AnnoyIndexThreadedBuildPolicy = ::Annoy::AnnoyIndexMultiThreadedBuildPolicy;
// TODO: Support different metrics. List of available metrics can be taken from here:
// https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171
template <typename Distance = ::Annoy::Euclidean>
template <typename Distance>
class AnnoyIndex : public ::Annoy::AnnoyIndex<UInt64, Float32, Distance, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>
{
using Base = ::Annoy::AnnoyIndex<UInt64, Float32, Distance, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
@ -31,9 +31,10 @@ namespace ApproximateNearestNeighbour
};
}
template <typename Distance>
struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
{
using AnnoyIndex = ANN::AnnoyIndex<>;
using AnnoyIndex = ANN::AnnoyIndex<Distance>;
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_);
@ -54,10 +55,10 @@ struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
AnnoyIndexPtr index;
};
template <typename Distance>
struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator
{
using AnnoyIndex = ANN::AnnoyIndex<>;
using AnnoyIndex = ANN::AnnoyIndex<Distance>;
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block, uint64_t number_of_trees);
@ -80,7 +81,8 @@ public:
MergeTreeIndexConditionAnnoy(
const IndexDescription & index,
const SelectQueryInfo & query,
ContextPtr context);
ContextPtr context,
const String& metric_name);
bool alwaysUnknownOrTrue() const override;
@ -91,16 +93,21 @@ public:
~MergeTreeIndexConditionAnnoy() override = default;
private:
template <typename Distance>
std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;
ANN::ANNCondition condition;
const String metric_name;
};
class MergeTreeIndexAnnoy : public IMergeTreeIndex
{
public:
MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t number_of_trees_)
MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t number_of_trees_, const String& metric_name_)
: IMergeTreeIndex(index_)
, number_of_trees(number_of_trees_)
, metric_name(metric_name_)
{}
~MergeTreeIndexAnnoy() override = default;
@ -115,6 +122,7 @@ public:
private:
const uint64_t number_of_trees;
const String metric_name;
};

View File

@ -14,3 +14,11 @@
1 [0,0,10]
5 [0,0,10.2]
4 [0,0,9.7]
1 [0,0,10]
2 [0,0,10.5]
3 [0,0,9.5]
4 [0,0,9.7]
5 [0,0,10.2]
1 [0,0,10]
5 [0,0,10.2]
4 [0,0,9.7]

View File

@ -2,45 +2,71 @@
SET allow_experimental_annoy_index = 1;
DROP TABLE IF EXISTS 02354_annoy;
DROP TABLE IF EXISTS 02354_annoy_l2;
CREATE TABLE 02354_annoy
CREATE TABLE 02354_annoy_l2
(
id Int32,
embedding Array(Float32),
INDEX annoy_index embedding TYPE annoy(100) GRANULARITY 1
INDEX annoy_index embedding TYPE annoy() GRANULARITY 1
)
ENGINE = MergeTree
ORDER BY id
SETTINGS index_granularity=5;
INSERT INTO 02354_annoy VALUES (1, [0.0, 0.0, 10.0]), (2, [0.0, 0.0, 10.5]), (3, [0.0, 0.0, 9.5]), (4, [0.0, 0.0, 9.7]), (5, [0.0, 0.0, 10.2]), (6, [10.0, 0.0, 0.0]), (7, [9.5, 0.0, 0.0]), (8, [9.7, 0.0, 0.0]), (9, [10.2, 0.0, 0.0]), (10, [10.5, 0.0, 0.0]), (11, [0.0, 10.0, 0.0]), (12, [0.0, 9.5, 0.0]), (13, [0.0, 9.7, 0.0]), (14, [0.0, 10.2, 0.0]), (15, [0.0, 10.5, 0.0]);
INSERT INTO 02354_annoy_l2 VALUES (1, [0.0, 0.0, 10.0]), (2, [0.0, 0.0, 10.5]), (3, [0.0, 0.0, 9.5]), (4, [0.0, 0.0, 9.7]), (5, [0.0, 0.0, 10.2]), (6, [10.0, 0.0, 0.0]), (7, [9.5, 0.0, 0.0]), (8, [9.7, 0.0, 0.0]), (9, [10.2, 0.0, 0.0]), (10, [10.5, 0.0, 0.0]), (11, [0.0, 10.0, 0.0]), (12, [0.0, 9.5, 0.0]), (13, [0.0, 9.7, 0.0]), (14, [0.0, 10.2, 0.0]), (15, [0.0, 10.5, 0.0]);
SELECT *
FROM 02354_annoy
FROM 02354_annoy_l2
WHERE L2Distance(embedding, [0.0, 0.0, 10.0]) < 1.0
LIMIT 5;
SELECT *
FROM 02354_annoy
FROM 02354_annoy_l2
ORDER BY L2Distance(embedding, [0.0, 0.0, 10.0])
LIMIT 3;
SET param_02354_target_vector='[0.0, 0.0, 10.0]';
SELECT *
FROM 02354_annoy
FROM 02354_annoy_l2
WHERE L2Distance(embedding, {02354_target_vector: Array(Float32)}) < 1.0
LIMIT 5;
SELECT *
FROM 02354_annoy
FROM 02354_annoy_l2
ORDER BY L2Distance(embedding, {02354_target_vector: Array(Float32)})
LIMIT 3;
SELECT *
FROM 02354_annoy
FROM 02354_annoy_l2
ORDER BY L2Distance(embedding, [0.0, 0.0])
LIMIT 3; -- { serverError 80 }
DROP TABLE IF EXISTS 02354_annoy;
DROP TABLE IF EXISTS 02354_annoy_l2;
DROP TABLE IF EXISTS 02354_annoy_cosine;
CREATE TABLE 02354_annoy_cosine
(
id Int32,
embedding Array(Float32),
INDEX annoy_index embedding TYPE annoy(100, 'L2Distance') GRANULARITY 1
)
ENGINE = MergeTree
ORDER BY id
SETTINGS index_granularity=5;
INSERT INTO 02354_annoy_cosine VALUES (1, [0.0, 0.0, 10.0]), (2, [0.0, 0.0, 10.5]), (3, [0.0, 0.0, 9.5]), (4, [0.0, 0.0, 9.7]), (5, [0.0, 0.0, 10.2]), (6, [10.0, 0.0, 0.0]), (7, [9.5, 0.0, 0.0]), (8, [9.7, 0.0, 0.0]), (9, [10.2, 0.0, 0.0]), (10, [10.5, 0.0, 0.0]), (11, [0.0, 10.0, 0.0]), (12, [0.0, 9.5, 0.0]), (13, [0.0, 9.7, 0.0]), (14, [0.0, 10.2, 0.0]), (15, [0.0, 10.5, 0.0]);
SELECT *
FROM 02354_annoy_cosine
WHERE L2Distance(embedding, [0.0, 0.0, 10.0]) < 1.0
LIMIT 5;
SELECT *
FROM 02354_annoy_cosine
ORDER BY L2Distance(embedding, [0.0, 0.0, 10.0])
LIMIT 3;
DROP TABLE IF EXISTS 02354_annoy_cosine;