Merge remote-tracking branch 'origin/master' into pr-local-plan

This commit is contained in:
Igor Nikonov 2024-08-13 16:16:31 +00:00
commit 5339717c11
115 changed files with 1752 additions and 2790 deletions

View File

@ -59,6 +59,9 @@ At a minimum, the following information should be added (but add more as needed)
- [ ] <!---ci_exclude_tsan|msan|ubsan|coverage--> Exclude: All with TSAN, MSAN, UBSAN, Coverage
- [ ] <!---ci_exclude_aarch64|release|debug--> Exclude: All with aarch64, release, debug
---
- [ ] <!---ci_include_fuzzer--> Run only fuzzers related jobs (libFuzzer fuzzers, AST fuzzers, etc.)
- [ ] <!---ci_exclude_ast--> Exclude: AST fuzzers
---
- [ ] <!---do_not_test--> Do not test
- [ ] <!---woolen_wolfdog--> Woolen Wolfdog
- [ ] <!---upload_all--> Upload binaries for special builds

3
.gitmodules vendored
View File

@ -230,9 +230,6 @@
[submodule "contrib/minizip-ng"]
path = contrib/minizip-ng
url = https://github.com/zlib-ng/minizip-ng
[submodule "contrib/annoy"]
path = contrib/annoy
url = https://github.com/ClickHouse/annoy
[submodule "contrib/qpl"]
path = contrib/qpl
url = https://github.com/intel/qpl

View File

@ -1,4 +1,4 @@
add_compile_options($<$<OR:$<COMPILE_LANGUAGE:C>,$<COMPILE_LANGUAGE:CXX>>:${COVERAGE_FLAGS}>)
add_compile_options("$<$<OR:$<COMPILE_LANGUAGE:C>,$<COMPILE_LANGUAGE:CXX>>:${COVERAGE_FLAGS}>")
if (USE_CLANG_TIDY)
set (CMAKE_CXX_CLANG_TIDY "${CLANG_TIDY_PATH}")

View File

@ -58,6 +58,10 @@ namespace Net
void setKeepAliveTimeout(Poco::Timespan keepAliveTimeout);
size_t getKeepAliveTimeout() const { return _keepAliveTimeout.totalSeconds(); }
size_t getMaxKeepAliveRequests() const { return _maxKeepAliveRequests; }
private:
bool _firstRequest;
Poco::Timespan _keepAliveTimeout;

View File

@ -19,11 +19,11 @@ namespace Poco {
namespace Net {
HTTPServerSession::HTTPServerSession(const StreamSocket& socket, HTTPServerParams::Ptr pParams):
HTTPSession(socket, pParams->getKeepAlive()),
_firstRequest(true),
_keepAliveTimeout(pParams->getKeepAliveTimeout()),
_maxKeepAliveRequests(pParams->getMaxKeepAliveRequests())
HTTPServerSession::HTTPServerSession(const StreamSocket & socket, HTTPServerParams::Ptr pParams)
: HTTPSession(socket, pParams->getKeepAlive())
, _firstRequest(true)
, _keepAliveTimeout(pParams->getKeepAliveTimeout())
, _maxKeepAliveRequests(pParams->getMaxKeepAliveRequests())
{
setTimeout(pParams->getTimeout());
}
@ -56,7 +56,8 @@ bool HTTPServerSession::hasMoreRequests()
--_maxKeepAliveRequests;
return buffered() > 0 || socket().poll(_keepAliveTimeout, Socket::SELECT_READ);
}
else return false;
else
return false;
}

View File

