mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-23 16:12:01 +00:00
add cosine distance for annoy and fix docs for cosineDistance
This commit is contained in:
parent
ba7c3c5eae
commit
f187d4e1d4
@ -2,13 +2,20 @@
|
|||||||
|
|
||||||
The main task that indexes achieve is to quickly find nearest neighbors for multidimensional data. An example of such a problem can be finding similar pictures (texts) for a given picture (text). That problem can be reduced to finding the nearest [embeddings](https://cloud.google.com/architecture/overview-extracting-and-serving-feature-embeddings-for-machine-learning). They can be created from data using [UDF](../../../sql-reference/functions/index.md#executable-user-defined-functions).
|
The main task that indexes achieve is to quickly find nearest neighbors for multidimensional data. An example of such a problem can be finding similar pictures (texts) for a given picture (text). That problem can be reduced to finding the nearest [embeddings](https://cloud.google.com/architecture/overview-extracting-and-serving-feature-embeddings-for-machine-learning). They can be created from data using [UDF](../../../sql-reference/functions/index.md#executable-user-defined-functions).
|
||||||
|
|
||||||
The next query finds the closest neighbors in N-dimensional space using the L2 (Euclidean) distance:
|
The next queries find the closest neighbors in N-dimensional space using the L2 (Euclidean) distance:
|
||||||
``` sql
|
``` sql
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM table_name
|
FROM table_name
|
||||||
WHERE L2Distance(Column, Point) < MaxDistance
|
WHERE L2Distance(Column, Point) < MaxDistance
|
||||||
LIMIT N
|
LIMIT N
|
||||||
```
|
```
|
||||||
|
|
||||||
|
``` sql
|
||||||
|
SELECT *
|
||||||
|
FROM table_name
|
||||||
|
ORDER BY L2Distance(Column, Point)
|
||||||
|
LIMIT N
|
||||||
|
```
|
||||||
But it will take some time for execution because of the long calculation of the distance between `TargetEmbedding` and all other vectors. This is where ANN indexes can help. They store a compact approximation of the search space (e.g. using clustering, search trees, etc.) and are able to compute approximate neighbors quickly.
|
But it will take some time for execution because of the long calculation of the distance between `TargetEmbedding` and all other vectors. This is where ANN indexes can help. They store a compact approximation of the search space (e.g. using clustering, search trees, etc.) and are able to compute approximate neighbors quickly.
|
||||||
|
|
||||||
## Indexes Structure
|
## Indexes Structure
|
||||||
@ -53,7 +60,7 @@ CREATE TABLE t
|
|||||||
(
|
(
|
||||||
`id` Int64,
|
`id` Int64,
|
||||||
`number` Tuple(Float32, Float32, Float32),
|
`number` Tuple(Float32, Float32, Float32),
|
||||||
INDEX x number TYPE annoy GRANULARITY N
|
INDEX x number TYPE index_name(parameters) GRANULARITY N
|
||||||
)
|
)
|
||||||
ENGINE = MergeTree
|
ENGINE = MergeTree
|
||||||
ORDER BY id;
|
ORDER BY id;
|
||||||
@ -64,7 +71,7 @@ CREATE TABLE t
|
|||||||
(
|
(
|
||||||
`id` Int64,
|
`id` Int64,
|
||||||
`number` Array(Float32),
|
`number` Array(Float32),
|
||||||
INDEX x number TYPE annoy GRANULARITY N
|
INDEX x number TYPE index_name(parameters) GRANULARITY N
|
||||||
)
|
)
|
||||||
ENGINE = MergeTree
|
ENGINE = MergeTree
|
||||||
ORDER BY id;
|
ORDER BY id;
|
||||||
@ -92,7 +99,7 @@ CREATE TABLE t
|
|||||||
(
|
(
|
||||||
id Int64,
|
id Int64,
|
||||||
number Tuple(Float32, Float32, Float32),
|
number Tuple(Float32, Float32, Float32),
|
||||||
INDEX x number TYPE annoy(T) GRANULARITY N
|
INDEX x number TYPE annoy(Trees, DistanceName) GRANULARITY N
|
||||||
)
|
)
|
||||||
ENGINE = MergeTree
|
ENGINE = MergeTree
|
||||||
ORDER BY id;
|
ORDER BY id;
|
||||||
@ -103,7 +110,7 @@ CREATE TABLE t
|
|||||||
(
|
(
|
||||||
id Int64,
|
id Int64,
|
||||||
number Array(Float32),
|
number Array(Float32),
|
||||||
INDEX x number TYPE annoy(T) GRANULARITY N
|
INDEX x number TYPE annoy(Trees, DistanceName) GRANULARITY N
|
||||||
)
|
)
|
||||||
ENGINE = MergeTree
|
ENGINE = MergeTree
|
||||||
ORDER BY id;
|
ORDER BY id;
|
||||||
@ -111,9 +118,19 @@ ORDER BY id;
|
|||||||
!!! note "Note"
|
!!! note "Note"
|
||||||
Table with array field will work faster, but all arrays **must** have same length. Use [CONSTRAINT](../../../sql-reference/statements/create/table.md#constraints) to avoid errors. For example, `CONSTRAINT constraint_name_1 CHECK length(number) = 256`.
|
Table with array field will work faster, but all arrays **must** have same length. Use [CONSTRAINT](../../../sql-reference/statements/create/table.md#constraints) to avoid errors. For example, `CONSTRAINT constraint_name_1 CHECK length(number) = 256`.
|
||||||
|
|
||||||
Parameter `T` is the number of trees which algorithm will create. The bigger it is, the slower (approximately linear) it works (in both `CREATE` and `SELECT` requests), but the better accuracy you get (adjusted for randomness).
|
Parameter `Trees` is the number of trees which algorithm will create. The bigger it is, the slower (approximately linear) it works (in both `CREATE` and `SELECT` requests), but the better accuracy you get (adjusted for randomness). By default it is set to `100`. Parameter `DistanceName` is name of distance function. By default it is set to `L2Distance`. It can be set without changing first parameter, for example
|
||||||
|
```sql
|
||||||
|
CREATE TABLE t
|
||||||
|
(
|
||||||
|
id Int64,
|
||||||
|
number Array(Float32),
|
||||||
|
INDEX x number TYPE annoy('cosineDistance') GRANULARITY N
|
||||||
|
)
|
||||||
|
ENGINE = MergeTree
|
||||||
|
ORDER BY id;
|
||||||
|
```
|
||||||
|
|
||||||
Annoy supports only `L2Distance`.
|
Annoy supports `L2Distance` and `cosineDistance`.
|
||||||
|
|
||||||
In the `SELECT` in the settings (`ann_index_select_query_params`) you can specify the size of the internal buffer (more details in the description above or in the [original repository](https://github.com/spotify/annoy)). During the query it will inspect up to `search_k` nodes which defaults to `n_trees * n` if not provided. `search_k` gives you a run-time tradeoff between better accuracy and speed.
|
In the `SELECT` in the settings (`ann_index_select_query_params`) you can specify the size of the internal buffer (more details in the description above or in the [original repository](https://github.com/spotify/annoy)). During the query it will inspect up to `search_k` nodes which defaults to `n_trees * n` if not provided. `search_k` gives you a run-time tradeoff between better accuracy and speed.
|
||||||
|
|
||||||
|
@ -474,13 +474,13 @@ Calculates the cosine distance between two vectors (the values of the tuples are
|
|||||||
**Syntax**
|
**Syntax**
|
||||||
|
|
||||||
```sql
|
```sql
|
||||||
cosineDistance(tuple1, tuple2)
|
cosineDistance(vector1, vector2)
|
||||||
```
|
```
|
||||||
|
|
||||||
**Arguments**
|
**Arguments**
|
||||||
|
|
||||||
- `tuple1` — First tuple. [Tuple](../../sql-reference/data-types/tuple.md).
|
- `vector1` — First tuple. [Tuple](../../sql-reference/data-types/tuple.md) or [Array](../../sql-reference/data-types/array.md).
|
||||||
- `tuple2` — Second tuple. [Tuple](../../sql-reference/data-types/tuple.md).
|
- `vector2` — Second tuple. [Tuple](../../sql-reference/data-types/tuple.md) or [Array](../../sql-reference/data-types/array.md).
|
||||||
|
|
||||||
**Returned value**
|
**Returned value**
|
||||||
|
|
||||||
@ -488,7 +488,7 @@ cosineDistance(tuple1, tuple2)
|
|||||||
|
|
||||||
Type: [Float](../../sql-reference/data-types/float.md).
|
Type: [Float](../../sql-reference/data-types/float.md).
|
||||||
|
|
||||||
**Example**
|
**Examples**
|
||||||
|
|
||||||
Query:
|
Query:
|
||||||
|
|
||||||
@ -503,3 +503,17 @@ Result:
|
|||||||
│ 0.007722123286332261 │
|
│ 0.007722123286332261 │
|
||||||
└────────────────────────────────┘
|
└────────────────────────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Query:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
SELECT cosineDistance([1, 2], [2, 3]);
|
||||||
|
```
|
||||||
|
|
||||||
|
Result:
|
||||||
|
|
||||||
|
```text
|
||||||
|
┌─cosineDistance([1, 2], [2, 3])─┐
|
||||||
|
│ 0.007722123286332261 │
|
||||||
|
└────────────────────────────────┘
|
||||||
|
```
|
||||||
|
@ -69,13 +69,15 @@ namespace ErrorCodes
|
|||||||
extern const int INCORRECT_DATA;
|
extern const int INCORRECT_DATA;
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_)
|
template <typename Distance>
|
||||||
|
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_)
|
||||||
: index_name(index_name_)
|
: index_name(index_name_)
|
||||||
, index_sample_block(index_sample_block_)
|
, index_sample_block(index_sample_block_)
|
||||||
, index(nullptr)
|
, index(nullptr)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy(
|
template <typename Distance>
|
||||||
|
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(
|
||||||
const String & index_name_,
|
const String & index_name_,
|
||||||
const Block & index_sample_block_,
|
const Block & index_sample_block_,
|
||||||
AnnoyIndexPtr index_base_)
|
AnnoyIndexPtr index_base_)
|
||||||
@ -84,7 +86,8 @@ MergeTreeIndexGranuleAnnoy::MergeTreeIndexGranuleAnnoy(
|
|||||||
, index(std::move(index_base_))
|
, index(std::move(index_base_))
|
||||||
{}
|
{}
|
||||||
|
|
||||||
void MergeTreeIndexGranuleAnnoy::serializeBinary(WriteBuffer & ostr) const
|
template <typename Distance>
|
||||||
|
void MergeTreeIndexGranuleAnnoy<Distance>::serializeBinary(WriteBuffer & ostr) const
|
||||||
{
|
{
|
||||||
/// number of dimensions is required in the constructor,
|
/// number of dimensions is required in the constructor,
|
||||||
/// so it must be written and read separately from the other part
|
/// so it must be written and read separately from the other part
|
||||||
@ -92,7 +95,8 @@ void MergeTreeIndexGranuleAnnoy::serializeBinary(WriteBuffer & ostr) const
|
|||||||
index->serialize(ostr);
|
index->serialize(ostr);
|
||||||
}
|
}
|
||||||
|
|
||||||
void MergeTreeIndexGranuleAnnoy::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
|
template <typename Distance>
|
||||||
|
void MergeTreeIndexGranuleAnnoy<Distance>::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
|
||||||
{
|
{
|
||||||
uint64_t dimension;
|
uint64_t dimension;
|
||||||
readIntBinary(dimension, istr);
|
readIntBinary(dimension, istr);
|
||||||
@ -100,8 +104,8 @@ void MergeTreeIndexGranuleAnnoy::deserializeBinary(ReadBuffer & istr, MergeTreeI
|
|||||||
index->deserialize(istr);
|
index->deserialize(istr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Distance>
|
||||||
MergeTreeIndexAggregatorAnnoy::MergeTreeIndexAggregatorAnnoy(
|
MergeTreeIndexAggregatorAnnoy<Distance>::MergeTreeIndexAggregatorAnnoy(
|
||||||
const String & index_name_,
|
const String & index_name_,
|
||||||
const Block & index_sample_block_,
|
const Block & index_sample_block_,
|
||||||
uint64_t number_of_trees_)
|
uint64_t number_of_trees_)
|
||||||
@ -110,16 +114,18 @@ MergeTreeIndexAggregatorAnnoy::MergeTreeIndexAggregatorAnnoy(
|
|||||||
, number_of_trees(number_of_trees_)
|
, number_of_trees(number_of_trees_)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy::getGranuleAndReset()
|
template <typename Distance>
|
||||||
|
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy<Distance>::getGranuleAndReset()
|
||||||
{
|
{
|
||||||
// NOLINTNEXTLINE(*)
|
// NOLINTNEXTLINE(*)
|
||||||
index->build(number_of_trees, /*number_of_threads=*/1);
|
index->build(number_of_trees, /*number_of_threads=*/1);
|
||||||
auto granule = std::make_shared<MergeTreeIndexGranuleAnnoy>(index_name, index_sample_block, index);
|
auto granule = std::make_shared<MergeTreeIndexGranuleAnnoy<Distance> >(index_name, index_sample_block, index);
|
||||||
index = nullptr;
|
index = nullptr;
|
||||||
return granule;
|
return granule;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, size_t limit)
|
template <typename Distance>
|
||||||
|
void MergeTreeIndexAggregatorAnnoy<Distance>::update(const Block & block, size_t * pos, size_t limit)
|
||||||
{
|
{
|
||||||
if (*pos >= block.rows())
|
if (*pos >= block.rows())
|
||||||
throw Exception(
|
throw Exception(
|
||||||
@ -203,8 +209,9 @@ void MergeTreeIndexAggregatorAnnoy::update(const Block & block, size_t * pos, si
|
|||||||
MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy(
|
MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy(
|
||||||
const IndexDescription & /*index*/,
|
const IndexDescription & /*index*/,
|
||||||
const SelectQueryInfo & query,
|
const SelectQueryInfo & query,
|
||||||
ContextPtr context)
|
ContextPtr context,
|
||||||
: condition(query, context)
|
const String& metric_name_)
|
||||||
|
: condition(query, context), metric_name(metric_name_)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
|
|
||||||
@ -215,10 +222,28 @@ bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /
|
|||||||
|
|
||||||
bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const
|
bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const
|
||||||
{
|
{
|
||||||
return condition.alwaysUnknownOrTrue("L2Distance");
|
return condition.alwaysUnknownOrTrue(metric_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
|
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
|
||||||
|
{
|
||||||
|
if (metric_name == "L2Distance")
|
||||||
|
{
|
||||||
|
return getUsefulRangesImpl<::Annoy::Euclidean>(idx_granule);
|
||||||
|
}
|
||||||
|
else if (metric_name == "cosineDistance")
|
||||||
|
{
|
||||||
|
return getUsefulRangesImpl<::Annoy::Angular>(idx_granule);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong metric type {}", metric_name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename Distance>
|
||||||
|
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
|
||||||
{
|
{
|
||||||
UInt64 limit = condition.getLimit();
|
UInt64 limit = condition.getLimit();
|
||||||
UInt64 index_granularity = condition.getIndexGranularity();
|
UInt64 index_granularity = condition.getIndexGranularity();
|
||||||
@ -230,7 +255,7 @@ std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndex
|
|||||||
|
|
||||||
std::vector<float> target_vec = condition.getTargetVector();
|
std::vector<float> target_vec = condition.getTargetVector();
|
||||||
|
|
||||||
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy>(idx_granule);
|
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy<Distance> >(idx_granule);
|
||||||
if (granule == nullptr)
|
if (granule == nullptr)
|
||||||
{
|
{
|
||||||
throw Exception("Granule has the wrong type", ErrorCodes::LOGICAL_ERROR);
|
throw Exception("Granule has the wrong type", ErrorCodes::LOGICAL_ERROR);
|
||||||
@ -284,38 +309,70 @@ std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndex
|
|||||||
return result_vector;
|
return result_vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
|
MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
|
||||||
{
|
{
|
||||||
return std::make_shared<MergeTreeIndexGranuleAnnoy>(index.name, index.sample_block);
|
if (metric_name == "L2Distance")
|
||||||
|
{
|
||||||
|
return std::make_shared<MergeTreeIndexGranuleAnnoy<::Annoy::Euclidean> >(index.name, index.sample_block);
|
||||||
|
}
|
||||||
|
if (metric_name == "cosineDistance")
|
||||||
|
{
|
||||||
|
return std::make_shared<MergeTreeIndexGranuleAnnoy<::Annoy::Angular> >(index.name, index.sample_block);
|
||||||
|
}
|
||||||
|
throw Exception(ErrorCodes::INCORRECT_DATA, "Wrong metric name {}", metric_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const
|
MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const
|
||||||
{
|
{
|
||||||
return std::make_shared<MergeTreeIndexAggregatorAnnoy>(index.name, index.sample_block, number_of_trees);
|
if (metric_name == "L2Distance")
|
||||||
|
{
|
||||||
|
return std::make_shared<MergeTreeIndexAggregatorAnnoy<::Annoy::Euclidean> >(index.name, index.sample_block, number_of_trees);
|
||||||
|
}
|
||||||
|
if (metric_name == "cosineDistance")
|
||||||
|
{
|
||||||
|
return std::make_shared<MergeTreeIndexAggregatorAnnoy<::Annoy::Angular> >(index.name, index.sample_block, number_of_trees);
|
||||||
|
}
|
||||||
|
throw Exception(ErrorCodes::INCORRECT_DATA, "Wrong metric name {}", metric_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(
|
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(
|
||||||
const SelectQueryInfo & query, ContextPtr context) const
|
const SelectQueryInfo & query, ContextPtr context) const
|
||||||
{
|
{
|
||||||
return std::make_shared<MergeTreeIndexConditionAnnoy>(index, query, context);
|
return std::make_shared<MergeTreeIndexConditionAnnoy>(index, query, context, metric_name);
|
||||||
};
|
};
|
||||||
|
|
||||||
MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index)
|
MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index)
|
||||||
{
|
{
|
||||||
uint64_t param = index.arguments[0].get<uint64_t>();
|
uint64_t param = 100;
|
||||||
return std::make_shared<MergeTreeIndexAnnoy>(index, param);
|
String metric_name = "L2Distance";
|
||||||
|
if (index.arguments.size() > 0 && !index.arguments[0].tryGet<uint64_t>(param))
|
||||||
|
{
|
||||||
|
if (!index.arguments[0].tryGet<String>(metric_name))
|
||||||
|
{
|
||||||
|
throw Exception("Can't parse first argument", ErrorCodes::INCORRECT_DATA);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (index.arguments.size() > 1 && !index.arguments[1].tryGet<String>(metric_name))
|
||||||
|
{
|
||||||
|
throw Exception("Can't parse second argument", ErrorCodes::INCORRECT_DATA);
|
||||||
|
}
|
||||||
|
return std::make_shared<MergeTreeIndexAnnoy>(index, param, metric_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void annoyIndexValidator(const IndexDescription & index, bool /* attach */)
|
void annoyIndexValidator(const IndexDescription & index, bool /* attach */)
|
||||||
{
|
{
|
||||||
if (index.arguments.size() != 1)
|
if (index.arguments.size() > 2)
|
||||||
{
|
{
|
||||||
throw Exception("Annoy index must have exactly one argument.", ErrorCodes::INCORRECT_QUERY);
|
throw Exception("Annoy index must no more than two arguments.", ErrorCodes::INCORRECT_QUERY);
|
||||||
}
|
}
|
||||||
if (index.arguments[0].getType() != Field::Types::UInt64)
|
if (index.arguments.size() > 0 && index.arguments[0].getType() != Field::Types::UInt64
|
||||||
|
&& index.arguments[0].getType() != Field::Types::String)
|
||||||
{
|
{
|
||||||
throw Exception("Annoy index argument must be UInt64.", ErrorCodes::INCORRECT_QUERY);
|
throw Exception("Annoy index first argument must be UInt64 or String.", ErrorCodes::INCORRECT_QUERY);
|
||||||
|
}
|
||||||
|
if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::String)
|
||||||
|
{
|
||||||
|
throw Exception("Annoy index second argument must be String.", ErrorCodes::INCORRECT_QUERY);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ namespace ApproximateNearestNeighbour
|
|||||||
using AnnoyIndexThreadedBuildPolicy = ::Annoy::AnnoyIndexMultiThreadedBuildPolicy;
|
using AnnoyIndexThreadedBuildPolicy = ::Annoy::AnnoyIndexMultiThreadedBuildPolicy;
|
||||||
// TODO: Support different metrics. List of available metrics can be taken from here:
|
// TODO: Support different metrics. List of available metrics can be taken from here:
|
||||||
// https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171
|
// https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171
|
||||||
template <typename Distance = ::Annoy::Euclidean>
|
template <typename Distance>
|
||||||
class AnnoyIndex : public ::Annoy::AnnoyIndex<UInt64, Float32, Distance, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>
|
class AnnoyIndex : public ::Annoy::AnnoyIndex<UInt64, Float32, Distance, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>
|
||||||
{
|
{
|
||||||
using Base = ::Annoy::AnnoyIndex<UInt64, Float32, Distance, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
|
using Base = ::Annoy::AnnoyIndex<UInt64, Float32, Distance, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
|
||||||
@ -31,9 +31,10 @@ namespace ApproximateNearestNeighbour
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Distance>
|
||||||
struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
|
struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
|
||||||
{
|
{
|
||||||
using AnnoyIndex = ANN::AnnoyIndex<>;
|
using AnnoyIndex = ANN::AnnoyIndex<Distance>;
|
||||||
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
|
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
|
||||||
|
|
||||||
MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_);
|
MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_);
|
||||||
@ -54,10 +55,10 @@ struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
|
|||||||
AnnoyIndexPtr index;
|
AnnoyIndexPtr index;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Distance>
|
||||||
struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator
|
struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator
|
||||||
{
|
{
|
||||||
using AnnoyIndex = ANN::AnnoyIndex<>;
|
using AnnoyIndex = ANN::AnnoyIndex<Distance>;
|
||||||
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
|
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
|
||||||
|
|
||||||
MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block, uint64_t number_of_trees);
|
MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block, uint64_t number_of_trees);
|
||||||
@ -80,7 +81,8 @@ public:
|
|||||||
MergeTreeIndexConditionAnnoy(
|
MergeTreeIndexConditionAnnoy(
|
||||||
const IndexDescription & index,
|
const IndexDescription & index,
|
||||||
const SelectQueryInfo & query,
|
const SelectQueryInfo & query,
|
||||||
ContextPtr context);
|
ContextPtr context,
|
||||||
|
const String& metric_name);
|
||||||
|
|
||||||
bool alwaysUnknownOrTrue() const override;
|
bool alwaysUnknownOrTrue() const override;
|
||||||
|
|
||||||
@ -91,16 +93,21 @@ public:
|
|||||||
~MergeTreeIndexConditionAnnoy() override = default;
|
~MergeTreeIndexConditionAnnoy() override = default;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
template <typename Distance>
|
||||||
|
std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;
|
||||||
|
|
||||||
ANN::ANNCondition condition;
|
ANN::ANNCondition condition;
|
||||||
|
const String metric_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
class MergeTreeIndexAnnoy : public IMergeTreeIndex
|
class MergeTreeIndexAnnoy : public IMergeTreeIndex
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t number_of_trees_)
|
MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t number_of_trees_, const String& metric_name_)
|
||||||
: IMergeTreeIndex(index_)
|
: IMergeTreeIndex(index_)
|
||||||
, number_of_trees(number_of_trees_)
|
, number_of_trees(number_of_trees_)
|
||||||
|
, metric_name(metric_name_)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
~MergeTreeIndexAnnoy() override = default;
|
~MergeTreeIndexAnnoy() override = default;
|
||||||
@ -115,6 +122,7 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
const uint64_t number_of_trees;
|
const uint64_t number_of_trees;
|
||||||
|
const String metric_name;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,3 +14,11 @@
|
|||||||
1 [0,0,10]
|
1 [0,0,10]
|
||||||
5 [0,0,10.2]
|
5 [0,0,10.2]
|
||||||
4 [0,0,9.7]
|
4 [0,0,9.7]
|
||||||
|
1 [0,0,10]
|
||||||
|
2 [0,0,10.5]
|
||||||
|
3 [0,0,9.5]
|
||||||
|
4 [0,0,9.7]
|
||||||
|
5 [0,0,10.2]
|
||||||
|
1 [0,0,10]
|
||||||
|
5 [0,0,10.2]
|
||||||
|
4 [0,0,9.7]
|
||||||
|
@ -2,45 +2,71 @@
|
|||||||
|
|
||||||
SET allow_experimental_annoy_index = 1;
|
SET allow_experimental_annoy_index = 1;
|
||||||
|
|
||||||
DROP TABLE IF EXISTS 02354_annoy;
|
DROP TABLE IF EXISTS 02354_annoy_l2;
|
||||||
|
|
||||||
CREATE TABLE 02354_annoy
|
CREATE TABLE 02354_annoy_l2
|
||||||
(
|
(
|
||||||
id Int32,
|
id Int32,
|
||||||
embedding Array(Float32),
|
embedding Array(Float32),
|
||||||
INDEX annoy_index embedding TYPE annoy(100) GRANULARITY 1
|
INDEX annoy_index embedding TYPE annoy() GRANULARITY 1
|
||||||
)
|
)
|
||||||
ENGINE = MergeTree
|
ENGINE = MergeTree
|
||||||
ORDER BY id
|
ORDER BY id
|
||||||
SETTINGS index_granularity=5;
|
SETTINGS index_granularity=5;
|
||||||
|
|
||||||
INSERT INTO 02354_annoy VALUES (1, [0.0, 0.0, 10.0]), (2, [0.0, 0.0, 10.5]), (3, [0.0, 0.0, 9.5]), (4, [0.0, 0.0, 9.7]), (5, [0.0, 0.0, 10.2]), (6, [10.0, 0.0, 0.0]), (7, [9.5, 0.0, 0.0]), (8, [9.7, 0.0, 0.0]), (9, [10.2, 0.0, 0.0]), (10, [10.5, 0.0, 0.0]), (11, [0.0, 10.0, 0.0]), (12, [0.0, 9.5, 0.0]), (13, [0.0, 9.7, 0.0]), (14, [0.0, 10.2, 0.0]), (15, [0.0, 10.5, 0.0]);
|
INSERT INTO 02354_annoy_l2 VALUES (1, [0.0, 0.0, 10.0]), (2, [0.0, 0.0, 10.5]), (3, [0.0, 0.0, 9.5]), (4, [0.0, 0.0, 9.7]), (5, [0.0, 0.0, 10.2]), (6, [10.0, 0.0, 0.0]), (7, [9.5, 0.0, 0.0]), (8, [9.7, 0.0, 0.0]), (9, [10.2, 0.0, 0.0]), (10, [10.5, 0.0, 0.0]), (11, [0.0, 10.0, 0.0]), (12, [0.0, 9.5, 0.0]), (13, [0.0, 9.7, 0.0]), (14, [0.0, 10.2, 0.0]), (15, [0.0, 10.5, 0.0]);
|
||||||
|
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM 02354_annoy
|
FROM 02354_annoy_l2
|
||||||
WHERE L2Distance(embedding, [0.0, 0.0, 10.0]) < 1.0
|
WHERE L2Distance(embedding, [0.0, 0.0, 10.0]) < 1.0
|
||||||
LIMIT 5;
|
LIMIT 5;
|
||||||
|
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM 02354_annoy
|
FROM 02354_annoy_l2
|
||||||
ORDER BY L2Distance(embedding, [0.0, 0.0, 10.0])
|
ORDER BY L2Distance(embedding, [0.0, 0.0, 10.0])
|
||||||
LIMIT 3;
|
LIMIT 3;
|
||||||
|
|
||||||
SET param_02354_target_vector='[0.0, 0.0, 10.0]';
|
SET param_02354_target_vector='[0.0, 0.0, 10.0]';
|
||||||
|
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM 02354_annoy
|
FROM 02354_annoy_l2
|
||||||
WHERE L2Distance(embedding, {02354_target_vector: Array(Float32)}) < 1.0
|
WHERE L2Distance(embedding, {02354_target_vector: Array(Float32)}) < 1.0
|
||||||
LIMIT 5;
|
LIMIT 5;
|
||||||
|
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM 02354_annoy
|
FROM 02354_annoy_l2
|
||||||
ORDER BY L2Distance(embedding, {02354_target_vector: Array(Float32)})
|
ORDER BY L2Distance(embedding, {02354_target_vector: Array(Float32)})
|
||||||
LIMIT 3;
|
LIMIT 3;
|
||||||
|
|
||||||
SELECT *
|
SELECT *
|
||||||
FROM 02354_annoy
|
FROM 02354_annoy_l2
|
||||||
ORDER BY L2Distance(embedding, [0.0, 0.0])
|
ORDER BY L2Distance(embedding, [0.0, 0.0])
|
||||||
LIMIT 3; -- { serverError 80 }
|
LIMIT 3; -- { serverError 80 }
|
||||||
|
|
||||||
DROP TABLE IF EXISTS 02354_annoy;
|
DROP TABLE IF EXISTS 02354_annoy_l2;
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS 02354_annoy_cosine;
|
||||||
|
|
||||||
|
CREATE TABLE 02354_annoy_cosine
|
||||||
|
(
|
||||||
|
id Int32,
|
||||||
|
embedding Array(Float32),
|
||||||
|
INDEX annoy_index embedding TYPE annoy(100, 'L2Distance') GRANULARITY 1
|
||||||
|
)
|
||||||
|
ENGINE = MergeTree
|
||||||
|
ORDER BY id
|
||||||
|
SETTINGS index_granularity=5;
|
||||||
|
|
||||||
|
INSERT INTO 02354_annoy_cosine VALUES (1, [0.0, 0.0, 10.0]), (2, [0.0, 0.0, 10.5]), (3, [0.0, 0.0, 9.5]), (4, [0.0, 0.0, 9.7]), (5, [0.0, 0.0, 10.2]), (6, [10.0, 0.0, 0.0]), (7, [9.5, 0.0, 0.0]), (8, [9.7, 0.0, 0.0]), (9, [10.2, 0.0, 0.0]), (10, [10.5, 0.0, 0.0]), (11, [0.0, 10.0, 0.0]), (12, [0.0, 9.5, 0.0]), (13, [0.0, 9.7, 0.0]), (14, [0.0, 10.2, 0.0]), (15, [0.0, 10.5, 0.0]);
|
||||||
|
|
||||||
|
SELECT *
|
||||||
|
FROM 02354_annoy_cosine
|
||||||
|
WHERE L2Distance(embedding, [0.0, 0.0, 10.0]) < 1.0
|
||||||
|
LIMIT 5;
|
||||||
|
|
||||||
|
SELECT *
|
||||||
|
FROM 02354_annoy_cosine
|
||||||
|
ORDER BY L2Distance(embedding, [0.0, 0.0, 10.0])
|
||||||
|
LIMIT 3;
|
||||||
|
|
||||||
|
DROP TABLE IF EXISTS 02354_annoy_cosine;
|
||||||
|
Loading…
Reference in New Issue
Block a user