Merge branch 'master' into avro-negative-decimals

This commit is contained in:
Kruglov Pavel 2023-06-08 12:51:36 +02:00 committed by GitHub
commit ea8ba2287b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
155 changed files with 1977 additions and 1482 deletions

View File

@ -12,6 +12,7 @@ add_library (_lz4 ${SRCS})
add_library (ch_contrib::lz4 ALIAS _lz4)
target_compile_definitions (_lz4 PUBLIC LZ4_DISABLE_DEPRECATE_WARNINGS=1)
target_compile_definitions (_lz4 PUBLIC LZ4_FAST_DEC_LOOP=1)
if (SANITIZE STREQUAL "undefined")
target_compile_options (_lz4 PRIVATE -fno-sanitize=undefined)
endif ()

View File

@ -1,147 +1,156 @@
# Approximate Nearest Neighbor Search Indexes [experimental] {#table_engines-ANNIndex}
The main task that indexes achieve is to quickly find nearest neighbors for multidimensional data. An example of such a problem can be finding similar pictures (texts) for a given picture (text). That problem can be reduced to finding the nearest [embeddings](https://cloud.google.com/architecture/overview-extracting-and-serving-feature-embeddings-for-machine-learning). They can be created from data using [UDF](/docs/en/sql-reference/functions/index.md/#executable-user-defined-functions).
Nearest neighborhood search refers to the problem of finding the point(s) with the smallest distance to a given point in an n-dimensional
space. Since exact search is in practice usually typically too slow, the task is often solved with approximate algorithms. A popular use
case of of neighbor search is finding similar pictures (texts) for a given picture (text). Pictures (texts) can be decomposed into
[embeddings](https://cloud.google.com/architecture/overview-extracting-and-serving-feature-embeddings-for-machine-learning), and instead of
comparing pictures (texts) pixel-by-pixel (character-by-character), only the embeddings are compared.
In terms of SQL, the problem can be expressed as follows:
The next queries find the closest neighbors in N-dimensional space using the L2 (Euclidean) distance:
``` sql
SELECT *
FROM table_name
WHERE L2Distance(Column, Point) < MaxDistance
FROM table
WHERE L2Distance(column, Point) < MaxDistance
LIMIT N
```
``` sql
SELECT *
FROM table_name
ORDER BY L2Distance(Column, Point)
FROM table
ORDER BY L2Distance(column, Point)
LIMIT N
```
But it will take some time for execution because of the long calculation of the distance between `TargetEmbedding` and all other vectors. This is where ANN indexes can help. They store a compact approximation of the search space (e.g. using clustering, search trees, etc.) and are able to compute approximate neighbors quickly.
## Indexes Structure
The queries are expensive because the L2 (Euclidean) distance between `Point` and all points in `column` and must be computed. To speed this process up, Approximate Nearest Neighbor Search Indexes (ANN indexes) store a compact representation of the search space (using clustering, search trees, etc.) which allows to compute an approximate answer quickly.
Approximate Nearest Neighbor Search Indexes (`ANNIndexes`) are similar to skip indexes. They are constructed by some granules and determine which of them should be skipped. Compared to skip indices, ANN indices use their results not only to skip some group of granules, but also to select particular granules from a set of granules.
# Creating ANN Indexes
`ANNIndexes` are designed to speed up two types of queries:
As long as ANN indexes are experimental, you first need to `SET allow_experimental_annoy_index = 1`.
- ###### Type 1: Where
``` sql
SELECT *
FROM table_name
WHERE DistanceFunction(Column, Point) < MaxDistance
LIMIT N
```
- ###### Type 2: Order by
``` sql
SELECT *
FROM table_name [WHERE ...]
ORDER BY DistanceFunction(Column, Point)
LIMIT N
```
In these queries, `DistanceFunction` is selected from [distance functions](/docs/en/sql-reference/functions/distance-functions.md). `Point` is a known vector (something like `(0.1, 0.1, ... )`). To avoid writing large vectors, use [client parameters](/docs/en//interfaces/cli.md#queries-with-parameters-cli-queries-with-parameters). `Value` - a float value that will bound the neighbourhood.
:::note
ANN index can't speed up query that satisfies both types (`where + order by`, only one of them). All queries must have the limit, as algorithms are used to find nearest neighbors and need a specific number of them.
:::
:::note
Indexes are applied only to queries with a limit less than the `max_limit_for_ann_queries` setting. This helps to avoid memory overflows in queries with a large limit. `max_limit_for_ann_queries` setting can be changed if you know you can provide enough memory. The default value is `1000000`.
:::
Both types of queries are handled the same way. The indexes get `n` neighbors (where `n` is taken from the `LIMIT` clause) and work with them. In `ORDER BY` query they remember the numbers of all parts of the granule that have at least one of neighbor. In `WHERE` query they remember only those parts that satisfy the requirements.
## Create table with ANNIndex
This feature is disabled by default. To enable it, set `allow_experimental_annoy_index` to 1. Also, this feature is disabled on ARM, due to likely problems with the algorithm.
Syntax to create an ANN index over an `Array` column:
```sql
CREATE TABLE t
CREATE TABLE table
(
`id` Int64,
`data` Tuple(Float32, Float32, Float32),
INDEX ann_index_name data TYPE ann_index_type(ann_index_parameters) GRANULARITY N
`embedding` Array(Float32),
INDEX <ann_index_name> embedding TYPE <ann_index_type>(<ann_index_parameters>) GRANULARITY <N>
)
ENGINE = MergeTree
ORDER BY id;
```
Syntax to create an ANN index over a `Tuple` column:
```sql
CREATE TABLE t
CREATE TABLE table
(
`id` Int64,
`data` Array(Float32),
INDEX ann_index_name data TYPE ann_index_type(ann_index_parameters) GRANULARITY N
`embedding` Tuple(Float32[, Float32[, ...]]),
INDEX <ann_index_name> embedding TYPE <ann_index_type>(<ann_index_parameters>) GRANULARITY <N>
)
ENGINE = MergeTree
ORDER BY id;
```
With greater `GRANULARITY` indexes remember the data structure better. The `GRANULARITY` indicates how many granules will be used to construct the index. The more data is provided for the index, the more of it can be handled by one index and the more chances that with the right hyper parameters the index will remember the data structure better. But some indexes can't be built if they don't have enough data, so this granule will always participate in the query. For more information, see the description of indexes.
ANN indexes are built during column insertion and merge and `INSERT` and `OPTIMIZE` statements will be slower than for ordinary tables. ANNIndexes are ideally used only with immutable or rarely changed data, respectively there are much more read requests than write requests.
As the indexes are built only during insertions into table, `INSERT` and `OPTIMIZE` queries are slower than for ordinary table. At this stage indexes remember all the information about the given data. ANNIndexes should be used if you have immutable or rarely changed data and many read requests.
Similar to regular skip indexes, ANN indexes are constructed over granules and each indexed block consists of `GRANULARITY = <N>`-many
granules. For example, if the primary index granularity of the table is 8192 (setting `index_granularity = 8192`) and `GRANULARITY = 2`,
then each indexed block will consist of 16384 rows. However, unlike skip indexes, ANN indexes are not only able to skip the entire indexed
block, they are able to skip individual granules in indexed blocks. As a result, the `GRANULARITY` parameter has a different meaning in ANN
indexes than in normal skip indexes. Basically, the bigger `GRANULARITY` is chosen, the more data is provided to a single ANN index, and the
higher the chance that with the right hyper parameters, the index will remember the data structure better.
You can create your table with index which uses certain algorithm. Now only indices based on the following algorithms are supported:
# Using ANN Indexes
ANN indexes support two types of queries:
- WHERE queries:
``` sql
SELECT *
FROM table
WHERE DistanceFunction(column, Point) < MaxDistance
LIMIT N
```
- ORDER BY queries:
``` sql
SELECT *
FROM table
[WHERE ...]
ORDER BY DistanceFunction(column, Point)
LIMIT N
```
`DistanceFunction` is a [distance function](/docs/en/sql-reference/functions/distance-functions.md), `Point` is a reference vector (e.g. `(0.17, 0.33, ...)`) and `MaxDistance` is a floating point value which restricts the size of the neighbourhood.
:::tip
To avoid writing out large vectors, you can use [query parameters](/docs/en//interfaces/cli.md#queries-with-parameters-cli-queries-with-parameters), e.g.
```bash
clickhouse-client --param_vec='hello' --query="SELECT * FROM table WHERE L2Distance(embedding, {vec: Array(Float32)}) < 1.0"
```
:::
ANN indexes cannot speed up queries that contain both a `WHERE DistanceFunction(column, Point) < MaxDistance` and an `ORDER BY DistanceFunction(column, Point)` clause. Also, the approximate algorithms used to determine the nearest neighbors require a limit, hence queries that use an ANN index must have a `LIMIT` clause.
An ANN index is only used if the query has a `LIMIT` value smaller than setting `max_limit_for_ann_queries` (default: 1 million rows). This is a safety measure which helps to avoid large memory consumption by external libraries for approximate neighbor search.
# Available ANN Indexes
# Index list
- [Annoy](/docs/en/engines/table-engines/mergetree-family/annindexes.md#annoy-annoy)
# Annoy {#annoy}
Implementation of the algorithm was taken from [this repository](https://github.com/spotify/annoy).
## Annoy {#annoy}
Short description of the algorithm:
The algorithm recursively divides in half all space by random linear surfaces (lines in 2D, planes in 3D etc.). Thus it makes tree of polyhedrons and points that they contains. Repeating the operation several times for greater accuracy it creates a forest.
To find K Nearest Neighbours it goes down through the trees and fills the buffer of closest points using the priority queue of polyhedrons. Next, it sorts buffer and return the nearest K points.
(currently disabled on ARM due to memory safety problems with the algorithm)
This type of ANN index implements [the Annoy algorithm](https://github.com/spotify/annoy) which uses a recursive division of the space in random linear surfaces (lines in 2D, planes in 3D etc.).
Syntax to create a Annoy index over a `Array` column:
__Examples__:
```sql
CREATE TABLE t
CREATE TABLE table
(
id Int64,
data Tuple(Float32, Float32, Float32),
INDEX ann_index_name data TYPE annoy(NumTrees, DistanceName) GRANULARITY N
embedding Array(Float32),
INDEX <ann_index_name> embedding TYPE annoy([DistanceName[, NumTrees]]) GRANULARITY N
)
ENGINE = MergeTree
ORDER BY id;
```
Syntax to create a Annoy index over a `Tuple` column:
```sql
CREATE TABLE t
CREATE TABLE table
(
id Int64,
data Array(Float32),
INDEX ann_index_name data TYPE annoy(NumTrees, DistanceName) GRANULARITY N
embedding Tuple(Float32[, Float32[, ...]]),
INDEX <ann_index_name> embedding TYPE annoy([DistanceName[, NumTrees]]) GRANULARITY N
)
ENGINE = MergeTree
ORDER BY id;
```
Parameter `DistanceName` is name of a distance function (default `L2Distance`). Annoy currently supports `L2Distance` and `cosineDistance` as distance functions. Parameter `NumTrees` (default: 100) is the number of trees which the algorithm will create. Higher values of `NumTree` mean slower `CREATE` and `SELECT` statements (approximately linearly), but increase the accuracy of search results.
:::note
Table with array field will work faster, but all arrays **must** have same length. Use [CONSTRAINT](/docs/en/sql-reference/statements/create/table.md#constraints) to avoid errors. For example, `CONSTRAINT constraint_name_1 CHECK length(data) = 256`.
Indexes over columns of type `Array` will generally work faster than indexes on `Tuple` columns. All arrays **must** have same length. Use [CONSTRAINT](/docs/en/sql-reference/statements/create/table.md#constraints) to avoid errors. For example, `CONSTRAINT constraint_name_1 CHECK length(embedding) = 256`.
:::
Parameter `NumTrees` is the number of trees which the algorithm will create. The bigger it is, the slower (approximately linear) it works (in both `CREATE` and `SELECT` requests), but the better accuracy you get (adjusted for randomness). By default it is set to `100`. Parameter `DistanceName` is name of distance function. By default it is set to `L2Distance`. It can be set without changing first parameter, for example
```sql
CREATE TABLE t
(
id Int64,
data Array(Float32),
INDEX ann_index_name data TYPE annoy('cosineDistance') GRANULARITY N
)
ENGINE = MergeTree
ORDER BY id;
```
Setting `annoy_index_search_k_nodes` (default: `NumTrees * LIMIT`) determines how many tree nodes are inspected during SELECTs. It can be used to
balance runtime and accuracy at runtime.
Annoy supports `L2Distance` and `cosineDistance`.
Example:
In the `SELECT` in the settings (`ann_index_select_query_params`) you can specify the size of the internal buffer (more details in the description above or in the [original repository](https://github.com/spotify/annoy)). During the query it will inspect up to `search_k` nodes which defaults to `n_trees * n` if not provided. `search_k` gives you a run-time trade-off between better accuracy and speed.
__Example__:
``` sql
SELECT *
FROM table_name [WHERE ...]
ORDER BY L2Distance(Column, Point)
ORDER BY L2Distance(column, Point)
LIMIT N
SETTING ann_index_select_query_params=`k_search=100`
SETTINGS annoy_index_search_k_nodes=100
```

View File

@ -73,7 +73,7 @@ FROM t_null_big
└────────────────────┴─────────────────────┘
```
Also you can use [Tuple](../data-types/tuple.md) to work around NULL skipping behavior. The a `Tuple` that contains only a `NULL` value is not `NULL`, so the aggregate functions won't skip that row because of that `NULL` value.
Also you can use [Tuple](/docs/en/sql-reference/data-types/tuple.md) to work around NULL skipping behavior. The a `Tuple` that contains only a `NULL` value is not `NULL`, so the aggregate functions won't skip that row because of that `NULL` value.
```sql
SELECT

View File

@ -6,7 +6,7 @@ sidebar_position: 106
# argMax
Calculates the `arg` value for a maximum `val` value. If there are several different values of `arg` for maximum values of `val`, returns the first of these values encountered.
Both parts the `arg` and the `max` behave as [aggregate functions](../index.md), they both [skip `Null`](../index.md#null-processing) during processing and return not `Null` values if not `Null` values are available.
Both parts the `arg` and the `max` behave as [aggregate functions](/docs/en/sql-reference/aggregate-functions/index.md), they both [skip `Null`](/docs/en/sql-reference/aggregate-functions/index.md#null-processing) during processing and return not `Null` values if not `Null` values are available.
**Syntax**
@ -106,4 +106,4 @@ SELECT argMax(a, tuple(b)) FROM test;
**See also**
- [Tuple](../../data-types/tuple.md)
- [Tuple](/docs/en/sql-reference/data-types/tuple.md)

View File

@ -6,7 +6,7 @@ sidebar_position: 105
# argMin
Calculates the `arg` value for a minimum `val` value. If there are several different values of `arg` for minimum values of `val`, returns the first of these values encountered.
Both parts the `arg` and the `min` behave as [aggregate functions](../index.md), they both [skip `Null`](../index.md#null-processing) during processing and return not `Null` values if not `Null` values are available.
Both parts the `arg` and the `min` behave as [aggregate functions](/docs/en/sql-reference/aggregate-functions/index.md), they both [skip `Null`](/docs/en/sql-reference/aggregate-functions/index.md#null-processing) during processing and return not `Null` values if not `Null` values are available.
**Syntax**
@ -103,7 +103,7 @@ SELECT argMin((a, b), (b, a)), min(tuple(b, a)) FROM test;
│ (NULL,NULL) │ (NULL,NULL) │ -- argMin returns (NULL,NULL) here because `Tuple` allows to don't skip `NULL` and min(tuple(b, a)) in this case is minimal value for this dataset
└──────────────────────────────────┴──────────────────┘
select argMin(a, tuple(b)) from test;
SELECT argMin(a, tuple(b)) FROM test;
┌─argMax(a, tuple(b))─┐
│ d │ -- `Tuple` can be used in `min` to not skip rows with `NULL` values as b.
└─────────────────────┘
@ -111,4 +111,4 @@ select argMin(a, tuple(b)) from test;
**See also**
- [Tuple](../../data-types/tuple.md)
- [Tuple](/docs/en/sql-reference/data-types/tuple.md)

View File

@ -10,7 +10,7 @@ sidebar_label: INDEX
The following operations are available:
- `ALTER TABLE [db].table_name [ON CLUSTER cluster] ADD INDEX name expression TYPE type GRANULARITY value [FIRST|AFTER name]` - Adds index description to tables metadata.
- `ALTER TABLE [db].table_name [ON CLUSTER cluster] ADD INDEX name expression TYPE type [GRANULARITY value] [FIRST|AFTER name]` - Adds index description to tables metadata.
- `ALTER TABLE [db].table_name [ON CLUSTER cluster] DROP INDEX name` - Removes index description from tables metadata and deletes index files from disk. Implemented as a [mutation](/docs/en/sql-reference/statements/alter/index.md#mutations).

View File

@ -141,6 +141,13 @@ public:
nested_func->merge(place, rhs, arena);
}
bool isAbleToParallelizeMerge() const override { return nested_func->isAbleToParallelizeMerge(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, ThreadPool & thread_pool, Arena * arena) const override
{
nested_func->merge(place, rhs, thread_pool, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_func->serialize(place, buf, version);

View File

@ -110,6 +110,13 @@ public:
nested_func->merge(place, rhs, arena);
}
bool isAbleToParallelizeMerge() const override { return nested_func->isAbleToParallelizeMerge(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, ThreadPool & thread_pool, Arena * arena) const override
{
nested_func->merge(place, rhs, thread_pool, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_func->serialize(place, buf, version);

View File

@ -148,6 +148,13 @@ public:
nested_function->merge(nestedPlace(place), nestedPlace(rhs), arena);
}
bool isAbleToParallelizeMerge() const override { return nested_function->isAbleToParallelizeMerge(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, ThreadPool & thread_pool, Arena * arena) const override
{
nested_function->merge(nestedPlace(place), nestedPlace(rhs), thread_pool, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
bool flag = getFlag(place);

View File

@ -91,6 +91,13 @@ public:
nested_func->merge(place, rhs, arena);
}
bool isAbleToParallelizeMerge() const override { return nested_func->isAbleToParallelizeMerge(); }
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, ThreadPool & thread_pool, Arena * arena) const override
{
nested_func->merge(place, rhs, thread_pool, arena);
}
void serialize(ConstAggregateDataPtr __restrict place, WriteBuffer & buf, std::optional<size_t> version) const override
{
nested_func->serialize(place, buf, version);

View File

@ -117,7 +117,10 @@ ASTPtr ColumnNode::toASTImpl(const ConvertToASTOptions & options) const
else
{
const auto & table_storage_id = table_node->getStorageID();
if (table_storage_id.hasDatabase() && options.qualify_indentifiers_with_database)
column_identifier_parts = { table_storage_id.getDatabaseName(), table_storage_id.getTableName() };
else
column_identifier_parts = { table_storage_id.getTableName() };
}
}
}

View File

@ -187,10 +187,13 @@ public:
/// Identifiers are fully qualified (`database.table.column`), otherwise names are just column names (`column`)
bool fully_qualified_identifiers = true;
/// Identifiers are qualified but database name is not added (`table.column`) if set to false.
bool qualify_indentifiers_with_database = true;
};
/// Convert query tree to AST
ASTPtr toAST(const ConvertToASTOptions & options = { .add_cast_for_constants = true, .fully_qualified_identifiers = true }) const;
ASTPtr toAST(const ConvertToASTOptions & options = { .add_cast_for_constants = true, .fully_qualified_identifiers = true, .qualify_indentifiers_with_database = true }) const;
/// Convert query tree to AST and then format it for error message.
String formatConvertedASTForErrorMessage() const;

View File

@ -10,9 +10,10 @@
namespace DB
{
LambdaNode::LambdaNode(Names argument_names_, QueryTreeNodePtr expression_)
LambdaNode::LambdaNode(Names argument_names_, QueryTreeNodePtr expression_, DataTypePtr result_type_)
: IQueryTreeNode(children_size)
, argument_names(std::move(argument_names_))
, result_type(std::move(result_type_))
{
auto arguments_list_node = std::make_shared<ListNode>();
auto & nodes = arguments_list_node->getNodes();
@ -63,7 +64,7 @@ void LambdaNode::updateTreeHashImpl(HashState & state) const
QueryTreeNodePtr LambdaNode::cloneImpl() const
{
return std::make_shared<LambdaNode>(argument_names, getExpression());
return std::make_shared<LambdaNode>(argument_names, getExpression(), result_type);
}
ASTPtr LambdaNode::toASTImpl(const ConvertToASTOptions & options) const

View File

@ -35,7 +35,7 @@ class LambdaNode final : public IQueryTreeNode
{
public:
/// Initialize lambda with argument names and lambda body expression
explicit LambdaNode(Names argument_names_, QueryTreeNodePtr expression_);
explicit LambdaNode(Names argument_names_, QueryTreeNodePtr expression_, DataTypePtr result_type_ = {});
/// Get argument names
const Names & getArgumentNames() const

View File

@ -4767,13 +4767,14 @@ ProjectionNames QueryAnalyzer::resolveFunction(QueryTreeNodePtr & node, Identifi
auto * table_node = in_second_argument->as<TableNode>();
auto * table_function_node = in_second_argument->as<TableFunctionNode>();
if (table_node && dynamic_cast<StorageSet *>(table_node->getStorage().get()) != nullptr)
if (table_node)
{
/// If table is already prepared set, we do not replace it with subquery
/// If table is already prepared set, we do not replace it with subquery.
/// If table is not a StorageSet, we'll create plan to build set in the Planner.
}
else if (table_node || table_function_node)
else if (table_function_node)
{
const auto & storage_snapshot = table_node ? table_node->getStorageSnapshot() : table_function_node->getStorageSnapshot();
const auto & storage_snapshot = table_function_node->getStorageSnapshot();
auto columns_to_select = storage_snapshot->getColumns(GetColumnsOptions(GetColumnsOptions::Ordinary));
size_t columns_to_select_size = columns_to_select.size();

View File

@ -91,6 +91,11 @@ ASTPtr TableNode::toASTImpl(const ConvertToASTOptions & /* options */) const
if (!temporary_table_name.empty())
return std::make_shared<ASTTableIdentifier>(temporary_table_name);
// In case of cross-replication we don't know what database is used for the table.
// `storage_id.hasDatabase()` can return false only on the initiator node.
// Each shard will use the default database (in the case of cross-replication shards may have different defaults).
if (!storage_id.hasDatabase())
return std::make_shared<ASTTableIdentifier>(storage_id.getTableName());
return std::make_shared<ASTTableIdentifier>(storage_id.getDatabaseName(), storage_id.getTableName());
}

View File

@ -313,6 +313,11 @@ MutableColumnPtr ColumnLowCardinality::cloneResized(size_t size) const
MutableColumnPtr ColumnLowCardinality::cloneNullable() const
{
auto res = cloneFinalized();
/* Compact required not to share dictionary.
* If `shared` flag is not set `cloneFinalized` will return shallow copy
* and `nestedToNullable` will mutate source column.
*/
assert_cast<ColumnLowCardinality &>(*res).compactInplace();
assert_cast<ColumnLowCardinality &>(*res).nestedToNullable();
return res;
}

View File

@ -48,3 +48,16 @@ TEST(ColumnLowCardinality, Insert)
testLowCardinalityNumberInsert<Float32>(std::make_shared<DataTypeFloat32>());
testLowCardinalityNumberInsert<Float64>(std::make_shared<DataTypeFloat64>());
}
TEST(ColumnLowCardinality, Clone)
{
auto data_type = std::make_shared<DataTypeInt32>();
auto low_cardinality_type = std::make_shared<DataTypeLowCardinality>(data_type);
auto column = low_cardinality_type->createColumn();
ASSERT_FALSE(assert_cast<const ColumnLowCardinality &>(*column).nestedIsNullable());
auto nullable_column = assert_cast<const ColumnLowCardinality &>(*column).cloneNullable();
ASSERT_TRUE(assert_cast<const ColumnLowCardinality &>(*nullable_column).nestedIsNullable());
ASSERT_FALSE(assert_cast<const ColumnLowCardinality &>(*column).nestedIsNullable());
}

View File

@ -167,7 +167,7 @@ void ExternalTablesHandler::handlePart(const Poco::Net::MessageHeader & header,
auto temporary_table = TemporaryTableHolder(getContext(), ColumnsDescription{columns}, {});
auto storage = temporary_table.getTable();
getContext()->addExternalTable(data->table_name, std::move(temporary_table));
auto sink = storage->write(ASTPtr(), storage->getInMemoryMetadataPtr(), getContext());
auto sink = storage->write(ASTPtr(), storage->getInMemoryMetadataPtr(), getContext(), /*async_insert=*/false);
/// Write data
auto pipeline = QueryPipelineBuilder::getPipeline(std::move(*data->pipe));

View File

@ -160,6 +160,7 @@ class IColumn;
M(UInt64, allow_experimental_parallel_reading_from_replicas, 0, "Use all the replicas from a shard for SELECT query execution. Reading is parallelized and coordinated dynamically. 0 - disabled, 1 - enabled, silently disable them in case of failure, 2 - enabled, throw an exception in case of failure", 0) \
M(Float, parallel_replicas_single_task_marks_count_multiplier, 2, "A multiplier which will be added during calculation for minimal number of marks to retrieve from coordinator. This will be applied only for remote replicas.", 0) \
M(Bool, parallel_replicas_for_non_replicated_merge_tree, false, "If true, ClickHouse will use parallel replicas algorithm also for non-replicated MergeTree tables", 0) \
M(UInt64, parallel_replicas_min_number_of_granules_to_enable, 0, "If the number of marks to read is less than the value of this setting - parallel replicas will be disabled", 0) \
\
M(Bool, skip_unavailable_shards, false, "If true, ClickHouse silently skips unavailable shards and nodes unresolvable through DNS. Shard is marked as unavailable when none of the replicas can be reached.", 0) \
\
@ -721,7 +722,6 @@ class IColumn;
\
M(Bool, parallelize_output_from_storages, true, "Parallelize output for reading step from storage. It allows parallelizing query processing right after reading from storage if possible", 0) \
M(String, insert_deduplication_token, "", "If not empty, used for duplicate detection instead of data digest", 0) \
M(String, ann_index_select_query_params, "", "Parameters passed to ANN indexes in SELECT queries, the format is 'param1=x, param2=y, ...'", 0) \
M(Bool, count_distinct_optimization, false, "Rewrite count distinct to subquery of group by", 0) \
M(Bool, throw_if_no_data_to_insert, true, "Enables or disables empty INSERTs, enabled by default", 0) \
M(Bool, compatibility_ignore_auto_increment_in_create_table, false, "Ignore AUTO_INCREMENT keyword in column declaration if true, otherwise return error. It simplifies migration from MySQL", 0) \
@ -744,7 +744,8 @@ class IColumn;
M(Bool, allow_experimental_hash_functions, false, "Enable experimental hash functions (hashid, etc)", 0) \
M(Bool, allow_experimental_object_type, false, "Allow Object and JSON data types", 0) \
M(Bool, allow_experimental_annoy_index, false, "Allows to use Annoy index. Disabled by default because this feature is experimental", 0) \
M(UInt64, max_limit_for_ann_queries, 1000000, "Maximum limit value for using ANN indexes is used to prevent memory overflow in search queries for indexes", 0) \
M(UInt64, max_limit_for_ann_queries, 1'000'000, "SELECT queries with LIMIT bigger than this setting cannot use ANN indexes. Helps to prevent memory overflows in ANN search indexes.", 0) \
M(Int64, annoy_index_search_k_nodes, -1, "SELECT queries search up to this many nodes in Annoy indexes.", 0) \
M(Bool, throw_on_unsupported_query_inside_transaction, true, "Throw exception if unsupported query is used inside transaction", 0) \
M(TransactionsWaitCSNMode, wait_changes_become_visible_after_commit_mode, TransactionsWaitCSNMode::WAIT_UNKNOWN, "Wait for committed changes to become actually visible in the latest snapshot", 0) \
M(Bool, implicit_transaction, false, "If enabled and not already inside a transaction, wraps the query inside a full transaction (begin + commit or rollback)", 0) \

View File

@ -52,18 +52,20 @@ bool FileSegmentRangeWriter::write(const char * data, size_t size, size_t offset
FileSegment * file_segment;
if (file_segments.empty() || file_segments.back().isDownloaded())
if (!file_segments || file_segments->empty() || file_segments->front().isDownloaded())
{
file_segment = &allocateFileSegment(expected_write_offset, segment_kind);
}
else
{
file_segment = &file_segments.back();
file_segment = &file_segments->front();
}
SCOPE_EXIT({
if (file_segments.back().isDownloader())
file_segments.back().completePartAndResetDownloader();
if (!file_segments || file_segments->empty())
return;
if (file_segments->front().isDownloader())
file_segments->front().completePartAndResetDownloader();
});
while (size > 0)
@ -71,7 +73,7 @@ bool FileSegmentRangeWriter::write(const char * data, size_t size, size_t offset
size_t available_size = file_segment->range().size() - file_segment->getDownloadedSize(false);
if (available_size == 0)
{
completeFileSegment(*file_segment);
completeFileSegment();
file_segment = &allocateFileSegment(expected_write_offset, segment_kind);
continue;
}
@ -114,10 +116,7 @@ void FileSegmentRangeWriter::finalize()
if (finalized)
return;
if (file_segments.empty())
return;
completeFileSegment(file_segments.back());
completeFileSegment();
finalized = true;
}
@ -145,10 +144,9 @@ FileSegment & FileSegmentRangeWriter::allocateFileSegment(size_t offset, FileSeg
/// We set max_file_segment_size to be downloaded,
/// if we have less size to write, file segment will be resized in complete() method.
auto holder = cache->set(key, offset, cache->getMaxFileSegmentSize(), create_settings);
chassert(holder->size() == 1);
holder->moveTo(file_segments);
return file_segments.back();
file_segments = cache->set(key, offset, cache->getMaxFileSegmentSize(), create_settings);
chassert(file_segments->size() == 1);
return file_segments->front();
}
void FileSegmentRangeWriter::appendFilesystemCacheLog(const FileSegment & file_segment)
@ -176,8 +174,12 @@ void FileSegmentRangeWriter::appendFilesystemCacheLog(const FileSegment & file_s
cache_log->add(elem);
}
void FileSegmentRangeWriter::completeFileSegment(FileSegment & file_segment)
void FileSegmentRangeWriter::completeFileSegment()
{
if (!file_segments || file_segments->empty())
return;
auto & file_segment = file_segments->front();
/// File segment can be detached if space reservation failed.
if (file_segment.isDetached() || file_segment.isCompleted())
return;

View File

@ -43,7 +43,7 @@ private:
void appendFilesystemCacheLog(const FileSegment & file_segment);
void completeFileSegment(FileSegment & file_segment);
void completeFileSegment();
FileCache * cache;
FileSegment::Key key;
@ -53,7 +53,7 @@ private:
String query_id;
String source_path;
FileSegmentsHolder file_segments{};
FileSegmentsHolderPtr file_segments;
size_t expected_write_offset = 0;

View File

@ -1,6 +1,6 @@
#include <Columns/ColumnConst.h>
#include <Columns/ColumnString.h>
#include <Columns/ColumnsNumber.h>
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <DataTypes/DataTypeInterval.h>
@ -25,7 +25,7 @@ class FunctionDateTrunc : public IFunction
public:
static constexpr auto name = "dateTrunc";
explicit FunctionDateTrunc(ContextPtr context_) : context(context_) { }
explicit FunctionDateTrunc(ContextPtr context_) : context(context_) {}
static FunctionPtr create(ContextPtr context) { return std::make_shared<FunctionDateTrunc>(context); }
@ -39,58 +39,51 @@ public:
{
/// The first argument is a constant string with the name of datepart.
intermediate_type_is_date = false;
auto result_type_is_date = false;
String datepart_param;
auto check_first_argument = [&]
{
auto check_first_argument = [&] {
const ColumnConst * datepart_column = checkAndGetColumnConst<ColumnString>(arguments[0].column.get());
if (!datepart_column)
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"First argument for function {} must be constant string: "
"name of datepart",
getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "First argument for function {} must be constant string: "
"name of datepart", getName());
datepart_param = datepart_column->getValue<String>();
if (datepart_param.empty())
throw Exception(
ErrorCodes::BAD_ARGUMENTS, "First argument (name of datepart) for function {} cannot be empty", getName());
throw Exception(ErrorCodes::BAD_ARGUMENTS, "First argument (name of datepart) for function {} cannot be empty",
getName());
if (!IntervalKind::tryParseString(datepart_param, datepart_kind))
throw Exception(ErrorCodes::BAD_ARGUMENTS, "{} doesn't look like datepart name in {}", datepart_param, getName());
intermediate_type_is_date = (datepart_kind == IntervalKind::Year) || (datepart_kind == IntervalKind::Quarter)
|| (datepart_kind == IntervalKind::Month) || (datepart_kind == IntervalKind::Week);
result_type_is_date = (datepart_kind == IntervalKind::Year)
|| (datepart_kind == IntervalKind::Quarter) || (datepart_kind == IntervalKind::Month)
|| (datepart_kind == IntervalKind::Week);
};
bool second_argument_is_date = false;
auto check_second_argument = [&]
{
auto check_second_argument = [&] {
if (!isDate(arguments[1].type) && !isDateTime(arguments[1].type) && !isDateTime64(arguments[1].type))
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of 2nd argument of function {}. "
"Should be a date or a date with time",
arguments[1].type->getName(),
getName());
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of 2nd argument of function {}. "
"Should be a date or a date with time", arguments[1].type->getName(), getName());
second_argument_is_date = isDate(arguments[1].type);
if (second_argument_is_date
&& ((datepart_kind == IntervalKind::Hour) || (datepart_kind == IntervalKind::Minute)
|| (datepart_kind == IntervalKind::Second)))
if (second_argument_is_date && ((datepart_kind == IntervalKind::Hour)
|| (datepart_kind == IntervalKind::Minute) || (datepart_kind == IntervalKind::Second)))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type Date of argument for function {}", getName());
};
auto check_timezone_argument = [&]
{
auto check_timezone_argument = [&] {
if (!WhichDataType(arguments[2].type).isString())
throw Exception(
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Illegal type {} of argument of function {}. "
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}. "
"This argument is optional and must be a constant string with timezone name",
arguments[2].type->getName(),
getName());
arguments[2].type->getName(), getName());
if (second_argument_is_date && result_type_is_date)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"The timezone argument of function {} with datepart '{}' "
"is allowed only when the 2nd argument has the type DateTime",
getName(), datepart_param);
};
if (arguments.size() == 2)
@ -106,13 +99,14 @@ public:
}
else
{
throw Exception(
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Number of arguments for function {} doesn't match: passed {}, should be 2 or 3",
getName(),
arguments.size());
getName(), arguments.size());
}
if (result_type_is_date)
return std::make_shared<DataTypeDate>();
else
return std::make_shared<DataTypeDateTime>(extractTimeZoneNameFromFunctionArguments(arguments, 2, 1));
}
@ -130,40 +124,26 @@ public:
auto to_start_of_interval = FunctionFactory::instance().get("toStartOfInterval", context);
ColumnPtr truncated_column;
auto date_type = std::make_shared<DataTypeDate>();
if (arguments.size() == 2)
truncated_column = to_start_of_interval->build(temp_columns)
->execute(temp_columns, intermediate_type_is_date ? date_type : result_type, input_rows_count);
else
{
return to_start_of_interval->build(temp_columns)->execute(temp_columns, result_type, input_rows_count);
temp_columns[2] = arguments[2];
truncated_column = to_start_of_interval->build(temp_columns)
->execute(temp_columns, intermediate_type_is_date ? date_type : result_type, input_rows_count);
return to_start_of_interval->build(temp_columns)->execute(temp_columns, result_type, input_rows_count);
}
if (!intermediate_type_is_date)
return truncated_column;
ColumnsWithTypeAndName temp_truncated_column(1);
temp_truncated_column[0] = {truncated_column, date_type, ""};
auto to_date_time_or_default = FunctionFactory::instance().get("toDateTime", context);
return to_date_time_or_default->build(temp_truncated_column)->execute(temp_truncated_column, result_type, input_rows_count);
bool hasInformationAboutMonotonicity() const override
{
return true;
}
bool hasInformationAboutMonotonicity() const override { return true; }
Monotonicity getMonotonicityForRange(const IDataType &, const Field &, const Field &) const override
{
return {.is_monotonic = true, .is_always_monotonic = true};
return { .is_monotonic = true, .is_always_monotonic = true };
}
private:
ContextPtr context;
mutable IntervalKind::Kind datepart_kind = IntervalKind::Kind::Second;
mutable bool intermediate_type_is_date = false;
};
}

View File

@ -360,12 +360,6 @@ struct FileSegmentsHolder : private boost::noncopyable
FileSegments::const_iterator begin() const { return file_segments.begin(); }
FileSegments::const_iterator end() const { return file_segments.end(); }
void moveTo(FileSegmentsHolder & holder)
{
holder.file_segments.insert(holder.file_segments.end(), file_segments.begin(), file_segments.end());
file_segments.clear();
}
private:
FileSegments file_segments{};
const bool complete_on_dtor = true;

View File

@ -170,7 +170,7 @@ public:
else if (getContext()->getSettingsRef().use_index_for_in_with_subqueries)
{
auto external_table = external_storage_holder->getTable();
auto table_out = external_table->write({}, external_table->getInMemoryMetadataPtr(), getContext());
auto table_out = external_table->write({}, external_table->getInMemoryMetadataPtr(), getContext(), /*async_insert=*/false);
auto io = interpreter->execute();
io.pipeline.complete(std::move(table_out));
CompletedPipelineExecutor executor(io.pipeline);

View File

@ -707,8 +707,9 @@ Block HashJoin::prepareRightBlock(const Block & block, const Block & saved_block
for (const auto & sample_column : saved_block_sample_.getColumnsWithTypeAndName())
{
ColumnWithTypeAndName column = block.getByName(sample_column.name);
if (sample_column.column->isNullable())
JoinCommon::convertColumnToNullable(column);
/// There's no optimization for right side const columns. Remove constness if any.
column.column = recursiveRemoveSparse(column.column->convertToFullColumnIfConst());
if (column.column->lowCardinality() && !sample_column.column->lowCardinality())
{
@ -716,8 +717,9 @@ Block HashJoin::prepareRightBlock(const Block & block, const Block & saved_block
column.type = removeLowCardinality(column.type);
}
/// There's no optimization for right side const columns. Remove constness if any.
column.column = recursiveRemoveSparse(column.column->convertToFullColumnIfConst());
if (sample_column.column->isNullable())
JoinCommon::convertColumnToNullable(column);
structured_block.insert(std::move(column));
}

View File

@ -282,7 +282,7 @@ Chain InterpreterInsertQuery::buildSink(
/// Otherwise we'll get duplicates when MV reads same rows again from Kafka.
if (table->noPushingToViews() && !no_destination)
{
auto sink = table->write(query_ptr, metadata_snapshot, context_ptr);
auto sink = table->write(query_ptr, metadata_snapshot, context_ptr, async_insert);
sink->setRuntimeData(thread_status, elapsed_counter_ms);
out.addSource(std::move(sink));
}
@ -290,7 +290,7 @@ Chain InterpreterInsertQuery::buildSink(
{
out = buildPushingToViewsChain(table, metadata_snapshot, context_ptr,
query_ptr, no_destination,
thread_status_holder, running_group, elapsed_counter_ms);
thread_status_holder, running_group, elapsed_counter_ms, async_insert);
}
return out;

View File

@ -160,16 +160,14 @@ static ColumnPtr tryConvertColumnToNullable(ColumnPtr col)
if (col->lowCardinality())
{
auto mut_col = IColumn::mutate(std::move(col));
ColumnLowCardinality * col_lc = assert_cast<ColumnLowCardinality *>(mut_col.get());
if (col_lc->nestedIsNullable())
const ColumnLowCardinality & col_lc = assert_cast<const ColumnLowCardinality &>(*col);
if (col_lc.nestedIsNullable())
{
return mut_col;
return col;
}
else if (col_lc->nestedCanBeInsideNullable())
else if (col_lc.nestedCanBeInsideNullable())
{
col_lc->nestedToNullable();
return mut_col;
return col_lc.cloneNullable();
}
}
else if (const ColumnConst * col_const = checkAndGetColumn<ColumnConst>(*col))
@ -232,11 +230,7 @@ void removeColumnNullability(ColumnWithTypeAndName & column)
if (column.column && column.column->lowCardinality())
{
auto mut_col = IColumn::mutate(std::move(column.column));
ColumnLowCardinality * col_as_lc = typeid_cast<ColumnLowCardinality *>(mut_col.get());
if (col_as_lc && col_as_lc->nestedIsNullable())
col_as_lc->nestedRemoveNullable();
column.column = std::move(mut_col);
column.column = assert_cast<const ColumnLowCardinality *>(column.column.get())->cloneWithDefaultOnNull();
}
}
else

View File

@ -232,8 +232,17 @@ public:
bool allowParallelHashJoin() const;
bool joinUseNulls() const { return join_use_nulls; }
bool forceNullableRight() const { return join_use_nulls && isLeftOrFull(kind()); }
bool forceNullableLeft() const { return join_use_nulls && isRightOrFull(kind()); }
bool forceNullableRight() const
{
return join_use_nulls && isLeftOrFull(kind());
}
bool forceNullableLeft() const
{
return join_use_nulls && isRightOrFull(kind());
}
size_t defaultMaxBytes() const { return default_max_bytes; }
size_t maxJoinedBlockRows() const { return max_joined_block_rows; }
size_t maxRowsInRightBlock() const { return partial_merge_join_rows_in_right_blocks; }

View File

@ -192,6 +192,22 @@ Field convertFieldToTypeImpl(const Field & src, const IDataType & type, const ID
{
return static_cast<const DataTypeDateTime &>(type).getTimeZone().fromDayNum(DayNum(src.get<Int32>()));
}
else if (which_type.isDateTime64() && which_from_type.isDate())
{
const auto & date_time64_type = static_cast<const DataTypeDateTime64 &>(type);
const auto value = date_time64_type.getTimeZone().fromDayNum(DayNum(src.get<UInt16>()));
return DecimalField(
DecimalUtils::decimalFromComponentsWithMultiplier<DateTime64>(value, 0, date_time64_type.getScaleMultiplier()),
date_time64_type.getScale());
}
else if (which_type.isDateTime64() && which_from_type.isDate32())
{
const auto & date_time64_type = static_cast<const DataTypeDateTime64 &>(type);
const auto value = date_time64_type.getTimeZone().fromDayNum(ExtendedDayNum(static_cast<Int32>(src.get<Int32>())));
return DecimalField(
DecimalUtils::decimalFromComponentsWithMultiplier<DateTime64>(value, 0, date_time64_type.getScaleMultiplier()),
date_time64_type.getScale());
}
else if (type.isValueRepresentedByNumber() && src.getType() != Field::Types::String)
{
if (which_type.isUInt8()) return convertNumericType<UInt8>(src, type);

View File

@ -0,0 +1,184 @@
#include <initializer_list>
#include <limits>
#include <ostream>
#include <Core/Field.h>
#include <Core/iostream_debug_helpers.h>
#include <Interpreters/convertFieldToType.h>
#include <DataTypes/DataTypeFactory.h>
#include <gtest/gtest.h>
#include "base/Decimal.h"
#include "base/types.h"
using namespace DB;
struct ConvertFieldToTypeTestParams
{
const char * from_type; // MUST NOT BE NULL
const Field from_value;
const char * to_type; // MUST NOT BE NULL
const std::optional<Field> expected_value;
};
std::ostream & operator << (std::ostream & ostr, const ConvertFieldToTypeTestParams & params)
{
return ostr << "{"
<< "\n\tfrom_type : " << params.from_type
<< "\n\tfrom_value : " << params.from_value
<< "\n\tto_type : " << params.to_type
<< "\n\texpected : " << (params.expected_value ? *params.expected_value : Field())
<< "\n}";
}
class ConvertFieldToTypeTest : public ::testing::TestWithParam<ConvertFieldToTypeTestParams>
{};
TEST_P(ConvertFieldToTypeTest, convert)
{
const auto & params = GetParam();
ASSERT_NE(nullptr, params.from_type);
ASSERT_NE(nullptr, params.to_type);
const auto & type_factory = DataTypeFactory::instance();
const auto from_type = type_factory.get(params.from_type);
const auto to_type = type_factory.get(params.to_type);
if (params.expected_value)
{
const auto result = convertFieldToType(params.from_value, *to_type, from_type.get());
EXPECT_EQ(*params.expected_value, result);
}
else
{
EXPECT_ANY_THROW(convertFieldToType(params.from_value, *to_type, from_type.get()));
}
}
// Basically, the number of seconds in a day works for UTC here
const Int64 Day = 24 * 60 * 60;
// 123 is arbitrary value here
INSTANTIATE_TEST_SUITE_P(
DateToDateTime64,
ConvertFieldToTypeTest,
::testing::ValuesIn(std::initializer_list<ConvertFieldToTypeTestParams>{
// min value of Date
{
"Date",
Field(0),
"DateTime64(0, 'UTC')",
DecimalField(DateTime64(0), 0)
},
// Max value of Date
{
"Date",
Field(std::numeric_limits<DB::UInt16>::max()),
"DateTime64(0, 'UTC')",
DecimalField(DateTime64(std::numeric_limits<DB::UInt16>::max() * Day), 0)
},
// check that scale is respected
{
"Date",
Field(123),
"DateTime64(0, 'UTC')",
DecimalField(DateTime64(123 * Day), 0)
},
{
"Date",
Field(1),
"DateTime64(1, 'UTC')",
DecimalField(DateTime64(Day * 10), 1)
},
{
"Date",
Field(123),
"DateTime64(3, 'UTC')",
DecimalField(DateTime64(123 * Day * 1000), 3)
},
{
"Date",
Field(123),
"DateTime64(6, 'UTC')",
DecimalField(DateTime64(123 * Day * 1'000'000), 6)
},
})
);
INSTANTIATE_TEST_SUITE_P(
Date32ToDateTime64,
ConvertFieldToTypeTest,
::testing::ValuesIn(std::initializer_list<ConvertFieldToTypeTestParams>{
// min value of Date32: 1st Jan 1900 (see DATE_LUT_MIN_YEAR)
{
"Date32",
Field(-25'567),
"DateTime64(0, 'UTC')",
DecimalField(DateTime64(-25'567 * Day), 0)
},
// max value of Date32: 31 Dec 2299 (see DATE_LUT_MAX_YEAR)
{
"Date32",
Field(120'529),
"DateTime64(0, 'UTC')",
DecimalField(DateTime64(120'529 * Day), 0)
},
// check that scale is respected
{
"Date32",
Field(123),
"DateTime64(0, 'UTC')",
DecimalField(DateTime64(123 * Day), 0)
},
{
"Date32",
Field(123),
"DateTime64(1, 'UTC')",
DecimalField(DateTime64(123 * Day * 10), 1)
},
{
"Date32",
Field(123),
"DateTime64(3, 'UTC')",
DecimalField(DateTime64(123 * Day * 1000), 3)
},
{
"Date32",
Field(123),
"DateTime64(6, 'UTC')",
DecimalField(DateTime64(123 * Day * 1'000'000), 6)
}
})
);
INSTANTIATE_TEST_SUITE_P(
DateTimeToDateTime64,
ConvertFieldToTypeTest,
::testing::ValuesIn(std::initializer_list<ConvertFieldToTypeTestParams>{
{
"DateTime",
Field(1),
"DateTime64(0, 'UTC')",
DecimalField(DateTime64(1), 0)
},
{
"DateTime",
Field(1),
"DateTime64(1, 'UTC')",
DecimalField(DateTime64(1'0), 1)
},
{
"DateTime",
Field(123),
"DateTime64(3, 'UTC')",
DecimalField(DateTime64(123'000), 3)
},
{
"DateTime",
Field(123),
"DateTime64(6, 'UTC')",
DecimalField(DateTime64(123'000'000), 6)
},
})
);

View File

@ -36,17 +36,17 @@ bool ParserCreateIndexDeclaration::parseImpl(Pos & pos, ASTPtr & node, Expected
if (!data_type_p.parse(pos, type, expected))
return false;
if (!s_granularity.ignore(pos, expected))
return false;
if (s_granularity.ignore(pos, expected))
{
if (!granularity_p.parse(pos, granularity, expected))
return false;
}
auto index = std::make_shared<ASTIndexDeclaration>();
index->part_of_create_index_query = true;
index->granularity = granularity->as<ASTLiteral &>().value.safeGet<UInt64>();
index->set(index->expr, expr);
index->set(index->type, type);
index->granularity = granularity ? granularity->as<ASTLiteral &>().value.safeGet<UInt64>() : 1;
node = index;
return true;

View File

@ -139,9 +139,9 @@ bool ParserIndexDeclaration::parseImpl(Pos & pos, ASTPtr & node, Expected & expe
auto index = std::make_shared<ASTIndexDeclaration>();
index->name = name->as<ASTIdentifier &>().name();
index->granularity = granularity ? granularity->as<ASTLiteral &>().value.safeGet<UInt64>() : 1;
index->set(index->expr, expr);
index->set(index->type, type);
index->granularity = granularity ? granularity->as<ASTLiteral &>().value.safeGet<UInt64>() : 1;
node = index;
return true;

View File

@ -67,7 +67,8 @@ public:
planner_context.registerSet(set_key, PlannerSet(FutureSet(std::move(set))));
}
else if (in_second_argument_node_type == QueryTreeNodeType::QUERY ||
in_second_argument_node_type == QueryTreeNodeType::UNION)
in_second_argument_node_type == QueryTreeNodeType::UNION ||
in_second_argument_node_type == QueryTreeNodeType::TABLE)
{
planner_context.registerSet(set_key, PlannerSet(in_second_argument));
}

View File

@ -43,6 +43,7 @@
#include <Storages/IStorage.h>
#include <Analyzer/Utils.h>
#include <Analyzer/ColumnNode.h>
#include <Analyzer/ConstantNode.h>
#include <Analyzer/FunctionNode.h>
#include <Analyzer/SortNode.h>
@ -909,12 +910,42 @@ void addBuildSubqueriesForSetsStepIfNeeded(QueryPlan & query_plan,
if (!planner_set)
continue;
if (planner_set->getSet().isCreated() || !planner_set->getSubqueryNode())
auto subquery_to_execute = planner_set->getSubqueryNode();
if (planner_set->getSet().isCreated() || !subquery_to_execute)
continue;
if (auto * table_node = subquery_to_execute->as<TableNode>())
{
auto storage_snapshot = table_node->getStorageSnapshot();
auto columns_to_select = storage_snapshot->getColumns(GetColumnsOptions(GetColumnsOptions::Ordinary));
size_t columns_to_select_size = columns_to_select.size();
auto column_nodes_to_select = std::make_shared<ListNode>();
column_nodes_to_select->getNodes().reserve(columns_to_select_size);
NamesAndTypes projection_columns;
projection_columns.reserve(columns_to_select_size);
for (auto & column : columns_to_select)
{
column_nodes_to_select->getNodes().emplace_back(std::make_shared<ColumnNode>(column, subquery_to_execute));
projection_columns.emplace_back(column.name, column.type);
}
auto subquery_for_table = std::make_shared<QueryNode>(Context::createCopy(planner_context->getQueryContext()));
subquery_for_table->setIsSubquery(true);
subquery_for_table->getProjectionNode() = std::move(column_nodes_to_select);
subquery_for_table->getJoinTree() = std::move(subquery_to_execute);
subquery_for_table->resolveProjectionColumns(std::move(projection_columns));
subquery_to_execute = std::move(subquery_for_table);
}
auto subquery_options = select_query_options.subquery();
Planner subquery_planner(
planner_set->getSubqueryNode(),
subquery_to_execute,
subquery_options,
planner_context->getGlobalPlannerContext());
subquery_planner.buildQueryPlanIfNeeded();

View File

@ -19,18 +19,10 @@ const ColumnIdentifier & GlobalPlannerContext::createColumnIdentifier(const Quer
return createColumnIdentifier(column_node_typed.getColumn(), column_source_node);
}
const ColumnIdentifier & GlobalPlannerContext::createColumnIdentifier(const NameAndTypePair & column, const QueryTreeNodePtr & column_source_node)
const ColumnIdentifier & GlobalPlannerContext::createColumnIdentifier(const NameAndTypePair & column, const QueryTreeNodePtr & /*column_source_node*/)
{
std::string column_identifier;
if (column_source_node->hasAlias())
column_identifier += column_source_node->getAlias();
else if (const auto * table_source_node = column_source_node->as<TableNode>())
column_identifier += table_source_node->getStorageID().getFullNameNotQuoted();
if (!column_identifier.empty())
column_identifier += '.';
column_identifier += column.name;
column_identifier += '_' + std::to_string(column_identifiers.size());
@ -137,7 +129,8 @@ void PlannerContext::registerSet(const SetKey & key, PlannerSet planner_set)
auto node_type = subquery_node->getNodeType();
if (node_type != QueryTreeNodeType::QUERY &&
node_type != QueryTreeNodeType::UNION)
node_type != QueryTreeNodeType::UNION &&
node_type != QueryTreeNodeType::TABLE)
throw Exception(ErrorCodes::LOGICAL_ERROR,
"Invalid node for set table expression. Expected query or union. Actual {}",
subquery_node->formatASTForErrorMessage());

View File

@ -106,6 +106,10 @@ void checkAccessRights(const TableNode & table_node, const Names & column_names,
storage_id.getFullTableName());
}
// In case of cross-replication we don't know what database is used for the table.
// `storage_id.hasDatabase()` can return false only on the initiator node.
// Each shard will use the default database (in the case of cross-replication shards may have different defaults).
if (storage_id.hasDatabase())
query_context->checkAccess(AccessType::SELECT, storage_id, column_names);
}
@ -873,10 +877,11 @@ JoinTreeQueryPlan buildQueryPlanForJoinNode(const QueryTreeNodePtr & join_table_
JoinClausesAndActions join_clauses_and_actions;
JoinKind join_kind = join_node.getKind();
JoinStrictness join_strictness = join_node.getStrictness();
std::optional<bool> join_constant;
if (join_node.getStrictness() == JoinStrictness::All)
if (join_strictness == JoinStrictness::All)
join_constant = tryExtractConstantFromJoinNode(join_table_expression);
if (join_constant)

View File

@ -107,7 +107,10 @@ Block buildCommonHeaderForUnion(const Blocks & queries_headers, SelectUnionMode
ASTPtr queryNodeToSelectQuery(const QueryTreeNodePtr & query_node)
{
auto & query_node_typed = query_node->as<QueryNode &>();
auto result_ast = query_node_typed.toAST();
// In case of cross-replication we don't know what database is used for the table.
// Each shard will use the default database (in the case of cross-replication shards may have different defaults).
auto result_ast = query_node_typed.toAST({ .qualify_indentifiers_with_database = false });
while (true)
{

View File

@ -91,7 +91,7 @@ void CreatingSetsTransform::startSubquery()
if (subquery.table)
/// TODO: make via port
table_out = QueryPipeline(subquery.table->write({}, subquery.table->getInMemoryMetadataPtr(), getContext()));
table_out = QueryPipeline(subquery.table->write({}, subquery.table->getInMemoryMetadataPtr(), getContext(), /*async_insert=*/false));
done_with_set = !subquery.set_in_progress;
done_with_table = !subquery.table;

View File

@ -196,6 +196,7 @@ Chain buildPushingToViewsChain(
ThreadStatusesHolderPtr thread_status_holder,
ThreadGroupPtr running_group,
std::atomic_uint64_t * elapsed_counter_ms,
bool async_insert,
const Block & live_view_header)
{
checkStackSize();
@ -347,7 +348,7 @@ Chain buildPushingToViewsChain(
out = buildPushingToViewsChain(
view, view_metadata_snapshot, insert_context, ASTPtr(),
/* no_destination= */ true,
thread_status_holder, running_group, view_counter_ms, storage_header);
thread_status_holder, running_group, view_counter_ms, async_insert, storage_header);
}
else if (auto * window_view = dynamic_cast<StorageWindowView *>(view.get()))
{
@ -356,13 +357,13 @@ Chain buildPushingToViewsChain(
out = buildPushingToViewsChain(
view, view_metadata_snapshot, insert_context, ASTPtr(),
/* no_destination= */ true,
thread_status_holder, running_group, view_counter_ms);
thread_status_holder, running_group, view_counter_ms, async_insert);
}
else
out = buildPushingToViewsChain(
view, view_metadata_snapshot, insert_context, ASTPtr(),
/* no_destination= */ false,
thread_status_holder, running_group, view_counter_ms);
thread_status_holder, running_group, view_counter_ms, async_insert);
views_data->views.emplace_back(ViewRuntimeData{
std::move(query),
@ -444,7 +445,7 @@ Chain buildPushingToViewsChain(
/// Do not push to destination table if the flag is set
else if (!no_destination)
{
auto sink = storage->write(query_ptr, metadata_snapshot, context);
auto sink = storage->write(query_ptr, metadata_snapshot, context, async_insert);
metadata_snapshot->check(sink->getHeader().getColumnsWithTypeAndName());
sink->setRuntimeData(thread_status, elapsed_counter_ms);
result_chain.addSource(std::move(sink));

View File

@ -69,6 +69,8 @@ Chain buildPushingToViewsChain(
ThreadGroupPtr running_group,
/// Counter to measure time spent separately per view. Should be improved.
std::atomic_uint64_t * elapsed_counter_ms,
/// True if it's part of async insert flush
bool async_insert,
/// LiveView executes query itself, it needs source block structure.
const Block & live_view_header = {});

View File

@ -1101,7 +1101,7 @@ namespace
{
/// The data will be written directly to the table.
auto metadata_snapshot = storage->getInMemoryMetadataPtr();
auto sink = storage->write(ASTPtr(), metadata_snapshot, query_context);
auto sink = storage->write(ASTPtr(), metadata_snapshot, query_context, /*async_insert=*/false);
std::unique_ptr<ReadBuffer> buf = std::make_unique<ReadBufferFromMemory>(external_table.data().data(), external_table.data().size());
buf = wrapReadBufferWithCompressionMethod(std::move(buf), chooseCompressionMethod("", external_table.compression_type()));

View File

@ -1692,7 +1692,7 @@ bool TCPHandler::receiveData(bool scalar)
}
auto metadata_snapshot = storage->getInMemoryMetadataPtr();
/// The data will be written directly to the table.
QueryPipeline temporary_table_out(storage->write(ASTPtr(), metadata_snapshot, query_context));
QueryPipeline temporary_table_out(storage->write(ASTPtr(), metadata_snapshot, query_context, /*async_insert=*/false));
PushingPipelineExecutor executor(temporary_table_out);
executor.start();
executor.push(block);

View File

@ -624,7 +624,7 @@ Pipe StorageHDFS::read(
return Pipe::unitePipes(std::move(pipes));
}
SinkToStoragePtr StorageHDFS::write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context_)
SinkToStoragePtr StorageHDFS::write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context_, bool /*async_insert*/)
{
String current_uri = uris.back();

View File

@ -41,7 +41,7 @@ public:
size_t max_block_size,
size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context, bool async_insert) override;
void truncate(
const ASTPtr & query,

View File

@ -905,7 +905,7 @@ HiveFiles StorageHive::collectHiveFiles(
return hive_files;
}
SinkToStoragePtr StorageHive::write(const ASTPtr & /*query*/, const StorageMetadataPtr & /* metadata_snapshot*/, ContextPtr /*context*/)
SinkToStoragePtr StorageHive::write(const ASTPtr & /*query*/, const StorageMetadataPtr & /* metadata_snapshot*/, ContextPtr /*context*/, bool /*async_insert*/)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method write is not implemented for StorageHive");
}

View File

@ -61,7 +61,7 @@ public:
size_t max_block_size,
size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /*context*/) override;
SinkToStoragePtr write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /*context*/, bool async_insert) override;
NamesAndTypesList getVirtuals() const override;

View File

@ -402,11 +402,14 @@ public:
* passed in all parts of the returned streams. Storage metadata can be
* changed during lifetime of the returned streams, but the snapshot is
* guaranteed to be immutable.
*
* async_insert - set to true if the write is part of async insert flushing
*/
virtual SinkToStoragePtr write(
const ASTPtr & /*query*/,
const StorageMetadataPtr & /*metadata_snapshot*/,
ContextPtr /*context*/)
ContextPtr /*context*/,
bool /*async_insert*/)
{
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method write is not supported by storage {}", getName());
}

View File

@ -374,7 +374,7 @@ Pipe StorageKafka::read(
}
SinkToStoragePtr StorageKafka::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
SinkToStoragePtr StorageKafka::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
auto modified_context = Context::createCopy(local_context);
modified_context->applySettingsChanges(settings_adjustments);

View File

@ -60,7 +60,8 @@ public:
SinkToStoragePtr write(
const ASTPtr & query,
const StorageMetadataPtr & /*metadata_snapshot*/,
ContextPtr context) override;
ContextPtr context,
bool async_insert) override;
/// We want to control the number of rows in a chunk inserted into Kafka
bool prefersLargeBlocks() const override { return false; }

View File

@ -137,7 +137,7 @@ Pipe StorageMeiliSearch::read(
return Pipe(std::make_shared<MeiliSearchSource>(config, sample_block, max_block_size, route, kv_pairs_params));
}
SinkToStoragePtr StorageMeiliSearch::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
SinkToStoragePtr StorageMeiliSearch::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
LOG_TRACE(log, "Trying update index: {}", config.index);
return std::make_shared<SinkMeiliSearch>(config, metadata_snapshot->getSampleBlock(), local_context);

View File

@ -26,7 +26,7 @@ public:
size_t max_block_size,
size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool async_insert) override;
static MeiliSearchConfiguration getConfiguration(ASTs engine_args, ContextPtr context);

View File

@ -1,17 +1,15 @@
#include <Storages/MergeTree/CommonANNIndexes.h>
#include <Storages/MergeTree/KeyCondition.h>
#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>
#include <Interpreters/Context.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTIdentifier.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTOrderByElement.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Storages/MergeTree/KeyCondition.h>
#include <Storages/MergeTree/MergeTreeSettings.h>
#include <Interpreters/Context.h>
namespace DB
{
@ -24,208 +22,166 @@ namespace ErrorCodes
namespace
{
namespace ANN = ApproximateNearestNeighbour;
template <typename Literal>
void extractTargetVectorFromLiteral(ANN::ANNQueryInformation::Embedding & target, Literal literal)
void extractReferenceVectorFromLiteral(ApproximateNearestNeighborInformation::Embedding & reference_vector, Literal literal)
{
Float64 float_element_of_target_vector;
Int64 int_element_of_target_vector;
Float64 float_element_of_reference_vector;
Int64 int_element_of_reference_vector;
for (const auto & value : literal.value())
{
if (value.tryGet(float_element_of_target_vector))
{
target.emplace_back(float_element_of_target_vector);
}
else if (value.tryGet(int_element_of_target_vector))
{
target.emplace_back(static_cast<float>(int_element_of_target_vector));
}
if (value.tryGet(float_element_of_reference_vector))
reference_vector.emplace_back(float_element_of_reference_vector);
else if (value.tryGet(int_element_of_reference_vector))
reference_vector.emplace_back(static_cast<float>(int_element_of_reference_vector));
else
{
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in target vector. Only float or int are supported.");
}
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type of elements in reference vector. Only float or int are supported.");
}
}
ANN::ANNQueryInformation::Metric castMetricFromStringToType(String metric_name)
ApproximateNearestNeighborInformation::Metric stringToMetric(std::string_view metric)
{
if (metric_name == "L2Distance")
return ANN::ANNQueryInformation::Metric::L2;
if (metric_name == "LpDistance")
return ANN::ANNQueryInformation::Metric::Lp;
return ANN::ANNQueryInformation::Metric::Unknown;
if (metric == "L2Distance")
return ApproximateNearestNeighborInformation::Metric::L2;
else if (metric == "LpDistance")
return ApproximateNearestNeighborInformation::Metric::Lp;
else
return ApproximateNearestNeighborInformation::Metric::Unknown;
}
}
namespace ApproximateNearestNeighbour
{
ApproximateNearestNeighborCondition::ApproximateNearestNeighborCondition(const SelectQueryInfo & query_info, ContextPtr context)
: block_with_constants(KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context))
, index_granularity(context->getMergeTreeSettings().index_granularity)
, max_limit_for_ann_queries(context->getSettings().max_limit_for_ann_queries)
, index_is_useful(checkQueryStructure(query_info))
{}
ANNCondition::ANNCondition(const SelectQueryInfo & query_info,
ContextPtr context) :
block_with_constants{KeyCondition::getBlockWithConstants(query_info.query, query_info.syntax_analyzer_result, context)},
ann_index_select_query_params{context->getSettings().get("ann_index_select_query_params").get<String>()},
index_granularity{context->getMergeTreeSettings().get("index_granularity").get<UInt64>()},
limit_restriction{context->getSettings().get("max_limit_for_ann_queries").get<UInt64>()},
index_is_useful{checkQueryStructure(query_info)} {}
bool ANNCondition::alwaysUnknownOrTrue(String metric_name) const
bool ApproximateNearestNeighborCondition::alwaysUnknownOrTrue(String metric) const
{
if (!index_is_useful)
{
return true; // Query isn't supported
}
// If query is supported, check metrics for match
return !(castMetricFromStringToType(metric_name) == query_information->metric);
return !(stringToMetric(metric) == query_information->metric);
}
float ANNCondition::getComparisonDistanceForWhereQuery() const
float ApproximateNearestNeighborCondition::getComparisonDistanceForWhereQuery() const
{
if (index_is_useful && query_information.has_value()
&& query_information->query_type == ANNQueryInformation::Type::Where)
{
&& query_information->type == ApproximateNearestNeighborInformation::Type::Where)
return query_information->distance;
}
throw Exception(ErrorCodes::LOGICAL_ERROR, "Not supported method for this query type");
}
UInt64 ANNCondition::getLimit() const
UInt64 ApproximateNearestNeighborCondition::getLimit() const
{
if (index_is_useful && query_information.has_value())
{
return query_information->limit;
}
throw Exception(ErrorCodes::LOGICAL_ERROR, "No LIMIT section in query, not supported");
}
std::vector<float> ANNCondition::getTargetVector() const
std::vector<float> ApproximateNearestNeighborCondition::getReferenceVector() const
{
if (index_is_useful && query_information.has_value())
{
return query_information->target;
}
throw Exception(ErrorCodes::LOGICAL_ERROR, "Target vector was requested for useless or uninitialized index.");
return query_information->reference_vector;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Reference vector was requested for useless or uninitialized index.");
}
size_t ANNCondition::getNumOfDimensions() const
size_t ApproximateNearestNeighborCondition::getNumOfDimensions() const
{
if (index_is_useful && query_information.has_value())
{
return query_information->target.size();
}
return query_information->reference_vector.size();
throw Exception(ErrorCodes::LOGICAL_ERROR, "Number of dimensions was requested for useless or uninitialized index.");
}
String ANNCondition::getColumnName() const
String ApproximateNearestNeighborCondition::getColumnName() const
{
if (index_is_useful && query_information.has_value())
{
return query_information->column_name;
}
throw Exception(ErrorCodes::LOGICAL_ERROR, "Column name was requested for useless or uninitialized index.");
}
ANNQueryInformation::Metric ANNCondition::getMetricType() const
ApproximateNearestNeighborInformation::Metric ApproximateNearestNeighborCondition::getMetricType() const
{
if (index_is_useful && query_information.has_value())
{
return query_information->metric;
}
throw Exception(ErrorCodes::LOGICAL_ERROR, "Metric name was requested for useless or uninitialized index.");
}
float ANNCondition::getPValueForLpDistance() const
float ApproximateNearestNeighborCondition::getPValueForLpDistance() const
{
if (index_is_useful && query_information.has_value())
{
return query_information->p_for_lp_dist;
}
throw Exception(ErrorCodes::LOGICAL_ERROR, "P from LPDistance was requested for useless or uninitialized index.");
}
ANNQueryInformation::Type ANNCondition::getQueryType() const
ApproximateNearestNeighborInformation::Type ApproximateNearestNeighborCondition::getQueryType() const
{
if (index_is_useful && query_information.has_value())
{
return query_information->query_type;
}
return query_information->type;
throw Exception(ErrorCodes::LOGICAL_ERROR, "Query type was requested for useless or uninitialized index.");
}
bool ANNCondition::checkQueryStructure(const SelectQueryInfo & query)
bool ApproximateNearestNeighborCondition::checkQueryStructure(const SelectQueryInfo & query)
{
// RPN-s for different sections of the query
/// RPN-s for different sections of the query
RPN rpn_prewhere_clause;
RPN rpn_where_clause;
RPN rpn_order_by_clause;
RPNElement rpn_limit;
UInt64 limit;
ANNQueryInformation prewhere_info;
ANNQueryInformation where_info;
ANNQueryInformation order_by_info;
ApproximateNearestNeighborInformation prewhere_info;
ApproximateNearestNeighborInformation where_info;
ApproximateNearestNeighborInformation order_by_info;
// Build rpns for query sections
/// Build rpns for query sections
const auto & select = query.query->as<ASTSelectQuery &>();
if (select.prewhere()) // If query has PREWHERE clause
{
/// If query has PREWHERE clause
if (select.prewhere())
traverseAST(select.prewhere(), rpn_prewhere_clause);
}
if (select.where()) // If query has WHERE clause
{
/// If query has WHERE clause
if (select.where())
traverseAST(select.where(), rpn_where_clause);
}
if (select.limitLength()) // If query has LIMIT clause
{
/// If query has LIMIT clause
if (select.limitLength())
traverseAtomAST(select.limitLength(), rpn_limit);
}
if (select.orderBy()) // If query has ORDERBY clause
{
traverseOrderByAST(select.orderBy(), rpn_order_by_clause);
}
// Reverse RPNs for conveniences during parsing
/// Reverse RPNs for conveniences during parsing
std::reverse(rpn_prewhere_clause.begin(), rpn_prewhere_clause.end());
std::reverse(rpn_where_clause.begin(), rpn_where_clause.end());
std::reverse(rpn_order_by_clause.begin(), rpn_order_by_clause.end());
// Match rpns with supported types and extract information
/// Match rpns with supported types and extract information
const bool prewhere_is_valid = matchRPNWhere(rpn_prewhere_clause, prewhere_info);
const bool where_is_valid = matchRPNWhere(rpn_where_clause, where_info);
const bool order_by_is_valid = matchRPNOrderBy(rpn_order_by_clause, order_by_info);
const bool limit_is_valid = matchRPNLimit(rpn_limit, limit);
// Query without a LIMIT clause or with a limit greater than a restriction is not supported
if (!limit_is_valid || limit_restriction < limit)
{
/// Query without a LIMIT clause or with a limit greater than a restriction is not supported
if (!limit_is_valid || max_limit_for_ann_queries < limit)
return false;
}
// Search type query in both sections isn't supported
/// Search type query in both sections isn't supported
if (prewhere_is_valid && where_is_valid)
{
return false;
}
// Search type should be in WHERE or PREWHERE clause
/// Search type should be in WHERE or PREWHERE clause
if (prewhere_is_valid || where_is_valid)
{
query_information = std::move(prewhere_is_valid ? prewhere_info : where_info);
}
if (order_by_is_valid)
{
// Query with valid where and order by type is not supported
/// Query with valid where and order by type is not supported
if (query_information.has_value())
{
return false;
}
query_information = std::move(order_by_info);
}
@ -236,7 +192,7 @@ bool ANNCondition::checkQueryStructure(const SelectQueryInfo & query)
return query_information.has_value();
}
void ANNCondition::traverseAST(const ASTPtr & node, RPN & rpn)
void ApproximateNearestNeighborCondition::traverseAST(const ASTPtr & node, RPN & rpn)
{
// If the node is ASTFunction, it may have children nodes
if (const auto * func = node->as<ASTFunction>())
@ -244,27 +200,23 @@ void ANNCondition::traverseAST(const ASTPtr & node, RPN & rpn)
const ASTs & children = func->arguments->children;
// Traverse children nodes
for (const auto& child : children)
{
traverseAST(child, rpn);
}
}
RPNElement element;
// Get the data behind node
/// Get the data behind node
if (!traverseAtomAST(node, element))
{
element.function = RPNElement::FUNCTION_UNKNOWN;
}
rpn.emplace_back(std::move(element));
}
bool ANNCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
bool ApproximateNearestNeighborCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
{
// Match Functions
/// Match Functions
if (const auto * function = node->as<ASTFunction>())
{
// Set the name
/// Set the name
out.func_name = function->name;
if (function->name == "L1Distance" ||
@ -273,36 +225,24 @@ bool ANNCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
function->name == "cosineDistance" ||
function->name == "dotProduct" ||
function->name == "LpDistance")
{
out.function = RPNElement::FUNCTION_DISTANCE;
}
else if (function->name == "tuple")
{
out.function = RPNElement::FUNCTION_TUPLE;
}
else if (function->name == "array")
{
out.function = RPNElement::FUNCTION_ARRAY;
}
else if (function->name == "less" ||
function->name == "greater" ||
function->name == "lessOrEquals" ||
function->name == "greaterOrEquals")
{
out.function = RPNElement::FUNCTION_COMPARISON;
}
else if (function->name == "_CAST")
{
out.function = RPNElement::FUNCTION_CAST;
}
else
{
return false;
}
return true;
}
// Match identifier
/// Match identifier
else if (const auto * identifier = node->as<ASTIdentifier>())
{
out.function = RPNElement::FUNCTION_IDENTIFIER;
@ -312,11 +252,11 @@ bool ANNCondition::traverseAtomAST(const ASTPtr & node, RPNElement & out)
return true;
}
// Check if we have constants behind the node
/// Check if we have constants behind the node
return tryCastToConstType(node, out);
}
bool ANNCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
bool ApproximateNearestNeighborCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
{
Field const_value;
DataTypePtr const_type;
@ -375,37 +315,29 @@ bool ANNCondition::tryCastToConstType(const ASTPtr & node, RPNElement & out)
return false;
}
void ANNCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn)
void ApproximateNearestNeighborCondition::traverseOrderByAST(const ASTPtr & node, RPN & rpn)
{
if (const auto * expr_list = node->as<ASTExpressionList>())
{
if (const auto * order_by_element = expr_list->children.front()->as<ASTOrderByElement>())
{
traverseAST(order_by_element->children.front(), rpn);
}
}
}
// Returns true and stores ANNQueryInformation if the query has valid WHERE clause
bool ANNCondition::matchRPNWhere(RPN & rpn, ANNQueryInformation & expr)
/// Returns true and stores ApproximateNearestNeighborInformation if the query has valid WHERE clause
bool ApproximateNearestNeighborCondition::matchRPNWhere(RPN & rpn, ApproximateNearestNeighborInformation & ann_info)
{
/// Fill query type field
expr.query_type = ANNQueryInformation::Type::Where;
ann_info.type = ApproximateNearestNeighborInformation::Type::Where;
// WHERE section must have at least 5 expressions
// Operator->Distance(float)->DistanceFunc->Column->Tuple(Array)Func(TargetVector(floats))
/// WHERE section must have at least 5 expressions
/// Operator->Distance(float)->DistanceFunc->Column->Tuple(Array)Func(ReferenceVector(floats))
if (rpn.size() < 5)
{
return false;
}
auto iter = rpn.begin();
// Query starts from operator less
/// Query starts from operator less
if (iter->function != RPNElement::FUNCTION_COMPARISON)
{
return false;
}
const bool greater_case = iter->func_name == "greater" || iter->func_name == "greaterOrEquals";
const bool less_case = iter->func_name == "less" || iter->func_name == "lessOrEquals";
@ -415,64 +347,54 @@ bool ANNCondition::matchRPNWhere(RPN & rpn, ANNQueryInformation & expr)
if (less_case)
{
if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL)
{
return false;
}
expr.distance = getFloatOrIntLiteralOrPanic(iter);
if (expr.distance < 0)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance can't be negative. Got {}", expr.distance);
ann_info.distance = getFloatOrIntLiteralOrPanic(iter);
if (ann_info.distance < 0)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance can't be negative. Got {}", ann_info.distance);
++iter;
}
else if (!greater_case)
{
return false;
}
auto end = rpn.end();
if (!matchMainParts(iter, end, expr))
{
if (!matchMainParts(iter, end, ann_info))
return false;
}
if (greater_case)
{
if (expr.target.size() < 2)
{
if (ann_info.reference_vector.size() < 2)
return false;
}
expr.distance = expr.target.back();
if (expr.distance < 0)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance can't be negative. Got {}", expr.distance);
expr.target.pop_back();
ann_info.distance = ann_info.reference_vector.back();
if (ann_info.distance < 0)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance can't be negative. Got {}", ann_info.distance);
ann_info.reference_vector.pop_back();
}
// query is ok
/// query is ok
return true;
}
// Returns true and stores ANNExpr if the query has valid ORDERBY clause
bool ANNCondition::matchRPNOrderBy(RPN & rpn, ANNQueryInformation & expr)
/// Returns true and stores ANNExpr if the query has valid ORDERBY clause
bool ApproximateNearestNeighborCondition::matchRPNOrderBy(RPN & rpn, ApproximateNearestNeighborInformation & ann_info)
{
/// Fill query type field
expr.query_type = ANNQueryInformation::Type::OrderBy;
ann_info.type = ApproximateNearestNeighborInformation::Type::OrderBy;
// ORDER BY clause must have at least 3 expressions
if (rpn.size() < 3)
{
return false;
}
auto iter = rpn.begin();
auto end = rpn.end();
return ANNCondition::matchMainParts(iter, end, expr);
return ApproximateNearestNeighborCondition::matchMainParts(iter, end, ann_info);
}
// Returns true and stores Length if we have valid LIMIT clause in query
bool ANNCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
/// Returns true and stores Length if we have valid LIMIT clause in query
bool ApproximateNearestNeighborCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
{
if (rpn.function == RPNElement::FUNCTION_INT_LITERAL)
{
@ -483,52 +405,46 @@ bool ANNCondition::matchRPNLimit(RPNElement & rpn, UInt64 & limit)
return false;
}
/* Matches dist function, target vector, column name */
bool ANNCondition::matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ANNQueryInformation & expr)
/// Matches dist function, referencer vector, column name
bool ApproximateNearestNeighborCondition::matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ApproximateNearestNeighborInformation & ann_info)
{
bool identifier_found = false;
// Matches DistanceFunc->[Column]->[Tuple(array)Func]->TargetVector(floats)->[Column]
/// Matches DistanceFunc->[Column]->[Tuple(array)Func]->ReferenceVector(floats)->[Column]
if (iter->function != RPNElement::FUNCTION_DISTANCE)
{
return false;
}
expr.metric = castMetricFromStringToType(iter->func_name);
ann_info.metric = stringToMetric(iter->func_name);
++iter;
if (expr.metric == ANN::ANNQueryInformation::Metric::Lp)
if (ann_info.metric == ApproximateNearestNeighborInformation::Metric::Lp)
{
if (iter->function != RPNElement::FUNCTION_FLOAT_LITERAL &&
iter->function != RPNElement::FUNCTION_INT_LITERAL)
{
return false;
}
expr.p_for_lp_dist = getFloatOrIntLiteralOrPanic(iter);
ann_info.p_for_lp_dist = getFloatOrIntLiteralOrPanic(iter);
++iter;
}
if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
identifier_found = true;
expr.column_name = std::move(iter->identifier.value());
ann_info.column_name = std::move(iter->identifier.value());
++iter;
}
if (iter->function == RPNElement::FUNCTION_TUPLE || iter->function == RPNElement::FUNCTION_ARRAY)
{
++iter;
}
if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE)
{
extractTargetVectorFromLiteral(expr.target, iter->tuple_literal);
extractReferenceVectorFromLiteral(ann_info.reference_vector, iter->tuple_literal);
++iter;
}
if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractTargetVectorFromLiteral(expr.target, iter->array_literal);
extractReferenceVectorFromLiteral(ann_info.reference_vector, iter->array_literal);
++iter;
}
@ -539,68 +455,52 @@ bool ANNCondition::matchMainParts(RPN::iterator & iter, const RPN::iterator & en
++iter;
/// Cast should be made to array or tuple
if (!iter->func_name.starts_with("Array") && !iter->func_name.starts_with("Tuple"))
{
return false;
}
++iter;
if (iter->function == RPNElement::FUNCTION_LITERAL_TUPLE)
{
extractTargetVectorFromLiteral(expr.target, iter->tuple_literal);
extractReferenceVectorFromLiteral(ann_info.reference_vector, iter->tuple_literal);
++iter;
}
else if (iter->function == RPNElement::FUNCTION_LITERAL_ARRAY)
{
extractTargetVectorFromLiteral(expr.target, iter->array_literal);
extractReferenceVectorFromLiteral(ann_info.reference_vector, iter->array_literal);
++iter;
}
else
{
return false;
}
}
while (iter != end)
{
if (iter->function == RPNElement::FUNCTION_FLOAT_LITERAL ||
iter->function == RPNElement::FUNCTION_INT_LITERAL)
{
expr.target.emplace_back(getFloatOrIntLiteralOrPanic(iter));
}
ann_info.reference_vector.emplace_back(getFloatOrIntLiteralOrPanic(iter));
else if (iter->function == RPNElement::FUNCTION_IDENTIFIER)
{
if (identifier_found)
{
return false;
}
expr.column_name = std::move(iter->identifier.value());
ann_info.column_name = std::move(iter->identifier.value());
identifier_found = true;
}
else
{
return false;
}
++iter;
}
// Final checks of correctness
return identifier_found && !expr.target.empty();
/// Final checks of correctness
return identifier_found && !ann_info.reference_vector.empty();
}
// Gets float or int from AST node
float ANNCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter)
/// Gets float or int from AST node
float ApproximateNearestNeighborCondition::getFloatOrIntLiteralOrPanic(const RPN::iterator& iter)
{
if (iter->float_literal.has_value())
{
return iter->float_literal.value();
}
if (iter->int_literal.has_value())
{
return static_cast<float>(iter->int_literal.value());
}
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong parsed AST in buildRPN\n");
}
}
}

View File

@ -0,0 +1,223 @@
#pragma once
#include <Storages/MergeTree/MergeTreeIndices.h>
#include "base/types.h"
#include <optional>
#include <vector>
namespace DB
{
/// Approximate Nearest Neighbour queries have a similar structure:
/// - reference vector from which all distances are calculated
/// - metric name (e.g L2Distance, LpDistance, etc.)
/// - name of column with embeddings
/// - type of query
/// - maximum number of returned elements (LIMIT)
///
/// And two optional parameters:
/// - p for LpDistance function
/// - distance to compare with (only for where queries)
///
/// This struct holds all these components.
struct ApproximateNearestNeighborInformation
{
using Embedding = std::vector<float>;
Embedding reference_vector;
enum class Metric
{
Unknown,
L2,
Lp
};
Metric metric;
String column_name;
UInt64 limit;
enum class Type
{
OrderBy,
Where
};
Type type;
float p_for_lp_dist = -1.0;
float distance = -1.0;
};
// Class ANNCondition, is responsible for recognizing if the query is an ANN queries which can utilize ANN indexes. It parses the SQL query
/// and checks if it matches ANNIndexes. Method alwaysUnknownOrTrue returns false if we can speed up the query, and true otherwise. It has
/// only one argument, the name of the metric with which index was built. Two main patterns of queries are supported
///
/// - 1. WHERE queries:
/// SELECT * FROM * WHERE DistanceFunc(column, reference_vector) < floatLiteral LIMIT count
///
/// - 2. ORDER BY queries:
/// SELECT * FROM * WHERE * ORDER BY DistanceFunc(column, reference_vector) LIMIT count
///
/// Queries without LIMIT count are not supported
/// If the query is both of type 1. and 2., than we can't use the index and alwaysUnknownOrTrue returns true.
/// reference_vector should have float coordinates, e.g. (0.2, 0.1, .., 0.5)
///
/// If the query matches one of these two types, then this class extracts the main information needed for ANN indexes from the query.
///
/// From matching query it extracts
/// - referenceVector
/// - metricName(DistanceFunction)
/// - dimension size if query uses LpDistance
/// - distance to compare(ONLY for search types, otherwise you get exception)
/// - spaceDimension(which is referenceVector's components count)
/// - column
/// - objects count from LIMIT clause(for both queries)
/// - queryHasOrderByClause and queryHasWhereClause return true if query matches the type
///
/// Search query type is also recognized for PREWHERE clause
class ApproximateNearestNeighborCondition
{
public:
ApproximateNearestNeighborCondition(const SelectQueryInfo & query_info, ContextPtr context);
/// Returns false if query can be speeded up by an ANN index, true otherwise.
bool alwaysUnknownOrTrue(String metric) const;
/// Returns the distance to compare with for search query
float getComparisonDistanceForWhereQuery() const;
/// Distance should be calculated regarding to referenceVector
std::vector<float> getReferenceVector() const;
/// Reference vector's dimension size
size_t getNumOfDimensions() const;
String getColumnName() const;
ApproximateNearestNeighborInformation::Metric getMetricType() const;
/// The P- value if the metric is 'LpDistance'
float getPValueForLpDistance() const;
ApproximateNearestNeighborInformation::Type getQueryType() const;
UInt64 getIndexGranularity() const { return index_granularity; }
/// Length's value from LIMIT clause
UInt64 getLimit() const;
private:
struct RPNElement
{
enum Function
{
/// DistanceFunctions
FUNCTION_DISTANCE,
//tuple(0.1, ..., 0.1)
FUNCTION_TUPLE,
//array(0.1, ..., 0.1)
FUNCTION_ARRAY,
/// Operators <, >, <=, >=
FUNCTION_COMPARISON,
/// Numeric float value
FUNCTION_FLOAT_LITERAL,
/// Numeric int value
FUNCTION_INT_LITERAL,
/// Column identifier
FUNCTION_IDENTIFIER,
/// Unknown, can be any value
FUNCTION_UNKNOWN,
/// (0.1, ...., 0.1) vector without word 'tuple'
FUNCTION_LITERAL_TUPLE,
/// [0.1, ...., 0.1] vector without word 'array'
FUNCTION_LITERAL_ARRAY,
/// if client parameters are used, cast will always be in the query
FUNCTION_CAST,
/// name of type in cast function
FUNCTION_STRING_LITERAL,
};
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
: function(function_)
, func_name("Unknown")
, float_literal(std::nullopt)
, identifier(std::nullopt)
{}
Function function;
String func_name;
std::optional<float> float_literal;
std::optional<String> identifier;
std::optional<int64_t> int_literal;
std::optional<Tuple> tuple_literal;
std::optional<Array> array_literal;
UInt32 dim = 0;
};
using RPN = std::vector<RPNElement>;
bool checkQueryStructure(const SelectQueryInfo & query);
/// Util functions for the traversal of AST, parses AST and builds rpn
void traverseAST(const ASTPtr & node, RPN & rpn);
/// Return true if we can identify our node type
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
/// Checks if the AST stores ConstType expression
bool tryCastToConstType(const ASTPtr & node, RPNElement & out);
/// Traverses the AST of ORDERBY section
void traverseOrderByAST(const ASTPtr & node, RPN & rpn);
/// Returns true and stores ANNExpr if the query has valid WHERE section
static bool matchRPNWhere(RPN & rpn, ApproximateNearestNeighborInformation & ann_info);
/// Returns true and stores ANNExpr if the query has valid ORDERBY section
static bool matchRPNOrderBy(RPN & rpn, ApproximateNearestNeighborInformation & ann_info);
/// Returns true and stores Length if we have valid LIMIT clause in query
static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit);
/* Matches dist function, reference vector, column name */
static bool matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ApproximateNearestNeighborInformation & ann_info);
/// Gets float or int from AST node
static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter);
Block block_with_constants;
/// true if we have one of two supported query types
std::optional<ApproximateNearestNeighborInformation> query_information;
// Get from settings ANNIndex parameters
const UInt64 index_granularity;
/// only queries with a lower limit can be considered to avoid memory overflow
const UInt64 max_limit_for_ann_queries;
bool index_is_useful = false;
};
/// Common interface of ANN indexes.
class IMergeTreeIndexConditionApproximateNearestNeighbor : public IMergeTreeIndexCondition
{
public:
/// Returns vector of indexes of ranges in granule which are useful for query.
virtual std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const = 0;
};
}

View File

@ -1,236 +0,0 @@
#pragma once
#include <Storages/MergeTree/MergeTreeIndices.h>
#include "base/types.h"
#include <optional>
#include <vector>
namespace DB
{
namespace ApproximateNearestNeighbour
{
/**
* Queries for Approximate Nearest Neighbour Search
* have similar structure:
* 1) target vector from which all distances are calculated
* 2) metric name (e.g L2Distance, LpDistance, etc.)
* 3) name of column with embeddings
* 4) type of query
* 5) Number of elements, that should be taken (limit)
*
* And two optional parameters:
* 1) p for LpDistance function
* 2) distance to compare with (only for where queries)
*/
struct ANNQueryInformation
{
using Embedding = std::vector<float>;
// Extracted data from valid query
Embedding target;
enum class Metric
{
Unknown,
L2,
Lp
} metric;
String column_name;
UInt64 limit;
enum class Type
{
OrderBy,
Where
} query_type;
float p_for_lp_dist = -1.0;
float distance = -1.0;
};
/**
Class ANNCondition, is responsible for recognizing special query types which
can be speeded up by ANN Indexes. It parses the SQL query and checks
if it matches ANNIndexes. The recognizing method - alwaysUnknownOrTrue
returns false if we can speed up the query, and true otherwise.
It has only one argument, name of the metric with which index was built.
There are two main patterns of queries being supported
1) Search query type
SELECT * FROM * WHERE DistanceFunc(column, target_vector) < floatLiteral LIMIT count
2) OrderBy query type
SELECT * FROM * WHERE * ORDERBY DistanceFunc(column, target_vector) LIMIT count
*Query without LIMIT count is not supported*
target_vector(should have float coordinates) examples:
tuple(0.1, 0.1, ...., 0.1) or (0.1, 0.1, ...., 0.1)
[the word tuple is not needed]
If the query matches one of these two types, than the class extracts useful information
from the query. If the query has both 1 and 2 types, than we can't speed and alwaysUnknownOrTrue
returns true.
From matching query it extracts
* targetVector
* metricName(DistanceFunction)
* dimension size if query uses LpDistance
* distance to compare(ONLY for search types, otherwise you get exception)
* spaceDimension(which is targetVector's components count)
* column
* objects count from LIMIT clause(for both queries)
* settings str, if query has settings section with new 'ann_index_select_query_params' value,
than you can get the new value(empty by default) calling method getSettingsStr
* queryHasOrderByClause and queryHasWhereClause return true if query matches the type
Search query type is also recognized for PREWHERE clause
*/
class ANNCondition
{
public:
ANNCondition(const SelectQueryInfo & query_info,
ContextPtr context);
// false if query can be speeded up, true otherwise
bool alwaysUnknownOrTrue(String metric_name) const;
// returns the distance to compare with for search query
float getComparisonDistanceForWhereQuery() const;
// distance should be calculated regarding to targetVector
std::vector<float> getTargetVector() const;
// targetVector dimension size
size_t getNumOfDimensions() const;
String getColumnName() const;
ANNQueryInformation::Metric getMetricType() const;
// the P- value if the metric is 'LpDistance'
float getPValueForLpDistance() const;
ANNQueryInformation::Type getQueryType() const;
UInt64 getIndexGranularity() const { return index_granularity; }
// length's value from LIMIT clause
UInt64 getLimit() const;
// value of 'ann_index_select_query_params' if have in SETTINGS clause, empty string otherwise
String getParamsStr() const { return ann_index_select_query_params; }
private:
struct RPNElement
{
enum Function
{
// DistanceFunctions
FUNCTION_DISTANCE,
//tuple(0.1, ..., 0.1)
FUNCTION_TUPLE,
//array(0.1, ..., 0.1)
FUNCTION_ARRAY,
// Operators <, >, <=, >=
FUNCTION_COMPARISON,
// Numeric float value
FUNCTION_FLOAT_LITERAL,
// Numeric int value
FUNCTION_INT_LITERAL,
// Column identifier
FUNCTION_IDENTIFIER,
// Unknown, can be any value
FUNCTION_UNKNOWN,
// (0.1, ...., 0.1) vector without word 'tuple'
FUNCTION_LITERAL_TUPLE,
// [0.1, ...., 0.1] vector without word 'array'
FUNCTION_LITERAL_ARRAY,
// if client parameters are used, cast will always be in the query
FUNCTION_CAST,
// name of type in cast function
FUNCTION_STRING_LITERAL,
};
explicit RPNElement(Function function_ = FUNCTION_UNKNOWN)
: function(function_), func_name("Unknown"), float_literal(std::nullopt), identifier(std::nullopt) {}
Function function;
String func_name;
std::optional<float> float_literal;
std::optional<String> identifier;
std::optional<int64_t> int_literal;
std::optional<Tuple> tuple_literal;
std::optional<Array> array_literal;
UInt32 dim = 0;
};
using RPN = std::vector<RPNElement>;
bool checkQueryStructure(const SelectQueryInfo & query);
// Util functions for the traversal of AST, parses AST and builds rpn
void traverseAST(const ASTPtr & node, RPN & rpn);
// Return true if we can identify our node type
bool traverseAtomAST(const ASTPtr & node, RPNElement & out);
// Checks if the AST stores ConstType expression
bool tryCastToConstType(const ASTPtr & node, RPNElement & out);
// Traverses the AST of ORDERBY section
void traverseOrderByAST(const ASTPtr & node, RPN & rpn);
// Returns true and stores ANNExpr if the query has valid WHERE section
static bool matchRPNWhere(RPN & rpn, ANNQueryInformation & expr);
// Returns true and stores ANNExpr if the query has valid ORDERBY section
static bool matchRPNOrderBy(RPN & rpn, ANNQueryInformation & expr);
// Returns true and stores Length if we have valid LIMIT clause in query
static bool matchRPNLimit(RPNElement & rpn, UInt64 & limit);
/* Matches dist function, target vector, column name */
static bool matchMainParts(RPN::iterator & iter, const RPN::iterator & end, ANNQueryInformation & expr);
// Gets float or int from AST node
static float getFloatOrIntLiteralOrPanic(const RPN::iterator& iter);
Block block_with_constants;
// true if we have one of two supported query types
std::optional<ANNQueryInformation> query_information;
// Get from settings ANNIndex parameters
String ann_index_select_query_params;
UInt64 index_granularity;
/// only queries with a lower limit can be considered to avoid memory overflow
UInt64 limit_restriction;
bool index_is_useful = false;
};
// condition interface for Ann indexes. Returns vector of indexes of ranges in granule which are useful for query.
class IMergeTreeIndexConditionAnn : public IMergeTreeIndexCondition
{
public:
virtual std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const = 0;
};
}
}

View File

@ -7154,6 +7154,9 @@ QueryProcessingStage::Enum MergeTreeData::getQueryProcessingStage(
/// Parallel replicas
if (query_context->canUseParallelReplicasOnInitiator() && to_stage >= QueryProcessingStage::WithMergeableState)
{
if (!canUseParallelReplicasBasedOnPKAnalysis(query_context, storage_snapshot, query_info))
return QueryProcessingStage::Enum::FetchColumns;
/// ReplicatedMergeTree
if (supportsReplication())
return QueryProcessingStage::Enum::WithMergeableState;
@ -7179,6 +7182,42 @@ QueryProcessingStage::Enum MergeTreeData::getQueryProcessingStage(
}
bool MergeTreeData::canUseParallelReplicasBasedOnPKAnalysis(
ContextPtr query_context,
const StorageSnapshotPtr & storage_snapshot,
SelectQueryInfo & query_info) const
{
const auto & snapshot_data = assert_cast<const MergeTreeData::SnapshotData &>(*storage_snapshot->data);
const auto & parts = snapshot_data.parts;
MergeTreeDataSelectExecutor reader(*this);
auto result_ptr = reader.estimateNumMarksToRead(
parts,
query_info.prewhere_info,
storage_snapshot->getMetadataForQuery()->getColumns().getAll().getNames(),
storage_snapshot->metadata,
storage_snapshot->metadata,
query_info,
/*added_filter_nodes*/ActionDAGNodes{},
query_context,
query_context->getSettingsRef().max_threads);
if (result_ptr->error())
std::rethrow_exception(std::get<std::exception_ptr>(result_ptr->result));
LOG_TRACE(log, "Estimated number of granules to read is {}", result_ptr->marks());
bool decision = result_ptr->marks() >= query_context->getSettingsRef().parallel_replicas_min_number_of_granules_to_enable;
if (!decision)
LOG_DEBUG(log, "Parallel replicas will be disabled, because the estimated number of granules to read {} is less than the threshold which is {}",
result_ptr->marks(),
query_context->getSettingsRef().parallel_replicas_min_number_of_granules_to_enable);
return decision;
}
MergeTreeData & MergeTreeData::checkStructureAndGetMergeTreeData(IStorage & source_table, const StorageMetadataPtr & src_snapshot, const StorageMetadataPtr & my_snapshot) const
{
MergeTreeData * src_data = dynamic_cast<MergeTreeData *>(&source_table);

View File

@ -1536,6 +1536,13 @@ private:
static MutableDataPartPtr asMutableDeletingPart(const DataPartPtr & part);
mutable TemporaryParts temporary_parts;
/// Estimate the number of marks to read to make a decision whether to enable parallel replicas (distributed processing) or not
/// Note: it could be very rough.
bool canUseParallelReplicasBasedOnPKAnalysis(
ContextPtr query_context,
const StorageSnapshotPtr & storage_snapshot,
SelectQueryInfo & query_info) const;
};
/// RAII struct to record big parts that are submerging or emerging.

View File

@ -46,7 +46,7 @@
#include <IO/WriteBufferFromOStream.h>
#include <Storages/MergeTree/CommonANNIndexes.h>
#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>
namespace CurrentMetrics
{
@ -1714,17 +1714,14 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
{
if (index_mark != index_range.begin || !granule || last_index_mark != index_range.begin)
granule = reader.read();
const auto * gin_filter_condition = dynamic_cast<const MergeTreeConditionInverted *>(&*condition);
// Cast to Ann condition
auto ann_condition = std::dynamic_pointer_cast<ApproximateNearestNeighbour::IMergeTreeIndexConditionAnn>(condition);
auto ann_condition = std::dynamic_pointer_cast<IMergeTreeIndexConditionApproximateNearestNeighbor>(condition);
if (ann_condition != nullptr)
{
// vector of indexes of useful ranges
auto result = ann_condition->getUsefulRanges(granule);
if (result.empty())
{
++granules_dropped;
}
for (auto range : result)
{
@ -1742,6 +1739,7 @@ MarkRanges MergeTreeDataSelectExecutor::filterMarksUsingIndex(
}
bool result = false;
const auto * gin_filter_condition = dynamic_cast<const MergeTreeConditionInverted *>(&*condition);
if (!gin_filter_condition)
result = condition->mayBeTrueOnGranule(granule);
else

View File

@ -42,7 +42,7 @@ void MergeTreeIndexAggregatorBloomFilter::update(const Block & block, size_t * p
{
if (*pos >= block.rows())
throw Exception(ErrorCodes::LOGICAL_ERROR, "The provided position is not less than the number of block rows. "
"Position: {}, Block rows: {}.", toString(*pos), toString(block.rows()));
"Position: {}, Block rows: {}.", *pos, block.rows());
Block granule_index_block;
size_t max_read_rows = std::min(block.rows() - *pos, limit);

View File

@ -2,26 +2,40 @@
#include <Storages/MergeTree/MergeTreeIndexAnnoy.h>
#include <Columns/ColumnArray.h>
#include <Common/typeid_cast.h>
#include <Core/Field.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/castColumn.h>
#include <Columns/ColumnArray.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/Context.h>
#include <Interpreters/castColumn.h>
namespace DB
{
namespace ApproximateNearestNeighbour
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int INCORRECT_DATA;
extern const int INCORRECT_NUMBER_OF_COLUMNS;
extern const int INCORRECT_QUERY;
extern const int LOGICAL_ERROR;
}
template<typename Dist>
void AnnoyIndex<Dist>::serialize(WriteBuffer& ostr) const
template <typename Distance>
AnnoyIndexWithSerialization<Distance>::AnnoyIndexWithSerialization(uint64_t dim)
: Base::AnnoyIndex(dim)
{
assert(Base::_built);
}
template<typename Distance>
void AnnoyIndexWithSerialization<Distance>::serialize(WriteBuffer& ostr) const
{
chassert(Base::_built);
writeIntBinary(Base::_s, ostr);
writeIntBinary(Base::_n_items, ostr);
writeIntBinary(Base::_n_nodes, ostr);
@ -32,10 +46,10 @@ void AnnoyIndex<Dist>::serialize(WriteBuffer& ostr) const
ostr.write(reinterpret_cast<const char*>(Base::_nodes), Base::_s * Base::_n_nodes);
}
template<typename Dist>
void AnnoyIndex<Dist>::deserialize(ReadBuffer& istr)
template<typename Distance>
void AnnoyIndexWithSerialization<Distance>::deserialize(ReadBuffer& istr)
{
assert(!Base::_built);
chassert(!Base::_built);
readIntBinary(Base::_s, istr);
readIntBinary(Base::_n_items, istr);
readIntBinary(Base::_n_nodes, istr);
@ -54,24 +68,12 @@ void AnnoyIndex<Dist>::deserialize(ReadBuffer& istr)
Base::_built = true;
}
template<typename Dist>
uint64_t AnnoyIndex<Dist>::getNumOfDimensions() const
template<typename Distance>
uint64_t AnnoyIndexWithSerialization<Distance>::getNumOfDimensions() const
{
return Base::get_f();
}
}
namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
extern const int INCORRECT_DATA;
extern const int INCORRECT_NUMBER_OF_COLUMNS;
extern const int INCORRECT_QUERY;
extern const int LOGICAL_ERROR;
extern const int BAD_ARGUMENTS;
}
template <typename Distance>
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_)
@ -84,16 +86,16 @@ template <typename Distance>
MergeTreeIndexGranuleAnnoy<Distance>::MergeTreeIndexGranuleAnnoy(
const String & index_name_,
const Block & index_sample_block_,
AnnoyIndexPtr index_base_)
AnnoyIndexWithSerializationPtr<Distance> index_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, index(std::move(index_base_))
, index(std::move(index_))
{}
template <typename Distance>
void MergeTreeIndexGranuleAnnoy<Distance>::serializeBinary(WriteBuffer & ostr) const
{
/// number of dimensions is required in the constructor,
/// Number of dimensions is required in the index constructor,
/// so it must be written and read separately from the other part
writeIntBinary(index->getNumOfDimensions(), ostr); // write dimension
index->serialize(ostr);
@ -104,7 +106,7 @@ void MergeTreeIndexGranuleAnnoy<Distance>::deserializeBinary(ReadBuffer & istr,
{
uint64_t dimension;
readIntBinary(dimension, istr);
index = std::make_shared<AnnoyIndex>(dimension);
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(dimension);
index->deserialize(istr);
}
@ -112,18 +114,18 @@ template <typename Distance>
MergeTreeIndexAggregatorAnnoy<Distance>::MergeTreeIndexAggregatorAnnoy(
const String & index_name_,
const Block & index_sample_block_,
uint64_t number_of_trees_)
uint64_t trees_)
: index_name(index_name_)
, index_sample_block(index_sample_block_)
, number_of_trees(number_of_trees_)
, trees(trees_)
{}
template <typename Distance>
MergeTreeIndexGranulePtr MergeTreeIndexAggregatorAnnoy<Distance>::getGranuleAndReset()
{
// NOLINTNEXTLINE(*)
index->build(static_cast<int>(number_of_trees), /*number_of_threads=*/1);
auto granule = std::make_shared<MergeTreeIndexGranuleAnnoy<Distance> >(index_name, index_sample_block, index);
index->build(static_cast<int>(trees), /*number_of_threads=*/1);
auto granule = std::make_shared<MergeTreeIndexGranuleAnnoy<Distance>>(index_name, index_sample_block, index);
index = nullptr;
return granule;
}
@ -135,270 +137,255 @@ void MergeTreeIndexAggregatorAnnoy<Distance>::update(const Block & block, size_t
throw Exception(
ErrorCodes::LOGICAL_ERROR,
"The provided position is not less than the number of block rows. Position: {}, Block rows: {}.",
toString(*pos), toString(block.rows()));
*pos, block.rows());
size_t rows_read = std::min(limit, block.rows() - *pos);
if (rows_read == 0)
return;
if (index_sample_block.columns() > 1)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Only one column is supported");
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected block with single column");
auto index_column_name = index_sample_block.getByPosition(0).name;
const auto & column_cut = block.getByName(index_column_name).column->cut(*pos, rows_read);
const auto & column_array = typeid_cast<const ColumnArray*>(column_cut.get());
if (column_array)
const String & index_column_name = index_sample_block.getByPosition(0).name;
ColumnPtr column_cut = block.getByName(index_column_name).column->cut(*pos, rows_read);
if (const auto & column_array = typeid_cast<const ColumnArray *>(column_cut.get()))
{
const auto & data = column_array->getData();
const auto & array = typeid_cast<const ColumnFloat32&>(data).getData();
const auto & array = typeid_cast<const ColumnFloat32 &>(data).getData();
if (array.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Array has 0 rows, {} rows expected", rows_read);
const auto & offsets = column_array->getOffsets();
size_t num_rows = offsets.size();
const size_t num_rows = offsets.size();
/// Check all sizes are the same
size_t size = offsets[0];
for (size_t i = 0; i < num_rows - 1; ++i)
if (offsets[i + 1] - offsets[i] != size)
throw Exception(ErrorCodes::INCORRECT_DATA, "Arrays should have same length");
throw Exception(ErrorCodes::INCORRECT_DATA, "All arrays in column {} must have equal length", index_column_name);
index = std::make_shared<AnnoyIndex>(size);
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(size);
/// Add all rows of block
index->add_item(index->get_n_items(), array.data());
/// add all rows from 1 to num_rows - 1 (this is the same as the beginning of the last element)
for (size_t current_row = 1; current_row < num_rows; ++current_row)
index->add_item(index->get_n_items(), &array[offsets[current_row - 1]]);
}
else
else if (const auto & column_tuple = typeid_cast<const ColumnTuple *>(column_cut.get()))
{
/// Other possible type of column is Tuple
const auto & column_tuple = typeid_cast<const ColumnTuple*>(column_cut.get());
if (!column_tuple)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Wrong type was given to index.");
const auto & columns = column_tuple->getColumns();
/// TODO check if calling index->add_item() directly on the block's tuples is faster than materializing everything
std::vector<std::vector<Float32>> data{column_tuple->size(), std::vector<Float32>()};
for (const auto& column : columns)
for (const auto & column : columns)
{
const auto& pod_array = typeid_cast<const ColumnFloat32*>(column.get())->getData();
const auto & pod_array = typeid_cast<const ColumnFloat32 *>(column.get())->getData();
for (size_t i = 0; i < pod_array.size(); ++i)
data[i].push_back(pod_array[i]);
}
assert(!data.empty());
if (!index)
index = std::make_shared<AnnoyIndex>(data[0].size());
for (const auto& item : data)
if (data.empty())
throw Exception(ErrorCodes::LOGICAL_ERROR, "Tuple has 0 rows, {} rows expected", rows_read);
index = std::make_shared<AnnoyIndexWithSerialization<Distance>>(data[0].size());
for (const auto & item : data)
index->add_item(index->get_n_items(), item.data());
}
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Expected Array or Tuple column");
*pos += rows_read;
}
MergeTreeIndexConditionAnnoy::MergeTreeIndexConditionAnnoy(
const IndexDescription & /*index*/,
const IndexDescription & /*index_description*/,
const SelectQueryInfo & query,
ContextPtr context,
const String& distance_name_)
: condition(query, context), distance_name(distance_name_)
const String & distance_function_,
ContextPtr context)
: ann_condition(query, context)
, distance_function(distance_function_)
, search_k(context->getSettings().annoy_index_search_k_nodes)
{}
bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /* idx_granule */) const
bool MergeTreeIndexConditionAnnoy::mayBeTrueOnGranule(MergeTreeIndexGranulePtr /*idx_granule*/) const
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "mayBeTrueOnGranule is not supported for ANN skip indexes");
}
bool MergeTreeIndexConditionAnnoy::alwaysUnknownOrTrue() const
{
return condition.alwaysUnknownOrTrue(distance_name);
return ann_condition.alwaysUnknownOrTrue(distance_function);
}
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const
{
if (distance_name == "L2Distance")
{
return getUsefulRangesImpl<::Annoy::Euclidean>(idx_granule);
}
else if (distance_name == "cosineDistance")
{
return getUsefulRangesImpl<::Annoy::Angular>(idx_granule);
}
else
{
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown distance name. Must be 'L2Distance' or 'cosineDistance'. Got {}", distance_name);
}
if (distance_function == "L2Distance")
return getUsefulRangesImpl<Annoy::Euclidean>(idx_granule);
else if (distance_function == "cosineDistance")
return getUsefulRangesImpl<Annoy::Angular>(idx_granule);
std::unreachable();
}
template <typename Distance>
std::vector<size_t> MergeTreeIndexConditionAnnoy::getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const
{
UInt64 limit = condition.getLimit();
UInt64 index_granularity = condition.getIndexGranularity();
std::optional<float> comp_dist = condition.getQueryType() == ApproximateNearestNeighbour::ANNQueryInformation::Type::Where ?
std::optional<float>(condition.getComparisonDistanceForWhereQuery()) : std::nullopt;
const UInt64 limit = ann_condition.getLimit();
const UInt64 index_granularity = ann_condition.getIndexGranularity();
const std::optional<float> comparison_distance = ann_condition.getQueryType() == ApproximateNearestNeighborInformation::Type::Where
? std::optional<float>(ann_condition.getComparisonDistanceForWhereQuery())
: std::nullopt;
if (comp_dist && comp_dist.value() < 0)
if (comparison_distance && comparison_distance.value() < 0)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Attempt to optimize query with where without distance");
std::vector<float> target_vec = condition.getTargetVector();
const std::vector<float> reference_vector = ann_condition.getReferenceVector();
auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy<Distance> >(idx_granule);
const auto granule = std::dynamic_pointer_cast<MergeTreeIndexGranuleAnnoy<Distance>>(idx_granule);
if (granule == nullptr)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Granule has the wrong type");
auto annoy = granule->index;
const AnnoyIndexWithSerializationPtr<Distance> annoy = granule->index;
if (condition.getNumOfDimensions() != annoy->getNumOfDimensions())
if (ann_condition.getNumOfDimensions() != annoy->getNumOfDimensions())
throw Exception(ErrorCodes::INCORRECT_QUERY, "The dimension of the space in the request ({}) "
"does not match with the dimension in the index ({})",
toString(condition.getNumOfDimensions()), toString(annoy->getNumOfDimensions()));
"does not match the dimension in the index ({})",
ann_condition.getNumOfDimensions(), annoy->getNumOfDimensions());
/// neighbors contain indexes of dots which were closest to target vector
std::vector<UInt64> neighbors;
std::vector<UInt64> neighbors; /// indexes of dots which were closest to the reference vector
std::vector<Float32> distances;
neighbors.reserve(limit);
distances.reserve(limit);
int k_search = -1;
String params_str = condition.getParamsStr();
if (!params_str.empty())
{
try
{
/// k_search=... (algorithm will inspect up to search_k nodes which defaults to n_trees * n if not provided)
k_search = std::stoi(params_str.data() + 9);
}
catch (...)
{
throw Exception(ErrorCodes::INCORRECT_QUERY, "Setting of the annoy index should be int");
}
}
annoy->get_nns_by_vector(target_vec.data(), limit, k_search, &neighbors, &distances);
std::unordered_set<size_t> granule_numbers;
annoy->get_nns_by_vector(reference_vector.data(), limit, static_cast<int>(search_k), &neighbors, &distances);
chassert(neighbors.size() == distances.size());
std::vector<size_t> granule_numbers;
granule_numbers.reserve(neighbors.size());
for (size_t i = 0; i < neighbors.size(); ++i)
{
if (comp_dist && distances[i] > comp_dist)
if (comparison_distance && distances[i] > comparison_distance)
continue;
granule_numbers.insert(neighbors[i] / index_granularity);
granule_numbers.push_back(neighbors[i] / index_granularity);
}
std::vector<size_t> result_vector;
result_vector.reserve(granule_numbers.size());
for (auto granule_number : granule_numbers)
result_vector.push_back(granule_number);
/// make unique
std::sort(granule_numbers.begin(), granule_numbers.end());
granule_numbers.erase(std::unique(granule_numbers.begin(), granule_numbers.end()), granule_numbers.end());
return result_vector;
return granule_numbers;
}
MergeTreeIndexAnnoy::MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t trees_, const String & distance_function_)
: IMergeTreeIndex(index_)
, trees(trees_)
, distance_function(distance_function_)
{}
MergeTreeIndexGranulePtr MergeTreeIndexAnnoy::createIndexGranule() const
{
if (distance_name == "L2Distance")
{
return std::make_shared<MergeTreeIndexGranuleAnnoy<::Annoy::Euclidean> >(index.name, index.sample_block);
}
if (distance_name == "cosineDistance")
{
return std::make_shared<MergeTreeIndexGranuleAnnoy<::Annoy::Angular> >(index.name, index.sample_block);
}
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown distance name. Must be 'L2Distance' or 'cosineDistance'. Got {}", distance_name);
if (distance_function == "L2Distance")
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Euclidean>>(index.name, index.sample_block);
else if (distance_function == "cosineDistance")
return std::make_shared<MergeTreeIndexGranuleAnnoy<Annoy::Angular>>(index.name, index.sample_block);
std::unreachable();
}
MergeTreeIndexAggregatorPtr MergeTreeIndexAnnoy::createIndexAggregator() const
{
if (distance_name == "L2Distance")
{
return std::make_shared<MergeTreeIndexAggregatorAnnoy<::Annoy::Euclidean> >(index.name, index.sample_block, number_of_trees);
}
if (distance_name == "cosineDistance")
{
return std::make_shared<MergeTreeIndexAggregatorAnnoy<::Annoy::Angular> >(index.name, index.sample_block, number_of_trees);
}
throw Exception(ErrorCodes::BAD_ARGUMENTS, "Unknown distance name. Must be 'L2Distance' or 'cosineDistance'. Got {}", distance_name);
/// TODO: Support more metrics. Available metrics: https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171
if (distance_function == "L2Distance")
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Euclidean>>(index.name, index.sample_block, trees);
else if (distance_function == "cosineDistance")
return std::make_shared<MergeTreeIndexAggregatorAnnoy<Annoy::Angular>>(index.name, index.sample_block, trees);
std::unreachable();
}
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(
const SelectQueryInfo & query, ContextPtr context) const
MergeTreeIndexConditionPtr MergeTreeIndexAnnoy::createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const
{
return std::make_shared<MergeTreeIndexConditionAnnoy>(index, query, context, distance_name);
return std::make_shared<MergeTreeIndexConditionAnnoy>(index, query, distance_function, context);
};
MergeTreeIndexPtr annoyIndexCreator(const IndexDescription & index)
{
uint64_t param = 100;
String distance_name = "L2Distance";
if (!index.arguments.empty() && !index.arguments[0].tryGet<uint64_t>(param))
{
if (!index.arguments[0].tryGet<String>(distance_name))
{
throw Exception(ErrorCodes::INCORRECT_DATA, "Can't parse first argument");
}
}
if (index.arguments.size() > 1 && !index.arguments[1].tryGet<String>(distance_name))
{
throw Exception(ErrorCodes::INCORRECT_DATA, "Can't parse second argument");
}
return std::make_shared<MergeTreeIndexAnnoy>(index, param, distance_name);
}
static constexpr auto default_trees = 100uz;
static constexpr auto default_distance_function = "L2Distance";
static void assertIndexColumnsType(const Block & header)
{
DataTypePtr column_data_type_ptr = header.getDataTypes()[0];
String distance_function = default_distance_function;
if (!index.arguments.empty())
distance_function = index.arguments[0].get<String>();
if (const auto * array_type = typeid_cast<const DataTypeArray *>(column_data_type_ptr.get()))
{
TypeIndex nested_type_index = array_type->getNestedType()->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw Exception(
ErrorCodes::ILLEGAL_COLUMN,
"Unexpected type {} of Annoy index. Only Array(Float32) and Tuple(Float32) are supported.",
column_data_type_ptr->getName());
}
else if (const auto * tuple_type = typeid_cast<const DataTypeTuple *>(column_data_type_ptr.get()))
{
const DataTypes & nested_types = tuple_type->getElements();
for (const auto & type : nested_types)
{
TypeIndex nested_type_index = type->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw Exception(
ErrorCodes::ILLEGAL_COLUMN,
"Unexpected type {} of Annoy index. Only Array(Float32) and Tuple(Float32) are supported.",
column_data_type_ptr->getName());
}
}
else
throw Exception(
ErrorCodes::ILLEGAL_COLUMN,
"Unexpected type {} of Annoy index. Only Array(Float32) and Tuple(Float32) are supported.",
column_data_type_ptr->getName());
uint64_t trees = default_trees;
if (index.arguments.size() > 1)
trees = index.arguments[1].get<uint64_t>();
return std::make_shared<MergeTreeIndexAnnoy>(index, trees, distance_function);
}
void annoyIndexValidator(const IndexDescription & index, bool /* attach */)
{
/// Check number and type of Annoy index arguments:
if (index.arguments.size() > 2)
{
throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index must not have more than two parameters");
}
if (!index.arguments.empty() && index.arguments[0].getType() != Field::Types::UInt64
&& index.arguments[0].getType() != Field::Types::String)
{
throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index first argument must be UInt64 or String.");
}
if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::String)
{
throw Exception(ErrorCodes::INCORRECT_QUERY, "Annoy index second argument must be String.");
}
if (!index.arguments.empty() && index.arguments[0].getType() != Field::Types::String)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Distance function argument of Annoy index must be of type String");
if (index.arguments.size() > 1 && index.arguments[1].getType() != Field::Types::UInt64)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Number of trees argument of Annoy index must be UInt64");
/// Check that the index is created on a single column
if (index.column_names.size() != 1 || index.data_types.size() != 1)
throw Exception(ErrorCodes::INCORRECT_NUMBER_OF_COLUMNS, "Annoy indexes must be created on a single column");
assertIndexColumnsType(index.sample_block);
/// Check that a supported metric was passed as first argument
if (!index.arguments.empty())
{
String distance_name = index.arguments[0].get<String>();
if (distance_name != "L2Distance" && distance_name != "cosineDistance")
throw Exception(ErrorCodes::INCORRECT_DATA, "Annoy index supports only distance functions 'L2Distance' and 'cosineDistance'. Given distance function: {}", distance_name);
}
/// Check data type of indexed column:
auto throw_unsupported_underlying_column_exception = [](DataTypePtr data_type)
{
throw Exception(
ErrorCodes::ILLEGAL_COLUMN,
"Annoy indexes can only be created on columns of type Array(Float32) and Tuple(Float32). Given type: {}",
data_type->getName());
};
DataTypePtr data_type = index.sample_block.getDataTypes()[0];
if (const auto * data_type_array = typeid_cast<const DataTypeArray *>(data_type.get()))
{
TypeIndex nested_type_index = data_type_array->getNestedType()->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw_unsupported_underlying_column_exception(data_type);
}
else if (const auto * data_type_tuple = typeid_cast<const DataTypeTuple *>(data_type.get()))
{
const DataTypes & inner_types = data_type_tuple->getElements();
for (const auto & inner_type : inner_types)
{
TypeIndex nested_type_index = inner_type->getTypeId();
if (!WhichDataType(nested_type_index).isFloat32())
throw_unsupported_underlying_column_exception(data_type);
}
}
else
throw_unsupported_underlying_column_exception(data_type);
}
}
#endif // ENABLE_ANNOY
#endif

View File

@ -2,7 +2,7 @@
#ifdef ENABLE_ANNOY
#include <Storages/MergeTree/CommonANNIndexes.h>
#include <Storages/MergeTree/ApproximateNearestNeighborIndexesCommon.h>
#include <annoylib.h>
#include <kissrandom.h>
@ -10,36 +10,26 @@
namespace DB
{
// auxiliary namespace for working with spotify-annoy library
// mainly for serialization and deserialization of the index
namespace ApproximateNearestNeighbour
template <typename Distance>
class AnnoyIndexWithSerialization : public Annoy::AnnoyIndex<UInt64, Float32, Distance, Annoy::Kiss64Random, Annoy::AnnoyIndexMultiThreadedBuildPolicy>
{
using AnnoyIndexThreadedBuildPolicy = ::Annoy::AnnoyIndexMultiThreadedBuildPolicy;
// TODO: Support different metrics. List of available metrics can be taken from here:
// https://github.com/spotify/annoy/blob/master/src/annoymodule.cc#L151-L171
template <typename Distance>
class AnnoyIndex : public ::Annoy::AnnoyIndex<UInt64, Float32, Distance, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>
{
using Base = ::Annoy::AnnoyIndex<UInt64, Float32, Distance, ::Annoy::Kiss64Random, AnnoyIndexThreadedBuildPolicy>;
public:
explicit AnnoyIndex(const uint64_t dim) : Base::AnnoyIndex(dim) {}
using Base = Annoy::AnnoyIndex<UInt64, Float32, Distance, Annoy::Kiss64Random, Annoy::AnnoyIndexMultiThreadedBuildPolicy>;
public:
explicit AnnoyIndexWithSerialization(uint64_t dim);
void serialize(WriteBuffer& ostr) const;
void deserialize(ReadBuffer& istr);
uint64_t getNumOfDimensions() const;
};
}
};
template <typename Distance>
using AnnoyIndexWithSerializationPtr = std::shared_ptr<AnnoyIndexWithSerialization<Distance>>;
template <typename Distance>
struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
{
using AnnoyIndex = ApproximateNearestNeighbour::AnnoyIndex<Distance>;
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_);
MergeTreeIndexGranuleAnnoy(
const String & index_name_,
const Block & index_sample_block_,
AnnoyIndexPtr index_base_);
MergeTreeIndexGranuleAnnoy(const String & index_name_, const Block & index_sample_block_, AnnoyIndexWithSerializationPtr<Distance> index_);
~MergeTreeIndexGranuleAnnoy() override = default;
@ -48,54 +38,50 @@ struct MergeTreeIndexGranuleAnnoy final : public IMergeTreeIndexGranule
bool empty() const override { return !index.get(); }
String index_name;
Block index_sample_block;
AnnoyIndexPtr index;
const String index_name;
const Block index_sample_block;
AnnoyIndexWithSerializationPtr<Distance> index;
};
template <typename Distance>
struct MergeTreeIndexAggregatorAnnoy final : IMergeTreeIndexAggregator
{
using AnnoyIndex = ApproximateNearestNeighbour::AnnoyIndex<Distance>;
using AnnoyIndexPtr = std::shared_ptr<AnnoyIndex>;
MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block, uint64_t number_of_trees);
MergeTreeIndexAggregatorAnnoy(const String & index_name_, const Block & index_sample_block, uint64_t trees);
~MergeTreeIndexAggregatorAnnoy() override = default;
bool empty() const override { return !index || index->get_n_items() == 0; }
MergeTreeIndexGranulePtr getGranuleAndReset() override;
void update(const Block & block, size_t * pos, size_t limit) override;
String index_name;
Block index_sample_block;
const uint64_t number_of_trees;
AnnoyIndexPtr index;
const String index_name;
const Block index_sample_block;
const uint64_t trees;
AnnoyIndexWithSerializationPtr<Distance> index;
};
class MergeTreeIndexConditionAnnoy final : public ApproximateNearestNeighbour::IMergeTreeIndexConditionAnn
class MergeTreeIndexConditionAnnoy final : public IMergeTreeIndexConditionApproximateNearestNeighbor
{
public:
MergeTreeIndexConditionAnnoy(
const IndexDescription & index,
const IndexDescription & index_description,
const SelectQueryInfo & query,
ContextPtr context,
const String& distance_name);
bool alwaysUnknownOrTrue() const override;
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;
const String & distance_function,
ContextPtr context);
~MergeTreeIndexConditionAnnoy() override = default;
bool alwaysUnknownOrTrue() const override;
bool mayBeTrueOnGranule(MergeTreeIndexGranulePtr idx_granule) const override;
std::vector<size_t> getUsefulRanges(MergeTreeIndexGranulePtr idx_granule) const override;
private:
template <typename Distance>
std::vector<size_t> getUsefulRangesImpl(MergeTreeIndexGranulePtr idx_granule) const;
ApproximateNearestNeighbour::ANNCondition condition;
const String distance_name;
const ApproximateNearestNeighborCondition ann_condition;
const String distance_function;
const Int64 search_k;
};
@ -103,28 +89,22 @@ class MergeTreeIndexAnnoy : public IMergeTreeIndex
{
public:
MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t number_of_trees_, const String& distance_name_)
: IMergeTreeIndex(index_)
, number_of_trees(number_of_trees_)
, distance_name(distance_name_)
{}
MergeTreeIndexAnnoy(const IndexDescription & index_, uint64_t trees_, const String & distance_function_);
~MergeTreeIndexAnnoy() override = default;
MergeTreeIndexGranulePtr createIndexGranule() const override;
MergeTreeIndexAggregatorPtr createIndexAggregator() const override;
MergeTreeIndexConditionPtr createIndexCondition(
const SelectQueryInfo & query, ContextPtr context) const override;
MergeTreeIndexConditionPtr createIndexCondition(const SelectQueryInfo & query, ContextPtr context) const override;
bool mayBenefitFromIndexForIn(const ASTPtr & /*node*/) const override { return false; }
private:
const uint64_t number_of_trees;
const String distance_name;
const uint64_t trees;
const String distance_function;
};
}
#endif // ENABLE_ANNOY
#endif

View File

@ -92,7 +92,7 @@ void MergeTreeIndexAggregatorFullText::update(const Block & block, size_t * pos,
{
if (*pos >= block.rows())
throw Exception(ErrorCodes::LOGICAL_ERROR, "The provided position is not less than the number of block rows. "
"Position: {}, Block rows: {}.", toString(*pos), toString(block.rows()));
"Position: {}, Block rows: {}.", *pos, block.rows());
size_t rows_read = std::min(limit, block.rows() - *pos);

View File

@ -123,7 +123,7 @@ void MergeTreeIndexAggregatorInverted::update(const Block & block, size_t * pos,
{
if (*pos >= block.rows())
throw Exception(ErrorCodes::LOGICAL_ERROR, "The provided position is not less than the number of block rows. "
"Position: {}, Block rows: {}.", toString(*pos), toString(block.rows()));
"Position: {}, Block rows: {}.", *pos, block.rows());
size_t rows_read = std::min(limit, block.rows() - *pos);
auto row_id = store->getNextRowIDRange(rows_read);

View File

@ -122,7 +122,7 @@ void MergeTreeIndexAggregatorMinMax::update(const Block & block, size_t * pos, s
{
if (*pos >= block.rows())
throw Exception(ErrorCodes::LOGICAL_ERROR, "The provided position is not less than the number of block rows. "
"Position: {}, Block rows: {}.", toString(*pos), toString(block.rows()));
"Position: {}, Block rows: {}.", *pos, block.rows());
size_t rows_read = std::min(limit, block.rows() - *pos);

View File

@ -146,7 +146,7 @@ void MergeTreeIndexAggregatorSet::update(const Block & block, size_t * pos, size
{
if (*pos >= block.rows())
throw Exception(ErrorCodes::LOGICAL_ERROR, "The provided position is not less than the number of block rows. "
"Position: {}, Block rows: {}.", toString(*pos), toString(block.rows()));
"Position: {}, Block rows: {}.", *pos, block.rows());
size_t rows_read = std::min(limit, block.rows() - *pos);

View File

@ -353,7 +353,7 @@ void StorageNATS::read(
}
SinkToStoragePtr StorageNATS::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
SinkToStoragePtr StorageNATS::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
auto modified_context = addSettings(local_context);
std::string subject = modified_context->getSettingsRef().stream_like_engine_insert_queue.changed

View File

@ -51,7 +51,7 @@ public:
size_t /* max_block_size */,
size_t /* num_streams */) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context, bool async_insert) override;
/// We want to control the number of rows in a chunk inserted into NATS
bool prefersLargeBlocks() const override { return false; }

View File

@ -764,7 +764,7 @@ void StorageRabbitMQ::read(
}
SinkToStoragePtr StorageRabbitMQ::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
SinkToStoragePtr StorageRabbitMQ::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
auto producer = std::make_unique<RabbitMQProducer>(
configuration, routing_keys, exchange_name, exchange_type, producer_id.fetch_add(1), persistent, shutdown_called, log);

View File

@ -57,7 +57,8 @@ public:
SinkToStoragePtr write(
const ASTPtr & query,
const StorageMetadataPtr & metadata_snapshot,
ContextPtr context) override;
ContextPtr context,
bool async_insert) override;
/// We want to control the number of rows in a chunk inserted into RabbitMQ
bool prefersLargeBlocks() const override { return false; }

View File

@ -461,7 +461,7 @@ Pipe StorageEmbeddedRocksDB::read(
}
SinkToStoragePtr StorageEmbeddedRocksDB::write(
const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /*context*/)
const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /*context*/, bool /*async_insert*/)
{
return std::make_shared<EmbeddedRocksDBSink>(*this, metadata_snapshot);
}

View File

@ -48,7 +48,7 @@ public:
size_t max_block_size,
size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context, bool async_insert) override;
void truncate(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr, TableExclusiveLockHolder &) override;
void checkMutationIsPossible(const MutationCommands & commands, const Settings & settings) const override;

View File

@ -656,7 +656,7 @@ private:
};
SinkToStoragePtr StorageBuffer::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /*context*/)
SinkToStoragePtr StorageBuffer::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /*context*/, bool /*async_insert*/)
{
return std::make_shared<BufferSink>(*this, metadata_snapshot);
}

View File

@ -88,7 +88,7 @@ public:
bool supportsSubcolumns() const override { return true; }
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context, bool /*async_insert*/) override;
void startup() override;
/// Flush all buffers into the subordinate table and stop background thread.

View File

@ -847,7 +847,7 @@ private:
/** Execute subquery node and put result in mutable context temporary table.
* Returns table node that is initialized with temporary table storage.
*/
QueryTreeNodePtr executeSubqueryNode(const QueryTreeNodePtr & subquery_node,
TableNodePtr executeSubqueryNode(const QueryTreeNodePtr & subquery_node,
ContextMutablePtr & mutable_context,
size_t subquery_depth)
{
@ -897,7 +897,7 @@ QueryTreeNodePtr executeSubqueryNode(const QueryTreeNodePtr & subquery_node,
auto temporary_table_expression_node = std::make_shared<TableNode>(external_storage, mutable_context);
temporary_table_expression_node->setTemporaryTableName(temporary_table_name);
auto table_out = external_storage->write({}, external_storage->getInMemoryMetadataPtr(), mutable_context);
auto table_out = external_storage->write({}, external_storage->getInMemoryMetadataPtr(), mutable_context, /*async_insert=*/false);
auto io = interpreter.execute();
io.pipeline.complete(std::move(table_out));
CompletedPipelineExecutor executor(io.pipeline);
@ -943,8 +943,14 @@ QueryTreeNodePtr buildQueryTreeDistributed(SelectQueryInfo & query_info,
}
else
{
auto resolved_remote_storage_id = query_context->resolveStorageID(remote_storage_id);
auto storage = std::make_shared<StorageDummy>(resolved_remote_storage_id, distributed_storage_snapshot->metadata->getColumns());
auto resolved_remote_storage_id = remote_storage_id;
// In case of cross-replication we don't know what database is used for the table.
// `storage_id.hasDatabase()` can return false only on the initiator node.
// Each shard will use the default database (in the case of cross-replication shards may have different defaults).
if (remote_storage_id.hasDatabase())
resolved_remote_storage_id = query_context->resolveStorageID(remote_storage_id);
auto storage = std::make_shared<StorageDummy>(resolved_remote_storage_id, distributed_storage_snapshot->metadata->getColumns(), distributed_storage_snapshot->object_columns);
auto table_node = std::make_shared<TableNode>(std::move(storage), query_context);
if (table_expression_modifiers)
@ -1001,6 +1007,7 @@ QueryTreeNodePtr buildQueryTreeDistributed(SelectQueryInfo & query_info,
planner_context->getMutableQueryContext(),
global_in_or_join_node.subquery_depth);
temporary_table_expression_node->setAlias(join_right_table_expression->getAlias());
replacement_map.emplace(join_right_table_expression.get(), std::move(temporary_table_expression_node));
continue;
}
@ -1014,6 +1021,7 @@ QueryTreeNodePtr buildQueryTreeDistributed(SelectQueryInfo & query_info,
auto temporary_table_expression_node = executeSubqueryNode(in_function_subquery_node,
planner_context->getMutableQueryContext(),
global_in_or_join_node.subquery_depth);
in_function_subquery_node = std::move(temporary_table_expression_node);
}
else
@ -1057,9 +1065,8 @@ void StorageDistributed::read(
storage_snapshot,
remote_storage_id,
remote_table_function_ptr);
header = InterpreterSelectQueryAnalyzer::getSampleBlock(query_tree_distributed, local_context, SelectQueryOptions(processed_stage).analyze());
query_ast = queryNodeToSelectQuery(query_tree_distributed);
header = InterpreterSelectQueryAnalyzer::getSampleBlock(query_ast, local_context, SelectQueryOptions(processed_stage).analyze());
}
else
{
@ -1132,7 +1139,7 @@ void StorageDistributed::read(
}
SinkToStoragePtr StorageDistributed::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
SinkToStoragePtr StorageDistributed::write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
auto cluster = getCluster();
const auto & settings = local_context->getSettingsRef();

View File

@ -118,7 +118,7 @@ public:
bool supportsParallelInsert() const override { return true; }
std::optional<UInt64> totalBytes(const Settings &) const override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context, bool /*async_insert*/) override;
std::optional<QueryPipeline> distributedWrite(const ASTInsertQuery & query, ContextPtr context) override;

View File

@ -9,8 +9,9 @@
namespace DB
{
StorageDummy::StorageDummy(const StorageID & table_id_, const ColumnsDescription & columns_)
StorageDummy::StorageDummy(const StorageID & table_id_, const ColumnsDescription & columns_, ColumnsDescription object_columns_)
: IStorage(table_id_)
, object_columns(std::move(object_columns_))
{
StorageInMemoryMetadata storage_metadata;
storage_metadata.setColumns(columns_);

View File

@ -11,7 +11,7 @@ namespace DB
class StorageDummy : public IStorage
{
public:
StorageDummy(const StorageID & table_id_, const ColumnsDescription & columns_);
StorageDummy(const StorageID & table_id_, const ColumnsDescription & columns_, ColumnsDescription object_columns_ = {});
std::string getName() const override { return "StorageDummy"; }
@ -22,6 +22,11 @@ public:
bool supportsDynamicSubcolumns() const override { return true; }
bool canMoveConditionsToPrewhere() const override { return false; }
StorageSnapshotPtr getStorageSnapshot(const StorageMetadataPtr & metadata_snapshot, ContextPtr /*query_context*/) const override
{
return std::make_shared<StorageSnapshot>(*this, metadata_snapshot, object_columns);
}
QueryProcessingStage::Enum getQueryProcessingStage(
ContextPtr local_context,
QueryProcessingStage::Enum to_stage,
@ -37,6 +42,8 @@ public:
QueryProcessingStage::Enum processed_stage,
size_t max_block_size,
size_t num_streams) override;
private:
const ColumnsDescription object_columns;
};
class ReadFromDummy : public SourceStepWithFilter

View File

@ -1049,7 +1049,8 @@ private:
SinkToStoragePtr StorageFile::write(
const ASTPtr & query,
const StorageMetadataPtr & metadata_snapshot,
ContextPtr context)
ContextPtr context,
bool /*async_insert*/)
{
if (format_name == "Distributed")
throw Exception(ErrorCodes::NOT_IMPLEMENTED, "Method write is not implemented for Distributed format");

View File

@ -50,7 +50,8 @@ public:
SinkToStoragePtr write(
const ASTPtr & query,
const StorageMetadataPtr & /*metadata_snapshot*/,
ContextPtr context) override;
ContextPtr context,
bool async_insert) override;
void truncate(
const ASTPtr & /*query*/,

View File

@ -89,10 +89,10 @@ RWLockImpl::LockHolder StorageJoin::tryLockForCurrentQueryTimedWithContext(const
return lock->getLock(type, query_id, acquire_timeout, false);
}
SinkToStoragePtr StorageJoin::write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context)
SinkToStoragePtr StorageJoin::write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context, bool /*async_insert*/)
{
std::lock_guard mutate_lock(mutate_mutex);
return StorageSetOrJoinBase::write(query, metadata_snapshot, context);
return StorageSetOrJoinBase::write(query, metadata_snapshot, context, /*async_insert=*/false);
}
void StorageJoin::truncate(const ASTPtr &, const StorageMetadataPtr &, ContextPtr context, TableExclusiveLockHolder &)

View File

@ -59,7 +59,7 @@ public:
/// (but not during processing whole query, it's safe for joinGet that doesn't involve `used_flags` from HashJoin)
ColumnWithTypeAndName joinGet(const Block & block, const Block & block_with_columns_to_add, ContextPtr context) const;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context, bool async_insert) override;
Pipe read(
const Names & column_names,

View File

@ -524,7 +524,7 @@ Pipe StorageKeeperMap::read(
return process_keys(std::move(filtered_keys));
}
SinkToStoragePtr StorageKeeperMap::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
SinkToStoragePtr StorageKeeperMap::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
checkTable<true>();
return std::make_shared<StorageKeeperMapSink>(*this, metadata_snapshot->getSampleBlock(), local_context);

View File

@ -42,7 +42,7 @@ public:
size_t max_block_size,
size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context, bool async_insert) override;
void truncate(const ASTPtr &, const StorageMetadataPtr &, ContextPtr, TableExclusiveLockHolder &) override;
void drop() override;

View File

@ -855,7 +855,7 @@ Pipe StorageLog::read(
return Pipe::unitePipes(std::move(pipes));
}
SinkToStoragePtr StorageLog::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
SinkToStoragePtr StorageLog::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
WriteLock lock{rwlock, getLockTimeout(local_context)};
if (!lock)

View File

@ -55,7 +55,7 @@ public:
size_t max_block_size,
size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool async_insert) override;
void rename(const String & new_path_to_table_data, const StorageID & new_table_id) override;

View File

@ -32,7 +32,7 @@ public:
QueryProcessingStage::Enum processed_stage,
size_t max_block_size, size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr &, const StorageMetadataPtr &, ContextPtr) override { throwNotAllowed(); }
SinkToStoragePtr write(const ASTPtr &, const StorageMetadataPtr &, ContextPtr, bool) override { throwNotAllowed(); }
NamesAndTypesList getVirtuals() const override;
ColumnSizeByName getColumnSizes() const override;

View File

@ -192,13 +192,13 @@ void StorageMaterializedView::read(
}
}
SinkToStoragePtr StorageMaterializedView::write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr local_context)
SinkToStoragePtr StorageMaterializedView::write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr local_context, bool async_insert)
{
auto storage = getTargetTable();
auto lock = storage->lockForShare(local_context->getCurrentQueryId(), local_context->getSettingsRef().lock_acquire_timeout);
auto metadata_snapshot = storage->getInMemoryMetadataPtr();
auto sink = storage->write(query, metadata_snapshot, local_context);
auto sink = storage->write(query, metadata_snapshot, local_context, async_insert);
sink->addTableLock(lock);
return sink;

View File

@ -39,7 +39,7 @@ public:
return target_table->mayBenefitFromIndexForIn(left_in_operand, query_context, metadata_snapshot);
}
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context, bool async_insert) override;
void drop() override;
void dropInnerTableIfAny(bool sync, ContextPtr local_context) override;