@ -57,7 +57,7 @@ option(WITH_COVERAGE "Instrumentation for code coverage with default implementat
if (WITH_COVERAGE)
message (STATUS "Enabled instrumentation for code coverage")
set(COVERAGE_FLAGS "SHELL:-fprofile-instr-generate -fcoverage-mapping")
set (COVERAGE_FLAGS -fprofile-instr-generate -fcoverage-mapping)
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fprofile-instr-generate -fcoverage-mapping")
endif()

View File

@ -205,9 +205,8 @@ add_contrib (morton-nd-cmake morton-nd)
if (ARCH_S390X)
add_contrib(crc32-s390x-cmake crc32-s390x)
endif()
add_contrib (annoy-cmake annoy)
option(ENABLE_USEARCH "Enable USearch (Approximate Neighborhood Search, HNSW) support" ${ENABLE_LIBRARIES})
option(ENABLE_USEARCH "Enable USearch" ${ENABLE_LIBRARIES})
if (ENABLE_USEARCH)
add_contrib (FP16-cmake FP16)
add_contrib (robin-map-cmake robin-map)

1
contrib/annoy vendored

@ -1 +0,0 @@
Subproject commit f2ac8e7b48f9a9cf676d3b58286e5455aba8e956

View File

@ -1,24 +0,0 @@
option(ENABLE_ANNOY "Enable Annoy index support" ${ENABLE_LIBRARIES})
# Annoy index should be disabled with undefined sanitizer. Because of memory storage optimizations
# (https://github.com/ClickHouse/annoy/blob/9d8a603a4cd252448589e84c9846f94368d5a289/src/annoylib.h#L442-L463)
# UBSan fails and leads to crash. Simmilar issue is already opened in Annoy repo
# https://github.com/spotify/annoy/issues/456
# Problem with aligment can lead to errors like
# (https://stackoverflow.com/questions/46790550/c-undefined-behavior-strict-aliasing-rule-or-incorrect-alignment)
# or will lead to crash on arm https://developer.arm.com/documentation/ka003038/latest
# This issues should be resolved before annoy became non-experimental (--> setting "allow_experimental_annoy_index")
if ((NOT ENABLE_ANNOY) OR (SANITIZE STREQUAL "undefined") OR (ARCH_AARCH64))
message (STATUS "Not using annoy")
return()
endif()
set(ANNOY_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/annoy")
set(ANNOY_SOURCE_DIR "${ANNOY_PROJECT_DIR}/src")
add_library(_annoy INTERFACE)
target_include_directories(_annoy SYSTEM INTERFACE ${ANNOY_SOURCE_DIR})
add_library(ch_contrib::annoy ALIAS _annoy)
target_compile_definitions(_annoy INTERFACE ENABLE_ANNOY)
target_compile_definitions(_annoy INTERFACE ANNOYLIB_MULTITHREADED_BUILD)

@ -1 +1 @@
Subproject commit 1f95f8083066f5b38fd2db172e7e7f9aa7c49d2d
Subproject commit b922c8ab9004ef9944982e4f165e2747b13223fa

View File

@ -1,9 +1,7 @@
set(USEARCH_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/usearch")
set(USEARCH_SOURCE_DIR "${USEARCH_PROJECT_DIR}/include")
set(FP16_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/FP16")
set(ROBIN_MAP_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/robin-map")
set(SIMSIMD_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/SimSIMD-map")
set(SIMSIMD_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/SimSIMD")
set(USEARCH_PROJECT_DIR "${ClickHouse_SOURCE_DIR}/contrib/usearch")
add_library(_usearch INTERFACE)
@ -11,7 +9,6 @@ target_include_directories(_usearch SYSTEM INTERFACE
${FP16_PROJECT_DIR}/include
${ROBIN_MAP_PROJECT_DIR}/include
${SIMSIMD_PROJECT_DIR}/include
${USEARCH_SOURCE_DIR})
${USEARCH_PROJECT_DIR}/include)
add_library(ch_contrib::usearch ALIAS _usearch)
target_compile_definitions(_usearch INTERFACE ENABLE_USEARCH)

View File

@ -108,7 +108,8 @@ if [ -n "$MAKE_DEB" ]; then
bash -x /build/packages/build
fi
mv ./programs/clickhouse* /output || mv ./programs/*_fuzzer /output
mv ./programs/clickhouse* /output ||:
mv ./programs/*_fuzzer /output ||:
[ -x ./programs/self-extracting/clickhouse ] && mv ./programs/self-extracting/clickhouse /output
[ -x ./programs/self-extracting/clickhouse-stripped ] && mv ./programs/self-extracting/clickhouse-stripped /output
[ -x ./programs/self-extracting/clickhouse-keeper ] && mv ./programs/self-extracting/clickhouse-keeper /output

View File

@ -17,7 +17,7 @@ In terms of SQL, the nearest neighborhood problem can be expressed as follows:
``` sql
SELECT *
FROM table_with_ann_index
FROM table
ORDER BY Distance(vectors, Point)
LIMIT N
```
@ -27,75 +27,109 @@ Function `Distance` computes the distance between two vectors. Often, the Euclid
distance functions](/docs/en/sql-reference/functions/distance-functions.md) are also possible. `Point` is the reference point, e.g. `(0.17,
0.33, ...)`, and `N` limits the number of search results.
An alternative formulation of the nearest neighborhood search problem looks as follows:
This query returns the top-`N` closest points to the reference point. Parameter `N` limits the number of returned values which is useful for
situations where `MaxDistance` is difficult to determine in advance.
``` sql
SELECT *
FROM table_with_ann_index
WHERE Distance(vectors, Point) < MaxDistance
LIMIT N
```
While the first query returns the top-`N` closest points to the reference point, the second query returns all points closer to the reference
point than a maximally allowed radius `MaxDistance`. Parameter `N` limits the number of returned values which is useful for situations where
`MaxDistance` is difficult to determine in advance.
With brute force search, both queries are expensive (linear in the number of points) because the distance between all points in `vectors` and
With brute force search, the query is expensive (linear in the number of points) because the distance between all points in `vectors` and
`Point` must be computed. To speed this process up, Approximate Nearest Neighbor Search Indexes (ANN indexes) store a compact representation
of the search space (using clustering, search trees, etc.) which allows to compute an approximate answer much quicker (in sub-linear time).
# Creating and Using ANN Indexes {#creating_using_ann_indexes}
# Creating and Using Vector Similarity Indexes
Syntax to create an ANN index over an [Array(Float32)](../../../sql-reference/data-types/array.md) column:
Syntax to create a vector similarity index over an [Array(Float32)](../../../sql-reference/data-types/array.md) column:
```sql
CREATE TABLE table_with_ann_index
CREATE TABLE table
(
`id` Int64,
`vectors` Array(Float32),
INDEX [ann_index_name vectors TYPE [ann_index_type]([ann_index_parameters]) [GRANULARITY [N]]
id Int64,
vectors Array(Float32),
INDEX index_name vectors TYPE vector_similarity(method, distance_function[, quantization, connectivity, expansion_add, expansion_search]) [GRANULARITY N]
)
ENGINE = MergeTree
ORDER BY id;
```
Parameters:
- `method`: Supports currently only `hnsw`.
- `distance_function`: either `L2Distance` (the [Euclidean distance](https://en.wikipedia.org/wiki/Euclidean_distance) - the length of a
line between two points in Euclidean space), or `cosineDistance` (the [cosine
distance](https://en.wikipedia.org/wiki/Cosine_similarity#Cosine_distance)- the angle between two non-zero vectors).
- `quantization`: either `f32`, `f16`, or `i8` for storing the vector with reduced precision (optional, default: `f32`)
- `m`: the number of neighbors per graph node (optional, default: 16)
- `ef_construction`: (optional, default: 128)
- `ef_search`: (optional, default: 64)
Example:
```sql
CREATE TABLE table
(
id Int64,
vectors Array(Float32),
INDEX idx vectors TYPE vector_similarity('hnsw', 'L2Distance') -- Alternative syntax: TYPE vector_similarity(hnsw, L2Distance)
)
ENGINE = MergeTree
ORDER BY id;
```
Vector similarity indexes are based on the [USearch library](https://github.com/unum-cloud/usearch), which implements the [HNSW
algorithm](https://arxiv.org/abs/1603.09320), i.e., a hierarchical graph where each point represents a vector and the edges represent
similarity. Such hierarchical structures can be very efficient on large collections. They may often fetch 0.05% or less data from the
overall dataset, while still providing 99% recall. This is especially useful when working with high-dimensional vectors, that are expensive
to load and compare. The library also has several hardware-specific SIMD optimizations to accelerate further distance computations on modern
Arm (NEON and SVE) and x86 (AVX2 and AVX-512) CPUs and OS-specific optimizations to allow efficient navigation around immutable persistent
files, without loading them into RAM.
USearch indexes are currently experimental, to use them you first need to `SET allow_experimental_vector_similarity_index = 1`.
Vector similarity indexes currently support two distance functions:
- `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space
([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)).
- `cosineDistance`, also called cosine similarity, is the cosine of the angle between two (non-zero) vectors
([Wikipedia](https://en.wikipedia.org/wiki/Cosine_similarity)).
Vector similarity indexes allows storing the vectors in reduced precision formats. Supported scalar kinds are `f64`, `f32`, `f16` or `i8`.
If no scalar kind was specified during index creation, `f16` is used as default.
For normalized data, `L2Distance` is usually a better choice, otherwise `cosineDistance` is recommended to compensate for scale. If no
distance function was specified during index creation, `L2Distance` is used as default.
:::note
All arrays must have same length. To avoid errors, you can use a
[CONSTRAINT](/docs/en/sql-reference/statements/create/table.md#constraints), for example, `CONSTRAINT constraint_name_1 CHECK
length(vectors) = 256`. Also, empty `Arrays` and unspecified `Array` values in INSERT statements (i.e. default values) are not supported.
:::
:::note
The vector similarity index currently does not work with per-table, non-default `index_granularity` settings (see
[here](https://github.com/ClickHouse/ClickHouse/pull/51325#issuecomment-1605920475)). If necessary, the value must be changed in config.xml.
:::
ANN indexes are built during column insertion and merge. As a result, `INSERT` and `OPTIMIZE` statements will be slower than for ordinary
tables. ANNIndexes are ideally used only with immutable or rarely changed data, respectively when are far more read requests than write
requests.
ANN indexes support two types of queries:
- ORDER BY queries:
ANN indexes support these queries:
``` sql
SELECT *
FROM table_with_ann_index
FROM table
[WHERE ...]
ORDER BY Distance(vectors, Point)
LIMIT N
```
- WHERE queries:
``` sql
SELECT *
FROM table_with_ann_index
WHERE Distance(vectors, Point) < MaxDistance
LIMIT N
```
:::tip
To avoid writing out large vectors, you can use [query
parameters](/docs/en/interfaces/cli.md#queries-with-parameters-cli-queries-with-parameters), e.g.
```bash
clickhouse-client --param_vec='hello' --query="SELECT * FROM table_with_ann_index WHERE L2Distance(vectors, {vec: Array(Float32)}) < 1.0"
clickhouse-client --param_vec='hello' --query="SELECT * FROM table WHERE L2Distance(vectors, {vec: Array(Float32)}) < 1.0"
```
:::
**Restrictions**: Queries that contain both a `WHERE Distance(vectors, Point) < MaxDistance` and an `ORDER BY Distance(vectors, Point)`
clause cannot use ANN indexes. Also, the approximate algorithms used to determine the nearest neighbors require a limit, hence queries
without `LIMIT` clause cannot utilize ANN indexes. Also, ANN indexes are only used if the query has a `LIMIT` value smaller than setting
**Restrictions**: Approximate algorithms used to determine the nearest neighbors require a limit, hence queries without `LIMIT` clause
cannot utilize ANN indexes. Also, ANN indexes are only used if the query has a `LIMIT` value smaller than setting
`max_limit_for_ann_queries` (default: 1 million rows). This is a safeguard to prevent large memory allocations by external libraries for
approximate neighbor search.
@ -122,128 +156,3 @@ brute-force distance calculation over all rows of the granules. With a small `GR
equally good, only the processing performance differs. It is generally recommended to use a large `GRANULARITY` for ANN indexes and fall
back to a smaller `GRANULARITY` values only in case of problems like excessive memory consumption of the ANN structures. If no `GRANULARITY`
was specified for ANN indexes, the default value is 100 million.
# Available ANN Indexes {#available_ann_indexes}
- [Annoy](/docs/en/engines/table-engines/mergetree-family/annindexes.md#annoy-annoy)
- [USearch](/docs/en/engines/table-engines/mergetree-family/annindexes.md#usearch-usearch)
## Annoy {#annoy}
Annoy indexes are currently experimental, to use them you first need to `SET allow_experimental_annoy_index = 1`. They are also currently
disabled on ARM due to memory safety problems with the algorithm.
This type of ANN index is based on the [Annoy library](https://github.com/spotify/annoy) which recursively divides the space into random
linear surfaces (lines in 2D, planes in 3D etc.).
<div class='vimeo-container'>
<iframe src="//www.youtube.com/embed/QkCCyLW0ehU"
width="640"
height="360"
frameborder="0"
allow="autoplay;
fullscreen;
picture-in-picture"
allowfullscreen>
</iframe>
</div>
Syntax to create an Annoy index over an [Array(Float32)](../../../sql-reference/data-types/array.md) column:
```sql
CREATE TABLE table_with_annoy_index
(
id Int64,
vectors Array(Float32),
INDEX [ann_index_name] vectors TYPE annoy([Distance[, NumTrees]]) [GRANULARITY N]
)
ENGINE = MergeTree
ORDER BY id;
```
Annoy currently supports two distance functions:
- `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space
([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)).
- `cosineDistance`, also called cosine similarity, is the cosine of the angle between two (non-zero) vectors
([Wikipedia](https://en.wikipedia.org/wiki/Cosine_similarity)).
For normalized data, `L2Distance` is usually a better choice, otherwise `cosineDistance` is recommended to compensate for scale. If no
distance function was specified during index creation, `L2Distance` is used as default.
Parameter `NumTrees` is the number of trees which the algorithm creates (default if not specified: 100). Higher values of `NumTree` mean
more accurate search results but slower index creation / query times (approximately linearly) as well as larger index sizes.
:::note
All arrays must have same length. To avoid errors, you can use a
[CONSTRAINT](/docs/en/sql-reference/statements/create/table.md#constraints), for example, `CONSTRAINT constraint_name_1 CHECK
length(vectors) = 256`. Also, empty `Arrays` and unspecified `Array` values in INSERT statements (i.e. default values) are not supported.
:::
The creation of Annoy indexes (whenever a new part is build, e.g. at the end of a merge) is a relatively slow process. You can increase
setting `max_threads_for_annoy_index_creation` (default: 4) which controls how many threads are used to create an Annoy index. Please be
careful with this setting, it is possible that multiple indexes are created in parallel in which case there can be overparallelization.
Setting `annoy_index_search_k_nodes` (default: `NumTrees * LIMIT`) determines how many tree nodes are inspected during SELECTs. Larger
values mean more accurate results at the cost of longer query runtime:
```sql
SELECT *
FROM table_name
ORDER BY L2Distance(vectors, Point)
LIMIT N
SETTINGS annoy_index_search_k_nodes=100;
```
:::note
The Annoy index currently does not work with per-table, non-default `index_granularity` settings (see
[here](https://github.com/ClickHouse/ClickHouse/pull/51325#issuecomment-1605920475)). If necessary, the value must be changed in config.xml.
:::
## USearch {#usearch}
This type of ANN index is based on the [USearch library](https://github.com/unum-cloud/usearch), which implements the [HNSW
algorithm](https://arxiv.org/abs/1603.09320), i.e., builds a hierarchical graph where each point represents a vector and the edges represent
similarity. Such hierarchical structures can be very efficient on large collections. They may often fetch 0.05% or less data from the
overall dataset, while still providing 99% recall. This is especially useful when working with high-dimensional vectors,
that are expensive to load and compare. The library also has several hardware-specific SIMD optimizations to accelerate further
distance computations on modern Arm (NEON and SVE) and x86 (AVX2 and AVX-512) CPUs and OS-specific optimizations to allow efficient
navigation around immutable persistent files, without loading them into RAM.
<div class='vimeo-container'>
<iframe src="//www.youtube.com/embed/UMrhB3icP9w"
width="640"
height="360"
frameborder="0"
allow="autoplay;
fullscreen;
picture-in-picture"
allowfullscreen>
</iframe>
</div>
Syntax to create an USearch index over an [Array](../../../sql-reference/data-types/array.md) column:
```sql
CREATE TABLE table_with_usearch_index
(
id Int64,
vectors Array(Float32),
INDEX [ann_index_name] vectors TYPE usearch([Distance[, ScalarKind]]) [GRANULARITY N]
)
ENGINE = MergeTree
ORDER BY id;
```
USearch currently supports two distance functions:
- `L2Distance`, also called Euclidean distance, is the length of a line segment between two points in Euclidean space
([Wikipedia](https://en.wikipedia.org/wiki/Euclidean_distance)).
- `cosineDistance`, also called cosine similarity, is the cosine of the angle between two (non-zero) vectors
([Wikipedia](https://en.wikipedia.org/wiki/Cosine_similarity)).
USearch allows storing the vectors in reduced precision formats. Supported scalar kinds are `f64`, `f32`, `f16` or `i8`. If no scalar kind
was specified during index creation, `f16` is used as default.
For normalized data, `L2Distance` is usually a better choice, otherwise `cosineDistance` is recommended to compensate for scale. If no
distance function was specified during index creation, `L2Distance` is used as default.

View File

@ -1400,6 +1400,16 @@ The number of seconds that ClickHouse waits for incoming requests before closing
<keep_alive_timeout>10</keep_alive_timeout>
```
## max_keep_alive_requests {#max-keep-alive-requests}
Maximal number of requests through a single keep-alive connection until it will be closed by ClickHouse server. Default to 10000.
**Example**
``` xml
<max_keep_alive_requests>10</max_keep_alive_requests>
```
## listen_host {#listen_host}
Restriction on hosts that requests can come from. If you want the server to answer all of them, specify `::`.

View File

@ -1,4 +1,4 @@
add_compile_options($<$<OR:$<COMPILE_LANGUAGE:C>,$<COMPILE_LANGUAGE:CXX>>:${COVERAGE_FLAGS}>)
add_compile_options("$<$<OR:$<COMPILE_LANGUAGE:C>,$<COMPILE_LANGUAGE:CXX>>:${COVERAGE_FLAGS}>")
if (USE_CLANG_TIDY)
set (CMAKE_CXX_CLANG_TIDY "${CLANG_TIDY_PATH}")

View File

@ -27,7 +27,7 @@ std::string LibraryBridge::bridgeName() const
LibraryBridge::HandlerFactoryPtr LibraryBridge::getHandlerFactoryPtr(ContextPtr context) const
{
return std::make_shared<LibraryBridgeHandlerFactory>("LibraryRequestHandlerFactory", keep_alive_timeout, context);
return std::make_shared<LibraryBridgeHandlerFactory>("LibraryRequestHandlerFactory", context);
}
}

View File

@ -9,12 +9,10 @@ namespace DB
{
LibraryBridgeHandlerFactory::LibraryBridgeHandlerFactory(
const std::string & name_,
size_t keep_alive_timeout_,
ContextPtr context_)
: WithContext(context_)
, log(getLogger(name_))
, name(name_)
, keep_alive_timeout(keep_alive_timeout_)
{
}
@ -26,17 +24,17 @@ std::unique_ptr<HTTPRequestHandler> LibraryBridgeHandlerFactory::createRequestHa
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET)
{
if (uri.getPath() == "/extdict_ping")
return std::make_unique<ExternalDictionaryLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
return std::make_unique<ExternalDictionaryLibraryBridgeExistsHandler>(getContext());
else if (uri.getPath() == "/catboost_ping")
return std::make_unique<CatBoostLibraryBridgeExistsHandler>(keep_alive_timeout, getContext());
return std::make_unique<CatBoostLibraryBridgeExistsHandler>(getContext());
}
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
{
if (uri.getPath() == "/extdict_request")
return std::make_unique<ExternalDictionaryLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
return std::make_unique<ExternalDictionaryLibraryBridgeRequestHandler>(getContext());
else if (uri.getPath() == "/catboost_request")
return std::make_unique<CatBoostLibraryBridgeRequestHandler>(keep_alive_timeout, getContext());
return std::make_unique<CatBoostLibraryBridgeRequestHandler>(getContext());
}
return nullptr;

View File

@ -13,7 +13,6 @@ class LibraryBridgeHandlerFactory : public HTTPRequestHandlerFactory, WithContex
public:
LibraryBridgeHandlerFactory(
const std::string & name_,
size_t keep_alive_timeout_,
ContextPtr context_);
std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) override;
@ -21,7 +20,6 @@ public:
private:
LoggerPtr log;
const std::string name;
const size_t keep_alive_timeout;
};
}

View File

@ -87,10 +87,8 @@ static void writeData(Block data, OutputFormatPtr format)
}
ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(getLogger("ExternalDictionaryLibraryBridgeRequestHandler"))
ExternalDictionaryLibraryBridgeRequestHandler::ExternalDictionaryLibraryBridgeRequestHandler(ContextPtr context_)
: WithContext(context_), log(getLogger("ExternalDictionaryLibraryBridgeRequestHandler"))
{
}
@ -137,7 +135,7 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
const String & dictionary_id = params.get("dictionary_id");
LOG_TRACE(log, "Library method: '{}', dictionary id: {}", method, dictionary_id);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
@ -374,10 +372,8 @@ void ExternalDictionaryLibraryBridgeRequestHandler::handleRequest(HTTPServerRequ
}
ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(getLogger("ExternalDictionaryLibraryBridgeExistsHandler"))
ExternalDictionaryLibraryBridgeExistsHandler::ExternalDictionaryLibraryBridgeExistsHandler(ContextPtr context_)
: WithContext(context_), log(getLogger("ExternalDictionaryLibraryBridgeExistsHandler"))
{
}
@ -401,7 +397,7 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
String res = library_handler ? "1" : "0";
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
LOG_TRACE(log, "Sending ping response: {} (dictionary id: {})", res, dictionary_id);
response.sendBuffer(res.data(), res.size());
}
@ -412,11 +408,8 @@ void ExternalDictionaryLibraryBridgeExistsHandler::handleRequest(HTTPServerReque
}
CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(
size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(getLogger("CatBoostLibraryBridgeRequestHandler"))
CatBoostLibraryBridgeRequestHandler::CatBoostLibraryBridgeRequestHandler(ContextPtr context_)
: WithContext(context_), log(getLogger("CatBoostLibraryBridgeRequestHandler"))
{
}
@ -455,7 +448,7 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
const String & method = params.get("method");
LOG_TRACE(log, "Library method: '{}'", method);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
@ -617,10 +610,8 @@ void CatBoostLibraryBridgeRequestHandler::handleRequest(HTTPServerRequest & requ
}
CatBoostLibraryBridgeExistsHandler::CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, keep_alive_timeout(keep_alive_timeout_)
, log(getLogger("CatBoostLibraryBridgeExistsHandler"))
CatBoostLibraryBridgeExistsHandler::CatBoostLibraryBridgeExistsHandler(ContextPtr context_)
: WithContext(context_), log(getLogger("CatBoostLibraryBridgeExistsHandler"))
{
}
@ -634,7 +625,7 @@ void CatBoostLibraryBridgeExistsHandler::handleRequest(HTTPServerRequest & reque
String res = "1";
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
LOG_TRACE(log, "Sending ping response: {}", res);
response.sendBuffer(res.data(), res.size());
}

View File

@ -18,14 +18,13 @@ namespace DB
class ExternalDictionaryLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
{
public:
ExternalDictionaryLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_);
explicit ExternalDictionaryLibraryBridgeRequestHandler(ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
static constexpr auto FORMAT = "RowBinary";
const size_t keep_alive_timeout;
LoggerPtr log;
};
@ -34,12 +33,11 @@ private:
class ExternalDictionaryLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
{
public:
ExternalDictionaryLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_);
explicit ExternalDictionaryLibraryBridgeExistsHandler(ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
const size_t keep_alive_timeout;
LoggerPtr log;
};
@ -63,12 +61,11 @@ private:
class CatBoostLibraryBridgeRequestHandler : public HTTPRequestHandler, WithContext
{
public:
CatBoostLibraryBridgeRequestHandler(size_t keep_alive_timeout_, ContextPtr context_);
explicit CatBoostLibraryBridgeRequestHandler(ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
const size_t keep_alive_timeout;
LoggerPtr log;
};
@ -77,12 +74,11 @@ private:
class CatBoostLibraryBridgeExistsHandler : public HTTPRequestHandler, WithContext
{
public:
CatBoostLibraryBridgeExistsHandler(size_t keep_alive_timeout_, ContextPtr context_);
explicit CatBoostLibraryBridgeExistsHandler(ContextPtr context_);
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
const size_t keep_alive_timeout;
LoggerPtr log;
};

View File

@ -202,10 +202,7 @@ void ODBCColumnsInfoHandler::handleRequest(HTTPServerRequest & request, HTTPServ
if (columns.empty())
throw Exception(ErrorCodes::UNKNOWN_TABLE, "Columns definition was not returned");
WriteBufferFromHTTPServerResponse out(
response,
request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD,
keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
writeStringBinary(columns.toString(), out);

View File

@ -15,18 +15,12 @@ namespace DB
class ODBCColumnsInfoHandler : public HTTPRequestHandler, WithContext
{
public:
ODBCColumnsInfoHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(getLogger("ODBCColumnsInfoHandler"))
, keep_alive_timeout(keep_alive_timeout_)
{
}
explicit ODBCColumnsInfoHandler(ContextPtr context_) : WithContext(context_), log(getLogger("ODBCColumnsInfoHandler")) { }
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
LoggerPtr log;
size_t keep_alive_timeout;
};
}

View File

@ -74,7 +74,7 @@ void IdentifierQuoteHandler::handleRequest(HTTPServerRequest & request, HTTPServ
auto identifier = getIdentifierQuote(std::move(connection));
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
writeStringBinary(identifier, out);

View File

@ -14,18 +14,12 @@ namespace DB
class IdentifierQuoteHandler : public HTTPRequestHandler, WithContext
{
public:
IdentifierQuoteHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(getLogger("IdentifierQuoteHandler"))
, keep_alive_timeout(keep_alive_timeout_)
{
}
explicit IdentifierQuoteHandler(ContextPtr context_) : WithContext(context_), log(getLogger("IdentifierQuoteHandler")) { }
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
LoggerPtr log;
size_t keep_alive_timeout;
};
}

View File

@ -132,7 +132,7 @@ void ODBCHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse
return;
}
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{

View File

@ -20,12 +20,10 @@ class ODBCHandler : public HTTPRequestHandler, WithContext
{
public:
ODBCHandler(
size_t keep_alive_timeout_,
ContextPtr context_,
const String & mode_)
: WithContext(context_)
, log(getLogger("ODBCHandler"))
, keep_alive_timeout(keep_alive_timeout_)
, mode(mode_)
{
}
@ -35,7 +33,6 @@ public:
private:
LoggerPtr log;
size_t keep_alive_timeout;
String mode;
static inline std::mutex mutex;

View File

@ -27,7 +27,7 @@ std::string ODBCBridge::bridgeName() const
ODBCBridge::HandlerFactoryPtr ODBCBridge::getHandlerFactoryPtr(ContextPtr context) const
{
return std::make_shared<ODBCBridgeHandlerFactory>("ODBCRequestHandlerFactory-factory", keep_alive_timeout, context);
return std::make_shared<ODBCBridgeHandlerFactory>("ODBCRequestHandlerFactory-factory", context);
}
}

View File

@ -9,11 +9,8 @@
namespace DB
{
ODBCBridgeHandlerFactory::ODBCBridgeHandlerFactory(const std::string & name_, size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(getLogger(name_))
, name(name_)
, keep_alive_timeout(keep_alive_timeout_)
ODBCBridgeHandlerFactory::ODBCBridgeHandlerFactory(const std::string & name_, ContextPtr context_)
: WithContext(context_), log(getLogger(name_)), name(name_)
{
}
@ -23,33 +20,33 @@ std::unique_ptr<HTTPRequestHandler> ODBCBridgeHandlerFactory::createRequestHandl
LOG_TRACE(log, "Request URI: {}", uri.toString());
if (uri.getPath() == "/ping" && request.getMethod() == Poco::Net::HTTPRequest::HTTP_GET)
return std::make_unique<PingHandler>(keep_alive_timeout);
return std::make_unique<PingHandler>();
if (request.getMethod() == Poco::Net::HTTPRequest::HTTP_POST)
{
if (uri.getPath() == "/columns_info")
#if USE_ODBC
return std::make_unique<ODBCColumnsInfoHandler>(keep_alive_timeout, getContext());
return std::make_unique<ODBCColumnsInfoHandler>(getContext());
#else
return nullptr;
#endif
else if (uri.getPath() == "/identifier_quote")
#if USE_ODBC
return std::make_unique<IdentifierQuoteHandler>(keep_alive_timeout, getContext());
return std::make_unique<IdentifierQuoteHandler>(getContext());
#else
return nullptr;
#endif
else if (uri.getPath() == "/schema_allowed")
#if USE_ODBC
return std::make_unique<SchemaAllowedHandler>(keep_alive_timeout, getContext());
return std::make_unique<SchemaAllowedHandler>(getContext());
#else
return nullptr;
#endif
else if (uri.getPath() == "/write")
return std::make_unique<ODBCHandler>(keep_alive_timeout, getContext(), "write");
return std::make_unique<ODBCHandler>(getContext(), "write");
else
return std::make_unique<ODBCHandler>(keep_alive_timeout, getContext(), "read");
return std::make_unique<ODBCHandler>(getContext(), "read");
}
return nullptr;
}

View File

@ -17,14 +17,13 @@ namespace DB
class ODBCBridgeHandlerFactory : public HTTPRequestHandlerFactory, WithContext
{
public:
ODBCBridgeHandlerFactory(const std::string & name_, size_t keep_alive_timeout_, ContextPtr context_);
ODBCBridgeHandlerFactory(const std::string & name_, ContextPtr context_);
std::unique_ptr<HTTPRequestHandler> createRequestHandler(const HTTPServerRequest & request) override;
private:
LoggerPtr log;
std::string name;
size_t keep_alive_timeout;
};
}

View File

@ -10,7 +10,7 @@ void PingHandler::handleRequest(HTTPServerRequest & /* request */, HTTPServerRes
{
try
{
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
const char * data = "Ok.\n";
response.sendBuffer(data, strlen(data));
}

View File

@ -9,11 +9,7 @@ namespace DB
class PingHandler : public HTTPRequestHandler
{
public:
explicit PingHandler(size_t keep_alive_timeout_) : keep_alive_timeout(keep_alive_timeout_) {}
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
size_t keep_alive_timeout;
};
}

View File

@ -88,7 +88,7 @@ void SchemaAllowedHandler::handleRequest(HTTPServerRequest & request, HTTPServer
bool result = isSchemaAllowed(std::move(connection));
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout);
WriteBufferFromHTTPServerResponse out(response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD);
try
{
writeBoolText(result, out);

View File

@ -17,18 +17,12 @@ class Context;
class SchemaAllowedHandler : public HTTPRequestHandler, WithContext
{
public:
SchemaAllowedHandler(size_t keep_alive_timeout_, ContextPtr context_)
: WithContext(context_)
, log(getLogger("SchemaAllowedHandler"))
, keep_alive_timeout(keep_alive_timeout_)
{
}
explicit SchemaAllowedHandler(ContextPtr context_) : WithContext(context_), log(getLogger("SchemaAllowedHandler")) { }
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event) override;
private:
LoggerPtr log;
size_t keep_alive_timeout;
};
}

View File

@ -2428,6 +2428,7 @@ void Server::createServers(
Poco::Net::HTTPServerParams::Ptr http_params = new Poco::Net::HTTPServerParams;
http_params->setTimeout(settings.http_receive_timeout);
http_params->setKeepAliveTimeout(global_context->getServerSettings().keep_alive_timeout);
http_params->setMaxKeepAliveRequests(static_cast<int>(global_context->getServerSettings().max_keep_alive_requests));
Poco::Util::AbstractConfiguration::Keys protocols;
config.keys("protocols", protocols);

View File

@ -10,6 +10,7 @@
#include <Poco/Net/SocketAddress.h>
#include <Poco/Net/StreamSocket.h>
#include <Daemon/BaseDaemon.h>
#include <Interpreters/Context.h>
@ -25,6 +26,12 @@ static int64_t port = 9000;
using namespace std::chrono_literals;
void on_exit()
{
BaseDaemon::terminate();
main_app.wait();
}
extern "C"
int LLVMFuzzerInitialize(int * argc, char ***argv)
{
@ -60,6 +67,8 @@ int LLVMFuzzerInitialize(int * argc, char ***argv)
exit(-1);
}
atexit(on_exit);
return 0;
}

View File

@ -1,2 +1,2 @@
clickhouse_add_executable(aggregate_function_state_deserialization_fuzzer aggregate_function_state_deserialization_fuzzer.cpp ${SRCS})
target_link_libraries(aggregate_function_state_deserialization_fuzzer PRIVATE dbms clickhouse_aggregate_functions clickhouse_functions)
target_link_libraries(aggregate_function_state_deserialization_fuzzer PRIVATE clickhouse_functions clickhouse_aggregate_functions)

View File

@ -1,4 +1,4 @@
add_compile_options($<$<OR:$<COMPILE_LANGUAGE:C>,$<COMPILE_LANGUAGE:CXX>>:${COVERAGE_FLAGS}>)
add_compile_options("$<$<OR:$<COMPILE_LANGUAGE:C>,$<COMPILE_LANGUAGE:CXX>>:${COVERAGE_FLAGS}>")
if (USE_INCLUDE_WHAT_YOU_USE)
set (CMAKE_CXX_INCLUDE_WHAT_YOU_USE ${IWYU_PATH})
@ -601,10 +601,6 @@ endif()
dbms_target_link_libraries(PUBLIC ch_contrib::consistent_hashing)
if (TARGET ch_contrib::annoy)
dbms_target_link_libraries(PUBLIC ch_contrib::annoy)
endif()
if (TARGET ch_contrib::usearch)
dbms_target_link_libraries(PUBLIC ch_contrib::usearch)
endif()

View File

@ -2751,7 +2751,7 @@ void ClientBase::runLibFuzzer()
for (auto & arg : fuzzer_args_holder)
fuzzer_args.emplace_back(arg.data());
int fuzzer_argc = fuzzer_args.size();
int fuzzer_argc = static_cast<int>(fuzzer_args.size());
char ** fuzzer_argv = fuzzer_args.data();
LLVMFuzzerRunDriver(&fuzzer_argc, &fuzzer_argv, [](const uint8_t * data, size_t size)

View File

@ -58,6 +58,7 @@
#cmakedefine01 USE_FILELOG
#cmakedefine01 USE_ODBC
#cmakedefine01 USE_BLAKE3
#cmakedefine01 USE_USEARCH
#cmakedefine01 USE_SKIM
#cmakedefine01 USE_PRQL
#cmakedefine01 USE_ULID

View File

@ -134,6 +134,7 @@ namespace DB
M(Bool, async_load_databases, false, "Enable asynchronous loading of databases and tables to speedup server startup. Queries to not yet loaded entity will be blocked until load is finished.", 0) \
M(Bool, display_secrets_in_show_and_select, false, "Allow showing secrets in SHOW and SELECT queries via a format setting and a grant", 0) \
M(Seconds, keep_alive_timeout, DEFAULT_HTTP_KEEP_ALIVE_TIMEOUT, "The number of seconds that ClickHouse waits for incoming requests before closing the connection.", 0) \
M(UInt64, max_keep_alive_requests, 10000, "The maximum number of requests handled via a single http keepalive connection before the server closes this connection.", 0) \
M(Seconds, replicated_fetches_http_connection_timeout, 0, "HTTP connection timeout for part fetch requests. Inherited from default profile `http_connection_timeout` if not set explicitly.", 0) \
M(Seconds, replicated_fetches_http_send_timeout, 0, "HTTP send timeout for part fetch requests. Inherited from default profile `http_send_timeout` if not set explicitly.", 0) \
M(Seconds, replicated_fetches_http_receive_timeout, 0, "HTTP receive timeout for fetch part requests. Inherited from default profile `http_receive_timeout` if not set explicitly.", 0) \

View File

@ -908,14 +908,11 @@ class IColumn;
M(Bool, allow_experimental_hash_functions, false, "Enable experimental hash functions", 0) \
M(Bool, allow_experimental_object_type, false, "Allow Object and JSON data types", 0) \
M(Bool, allow_experimental_time_series_table, false, "Allows experimental TimeSeries table engine", 0) \
M(Bool, allow_experimental_vector_similarity_index, false, "Allow experimental vector similarity index", 0) \
M(Bool, allow_experimental_variant_type, false, "Allow Variant data type", 0) \
M(Bool, allow_experimental_dynamic_type, false, "Allow Dynamic data type", 0) \
M(Bool, allow_experimental_annoy_index, false, "Allows to use Annoy index. Disabled by default because this feature is experimental", 0) \
M(Bool, allow_experimental_usearch_index, false, "Allows to use USearch index. Disabled by default because this feature is experimental", 0) \
M(Bool, allow_experimental_codecs, false, "If it is set to true, allow to specify experimental compression codecs (but we don't have those yet and this option does nothing).", 0) \
M(UInt64, max_limit_for_ann_queries, 1'000'000, "SELECT queries with LIMIT bigger than this setting cannot use ANN indexes. Helps to prevent memory overflows in ANN search indexes.", 0) \
M(UInt64, max_threads_for_annoy_index_creation, 4, "Number of threads used to build Annoy indexes (0 means all cores, not recommended)", 0) \
M(Int64, annoy_index_search_k_nodes, -1, "SELECT queries search up to this many nodes in Annoy indexes.", 0) \
M(Bool, throw_on_unsupported_query_inside_transaction, true, "Throw exception if unsupported query is used inside transaction", 0) \
M(TransactionsWaitCSNMode, wait_changes_become_visible_after_commit_mode, TransactionsWaitCSNMode::WAIT_UNKNOWN, "Wait for committed changes to become actually visible in the latest snapshot", 0) \
M(Bool, implicit_transaction, false, "If enabled and not already inside a transaction, wraps the query inside a full transaction (begin + commit or rollback)", 0) \
@ -1039,6 +1036,10 @@ class IColumn;
MAKE_OBSOLETE(M, UInt64, parallel_replicas_min_number_of_granules_to_enable, 0) \
MAKE_OBSOLETE(M, Bool, query_plan_optimize_projection, true) \
MAKE_OBSOLETE(M, Bool, query_cache_store_results_of_queries_with_nondeterministic_functions, false) \
MAKE_OBSOLETE(M, Bool, allow_experimental_annoy_index, false) \
MAKE_OBSOLETE(M, UInt64, max_threads_for_annoy_index_creation, 4) \
MAKE_OBSOLETE(M, Int64, annoy_index_search_k_nodes, -1) \
MAKE_OBSOLETE(M, Bool, allow_experimental_usearch_index, false) \
MAKE_OBSOLETE(M, Bool, optimize_move_functions_out_of_any, false) \
MAKE_OBSOLETE(M, Bool, allow_experimental_undrop_table_query, true) \
MAKE_OBSOLETE(M, Bool, allow_experimental_s3queue, true) \

View File

@ -88,6 +88,7 @@ static std::initializer_list<std::pair<ClickHouseVersion, SettingsChangesHistory
{"enable_analyzer", 1, 1, "Added an alias to a setting `allow_experimental_analyzer`."},
{"optimize_functions_to_subcolumns", false, true, "Enabled settings by default"},
{"parallel_replicas_local_plan", false, false, "Use local plan for local replica in a query with parallel replicas"},
{"allow_experimental_vector_similarity_index", false, false, "Added new setting to allow experimental vector similarity indexes"},
}
},
{"24.7",

View File

@ -1,2 +1,2 @@
clickhouse_add_executable (names_and_types_fuzzer names_and_types_fuzzer.cpp)
target_link_libraries (names_and_types_fuzzer PRIVATE dbms clickhouse_functions)
target_link_libraries (names_and_types_fuzzer PRIVATE clickhouse_functions)

View File

@ -1,2 +1,2 @@
clickhouse_add_executable(data_type_deserialization_fuzzer data_type_deserialization_fuzzer.cpp ${SRCS})
target_link_libraries(data_type_deserialization_fuzzer PRIVATE dbms clickhouse_aggregate_functions clickhouse_functions)
target_link_libraries(data_type_deserialization_fuzzer PRIVATE clickhouse_functions clickhouse_aggregate_functions)

View File

@ -1153,8 +1153,7 @@ void DatabaseReplicated::recoverLostReplica(const ZooKeeperPtr & current_zookeep
query_context->setSetting("allow_experimental_object_type", 1);
query_context->setSetting("allow_experimental_variant_type", 1);
query_context->setSetting("allow_experimental_dynamic_type", 1);
query_context->setSetting("allow_experimental_annoy_index", 1);
query_context->setSetting("allow_experimental_usearch_index", 1);
query_context->setSetting("allow_experimental_vector_similarity_index", 1);
query_context->setSetting("allow_experimental_bigint_types", 1);
query_context->setSetting("allow_experimental_window_functions", 1);
query_context->setSetting("allow_experimental_geo_types", 1);

View File

@ -1,2 +1,2 @@
clickhouse_add_executable(format_fuzzer format_fuzzer.cpp ${SRCS})
target_link_libraries(format_fuzzer PRIVATE dbms clickhouse_aggregate_functions clickhouse_functions)
target_link_libraries(format_fuzzer PRIVATE clickhouse_functions clickhouse_aggregate_functions)

View File

@ -3,7 +3,6 @@
#include <IO/ReadBufferFromMemory.h>
#include <IO/ReadHelpers.h>
#include <Formats/FormatFactory.h>
#include <Formats/registerFormats.h>
#include <QueryPipeline/Pipe.h>

View File

@ -34,14 +34,20 @@ namespace ErrorCodes
extern const int RECEIVED_ERROR_TOO_MANY_REQUESTS;
}
void setResponseDefaultHeaders(HTTPServerResponse & response, size_t keep_alive_timeout)
void setResponseDefaultHeaders(HTTPServerResponse & response)
{
if (!response.getKeepAlive())
return;
Poco::Timespan timeout(keep_alive_timeout, 0);
if (timeout.totalSeconds())
response.set("Keep-Alive", "timeout=" + std::to_string(timeout.totalSeconds()));
const size_t keep_alive_timeout = response.getSession().getKeepAliveTimeout();
const size_t keep_alive_max_requests = response.getSession().getMaxKeepAliveRequests();
if (keep_alive_timeout)
{
if (keep_alive_max_requests)
response.set("Keep-Alive", fmt::format("timeout={}, max={}", keep_alive_timeout, keep_alive_max_requests));
else
response.set("Keep-Alive", fmt::format("timeout={}", keep_alive_timeout));
}
}
HTTPSessionPtr makeHTTPSession(

View File

@ -54,7 +54,7 @@ private:
using HTTPSessionPtr = std::shared_ptr<Poco::Net::HTTPClientSession>;
void setResponseDefaultHeaders(HTTPServerResponse & response, size_t keep_alive_timeout);
void setResponseDefaultHeaders(HTTPServerResponse & response);
/// Create session object to perform requests and set required parameters.
HTTPSessionPtr makeHTTPSession(

View File

@ -787,10 +787,8 @@ InterpreterCreateQuery::TableProperties InterpreterCreateQuery::getTableProperti
if (index_desc.type == INVERTED_INDEX_NAME && !settings.allow_experimental_inverted_index)
throw Exception(ErrorCodes::ILLEGAL_INDEX, "Please use index type 'full_text' instead of 'inverted'");
/// ----
if (index_desc.type == "annoy" && !settings.allow_experimental_annoy_index)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index is disabled. Turn on allow_experimental_annoy_index");
if (index_desc.type == "usearch" && !settings.allow_experimental_usearch_index)
throw Exception(ErrorCodes::INCORRECT_QUERY, "USearch index is disabled. Turn on allow_experimental_usearch_index");
if (index_desc.type == "vector_similarity" && !settings.allow_experimental_vector_similarity_index)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Vector similarity index is disabled. Turn on allow_experimental_vector_similarity_index");
properties.indices.push_back(index_desc);
}

View File

@ -13,8 +13,7 @@ class ASTIndexDeclaration : public IAST
{
public:
static const auto DEFAULT_INDEX_GRANULARITY = 1uz;
static const auto DEFAULT_ANNOY_INDEX_GRANULARITY = 100'000'000uz;
static const auto DEFAULT_USEARCH_INDEX_GRANULARITY = 100'000'000uz;
static const auto DEFAULT_VECTOR_SIMILARITY_INDEX_GRANULARITY = 100'000'000uz;
ASTIndexDeclaration(ASTPtr expression, ASTPtr type, const String & name_);

View File

@ -89,10 +89,8 @@ bool ParserCreateIndexDeclaration::parseImpl(Pos & pos, ASTPtr & node, Expected
else
{
auto index_type = index->getType();
if (index_type && index_type->name == "annoy")
index->granularity = ASTIndexDeclaration::DEFAULT_ANNOY_INDEX_GRANULARITY;
else if (index_type && index_type->name == "usearch")
index->granularity = ASTIndexDeclaration::DEFAULT_USEARCH_INDEX_GRANULARITY;
if (index_type && index_type->name == "vector_similarity")
index->granularity = ASTIndexDeclaration::DEFAULT_VECTOR_SIMILARITY_INDEX_GRANULARITY;
else
index->granularity = ASTIndexDeclaration::DEFAULT_INDEX_GRANULARITY;
}

View File

@ -214,10 +214,8 @@ bool ParserIndexDeclaration::parseImpl(Pos & pos, ASTPtr & node, Expected & expe
else
{
auto index_type = index->getType();
if (index_type->name == "annoy")
index->granularity = ASTIndexDeclaration::DEFAULT_ANNOY_INDEX_GRANULARITY;
else if (index_type->name == "usearch")
index->granularity = ASTIndexDeclaration::DEFAULT_USEARCH_INDEX_GRANULARITY;
if (index_type->name == "vector_similarity")
index->granularity = ASTIndexDeclaration::DEFAULT_VECTOR_SIMILARITY_INDEX_GRANULARITY;
else
index->granularity = ASTIndexDeclaration::DEFAULT_INDEX_GRANULARITY;
}

View File

@ -39,7 +39,7 @@ set(CMAKE_INCLUDE_CURRENT_DIR TRUE)
clickhouse_add_executable(codegen_select_fuzzer ${FUZZER_SRCS})
set_source_files_properties("${PROTO_SRCS}" "out.cpp" PROPERTIES COMPILE_FLAGS "-Wno-reserved-identifier")
set_source_files_properties("${PROTO_SRCS}" "out.cpp" PROPERTIES COMPILE_FLAGS "-Wno-reserved-identifier -Wno-extra-semi-stmt -Wno-used-but-marked-unused")
# contrib/libprotobuf-mutator/src/libfuzzer/libfuzzer_macro.h:143:44: error: no newline at end of file [-Werror,-Wnewline-eof]
target_compile_options (codegen_select_fuzzer PRIVATE -Wno-newline-eof)

View File

@ -12,6 +12,7 @@
#include <Columns/ColumnString.h>
#include <Columns/ColumnTuple.h>
#include <Columns/ColumnMap.h>
#include <Columns/ColumnsCommon.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeDateTime64.h>
@ -203,33 +204,23 @@ template <typename NumberType, typename NumberVectorBatch, typename ConvertFunc>
void ORCBlockOutputFormat::writeNumbers(
orc::ColumnVectorBatch & orc_column,
const IColumn & column,
const PaddedPODArray<UInt8> * null_bytemap,
const PaddedPODArray<UInt8> * /*null_bytemap*/,
ConvertFunc convert)
{
NumberVectorBatch & number_orc_column = dynamic_cast<NumberVectorBatch &>(orc_column);
const auto & number_column = assert_cast<const ColumnVector<NumberType> &>(column);
number_orc_column.resize(number_column.size());
number_orc_column.data.resize(number_column.size());
for (size_t i = 0; i != number_column.size(); ++i)
{
if (null_bytemap && (*null_bytemap)[i])
{
number_orc_column.notNull[i] = 0;
continue;
}
number_orc_column.notNull[i] = 1;
number_orc_column.data[i] = convert(number_column.getElement(i));
}
number_orc_column.numElements = number_column.size();
}
template <typename Decimal, typename DecimalVectorBatch, typename ConvertFunc>
void ORCBlockOutputFormat::writeDecimals(
orc::ColumnVectorBatch & orc_column,
const IColumn & column,
DataTypePtr & type,
const PaddedPODArray<UInt8> * null_bytemap,
const PaddedPODArray<UInt8> * /*null_bytemap*/,
ConvertFunc convert)
{
DecimalVectorBatch & decimal_orc_column = dynamic_cast<DecimalVectorBatch &>(orc_column);
@ -238,71 +229,49 @@ void ORCBlockOutputFormat::writeDecimals(
decimal_orc_column.precision = decimal_type->getPrecision();
decimal_orc_column.scale = decimal_type->getScale();
decimal_orc_column.resize(decimal_column.size());
for (size_t i = 0; i != decimal_column.size(); ++i)
{
if (null_bytemap && (*null_bytemap)[i])
{
decimal_orc_column.notNull[i] = 0;
continue;
}
decimal_orc_column.notNull[i] = 1;
decimal_orc_column.values.resize(decimal_column.size());
for (size_t i = 0; i != decimal_column.size(); ++i)
decimal_orc_column.values[i] = convert(decimal_column.getElement(i).value);
}
decimal_orc_column.numElements = decimal_column.size();
}
template <typename ColumnType>
void ORCBlockOutputFormat::writeStrings(
orc::ColumnVectorBatch & orc_column,
const IColumn & column,
const PaddedPODArray<UInt8> * null_bytemap)
const PaddedPODArray<UInt8> * /*null_bytemap*/)
{
orc::StringVectorBatch & string_orc_column = dynamic_cast<orc::StringVectorBatch &>(orc_column);
const auto & string_column = assert_cast<const ColumnType &>(column);
string_orc_column.resize(string_column.size());
string_orc_column.data.resize(string_column.size());
string_orc_column.length.resize(string_column.size());
for (size_t i = 0; i != string_column.size(); ++i)
{
if (null_bytemap && (*null_bytemap)[i])
{
string_orc_column.notNull[i] = 0;
continue;
}
string_orc_column.notNull[i] = 1;
const std::string_view & string = string_column.getDataAt(i).toView();
string_orc_column.data[i] = const_cast<char *>(string.data());
string_orc_column.length[i] = string.size();
}
string_orc_column.numElements = string_column.size();
}
template <typename ColumnType, typename GetSecondsFunc, typename GetNanosecondsFunc>
void ORCBlockOutputFormat::writeDateTimes(
orc::ColumnVectorBatch & orc_column,
const IColumn & column,
const PaddedPODArray<UInt8> * null_bytemap,
const PaddedPODArray<UInt8> * /*null_bytemap*/,
GetSecondsFunc get_seconds,
GetNanosecondsFunc get_nanoseconds)
{
orc::TimestampVectorBatch & timestamp_orc_column = dynamic_cast<orc::TimestampVectorBatch &>(orc_column);
const auto & timestamp_column = assert_cast<const ColumnType &>(column);
timestamp_orc_column.resize(timestamp_column.size());
timestamp_orc_column.data.resize(timestamp_column.size());
timestamp_orc_column.nanoseconds.resize(timestamp_column.size());
for (size_t i = 0; i != timestamp_column.size(); ++i)
{
if (null_bytemap && (*null_bytemap)[i])
{
timestamp_orc_column.notNull[i] = 0;
continue;
}
timestamp_orc_column.notNull[i] = 1;
timestamp_orc_column.data[i] = static_cast<int64_t>(get_seconds(timestamp_column.getElement(i)));
timestamp_orc_column.nanoseconds[i] = static_cast<int64_t>(get_nanoseconds(timestamp_column.getElement(i)));
}
timestamp_orc_column.numElements = timestamp_column.size();
}
void ORCBlockOutputFormat::writeColumn(
@ -311,9 +280,27 @@ void ORCBlockOutputFormat::writeColumn(
DataTypePtr & type,
const PaddedPODArray<UInt8> * null_bytemap)
{
orc_column.notNull.resize(column.size());
size_t rows = column.size();
orc_column.resize(rows);
orc_column.numElements = rows;
/// Calculate orc_column.hasNulls
if (null_bytemap)
orc_column.hasNulls = true;
orc_column.hasNulls = !memoryIsZero(null_bytemap->data(), 0, null_bytemap->size());
else
orc_column.hasNulls = false;
/// Fill orc_column.notNull
if (orc_column.hasNulls)
{
for (size_t i = 0; i < rows; ++i)
orc_column.notNull[i] = !(*null_bytemap)[i];
}
else
{
for (size_t i = 0; i < rows; ++i)
orc_column.notNull[i] = 1;
}
/// ORC doesn't have unsigned types, so cast everything to signed and sign-extend to Int64 to
/// make the ORC library calculate min and max correctly.
@ -471,6 +458,7 @@ void ORCBlockOutputFormat::writeColumn(
}
case TypeIndex::Nullable:
{
chassert(!null_bytemap);
const auto & nullable_column = assert_cast<const ColumnNullable &>(column);
const PaddedPODArray<UInt8> & new_null_bytemap = assert_cast<const ColumnVector<UInt8> &>(*nullable_column.getNullMapColumnPtr()).getData();
auto nested_type = removeNullable(type);
@ -485,19 +473,15 @@ void ORCBlockOutputFormat::writeColumn(
const ColumnArray::Offsets & offsets = list_column.getOffsets();
size_t column_size = list_column.size();
list_orc_column.resize(column_size);
list_orc_column.offsets.resize(column_size + 1);
/// The length of list i in ListVectorBatch is offsets[i+1] - offsets[i].
list_orc_column.offsets[0] = 0;
for (size_t i = 0; i != column_size; ++i)
{
list_orc_column.offsets[i + 1] = offsets[i];
list_orc_column.notNull[i] = 1;
}
orc::ColumnVectorBatch & nested_orc_column = *list_orc_column.elements;
writeColumn(nested_orc_column, list_column.getData(), nested_type, null_bytemap);
list_orc_column.numElements = column_size;
writeColumn(nested_orc_column, list_column.getData(), nested_type, nullptr);
break;
}
case TypeIndex::Tuple:
@ -505,10 +489,8 @@ void ORCBlockOutputFormat::writeColumn(
orc::StructVectorBatch & struct_orc_column = dynamic_cast<orc::StructVectorBatch &>(orc_column);
const auto & tuple_column = assert_cast<const ColumnTuple &>(column);
auto nested_types = assert_cast<const DataTypeTuple *>(type.get())->getElements();
for (size_t i = 0; i != tuple_column.size(); ++i)
struct_orc_column.notNull[i] = 1;
for (size_t i = 0; i != tuple_column.tupleSize(); ++i)
writeColumn(*struct_orc_column.fields[i], tuple_column.getColumn(i), nested_types[i], null_bytemap);
writeColumn(*struct_orc_column.fields[i], tuple_column.getColumn(i), nested_types[i], nullptr);
break;
}
case TypeIndex::Map:
@ -520,25 +502,21 @@ void ORCBlockOutputFormat::writeColumn(
size_t column_size = list_column.size();
map_orc_column.resize(list_column.size());
map_orc_column.offsets.resize(column_size + 1);
/// The length of list i in ListVectorBatch is offsets[i+1] - offsets[i].
map_orc_column.offsets[0] = 0;
for (size_t i = 0; i != column_size; ++i)
{
map_orc_column.offsets[i + 1] = offsets[i];
map_orc_column.notNull[i] = 1;
}
const auto nested_columns = assert_cast<const ColumnTuple *>(list_column.getDataPtr().get())->getColumns();
orc::ColumnVectorBatch & keys_orc_column = *map_orc_column.keys;
auto key_type = map_type.getKeyType();
writeColumn(keys_orc_column, *nested_columns[0], key_type, null_bytemap);
writeColumn(keys_orc_column, *nested_columns[0], key_type, nullptr);
orc::ColumnVectorBatch & values_orc_column = *map_orc_column.elements;
auto value_type = map_type.getValueType();
writeColumn(values_orc_column, *nested_columns[1], value_type, null_bytemap);
map_orc_column.numElements = column_size;
writeColumn(values_orc_column, *nested_columns[1], value_type, nullptr);
break;
}
default:
@ -546,27 +524,6 @@ void ORCBlockOutputFormat::writeColumn(
}
}
size_t ORCBlockOutputFormat::getColumnSize(const IColumn & column, DataTypePtr & type)
{
if (type->getTypeId() == TypeIndex::Array)
{
auto nested_type = assert_cast<const DataTypeArray &>(*type).getNestedType();
const IColumn & nested_column = assert_cast<const ColumnArray &>(column).getData();
return std::max(column.size(), getColumnSize(nested_column, nested_type));
}
return column.size();
}
size_t ORCBlockOutputFormat::getMaxColumnSize(Chunk & chunk)
{
size_t columns_num = chunk.getNumColumns();
size_t max_column_size = 0;
for (size_t i = 0; i != columns_num; ++i)
max_column_size = std::max(max_column_size, getColumnSize(*chunk.getColumns()[i], data_types[i]));
return max_column_size;
}
void ORCBlockOutputFormat::consume(Chunk chunk)
{
if (!writer)
@ -575,10 +532,7 @@ void ORCBlockOutputFormat::consume(Chunk chunk)
size_t columns_num = chunk.getNumColumns();
size_t rows_num = chunk.getNumRows();
/// getMaxColumnSize is needed to write arrays.
/// The size of the batch must be no less than total amount of array elements
/// and no less than the number of rows (ORC writes a null bit for every row).
std::unique_ptr<orc::ColumnVectorBatch> batch = writer->createRowBatch(getMaxColumnSize(chunk));
std::unique_ptr<orc::ColumnVectorBatch> batch = writer->createRowBatch(chunk.getNumRows());
orc::StructVectorBatch & root = dynamic_cast<orc::StructVectorBatch &>(*batch);
auto columns = chunk.detachColumns();

View File

@ -69,11 +69,6 @@ private:
void writeColumn(orc::ColumnVectorBatch & orc_column, const IColumn & column, DataTypePtr & type, const PaddedPODArray<UInt8> * null_bytemap);
/// These two functions are needed to know maximum nested size of arrays to
/// create an ORC Batch with the appropriate size
size_t getColumnSize(const IColumn & column, DataTypePtr & type);
size_t getMaxColumnSize(Chunk & chunk);
void prepareWriter();
const FormatSettings format_settings;

View File

@ -24,8 +24,8 @@
#include <Processors/Transforms/SelectByIndicesTransform.h>
#include <QueryPipeline/QueryPipelineBuilder.h>
#include <Storages/MergeTree/MergeTreeDataSelectExecutor.h>
#include <Storages/MergeTree/MergeTreeIndexAnnoy.h>
#include <Storages/MergeTree/MergeTreeIndexUSearch.h>
#include <Storages/MergeTree/MergeTreeIndexVectorSimilarity.h>
#include <Storages/MergeTree/MergeTreeIndexLegacyVectorSimilarity.h>
#include <Storages/MergeTree/MergeTreeReadPool.h>
#include <Storages/MergeTree/MergeTreePrefetchedReadPool.h>
#include <Storages/MergeTree/MergeTreeReadPoolInOrder.h>
@ -52,6 +52,8 @@
#include <memory>
#include <unordered_map>
#include "config.h"
using namespace DB;
namespace
@ -1519,16 +1521,14 @@ static void buildIndexes(
else
{
MergeTreeIndexConditionPtr condition;
if (index_helper->isVectorSearch())
if (index_helper->isVectorSimilarityIndex())
{
#ifdef ENABLE_ANNOY
if (const auto * annoy = typeid_cast<const MergeTreeIndexAnnoy *>(index_helper.get()))
condition = annoy->createIndexCondition(query_info, context);
#endif
#ifdef ENABLE_USEARCH
if (const auto * usearch = typeid_cast<const MergeTreeIndexUSearch *>(index_helper.get()))
condition = usearch->createIndexCondition(query_info, context);
#if USE_USEARCH
if (const auto * vector_similarity_index = typeid_cast<const MergeTreeIndexVectorSimilarity *>(index_helper.get()))
condition = vector_similarity_index->createIndexCondition(query_info, context);
#endif
if (const auto * legacy_vector_similarity_index = typeid_cast<const MergeTreeIndexLegacyVectorSimilarity *>(index_helper.get()))
condition = legacy_vector_similarity_index->createIndexCondition(query_info, context);
if (!condition)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Unknown vector search index {}", index_helper->index.name);
}

View File

@ -248,6 +248,8 @@ public:
void attachRequest(HTTPServerRequest * request_) { request = request_; }
const Poco::Net::HTTPServerSession & getSession() const { return session; }
private:
Poco::Net::HTTPServerSession & session;
HTTPServerRequest * request = nullptr;

View File

@ -30,7 +30,7 @@ void WriteBufferFromHTTPServerResponse::startSendHeaders()
if (add_cors_header)
response.set("Access-Control-Allow-Origin", "*");
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
std::stringstream header; //STYLE_CHECK_ALLOW_STD_STRING_STREAM
response.beginWrite(header);
@ -119,12 +119,10 @@ void WriteBufferFromHTTPServerResponse::nextImpl()
WriteBufferFromHTTPServerResponse::WriteBufferFromHTTPServerResponse(
HTTPServerResponse & response_,
bool is_http_method_head_,
UInt64 keep_alive_timeout_,
const ProfileEvents::Event & write_event_)
: HTTPWriteBuffer(response_.getSocket(), write_event_)
, response(response_)
, is_http_method_head(is_http_method_head_)
, keep_alive_timeout(keep_alive_timeout_)
{
}

View File

@ -29,7 +29,6 @@ public:
WriteBufferFromHTTPServerResponse(
HTTPServerResponse & response_,
bool is_http_method_head_,
UInt64 keep_alive_timeout_,
const ProfileEvents::Event & write_event_ = ProfileEvents::end());
~WriteBufferFromHTTPServerResponse() override;
@ -91,7 +90,6 @@ private:
bool is_http_method_head;
bool add_cors_header = false;
size_t keep_alive_timeout = 0;
bool initialized = false;

View File

@ -29,7 +29,7 @@ void sendExceptionToHTTPClient(
if (!out)
{
/// If nothing was sent yet.
WriteBufferFromHTTPServerResponse out_for_message{response, request.getMethod() == HTTPRequest::HTTP_HEAD, DEFAULT_HTTP_KEEP_ALIVE_TIMEOUT};
WriteBufferFromHTTPServerResponse out_for_message{response, request.getMethod() == HTTPRequest::HTTP_HEAD};
out_for_message.writeln(exception_message);
out_for_message.finalize();

View File

@ -266,7 +266,6 @@ void HTTPHandler::processQuery(
std::make_shared<WriteBufferFromHTTPServerResponse>(
response,
request.getMethod() == HTTPRequest::HTTP_HEAD,
context->getServerSettings().keep_alive_timeout.totalSeconds(),
write_event);
used_output.out = used_output.out_holder;
used_output.out_maybe_compressed = used_output.out_holder;
@ -558,7 +557,7 @@ try
if (!used_output.out_holder && !used_output.exception_is_written)
{
/// If nothing was sent yet and we don't even know if we must compress the response.
WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD, DEFAULT_HTTP_KEEP_ALIVE_TIMEOUT).writeln(s);
WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD).writeln(s);
}
else if (used_output.out_maybe_compressed)
{

View File

@ -122,7 +122,8 @@ static inline auto createHandlersFactoryFromConfig(
}
else if (handler_type == "prometheus")
{
main_handler_factory->addHandler(createPrometheusHandlerFactoryForHTTPRule(server, config, prefix + "." + key, async_metrics));
main_handler_factory->addHandler(
createPrometheusHandlerFactoryForHTTPRule(server, config, prefix + "." + key, async_metrics));
}
else if (handler_type == "replicas_status")
{

View File

@ -87,9 +87,8 @@ void InterserverIOHTTPHandler::handleRequest(HTTPServerRequest & request, HTTPSe
response.setChunkedTransferEncoding(true);
Output used_output;
const auto keep_alive_timeout = server.context()->getServerSettings().keep_alive_timeout.totalSeconds();
used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>(
response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, keep_alive_timeout, write_event);
response, request.getMethod() == Poco::Net::HTTPRequest::HTTP_HEAD, write_event);
auto finalize_output = [&]
{

View File

@ -95,7 +95,7 @@ public:
class PrometheusRequestHandler::ImplWithContext : public Impl
{
public:
explicit ImplWithContext(PrometheusRequestHandler & parent) : Impl(parent), default_settings(parent.server.context()->getSettingsRef()) { }
explicit ImplWithContext(PrometheusRequestHandler & parent) : Impl(parent), default_settings(server().context()->getSettingsRef()) { }
virtual void handlingRequestWithContext(HTTPServerRequest & request, HTTPServerResponse & response) = 0;
@ -353,7 +353,7 @@ void PrometheusRequestHandler::handleRequest(HTTPServerRequest & request, HTTPSe
if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
response.setChunkedTransferEncoding(true);
setResponseDefaultHeaders(response, config.keep_alive_timeout);
setResponseDefaultHeaders(response);
impl->beforeHandlingRequest(request);
impl->handleRequest(request, response);
@ -379,7 +379,7 @@ WriteBufferFromHTTPServerResponse & PrometheusRequestHandler::getOutputStream(HT
if (write_buffer_from_response)
return *write_buffer_from_response;
write_buffer_from_response = std::make_unique<WriteBufferFromHTTPServerResponse>(
response, http_method == HTTPRequest::HTTP_HEAD, config.keep_alive_timeout, write_event);
response, http_method == HTTPRequest::HTTP_HEAD, write_event);
return *write_buffer_from_response;
}
@ -399,7 +399,7 @@ void PrometheusRequestHandler::finalizeResponse(HTTPServerResponse & response)
if (write_buffer_from_response)
std::exchange(write_buffer_from_response, {})->finalize();
else
WriteBufferFromHTTPServerResponse{response, http_method == HTTPRequest::HTTP_HEAD, config.keep_alive_timeout, write_event}.finalize();
WriteBufferFromHTTPServerResponse{response, http_method == HTTPRequest::HTTP_HEAD, write_event}.finalize();
}
chassert(response_finalized && !write_buffer_from_response);
}

View File

@ -15,8 +15,11 @@ class WriteBufferFromHTTPServerResponse;
class PrometheusRequestHandler : public HTTPRequestHandler
{
public:
PrometheusRequestHandler(IServer & server_, const PrometheusRequestHandlerConfig & config_,
const AsynchronousMetrics & async_metrics_, std::shared_ptr<PrometheusMetricsWriter> metrics_writer_);
PrometheusRequestHandler(
IServer & server_,
const PrometheusRequestHandlerConfig & config_,
const AsynchronousMetrics & async_metrics_,
std::shared_ptr<PrometheusMetricsWriter> metrics_writer_);
~PrometheusRequestHandler() override;
void handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & write_event_) override;

View File

@ -89,8 +89,7 @@ void ReplicasStatusHandler::handleRequest(HTTPServerRequest & request, HTTPServe
}
}
const auto & server_settings = getContext()->getServerSettings();
setResponseDefaultHeaders(response, server_settings.keep_alive_timeout.totalSeconds());
setResponseDefaultHeaders(response);
if (!ok)
{

View File

@ -35,10 +35,9 @@ namespace ErrorCodes
extern const int INVALID_CONFIG_PARAMETER;
}
static inline std::unique_ptr<WriteBuffer>
responseWriteBuffer(HTTPServerRequest & request, HTTPServerResponse & response, UInt64 keep_alive_timeout)
static inline std::unique_ptr<WriteBuffer> responseWriteBuffer(HTTPServerRequest & request, HTTPServerResponse & response)
{
auto buf = std::unique_ptr<WriteBuffer>(new WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD, keep_alive_timeout));
auto buf = std::unique_ptr<WriteBuffer>(new WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD));
/// The client can pass a HTTP header indicating supported compression method (gzip or deflate).
String http_response_compression_methods = request.get("Accept-Encoding", "");
@ -91,8 +90,7 @@ static inline void trySendExceptionToClient(
void StaticRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event & /*write_event*/)
{
auto keep_alive_timeout = server.context()->getServerSettings().keep_alive_timeout.totalSeconds();
auto out = responseWriteBuffer(request, response, keep_alive_timeout);
auto out = responseWriteBuffer(request, response);
try
{
@ -107,7 +105,7 @@ void StaticRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServer
"The Transfer-Encoding is not chunked and there "
"is no Content-Length header for POST request");
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTPStatus(status));
writeResponse(*out);
}

View File

@ -30,23 +30,20 @@ DashboardWebUIRequestHandler::DashboardWebUIRequestHandler(IServer & server_) :
BinaryWebUIRequestHandler::BinaryWebUIRequestHandler(IServer & server_) : server(server_) {}
JavaScriptWebUIRequestHandler::JavaScriptWebUIRequestHandler(IServer & server_) : server(server_) {}
static void handle(const IServer & server, HTTPServerRequest & request, HTTPServerResponse & response, std::string_view html)
static void handle(HTTPServerRequest & request, HTTPServerResponse & response, std::string_view html)
{
auto keep_alive_timeout = server.context()->getServerSettings().keep_alive_timeout.totalSeconds();
response.setContentType("text/html; charset=UTF-8");
if (request.getVersion() == HTTPServerRequest::HTTP_1_1)
response.setChunkedTransferEncoding(true);
setResponseDefaultHeaders(response, keep_alive_timeout);
setResponseDefaultHeaders(response);
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_OK);
WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD, keep_alive_timeout).write(html.data(), html.size());
WriteBufferFromHTTPServerResponse(response, request.getMethod() == HTTPRequest::HTTP_HEAD).write(html.data(), html.size());
}
void PlayWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event &)
{
handle(server, request, response, {reinterpret_cast<const char *>(gresource_play_htmlData), gresource_play_htmlSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_play_htmlData), gresource_play_htmlSize});
}
void DashboardWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event &)
@ -64,23 +61,23 @@ void DashboardWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HT
static re2::RE2 lz_string_url = R"(https://[^\s"'`]+lz-string[^\s"'`]*\.js)";
RE2::Replace(&html, lz_string_url, "/js/lz-string.js");
handle(server, request, response, html);
handle(request, response, html);
}
void BinaryWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event &)
{
handle(server, request, response, {reinterpret_cast<const char *>(gresource_binary_htmlData), gresource_binary_htmlSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_binary_htmlData), gresource_binary_htmlSize});
}
void JavaScriptWebUIRequestHandler::handleRequest(HTTPServerRequest & request, HTTPServerResponse & response, const ProfileEvents::Event &)
{
if (request.getURI() == "/js/uplot.js")
{
handle(server, request, response, {reinterpret_cast<const char *>(gresource_uplot_jsData), gresource_uplot_jsSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_uplot_jsData), gresource_uplot_jsSize});
}
else if (request.getURI() == "/js/lz-string.js")
{
handle(server, request, response, {reinterpret_cast<const char *>(gresource_lz_string_jsData), gresource_lz_string_jsSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_lz_string_jsData), gresource_lz_string_jsSize});
}
else
{
@ -88,7 +85,7 @@ void JavaScriptWebUIRequestHandler::handleRequest(HTTPServerRequest & request, H
*response.send() << "Not found.\n";
}
handle(server, request, response, {reinterpret_cast<const char *>(gresource_binary_htmlData), gresource_binary_htmlSize});
handle(request, response, {reinterpret_cast<const char *>(gresource_binary_htmlData), gresource_binary_htmlSize});
}
}

View File

@ -3,6 +3,7 @@
#include <Storages/IndicesDescription.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTIndexDeclaration.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ParserCreateQuery.h>
@ -130,10 +131,15 @@ IndexDescription IndexDescription::getIndexFromAST(const ASTPtr & definition_ast
{
for (size_t i = 0; i < index_type->arguments->children.size(); ++i)
{
const auto * argument = index_type->arguments->children[i]->as<ASTLiteral>();
if (!argument)
const auto & child = index_type->arguments->children[i];
if (const auto * ast_literal = child->as<ASTLiteral>(); ast_literal != nullptr)
/// E.g. INDEX index_name column_name TYPE vector_similarity('hnsw', 'f32')
result.arguments.emplace_back(ast_literal->value);
else if (const auto * ast_identifier = child->as<ASTIdentifier>(); ast_identifier != nullptr)
/// E.g. INDEX index_name column_name TYPE vector_similarity(hnsw, f32)
result.arguments.emplace_back(ast_identifier->name());
else
throw Exception(ErrorCodes::INCORRECT_QUERY, "Only literals can be skip index arguments");
result.arguments.emplace_back(argument->value);
}
}

View File

@ -1,507 +0,0 @@
#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>
#include <Core/Settings.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Storages/MergeTree/KeyCondition.h>
#include <Storages/MergeTree/MergeTreeSettings.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int INCORRECT_QUERY;
}
namespace
{
template <typename Literal>
void extractReferenceVectorFromLiteral(ApproximateNearestNeighborInformation::Embedding & reference_vector, Literal literal)
{
Float64 float_element_of_reference_vector;
Int64 int_element_of_reference_vector;
for (const auto & value : literal.value())
{
if (value.tryGet(float_element_of_reference_vector))
reference_vector.emplace_back(float_element_of_reference_vector);
else if (value.tryGet(int_element_of_reference_vector))
reference_vector.emplace_back(static_cast<float>(int_element_of_reference_vector));
else
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported.");
}
}
ApproximateNearestNeighborInformation::Metric stringToMetric(std::string_view metric)
{
if (metric == "L2Distance")
return ApproximateNearestNeighborInformation::Metric::L2;
else if (metric == "LpDistance")
return ApproximateNearestNeighborInformation::Metric::Lp;
else
return ApproximateNearestNeighborInformation::Metric::Unknown;
}
}
ApproximateNearestNeighborCondition::ApproximateNearestNeighborCondition(const SelectQueryInfo & query_info, ContextPtr context)
: block_with_constants(KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context))
, index_granularity(context->getMergeTreeSettings().index_granularity)
, max_limit_for_ann_queries(context->getSettingsRef().max_limit_for_ann_queries)
, index_is_useful(checkQueryStructure(query_info))
{}
bool ApproximateNearestNeighborCondition::alwaysUnknownOrTrue(String metric) const
{
if (!index_is_useful)
return true; // Query isn't supported
// If query is supported, check metrics for match
return !(stringToMetric(metric) == query_information->metric);
}
float ApproximateNearestNeighborCondition::getComparisonDistanceForWhereQuery() const
{
if (index_is_useful && query_information.has_value()
&& query_information->type == ApproximateNearestNeighborInformation::Type::Where)
return query_information->distance;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported method for this query type");
}
UInt64 ApproximateNearestNeighborCondition::getLimit() const
{
if (index_is_useful && query_information.has_value())
return query_information->limit;
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
}
std::vector<float> ApproximateNearestNeighborCondition::getReferenceVector() const
{
if (index_is_useful && query_information.has_value())
return query_information->reference_vector;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index.");
}
size_t ApproximateNearestNeighborCondition::getDimensions() const
{
if (index_is_useful && query_information.has_value())
return query_information->reference_vector.size();
throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index.");
}
String ApproximateNearestNeighborCondition::getColumnName() const
{
if (index_is_useful && query_information.has_value())
return query_information->column_name;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index.");
}
ApproximateNearestNeighborInformation::Metric ApproximateNearestNeighborCondition::getMetricType() const
{
if (index_is_useful && query_information.has_value())
return query_information->metric;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Metric name was requested for useless or uninitialized index.");
}
float ApproximateNearestNeighborCondition::getPValueForLpDistance() const
{
if (index_is_useful && query_information.has_value())
return query_information->p_for_lp_dist;
throw Exception(ErrorCodes::LOGICAL_ERROR, "P from LPDistance was requested for useless or uninitialized index.");
}
ApproximateNearestNeighborInformation::Type ApproximateNearestNeighborCondition::getQueryType() const
{
if (index_is_useful && query_information.has_value())
return query_information->type;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Query type was requested for useless or uninitialized index.");
}
bool ApproximateNearestNeighborCondition::checkQueryStructure(const SelectQueryInfo & query)
{
/// RPN-s for different sections of the query
RPN rpn_prewhere_clause;
RPN rpn_where_clause;
RPN rpn_order_by_clause;
RPNElement rpn_limit;
UInt64 limit;
ApproximateNearestNeighborInformation prewhere_info;
ApproximateNearestNeighborInformation where_info;
ApproximateNearestNeighborInformation order_by_info;
/// Build rpns for query sections
const auto & select = query.query->as<ASTSelectQuery &>();
/// If query has PREWHERE clause
if (select.prewhere())
traverseAST(select.prewhere(), rpn_prewhere_clause);
/// If query has WHERE clause
if (select.where())
traverseAST(select.where(), rpn_where_clause);
/// If query has LIMIT clause
if (select.limitLength())
traverseAtomAST(select.limitLength(), rpn_limit);
if (select.orderBy()) // If query has ORDERBY clause
traverseOrderByAST(select.orderBy(), rpn_order_by_clause);
/// Reverse RPNs for conveniences during parsing
std::reverse(rpn_prewhere_clause.begin(), rpn_prewhere_clause.end());
std::reverse(rpn_where_clause.begin(), rpn_where_clause.end());
std::reverse(rpn_order_by_clause.begin(), rpn_order_by_clause.end());
/// Match rpns with supported types and extract information
const bool prewhere_is_valid = matchRPNWhere(rpn_prewhere_clause, prewhere_info);
const bool where_is_valid = matchRPNWhere(rpn_where_clause, where_info);
const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by_clause, order_by_info);
const bool limit_is_valid = matchRPNLimit(rpn_limit, limit);
/// Query without a LIMIT clause or with a limit greater than a restriction is not supported
if (!limit_is_valid || max_limit_for_ann_queries < limit)
return false;
/// Search type query in both sections isn't supported
if (prewhere_is_valid && where_is_valid)
return false;
/// Search type should be in WHERE or PREWHERE clause
if (prewhere_is_valid || where_is_valid)
query_information = std::move(prewhere_is_valid ? prewhere_info : where_info);
if (order_by_is_valid)
{
/// Query with valid where and order by type is not supported
if (query_information.has_value())
return false;
query_information = std::move(order_by_info);
}
if (query_information)
query_information->limit = limit;
return query_information.has_value();
}
void ApproximateNearestNeighborCondition::traverseAST(const ASTPtr & node, RPN & rpn)
{
// If the node is ASTFunction, it may have children nodes
if (const auto * func = node->as<ASTFunction>())
{
const ASTs & children = func->arguments->children;
// Traverse children nodes
for (const auto& child : children)
traverseAST(child, rpn);
}
RPNElement element;
/// Get the data behind node
if (!traverseAtomAST(node, element))
element.function = RPNElement::FUNCTION_UNKNOWN;
rpn.emplace_back(std::move(element));
}
bool ApproximateNearestNeighborCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
{
/// Match Functions
if (const auto * function = node->as<ASTFunction>())
{
/// Set the name
out.func_name = function->name;
if (function->name == "L1Distance" ||
function->name == "L2Distance" ||
function->name == "LinfDistance" ||
function->name == "cosineDistance" ||
function->name == "dotProduct" ||
function->name == "LpDistance")
out.function = RPNElement::FUNCTION_DISTANCE;
else if (function->name == "tuple")
out.function = RPNElement::FUNCTION_TUPLE;
else if (function->name == "array")
out.function = RPNElement::FUNCTION_ARRAY;
else if (function->name == "less" ||
function->name == "greater" ||
function->name == "lessOrEquals" ||
function->name == "greaterOrEquals")
out.function = RPNElement::FUNCTION_COMPARISON;
else if (function->name == "_CAST")
out.function = RPNElement::FUNCTION_CAST;
else
return false;
return true;
}
/// Match identifier
else if (const auto * identifier = node->as<ASTIdentifier>())
{
out.function = RPNElement::FUNCTION_IDENTIFIER;
out.identifier.emplace(identifier->name());
out.func_name = "column identifier";
return true;
}
/// Check if we have constants behind the node
return tryCastToConstType(node, out);
}
bool ApproximateNearestNeighborCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
{
Field const_value;
DataTypePtr const_type;
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
{
/// Check for constant types
if (const_value.getType() == Field::Types::Float64)
{
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
out.float_literal.emplace(const_value.safeGet<Float32>());
out.func_name = "Float literal";
return true;
}
if (const_value.getType() == Field::Types::UInt64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.safeGet<UInt64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Int64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.safeGet<Int64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Tuple)
{
out.function = RPNElement::FUNCTION_LITERAL_TUPLE;
out.tuple_literal = const_value.safeGet<Tuple>();
out.func_name = "Tuple literal";
return true;
}
if (const_value.getType() == Field::Types::Array)
{
out.function = RPNElement::FUNCTION_LITERAL_ARRAY;
out.array_literal = const_value.safeGet<Array>();
out.func_name = "Array literal";
return true;
}
if (const_value.getType() == Field::Types::String)
{
out.function = RPNElement::FUNCTION_STRING_LITERAL;
out.func_name = const_value.safeGet<String>();
return true;
}
}
return false;
}
void ApproximateNearestNeighborCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn)
{
if (const auto * expr_list = node->as<ASTExpressionList>())
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
traverseAST(order_by_element->children.front(), rpn);
}
/// Returns true and stores ApproximateNearestNeighborInformation if the query has valid WHERE clause
bool ApproximateNearestNeighborCondition::matchRPNWhere(RPN & rpn, ApproximateNearestNeighborInformation & ann_info)
{
/// Fill query type field
ann_info.type = ApproximateNearestNeighborInformation::Type::Where;
/// WHERE section must have at least 5 expressions
/// Operator->Distance(float)->DistanceFunc->Column->Tuple(Array)Func(ReferenceVector(floats))
if (rpn.size() < 5)
return false;
auto iter = rpn.begin();
/// Query starts from operator less
if (iter->function != RPNElement::FUNCTION_COMPARISON)
return false;
const bool greater_case = iter->func_name == "greater" || iter->func_name == "greaterOrEquals";
const bool less_case = iter->func_name == "less" || iter->func_name == "lessOrEquals";
++iter;
if (less_case)
{
if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL)
return false;
ann_info.distance = getFloatOrIntLiteralOrPanic(iter);
if (ann_info.distance < 0)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance can't be negative. Got {}", ann_info.distance);
++iter;
}
else if (!greater_case)
return false;
auto end = rpn.end();
if (!matchMainParts(iter, end, ann_info))
return false;
if (greater_case)
{
if (ann_info.reference_vector.size() < 2)
return false;
ann_info.distance = ann_info.reference_vector.back();
if (ann_info.distance < 0)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance can't be negative. Got {}", ann_info.distance);
ann_info.reference_vector.pop_back();
}
/// query is ok
return true;
}
/// Returns true and stores ANNExpr if the query has valid ORDERBY clause
bool ApproximateNearestNeighborCondition::matchRPNOrderBy(RPN & rpn, ApproximateNearestNeighborInformation & ann_info)
{
/// Fill query type field
ann_info.type = ApproximateNearestNeighborInformation::Type::OrderBy;
// ORDER BY clause must have at least 3 expressions
if (rpn.size() < 3)
return false;
auto iter = rpn.begin();
auto end = rpn.end();
return ApproximateNearestNeighborCondition::matchMainParts(iter, end, ann_info);
}
/// Returns true and stores Length if we have valid LIMIT clause in query
bool ApproximateNearestNeighborCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
{
if (rpn.function == RPNElement::FUNCTION_INT_LITERAL)
{
limit = rpn.int_literal.value();
return true;
}
return false;
}
/// Matches dist function, referencer vector, column name
bool ApproximateNearestNeighborCondition::matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ApproximateNearestNeighborInformation & ann_info)
{
bool identifier_found = false;
/// Matches DistanceFunc->[Column]->[Tuple(array)Func]->ReferenceVector(floats)->[Column]
if (iter->function != RPNElement::FUNCTION_DISTANCE)
return false;
ann_info.metric = stringToMetric(iter->func_name);
++iter;
if (ann_info.metric == ApproximateNearestNeighborInformation::Metric::Lp)
{
if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL &&
iter->function != RPNElement::FUNCTION_INT_LITERAL)
return false;
ann_info.p_for_lp_dist = getFloatOrIntLiteralOrPanic(iter);
++iter;
}
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
identifier_found = true;
ann_info.column_name = std::move(iter->identifier.value());
++iter;
}
if (iter->function == RPNElement::FUNCTION_TUPLE || iter->function == RPNElement::FUNCTION_ARRAY)
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE)
{
extractReferenceVectorFromLiteral(ann_info.reference_vector, iter->tuple_literal);
++iter;
}
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractReferenceVectorFromLiteral(ann_info.reference_vector, iter->array_literal);
++iter;
}
/// further conditions are possible if there is no tuple or array, or no identifier is found
/// the tuple or array can be inside a cast function. For other cases, see the loop after this condition
if (iter != end && iter->function == RPNElement::FUNCTION_CAST)
{
++iter;
/// Cast should be made to array or tuple
if (!iter->func_name.starts_with("Array") && !iter->func_name.starts_with("Tuple"))
return false;
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE)
{
extractReferenceVectorFromLiteral(ann_info.reference_vector, iter->tuple_literal);
++iter;
}
else if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractReferenceVectorFromLiteral(ann_info.reference_vector, iter->array_literal);
++iter;
}
else
return false;
}
while (iter != end)
{
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
ann_info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter));
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
if (identifier_found)
return false;
ann_info.column_name = std::move(iter->identifier.value());
identifier_found = true;
}
else
return false;
++iter;
}
/// Final checks of correctness
return identifier_found && !ann_info.reference_vector.empty();
}
/// Gets float or int from AST node
float ApproximateNearestNeighborCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter)
{
if (iter->float_literal.has_value())
return iter->float_literal.value();
if (iter->int_literal.has_value())
return static_cast<float>(iter->int_literal.value());
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong parsed AST in buildRPN\n");
}
}

View File

@ -11,6 +11,7 @@
#include <Storages/MergeTree/MergeTreeDataPartUUID.h>
#include <Storages/MergeTree/StorageFromMergeTreeDataPart.h>
#include <Storages/MergeTree/MergeTreeIndexFullText.h>
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Storages/ReadInOrderOptimizer.h>
#include <Storages/VirtualColumnUtils.h>
#include <Parsers/ASTIdentifier.h>
@ -48,7 +49,6 @@
#include <Functions/IFunction.h>
#include <IO/WriteBufferFromOStream.h>
#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>
namespace CurrentMetrics
{
@ -1406,11 +1406,10 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin)
reader.read(granule);
auto ann_condition = std::dynamic_pointer_cast<IMergeTreeIndexConditionApproximateNearestNeighbor>(condition);
if (ann_condition != nullptr)
if (index_helper->isVectorSimilarityIndex())
{
/// An array of indices of useful ranges.
auto result = ann_condition->getUsefulRanges(granule);
auto result = condition->getUsefulRanges(granule);
for (auto range : result)
{

View File

@ -27,7 +27,6 @@ MergeTreeWriterSettings::MergeTreeWriterSettings(
, rewrite_primary_key(rewrite_primary_key_)
, blocks_are_granules_size(blocks_are_granules_size_)
, query_write_settings(query_write_settings_)
, max_threads_for_annoy_index_creation(global_settings.max_threads_for_annoy_index_creation)
, low_cardinality_max_dictionary_size(global_settings.low_cardinality_max_dictionary_size)
, low_cardinality_use_single_dictionary_for_part(global_settings.low_cardinality_use_single_dictionary_for_part != 0)
, use_compact_variant_discriminators_serialization(storage_settings->use_compact_variant_discriminators_serialization)

View File

@ -77,8 +77,6 @@ struct MergeTreeWriterSettings
bool blocks_are_granules_size;
WriteSettings query_write_settings;
size_t max_threads_for_annoy_index_creation;
size_t low_cardinality_max_dictionary_size;
bool low_cardinality_use_single_dictionary_for_part;
bool use_compact_variant_discriminators_serialization;

View File

@ -1,416 +0,0 @@
#ifdef ENABLE_ANNOY
#include <Storages/MergeTree/MergeTreeIndexAnnoy.h>
#include <Columns/ColumnArray.h>
#include <Common/typeid_cast.h>
#include <Core/Field.h>
#include <Core/Settings.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include <Interpreters/castColumn.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int INCORRECT_DATA;
extern const int INCORRECT_NUMBER_OF_COLUMNS;
extern const int INCORRECT_QUERY;
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
template <typename Distance>
AnnoyIndexWithSerialization<Distance>::AnnoyIndexWithSerialization(size_t dimensions)
: Base::AnnoyIndex(static_cast<int>(dimensions))
{
}
template<typename Distance>
void AnnoyIndexWithSerialization<Distance>::serialize(WriteBuffer & ostr) const
{
chassert(Base::_built);
writeIntBinary(Base::_s, ostr);
writeIntBinary(Base::_n_items, ostr);
writeIntBinary(Base::_n_nodes, ostr);
writeIntBinary(Base::_nodes_size, ostr);
writeIntBinary(Base::_K, ostr);
writeIntBinary(Base::_seed, ostr);
writeVectorBinary(Base::_roots, ostr);
ostr.write(reinterpret_cast<const char *>(Base::_nodes), Base::_s * Base::_n_nodes);
}
template<typename Distance>
void AnnoyIndexWithSerialization<Distance>::deserialize(ReadBuffer & istr)
{
chassert(!Base::_built);
readIntBinary(Base::_s, istr);
readIntBinary(Base::_n_items, istr);
readIntBinary(Base::_n_nodes, istr);
readIntBinary(Base::_nodes_size, istr);
readIntBinary(Base::_K, istr);
readIntBinary(Base::_seed, istr);
readVectorBinary(Base::_roots, istr);
Base::_nodes = realloc(Base::_nodes, Base::_s * Base::_n_nodes);
istr.readStrict(reinterpret_cast<char *>(Base::_nodes), Base::_s * Base::_n_nodes);
Base::_fd = 0;
// set flags
Base::_loaded = false;
Base::_verbose = false;
Base::_on_disk = false;
Base::_built = true;
}
template<typename Distance>
size_t AnnoyIndexWithSerialization<Distance>::getDimensions() const
{
return Base::get_f();
}
template <typename Distance>
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, index(nullptr)
{}
template <typename Distance>
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(
const String & index_name_,
const Block & index_sample_block_,
AnnoyIndexWithSerializationPtr<Distance> index_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, index(std::move(index_))
{}
template <typename Distance>
void MergeTreeIndexGranuleAnnoy<Distance>::serializeBinary(WriteBuffer & ostr) const
{
/// Number of dimensions is required in the index constructor,
/// so it must be written and read separately from the other part
writeIntBinary(static_cast<UInt64>(index->getDimensions()), ostr); // write dimension
index->serialize(ostr);
}
template <typename Distance>
void MergeTreeIndexGranuleAnnoy<Distance>::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
{
UInt64 dimension;
readIntBinary(dimension, istr);
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(dimension);
index->deserialize(istr);
}
template <typename Distance>
MergeTreeIndexAggregatorAnnoy<Distance>::MergeTreeIndexAggregatorAnnoy(
const String & index_name_,
const Block & index_sample_block_,
UInt64 trees_,
size_t max_threads_for_creation_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, trees(trees_)
, max_threads_for_creation(max_threads_for_creation_)
{}
template <typename Distance>
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy<Distance>::getGranuleAndReset()
{
int threads = (max_threads_for_creation == 0) ? -1 : static_cast<int>(max_threads_for_creation);
/// clang-tidy reports a false positive: it considers %p with an outdated pointer in fprintf() (used by logging which we don't do) dereferencing
index->build(static_cast<int>(trees), threads);
auto granule = std::make_shared<MergeTreeIndexGranuleAnnoy<Distance>>(index_name, index_sample_block, index);
index = nullptr;
return granule;
}
template <typename Distance>
void MergeTreeIndexAggregatorAnnoy<Distance>::update(const Block & block, size_t * pos, size_t limit)
{
if (*pos >= block.rows())
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"The provided position is not less than the number of block rows. Position: {}, Block rows: {}.",
*pos, block.rows());
size_t rows_read = std::min(limit, block.rows() - *pos);
if (rows_read == 0)
return;
if (rows_read > std::numeric_limits<uint32_t>::max())
throw Exception(ErrorCodes::INCORRECT_DATA, "Index granularity is too big: more than 4B rows per index granule.");
if (index_sample_block.columns() > 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected block with single column");
const String & index_column_name = index_sample_block.getByPosition(0).name;
ColumnPtr column_cut = block.getByName(index_column_name).column->cut(*pos, rows_read);
if (const auto & column_array = typeid_cast<const ColumnArray *>(column_cut.get()))
{
const auto & column_array_data = column_array->getData();
const auto & column_array_data_float = typeid_cast<const ColumnFloat32 &>(column_array_data);
const auto & column_array_data_float_data = column_array_data_float.getData();
const auto & column_array_offsets = column_array->getOffsets();
const size_t num_rows = column_array_offsets.size();
if (column_array->empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Array is unexpectedly empty");
/// The Annoy algorithm naturally assumes that the indexed vectors have dimension >= 1. This condition is violated if empty arrays
/// are INSERTed into an Annoy-indexed column or if no value was specified at all in which case the arrays take on their default
/// value which is also empty.
if (column_array->isDefaultAt(0))
throw Exception(ErrorCodes::INCORRECT_DATA, "The arrays in column '{}' must not be empty. Did you try to INSERT default values?", index_column_name);
/// Check all sizes are the same
size_t dimension = column_array_offsets[0];
for (size_t i = 0; i < num_rows - 1; ++i)
if (column_array_offsets[i + 1] - column_array_offsets[i] != dimension)
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
/// Also check that previously inserted blocks have the same size as this block.
/// Note that this guarantees consistency of dimension only within parts. We are unable to detect inconsistent dimensions across
/// parts - for this, a little help from the user is needed, e.g. CONSTRAINT cnstr CHECK length(array) = 42.
if (index && index->getDimensions() != dimension)
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
if (!index)
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(dimension);
/// Add all rows of block
index->add_item(index->get_n_items(), column_array_data_float_data.data());
for (size_t current_row = 1; current_row < num_rows; ++current_row)
index->add_item(index->get_n_items(), &column_array_data_float_data[column_array_offsets[current_row - 1]]);
}
else if (const auto & column_tuple = typeid_cast<const ColumnTuple *>(column_cut.get()))
{
const auto & column_tuple_columns = column_tuple->getColumns();
/// TODO check if calling index->add_item() directly on the block's tuples is faster than materializing everything
std::vector<std::vector<Float32>> data(column_tuple->size(), std::vector<Float32>());
for (const auto & column : column_tuple_columns)
{
const auto & pod_array = typeid_cast<const ColumnFloat32 *>(column.get())->getData();
for (size_t i = 0; i < pod_array.size(); ++i)
data[i].push_back(pod_array[i]);
}
if (data.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Tuple has 0 rows, {} rows expected", rows_read);
if (!index)
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(data[0].size());
for (const auto & item : data)
index->add_item(index->get_n_items(), item.data());
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array or Tuple column");
*pos += rows_read;
}
MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy(
const IndexDescription & /*index_description*/,
const SelectQueryInfo & query,
const String & distance_function_,
ContextPtr context)
: ann_condition(query, context)
, distance_function(distance_function_)
, search_k(context->getSettingsRef().annoy_index_search_k_nodes)
{}
bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /*idx_granule*/) const
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "mayBeTrueOnGranule is not supported for ANN skip indexes");
}
bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const
{
return ann_condition.alwaysUnknownOrTrue(distance_function);
}
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return getUsefulRangesImpl<Annoy::Euclidean>(idx_granule);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return getUsefulRangesImpl<Annoy::Angular>(idx_granule);
std::unreachable();
}
template <typename Distance>
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
{
const UInt64 limit = ann_condition.getLimit();
const UInt64 index_granularity = ann_condition.getIndexGranularity();
const std::optional<float> comparison_distance = ann_condition.getQueryType() == ApproximateNearestNeighborInformation::Type::Where
? std::optional<float>(ann_condition.getComparisonDistanceForWhereQuery())
: std::nullopt;
if (comparison_distance && comparison_distance.value() < 0)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to optimize query with where without distance");
const std::vector<float> reference_vector = ann_condition.getReferenceVector();
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy<Distance>>(idx_granule);
if (granule == nullptr)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
const AnnoyIndexWithSerializationPtr<Distance> annoy = granule->index;
if (ann_condition.getDimensions() != annoy->getDimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
"does not match the dimension in the index ({})",
ann_condition.getDimensions(), annoy->getDimensions());
std::vector<UInt64> neighbors; /// indexes of dots which were closest to the reference vector
std::vector<Float32> distances;
neighbors.reserve(limit);
distances.reserve(limit);
annoy->get_nns_by_vector(reference_vector.data(), limit, static_cast<int>(search_k), &neighbors, &distances);
chassert(neighbors.size() == distances.size());
std::vector<size_t> granules;
granules.reserve(neighbors.size());
for (size_t i = 0; i < neighbors.size(); ++i)
{
if (comparison_distance && distances[i] > comparison_distance)
continue;
granules.push_back(neighbors[i] / index_granularity);
}
/// make unique
std::sort(granules.begin(), granules.end());
granules.erase(std::unique(granules.begin(), granules.end()), granules.end());
return granules;
}
MergeTreeIndexAnnoy::MergeTreeIndexAnnoy(const IndexDescription & index_, UInt64 trees_, const String & distance_function_)
: IMergeTreeIndex(index_)
, trees(trees_)
, distance_function(distance_function_)
{}
MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Euclidean>>(index.name, index.sample_block);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Angular>>(index.name, index.sample_block);
std::unreachable();
}
MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator(const MergeTreeWriterSettings & settings) const
{
/// TODO: Support more metrics. Available metrics: https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Euclidean>>(index.name, index.sample_block, trees, settings.max_threads_for_annoy_index_creation);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Angular>>(index.name, index.sample_block, trees, settings.max_threads_for_annoy_index_creation);
std::unreachable();
}
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionAnnoy>(index, query, distance_function, context);
};
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(const ActionsDAG *, ContextPtr) const
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "MergeTreeIndexAnnoy cannot be created with ActionsDAG");
}
MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index)
{
static constexpr auto DEFAULT_DISTANCE_FUNCTION = DISTANCE_FUNCTION_L2;
String distance_function = DEFAULT_DISTANCE_FUNCTION;
if (!index.arguments.empty())
distance_function = index.arguments[0].safeGet<String>();
static constexpr auto DEFAULT_TREES = 100uz;
UInt64 trees = DEFAULT_TREES;
if (index.arguments.size() > 1)
trees = index.arguments[1].safeGet<UInt64>();
return std::make_shared<MergeTreeIndexAnnoy>(index, trees, distance_function);
}
void annoyIndexValidator(const IndexDescription & index, bool /* attach */)
{
/// Check number and type of Annoy index arguments:
if (index.arguments.size() > 2)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index must not have more than two parameters");
if (!index.arguments.empty() && index.arguments[0].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance function argument of Annoy index must be of type String");
if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::UInt64)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Number of trees argument of Annoy index must be of type UInt64");
/// Check that the index is created on a single column
if (index.column_names.size() != 1 || index.data_types.size() != 1)
throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Annoy indexes must be created on a single column");
/// Check that a supported metric was passed as first argument
if (!index.arguments.empty())
{
String distance_name = index.arguments[0].safeGet<String>();
if (distance_name != DISTANCE_FUNCTION_L2 && distance_name != DISTANCE_FUNCTION_COSINE)
throw Exception(ErrorCodes::INCORRECT_DATA, "Annoy index only supports distance functions '{}' and '{}'", DISTANCE_FUNCTION_L2, DISTANCE_FUNCTION_COSINE);
}
/// Check data type of indexed column:
auto throw_unsupported_underlying_column_exception = []()
{
throw Exception(
ErrorCodes::ILLEGAL_COLUMN,
"Annoy indexes can only be created on columns of type Array(Float32) and Tuple(Float32[, Float32[, ...]])");
};
DataTypePtr data_type = index.sample_block.getDataTypes()[0];
if (const auto * data_type_array = typeid_cast<const DataTypeArray *>(data_type.get()))
{
TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw_unsupported_underlying_column_exception();
}
else if (const auto * data_type_tuple = typeid_cast<const DataTypeTuple *>(data_type.get()))
{
const DataTypes & inner_types = data_type_tuple->getElements();
for (const auto & inner_type : inner_types)
{
TypeIndex nested_type_index = inner_type->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw_unsupported_underlying_column_exception();
}
}
else
throw_unsupported_underlying_column_exception();
}
}
#endif

View File

@ -1,112 +0,0 @@
#pragma once
#ifdef ENABLE_ANNOY
#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>
#include <annoylib.h>
#include <kissrandom.h>
namespace DB
{
template <typename Distance>
class AnnoyIndexWithSerialization : public Annoy::AnnoyIndex<UInt64, Float32, Distance, Annoy::Kiss64Random, Annoy::AnnoyIndexMultiThreadedBuildPolicy>
{
using Base = Annoy::AnnoyIndex<UInt64, Float32, Distance, Annoy::Kiss64Random, Annoy::AnnoyIndexMultiThreadedBuildPolicy>;
public:
explicit AnnoyIndexWithSerialization(size_t dimensions);
void serialize(WriteBuffer & ostr) const;
void deserialize(ReadBuffer & istr);
size_t getDimensions() const;
};
template <typename Distance>
using AnnoyIndexWithSerializationPtr = std::shared_ptr<AnnoyIndexWithSerialization<Distance>>;
template <typename Distance>
struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
{
MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_);
MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_, AnnoyIndexWithSerializationPtr<Distance> index_);
~MergeTreeIndexGranuleAnnoy() override = default;
void serializeBinary(WriteBuffer & ostr) const override;
void deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion version) override;
bool empty() const override { return !index.get(); }
const String index_name;
const Block index_sample_block;
AnnoyIndexWithSerializationPtr<Distance> index;
};
template <typename Distance>
struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator
{
MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block, UInt64 trees, size_t max_threads_for_creation);
~MergeTreeIndexAggregatorAnnoy() override = default;
bool empty() const override { return !index || index->get_n_items() == 0; }
MergeTreeIndexGranulePtr getGranuleAndReset() override;
void update(const Block & block, size_t * pos, size_t limit) override;
const String index_name;
const Block index_sample_block;
const UInt64 trees;
const size_t max_threads_for_creation;
AnnoyIndexWithSerializationPtr<Distance> index;
};
class MergeTreeIndexConditionAnnoy final : public IMergeTreeIndexConditionApproximateNearestNeighbor
{
public:
MergeTreeIndexConditionAnnoy(
const IndexDescription & index_description,
const SelectQueryInfo & query,
const String & distance_function,
ContextPtr context);
~MergeTreeIndexConditionAnnoy() override = default;
bool alwaysUnknownOrTrue() const override;
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;
private:
template <typename Distance>
std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;
const ApproximateNearestNeighborCondition ann_condition;
const String distance_function;
const Int64 search_k;
};
class MergeTreeIndexAnnoy final : public IMergeTreeIndex
{
public:
MergeTreeIndexAnnoy(const IndexDescription & index_, UInt64 trees_, const String & distance_function_);
~MergeTreeIndexAnnoy() override = default;
MergeTreeIndexGranulePtr createIndexGranule() const override;
MergeTreeIndexAggregatorPtr createIndexAggregator(const MergeTreeWriterSettings & settings) const override;
MergeTreeIndexConditionPtr createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const;
MergeTreeIndexConditionPtr createIndexCondition(const ActionsDAG *, ContextPtr) const override;
bool isVectorSearch() const override { return true; }
private:
const UInt64 trees;
const String distance_function;
};
}
#endif

View File

@ -0,0 +1,45 @@
#include <Storages/MergeTree/MergeTreeIndexLegacyVectorSimilarity.h>
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_INDEX;
}
MergeTreeIndexLegacyVectorSimilarity::MergeTreeIndexLegacyVectorSimilarity(const IndexDescription & index_)
: IMergeTreeIndex(index_)
{
}
MergeTreeIndexGranulePtr MergeTreeIndexLegacyVectorSimilarity::createIndexGranule() const
{
throw Exception(ErrorCodes::ILLEGAL_INDEX, "Indexes of type 'annoy' or 'usearch' are no longer supported. Please drop and recreate the index as type 'vector_similarity'");
}
MergeTreeIndexAggregatorPtr MergeTreeIndexLegacyVectorSimilarity::createIndexAggregator(const MergeTreeWriterSettings &) const
{
throw Exception(ErrorCodes::ILLEGAL_INDEX, "Indexes of type 'annoy' or 'usearch' are no longer supported. Please drop and recreate the index as type 'vector_similarity'");
}
MergeTreeIndexConditionPtr MergeTreeIndexLegacyVectorSimilarity::createIndexCondition(const SelectQueryInfo &, ContextPtr) const
{
throw Exception(ErrorCodes::ILLEGAL_INDEX, "Indexes of type 'annoy' or 'usearch' are no longer supported. Please drop and recreate the index as type 'vector_similarity'");
};
MergeTreeIndexConditionPtr MergeTreeIndexLegacyVectorSimilarity::createIndexCondition(const ActionsDAG *, ContextPtr) const
{
throw Exception(ErrorCodes::ILLEGAL_INDEX, "Indexes of type 'annoy' or 'usearch' are no longer supported. Please drop and recreate the index as type 'vector_similarity'");
}
MergeTreeIndexPtr legacyVectorSimilarityIndexCreator(const IndexDescription & index)
{
return std::make_shared<MergeTreeIndexLegacyVectorSimilarity>(index);
}
void legacyVectorSimilarityIndexValidator(const IndexDescription &, bool)
{
}
}

View File

@ -0,0 +1,26 @@
#pragma once
#include <Storages/MergeTree/VectorSimilarityCondition.h>
/// Walking corpse implementation for removed skipping index of type "annoy" and "usearch".
/// Its only purpose is to allow loading old tables with indexes of these types.
/// Data insertion and index usage/search will throw an exception, suggesting to migrate to "vector_similarity" indexes.
namespace DB
{
class MergeTreeIndexLegacyVectorSimilarity : public IMergeTreeIndex
{
public:
explicit MergeTreeIndexLegacyVectorSimilarity(const IndexDescription & index_);
~MergeTreeIndexLegacyVectorSimilarity() override = default;
MergeTreeIndexGranulePtr createIndexGranule() const override;
MergeTreeIndexAggregatorPtr createIndexAggregator(const MergeTreeWriterSettings &) const override;
MergeTreeIndexConditionPtr createIndexCondition(const SelectQueryInfo &, ContextPtr) const;
MergeTreeIndexConditionPtr createIndexCondition(const ActionsDAG *, ContextPtr) const override;
bool isVectorSimilarityIndex() const override { return true; }
};
}

View File

@ -1,463 +0,0 @@
#ifdef ENABLE_USEARCH
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#include <Storages/MergeTree/MergeTreeIndexUSearch.h>
#include <Columns/ColumnArray.h>
#include <Common/typeid_cast.h>
#include <Core/Field.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include <Interpreters/castColumn.h>
namespace ProfileEvents
{
extern const Event USearchAddCount;
extern const Event USearchAddVisitedMembers;
extern const Event USearchAddComputedDistances;
extern const Event USearchSearchCount;
extern const Event USearchSearchVisitedMembers;
extern const Event USearchSearchComputedDistances;
}
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_ALLOCATE_MEMORY;
extern const int ILLEGAL_COLUMN;
extern const int INCORRECT_DATA;
extern const int INCORRECT_NUMBER_OF_COLUMNS;
extern const int INCORRECT_QUERY;
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
namespace
{
std::unordered_map<String, unum::usearch::scalar_kind_t> nameToScalarKind = {
{"f64", unum::usearch::scalar_kind_t::f64_k},
{"f32", unum::usearch::scalar_kind_t::f32_k},
{"f16", unum::usearch::scalar_kind_t::f16_k},
{"i8", unum::usearch::scalar_kind_t::i8_k}};
}
template <unum::usearch::metric_kind_t Metric>
USearchIndexWithSerialization<Metric>::USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind)
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, Metric, scalar_kind)))
{
}
template <unum::usearch::metric_kind_t Metric>
void USearchIndexWithSerialization<Metric>::serialize(WriteBuffer & ostr) const
{
auto callback = [&ostr](void * from, size_t n)
{
ostr.write(reinterpret_cast<const char *>(from), n);
return true;
};
Base::save_to_stream(callback);
}
template <unum::usearch::metric_kind_t Metric>
void USearchIndexWithSerialization<Metric>::deserialize(ReadBuffer & istr)
{
auto callback = [&istr](void * from, size_t n)
{
istr.readStrict(reinterpret_cast<char *>(from), n);
return true;
};
Base::load_from_stream(callback);
}
template <unum::usearch::metric_kind_t Metric>
size_t USearchIndexWithSerialization<Metric>::getDimensions() const
{
return Base::dimensions();
}
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::scalar_kind_t scalar_kind_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, scalar_kind(scalar_kind_)
, index(nullptr)
{
}
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranuleUSearch<Metric>::MergeTreeIndexGranuleUSearch(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::scalar_kind_t scalar_kind_,
USearchIndexWithSerializationPtr<Metric> index_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, scalar_kind(scalar_kind_)
, index(std::move(index_))
{
}
template <unum::usearch::metric_kind_t Metric>
void MergeTreeIndexGranuleUSearch<Metric>::serializeBinary(WriteBuffer & ostr) const
{
/// Number of dimensions is required in the index constructor,
/// so it must be written and read separately from the other part
writeIntBinary(static_cast<UInt64>(index->getDimensions()), ostr); // write dimension
index->serialize(ostr);
}
template <unum::usearch::metric_kind_t Metric>
void MergeTreeIndexGranuleUSearch<Metric>::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
{
UInt64 dimension;
readIntBinary(dimension, istr);
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimension, scalar_kind);
index->deserialize(istr);
}
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexAggregatorUSearch<Metric>::MergeTreeIndexAggregatorUSearch(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::scalar_kind_t scalar_kind_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, scalar_kind(scalar_kind_)
{
}
template <unum::usearch::metric_kind_t Metric>
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorUSearch<Metric>::getGranuleAndReset()
{
auto granule = std::make_shared<MergeTreeIndexGranuleUSearch<Metric>>(index_name, index_sample_block, scalar_kind, index);
index = nullptr;
return granule;
}
template <unum::usearch::metric_kind_t Metric>
void MergeTreeIndexAggregatorUSearch<Metric>::update(const Block & block, size_t * pos, size_t limit)
{
if (*pos >= block.rows())
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"The provided position is not less than the number of block rows. Position: {}, Block rows: {}.",
*pos,
block.rows());
size_t rows_read = std::min(limit, block.rows() - *pos);
if (rows_read == 0)
return;
if (rows_read > std::numeric_limits<uint32_t>::max())
throw Exception(ErrorCodes::INCORRECT_DATA, "Index granularity is too big: more than 4B rows per index granule.");
if (index_sample_block.columns() > 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected block with single column");
const String & index_column_name = index_sample_block.getByPosition(0).name;
ColumnPtr column_cut = block.getByName(index_column_name).column->cut(*pos, rows_read);
if (const auto & column_array = typeid_cast<const ColumnArray *>(column_cut.get()))
{
const auto & column_array_data = column_array->getData();
const auto & column_array_data_float = typeid_cast<const ColumnFloat32 &>(column_array_data);
const auto & column_array_data_float_data = column_array_data_float.getData();
const auto & column_array_offsets = column_array->getOffsets();
const size_t num_rows = column_array_offsets.size();
if (column_array->empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Array is unexpectedly empty");
/// The Usearch algorithm naturally assumes that the indexed vectors have dimension >= 1. This condition is violated if empty arrays
/// are INSERTed into an Usearch-indexed column or if no value was specified at all in which case the arrays take on their default
/// values which is also empty.
if (column_array->isDefaultAt(0))
throw Exception(ErrorCodes::INCORRECT_DATA, "The arrays in column '{}' must not be empty. Did you try to INSERT default values?", index_column_name);
/// Check all sizes are the same
size_t dimension = column_array_offsets[0];
for (size_t i = 0; i < num_rows - 1; ++i)
if (column_array_offsets[i + 1] - column_array_offsets[i] != dimension)
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
/// Also check that previously inserted blocks have the same size as this block.
/// Note that this guarantees consistency of dimension only within parts. We are unable to detect inconsistent dimensions across
/// parts - for this, a little help from the user is needed, e.g. CONSTRAINT cnstr CHECK length(array) = 42.
if (index && index->getDimensions() != dimension)
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
if (!index)
index = std::make_shared<USearchIndexWithSerialization<Metric>>(dimension, scalar_kind);
/// Add all rows of block
if (!index->reserve(unum::usearch::ceil2(index->size() + num_rows)))
throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for usearch index");
for (size_t current_row = 0; current_row < num_rows; ++current_row)
{
auto rc = index->add(static_cast<uint32_t>(index->size()), &column_array_data_float_data[column_array_offsets[current_row - 1]]);
if (!rc)
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, rc.error.release());
ProfileEvents::increment(ProfileEvents::USearchAddCount);
ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, rc.visited_members);
ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, rc.computed_distances);
}
}
else if (const auto & column_tuple = typeid_cast<const ColumnTuple *>(column_cut.get()))
{
const auto & column_tuple_columns = column_tuple->getColumns();
std::vector<std::vector<Float32>> data(column_tuple->size(), std::vector<Float32>());
for (const auto & column : column_tuple_columns)
{
const auto & pod_array = typeid_cast<const ColumnFloat32 *>(column.get())->getData();
for (size_t i = 0; i < pod_array.size(); ++i)
data[i].push_back(pod_array[i]);
}
if (data.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Tuple has 0 rows, {} rows expected", rows_read);
if (!index)
index = std::make_shared<USearchIndexWithSerialization<Metric>>(data[0].size(), scalar_kind);
if (!index->reserve(unum::usearch::ceil2(index->size() + data.size())))
throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for usearch index");
for (const auto & item : data)
{
auto rc = index->add(static_cast<uint32_t>(index->size()), item.data());
if (!rc)
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, rc.error.release());
ProfileEvents::increment(ProfileEvents::USearchAddCount);
ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, rc.visited_members);
ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, rc.computed_distances);
}
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array or Tuple column");
*pos += rows_read;
}
MergeTreeIndexConditionUSearch::MergeTreeIndexConditionUSearch(
const IndexDescription & /*index_description*/,
const SelectQueryInfo & query,
const String & distance_function_,
ContextPtr context)
: ann_condition(query, context)
, distance_function(distance_function_)
{
}
bool MergeTreeIndexConditionUSearch::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /*idx_granule*/) const
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "mayBeTrueOnGranule is not supported for ANN skip indexes");
}
bool MergeTreeIndexConditionUSearch::alwaysUnknownOrTrue() const
{
return ann_condition.alwaysUnknownOrTrue(distance_function);
}
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return getUsefulRangesImpl<unum::usearch::metric_kind_t::l2sq_k>(idx_granule);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return getUsefulRangesImpl<unum::usearch::metric_kind_t::cos_k>(idx_granule);
std::unreachable();
}
template <unum::usearch::metric_kind_t Metric>
std::vector<size_t> MergeTreeIndexConditionUSearch::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
{
const UInt64 limit = ann_condition.getLimit();
const UInt64 index_granularity = ann_condition.getIndexGranularity();
const std::optional<float> comparison_distance = ann_condition.getQueryType() == ApproximateNearestNeighborInformation::Type::Where
? std::optional<float>(ann_condition.getComparisonDistanceForWhereQuery())
: std::nullopt;
if (comparison_distance && comparison_distance.value() < 0)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to optimize query with where without distance");
const std::vector<float> reference_vector = ann_condition.getReferenceVector();
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleUSearch<Metric>>(idx_granule);
if (granule == nullptr)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
const USearchIndexWithSerializationPtr<Metric> index = granule->index;
if (ann_condition.getDimensions() != index->dimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
"does not match the dimension in the index ({})",
ann_condition.getDimensions(), index->dimensions());
auto result = index->search(reference_vector.data(), limit);
ProfileEvents::increment(ProfileEvents::USearchSearchCount);
ProfileEvents::increment(ProfileEvents::USearchSearchVisitedMembers, result.visited_members);
ProfileEvents::increment(ProfileEvents::USearchSearchComputedDistances, result.computed_distances);
std::vector<UInt32> neighbors(result.size()); /// indexes of dots which were closest to the reference vector
std::vector<Float32> distances(result.size());
result.dump_to(neighbors.data(), distances.data());
std::vector<size_t> granules;
granules.reserve(neighbors.size());
for (size_t i = 0; i < neighbors.size(); ++i)
{
if (comparison_distance && distances[i] > comparison_distance)
continue;
granules.push_back(neighbors[i] / index_granularity);
}
/// make unique
std::sort(granules.begin(), granules.end());
granules.erase(std::unique(granules.begin(), granules.end()), granules.end());
return granules;
}
MergeTreeIndexUSearch::MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_)
: IMergeTreeIndex(index_)
, distance_function(distance_function_)
, scalar_kind(scalar_kind_)
{
}
MergeTreeIndexGranulePtr MergeTreeIndexUSearch::createIndexGranule() const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexGranuleUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
std::unreachable();
}
MergeTreeIndexAggregatorPtr MergeTreeIndexUSearch::createIndexAggregator(const MergeTreeWriterSettings & /*settings*/) const
{
if (distance_function == DISTANCE_FUNCTION_L2)
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::l2sq_k>>(index.name, index.sample_block, scalar_kind);
else if (distance_function == DISTANCE_FUNCTION_COSINE)
return std::make_shared<MergeTreeIndexAggregatorUSearch<unum::usearch::metric_kind_t::cos_k>>(index.name, index.sample_block, scalar_kind);
std::unreachable();
}
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionUSearch>(index, query, distance_function, context);
};
MergeTreeIndexConditionPtr MergeTreeIndexUSearch::createIndexCondition(const ActionsDAG *, ContextPtr) const
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "MergeTreeIndexAnnoy cannot be created with ActionsDAG");
}
MergeTreeIndexPtr usearchIndexCreator(const IndexDescription & index)
{
static constexpr auto default_distance_function = DISTANCE_FUNCTION_L2;
String distance_function = default_distance_function;
if (!index.arguments.empty())
distance_function = index.arguments[0].safeGet<String>();
static constexpr auto default_scalar_kind = unum::usearch::scalar_kind_t::f16_k;
auto scalar_kind = default_scalar_kind;
if (index.arguments.size() > 1)
scalar_kind = nameToScalarKind.at(index.arguments[1].safeGet<String>());
return std::make_shared<MergeTreeIndexUSearch>(index, distance_function, scalar_kind);
}
void usearchIndexValidator(const IndexDescription & index, bool /* attach */)
{
/// Check number and type of USearch index arguments:
if (index.arguments.size() > 2)
throw Exception(ErrorCodes::INCORRECT_QUERY, "USearch index must not have more than one parameters");
if (!index.arguments.empty() && index.arguments[0].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "First argument of USearch index (distance function) must be of type String");
if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Second argument of USearch index (scalar type) must be of type String");
/// Check that the index is created on a single column
if (index.column_names.size() != 1 || index.data_types.size() != 1)
throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "USearch indexes must be created on a single column");
/// Check that a supported metric was passed as first argument
if (!index.arguments.empty())
{
String distance_name = index.arguments[0].safeGet<String>();
if (distance_name != DISTANCE_FUNCTION_L2 && distance_name != DISTANCE_FUNCTION_COSINE)
throw Exception(ErrorCodes::INCORRECT_DATA, "USearch index only supports distance functions '{}' and '{}'", DISTANCE_FUNCTION_L2, DISTANCE_FUNCTION_COSINE);
}
/// Check that a supported kind was passed as a second argument
if (index.arguments.size() > 1 && !nameToScalarKind.contains(index.arguments[1].safeGet<String>()))
{
String supported_kinds;
for (const auto & [name, kind] : nameToScalarKind)
{
if (!supported_kinds.empty())
supported_kinds += ", ";
supported_kinds += name;
}
throw Exception(ErrorCodes::INCORRECT_DATA, "Unrecognized scalar kind (second argument) for USearch index. Supported kinds are: {}", supported_kinds);
}
/// Check data type of indexed column:
auto throw_unsupported_underlying_column_exception = []()
{
throw Exception(
ErrorCodes::ILLEGAL_COLUMN,
"USearch can only be created on columns of type Array(Float32) and Tuple(Float32[, Float32[, ...]])");
};
DataTypePtr data_type = index.sample_block.getDataTypes()[0];
if (const auto * data_type_array = typeid_cast<const DataTypeArray *>(data_type.get()))
{
TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw_unsupported_underlying_column_exception();
}
else if (const auto * data_type_tuple = typeid_cast<const DataTypeTuple *>(data_type.get()))
{
const DataTypes & inner_types = data_type_tuple->getElements();
for (const auto & inner_type : inner_types)
{
TypeIndex nested_type_index = inner_type->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw_unsupported_underlying_column_exception();
}
}
else
throw_unsupported_underlying_column_exception();
}
}
#endif

View File

@ -1,116 +0,0 @@
#pragma once
#ifdef ENABLE_USEARCH
#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#include <usearch/index_dense.hpp>
#pragma clang diagnostic pop
namespace DB
{
using USearchImplType = unum::usearch::index_dense_gt</* key_at */ uint32_t, /* compressed_slot_at */ uint32_t>;
template <unum::usearch::metric_kind_t Metric>
class USearchIndexWithSerialization : public USearchImplType
{
using Base = USearchImplType;
public:
USearchIndexWithSerialization(size_t dimensions, unum::usearch::scalar_kind_t scalar_kind);
void serialize(WriteBuffer & ostr) const;
void deserialize(ReadBuffer & istr);
size_t getDimensions() const;
};
template <unum::usearch::metric_kind_t Metric>
using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization<Metric>>;
template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexGranuleUSearch final : public IMergeTreeIndexGranule
{
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_);
MergeTreeIndexGranuleUSearch(const String & index_name_, const Block & index_sample_block_, unum::usearch::scalar_kind_t scalar_kind_, USearchIndexWithSerializationPtr<Metric> index_);
~MergeTreeIndexGranuleUSearch() override = default;
void serializeBinary(WriteBuffer & ostr) const override;
void deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion version) override;
bool empty() const override { return !index.get(); }
const String index_name;
const Block index_sample_block;
const unum::usearch::scalar_kind_t scalar_kind;
USearchIndexWithSerializationPtr<Metric> index;
};
template <unum::usearch::metric_kind_t Metric>
struct MergeTreeIndexAggregatorUSearch final : IMergeTreeIndexAggregator
{
MergeTreeIndexAggregatorUSearch(const String & index_name_, const Block & index_sample_block, unum::usearch::scalar_kind_t scalar_kind_);
~MergeTreeIndexAggregatorUSearch() override = default;
bool empty() const override { return !index || index->size() == 0; }
MergeTreeIndexGranulePtr getGranuleAndReset() override;
void update(const Block & block, size_t * pos, size_t limit) override;
const String index_name;
const Block index_sample_block;
const unum::usearch::scalar_kind_t scalar_kind;
USearchIndexWithSerializationPtr<Metric> index;
};
class MergeTreeIndexConditionUSearch final : public IMergeTreeIndexConditionApproximateNearestNeighbor
{
public:
MergeTreeIndexConditionUSearch(
const IndexDescription & index_description,
const SelectQueryInfo & query,
const String & distance_function,
ContextPtr context);
~MergeTreeIndexConditionUSearch() override = default;
bool alwaysUnknownOrTrue() const override;
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;
private:
template <unum::usearch::metric_kind_t Metric>
std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;
const ApproximateNearestNeighborCondition ann_condition;
const String distance_function;
};
class MergeTreeIndexUSearch : public IMergeTreeIndex
{
public:
MergeTreeIndexUSearch(const IndexDescription & index_, const String & distance_function_, unum::usearch::scalar_kind_t scalar_kind_);
~MergeTreeIndexUSearch() override = default;
MergeTreeIndexGranulePtr createIndexGranule() const override;
MergeTreeIndexAggregatorPtr createIndexAggregator(const MergeTreeWriterSettings & settings) const override;
MergeTreeIndexConditionPtr createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const;
MergeTreeIndexConditionPtr createIndexCondition(const ActionsDAG *, ContextPtr) const override;
bool isVectorSearch() const override { return true; }
private:
const String distance_function;
const unum::usearch::scalar_kind_t scalar_kind;
};
}
#endif

View File

@ -0,0 +1,492 @@
#include <Storages/MergeTree/MergeTreeIndexVectorSimilarity.h>
#if USE_USEARCH
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#include <Columns/ColumnArray.h>
#include <Common/BitHelpers.h>
#include <Common/formatReadable.h>
#include <Common/logger_useful.h>
#include <Common/typeid_cast.h>
#include <Core/Field.h>
#include <DataTypes/DataTypeArray.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include <Interpreters/castColumn.h>
namespace ProfileEvents
{
extern const Event USearchAddCount;
extern const Event USearchAddVisitedMembers;
extern const Event USearchAddComputedDistances;
extern const Event USearchSearchCount;
extern const Event USearchSearchVisitedMembers;
extern const Event USearchSearchComputedDistances;
}
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_ALLOCATE_MEMORY;
extern const int FORMAT_VERSION_TOO_OLD;
extern const int ILLEGAL_COLUMN;
extern const int INCORRECT_DATA;
extern const int INCORRECT_NUMBER_OF_COLUMNS;
extern const int INCORRECT_QUERY;
extern const int LOGICAL_ERROR;
extern const int NOT_IMPLEMENTED;
}
namespace
{
/// The only indexing method currently supported by USearch
std::set<String> methods = {"hnsw"};
/// Maps from user-facing name to internal name
std::unordered_map<String, unum::usearch::metric_kind_t> distanceFunctionToMetricKind = {
{"L2Distance", unum::usearch::metric_kind_t::l2sq_k},
{"cosineDistance", unum::usearch::metric_kind_t::cos_k}};
/// Maps from user-facing name to internal name
std::unordered_map<String, unum::usearch::scalar_kind_t> quantizationToScalarKind = {
{"f32", unum::usearch::scalar_kind_t::f32_k},
{"f16", unum::usearch::scalar_kind_t::f16_k},
{"i8", unum::usearch::scalar_kind_t::i8_k}};
template<typename T>
concept is_set = std::same_as<T, std::set<typename T::key_type, typename T::key_compare, typename T::allocator_type>>;
template<typename T>
concept is_unordered_map = std::same_as<T, std::unordered_map<typename T::key_type, typename T::mapped_type, typename T::hasher, typename T::key_equal, typename T::allocator_type>>;
template <typename T>
String joinByComma(const T & t)
{
if constexpr (is_set<T>)
{
return fmt::format("{}", fmt::join(t, ", "));
}
else if constexpr (is_unordered_map<T>)
{
String joined_keys;
for (const auto & [k, _] : t)
{
if (!joined_keys.empty())
joined_keys += ", ";
joined_keys += k;
}
return joined_keys;
}
/// TODO once our libcxx is recent enough, replace above by
/// return fmt::format("{}", fmt::join(std::views::keys(t)), ", "));
std::unreachable();
}
}
USearchIndexWithSerialization::USearchIndexWithSerialization(
size_t dimensions,
unum::usearch::metric_kind_t metric_kind,
unum::usearch::scalar_kind_t scalar_kind,
UsearchHnswParams usearch_hnsw_params)
: Base(Base::make(unum::usearch::metric_punned_t(dimensions, metric_kind, scalar_kind),
unum::usearch::index_dense_config_t(usearch_hnsw_params.m, usearch_hnsw_params.ef_construction, usearch_hnsw_params.ef_search)))
{
}
void USearchIndexWithSerialization::serialize(WriteBuffer & ostr) const
{
auto callback = [&ostr](void * from, size_t n)
{
ostr.write(reinterpret_cast<const char *>(from), n);
return true;
};
auto result = Base::save_to_stream(callback);
if (result.error)
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not save vector similarity index, error: " + String(result.error.release()));
}
void USearchIndexWithSerialization::deserialize(ReadBuffer & istr)
{
auto callback = [&istr](void * from, size_t n)
{
istr.readStrict(reinterpret_cast<char *>(from), n);
return true;
};
auto result = Base::load_from_stream(callback);
if (result.error)
/// See the comment in MergeTreeIndexGranuleVectorSimilarity::deserializeBinary why we throw here
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not load vector similarity index, error: " + String(result.error.release()) + " Please drop the index and create it again.");
}
USearchIndexWithSerialization::Statistics USearchIndexWithSerialization::getStatistics() const
{
Statistics statistics = {
.max_level = max_level(),
.connectivity = connectivity(),
.size = size(), /// number of vectors
.capacity = capacity(), /// number of vectors reserved
.memory_usage = memory_usage(), /// in bytes, the value is not exact
.bytes_per_vector = bytes_per_vector(),
.scalar_words = scalar_words(),
.statistics = stats()};
return statistics;
}
MergeTreeIndexGranuleVectorSimilarity::MergeTreeIndexGranuleVectorSimilarity(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
UsearchHnswParams usearch_hnsw_params_)
: MergeTreeIndexGranuleVectorSimilarity(index_name_, index_sample_block_, metric_kind_, scalar_kind_, usearch_hnsw_params_, nullptr)
{
}
MergeTreeIndexGranuleVectorSimilarity::MergeTreeIndexGranuleVectorSimilarity(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
UsearchHnswParams usearch_hnsw_params_,
USearchIndexWithSerializationPtr index_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_)
, usearch_hnsw_params(usearch_hnsw_params_)
, index(std::move(index_))
{
}
void MergeTreeIndexGranuleVectorSimilarity::serializeBinary(WriteBuffer & ostr) const
{
if (empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to write empty minmax index {}", backQuote(index_name));
writeIntBinary(FILE_FORMAT_VERSION, ostr);
/// Number of dimensions is required in the index constructor,
/// so it must be written and read separately from the other part
writeIntBinary(static_cast<UInt64>(index->dimensions()), ostr);
index->serialize(ostr);
auto statistics = index->getStatistics();
LOG_TRACE(logger, "Wrote vector similarity index: max_level = {}, connectivity = {}, size = {}, capacity = {}, memory_usage = {}",
statistics.max_level, statistics.connectivity, statistics.size, statistics.capacity, ReadableSize(statistics.memory_usage));
}
void MergeTreeIndexGranuleVectorSimilarity::deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion /*version*/)
{
UInt64 file_version;
readIntBinary(file_version, istr);
if (file_version != FILE_FORMAT_VERSION)
throw Exception(
ErrorCodes::FORMAT_VERSION_TOO_OLD,
"Vector similarity index could not be loaded because its version is too old (current version: {}, persisted version: {}). Please drop the index and create it again.",
FILE_FORMAT_VERSION, file_version);
/// More fancy error handling would be: Set a flag on the index that it failed to load. During usage return all granules, i.e.
/// behave as if the index does not exist. Since format changes are expected to happen only rarely and it is "only" an index, keep it simple for now.
UInt64 dimension;
readIntBinary(dimension, istr);
index = std::make_shared<USearchIndexWithSerialization>(dimension, metric_kind, scalar_kind, usearch_hnsw_params);
index->deserialize(istr);
auto statistics = index->getStatistics();
LOG_TRACE(logger, "Loaded vector similarity index: max_level = {}, connectivity = {}, size = {}, capacity = {}, memory_usage = {}",
statistics.max_level, statistics.connectivity, statistics.size, statistics.capacity, ReadableSize(statistics.memory_usage));
}
MergeTreeIndexAggregatorVectorSimilarity::MergeTreeIndexAggregatorVectorSimilarity(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
UsearchHnswParams usearch_hnsw_params_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_)
, usearch_hnsw_params(usearch_hnsw_params_)
{
}
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorVectorSimilarity::getGranuleAndReset()
{
auto granule = std::make_shared<MergeTreeIndexGranuleVectorSimilarity>(index_name, index_sample_block, metric_kind, scalar_kind, usearch_hnsw_params, index);
index = nullptr;
return granule;
}
void MergeTreeIndexAggregatorVectorSimilarity::update(const Block & block, size_t * pos, size_t limit)
{
if (*pos >= block.rows())
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"The provided position is not less than the number of block rows. Position: {}, Block rows: {}.",
*pos,
block.rows());
size_t rows_read = std::min(limit, block.rows() - *pos);
if (rows_read == 0)
return;
if (rows_read > std::numeric_limits<UInt32>::max())
throw Exception(ErrorCodes::INCORRECT_DATA, "Index granularity is too big: more than {} rows per index granule.", std::numeric_limits<UInt32>::max());
if (index_sample_block.columns() > 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected block with single column");
const String & index_column_name = index_sample_block.getByPosition(0).name;
ColumnPtr column_cut = block.getByName(index_column_name).column->cut(*pos, rows_read);
if (const auto & column_array = typeid_cast<const ColumnArray *>(column_cut.get()))
{
const auto & column_array_data = column_array->getData();
const auto & column_array_data_float = typeid_cast<const ColumnFloat32 &>(column_array_data);
const auto & column_array_data_float_data = column_array_data_float.getData();
const auto & column_array_offsets = column_array->getOffsets();
const size_t num_rows = column_array_offsets.size();
if (column_array->empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Array is unexpectedly empty");
/// The vector similarity algorithm naturally assumes that the indexed vectors have dimension >= 1. This condition is violated if empty arrays
/// are INSERTed into an vector-similarity-indexed column or if no value was specified at all in which case the arrays take on their default
/// values which is also empty.
if (column_array->isDefaultAt(0))
throw Exception(ErrorCodes::INCORRECT_DATA, "The arrays in column '{}' must not be empty. Did you try to INSERT default values?", index_column_name);
/// Check all sizes are the same
const size_t dimensions = column_array_offsets[0];
for (size_t i = 0; i < num_rows - 1; ++i)
if (column_array_offsets[i + 1] - column_array_offsets[i] != dimensions)
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
/// Also check that previously inserted blocks have the same size as this block.
/// Note that this guarantees consistency of dimension only within parts. We are unable to detect inconsistent dimensions across
/// parts - for this, a little help from the user is needed, e.g. CONSTRAINT cnstr CHECK length(array) = 42.
if (index && index->dimensions() != dimensions)
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column '{}' must have equal length", index_column_name);
if (!index)
index = std::make_shared<USearchIndexWithSerialization>(dimensions, metric_kind, scalar_kind, usearch_hnsw_params);
/// Reserving space is mandatory
if (!index->reserve(roundUpToPowerOfTwoOrZero(index->size() + num_rows)))
throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for vector similarity index");
for (size_t row = 0; row < num_rows; ++row)
{
auto rc = index->add(static_cast<UInt32>(index->size()), &column_array_data_float_data[column_array_offsets[row - 1]]);
if (!rc)
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not add data to vector similarity index, error: " + String(rc.error.release()));
ProfileEvents::increment(ProfileEvents::USearchAddCount);
ProfileEvents::increment(ProfileEvents::USearchAddVisitedMembers, rc.visited_members);
ProfileEvents::increment(ProfileEvents::USearchAddComputedDistances, rc.computed_distances);
}
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array(Float32) column");
*pos += rows_read;
}
MergeTreeIndexConditionVectorSimilarity::MergeTreeIndexConditionVectorSimilarity(
const IndexDescription & /*index_description*/,
const SelectQueryInfo & query,
unum::usearch::metric_kind_t metric_kind_,
ContextPtr context)
: vector_similarity_condition(query, context)
, metric_kind(metric_kind_)
{
}
bool MergeTreeIndexConditionVectorSimilarity::mayBeTrueOnGranule(MergeTreeIndexGranulePtr) const
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "mayBeTrueOnGranule is not supported for ANN skip indexes");
}
bool MergeTreeIndexConditionVectorSimilarity::alwaysUnknownOrTrue() const
{
String index_distance_function;
switch (metric_kind)
{
case unum::usearch::metric_kind_t::l2sq_k: index_distance_function = "L2Distance"; break;
case unum::usearch::metric_kind_t::cos_k: index_distance_function = "cosineDistance"; break;
default: std::unreachable();
}
return vector_similarity_condition.alwaysUnknownOrTrue(index_distance_function);
}
std::vector<size_t> MergeTreeIndexConditionVectorSimilarity::getUsefulRanges(MergeTreeIndexGranulePtr granule_) const
{
const UInt64 limit = vector_similarity_condition.getLimit();
const UInt64 index_granularity = vector_similarity_condition.getIndexGranularity();
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleVectorSimilarity>(granule_);
if (granule == nullptr)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
const USearchIndexWithSerializationPtr index = granule->index;
if (vector_similarity_condition.getDimensions() != index->dimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
"does not match the dimension in the index ({})",
vector_similarity_condition.getDimensions(), index->dimensions());
const std::vector<float> reference_vector = vector_similarity_condition.getReferenceVector();
auto result = index->search(reference_vector.data(), limit);
if (result.error)
throw Exception::createRuntime(ErrorCodes::INCORRECT_DATA, "Could not search in vector similarity index, error: " + String(result.error.release()));
ProfileEvents::increment(ProfileEvents::USearchSearchCount);
ProfileEvents::increment(ProfileEvents::USearchSearchVisitedMembers, result.visited_members);
ProfileEvents::increment(ProfileEvents::USearchSearchComputedDistances, result.computed_distances);
std::vector<USearchIndex::key_t> neighbors(result.size()); /// indexes of dots which were closest to the reference vector
std::vector<USearchIndex::distance_t> distances(result.size());
result.dump_to(neighbors.data(), distances.data());
std::vector<size_t> granules;
granules.reserve(neighbors.size());
for (auto neighbor : neighbors)
granules.push_back(neighbor / index_granularity);
/// make unique
std::sort(granules.begin(), granules.end());
granules.erase(std::unique(granules.begin(), granules.end()), granules.end());
return granules;
}
MergeTreeIndexVectorSimilarity::MergeTreeIndexVectorSimilarity(
const IndexDescription & index_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
UsearchHnswParams usearch_hnsw_params_)
: IMergeTreeIndex(index_)
, metric_kind(metric_kind_)
, scalar_kind(scalar_kind_)
, usearch_hnsw_params(usearch_hnsw_params_)
{
}
MergeTreeIndexGranulePtr MergeTreeIndexVectorSimilarity::createIndexGranule() const
{
return std::make_shared<MergeTreeIndexGranuleVectorSimilarity>(index.name, index.sample_block, metric_kind, scalar_kind, usearch_hnsw_params);
}
MergeTreeIndexAggregatorPtr MergeTreeIndexVectorSimilarity::createIndexAggregator(const MergeTreeWriterSettings & /*settings*/) const
{
return std::make_shared<MergeTreeIndexAggregatorVectorSimilarity>(index.name, index.sample_block, metric_kind, scalar_kind, usearch_hnsw_params);
}
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionVectorSimilarity>(index, query, metric_kind, context);
};
MergeTreeIndexConditionPtr MergeTreeIndexVectorSimilarity::createIndexCondition(const ActionsDAG *, ContextPtr) const
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "MergeTreeIndexAnnoy cannot be created with ActionsDAG");
}
MergeTreeIndexPtr vectorSimilarityIndexCreator(const IndexDescription & index)
{
const bool has_six_args = (index.arguments.size() == 6);
unum::usearch::metric_kind_t metric_kind = distanceFunctionToMetricKind.at(index.arguments[1].safeGet<String>());
/// use defaults for the other parameters
unum::usearch::scalar_kind_t scalar_kind = unum::usearch::scalar_kind_t::f32_k;
UsearchHnswParams usearch_hnsw_params;
if (has_six_args)
{
scalar_kind = quantizationToScalarKind.at(index.arguments[2].safeGet<String>());
usearch_hnsw_params = {.m = index.arguments[3].safeGet<UInt64>(),
.ef_construction = index.arguments[4].safeGet<UInt64>(),
.ef_search = index.arguments[5].safeGet<UInt64>()};
}
return std::make_shared<MergeTreeIndexVectorSimilarity>(index, metric_kind, scalar_kind, usearch_hnsw_params);
}
void vectorSimilarityIndexValidator(const IndexDescription & index, bool /* attach */)
{
const bool has_two_args = (index.arguments.size() == 2);
const bool has_six_args = (index.arguments.size() == 6);
/// Check number and type of arguments
if (!has_two_args && !has_six_args)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Vector similarity index must have two or six arguments");
if (index.arguments[0].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "First argument of vector similarity index (method) must be of type String");
if (index.arguments[1].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Second argument of vector similarity index (metric) must be of type String");
if (has_six_args)
{
if (index.arguments[2].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Third argument of vector similarity index (quantization) must be of type String");
if (index.arguments[3].getType() != Field::Types::UInt64)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Fourth argument of vector similarity index (M) must be of type UInt64");
if (index.arguments[4].getType() != Field::Types::UInt64)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Fifth argument of vector similarity index (ef_construction) must be of type UInt64");
if (index.arguments[5].getType() != Field::Types::UInt64)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Sixth argument of vector similarity index (ef_search) must be of type UInt64");
}
/// Check that passed arguments are supported
if (!methods.contains(index.arguments[0].safeGet<String>()))
throw Exception(ErrorCodes::INCORRECT_DATA, "First argument (method) of vector similarity index is not supported. Supported methods are: {}", joinByComma(methods));
if (!distanceFunctionToMetricKind.contains(index.arguments[1].safeGet<String>()))
throw Exception(ErrorCodes::INCORRECT_DATA, "Second argument (distance function) of vector similarity index is not supported. Supported distance function are: {}", joinByComma(distanceFunctionToMetricKind));
if (has_six_args)
{
if (!quantizationToScalarKind.contains(index.arguments[2].safeGet<String>()))
throw Exception(ErrorCodes::INCORRECT_DATA, "Third argument (quantization) of vector similarity index is not supported. Supported quantizations are: {}", joinByComma(quantizationToScalarKind));
if (index.arguments[3].safeGet<UInt64>() < 2)
throw Exception(ErrorCodes::INCORRECT_DATA, "Fourth argument (M) of vector similarity index must be > 1");
if (index.arguments[4].safeGet<UInt64>() < 1)
throw Exception(ErrorCodes::INCORRECT_DATA, "Fifth argument (ef_construction) of vector similarity index must be > 0");
if (index.arguments[5].safeGet<UInt64>() < 1)
throw Exception(ErrorCodes::INCORRECT_DATA, "Sixth argument (ef_search) of vector similarity index must be > 0");
}
/// Check that the index is created on a single column
if (index.column_names.size() != 1 || index.data_types.size() != 1)
throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Vector similarity indexes must be created on a single column");
/// Check data type of the indexed column:
DataTypePtr data_type = index.sample_block.getDataTypes()[0];
if (const auto * data_type_array = typeid_cast<const DataTypeArray *>(data_type.get()))
{
TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float32)");
}
else
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Vector similarity indexes can only be created on columns of type Array(Float32)");
}
}
}
#endif

View File

@ -0,0 +1,172 @@
#pragma once
#include "config.h"
#if USE_USEARCH
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
# include <Storages/MergeTree/VectorSimilarityCondition.h>
# include <Common/Logger.h>
# include <usearch/index_dense.hpp>
#pragma clang diagnostic pop
namespace DB
{
struct UsearchHnswParams
{
size_t m = unum::usearch::default_connectivity();
size_t ef_construction = unum::usearch::default_expansion_add();
size_t ef_search = unum::usearch::default_expansion_search();
};
using USearchIndex = unum::usearch::index_dense_gt</*key_at*/ uint32_t, /*compressed_slot_at*/ uint32_t>;
class USearchIndexWithSerialization : public USearchIndex
{
using Base = USearchIndex;
public:
USearchIndexWithSerialization(
size_t dimensions,
unum::usearch::metric_kind_t metric_kind,
unum::usearch::scalar_kind_t scalar_kind,
UsearchHnswParams usearch_hnsw_params);
void serialize(WriteBuffer & ostr) const;
void deserialize(ReadBuffer & istr);
struct Statistics
{
size_t max_level;
size_t connectivity;
size_t size;
size_t capacity;
size_t memory_usage;
/// advanced stats:
size_t bytes_per_vector;
size_t scalar_words;
Base::stats_t statistics;
};
Statistics getStatistics() const;
};
using USearchIndexWithSerializationPtr = std::shared_ptr<USearchIndexWithSerialization>;
struct MergeTreeIndexGranuleVectorSimilarity final : public IMergeTreeIndexGranule
{
MergeTreeIndexGranuleVectorSimilarity(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
UsearchHnswParams usearch_hnsw_params_);
MergeTreeIndexGranuleVectorSimilarity(
const String & index_name_,
const Block & index_sample_block_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
UsearchHnswParams usearch_hnsw_params_,
USearchIndexWithSerializationPtr index_);
~MergeTreeIndexGranuleVectorSimilarity() override = default;
void serializeBinary(WriteBuffer & ostr) const override;
void deserializeBinary(ReadBuffer & istr, MergeTreeIndexVersion version) override;
bool empty() const override { return !index || index->size() == 0; }
const String index_name;
const Block index_sample_block;
const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind;
const UsearchHnswParams usearch_hnsw_params;
USearchIndexWithSerializationPtr index;
LoggerPtr logger = getLogger("VectorSimilarityIndex");
private:
/// The version of the persistence format of USearch index. Increment whenever you change the format.
/// Note: USearch prefixes the serialized data with its own version header. We can't rely on that because 1. the index in ClickHouse
/// is (at least in theory) agnostic of specific vector search libraries, and 2. additional data (e.g. the number of dimensions)
/// outside USearch exists which we should version separately.
static constexpr UInt64 FILE_FORMAT_VERSION = 1;
};
struct MergeTreeIndexAggregatorVectorSimilarity final : IMergeTreeIndexAggregator
{
MergeTreeIndexAggregatorVectorSimilarity(
const String & index_name_,
const Block & index_sample_block,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
UsearchHnswParams usearch_hnsw_params_);
~MergeTreeIndexAggregatorVectorSimilarity() override = default;
bool empty() const override { return !index || index->size() == 0; }
MergeTreeIndexGranulePtr getGranuleAndReset() override;
void update(const Block & block, size_t * pos, size_t limit) override;
const String index_name;
const Block index_sample_block;
const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind;
const UsearchHnswParams usearch_hnsw_params;
USearchIndexWithSerializationPtr index;
};
class MergeTreeIndexConditionVectorSimilarity final : public IMergeTreeIndexCondition
{
public:
MergeTreeIndexConditionVectorSimilarity(
const IndexDescription & index_description,
const SelectQueryInfo & query,
unum::usearch::metric_kind_t metric_kind_,
ContextPtr context);
~MergeTreeIndexConditionVectorSimilarity() override = default;
bool alwaysUnknownOrTrue() const override;
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr granule) const override;
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr granule) const override;
private:
const VectorSimilarityCondition vector_similarity_condition;
const unum::usearch::metric_kind_t metric_kind;
};
class MergeTreeIndexVectorSimilarity : public IMergeTreeIndex
{
public:
MergeTreeIndexVectorSimilarity(
const IndexDescription & index_,
unum::usearch::metric_kind_t metric_kind_,
unum::usearch::scalar_kind_t scalar_kind_,
UsearchHnswParams usearch_hnsw_params_);
~MergeTreeIndexVectorSimilarity() override = default;
MergeTreeIndexGranulePtr createIndexGranule() const override;
MergeTreeIndexAggregatorPtr createIndexAggregator(const MergeTreeWriterSettings & settings) const override;
MergeTreeIndexConditionPtr createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const;
MergeTreeIndexConditionPtr createIndexCondition(const ActionsDAG *, ContextPtr) const override;
bool isVectorSimilarityIndex() const override { return true; }
private:
const unum::usearch::metric_kind_t metric_kind;
const unum::usearch::scalar_kind_t scalar_kind;
const UsearchHnswParams usearch_hnsw_params;
};
}
#endif

View File

@ -127,15 +127,21 @@ MergeTreeIndexFactory::MergeTreeIndexFactory()
registerCreator("hypothesis", hypothesisIndexCreator);
registerValidator("hypothesis", hypothesisIndexValidator);
#ifdef ENABLE_ANNOY
registerCreator("annoy", annoyIndexCreator);
registerValidator("annoy", annoyIndexValidator);
#endif
#ifdef ENABLE_USEARCH
registerCreator("usearch", usearchIndexCreator);
registerValidator("usearch", usearchIndexValidator);
#if USE_USEARCH
registerCreator("vector_similarity", vectorSimilarityIndexCreator);
registerValidator("vector_similarity", vectorSimilarityIndexValidator);
#endif
/// ------
/// TODO: remove this block at the end of 2024.
/// Index types 'annoy' and 'usearch' are no longer supported as of June 2024. Their successor is index type 'vector_similarity'.
/// To support loading tables with old indexes during a transition period, register dummy indexes which allow load/attaching but
/// throw an exception when the user attempts to use them.
registerCreator("annoy", legacyVectorSimilarityIndexCreator);
registerValidator("annoy", legacyVectorSimilarityIndexValidator);
registerCreator("usearch", legacyVectorSimilarityIndexCreator);
registerValidator("usearch", legacyVectorSimilarityIndexValidator);
/// ------
registerCreator("inverted", fullTextIndexCreator);
registerValidator("inverted", fullTextIndexValidator);

View File

@ -15,6 +15,7 @@
#include <Interpreters/ExpressionActions.h>
#include <DataTypes/DataTypeLowCardinality.h>
#include "config.h"
constexpr auto INDEX_FILE_PREFIX = "skp_idx_";
@ -92,6 +93,13 @@ public:
virtual bool alwaysUnknownOrTrue() const = 0;
virtual bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr granule) const = 0;
/// Special stuff for vector similarity indexes
/// - Returns vector of indexes of ranges in granule which are useful for query.
virtual std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr) const
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Not implemented for non-vector-similarity indexes.");
}
};
using MergeTreeIndexConditionPtr = std::shared_ptr<IMergeTreeIndexCondition>;
@ -169,7 +177,7 @@ struct IMergeTreeIndex
virtual MergeTreeIndexConditionPtr createIndexCondition(
const ActionsDAG * filter_actions_dag, ContextPtr context) const = 0;
virtual bool isVectorSearch() const { return false; }
virtual bool isVectorSimilarityIndex() const { return false; }
virtual MergeTreeIndexMergedConditionPtr createIndexMergedCondition(
const SelectQueryInfo & /*query_info*/, StorageMetadataPtr /*storage_metadata*/) const
@ -230,15 +238,13 @@ void bloomFilterIndexValidator(const IndexDescription & index, bool attach);
MergeTreeIndexPtr hypothesisIndexCreator(const IndexDescription & index);
void hypothesisIndexValidator(const IndexDescription & index, bool attach);
#ifdef ENABLE_ANNOY
MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index);
void annoyIndexValidator(const IndexDescription & index, bool attach);
#if USE_USEARCH
MergeTreeIndexPtr vectorSimilarityIndexCreator(const IndexDescription & index);
void vectorSimilarityIndexValidator(const IndexDescription & index, bool attach);
#endif
#ifdef ENABLE_USEARCH
MergeTreeIndexPtr usearchIndexCreator(const IndexDescription& index);
void usearchIndexValidator(const IndexDescription& index, bool attach);
#endif
MergeTreeIndexPtr legacyVectorSimilarityIndexCreator(const IndexDescription & index);
void legacyVectorSimilarityIndexValidator(const IndexDescription & index, bool attach);
MergeTreeIndexPtr fullTextIndexCreator(const IndexDescription & index);
void fullTextIndexValidator(const IndexDescription & index, bool attach);

View File

@ -0,0 +1,350 @@
#include <Storages/MergeTree/VectorSimilarityCondition.h>
#include <Core/Settings.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Storages/MergeTree/KeyCondition.h>
#include <Storages/MergeTree/MergeTreeSettings.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int INCORRECT_QUERY;
}
namespace
{
template <typename Literal>
void extractReferenceVectorFromLiteral(std::vector<Float32> & reference_vector, Literal literal)
{
Float64 float_element_of_reference_vector;
Int64 int_element_of_reference_vector;
for (const auto & value : literal.value())
{
if (value.tryGet(float_element_of_reference_vector))
reference_vector.emplace_back(float_element_of_reference_vector);
else if (value.tryGet(int_element_of_reference_vector))
reference_vector.emplace_back(static_cast<float>(int_element_of_reference_vector));
else
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported.");
}
}
VectorSimilarityCondition::Info::DistanceFunction stringToDistanceFunction(std::string_view distance_function)
{
if (distance_function == "L2Distance")
return VectorSimilarityCondition::Info::DistanceFunction::L2;
else
return VectorSimilarityCondition::Info::DistanceFunction::Unknown;
}
}
VectorSimilarityCondition::VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context)
: block_with_constants(KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context))
, index_granularity(context->getMergeTreeSettings().index_granularity)
, max_limit_for_ann_queries(context->getSettingsRef().max_limit_for_ann_queries)
, index_is_useful(checkQueryStructure(query_info))
{}
bool VectorSimilarityCondition::alwaysUnknownOrTrue(String distance_function) const
{
if (!index_is_useful)
return true; /// query isn't supported
/// If query is supported, check if distance function of index is the same as distance function in query
return !(stringToDistanceFunction(distance_function) == query_information->distance_function);
}
UInt64 VectorSimilarityCondition::getLimit() const
{
if (index_is_useful && query_information.has_value())
return query_information->limit;
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
}
std::vector<float> VectorSimilarityCondition::getReferenceVector() const
{
if (index_is_useful && query_information.has_value())
return query_information->reference_vector;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index.");
}
size_t VectorSimilarityCondition::getDimensions() const
{
if (index_is_useful && query_information.has_value())
return query_information->reference_vector.size();
throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index.");
}
String VectorSimilarityCondition::getColumnName() const
{
if (index_is_useful && query_information.has_value())
return query_information->column_name;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index.");
}
VectorSimilarityCondition::Info::DistanceFunction VectorSimilarityCondition::getDistanceFunction() const
{
if (index_is_useful && query_information.has_value())
return query_information->distance_function;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Distance function was requested for useless or uninitialized index.");
}
bool VectorSimilarityCondition::checkQueryStructure(const SelectQueryInfo & query)
{
Info order_by_info;
/// Build rpns for query sections
const auto & select = query.query->as<ASTSelectQuery &>();
RPN rpn_order_by;
RPNElement rpn_limit;
UInt64 limit;
if (select.limitLength())
traverseAtomAST(select.limitLength(), rpn_limit);
if (select.orderBy())
traverseOrderByAST(select.orderBy(), rpn_order_by);
/// Reverse RPNs for conveniences during parsing
std::reverse(rpn_order_by.begin(), rpn_order_by.end());
const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by, order_by_info);
const bool limit_is_valid = matchRPNLimit(rpn_limit, limit);
if (!limit_is_valid || limit > max_limit_for_ann_queries)
return false;
if (order_by_is_valid)
{
query_information = std::move(order_by_info);
query_information->limit = limit;
return true;
}
return false;
}
void VectorSimilarityCondition::traverseAST(const ASTPtr & node, RPN & rpn)
{
/// If the node is ASTFunction, it may have children nodes
if (const auto * func = node->as<ASTFunction>())
{
const ASTs & children = func->arguments->children;
/// Traverse children nodes
for (const auto& child : children)
traverseAST(child, rpn);
}
RPNElement element;
/// Get the data behind node
if (!traverseAtomAST(node, element))
element.function = RPNElement::FUNCTION_UNKNOWN;
rpn.emplace_back(std::move(element));
}
bool VectorSimilarityCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
{
/// Match Functions
if (const auto * function = node->as<ASTFunction>())
{
/// Set the name
out.func_name = function->name;
if (function->name == "L1Distance" ||
function->name == "L2Distance" ||
function->name == "LinfDistance" ||
function->name == "cosineDistance" ||
function->name == "dotProduct")
out.function = RPNElement::FUNCTION_DISTANCE;
else if (function->name == "array")
out.function = RPNElement::FUNCTION_ARRAY;
else if (function->name == "_CAST")
out.function = RPNElement::FUNCTION_CAST;
else
return false;
return true;
}
/// Match identifier
else if (const auto * identifier = node->as<ASTIdentifier>())
{
out.function = RPNElement::FUNCTION_IDENTIFIER;
out.identifier.emplace(identifier->name());
out.func_name = "column identifier";
return true;
}
/// Check if we have constants behind the node
return tryCastToConstType(node, out);
}
bool VectorSimilarityCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
{
Field const_value;
DataTypePtr const_type;
if (KeyCondition::getConstant(node, block_with_constants, const_value, const_type))
{
/// Check for constant types
if (const_value.getType() == Field::Types::Float64)
{
out.function = RPNElement::FUNCTION_FLOAT_LITERAL;
out.float_literal.emplace(const_value.safeGet<Float32>());
out.func_name = "Float literal";
return true;
}
if (const_value.getType() == Field::Types::UInt64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.safeGet<UInt64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Int64)
{
out.function = RPNElement::FUNCTION_INT_LITERAL;
out.int_literal.emplace(const_value.safeGet<Int64>());
out.func_name = "Int literal";
return true;
}
if (const_value.getType() == Field::Types::Array)
{
out.function = RPNElement::FUNCTION_LITERAL_ARRAY;
out.array_literal = const_value.safeGet<Array>();
out.func_name = "Array literal";
return true;
}
if (const_value.getType() == Field::Types::String)
{
out.function = RPNElement::FUNCTION_STRING_LITERAL;
out.func_name = const_value.safeGet<String>();
return true;
}
}
return false;
}
void VectorSimilarityCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn)
{
if (const auto * expr_list = node->as<ASTExpressionList>())
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
traverseAST(order_by_element->children.front(), rpn);
}
/// Returns true and stores ANNExpr if the query has valid ORDERBY clause
bool VectorSimilarityCondition::matchRPNOrderBy(RPN & rpn, Info & info)
{
/// ORDER BY clause must have at least 3 expressions
if (rpn.size() < 3)
return false;
auto iter = rpn.begin();
auto end = rpn.end();
bool identifier_found = false;
/// Matches DistanceFunc->[Column]->[ArrayFunc]->ReferenceVector(floats)->[Column]
if (iter->function != RPNElement::FUNCTION_DISTANCE)
return false;
info.distance_function = stringToDistanceFunction(iter->func_name);
++iter;
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
identifier_found = true;
info.column_name = std::move(iter->identifier.value());
++iter;
}
if (iter->function == RPNElement::FUNCTION_ARRAY)
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
++iter;
}
/// further conditions are possible if there is no array, or no identifier is found
/// the array can be inside a cast function. For other cases, see the loop after this condition
if (iter != end && iter->function == RPNElement::FUNCTION_CAST)
{
++iter;
/// Cast should be made to array
if (!iter->func_name.starts_with("Array"))
return false;
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractReferenceVectorFromLiteral(info.reference_vector, iter->array_literal);
++iter;
}
else
return false;
}
while (iter != end)
{
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter));
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
if (identifier_found)
return false;
info.column_name = std::move(iter->identifier.value());
identifier_found = true;
}
else
return false;
++iter;
}
/// Final checks of correctness
return identifier_found && !info.reference_vector.empty();
}
/// Returns true and stores Length if we have valid LIMIT clause in query
bool VectorSimilarityCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
{
if (rpn.function == RPNElement::FUNCTION_INT_LITERAL)
{
limit = rpn.int_literal.value();
return true;
}
return false;
}
/// Gets float or int from AST node
float VectorSimilarityCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter)
{
if (iter->float_literal.has_value())
return iter->float_literal.value();
if (iter->int_literal.has_value())
return static_cast<float>(iter->int_literal.value());
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong parsed AST in buildRPN\n");
}
}

View File

@ -9,52 +9,9 @@
namespace DB
{
static constexpr auto DISTANCE_FUNCTION_L2 = "L2Distance";
static constexpr auto DISTANCE_FUNCTION_COSINE = "cosineDistance";
/// Approximate Nearest Neighbour queries have a similar structure:
/// - reference vector from which all distances are calculated
/// - metric name (e.g L2Distance, LpDistance, etc.)
/// - name of column with embeddings
/// - type of query
/// - maximum number of returned elements (LIMIT)
///
/// And two optional parameters:
/// - p for LpDistance function
/// - distance to compare with (only for where queries)
///
/// This struct holds all these components.
struct ApproximateNearestNeighborInformation
{
using Embedding = std::vector<float>;
Embedding reference_vector;
enum class Metric : uint8_t
{
Unknown,
L2,
Lp
};
Metric metric;
String column_name;
UInt64 limit;
enum class Type : uint8_t
{
OrderBy,
Where
};
Type type;
float p_for_lp_dist = -1.0;
float distance = -1.0;
};
// Class ANNCondition, is responsible for recognizing if the query is an ANN queries which can utilize ANN indexes. It parses the SQL query
/// and checks if it matches ANNIndexes. Method alwaysUnknownOrTrue returns false if we can speed up the query, and true otherwise. It has
/// only one argument, the name of the metric with which index was built. Two main patterns of queries are supported
/// Class VectorSimilarityCondition is responsible for recognizing if the query can utilize vector similarity indexes.
/// Method alwaysUnknownOrTrue returns false if we can speed up the query, and true otherwise. It has
/// only one argument, the name of the distance function with which index was built. Two main patterns of queries are supported
///
/// - 1. WHERE queries:
/// SELECT * FROM * WHERE DistanceFunc(column, reference_vector) < floatLiteral LIMIT count
@ -64,14 +21,14 @@ struct ApproximateNearestNeighborInformation
///
/// Queries without LIMIT count are not supported
/// If the query is both of type 1. and 2., than we can't use the index and alwaysUnknownOrTrue returns true.
/// reference_vector should have float coordinates, e.g. (0.2, 0.1, .., 0.5)
/// reference_vector should have float coordinates, e.g. [0.2, 0.1, .., 0.5]
///
/// If the query matches one of these two types, then this class extracts the main information needed for ANN indexes from the query.
/// If the query matches one of these two types, then this class extracts the main information needed for vector similarity indexes from the
/// query.
///
/// From matching query it extracts
/// - referenceVector
/// - metricName(DistanceFunction)
/// - dimension size if query uses LpDistance
/// - distance function
/// - distance to compare(ONLY for search types, otherwise you get exception)
/// - spaceDimension(which is referenceVector's components count)
/// - column
@ -79,35 +36,45 @@ struct ApproximateNearestNeighborInformation
/// - queryHasOrderByClause and queryHasWhereClause return true if query matches the type
///
/// Search query type is also recognized for PREWHERE clause
class ApproximateNearestNeighborCondition
class VectorSimilarityCondition
{
public:
ApproximateNearestNeighborCondition(const SelectQueryInfo & query_info, ContextPtr context);
VectorSimilarityCondition(const SelectQueryInfo & query_info, ContextPtr context);
/// Approximate nearest neighbour (ANN) / vector similarity queries have a similar structure:
/// - reference vector from which all distances are calculated
/// - distance function, e.g L2Distance
/// - name of column with embeddings
/// - type of query
/// - maximum number of returned elements (LIMIT)
///
/// And one optional parameter:
/// - distance to compare with (only for where queries)
///
/// This struct holds all these components.
struct Info
{
enum class DistanceFunction : uint8_t
{
Unknown,
L2
};
std::vector<Float32> reference_vector;
DistanceFunction distance_function;
String column_name;
UInt64 limit;
float distance = -1.0;
};
/// Returns false if query can be speeded up by an ANN index, true otherwise.
bool alwaysUnknownOrTrue(String metric) const;
bool alwaysUnknownOrTrue(String distance_function) const;
/// Returns the distance to compare with for search query
float getComparisonDistanceForWhereQuery() const;
/// Distance should be calculated regarding to referenceVector
std::vector<float> getReferenceVector() const;
/// Reference vector's dimension count
size_t getDimensions() const;
String getColumnName() const;
ApproximateNearestNeighborInformation::Metric getMetricType() const;
/// The P- value if the metric is 'LpDistance'
float getPValueForLpDistance() const;
ApproximateNearestNeighborInformation::Type getQueryType() const;
Info::DistanceFunction getDistanceFunction() const;
UInt64 getIndexGranularity() const { return index_granularity; }
/// Length's value from LIMIT clause
UInt64 getLimit() const;
private:
@ -118,9 +85,6 @@ private:
/// DistanceFunctions
FUNCTION_DISTANCE,
//tuple(0.1, ..., 0.1)
FUNCTION_TUPLE,
//array(0.1, ..., 0.1)
FUNCTION_ARRAY,
@ -139,9 +103,6 @@ private:
/// Unknown, can be any value
FUNCTION_UNKNOWN,
/// (0.1, ...., 0.1) vector without word 'tuple'
FUNCTION_LITERAL_TUPLE,
/// [0.1, ...., 0.1] vector without word 'array'
FUNCTION_LITERAL_ARRAY,
@ -154,19 +115,14 @@ private:
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
: function(function_)
, func_name("Unknown")
, float_literal(std::nullopt)
, identifier(std::nullopt)
{}
Function function;
String func_name;
String func_name = "Unknown";
std::optional<float> float_literal;
std::optional<String> identifier;
std::optional<int64_t> int_literal;
std::optional<Tuple> tuple_literal;
std::optional<Array> array_literal;
UInt32 dim = 0;
@ -186,16 +142,16 @@ private:
void traverseOrderByAST(const ASTPtr & node, RPN & rpn);
/// Returns true and stores ANNExpr if the query has valid WHERE section
static bool matchRPNWhere(RPN & rpn, ApproximateNearestNeighborInformation & ann_info);
static bool matchRPNWhere(RPN & rpn, Info & info);
/// Returns true and stores ANNExpr if the query has valid ORDERBY section
static bool matchRPNOrderBy(RPN & rpn, ApproximateNearestNeighborInformation & ann_info);
static bool matchRPNOrderBy(RPN & rpn, Info & info);
/// Returns true and stores Length if we have valid LIMIT clause in query
static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit);
/* Matches dist function, reference vector, column name */
static bool matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ApproximateNearestNeighborInformation & ann_info);
/// Matches dist function, reference vector, column name
static bool matchMainParts(RPN::iterator & iter, const RPN::iterator & end, Info & info);
/// Gets float or int from AST node
static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter);
@ -203,7 +159,7 @@ private:
Block block_with_constants;
/// true if we have one of two supported query types
std::optional<ApproximateNearestNeighborInformation> query_information;
std::optional<Info> query_information;
// Get from settings ANNIndex parameters
const UInt64 index_granularity;
@ -214,13 +170,4 @@ private:
bool index_is_useful = false;
};
/// Common interface of ANN indexes.
class IMergeTreeIndexConditionApproximateNearestNeighbor : public IMergeTreeIndexCondition
{
public:
/// Returns vector of indexes of ranges in granule which are useful for query.
virtual std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const = 0;
};
}

View File

@ -4,4 +4,4 @@ clickhouse_add_executable (mergetree_checksum_fuzzer mergetree_checksum_fuzzer.c
target_link_libraries (mergetree_checksum_fuzzer PRIVATE dbms clickhouse_functions)
clickhouse_add_executable (columns_description_fuzzer columns_description_fuzzer.cpp)
target_link_libraries (columns_description_fuzzer PRIVATE dbms clickhouse_functions)
target_link_libraries (columns_description_fuzzer PRIVATE clickhouse_functions)

View File

@ -1,4 +1,5 @@
#include <Storages/ColumnsDescription.h>
#include <iostream>
#include <iostream>

View File

@ -164,6 +164,9 @@ endif()
if (TARGET ch_contrib::bcrypt)
set(USE_BCRYPT 1)
endif()
if (TARGET ch_contrib::usearch)
set(USE_USEARCH 1)
endif()
if (TARGET ch_contrib::ssh)
set(USE_SSH 1)
endif()

View File

@ -75,7 +75,7 @@ def get_run_command(
f"--volume={result_path}:/test_output "
"--security-opt seccomp=unconfined " # required to issue io_uring sys-calls
f"--cap-add=SYS_PTRACE {env_str} {additional_options_str} {image} "
"python3 ./utils/runner.py"
"python3 /usr/share/clickhouse-test/fuzz/runner.py"
)

View File

@ -11,7 +11,7 @@ FUZZER_ARGS = os.getenv("FUZZER_ARGS", "")
def run_fuzzer(fuzzer: str):
logging.info(f"Running fuzzer {fuzzer}...")
logging.info("Running fuzzer %s...", fuzzer)
corpus_dir = f"{fuzzer}.in"
with Path(corpus_dir) as path:
@ -29,28 +29,27 @@ def run_fuzzer(fuzzer: str):
if parser.has_section("asan"):
os.environ["ASAN_OPTIONS"] = (
f"{os.environ['ASAN_OPTIONS']}:{':'.join('%s=%s' % (key, value) for key, value in parser['asan'].items())}"
f"{os.environ['ASAN_OPTIONS']}:{':'.join(f'{key}={value}' for key, value in parser['asan'].items())}"
)
if parser.has_section("msan"):
os.environ["MSAN_OPTIONS"] = (
f"{os.environ['MSAN_OPTIONS']}:{':'.join('%s=%s' % (key, value) for key, value in parser['msan'].items())}"
f"{os.environ['MSAN_OPTIONS']}:{':'.join(f'{key}={value}' for key, value in parser['msan'].items())}"
)
if parser.has_section("ubsan"):
os.environ["UBSAN_OPTIONS"] = (
f"{os.environ['UBSAN_OPTIONS']}:{':'.join('%s=%s' % (key, value) for key, value in parser['ubsan'].items())}"
f"{os.environ['UBSAN_OPTIONS']}:{':'.join(f'{key}={value}' for key, value in parser['ubsan'].items())}"
)
if parser.has_section("libfuzzer"):
custom_libfuzzer_options = " ".join(
"-%s=%s" % (key, value)
for key, value in parser["libfuzzer"].items()
f"-{key}={value}" for key, value in parser["libfuzzer"].items()
)
if parser.has_section("fuzzer_arguments"):
fuzzer_arguments = " ".join(
("%s" % key) if value == "" else ("%s=%s" % (key, value))
(f"{key}") if value == "" else (f"{key}={value}")
for key, value in parser["fuzzer_arguments"].items()
)
@ -65,7 +64,7 @@ def run_fuzzer(fuzzer: str):
cmd_line += " < /dev/null"
logging.info(f"...will execute: {cmd_line}")
logging.info("...will execute: %s", cmd_line)
subprocess.check_call(cmd_line, shell=True)

View File

@ -0,0 +1,4 @@
[fuzzer_arguments]
--log-file=tcp_protocol_fuzzer.log
--=
--logging.terminal=0

View File

@ -0,0 +1,4 @@
<clickhouse>
<keep_alive_timeout>3600</keep_alive_timeout>
<max_keep_alive_requests>5</max_keep_alive_requests>
</clickhouse>

View File

@ -0,0 +1,55 @@
import logging
import pytest
import random
import requests
from helpers.cluster import ClickHouseCluster
cluster = ClickHouseCluster(__file__)
node = cluster.add_instance("node", main_configs=["configs/keep_alive_settings.xml"])
@pytest.fixture(scope="module")
def start_cluster():
try:
logging.info("Starting cluster...")
cluster.start()
logging.info("Cluster started")
yield cluster
finally:
cluster.shutdown()
def test_max_keep_alive_requests_on_user_side(start_cluster):
# In this test we have `keep_alive_timeout` set to one hour to never trigger connection reset by timeout, `max_keep_alive_requests` is set to 5.
# We expect server to close connection after each 5 requests. We detect connection reset by change in src port.
# So the first 5 requests should come from the same port, the following 5 requests should come from another port.
log_comments = []
for _ in range(10):
rand_id = random.randint(0, 1000000)
log_comment = f"test_requests_with_keep_alive_{rand_id}"
log_comments.append(log_comment)
log_comments = sorted(log_comments)
session = requests.Session()
for i in range(10):
session.get(
f"http://{node.ip_address}:8123/?query=select%201&log_comment={log_comments[i]}"
)
ports = node.query(
f"""
SYSTEM FLUSH LOGS;
SELECT port
FROM system.query_log
WHERE log_comment IN ({", ".join(f"'{comment}'" for comment in log_comments)}) AND type = 'QueryFinish'
ORDER BY log_comment
"""
).split("\n")[:-1]
expected = 5 * [ports[0]] + [ports[5]] * 5
assert ports == expected

View File

@ -1,6 +1,6 @@
< Connection: Keep-Alive
< Keep-Alive: timeout=10
< Keep-Alive: timeout=10, max=?
< Connection: Keep-Alive
< Keep-Alive: timeout=10
< Keep-Alive: timeout=10, max=?
< Connection: Keep-Alive
< Keep-Alive: timeout=10
< Keep-Alive: timeout=10, max=?

View File

@ -6,9 +6,10 @@ CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
URL="${CLICKHOUSE_PORT_HTTP_PROTO}://${CLICKHOUSE_HOST}:${CLICKHOUSE_PORT_HTTP}/"
${CLICKHOUSE_CURL} -vsS "${URL}" --data-binary @- <<< "SELECT 1" 2>&1 | perl -lnE 'print if /Keep-Alive/';
${CLICKHOUSE_CURL} -vsS "${URL}" --data-binary @- <<< " error here " 2>&1 | perl -lnE 'print if /Keep-Alive/';
${CLICKHOUSE_CURL} -vsS "${URL}"ping 2>&1 | perl -lnE 'print if /Keep-Alive/';
# the sed command here replaces the real number of left requests with a question mark, because it can vary and we don't really have control over it
${CLICKHOUSE_CURL} -vsS "${URL}" --data-binary @- <<< "SELECT 1" 2>&1 | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I' | grep -i 'keep-alive';
${CLICKHOUSE_CURL} -vsS "${URL}" --data-binary @- <<< " error here " 2>&1 | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I' | grep -i 'keep-alive';
${CLICKHOUSE_CURL} -vsS "${URL}"ping 2>&1 | perl -lnE 'print if /Keep-Alive/' | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I' | grep -i 'keep-alive';
# no keep-alive:
${CLICKHOUSE_CURL} -vsS "${URL}"404/not/found/ 2>&1 | perl -lnE 'print if /Keep-Alive/';

View File

@ -2,11 +2,11 @@ HTTP/1.1 200 OK
Connection: Keep-Alive
Content-Type: text/tab-separated-values; charset=UTF-8
Transfer-Encoding: chunked
Keep-Alive: timeout=10
Keep-Alive: timeout=10, max=?
HTTP/1.1 200 OK
Connection: Keep-Alive
Content-Type: text/tab-separated-values; charset=UTF-8
Transfer-Encoding: chunked
Keep-Alive: timeout=10
Keep-Alive: timeout=10, max=?

View File

@ -4,8 +4,9 @@ CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
# shellcheck source=../shell_config.sh
. "$CURDIR"/../shell_config.sh
( ${CLICKHOUSE_CURL} -s --head "${CLICKHOUSE_URL}&query=SELECT%201";
${CLICKHOUSE_CURL} -s --head "${CLICKHOUSE_URL}&query=select+*+from+system.numbers+limit+1000000" ) | grep -v "Date:" | grep -v "X-ClickHouse-Server-Display-Name:" | grep -v "X-ClickHouse-Query-Id:" | grep -v "X-ClickHouse-Format:" | grep -v "X-ClickHouse-Timezone:"
# the sed command here replaces the real number of left requests with a question mark, because it can vary and we don't really have control over it
( ${CLICKHOUSE_CURL} -s --head "${CLICKHOUSE_URL}&query=SELECT%201" | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I';
${CLICKHOUSE_CURL} -s --head "${CLICKHOUSE_URL}&query=select+*+from+system.numbers+limit+1000000" ) | sed -r 's/(keep-alive: timeout=10, max=)[0-9]+/\1?/I' | grep -v "Date:" | grep -v "X-ClickHouse-Server-Display-Name:" | grep -v "X-ClickHouse-Query-Id:" | grep -v "X-ClickHouse-Format:" | grep -v "X-ClickHouse-Timezone:"
if [[ $(${CLICKHOUSE_CURL} -sS -X POST -I "${CLICKHOUSE_URL}&query=SELECT+1" | grep -c '411 Length Required') -ne 1 ]]; then
echo FAIL

View File

@ -1,17 +1,5 @@
Issue #52258: Empty Arrays or Arrays with default values are rejected
- Annoy
- Usearch
It is possible to create parts with different Array vector sizes but there will be an error at query time
- Annoy
- Usearch
Correctness of index with > 1 mark
- Annoy
1 [1,0] 0
9000 [9000,0] 0
1 (1,0) 0
9000 (9000,0) 0
- Usearch
1 [1,0] 0
9000 [9000,0] 0
1 (1,0) 0
9000 (9000,0) 0

Some files were not shown because too many files have changed in this diff Show More