View File

@ -159,7 +159,7 @@ void StorageMemory::read(
}
SinkToStoragePtr StorageMemory::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr context)
SinkToStoragePtr StorageMemory::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr context, bool /*async_insert*/)
{
return std::make_shared<MemorySink>(*this, metadata_snapshot, context);
}

View File

@ -64,7 +64,7 @@ public:
bool hasEvenlyDistributedRead() const override { return true; }
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context, bool async_insert) override;
void drop() override;

View File

@ -274,7 +274,7 @@ std::optional<UInt64> StorageMergeTree::totalBytes(const Settings &) const
}
SinkToStoragePtr
StorageMergeTree::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
StorageMergeTree::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
const auto & settings = local_context->getSettingsRef();
return std::make_shared<MergeTreeSink>(

View File

@ -71,7 +71,7 @@ public:
std::optional<UInt64> totalRowsByPartitionPredicate(const SelectQueryInfo &, ContextPtr) const override;
std::optional<UInt64> totalBytes(const Settings &) const override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context, bool async_insert) override;
/** Perform the next step in combining the parts.
*/

View File

@ -165,7 +165,7 @@ Pipe StorageMongoDB::read(
return Pipe(std::make_shared<MongoDBSource>(connection, createCursor(database_name, collection_name, sample_block), sample_block, max_block_size));
}
SinkToStoragePtr StorageMongoDB::write(const ASTPtr & /* query */, const StorageMetadataPtr & metadata_snapshot, ContextPtr /* context */)
SinkToStoragePtr StorageMongoDB::write(const ASTPtr & /* query */, const StorageMetadataPtr & metadata_snapshot, ContextPtr /* context */, bool /*async_insert*/)
{
connectIfNotConnected();
return std::make_shared<StorageMongoDBSink>(collection_name, database_name, metadata_snapshot, connection);

View File

@ -41,7 +41,8 @@ public:
SinkToStoragePtr write(
const ASTPtr & query,
const StorageMetadataPtr & /*metadata_snapshot*/,
ContextPtr context) override;
ContextPtr context,
bool async_insert) override;
struct Configuration
{

View File

@ -252,7 +252,7 @@ private:
};
SinkToStoragePtr StorageMySQL::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context)
SinkToStoragePtr StorageMySQL::write(const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr local_context, bool /*async_insert*/)
{
return std::make_shared<StorageMySQLSink>(
*this,

View File

@ -49,7 +49,7 @@ public:
size_t max_block_size,
size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context, bool async_insert) override;
struct Configuration
{

View File

@ -46,7 +46,7 @@ public:
bool supportsParallelInsert() const override { return true; }
SinkToStoragePtr write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr) override
SinkToStoragePtr write(const ASTPtr &, const StorageMetadataPtr & metadata_snapshot, ContextPtr, bool) override
{
return std::make_shared<NullSinkToStorage>(metadata_snapshot->getSampleBlock());
}

View File

@ -451,7 +451,7 @@ private:
SinkToStoragePtr StoragePostgreSQL::write(
const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /* context */)
const ASTPtr & /*query*/, const StorageMetadataPtr & metadata_snapshot, ContextPtr /* context */, bool /*async_insert*/)
{
return std::make_shared<PostgreSQLSink>(metadata_snapshot, pool->get(), remote_table_name, remote_table_schema, on_conflict);
}

View File

@ -46,7 +46,7 @@ public:
size_t max_block_size,
size_t num_streams) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context) override;
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & /*metadata_snapshot*/, ContextPtr context, bool async_insert) override;
struct Configuration
{

View File

@ -68,9 +68,9 @@ public:
return getNested()->read(query_plan, column_names, storage_snapshot, query_info, context, processed_stage, max_block_size, num_streams);
}
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context) override
SinkToStoragePtr write(const ASTPtr & query, const StorageMetadataPtr & metadata_snapshot, ContextPtr context, bool async_insert) override
{
return getNested()->write(query, metadata_snapshot, context);
return getNested()->write(query, metadata_snapshot, context, async_insert);
}
void drop() override { getNested()->drop(); }

Some files were not shown because too many files have changed in this diff Show More