Merge remote-tracking branch 'upstream/master' into distinct-combinator

This commit is contained in:
Anton Popov 2020-06-18 01:44:36 +03:00
commit 8ba5bd8530
314 changed files with 7022 additions and 2541 deletions

2
contrib/hyperscan vendored

@ -1 +1 @@
Subproject commit 3058c9c20cba3accdf92544d8513a26240c4ff70 Subproject commit 3907fd00ee8b2538739768fa9533f8635a276531

View File

@ -51,9 +51,6 @@ ORDER BY expr
For a description of parameters, see the [CREATE query description](../../../sql-reference/statements/create.md). For a description of parameters, see the [CREATE query description](../../../sql-reference/statements/create.md).
!!! note "Note"
`INDEX` is an experimental feature, see [Data Skipping Indexes](#table_engine-mergetree-data_skipping-indexes).
### Query Clauses {#mergetree-query-clauses} ### Query Clauses {#mergetree-query-clauses}
- `ENGINE` — Name and parameters of the engine. `ENGINE = MergeTree()`. The `MergeTree` engine does not have parameters. - `ENGINE` — Name and parameters of the engine. `ENGINE = MergeTree()`. The `MergeTree` engine does not have parameters.
@ -257,7 +254,7 @@ ClickHouse cannot use an index if the values of the primary key in the query par
ClickHouse uses this logic not only for days of the month sequences, but for any primary key that represents a partially-monotonic sequence. ClickHouse uses this logic not only for days of the month sequences, but for any primary key that represents a partially-monotonic sequence.
### Data Skipping Indexes (experimental) {#table_engine-mergetree-data_skipping-indexes} ### Data Skipping Indexes {#table_engine-mergetree-data_skipping-indexes}
The index declaration is in the columns section of the `CREATE` query. The index declaration is in the columns section of the `CREATE` query.

View File

@ -291,7 +291,7 @@ ClickHouse supports specific queries through the HTTP interface. For example, yo
$ echo '(4),(5),(6)' | curl 'http://localhost:8123/?query=INSERT%20INTO%20t%20VALUES' --data-binary @- $ echo '(4),(5),(6)' | curl 'http://localhost:8123/?query=INSERT%20INTO%20t%20VALUES' --data-binary @-
``` ```
ClickHouse also supports Predefined HTTP Interface which can help you more easy integration with third party tools like [Prometheus exporter](https://github.com/percona-lab/clickhouse_exporter). ClickHouse also supports Predefined HTTP Interface which can help you more easily integrate with third-party tools like [Prometheus exporter](https://github.com/percona-lab/clickhouse_exporter).
Example: Example:
@ -314,7 +314,7 @@ Example:
</http_handlers> </http_handlers>
``` ```
- You can now request the url directly for data in the Prometheus format: - You can now request the URL directly for data in the Prometheus format:
<!-- --> <!-- -->
@ -361,41 +361,40 @@ $ curl -v 'http://localhost:8123/predefined_query'
* Connection #0 to host localhost left intact * Connection #0 to host localhost left intact
* Connection #0 to host localhost left intact * Connection #0 to host localhost left intact
``` ```
As you can see from the example, if `<http_handlers>` is configured in the config.xml file and `<http_handlers>` can contain many `<rule>s`. ClickHouse will match the HTTP requests received to the predefined type in `<rule>` and the first matched runs the handler. Then ClickHouse will execute the corresponding predefined query if the match is successful. As you can see from the example if `http_handlers` is configured in the config.xml file and `http_handlers` can contain many `rules`. ClickHouse will match the HTTP requests received to the predefined type in `rule` and the first matched runs the handler. Then ClickHouse will execute the corresponding predefined query if the match is successful.
> Now `<rule>` can configure `<method>`, `<headers>`, `<url>`,`<handler>`: Now `rule` can configure `method`, `headers`, `url`, `handler`:
> `<method>` is responsible for matching the method part of the HTTP request. `<method>` fully conforms to the definition of [method](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods) in the HTTP protocol. It is an optional configuration. If it is not defined in the configuration file, it does not match the method portion of the HTTP request. - `method` is responsible for matching the method part of the HTTP request. `method` fully conforms to the definition of [method](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods) in the HTTP protocol. It is an optional configuration. If it is not defined in the configuration file, it does not match the method portion of the HTTP request.
>
> `<url>` is responsible for matching the url part of the HTTP request. It is compatible with [RE2](https://github.com/google/re2)s regular expressions. It is an optional configuration. If it is not defined in the configuration file, it does not match the url portion of the HTTP request.
>
> `<headers>` is responsible for matching the header part of the HTTP request. It is compatible with RE2s regular expressions. It is an optional configuration. If it is not defined in the configuration file, it does not match the header portion of the HTTP request.
>
> `<handler>` contains the main processing part. Now `<handler>` can configure `<type>`, `<status>`, `<content_type>`, `<response_content>`, `<query>`, `<query_param_name>`.
> \> `<type>` currently supports three types: **predefined\_query\_handler**, **dynamic\_query\_handler**, **static**.
> \>
> \> `<query>` - use with predefined\_query\_handler type, executes query when the handler is called.
> \>
> \> `<query_param_name>` - use with dynamic\_query\_handler type, extracts and executes the value corresponding to the `<query_param_name>` value in HTTP request params.
> \>
> \> `<status>` - use with static type, response status code.
> \>
> \> `<content_type>` - use with static type, response [content-type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type).
> \>
> \> `<response_content>` - use with static type, Response content sent to client, when using the prefix file:// or config://, find the content from the file or configuration send to client.
Next are the configuration methods for the different `<type>`. - `url` is responsible for matching the URL part of the HTTP request. It is compatible with [RE2](https://github.com/google/re2)s regular expressions. It is an optional configuration. If it is not defined in the configuration file, it does not match the URL portion of the HTTP request.
## predefined\_query\_handler {#predefined_query_handler} - `headers` are responsible for matching the header part of the HTTP request. It is compatible with RE2s regular expressions. It is an optional configuration. If it is not defined in the configuration file, it does not match the header portion of the HTTP request.
`<predefined_query_handler>` supports setting Settings and query\_params values. You can configure `<query>` in the type of `<predefined_query_handler>`. - `handler` contains the main processing part. Now `handler` can configure `type`, `status`, `content_type`, `response_content`, `query`, `query_param_name`.
`type` currently supports three types: [predefined_query_handler](#predefined_query_handler), [dynamic_query_handler](#dynamic_query_handler), [static](#static).
- `query` — use with `predefined_query_handler` type, executes query when the handler is called.
- `query_param_name` — use with `dynamic_query_handler` type, extracts and executes the value corresponding to the `query_param_name` value in HTTP request params.
- `status` — use with `static` type, response status code.
- `content_type` — use with `static` type, response [content-type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type).
`<query>` value is a predefined query of `<predefined_query_handler>`, which is executed by ClickHouse when an HTTP request is matched and the result of the query is returned. It is a must configuration. - `response_content` — use with `static` type, response content sent to client, when using the prefix file:// or config://, find the content from the file or configuration sends to client.
The following example defines the values of `max_threads` and `max_alter_threads` settings, then queries the system table to check whether these settings were set successfully. Next are the configuration methods for different `type`.
### predefined_query_handler {#predefined_query_handler}
`predefined_query_handler` supports setting `Settings` and `query_params` values. You can configure `query` in the type of `predefined_query_handler`.
`query` value is a predefined query of `predefined_query_handler`, which is executed by ClickHouse when an HTTP request is matched and the result of the query is returned. It is a must configuration.
The following example defines the values of [max_threads](../operations/settings/settings.md#settings-max_threads) and `max_alter_threads` settings, then queries the system table to check whether these settings were set successfully.
Example: Example:
@ -424,15 +423,15 @@ max_alter_threads 2
``` ```
!!! note "caution" !!! note "caution"
In one `<predefined_query_handler>` only supports one `<query>` of an insert type. In one `predefined_query_handler` only supports one `query` of an insert type.
## dynamic\_query\_handler {#dynamic_query_handler} ### dynamic_query_handler {#dynamic_query_handler}
In `<dynamic_query_handler>`, query is written in the form of param of the HTTP request. The difference is that in `<predefined_query_handler>`, query is wrote in the configuration file. You can configure `<query_param_name>` in `<dynamic_query_handler>`. In `dynamic_query_handler`, the query is written in the form of param of the HTTP request. The difference is that in `predefined_query_handler`, the query is written in the configuration file. You can configure `query_param_name` in `dynamic_query_handler`.
ClickHouse extracts and executes the value corresponding to the `<query_param_name>` value in the url of the HTTP request. The default value of `<query_param_name>` is `/query` . It is an optional configuration. If there is no definition in the configuration file, the param is not passed in. ClickHouse extracts and executes the value corresponding to the `query_param_name` value in the URL of the HTTP request. The default value of `query_param_name` is `/query` . It is an optional configuration. If there is no definition in the configuration file, the param is not passed in.
To experiment with this functionality, the example defines the values of max\_threads and max\_alter\_threads and queries whether the Settings were set successfully. To experiment with this functionality, the example defines the values of [max_threads](../operations/settings/settings.md#settings-max_threads) and `max_alter_threads` and `queries` whether the settings were set successfully.
Example: Example:
@ -455,9 +454,9 @@ max_threads 1
max_alter_threads 2 max_alter_threads 2
``` ```
## static {#static} ### static {#static}
`<static>` can return [content\_type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type), [status](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status) and response\_content. response\_content can return the specified content `static` can return [content_type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type), [status](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status) and `response_content`. `response_content` can return the specified content.
Example: Example:

View File

@ -822,6 +822,7 @@ ClickHouse supports the following algorithms of choosing replicas:
- [Nearest hostname](#load_balancing-nearest_hostname) - [Nearest hostname](#load_balancing-nearest_hostname)
- [In order](#load_balancing-in_order) - [In order](#load_balancing-in_order)
- [First or random](#load_balancing-first_or_random) - [First or random](#load_balancing-first_or_random)
- [Round robin](#load_balancing-round_robin)
### Random (by Default) {#load_balancing-random} ### Random (by Default) {#load_balancing-random}
@ -865,6 +866,14 @@ This algorithm chooses the first replica in the set or a random replica if the f
The `first_or_random` algorithm solves the problem of the `in_order` algorithm. With `in_order`, if one replica goes down, the next one gets a double load while the remaining replicas handle the usual amount of traffic. When using the `first_or_random` algorithm, the load is evenly distributed among replicas that are still available. The `first_or_random` algorithm solves the problem of the `in_order` algorithm. With `in_order`, if one replica goes down, the next one gets a double load while the remaining replicas handle the usual amount of traffic. When using the `first_or_random` algorithm, the load is evenly distributed among replicas that are still available.
### Round robin {#load_balancing-round_robin}
``` sql
load_balancing = round_robin
```
This algorithm uses round robin policy across replicas with the same number of errors (only the queries with `round_robin` policy is accounted).
## prefer\_localhost\_replica {#settings-prefer-localhost-replica} ## prefer\_localhost\_replica {#settings-prefer-localhost-replica}
Enables/disables preferable using the localhost replica when processing distributed queries. Enables/disables preferable using the localhost replica when processing distributed queries.

View File

@ -5,7 +5,7 @@ toc_title: ClickHouse Update
# ClickHouse Update {#clickhouse-update} # ClickHouse Update {#clickhouse-update}
If ClickHouse was installed from deb packages, execute the following commands on the server: If ClickHouse was installed from `deb` packages, execute the following commands on the server:
``` bash ``` bash
$ sudo apt-get update $ sudo apt-get update
@ -13,6 +13,6 @@ $ sudo apt-get install clickhouse-client clickhouse-server
$ sudo service clickhouse-server restart $ sudo service clickhouse-server restart
``` ```
If you installed ClickHouse using something other than the recommended deb packages, use the appropriate update method. If you installed ClickHouse using something other than the recommended `deb` packages, use the appropriate update method.
ClickHouse does not support a distributed update. The operation should be performed consecutively on each separate server. Do not update all the servers on a cluster simultaneously, or the cluster will be unavailable for some time. ClickHouse does not support a distributed update. The operation should be performed consecutively on each separate server. Do not update all the servers on a cluster simultaneously, or the cluster will be unavailable for some time.

View File

@ -370,6 +370,46 @@ GROUP BY timeslot
└─────────────────────┴──────────────────────────────────────────────┴────────────────────────────────┘ └─────────────────────┴──────────────────────────────────────────────┴────────────────────────────────┘
``` ```
## minMap(key, value), minMap(Tuple(key, value)) {#agg_functions-minmap}
Calculates the minimum from value array according to the keys specified in the key array.
Passing tuple of keys and values arrays is synonymical to passing two arrays of keys and values.
The number of elements in key and value must be the same for each row that is totaled.
Returns a tuple of two arrays: keys in sorted order, and values calculated for the corresponding keys.
Example:
```sql
SELECT minMap(a, b)
FROM values('a Array(Int32), b Array(Int64)', ([1, 2], [2, 2]), ([2, 3], [1, 1]))
```
```text
┌─minMap(a, b)──────┐
│ ([1,2,3],[2,1,1]) │
└───────────────────┘
```
## maxMap(key, value), maxMap(Tuple(key, value)) {#agg_functions-maxmap}
Calculates the maximum from value array according to the keys specified in the key array.
Passing tuple of keys and values arrays is synonymical to passing two arrays of keys and values.
The number of elements in key and value must be the same for each row that is totaled.
Returns a tuple of two arrays: keys in sorted order, and values calculated for the corresponding keys.
Example:
```sql
SELECT maxMap(a, b)
FROM values('a Array(Int32), b Array(Int64)', ([1, 2], [2, 2]), ([2, 3], [1, 1]))
```
```text
┌─maxMap(a, b)──────┐
│ ([1,2,3],[2,2,1]) │
└───────────────────┘
```
## skewPop {#skewpop} ## skewPop {#skewpop}
Computes the [skewness](https://en.wikipedia.org/wiki/Skewness) of a sequence. Computes the [skewness](https://en.wikipedia.org/wiki/Skewness) of a sequence.

View File

@ -79,6 +79,7 @@ Complied expression cache used when query/user/profile enable option [compile](.
## FLUSH LOGS {#query_language-system-flush_logs} ## FLUSH LOGS {#query_language-system-flush_logs}
Flushes buffers of log messages to system tables (e.g. system.query\_log). Allows you to not wait 7.5 seconds when debugging. Flushes buffers of log messages to system tables (e.g. system.query\_log). Allows you to not wait 7.5 seconds when debugging.
This will also create system tables even if message queue is empty.
## RELOAD CONFIG {#query_language-system-reload-config} ## RELOAD CONFIG {#query_language-system-reload-config}

View File

@ -279,4 +279,332 @@ $ curl -sS 'http://localhost:8123/?max_result_bytes=4000000&buffer_size=3000000&
$ curl -sS "<address>?param_id=2¶m_phrase=test" -d "SELECT * FROM table WHERE int_column = {id:UInt8} and string_column = {phrase:String}" $ curl -sS "<address>?param_id=2¶m_phrase=test" -d "SELECT * FROM table WHERE int_column = {id:UInt8} and string_column = {phrase:String}"
``` ```
## Предопределенный HTTP интерфейс {#predefined_http_interface}
ClickHouse поддерживает определенные запросы через HTTP-интерфейс. Например, вы можете записать данные в таблицу следующим образом:
``` bash
$ echo '(4),(5),(6)' | curl 'http://localhost:8123/?query=INSERT%20INTO%20t%20VALUES' --data-binary @-
```
ClickHouse также поддерживает предопределенный HTTP-интерфейс, который может помочь вам легче интегрироваться со сторонними инструментами, такими как [Prometheus exporter](https://github.com/percona-lab/clickhouse_exporter).
Пример:
- Прежде всего, добавьте раздел в конфигурационный файл сервера:
<!-- -->
``` xml
<http_handlers>
<rule>
<url>/predefined_query</url>
<methods>POST,GET</methods>
<handler>
<type>predefined_query_handler</type>
<query>SELECT * FROM system.metrics LIMIT 5 FORMAT Template SETTINGS format_template_resultset = 'prometheus_template_output_format_resultset', format_template_row = 'prometheus_template_output_format_row', format_template_rows_between_delimiter = '\n'</query>
</handler>
</rule>
<rule>...</rule>
<rule>...</rule>
</http_handlers>
```
- Теперь вы можете напрямую запросить URL-адрес для получения данных в формате Prometheus:
<!-- -->
``` bash
$ curl -v 'http://localhost:8123/predefined_query'
* Trying ::1...
* Connected to localhost (::1) port 8123 (#0)
> GET /predefined_query HTTP/1.1
> Host: localhost:8123
> User-Agent: curl/7.47.0
> Accept: */*
>
< HTTP/1.1 200 OK
< Date: Tue, 28 Apr 2020 08:52:56 GMT
< Connection: Keep-Alive
< Content-Type: text/plain; charset=UTF-8
< X-ClickHouse-Server-Display-Name: i-mloy5trc
< Transfer-Encoding: chunked
< X-ClickHouse-Query-Id: 96fe0052-01e6-43ce-b12a-6b7370de6e8a
< X-ClickHouse-Format: Template
< X-ClickHouse-Timezone: Asia/Shanghai
< Keep-Alive: timeout=3
< X-ClickHouse-Summary: {"read_rows":"0","read_bytes":"0","written_rows":"0","written_bytes":"0","total_rows_to_read":"0"}
<
# HELP "Query" "Number of executing queries"
# TYPE "Query" counter
"Query" 1
# HELP "Merge" "Number of executing background merges"
# TYPE "Merge" counter
"Merge" 0
# HELP "PartMutation" "Number of mutations (ALTER DELETE/UPDATE)"
# TYPE "PartMutation" counter
"PartMutation" 0
# HELP "ReplicatedFetch" "Number of data parts being fetched from replica"
# TYPE "ReplicatedFetch" counter
"ReplicatedFetch" 0
# HELP "ReplicatedSend" "Number of data parts being sent to replicas"
# TYPE "ReplicatedSend" counter
"ReplicatedSend" 0
* Connection #0 to host localhost left intact
* Connection #0 to host localhost left intact
```
Как вы можете видеть из примера, `http_handlers` настраивается в файле config.xml и может содержать несколько правил. ClickHouse будет сопоставлять полученные HTTP-запросы с предопределенным типом в правиле, и первое совпадение запустит обработчик. Затем ClickHouse выполнит соответствующий предопределенный запрос.
В настоящий момент с помощью `rule` можно настроить `method`, `headers`, `url`, `handler`:
- `method` отвечает за соответствие метода HTTP-запроса. `method` соответствует методу [method](https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods) протокола HTTP. Это необязательная настройка. Если она не определена в файле конфигурации, она не соответствует методу HTTP-запроса.
- `url` отвечает за соответствие URL HTTP-запроса. Она совместима с регулярными выражениями [RE2](https://github.com/google/re2). Это необязательная настройка. Если она не определена в файле конфигурации, она не соответствует URL-адресу HTTP-запроса.
- `headers` отвечают за соответствие заголовка HTTP-запроса. Она совместим с регулярными выражениями RE2. Это необязательная настройка. Если она не определен в файле конфигурации, она не соответствует заголовку HTTP-запроса.
- `handler` содержит основную часть обработчика. Сейчас `handler` может настраивать `type`, `status`, `content_type`, `response_content`, `query`, `query_param_name`.
`type` на данный момент поддерживает три типа: [predefined_query_handler](#predefined_query_handler), [dynamic_query_handler](#dynamic_query_handler), [static](#static).
- `query` — используется с типом `predefined_query_handler`, выполняет запрос при вызове обработчика.
- `query_param_name` — используется с типом `dynamic_query_handler`, извлекает и выполняет значение, соответствующее значению `query_param_name` в параметрах HTTP-запроса.
- `status` — используется с типом `static`, возвращает код состояния ответа.
- `content_type` — используется с типом `static`, возвращает [content-type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type).
- `response_content` — используется с типом`static`, содержимое ответа, отправленное клиенту, при использовании префикса file:// or config://, находит содержимое из файла или конфигурации, отправленного клиенту.
Далее приведены методы настройки для различных типов.
### predefined_query_handler {#predefined_query_handler}
`predefined_query_handler` поддерживает настройки `Settings` и `query_params` значений. Вы можете настроить запрос в типе `predefined_query_handler`.
Значение `query` — это предопределенный запрос `predefined_query_handler`, который выполняется ClickHouse при совпадении HTTP-запроса и возврате результата запроса. Это обязательная настройка.
В следующем примере определяются настройки [max_threads](../operations/settings/settings.md#settings-max_threads) и `max_alter_threads`, а затем запрашивается системная таблица, чтобы проверить, были ли эти параметры успешно установлены.
Пример:
``` xml
<http_handlers>
<rule>
<url><![CDATA[/query_param_with_url/\w+/(?P<name_1>[^/]+)(/(?P<name_2>[^/]+))?]]></url>
<method>GET</method>
<headers>
<XXX>TEST_HEADER_VALUE</XXX>
<PARAMS_XXX><![CDATA[(?P<name_1>[^/]+)(/(?P<name_2>[^/]+))?]]></PARAMS_XXX>
</headers>
<handler>
<type>predefined_query_handler</type>
<query>SELECT value FROM system.settings WHERE name = {name_1:String}</query>
<query>SELECT name, value FROM system.settings WHERE name = {name_2:String}</query>
</handler>
</rule>
</http_handlers>
```
``` bash
$ curl -H 'XXX:TEST_HEADER_VALUE' -H 'PARAMS_XXX:max_threads' 'http://localhost:8123/query_param_with_url/1/max_threads/max_alter_threads?max_threads=1&max_alter_threads=2'
1
max_alter_threads 2
```
!!! note "Предупреждение"
В одном `predefined_query_handler` поддерживается только один запрос типа `INSERT`.
### dynamic_query_handler {#dynamic_query_handler}
В `dynamic_query_handler`, запрос пишется в виде параметров HTTP-запроса. Разница в том, что в `predefined_query_handler`, запрос записывается в конфигурационный файл. Вы можете настроить `query_param_name` в `dynamic_query_handler`.
ClickHouse извлекает и выполняет значение, соответствующее значению `query_param_name` URL-адресе HTTP-запроса. Значение по умолчанию `query_param_name` — это `/query` . Это необязательная настройка. Если в файле конфигурации нет определения, параметр не передается.
Чтобы поэкспериментировать с этой функциональностью, в примере определяются значения [max_threads](../operations/settings/settings.md#settings-max_threads) и `max_alter_threads` и запрашивается, успешно ли были установлены настройки.
Пример:
``` xml
<http_handlers>
<rule>
<headers>
<XXX>TEST_HEADER_VALUE_DYNAMIC</XXX> </headers>
<handler>
<type>dynamic_query_handler</type>
<query_param_name>query_param</query_param_name>
</handler>
</rule>
</http_handlers>
```
``` bash
$ curl -H 'XXX:TEST_HEADER_VALUE_DYNAMIC' 'http://localhost:8123/own?max_threads=1&max_alter_threads=2&param_name_1=max_threads&param_name_2=max_alter_threads&query_param=SELECT%20name,value%20FROM%20system.settings%20where%20name%20=%20%7Bname_1:String%7D%20OR%20name%20=%20%7Bname_2:String%7D'
max_threads 1
max_alter_threads 2
```
### static {#static}
`static` может возвращать [content_type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type), [status](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status) и `response_content`. `response_content` может возвращать конкретное содержимое.
Пример:
Возвращает сообщение.
``` xml
<http_handlers>
<rule>
<methods>GET</methods>
<headers><XXX>xxx</XXX></headers>
<url>/hi</url>
<handler>
<type>static</type>
<status>402</status>
<content_type>text/html; charset=UTF-8</content_type>
<response_content>Say Hi!</response_content>
</handler>
</rule>
<http_handlers>
```
``` bash
$ curl -vv -H 'XXX:xxx' 'http://localhost:8123/hi'
* Trying ::1...
* Connected to localhost (::1) port 8123 (#0)
> GET /hi HTTP/1.1
> Host: localhost:8123
> User-Agent: curl/7.47.0
> Accept: */*
> XXX:xxx
>
< HTTP/1.1 402 Payment Required
< Date: Wed, 29 Apr 2020 03:51:26 GMT
< Connection: Keep-Alive
< Content-Type: text/html; charset=UTF-8
< Transfer-Encoding: chunked
< Keep-Alive: timeout=3
< X-ClickHouse-Summary: {"read_rows":"0","read_bytes":"0","written_rows":"0","written_bytes":"0","total_rows_to_read":"0"}
<
* Connection #0 to host localhost left intact
Say Hi!%
```
Находит содержимое настроек отправленных клиенту.
``` xml
<get_config_static_handler><![CDATA[<html ng-app="SMI2"><head><base href="http://ui.tabix.io/"></head><body><div ui-view="" class="content-ui"></div><script src="http://loader.tabix.io/master.js"></script></body></html>]]></get_config_static_handler>
<http_handlers>
<rule>
<methods>GET</methods>
<headers><XXX>xxx</XXX></headers>
<url>/get_config_static_handler</url>
<handler>
<type>static</type>
<response_content>config://get_config_static_handler</response_content>
</handler>
</rule>
</http_handlers>
```
``` bash
$ curl -v -H 'XXX:xxx' 'http://localhost:8123/get_config_static_handler'
* Trying ::1...
* Connected to localhost (::1) port 8123 (#0)
> GET /get_config_static_handler HTTP/1.1
> Host: localhost:8123
> User-Agent: curl/7.47.0
> Accept: */*
> XXX:xxx
>
< HTTP/1.1 200 OK
< Date: Wed, 29 Apr 2020 04:01:24 GMT
< Connection: Keep-Alive
< Content-Type: text/plain; charset=UTF-8
< Transfer-Encoding: chunked
< Keep-Alive: timeout=3
< X-ClickHouse-Summary: {"read_rows":"0","read_bytes":"0","written_rows":"0","written_bytes":"0","total_rows_to_read":"0"}
<
* Connection #0 to host localhost left intact
<html ng-app="SMI2"><head><base href="http://ui.tabix.io/"></head><body><div ui-view="" class="content-ui"></div><script src="http://loader.tabix.io/master.js"></script></body></html>%
```
Находит содержимое файла, отправленного клиенту.
``` xml
<http_handlers>
<rule>
<methods>GET</methods>
<headers><XXX>xxx</XXX></headers>
<url>/get_absolute_path_static_handler</url>
<handler>
<type>static</type>
<content_type>text/html; charset=UTF-8</content_type>
<response_content>file:///absolute_path_file.html</response_content>
</handler>
</rule>
<rule>
<methods>GET</methods>
<headers><XXX>xxx</XXX></headers>
<url>/get_relative_path_static_handler</url>
<handler>
<type>static</type>
<content_type>text/html; charset=UTF-8</content_type>
<response_content>file://./relative_path_file.html</response_content>
</handler>
</rule>
</http_handlers>
```
``` bash
$ user_files_path='/var/lib/clickhouse/user_files'
$ sudo echo "<html><body>Relative Path File</body></html>" > $user_files_path/relative_path_file.html
$ sudo echo "<html><body>Absolute Path File</body></html>" > $user_files_path/absolute_path_file.html
$ curl -vv -H 'XXX:xxx' 'http://localhost:8123/get_absolute_path_static_handler'
* Trying ::1...
* Connected to localhost (::1) port 8123 (#0)
> GET /get_absolute_path_static_handler HTTP/1.1
> Host: localhost:8123
> User-Agent: curl/7.47.0
> Accept: */*
> XXX:xxx
>
< HTTP/1.1 200 OK
< Date: Wed, 29 Apr 2020 04:18:16 GMT
< Connection: Keep-Alive
< Content-Type: text/html; charset=UTF-8
< Transfer-Encoding: chunked
< Keep-Alive: timeout=3
< X-ClickHouse-Summary: {"read_rows":"0","read_bytes":"0","written_rows":"0","written_bytes":"0","total_rows_to_read":"0"}
<
<html><body>Absolute Path File</body></html>
* Connection #0 to host localhost left intact
$ curl -vv -H 'XXX:xxx' 'http://localhost:8123/get_relative_path_static_handler'
* Trying ::1...
* Connected to localhost (::1) port 8123 (#0)
> GET /get_relative_path_static_handler HTTP/1.1
> Host: localhost:8123
> User-Agent: curl/7.47.0
> Accept: */*
> XXX:xxx
>
< HTTP/1.1 200 OK
< Date: Wed, 29 Apr 2020 04:18:31 GMT
< Connection: Keep-Alive
< Content-Type: text/html; charset=UTF-8
< Transfer-Encoding: chunked
< Keep-Alive: timeout=3
< X-ClickHouse-Summary: {"read_rows":"0","read_bytes":"0","written_rows":"0","written_bytes":"0","total_rows_to_read":"0"}
<
<html><body>Relative Path File</body></html>
* Connection #0 to host localhost left intact
```
[Оригинальная статья](https://clickhouse.tech/docs/ru/interfaces/http_interface/) <!--hide--> [Оригинальная статья](https://clickhouse.tech/docs/ru/interfaces/http_interface/) <!--hide-->

View File

@ -74,6 +74,7 @@ SELECT name, status FROM system.dictionaries;
## FLUSH LOGS {#query_language-system-flush_logs} ## FLUSH LOGS {#query_language-system-flush_logs}
Записывает буферы логов в системные таблицы (например system.query\_log). Позволяет не ждать 7.5 секунд при отладке. Записывает буферы логов в системные таблицы (например system.query\_log). Позволяет не ждать 7.5 секунд при отладке.
Если буфер логов пустой, то этот запрос просто создаст системные таблицы.
## RELOAD CONFIG {#query_language-system-reload-config} ## RELOAD CONFIG {#query_language-system-reload-config}

View File

@ -145,10 +145,13 @@ def build_website(args):
'public', 'public',
'node_modules', 'node_modules',
'templates', 'templates',
'feathericons',
'locale' 'locale'
) )
) )
shutil.copy2(
os.path.join(args.website_dir, 'js', 'embedd.min.js'),
os.path.join(args.output_dir, 'js', 'embedd.min.js')
)
for root, _, filenames in os.walk(args.output_dir): for root, _, filenames in os.walk(args.output_dir):
for filename in filenames: for filename in filenames:

View File

@ -986,7 +986,10 @@ private:
/// Process the query that doesn't require transferring data blocks to the server. /// Process the query that doesn't require transferring data blocks to the server.
void processOrdinaryQuery() void processOrdinaryQuery()
{ {
/// We will always rewrite query (even if there are no query_parameters) because it will help to find errors in query formatter. /// Rewrite query only when we have query parameters.
/// Note that if query is rewritten, comments in query are lost.
/// But the user often wants to see comments in server logs, query log, processlist, etc.
if (!query_parameters.empty())
{ {
/// Replace ASTQueryParameter with ASTLiteral for prepared statements. /// Replace ASTQueryParameter with ASTLiteral for prepared statements.
ReplaceQueryParameterVisitor visitor(query_parameters); ReplaceQueryParameterVisitor visitor(query_parameters);

View File

@ -14,6 +14,7 @@ set (CLICKHOUSE_ODBC_BRIDGE_SOURCES
set (CLICKHOUSE_ODBC_BRIDGE_LINK set (CLICKHOUSE_ODBC_BRIDGE_LINK
PRIVATE PRIVATE
clickhouse_parsers clickhouse_parsers
clickhouse_aggregate_functions
daemon daemon
dbms dbms
Poco::Data Poco::Data

View File

@ -64,19 +64,23 @@ namespace
bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override bool parseImpl(Pos & pos, ASTPtr & node, Expected & expected) override
{ {
if (ParserCreateUserQuery{}.enableAttachMode(true).parse(pos, node, expected)) ParserCreateUserQuery create_user_p;
return true; ParserCreateRoleQuery create_role_p;
if (ParserCreateRoleQuery{}.enableAttachMode(true).parse(pos, node, expected)) ParserCreateRowPolicyQuery create_policy_p;
return true; ParserCreateQuotaQuery create_quota_p;
if (ParserCreateRowPolicyQuery{}.enableAttachMode(true).parse(pos, node, expected)) ParserCreateSettingsProfileQuery create_profile_p;
return true; ParserGrantQuery grant_p;
if (ParserCreateQuotaQuery{}.enableAttachMode(true).parse(pos, node, expected))
return true; create_user_p.useAttachMode();
if (ParserCreateSettingsProfileQuery{}.enableAttachMode(true).parse(pos, node, expected)) create_role_p.useAttachMode();
return true; create_policy_p.useAttachMode();
if (ParserGrantQuery{}.enableAttachMode(true).parse(pos, node, expected)) create_quota_p.useAttachMode();
return true; create_profile_p.useAttachMode();
return false; grant_p.useAttachMode();
return create_user_p.parse(pos, node, expected) || create_role_p.parse(pos, node, expected)
|| create_policy_p.parse(pos, node, expected) || create_quota_p.parse(pos, node, expected)
|| create_profile_p.parse(pos, node, expected) || grant_p.parse(pos, node, expected);
} }
}; };
@ -261,7 +265,9 @@ namespace
/// Calculates the path for storing a map of name of access entity to UUID for access entities of some type. /// Calculates the path for storing a map of name of access entity to UUID for access entities of some type.
std::filesystem::path getListFilePath(const String & directory_path, EntityType type) std::filesystem::path getListFilePath(const String & directory_path, EntityType type)
{ {
std::string_view file_name = EntityTypeInfo::get(type).list_filename; String file_name = EntityTypeInfo::get(type).plural_raw_name;
boost::to_lower(file_name);
file_name += ".list";
return std::filesystem::path(directory_path).append(file_name); return std::filesystem::path(directory_path).append(file_name);
} }

View File

@ -45,11 +45,13 @@ struct IAccessEntity
struct TypeInfo struct TypeInfo
{ {
const char * const raw_name; const char * const raw_name;
const char * const plural_raw_name;
const String name; /// Uppercased with spaces instead of underscores, e.g. "SETTINGS PROFILE". const String name; /// Uppercased with spaces instead of underscores, e.g. "SETTINGS PROFILE".
const String alias; /// Alias of the keyword or empty string, e.g. "PROFILE". const String alias; /// Alias of the keyword or empty string, e.g. "PROFILE".
const String plural_name; /// Uppercased with spaces plural name, e.g. "SETTINGS PROFILES".
const String plural_alias; /// Uppercased with spaces plural name alias, e.g. "PROFILES".
const String name_for_output_with_entity_name; /// Lowercased with spaces instead of underscores, e.g. "settings profile". const String name_for_output_with_entity_name; /// Lowercased with spaces instead of underscores, e.g. "settings profile".
const char unique_char; /// Unique character for this type. E.g. 'P' for SETTINGS_PROFILE. const char unique_char; /// Unique character for this type. E.g. 'P' for SETTINGS_PROFILE.
const String list_filename; /// Name of the file containing list of objects of this type, including the file extension ".list".
const int not_found_error_code; const int not_found_error_code;
static const TypeInfo & get(Type type_); static const TypeInfo & get(Type type_);
@ -69,6 +71,18 @@ struct IAccessEntity
friend bool operator ==(const IAccessEntity & lhs, const IAccessEntity & rhs) { return lhs.equal(rhs); } friend bool operator ==(const IAccessEntity & lhs, const IAccessEntity & rhs) { return lhs.equal(rhs); }
friend bool operator !=(const IAccessEntity & lhs, const IAccessEntity & rhs) { return !(lhs == rhs); } friend bool operator !=(const IAccessEntity & lhs, const IAccessEntity & rhs) { return !(lhs == rhs); }
struct LessByName
{
bool operator()(const IAccessEntity & lhs, const IAccessEntity & rhs) const { return (lhs.getName() < rhs.getName()); }
bool operator()(const std::shared_ptr<const IAccessEntity> & lhs, const std::shared_ptr<const IAccessEntity> & rhs) const { return operator()(*lhs, *rhs); }
};
struct LessByTypeAndName
{
bool operator()(const IAccessEntity & lhs, const IAccessEntity & rhs) const { return (lhs.getType() < rhs.getType()) || ((lhs.getType() == rhs.getType()) && (lhs.getName() < rhs.getName())); }
bool operator()(const std::shared_ptr<const IAccessEntity> & lhs, const std::shared_ptr<const IAccessEntity> & rhs) const { return operator()(*lhs, *rhs); }
};
protected: protected:
String name; String name;
@ -87,44 +101,49 @@ using AccessEntityPtr = std::shared_ptr<const IAccessEntity>;
inline const IAccessEntity::TypeInfo & IAccessEntity::TypeInfo::get(Type type_) inline const IAccessEntity::TypeInfo & IAccessEntity::TypeInfo::get(Type type_)
{ {
static constexpr auto make_info = [](const char * raw_name_, char unique_char_, const char * list_filename_, int not_found_error_code_) static constexpr auto make_info = [](const char * raw_name_, const char * plural_raw_name_, char unique_char_, int not_found_error_code_)
{ {
String init_name = raw_name_; String init_names[2] = {raw_name_, plural_raw_name_};
boost::to_upper(init_name); String init_aliases[2];
boost::replace_all(init_name, "_", " "); for (size_t i = 0; i != std::size(init_names); ++i)
String init_alias; {
if (auto underscore_pos = init_name.find_first_of(" "); underscore_pos != String::npos) String & init_name = init_names[i];
init_alias = init_name.substr(underscore_pos + 1); String & init_alias = init_aliases[i];
String init_name_for_output_with_entity_name = init_name; boost::to_upper(init_name);
boost::replace_all(init_name, "_", " ");
if (auto underscore_pos = init_name.find_first_of(" "); underscore_pos != String::npos)
init_alias = init_name.substr(underscore_pos + 1);
}
String init_name_for_output_with_entity_name = init_names[0];
boost::to_lower(init_name_for_output_with_entity_name); boost::to_lower(init_name_for_output_with_entity_name);
return TypeInfo{raw_name_, std::move(init_name), std::move(init_alias), std::move(init_name_for_output_with_entity_name), unique_char_, list_filename_, not_found_error_code_}; return TypeInfo{raw_name_, plural_raw_name_, std::move(init_names[0]), std::move(init_aliases[0]), std::move(init_names[1]), std::move(init_aliases[1]), std::move(init_name_for_output_with_entity_name), unique_char_, not_found_error_code_};
}; };
switch (type_) switch (type_)
{ {
case Type::USER: case Type::USER:
{ {
static const auto info = make_info("USER", 'U', "users.list", ErrorCodes::UNKNOWN_USER); static const auto info = make_info("USER", "USERS", 'U', ErrorCodes::UNKNOWN_USER);
return info; return info;
} }
case Type::ROLE: case Type::ROLE:
{ {
static const auto info = make_info("ROLE", 'R', "roles.list", ErrorCodes::UNKNOWN_ROLE); static const auto info = make_info("ROLE", "ROLES", 'R', ErrorCodes::UNKNOWN_ROLE);
return info; return info;
} }
case Type::SETTINGS_PROFILE: case Type::SETTINGS_PROFILE:
{ {
static const auto info = make_info("SETTINGS_PROFILE", 'S', "settings_profiles.list", ErrorCodes::THERE_IS_NO_PROFILE); static const auto info = make_info("SETTINGS_PROFILE", "SETTINGS_PROFILES", 'S', ErrorCodes::THERE_IS_NO_PROFILE);
return info; return info;
} }
case Type::ROW_POLICY: case Type::ROW_POLICY:
{ {
static const auto info = make_info("ROW_POLICY", 'P', "row_policies.list", ErrorCodes::UNKNOWN_ROW_POLICY); static const auto info = make_info("ROW_POLICY", "ROW_POLICIES", 'P', ErrorCodes::UNKNOWN_ROW_POLICY);
return info; return info;
} }
case Type::QUOTA: case Type::QUOTA:
{ {
static const auto info = make_info("QUOTA", 'Q', "quotas.list", ErrorCodes::UNKNOWN_QUOTA); static const auto info = make_info("QUOTA", "QUOTAS", 'Q', ErrorCodes::UNKNOWN_QUOTA);
return info; return info;
} }
case Type::MAX: break; case Type::MAX: break;

View File

@ -24,16 +24,141 @@ namespace
using EntityType = IAccessStorage::EntityType; using EntityType = IAccessStorage::EntityType;
using EntityTypeInfo = IAccessStorage::EntityTypeInfo; using EntityTypeInfo = IAccessStorage::EntityTypeInfo;
bool isNotFoundErrorCode(int error_code)
String outputID(const UUID & id)
{ {
if (error_code == ErrorCodes::ACCESS_ENTITY_NOT_FOUND) return "ID(" + toString(id) + ")";
return true; }
for (auto type : ext::range(EntityType::MAX)) String outputTypeAndNameOrID(const IAccessStorage & storage, const UUID & id)
if (error_code == EntityTypeInfo::get(type).not_found_error_code) {
return true; auto entity = storage.tryRead(id);
if (entity)
return entity->outputTypeAndName();
return outputID(id);
}
return false;
template <typename Func, typename ResultType = std::result_of_t<Func()>>
ResultType doTry(const Func & func)
{
try
{
return func();
}
catch (Exception &)
{
return {};
}
}
template <bool ignore_errors, typename T, typename ApplyFunc, typename GetNameFunc = std::nullptr_t,
typename ResultTypeOfApplyFunc = std::result_of_t<ApplyFunc(T)>,
typename ResultType = std::conditional_t<std::is_same_v<ResultTypeOfApplyFunc, void>, void, std::vector<ResultTypeOfApplyFunc>>>
ResultType applyToMultipleEntities(
const std::vector<T> & multiple_entities,
const ApplyFunc & apply_function,
const char * error_message_format [[maybe_unused]] = nullptr,
const GetNameFunc & get_name_function [[maybe_unused]] = nullptr)
{
std::optional<Exception> exception;
std::vector<bool> success;
auto helper = [&](const auto & apply_and_store_result_function)
{
for (size_t i = 0; i != multiple_entities.size(); ++i)
{
try
{
apply_and_store_result_function(multiple_entities[i]);
if constexpr (!ignore_errors)
success[i] = true;
}
catch (Exception & e)
{
if (!ignore_errors && !exception)
exception.emplace(e);
}
catch (Poco::Exception & e)
{
if (!ignore_errors && !exception)
exception.emplace(Exception::CreateFromPocoTag{}, e);
}
catch (std::exception & e)
{
if (!ignore_errors && !exception)
exception.emplace(Exception::CreateFromSTDTag{}, e);
}
}
};
if constexpr (std::is_same_v<ResultType, void>)
{
if (multiple_entities.empty())
return;
if (multiple_entities.size() == 1)
{
apply_function(multiple_entities.front());
return;
}
if constexpr (!ignore_errors)
success.resize(multiple_entities.size(), false);
helper(apply_function);
if (ignore_errors || !exception)
return;
}
else
{
ResultType result;
if (multiple_entities.empty())
return result;
if (multiple_entities.size() == 1)
{
result.emplace_back(apply_function(multiple_entities.front()));
return result;
}
result.reserve(multiple_entities.size());
if constexpr (!ignore_errors)
success.resize(multiple_entities.size(), false);
helper([&](const T & entity) { result.emplace_back(apply_function(entity)); });
if (ignore_errors || !exception)
return result;
}
if constexpr (!ignore_errors)
{
Strings succeeded_names_list;
Strings failed_names_list;
for (size_t i = 0; i != multiple_entities.size(); ++i)
{
const auto & entity = multiple_entities[i];
String name = get_name_function(entity);
if (success[i])
succeeded_names_list.emplace_back(name);
else
failed_names_list.emplace_back(name);
}
String succeeded_names = boost::algorithm::join(succeeded_names_list, ", ");
String failed_names = boost::algorithm::join(failed_names_list, ", ");
if (succeeded_names.empty())
succeeded_names = "none";
String error_message = error_message_format;
boost::replace_all(error_message, "{succeeded_names}", succeeded_names);
boost::replace_all(error_message, "{failed_names}", failed_names);
exception->addMessage(error_message);
exception->rethrow();
}
__builtin_unreachable();
} }
} }
@ -91,14 +216,7 @@ bool IAccessStorage::exists(const UUID & id) const
AccessEntityPtr IAccessStorage::tryReadBase(const UUID & id) const AccessEntityPtr IAccessStorage::tryReadBase(const UUID & id) const
{ {
try return doTry([&] { return readImpl(id); });
{
return readImpl(id);
}
catch (Exception &)
{
return nullptr;
}
} }
@ -110,14 +228,7 @@ String IAccessStorage::readName(const UUID & id) const
std::optional<String> IAccessStorage::tryReadName(const UUID & id) const std::optional<String> IAccessStorage::tryReadName(const UUID & id) const
{ {
try return doTry([&] { return std::optional<String>{readNameImpl(id)}; });
{
return readNameImpl(id);
}
catch (Exception &)
{
return {};
}
} }
@ -129,56 +240,25 @@ UUID IAccessStorage::insert(const AccessEntityPtr & entity)
std::vector<UUID> IAccessStorage::insert(const std::vector<AccessEntityPtr> & multiple_entities) std::vector<UUID> IAccessStorage::insert(const std::vector<AccessEntityPtr> & multiple_entities)
{ {
std::vector<UUID> ids; return applyToMultipleEntities</* ignore_errors = */ false>(
ids.reserve(multiple_entities.size()); multiple_entities,
String error_message; [this](const AccessEntityPtr & entity) { return insertImpl(entity, /* replace_if_exists = */ false); },
for (const auto & entity : multiple_entities) "Couldn't insert {failed_names}. Successfully inserted: {succeeded_names}",
{ [](const AccessEntityPtr & entity) { return entity->outputTypeAndName(); });
try
{
ids.push_back(insertImpl(entity, false));
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS)
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
}
}
if (!error_message.empty())
throw Exception(error_message, ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
return ids;
} }
std::optional<UUID> IAccessStorage::tryInsert(const AccessEntityPtr & entity) std::optional<UUID> IAccessStorage::tryInsert(const AccessEntityPtr & entity)
{ {
try return doTry([&] { return std::optional<UUID>{insertImpl(entity, false)}; });
{
return insertImpl(entity, false);
}
catch (Exception &)
{
return {};
}
} }
std::vector<UUID> IAccessStorage::tryInsert(const std::vector<AccessEntityPtr> & multiple_entities) std::vector<UUID> IAccessStorage::tryInsert(const std::vector<AccessEntityPtr> & multiple_entities)
{ {
std::vector<UUID> ids; return applyToMultipleEntities</* ignore_errors = */ true>(
ids.reserve(multiple_entities.size()); multiple_entities,
for (const auto & entity : multiple_entities) [this](const AccessEntityPtr & entity) { return insertImpl(entity, /* replace_if_exists = */ false); });
{
try
{
ids.push_back(insertImpl(entity, false));
}
catch (Exception &)
{
}
}
return ids;
} }
@ -190,11 +270,11 @@ UUID IAccessStorage::insertOrReplace(const AccessEntityPtr & entity)
std::vector<UUID> IAccessStorage::insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities) std::vector<UUID> IAccessStorage::insertOrReplace(const std::vector<AccessEntityPtr> & multiple_entities)
{ {
std::vector<UUID> ids; return applyToMultipleEntities</* ignore_errors = */ false>(
ids.reserve(multiple_entities.size()); multiple_entities,
for (const auto & entity : multiple_entities) [this](const AccessEntityPtr & entity) { return insertImpl(entity, /* replace_if_exists = */ true); },
ids.push_back(insertImpl(entity, true)); "Couldn't insert {failed_names}. Successfully inserted: {succeeded_names}",
return ids; [](const AccessEntityPtr & entity) -> String { return entity->outputTypeAndName(); });
} }
@ -206,60 +286,25 @@ void IAccessStorage::remove(const UUID & id)
void IAccessStorage::remove(const std::vector<UUID> & ids) void IAccessStorage::remove(const std::vector<UUID> & ids)
{ {
String error_message; applyToMultipleEntities</* ignore_errors = */ false>(
std::optional<int> error_code; ids,
for (const auto & id : ids) [this](const UUID & id) { removeImpl(id); },
{ "Couldn't remove {failed_names}. Successfully removed: {succeeded_names}",
try [this](const UUID & id) { return outputTypeAndNameOrID(*this, id); });
{
removeImpl(id);
}
catch (Exception & e)
{
if (!isNotFoundErrorCode(e.code()))
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
if (error_code && (*error_code != e.code()))
error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND;
else
error_code = e.code();
}
}
if (!error_message.empty())
throw Exception(error_message, *error_code);
} }
bool IAccessStorage::tryRemove(const UUID & id) bool IAccessStorage::tryRemove(const UUID & id)
{ {
try return doTry([&] { removeImpl(id); return true; });
{
removeImpl(id);
return true;
}
catch (Exception &)
{
return false;
}
} }
std::vector<UUID> IAccessStorage::tryRemove(const std::vector<UUID> & ids) std::vector<UUID> IAccessStorage::tryRemove(const std::vector<UUID> & ids)
{ {
std::vector<UUID> removed; return applyToMultipleEntities</* ignore_errors = */ true>(
removed.reserve(ids.size()); ids,
for (const auto & id : ids) [this](const UUID & id) { removeImpl(id); return id; });
{
try
{
removeImpl(id);
removed.push_back(id);
}
catch (Exception &)
{
}
}
return removed;
} }
@ -271,60 +316,25 @@ void IAccessStorage::update(const UUID & id, const UpdateFunc & update_func)
void IAccessStorage::update(const std::vector<UUID> & ids, const UpdateFunc & update_func) void IAccessStorage::update(const std::vector<UUID> & ids, const UpdateFunc & update_func)
{ {
String error_message; applyToMultipleEntities</* ignore_errors = */ false>(
std::optional<int> error_code; ids,
for (const auto & id : ids) [this, &update_func](const UUID & id) { updateImpl(id, update_func); },
{ "Couldn't update {failed_names}. Successfully updated: {succeeded_names}",
try [this](const UUID & id) { return outputTypeAndNameOrID(*this, id); });
{
updateImpl(id, update_func);
}
catch (Exception & e)
{
if (!isNotFoundErrorCode(e.code()))
throw;
error_message += (error_message.empty() ? "" : ". ") + e.message();
if (error_code && (*error_code != e.code()))
error_code = ErrorCodes::ACCESS_ENTITY_NOT_FOUND;
else
error_code = e.code();
}
}
if (!error_message.empty())
throw Exception(error_message, *error_code);
} }
bool IAccessStorage::tryUpdate(const UUID & id, const UpdateFunc & update_func) bool IAccessStorage::tryUpdate(const UUID & id, const UpdateFunc & update_func)
{ {
try return doTry([&] { updateImpl(id, update_func); return true; });
{
updateImpl(id, update_func);
return true;
}
catch (Exception &)
{
return false;
}
} }
std::vector<UUID> IAccessStorage::tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func) std::vector<UUID> IAccessStorage::tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func)
{ {
std::vector<UUID> updated; return applyToMultipleEntities</* ignore_errors = */ true>(
updated.reserve(ids.size()); ids,
for (const auto & id : ids) [this, &update_func](const UUID & id) { updateImpl(id, update_func); return id; });
{
try
{
updateImpl(id, update_func);
updated.push_back(id);
}
catch (Exception &)
{
}
}
return updated;
} }
@ -388,7 +398,7 @@ Poco::Logger * IAccessStorage::getLogger() const
void IAccessStorage::throwNotFound(const UUID & id) const void IAccessStorage::throwNotFound(const UUID & id) const
{ {
throw Exception("ID {" + toString(id) + "} not found in [" + getStorageName() + "]", ErrorCodes::ACCESS_ENTITY_NOT_FOUND); throw Exception(outputID(id) + " not found in [" + getStorageName() + "]", ErrorCodes::ACCESS_ENTITY_NOT_FOUND);
} }
@ -402,7 +412,7 @@ void IAccessStorage::throwNotFound(EntityType type, const String & name) const
void IAccessStorage::throwBadCast(const UUID & id, EntityType type, const String & name, EntityType required_type) void IAccessStorage::throwBadCast(const UUID & id, EntityType type, const String & name, EntityType required_type)
{ {
throw Exception( throw Exception(
"ID {" + toString(id) + "}: " + outputEntityTypeAndName(type, name) + " expected to be of type " + toString(required_type), outputID(id) + ": " + outputEntityTypeAndName(type, name) + " expected to be of type " + toString(required_type),
ErrorCodes::LOGICAL_ERROR); ErrorCodes::LOGICAL_ERROR);
} }
@ -410,7 +420,7 @@ void IAccessStorage::throwBadCast(const UUID & id, EntityType type, const String
void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, EntityType type, const String & name, EntityType existing_type, const String & existing_name) const void IAccessStorage::throwIDCollisionCannotInsert(const UUID & id, EntityType type, const String & name, EntityType existing_type, const String & existing_name) const
{ {
throw Exception( throw Exception(
outputEntityTypeAndName(type, name) + ": cannot insert because the ID {" + toString(id) + "} is already used by " outputEntityTypeAndName(type, name) + ": cannot insert because the " + outputID(id) + " is already used by "
+ outputEntityTypeAndName(existing_type, existing_name) + " in [" + getStorageName() + "]", + outputEntityTypeAndName(existing_type, existing_name) + " in [" + getStorageName() + "]",
ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS); ErrorCodes::ACCESS_ENTITY_ALREADY_EXISTS);
} }

View File

@ -1,7 +1,9 @@
#pragma once #pragma once
#include <Access/IAccessEntity.h> #include <Access/IAccessEntity.h>
#include <Access/ExtendedRoleSet.h> #include <Access/RolesOrUsersSet.h>
#include <ext/range.h>
#include <boost/algorithm/string/split.hpp>
#include <boost/lexical_cast.hpp> #include <boost/lexical_cast.hpp>
#include <chrono> #include <chrono>
@ -84,14 +86,15 @@ struct Quota : public IAccessEntity
struct KeyTypeInfo struct KeyTypeInfo
{ {
const char * const raw_name; const char * const raw_name;
const String name; /// Lowercased with spaces, e.g. "client key". const String name; /// Lowercased with underscores, e.g. "client_key".
const std::vector<KeyType> base_types; /// For combined types keeps base types, e.g. for CLIENT_KEY_OR_USER_NAME it keeps [KeyType::CLIENT_KEY, KeyType::USER_NAME].
static const KeyTypeInfo & get(KeyType type); static const KeyTypeInfo & get(KeyType type);
}; };
KeyType key_type = KeyType::NONE; KeyType key_type = KeyType::NONE;
/// Which roles or users should use this quota. /// Which roles or users should use this quota.
ExtendedRoleSet to_roles; RolesOrUsersSet to_roles;
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<Quota>(); }
@ -195,8 +198,21 @@ inline const Quota::KeyTypeInfo & Quota::KeyTypeInfo::get(KeyType type)
{ {
String init_name = raw_name_; String init_name = raw_name_;
boost::to_lower(init_name); boost::to_lower(init_name);
boost::replace_all(init_name, "_", " "); std::vector<KeyType> init_base_types;
return KeyTypeInfo{raw_name_, std::move(init_name)}; String replaced = boost::algorithm::replace_all_copy(init_name, "_or_", "|");
Strings tokens;
boost::algorithm::split(tokens, replaced, boost::is_any_of("|"));
if (tokens.size() > 1)
{
for (const auto & token : tokens)
for (auto kt : ext::range(KeyType::MAX))
if (KeyTypeInfo::get(kt).name == token)
{
init_base_types.push_back(kt);
break;
}
}
return KeyTypeInfo{raw_name_, std::move(init_name), std::move(init_base_types)};
}; };
switch (type) switch (type)

View File

@ -39,7 +39,7 @@ private:
QuotaPtr quota; QuotaPtr quota;
UUID quota_id; UUID quota_id;
const ExtendedRoleSet * roles = nullptr; const RolesOrUsersSet * roles = nullptr;
std::unordered_map<String /* quota key */, boost::shared_ptr<const Intervals>> key_to_intervals; std::unordered_map<String /* quota key */, boost::shared_ptr<const Intervals>> key_to_intervals;
}; };

View File

@ -1,9 +1,8 @@
#include <Access/RolesOrUsersSet.h>
#include <Access/ExtendedRoleSet.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/Role.h> #include <Access/Role.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTRolesOrUsersSet.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <IO/ReadHelpers.h> #include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h> #include <IO/WriteHelpers.h>
@ -20,51 +19,51 @@ namespace ErrorCodes
} }
ExtendedRoleSet::ExtendedRoleSet() = default; RolesOrUsersSet::RolesOrUsersSet() = default;
ExtendedRoleSet::ExtendedRoleSet(const ExtendedRoleSet & src) = default; RolesOrUsersSet::RolesOrUsersSet(const RolesOrUsersSet & src) = default;
ExtendedRoleSet & ExtendedRoleSet::operator =(const ExtendedRoleSet & src) = default; RolesOrUsersSet & RolesOrUsersSet::operator =(const RolesOrUsersSet & src) = default;
ExtendedRoleSet::ExtendedRoleSet(ExtendedRoleSet && src) = default; RolesOrUsersSet::RolesOrUsersSet(RolesOrUsersSet && src) = default;
ExtendedRoleSet & ExtendedRoleSet::operator =(ExtendedRoleSet && src) = default; RolesOrUsersSet & RolesOrUsersSet::operator =(RolesOrUsersSet && src) = default;
ExtendedRoleSet::ExtendedRoleSet(AllTag) RolesOrUsersSet::RolesOrUsersSet(AllTag)
{ {
all = true; all = true;
} }
ExtendedRoleSet::ExtendedRoleSet(const UUID & id) RolesOrUsersSet::RolesOrUsersSet(const UUID & id)
{ {
add(id); add(id);
} }
ExtendedRoleSet::ExtendedRoleSet(const std::vector<UUID> & ids_) RolesOrUsersSet::RolesOrUsersSet(const std::vector<UUID> & ids_)
{ {
add(ids_); add(ids_);
} }
ExtendedRoleSet::ExtendedRoleSet(const ASTExtendedRoleSet & ast) RolesOrUsersSet::RolesOrUsersSet(const ASTRolesOrUsersSet & ast)
{ {
init(ast, nullptr); init(ast, nullptr);
} }
ExtendedRoleSet::ExtendedRoleSet(const ASTExtendedRoleSet & ast, const std::optional<UUID> & current_user_id) RolesOrUsersSet::RolesOrUsersSet(const ASTRolesOrUsersSet & ast, const std::optional<UUID> & current_user_id)
{ {
init(ast, nullptr, current_user_id); init(ast, nullptr, current_user_id);
} }
ExtendedRoleSet::ExtendedRoleSet(const ASTExtendedRoleSet & ast, const AccessControlManager & manager) RolesOrUsersSet::RolesOrUsersSet(const ASTRolesOrUsersSet & ast, const AccessControlManager & manager)
{ {
init(ast, &manager); init(ast, &manager);
} }
ExtendedRoleSet::ExtendedRoleSet(const ASTExtendedRoleSet & ast, const AccessControlManager & manager, const std::optional<UUID> & current_user_id) RolesOrUsersSet::RolesOrUsersSet(const ASTRolesOrUsersSet & ast, const AccessControlManager & manager, const std::optional<UUID> & current_user_id)
{ {
init(ast, &manager, current_user_id); init(ast, &manager, current_user_id);
} }
void ExtendedRoleSet::init(const ASTExtendedRoleSet & ast, const AccessControlManager * manager, const std::optional<UUID> & current_user_id) void RolesOrUsersSet::init(const ASTRolesOrUsersSet & ast, const AccessControlManager * manager, const std::optional<UUID> & current_user_id)
{ {
all = ast.all; all = ast.all;
@ -73,20 +72,20 @@ void ExtendedRoleSet::init(const ASTExtendedRoleSet & ast, const AccessControlMa
if (ast.id_mode) if (ast.id_mode)
return parse<UUID>(name); return parse<UUID>(name);
assert(manager); assert(manager);
if (ast.can_contain_users && ast.can_contain_roles) if (ast.allow_user_names && ast.allow_role_names)
{ {
auto id = manager->find<User>(name); auto id = manager->find<User>(name);
if (id) if (id)
return *id; return *id;
return manager->getID<Role>(name); return manager->getID<Role>(name);
} }
else if (ast.can_contain_users) else if (ast.allow_user_names)
{ {
return manager->getID<User>(name); return manager->getID<User>(name);
} }
else else
{ {
assert(ast.can_contain_roles); assert(ast.allow_role_names);
return manager->getID<Role>(name); return manager->getID<Role>(name);
} }
}; };
@ -122,9 +121,9 @@ void ExtendedRoleSet::init(const ASTExtendedRoleSet & ast, const AccessControlMa
} }
std::shared_ptr<ASTExtendedRoleSet> ExtendedRoleSet::toAST() const std::shared_ptr<ASTRolesOrUsersSet> RolesOrUsersSet::toAST() const
{ {
auto ast = std::make_shared<ASTExtendedRoleSet>(); auto ast = std::make_shared<ASTRolesOrUsersSet>();
ast->id_mode = true; ast->id_mode = true;
ast->all = all; ast->all = all;
@ -148,9 +147,9 @@ std::shared_ptr<ASTExtendedRoleSet> ExtendedRoleSet::toAST() const
} }
std::shared_ptr<ASTExtendedRoleSet> ExtendedRoleSet::toASTWithNames(const AccessControlManager & manager) const std::shared_ptr<ASTRolesOrUsersSet> RolesOrUsersSet::toASTWithNames(const AccessControlManager & manager) const
{ {
auto ast = std::make_shared<ASTExtendedRoleSet>(); auto ast = std::make_shared<ASTRolesOrUsersSet>();
ast->all = all; ast->all = all;
if (!ids.empty()) if (!ids.empty())
@ -181,21 +180,21 @@ std::shared_ptr<ASTExtendedRoleSet> ExtendedRoleSet::toASTWithNames(const Access
} }
String ExtendedRoleSet::toString() const String RolesOrUsersSet::toString() const
{ {
auto ast = toAST(); auto ast = toAST();
return serializeAST(*ast); return serializeAST(*ast);
} }
String ExtendedRoleSet::toStringWithNames(const AccessControlManager & manager) const String RolesOrUsersSet::toStringWithNames(const AccessControlManager & manager) const
{ {
auto ast = toASTWithNames(manager); auto ast = toASTWithNames(manager);
return serializeAST(*ast); return serializeAST(*ast);
} }
Strings ExtendedRoleSet::toStringsWithNames(const AccessControlManager & manager) const Strings RolesOrUsersSet::toStringsWithNames(const AccessControlManager & manager) const
{ {
if (!all && ids.empty()) if (!all && ids.empty())
return {}; return {};
@ -233,13 +232,13 @@ Strings ExtendedRoleSet::toStringsWithNames(const AccessControlManager & manager
} }
bool ExtendedRoleSet::empty() const bool RolesOrUsersSet::empty() const
{ {
return ids.empty() && !all; return ids.empty() && !all;
} }
void ExtendedRoleSet::clear() void RolesOrUsersSet::clear()
{ {
ids.clear(); ids.clear();
all = false; all = false;
@ -247,26 +246,26 @@ void ExtendedRoleSet::clear()
} }
void ExtendedRoleSet::add(const UUID & id) void RolesOrUsersSet::add(const UUID & id)
{ {
ids.insert(id); ids.insert(id);
} }
void ExtendedRoleSet::add(const std::vector<UUID> & ids_) void RolesOrUsersSet::add(const std::vector<UUID> & ids_)
{ {
for (const auto & id : ids_) for (const auto & id : ids_)
add(id); add(id);
} }
bool ExtendedRoleSet::match(const UUID & id) const bool RolesOrUsersSet::match(const UUID & id) const
{ {
return (all || ids.count(id)) && !except_ids.count(id); return (all || ids.count(id)) && !except_ids.count(id);
} }
bool ExtendedRoleSet::match(const UUID & user_id, const boost::container::flat_set<UUID> & enabled_roles) const bool RolesOrUsersSet::match(const UUID & user_id, const boost::container::flat_set<UUID> & enabled_roles) const
{ {
if (!all && !ids.count(user_id)) if (!all && !ids.count(user_id))
{ {
@ -285,7 +284,7 @@ bool ExtendedRoleSet::match(const UUID & user_id, const boost::container::flat_s
} }
std::vector<UUID> ExtendedRoleSet::getMatchingIDs() const std::vector<UUID> RolesOrUsersSet::getMatchingIDs() const
{ {
if (all) if (all)
throw Exception("getAllMatchingIDs() can't get ALL ids without manager", ErrorCodes::LOGICAL_ERROR); throw Exception("getAllMatchingIDs() can't get ALL ids without manager", ErrorCodes::LOGICAL_ERROR);
@ -295,7 +294,7 @@ std::vector<UUID> ExtendedRoleSet::getMatchingIDs() const
} }
std::vector<UUID> ExtendedRoleSet::getMatchingIDs(const AccessControlManager & manager) const std::vector<UUID> RolesOrUsersSet::getMatchingIDs(const AccessControlManager & manager) const
{ {
if (!all) if (!all)
return getMatchingIDs(); return getMatchingIDs();
@ -316,7 +315,7 @@ std::vector<UUID> ExtendedRoleSet::getMatchingIDs(const AccessControlManager & m
} }
bool operator ==(const ExtendedRoleSet & lhs, const ExtendedRoleSet & rhs) bool operator ==(const RolesOrUsersSet & lhs, const RolesOrUsersSet & rhs)
{ {
return (lhs.all == rhs.all) && (lhs.ids == rhs.ids) && (lhs.except_ids == rhs.except_ids); return (lhs.all == rhs.all) && (lhs.ids == rhs.ids) && (lhs.except_ids == rhs.except_ids);
} }

View File

@ -8,35 +8,35 @@
namespace DB namespace DB
{ {
class ASTExtendedRoleSet; class ASTRolesOrUsersSet;
class AccessControlManager; class AccessControlManager;
/// Represents a set of users/roles like /// Represents a set of users/roles like
/// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...] /// {user_name | role_name | CURRENT_USER} [,...] | NONE | ALL | ALL EXCEPT {user_name | role_name | CURRENT_USER} [,...]
/// Similar to ASTExtendedRoleSet, but with IDs instead of names. /// Similar to ASTRolesOrUsersSet, but with IDs instead of names.
struct ExtendedRoleSet struct RolesOrUsersSet
{ {
ExtendedRoleSet(); RolesOrUsersSet();
ExtendedRoleSet(const ExtendedRoleSet & src); RolesOrUsersSet(const RolesOrUsersSet & src);
ExtendedRoleSet & operator =(const ExtendedRoleSet & src); RolesOrUsersSet & operator =(const RolesOrUsersSet & src);
ExtendedRoleSet(ExtendedRoleSet && src); RolesOrUsersSet(RolesOrUsersSet && src);
ExtendedRoleSet & operator =(ExtendedRoleSet && src); RolesOrUsersSet & operator =(RolesOrUsersSet && src);
struct AllTag {}; struct AllTag {};
ExtendedRoleSet(AllTag); RolesOrUsersSet(AllTag);
ExtendedRoleSet(const UUID & id); RolesOrUsersSet(const UUID & id);
ExtendedRoleSet(const std::vector<UUID> & ids_); RolesOrUsersSet(const std::vector<UUID> & ids_);
/// The constructor from AST requires the AccessControlManager if `ast.id_mode == false`. /// The constructor from AST requires the AccessControlManager if `ast.id_mode == false`.
ExtendedRoleSet(const ASTExtendedRoleSet & ast); RolesOrUsersSet(const ASTRolesOrUsersSet & ast);
ExtendedRoleSet(const ASTExtendedRoleSet & ast, const std::optional<UUID> & current_user_id); RolesOrUsersSet(const ASTRolesOrUsersSet & ast, const std::optional<UUID> & current_user_id);
ExtendedRoleSet(const ASTExtendedRoleSet & ast, const AccessControlManager & manager); RolesOrUsersSet(const ASTRolesOrUsersSet & ast, const AccessControlManager & manager);
ExtendedRoleSet(const ASTExtendedRoleSet & ast, const AccessControlManager & manager, const std::optional<UUID> & current_user_id); RolesOrUsersSet(const ASTRolesOrUsersSet & ast, const AccessControlManager & manager, const std::optional<UUID> & current_user_id);
std::shared_ptr<ASTExtendedRoleSet> toAST() const; std::shared_ptr<ASTRolesOrUsersSet> toAST() const;
std::shared_ptr<ASTExtendedRoleSet> toASTWithNames(const AccessControlManager & manager) const; std::shared_ptr<ASTRolesOrUsersSet> toASTWithNames(const AccessControlManager & manager) const;
String toString() const; String toString() const;
String toStringWithNames(const AccessControlManager & manager) const; String toStringWithNames(const AccessControlManager & manager) const;
@ -47,7 +47,7 @@ struct ExtendedRoleSet
void add(const UUID & id); void add(const UUID & id);
void add(const std::vector<UUID> & ids_); void add(const std::vector<UUID> & ids_);
/// Checks if a specified ID matches this ExtendedRoleSet. /// Checks if a specified ID matches this RolesOrUsersSet.
bool match(const UUID & id) const; bool match(const UUID & id) const;
bool match(const UUID & user_id, const boost::container::flat_set<UUID> & enabled_roles) const; bool match(const UUID & user_id, const boost::container::flat_set<UUID> & enabled_roles) const;
@ -57,15 +57,15 @@ struct ExtendedRoleSet
/// Returns a list of matching users and roles. /// Returns a list of matching users and roles.
std::vector<UUID> getMatchingIDs(const AccessControlManager & manager) const; std::vector<UUID> getMatchingIDs(const AccessControlManager & manager) const;
friend bool operator ==(const ExtendedRoleSet & lhs, const ExtendedRoleSet & rhs); friend bool operator ==(const RolesOrUsersSet & lhs, const RolesOrUsersSet & rhs);
friend bool operator !=(const ExtendedRoleSet & lhs, const ExtendedRoleSet & rhs) { return !(lhs == rhs); } friend bool operator !=(const RolesOrUsersSet & lhs, const RolesOrUsersSet & rhs) { return !(lhs == rhs); }
boost::container::flat_set<UUID> ids; boost::container::flat_set<UUID> ids;
bool all = false; bool all = false;
boost::container::flat_set<UUID> except_ids; boost::container::flat_set<UUID> except_ids;
private: private:
void init(const ASTExtendedRoleSet & ast, const AccessControlManager * manager = nullptr, const std::optional<UUID> & current_user_id = {}); void init(const ASTRolesOrUsersSet & ast, const AccessControlManager * manager = nullptr, const std::optional<UUID> & current_user_id = {});
}; };
} }

View File

@ -11,22 +11,6 @@ namespace ErrorCodes
} }
String RowPolicy::NameParts::getName() const
{
String name;
name.reserve(database.length() + table_name.length() + short_name.length() + 6);
name += backQuoteIfNeed(short_name);
name += " ON ";
if (!database.empty())
{
name += backQuoteIfNeed(database);
name += '.';
}
name += backQuoteIfNeed(table_name);
return name;
}
void RowPolicy::setDatabase(const String & database) void RowPolicy::setDatabase(const String & database)
{ {
name_parts.database = database; name_parts.database = database;

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <Access/IAccessEntity.h> #include <Access/IAccessEntity.h>
#include <Access/ExtendedRoleSet.h> #include <Access/RolesOrUsersSet.h>
#include <array> #include <array>
@ -23,7 +23,9 @@ struct RowPolicy : public IAccessEntity
String database; String database;
String table_name; String table_name;
bool empty() const { return short_name.empty(); }
String getName() const; String getName() const;
String toString() const { return getName(); }
auto toTuple() const { return std::tie(short_name, database, table_name); } auto toTuple() const { return std::tie(short_name, database, table_name); }
friend bool operator ==(const NameParts & left, const NameParts & right) { return left.toTuple() == right.toTuple(); } friend bool operator ==(const NameParts & left, const NameParts & right) { return left.toTuple() == right.toTuple(); }
friend bool operator !=(const NameParts & left, const NameParts & right) { return left.toTuple() != right.toTuple(); } friend bool operator !=(const NameParts & left, const NameParts & right) { return left.toTuple() != right.toTuple(); }
@ -89,7 +91,7 @@ struct RowPolicy : public IAccessEntity
Type getType() const override { return TYPE; } Type getType() const override { return TYPE; }
/// Which roles or users should use this row policy. /// Which roles or users should use this row policy.
ExtendedRoleSet to_roles; RolesOrUsersSet to_roles;
private: private:
void setName(const String & name_) override; void setName(const String & name_) override;
@ -153,4 +155,20 @@ inline String toString(RowPolicy::ConditionType type)
return RowPolicy::ConditionTypeInfo::get(type).raw_name; return RowPolicy::ConditionTypeInfo::get(type).raw_name;
} }
inline String RowPolicy::NameParts::getName() const
{
String name;
name.reserve(database.length() + table_name.length() + short_name.length() + 6);
name += backQuoteIfNeed(short_name);
name += " ON ";
if (!database.empty())
{
name += backQuoteIfNeed(database);
name += '.';
}
name += backQuoteIfNeed(table_name);
return name;
}
} }

View File

@ -27,7 +27,7 @@ private:
void setPolicy(const RowPolicyPtr & policy_); void setPolicy(const RowPolicyPtr & policy_);
RowPolicyPtr policy; RowPolicyPtr policy;
const ExtendedRoleSet * roles = nullptr; const RolesOrUsersSet * roles = nullptr;
std::shared_ptr<const std::pair<String, String>> database_and_table_name; std::shared_ptr<const std::pair<String, String>> database_and_table_name;
ASTPtr parsed_conditions[RowPolicy::MAX_CONDITION_TYPE]; ASTPtr parsed_conditions[RowPolicy::MAX_CONDITION_TYPE];
}; };

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <Access/IAccessEntity.h> #include <Access/IAccessEntity.h>
#include <Access/ExtendedRoleSet.h> #include <Access/RolesOrUsersSet.h>
#include <Access/SettingsProfileElement.h> #include <Access/SettingsProfileElement.h>
@ -14,7 +14,7 @@ struct SettingsProfile : public IAccessEntity
SettingsProfileElements elements; SettingsProfileElements elements;
/// Which roles or users should use this settings profile. /// Which roles or users should use this settings profile.
ExtendedRoleSet to_roles; RolesOrUsersSet to_roles;
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;
std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<SettingsProfile>(); } std::shared_ptr<IAccessEntity> clone() const override { return cloneImpl<SettingsProfile>(); }

View File

@ -5,7 +5,7 @@
#include <Access/AllowedClientHosts.h> #include <Access/AllowedClientHosts.h>
#include <Access/GrantedAccess.h> #include <Access/GrantedAccess.h>
#include <Access/GrantedRoles.h> #include <Access/GrantedRoles.h>
#include <Access/ExtendedRoleSet.h> #include <Access/RolesOrUsersSet.h>
#include <Access/SettingsProfileElement.h> #include <Access/SettingsProfileElement.h>
@ -19,7 +19,7 @@ struct User : public IAccessEntity
AllowedClientHosts allowed_client_hosts = AllowedClientHosts::AnyHostTag{}; AllowedClientHosts allowed_client_hosts = AllowedClientHosts::AnyHostTag{};
GrantedAccess access; GrantedAccess access;
GrantedRoles granted_roles; GrantedRoles granted_roles;
ExtendedRoleSet default_roles = ExtendedRoleSet::AllTag{}; RolesOrUsersSet default_roles = RolesOrUsersSet::AllTag{};
SettingsProfileElements settings; SettingsProfileElements settings;
bool equal(const IAccessEntity & other) const override; bool equal(const IAccessEntity & other) const override;

View File

@ -17,7 +17,6 @@ SRCS(
EnabledRolesInfo.cpp EnabledRolesInfo.cpp
EnabledRowPolicies.cpp EnabledRowPolicies.cpp
EnabledSettings.cpp EnabledSettings.cpp
ExtendedRoleSet.cpp
GrantedAccess.cpp GrantedAccess.cpp
GrantedRoles.cpp GrantedRoles.cpp
IAccessEntity.cpp IAccessEntity.cpp
@ -29,6 +28,7 @@ SRCS(
QuotaUsage.cpp QuotaUsage.cpp
Role.cpp Role.cpp
RoleCache.cpp RoleCache.cpp
RolesOrUsersSet.cpp
RowPolicy.cpp RowPolicy.cpp
RowPolicyCache.cpp RowPolicyCache.cpp
SettingsConstraints.cpp SettingsConstraints.cpp

View File

@ -36,7 +36,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{ {
return std::make_shared<AggregateFunctionArray>(nested_function, arguments); return std::make_shared<AggregateFunctionArray>(nested_function, arguments);
} }

View File

@ -7,6 +7,12 @@
namespace DB namespace DB
{ {
AggregateFunctionPtr AggregateFunctionCount::getOwnNullAdapter(
const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const
{
return std::make_shared<AggregateFunctionCountNotNullUnary>(types[0], params);
}
namespace namespace
{ {
@ -22,7 +28,7 @@ AggregateFunctionPtr createAggregateFunctionCount(const std::string & name, cons
void registerAggregateFunctionCount(AggregateFunctionFactory & factory) void registerAggregateFunctionCount(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("count", createAggregateFunctionCount, AggregateFunctionFactory::CaseInsensitive); factory.registerFunction("count", {createAggregateFunctionCount, {true}}, AggregateFunctionFactory::CaseInsensitive);
} }
} }

View File

@ -68,16 +68,14 @@ public:
data(place).count = new_count; data(place).count = new_count;
} }
/// The function returns non-Nullable type even when wrapped with Null combinator. AggregateFunctionPtr getOwnNullAdapter(
bool returnDefaultWhenOnlyNull() const override const AggregateFunctionPtr &, const DataTypes & types, const Array & params) const override;
{
return true;
}
}; };
/// Simply count number of not-NULL values. /// Simply count number of not-NULL values.
class AggregateFunctionCountNotNullUnary final : public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary> class AggregateFunctionCountNotNullUnary final
: public IAggregateFunctionDataHelper<AggregateFunctionCountData, AggregateFunctionCountNotNullUnary>
{ {
public: public:
AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params) AggregateFunctionCountNotNullUnary(const DataTypePtr & argument, const Array & params)

View File

@ -29,18 +29,18 @@ namespace ErrorCodes
} }
void AggregateFunctionFactory::registerFunction(const String & name, Creator creator, CaseSensitiveness case_sensitiveness) void AggregateFunctionFactory::registerFunction(const String & name, Value creator_with_properties, CaseSensitiveness case_sensitiveness)
{ {
if (creator == nullptr) if (creator_with_properties.creator == nullptr)
throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided " throw Exception("AggregateFunctionFactory: the aggregate function " + name + " has been provided "
" a null constructor", ErrorCodes::LOGICAL_ERROR); " a null constructor", ErrorCodes::LOGICAL_ERROR);
if (!aggregate_functions.emplace(name, creator).second) if (!aggregate_functions.emplace(name, creator_with_properties).second)
throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique", throw Exception("AggregateFunctionFactory: the aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR); ErrorCodes::LOGICAL_ERROR);
if (case_sensitiveness == CaseInsensitive if (case_sensitiveness == CaseInsensitive
&& !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator).second) && !case_insensitive_aggregate_functions.emplace(Poco::toLower(name), creator_with_properties).second)
throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique", throw Exception("AggregateFunctionFactory: the case insensitive aggregate function name '" + name + "' is not unique",
ErrorCodes::LOGICAL_ERROR); ErrorCodes::LOGICAL_ERROR);
} }
@ -59,6 +59,7 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
const String & name, const String & name,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters, const Array & parameters,
AggregateFunctionProperties & out_properties,
int recursion_level) const int recursion_level) const
{ {
auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types); auto type_without_low_cardinality = convertLowCardinalityTypesToNested(argument_types);
@ -76,18 +77,15 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality); DataTypes nested_types = combinator->transformArguments(type_without_low_cardinality);
Array nested_parameters = combinator->transformParameters(parameters); Array nested_parameters = combinator->transformParameters(parameters);
AggregateFunctionPtr nested_function; bool has_null_arguments = std::any_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(),
[](const auto & type) { return type->onlyNull(); });
/// A little hack - if we have NULL arguments, don't even create nested function. AggregateFunctionPtr nested_function = getImpl(
/// Combinator will check if nested_function was created. name, nested_types, nested_parameters, out_properties, has_null_arguments, recursion_level);
if (name == "count" || std::none_of(type_without_low_cardinality.begin(), type_without_low_cardinality.end(), return combinator->transformAggregateFunction(nested_function, out_properties, type_without_low_cardinality, parameters);
[](const auto & type) { return type->onlyNull(); }))
nested_function = getImpl(name, nested_types, nested_parameters, recursion_level);
return combinator->transformAggregateFunction(nested_function, type_without_low_cardinality, parameters);
} }
auto res = getImpl(name, type_without_low_cardinality, parameters, recursion_level); auto res = getImpl(name, type_without_low_cardinality, parameters, out_properties, false, recursion_level);
if (!res) if (!res)
throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: AggregateFunctionFactory returned nullptr", ErrorCodes::LOGICAL_ERROR);
return res; return res;
@ -98,19 +96,35 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
const String & name_param, const String & name_param,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters, const Array & parameters,
AggregateFunctionProperties & out_properties,
bool has_null_arguments,
int recursion_level) const int recursion_level) const
{ {
String name = getAliasToOrName(name_param); String name = getAliasToOrName(name_param);
Value found;
/// Find by exact match. /// Find by exact match.
if (auto it = aggregate_functions.find(name); it != aggregate_functions.end()) if (auto it = aggregate_functions.find(name); it != aggregate_functions.end())
return it->second(name, argument_types, parameters); {
found = it->second;
}
/// Find by case-insensitive name. /// Find by case-insensitive name.
/// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names. /// Combinators cannot apply for case insensitive (SQL-style) aggregate function names. Only for native names.
if (recursion_level == 0) else if (recursion_level == 0)
{ {
if (auto it = case_insensitive_aggregate_functions.find(Poco::toLower(name)); it != case_insensitive_aggregate_functions.end()) if (auto jt = case_insensitive_aggregate_functions.find(Poco::toLower(name)); jt != case_insensitive_aggregate_functions.end())
return it->second(name, argument_types, parameters); found = jt->second;
}
if (found.creator)
{
out_properties = found.properties;
/// The case when aggregate function should return NULL on NULL arguments. This case is handled in "get" method.
if (!out_properties.returns_default_when_only_null && has_null_arguments)
return nullptr;
return found.creator(name, argument_types, parameters);
} }
/// Combinators of aggregate functions. /// Combinators of aggregate functions.
@ -126,9 +140,8 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
DataTypes nested_types = combinator->transformArguments(argument_types); DataTypes nested_types = combinator->transformArguments(argument_types);
Array nested_parameters = combinator->transformParameters(parameters); Array nested_parameters = combinator->transformParameters(parameters);
AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, recursion_level + 1); AggregateFunctionPtr nested_function = get(nested_name, nested_types, nested_parameters, out_properties, recursion_level + 1);
return combinator->transformAggregateFunction(nested_function, out_properties, argument_types, parameters);
return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
} }
auto hints = this->getHints(name); auto hints = this->getHints(name);
@ -140,10 +153,11 @@ AggregateFunctionPtr AggregateFunctionFactory::getImpl(
} }
AggregateFunctionPtr AggregateFunctionFactory::tryGet(const String & name, const DataTypes & argument_types, const Array & parameters) const AggregateFunctionPtr AggregateFunctionFactory::tryGet(
const String & name, const DataTypes & argument_types, const Array & parameters, AggregateFunctionProperties & out_properties) const
{ {
return isAggregateFunctionName(name) return isAggregateFunctionName(name)
? get(name, argument_types, parameters) ? get(name, argument_types, parameters, out_properties)
: nullptr; : nullptr;
} }

View File

@ -26,34 +26,51 @@ using DataTypes = std::vector<DataTypePtr>;
*/ */
using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>; using AggregateFunctionCreator = std::function<AggregateFunctionPtr(const String &, const DataTypes &, const Array &)>;
struct AggregateFunctionWithProperties
{
AggregateFunctionCreator creator;
AggregateFunctionProperties properties;
AggregateFunctionWithProperties() = default;
AggregateFunctionWithProperties(const AggregateFunctionWithProperties &) = default;
AggregateFunctionWithProperties & operator = (const AggregateFunctionWithProperties &) = default;
template <typename Creator, std::enable_if_t<!std::is_same_v<Creator, AggregateFunctionWithProperties>> * = nullptr>
AggregateFunctionWithProperties(Creator creator_, AggregateFunctionProperties properties_ = {})
: creator(std::forward<Creator>(creator_)), properties(std::move(properties_))
{
}
};
/** Creates an aggregate function by name. /** Creates an aggregate function by name.
*/ */
class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases<AggregateFunctionCreator> class AggregateFunctionFactory final : private boost::noncopyable, public IFactoryWithAliases<AggregateFunctionWithProperties>
{ {
public: public:
static AggregateFunctionFactory & instance(); static AggregateFunctionFactory & instance();
/// Register a function by its name. /// Register a function by its name.
/// No locking, you must register all functions before usage of get. /// No locking, you must register all functions before usage of get.
void registerFunction( void registerFunction(
const String & name, const String & name,
Creator creator, Value creator,
CaseSensitiveness case_sensitiveness = CaseSensitive); CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Throws an exception if not found. /// Throws an exception if not found.
AggregateFunctionPtr get( AggregateFunctionPtr get(
const String & name, const String & name,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters = {}, const Array & parameters,
AggregateFunctionProperties & out_properties,
int recursion_level = 0) const; int recursion_level = 0) const;
/// Returns nullptr if not found. /// Returns nullptr if not found.
AggregateFunctionPtr tryGet( AggregateFunctionPtr tryGet(
const String & name, const String & name,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters = {}) const; const Array & parameters,
AggregateFunctionProperties & out_properties) const;
bool isAggregateFunctionName(const String & name, int recursion_level = 0) const; bool isAggregateFunctionName(const String & name, int recursion_level = 0) const;
@ -62,19 +79,21 @@ private:
const String & name, const String & name,
const DataTypes & argument_types, const DataTypes & argument_types,
const Array & parameters, const Array & parameters,
AggregateFunctionProperties & out_properties,
bool has_null_arguments,
int recursion_level) const; int recursion_level) const;
private: private:
using AggregateFunctions = std::unordered_map<String, Creator>; using AggregateFunctions = std::unordered_map<String, Value>;
AggregateFunctions aggregate_functions; AggregateFunctions aggregate_functions;
/// Case insensitive aggregate functions will be additionally added here with lowercased name. /// Case insensitive aggregate functions will be additionally added here with lowercased name.
AggregateFunctions case_insensitive_aggregate_functions; AggregateFunctions case_insensitive_aggregate_functions;
const AggregateFunctions & getCreatorMap() const override { return aggregate_functions; } const AggregateFunctions & getMap() const override { return aggregate_functions; }
const AggregateFunctions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_aggregate_functions; } const AggregateFunctions & getCaseInsensitiveMap() const override { return case_insensitive_aggregate_functions; }
String getFactoryName() const override { return "AggregateFunctionFactory"; } String getFactoryName() const override { return "AggregateFunctionFactory"; }

View File

@ -33,7 +33,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{ {
return std::make_shared<AggregateFunctionForEach>(nested_function, arguments); return std::make_shared<AggregateFunctionForEach>(nested_function, arguments);
} }

View File

@ -31,7 +31,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{ {
return std::make_shared<AggregateFunctionIf>(nested_function, arguments); return std::make_shared<AggregateFunctionIf>(nested_function, arguments);
} }

View File

@ -34,7 +34,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array &) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array &) const override
{ {
const DataTypePtr & argument = arguments[0]; const DataTypePtr & argument = arguments[0];

View File

@ -25,7 +25,7 @@ public:
DataTypePtr getReturnType() const override DataTypePtr getReturnType() const override
{ {
return std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>()); return argument_types.front();
} }
void create(AggregateDataPtr) const override void create(AggregateDataPtr) const override

View File

@ -31,13 +31,11 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties & properties,
const DataTypes & arguments,
const Array & params) const override
{ {
/// Special case for 'count' function. It could be called with Nullable arguments
/// - that means - count number of calls, when all arguments are not NULL.
if (nested_function && nested_function->getName() == "count")
return std::make_shared<AggregateFunctionCountNotNullUnary>(arguments[0], params);
bool has_nullable_types = false; bool has_nullable_types = false;
bool has_null_types = false; bool has_null_types = false;
for (const auto & arg_type : arguments) for (const auto & arg_type : arguments)
@ -58,15 +56,23 @@ public:
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
if (has_null_types) if (has_null_types)
return std::make_shared<AggregateFunctionNothing>(arguments, params); {
/// Currently the only functions that returns not-NULL on all NULL arguments are count and uniq, and they returns UInt64.
if (properties.returns_default_when_only_null)
return std::make_shared<AggregateFunctionNothing>(DataTypes{
std::make_shared<DataTypeUInt64>()}, params);
else
return std::make_shared<AggregateFunctionNothing>(DataTypes{
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeNothing>())}, params);
}
assert(nested_function); assert(nested_function);
if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params)) if (auto adapter = nested_function->getOwnNullAdapter(nested_function, arguments, params))
return adapter; return adapter;
bool return_type_is_nullable = !nested_function->returnDefaultWhenOnlyNull() && nested_function->getReturnType()->canBeInsideNullable(); bool return_type_is_nullable = !properties.returns_default_when_only_null && nested_function->getReturnType()->canBeInsideNullable();
bool serialize_flag = return_type_is_nullable || nested_function->returnDefaultWhenOnlyNull(); bool serialize_flag = return_type_is_nullable || properties.returns_default_when_only_null;
if (arguments.size() == 1) if (arguments.size() == 1)
{ {

View File

@ -21,6 +21,7 @@ public:
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments, const DataTypes & arguments,
const Array & params) const override const Array & params) const override
{ {

View File

@ -43,6 +43,7 @@ public:
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments, const DataTypes & arguments,
const Array & params) const override const Array & params) const override
{ {

View File

@ -24,7 +24,10 @@ public:
} }
AggregateFunctionPtr transformAggregateFunction( AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const DataTypes & arguments, const Array & params) const override const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties &,
const DataTypes & arguments,
const Array & params) const override
{ {
return std::make_shared<AggregateFunctionState>(nested_function, arguments, params); return std::make_shared<AggregateFunctionState>(nested_function, arguments, params);
} }

View File

@ -18,21 +18,6 @@ namespace ErrorCodes
namespace namespace
{ {
template <bool overflow, bool tuple_argument>
struct SumMap
{
template <typename T>
using F = AggregateFunctionSumMap<T, overflow, tuple_argument>;
};
template <bool overflow, bool tuple_argument>
struct SumMapFiltered
{
template <typename T>
using F = AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>;
};
auto parseArguments(const std::string & name, const DataTypes & arguments) auto parseArguments(const std::string & name, const DataTypes & arguments)
{ {
DataTypes args; DataTypes args;
@ -85,30 +70,32 @@ auto parseArguments(const std::string & name, const DataTypes & arguments)
tuple_argument}; tuple_argument};
} }
template <bool overflow> // This function instantiates a particular overload of the sumMap family of
AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, const DataTypes & arguments, const Array & params) // functions.
// The template parameter MappedFunction<bool template_argument> is an aggregate
// function template that allows to choose the aggregate function variant that
// accepts either normal arguments or tuple argument.
template<template <bool tuple_argument> typename MappedFunction>
AggregateFunctionPtr createAggregateFunctionMap(const std::string & name, const DataTypes & arguments, const Array & params)
{ {
assertNoParameters(name, params); auto [keys_type, values_types, tuple_argument] = parseArguments(name, arguments);
auto [keys_type, values_types, tuple_argument] = parseArguments(name,
arguments);
AggregateFunctionPtr res; AggregateFunctionPtr res;
if (tuple_argument) if (tuple_argument)
{ {
res.reset(createWithNumericBasedType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments)); res.reset(createWithNumericBasedType<MappedFunction<true>::template F>(*keys_type, keys_type, values_types, arguments, params));
if (!res) if (!res)
res.reset(createWithDecimalType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments)); res.reset(createWithDecimalType<MappedFunction<true>::template F>(*keys_type, keys_type, values_types, arguments, params));
if (!res) if (!res)
res.reset(createWithStringType<SumMap<overflow, true>::template F>(*keys_type, keys_type, values_types, arguments)); res.reset(createWithStringType<MappedFunction<true>::template F>(*keys_type, keys_type, values_types, arguments, params));
} }
else else
{ {
res.reset(createWithNumericBasedType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments)); res.reset(createWithNumericBasedType<MappedFunction<false>::template F>(*keys_type, keys_type, values_types, arguments, params));
if (!res) if (!res)
res.reset(createWithDecimalType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments)); res.reset(createWithDecimalType<MappedFunction<false>::template F>(*keys_type, keys_type, values_types, arguments, params));
if (!res) if (!res)
res.reset(createWithStringType<SumMap<overflow, false>::template F>(*keys_type, keys_type, values_types, arguments)); res.reset(createWithStringType<MappedFunction<false>::template F>(*keys_type, keys_type, values_types, arguments, params));
} }
if (!res) if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT); throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
@ -116,52 +103,66 @@ AggregateFunctionPtr createAggregateFunctionSumMap(const std::string & name, con
return res; return res;
} }
template <bool overflow> // This template chooses the sumMap variant with given filtering and overflow
AggregateFunctionPtr createAggregateFunctionSumMapFiltered(const std::string & name, const DataTypes & arguments, const Array & params) // handling.
template <bool filtered, bool overflow>
struct SumMapVariants
{ {
if (params.size() != 1) // SumMapVariants chooses the `overflow` and `filtered` parameters of the
throw Exception("Aggregate function " + name + " requires exactly one parameter of Array type.", // aggregate functions. The `tuple_argument` and the value type `T` are left
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH); // as free parameters.
// DispatchOnTupleArgument chooses `tuple_argument`, and the value type `T`
Array keys_to_keep; // is left free.
if (!params.front().tryGet<Array>(keys_to_keep)) template <bool tuple_argument>
throw Exception("Aggregate function " + name + " requires an Array as parameter.", struct DispatchOnTupleArgument
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
auto [keys_type, values_types, tuple_argument] = parseArguments(name,
arguments);
AggregateFunctionPtr res;
if (tuple_argument)
{ {
res.reset(createWithNumericBasedType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); template <typename T>
if (!res) using F = std::conditional_t<filtered,
res.reset(createWithDecimalType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>,
if (!res) AggregateFunctionSumMap<T, overflow, tuple_argument>>;
res.reset(createWithStringType<SumMapFiltered<overflow, true>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params)); };
} };
else
{
res.reset(createWithNumericBasedType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithDecimalType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
if (!res)
res.reset(createWithStringType<SumMapFiltered<overflow, false>::template F>(*keys_type, keys_type, values_types, keys_to_keep, arguments, params));
}
if (!res)
throw Exception("Illegal type of argument for aggregate function " + name, ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
return res; // This template gives an aggregate function template that is narrowed
} // to accept either tuple argumen or normal argumens.
template <bool tuple_argument>
struct MinMapDispatchOnTupleArgument
{
template <typename T>
using F = AggregateFunctionMinMap<T, tuple_argument>;
};
// This template gives an aggregate function template that is narrowed
// to accept either tuple argumen or normal argumens.
template <bool tuple_argument>
struct MaxMapDispatchOnTupleArgument
{
template <typename T>
using F = AggregateFunctionMaxMap<T, tuple_argument>;
};
} }
void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory) void registerAggregateFunctionSumMap(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("sumMap", createAggregateFunctionSumMap<false /*overflow*/>); factory.registerFunction("sumMap", createAggregateFunctionMap<
factory.registerFunction("sumMapWithOverflow", createAggregateFunctionSumMap<true /*overflow*/>); SumMapVariants<false, false>::DispatchOnTupleArgument>);
factory.registerFunction("sumMapFiltered", createAggregateFunctionSumMapFiltered<false /*overflow*/>);
factory.registerFunction("sumMapFilteredWithOverflow", createAggregateFunctionSumMapFiltered<true /*overflow*/>); factory.registerFunction("sumMapWithOverflow", createAggregateFunctionMap<
SumMapVariants<false, true>::DispatchOnTupleArgument>);
factory.registerFunction("sumMapFiltered", createAggregateFunctionMap<
SumMapVariants<true, false>::DispatchOnTupleArgument>);
factory.registerFunction("sumMapFilteredWithOverflow",
createAggregateFunctionMap<
SumMapVariants<true, true>::DispatchOnTupleArgument>);
factory.registerFunction("minMap",
createAggregateFunctionMap<MinMapDispatchOnTupleArgument>);
factory.registerFunction("maxMap",
createAggregateFunctionMap<MaxMapDispatchOnTupleArgument>);
} }
} }

View File

@ -25,19 +25,20 @@ namespace ErrorCodes
{ {
extern const int BAD_ARGUMENTS; extern const int BAD_ARGUMENTS;
extern const int ILLEGAL_TYPE_OF_ARGUMENT; extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
} }
template <typename T> template <typename T>
struct AggregateFunctionSumMapData struct AggregateFunctionMapData
{ {
// Map needs to be ordered to maintain function properties // Map needs to be ordered to maintain function properties
std::map<T, Array> merged_maps; std::map<T, Array> merged_maps;
}; };
/** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of of at least 2 arrays - /** Aggregate function, that takes at least two arguments: keys and values, and as a result, builds a tuple of of at least 2 arrays -
* ordered keys and variable number of argument values summed up by corresponding keys. * ordered keys and variable number of argument values aggregated by corresponding keys.
* *
* This function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map". * sumMap function is the most useful when using SummingMergeTree to sum Nested columns, which name ends in "Map".
* *
* Example: sumMap(k, v...) of: * Example: sumMap(k, v...) of:
* k v * k v
@ -49,24 +50,27 @@ struct AggregateFunctionSumMapData
* [8,9,10] [20,20,20] * [8,9,10] [20,20,20]
* will return: * will return:
* ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20]) * ([1,2,3,4,5,6,7,8,9,10],[10,10,45,20,35,20,15,30,20,20])
*
* minMap and maxMap share the same idea, but calculate min and max correspondingly.
*/ */
template <typename T, typename Derived, bool overflow, bool tuple_argument> template <typename T, typename Derived, typename Visitor, bool overflow, bool tuple_argument>
class AggregateFunctionSumMapBase : public IAggregateFunctionDataHelper< class AggregateFunctionMapBase : public IAggregateFunctionDataHelper<
AggregateFunctionSumMapData<NearestFieldType<T>>, Derived> AggregateFunctionMapData<NearestFieldType<T>>, Derived>
{ {
private: private:
DataTypePtr keys_type; DataTypePtr keys_type;
DataTypes values_types; DataTypes values_types;
public: public:
AggregateFunctionSumMapBase( using Base = IAggregateFunctionDataHelper<
const DataTypePtr & keys_type_, const DataTypes & values_types_, AggregateFunctionMapData<NearestFieldType<T>>, Derived>;
const DataTypes & argument_types_, const Array & params_)
: IAggregateFunctionDataHelper<AggregateFunctionSumMapData<NearestFieldType<T>>, Derived>(argument_types_, params_)
, keys_type(keys_type_), values_types(values_types_) {}
String getName() const override { return "sumMap"; } AggregateFunctionMapBase(const DataTypePtr & keys_type_,
const DataTypes & values_types_, const DataTypes & argument_types_)
: Base(argument_types_, {} /* parameters */), keys_type(keys_type_),
values_types(values_types_)
{}
DataTypePtr getReturnType() const override DataTypePtr getReturnType() const override
{ {
@ -88,7 +92,7 @@ public:
// No overflow, meaning we promote the types if necessary. // No overflow, meaning we promote the types if necessary.
if (!value_type->canBePromoted()) if (!value_type->canBePromoted())
{ {
throw Exception{"Values to be summed are expected to be Numeric, Float or Decimal.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; throw Exception{"Values for " + getName() + " are expected to be Numeric, Float or Decimal.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
} }
result_type = value_type->promoteNumericType(); result_type = value_type->promoteNumericType();
@ -161,7 +165,7 @@ public:
if (it != merged_maps.end()) if (it != merged_maps.end())
{ {
applyVisitor(FieldVisitorSum(value), it->second[col]); applyVisitor(Visitor(value), it->second[col]);
} }
else else
{ {
@ -198,7 +202,7 @@ public:
if (it != merged_maps.end()) if (it != merged_maps.end())
{ {
for (size_t col = 0; col < values_types.size(); ++col) for (size_t col = 0; col < values_types.size(); ++col)
applyVisitor(FieldVisitorSum(elem.second[col]), it->second[col]); applyVisitor(Visitor(elem.second[col]), it->second[col]);
} }
else else
merged_maps[elem.first] = elem.second; merged_maps[elem.first] = elem.second;
@ -300,20 +304,27 @@ public:
} }
bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); } bool keepKey(const T & key) const { return static_cast<const Derived &>(*this).keepKey(key); }
String getName() const override { return static_cast<const Derived &>(*this).getName(); }
}; };
template <typename T, bool overflow, bool tuple_argument> template <typename T, bool overflow, bool tuple_argument>
class AggregateFunctionSumMap final : class AggregateFunctionSumMap final :
public AggregateFunctionSumMapBase<T, AggregateFunctionSumMap<T, overflow, tuple_argument>, overflow, tuple_argument> public AggregateFunctionMapBase<T, AggregateFunctionSumMap<T, overflow, tuple_argument>, FieldVisitorSum, overflow, tuple_argument>
{ {
private: private:
using Self = AggregateFunctionSumMap<T, overflow, tuple_argument>; using Self = AggregateFunctionSumMap<T, overflow, tuple_argument>;
using Base = AggregateFunctionSumMapBase<T, Self, overflow, tuple_argument>; using Base = AggregateFunctionMapBase<T, Self, FieldVisitorSum, overflow, tuple_argument>;
public: public:
AggregateFunctionSumMap(const DataTypePtr & keys_type_, DataTypes & values_types_, const DataTypes & argument_types_) AggregateFunctionSumMap(const DataTypePtr & keys_type_,
: Base{keys_type_, values_types_, argument_types_, {}} DataTypes & values_types_, const DataTypes & argument_types_,
{} const Array & params_)
: Base{keys_type_, values_types_, argument_types_}
{
// The constructor accepts parameters to have a uniform interface with
// sumMapFiltered, but this function doesn't have any parameters.
assertNoParameters(getName(), params_);
}
String getName() const override { return "sumMap"; } String getName() const override { return "sumMap"; }
@ -322,23 +333,35 @@ public:
template <typename T, bool overflow, bool tuple_argument> template <typename T, bool overflow, bool tuple_argument>
class AggregateFunctionSumMapFiltered final : class AggregateFunctionSumMapFiltered final :
public AggregateFunctionSumMapBase<T, public AggregateFunctionMapBase<T,
AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>, AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>,
FieldVisitorSum,
overflow, overflow,
tuple_argument> tuple_argument>
{ {
private: private:
using Self = AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>; using Self = AggregateFunctionSumMapFiltered<T, overflow, tuple_argument>;
using Base = AggregateFunctionSumMapBase<T, Self, overflow, tuple_argument>; using Base = AggregateFunctionMapBase<T, Self, FieldVisitorSum, overflow, tuple_argument>;
std::unordered_set<T> keys_to_keep; std::unordered_set<T> keys_to_keep;
public: public:
AggregateFunctionSumMapFiltered( AggregateFunctionSumMapFiltered(const DataTypePtr & keys_type_,
const DataTypePtr & keys_type_, const DataTypes & values_types_, const Array & keys_to_keep_, const DataTypes & values_types_, const DataTypes & argument_types_,
const DataTypes & argument_types_, const Array & params_) const Array & params_)
: Base{keys_type_, values_types_, argument_types_, params_} : Base{keys_type_, values_types_, argument_types_}
{ {
if (params_.size() != 1)
throw Exception(ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH,
"Aggregate function '{}' requires exactly one parameter "
"of Array type", getName());
Array keys_to_keep_;
if (!params_.front().tryGet<Array>(keys_to_keep_))
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT,
"Aggregate function {} requires an Array as a parameter",
getName());
keys_to_keep.reserve(keys_to_keep_.size()); keys_to_keep.reserve(keys_to_keep_.size());
for (const Field & f : keys_to_keep_) for (const Field & f : keys_to_keep_)
{ {
@ -346,9 +369,58 @@ public:
} }
} }
String getName() const override { return "sumMapFiltered"; } String getName() const override
{ return overflow ? "sumMapFilteredWithOverflow" : "sumMapFiltered"; }
bool keepKey(const T & key) const { return keys_to_keep.count(key); } bool keepKey(const T & key) const { return keys_to_keep.count(key); }
}; };
template <typename T, bool tuple_argument>
class AggregateFunctionMinMap final :
public AggregateFunctionMapBase<T, AggregateFunctionMinMap<T, tuple_argument>, FieldVisitorMin, true, tuple_argument>
{
private:
using Self = AggregateFunctionMinMap<T, tuple_argument>;
using Base = AggregateFunctionMapBase<T, Self, FieldVisitorMin, true, tuple_argument>;
public:
AggregateFunctionMinMap(const DataTypePtr & keys_type_,
DataTypes & values_types_, const DataTypes & argument_types_,
const Array & params_)
: Base{keys_type_, values_types_, argument_types_}
{
// The constructor accepts parameters to have a uniform interface with
// sumMapFiltered, but this function doesn't have any parameters.
assertNoParameters(getName(), params_);
}
String getName() const override { return "minMap"; }
bool keepKey(const T &) const { return true; }
};
template <typename T, bool tuple_argument>
class AggregateFunctionMaxMap final :
public AggregateFunctionMapBase<T, AggregateFunctionMaxMap<T, tuple_argument>, FieldVisitorMax, true, tuple_argument>
{
private:
using Self = AggregateFunctionMaxMap<T, tuple_argument>;
using Base = AggregateFunctionMapBase<T, Self, FieldVisitorMax, true, tuple_argument>;
public:
AggregateFunctionMaxMap(const DataTypePtr & keys_type_,
DataTypes & values_types_, const DataTypes & argument_types_,
const Array & params_)
: Base{keys_type_, values_types_, argument_types_}
{
// The constructor accepts parameters to have a uniform interface with
// sumMapFiltered, but this function doesn't have any parameters.
assertNoParameters(getName(), params_);
}
String getName() const override { return "maxMap"; }
bool keepKey(const T &) const { return true; }
};
} }

View File

@ -123,13 +123,13 @@ AggregateFunctionPtr createAggregateFunctionUniq(const std::string & name, const
void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory) void registerAggregateFunctionsUniq(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("uniq", factory.registerFunction("uniq",
createAggregateFunctionUniq<AggregateFunctionUniqUniquesHashSetData, AggregateFunctionUniqUniquesHashSetDataForVariadic>); {createAggregateFunctionUniq<AggregateFunctionUniqUniquesHashSetData, AggregateFunctionUniqUniquesHashSetDataForVariadic>, {true}});
factory.registerFunction("uniqHLL12", factory.registerFunction("uniqHLL12",
createAggregateFunctionUniq<false, AggregateFunctionUniqHLL12Data, AggregateFunctionUniqHLL12DataForVariadic>); {createAggregateFunctionUniq<false, AggregateFunctionUniqHLL12Data, AggregateFunctionUniqHLL12DataForVariadic>, {true}});
factory.registerFunction("uniqExact", factory.registerFunction("uniqExact",
createAggregateFunctionUniq<true, AggregateFunctionUniqExactData, AggregateFunctionUniqExactData<String>>); {createAggregateFunctionUniq<true, AggregateFunctionUniqExactData, AggregateFunctionUniqExactData<String>>, {true}});
} }
} }

View File

@ -244,12 +244,6 @@ public:
{ {
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size()); assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
} }
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
}; };
@ -304,12 +298,6 @@ public:
{ {
assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size()); assert_cast<ColumnUInt64 &>(to).getData().push_back(this->data(place).set.size());
} }
/// The function returns non-Nullable type even when wrapped with Null combinator.
bool returnDefaultWhenOnlyNull() const override
{
return true;
}
}; };
} }

View File

@ -85,7 +85,7 @@ AggregateFunctionPtr createAggregateFunctionUniqUpTo(const std::string & name, c
void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory & factory) void registerAggregateFunctionUniqUpTo(AggregateFunctionFactory & factory)
{ {
factory.registerFunction("uniqUpTo", createAggregateFunctionUniqUpTo); factory.registerFunction("uniqUpTo", {createAggregateFunctionUniqUpTo, {true}});
} }
} }

View File

@ -166,17 +166,12 @@ public:
* nested_function is a smart pointer to this aggregate function itself. * nested_function is a smart pointer to this aggregate function itself.
* arguments and params are for nested_function. * arguments and params are for nested_function.
*/ */
virtual AggregateFunctionPtr getOwnNullAdapter(const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const virtual AggregateFunctionPtr getOwnNullAdapter(
const AggregateFunctionPtr & /*nested_function*/, const DataTypes & /*arguments*/, const Array & /*params*/) const
{ {
return nullptr; return nullptr;
} }
/** When the function is wrapped with Null combinator,
* should we return Nullable type with NULL when no values were aggregated
* or we should return non-Nullable type with default value (example: count, countDistinct).
*/
virtual bool returnDefaultWhenOnlyNull() const { return false; }
const DataTypes & getArgumentTypes() const { return argument_types; } const DataTypes & getArgumentTypes() const { return argument_types; }
const Array & getParameters() const { return parameters; } const Array & getParameters() const { return parameters; }
@ -286,4 +281,15 @@ public:
}; };
/// Properties of aggregate function that are independent of argument types and parameters.
struct AggregateFunctionProperties
{
/** When the function is wrapped with Null combinator,
* should we return Nullable type with NULL when no values were aggregated
* or we should return non-Nullable type with default value (example: count, countDistinct).
*/
bool returns_default_when_only_null = false;
};
} }

View File

@ -59,6 +59,7 @@ public:
*/ */
virtual AggregateFunctionPtr transformAggregateFunction( virtual AggregateFunctionPtr transformAggregateFunction(
const AggregateFunctionPtr & nested_function, const AggregateFunctionPtr & nested_function,
const AggregateFunctionProperties & properties,
const DataTypes & arguments, const DataTypes & arguments,
const Array & params) const = 0; const Array & params) const = 0;

View File

@ -381,6 +381,6 @@ if (ENABLE_TESTS AND USE_GTEST)
-Wno-gnu-zero-variadic-macro-arguments -Wno-gnu-zero-variadic-macro-arguments
) )
target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils) target_link_libraries(unit_tests_dbms PRIVATE ${GTEST_BOTH_LIBRARIES} clickhouse_functions clickhouse_aggregate_functions clickhouse_parsers dbms clickhouse_common_zookeeper string_utils)
add_check(unit_tests_dbms) add_check(unit_tests_dbms)
endif () endif ()

View File

@ -11,7 +11,6 @@
#include <IO/ConnectionTimeouts.h> #include <IO/ConnectionTimeouts.h>
namespace ProfileEvents namespace ProfileEvents
{ {
extern const Event DistributedConnectionMissingTable; extern const Event DistributedConnectionMissingTable;
@ -71,6 +70,18 @@ IConnectionPool::Entry ConnectionPoolWithFailover::get(const ConnectionTimeouts
case LoadBalancing::FIRST_OR_RANDOM: case LoadBalancing::FIRST_OR_RANDOM:
get_priority = [](size_t i) -> size_t { return i >= 1; }; get_priority = [](size_t i) -> size_t { return i >= 1; };
break; break;
case LoadBalancing::ROUND_ROBIN:
if (last_used >= nested_pools.size())
last_used = 0;
++last_used;
/* Consider nested_pools.size() equals to 5
* last_used = 1 -> get_priority: 0 1 2 3 4
* last_used = 2 -> get_priority: 5 0 1 2 3
* last_used = 3 -> get_priority: 5 4 0 1 2
* ...
* */
get_priority = [&](size_t i) { ++i; return i < last_used ? nested_pools.size() - i : i - last_used; };
break;
} }
return Base::get(try_get_entry, get_priority); return Base::get(try_get_entry, get_priority);
@ -181,6 +192,18 @@ std::vector<ConnectionPoolWithFailover::TryResult> ConnectionPoolWithFailover::g
case LoadBalancing::FIRST_OR_RANDOM: case LoadBalancing::FIRST_OR_RANDOM:
get_priority = [](size_t i) -> size_t { return i >= 1; }; get_priority = [](size_t i) -> size_t { return i >= 1; };
break; break;
case LoadBalancing::ROUND_ROBIN:
if (last_used >= nested_pools.size())
last_used = 0;
++last_used;
/* Consider nested_pools.size() equals to 5
* last_used = 1 -> get_priority: 0 1 2 3 4
* last_used = 2 -> get_priority: 5 0 1 2 3
* last_used = 3 -> get_priority: 5 4 0 1 2
* ...
* */
get_priority = [&](size_t i) { ++i; return i < last_used ? nested_pools.size() - i : i - last_used; };
break;
} }
bool fallback_to_stale_replicas = settings ? bool(settings->fallback_to_stale_replicas_for_distributed_queries) : true; bool fallback_to_stale_replicas = settings ? bool(settings->fallback_to_stale_replicas_for_distributed_queries) : true;

View File

@ -97,6 +97,7 @@ private:
private: private:
std::vector<size_t> hostname_differences; /// Distances from name of this host to the names of hosts of pools. std::vector<size_t> hostname_differences; /// Distances from name of this host to the names of hosts of pools.
size_t last_used = 0; /// Last used for round_robin policy.
LoadBalancing default_load_balancing; LoadBalancing default_load_balancing;
}; };

View File

@ -1,2 +0,0 @@
add_executable(test-connect test_connect.cpp)
target_link_libraries (test-connect PRIVATE dbms)

View File

@ -1,59 +0,0 @@
#include <sys/types.h>
#include <sys/socket.h>
#include <unistd.h>
#include <iostream>
#include <Poco/Net/StreamSocket.h>
#include <Common/Exception.h>
#include <IO/ReadHelpers.h>
/** In a loop it connects to the server and immediately breaks the connection.
* Using the SO_LINGER option, we ensure that the connection is terminated by sending a RST packet (not FIN).
* This behavior causes a bug in the TCPServer implementation in the Poco library.
*/
int main(int argc, char ** argv)
try
{
for (size_t i = 0, num_iters = argc >= 2 ? DB::parse<size_t>(argv[1]) : 1; i < num_iters; ++i)
{
std::cerr << ".";
Poco::Net::SocketAddress address("localhost", 9000);
int fd = socket(PF_INET, SOCK_STREAM, IPPROTO_IP);
if (fd < 0)
DB::throwFromErrno("Cannot create socket", 0);
linger linger_value;
linger_value.l_onoff = 1;
linger_value.l_linger = 0;
if (0 != setsockopt(fd, SOL_SOCKET, SO_LINGER, &linger_value, sizeof(linger_value)))
DB::throwFromErrno("Cannot set linger", 0);
try
{
int res = connect(fd, address.addr(), address.length());
if (res != 0 && errno != EINPROGRESS && errno != EWOULDBLOCK)
{
close(fd);
DB::throwFromErrno("Cannot connect", 0);
}
close(fd);
}
catch (const Poco::Exception & e)
{
std::cerr << e.displayText() << "\n";
}
}
std::cerr << "\n";
}
catch (const Poco::Exception & e)
{
std::cerr << e.displayText() << "\n";
}

View File

@ -210,4 +210,88 @@ public:
} }
}; };
/** Implements `Max` operation.
* Returns true if changed
*/
class FieldVisitorMax : public StaticVisitor<bool>
{
private:
const Field & rhs;
public:
explicit FieldVisitorMax(const Field & rhs_) : rhs(rhs_) {}
bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Array &) const { throw Exception("Cannot compare Arrays", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Tuple &) const { throw Exception("Cannot compare Tuples", ErrorCodes::LOGICAL_ERROR); }
bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot compare AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); }
template <typename T>
bool operator() (DecimalField<T> & x) const
{
auto val = get<DecimalField<T>>(rhs);
if (val > x)
{
x = val;
return true;
}
return false;
}
template <typename T>
bool operator() (T & x) const
{
auto val = get<T>(rhs);
if (val > x)
{
x = val;
return true;
}
return false;
}
};
/** Implements `Min` operation.
* Returns true if changed
*/
class FieldVisitorMin : public StaticVisitor<bool>
{
private:
const Field & rhs;
public:
explicit FieldVisitorMin(const Field & rhs_) : rhs(rhs_) {}
bool operator() (Null &) const { throw Exception("Cannot compare Nulls", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Array &) const { throw Exception("Cannot sum Arrays", ErrorCodes::LOGICAL_ERROR); }
bool operator() (Tuple &) const { throw Exception("Cannot sum Tuples", ErrorCodes::LOGICAL_ERROR); }
bool operator() (AggregateFunctionStateData &) const { throw Exception("Cannot sum AggregateFunctionStates", ErrorCodes::LOGICAL_ERROR); }
template <typename T>
bool operator() (DecimalField<T> & x) const
{
auto val = get<DecimalField<T>>(rhs);
if (val < x)
{
x = val;
return true;
}
return false;
}
template <typename T>
bool operator() (T & x) const
{
auto val = get<T>(rhs);
if (val < x)
{
x = val;
return true;
}
return false;
}
};
} }

View File

@ -16,14 +16,14 @@ namespace ErrorCodes
} }
/** If stored objects may have several names (aliases) /** If stored objects may have several names (aliases)
* this interface may be helpful * this interface may be helpful
* template parameter is available as Creator * template parameter is available as Value
*/ */
template <typename CreatorFunc> template <typename ValueType>
class IFactoryWithAliases : public IHints<2, IFactoryWithAliases<CreatorFunc>> class IFactoryWithAliases : public IHints<2, IFactoryWithAliases<ValueType>>
{ {
protected: protected:
using Creator = CreatorFunc; using Value = ValueType;
String getAliasToOrName(const String & name) const String getAliasToOrName(const String & name) const
{ {
@ -43,13 +43,13 @@ public:
CaseInsensitive CaseInsensitive
}; };
/** Register additional name for creator /** Register additional name for value
* real_name have to be already registered. * real_name have to be already registered.
*/ */
void registerAlias(const String & alias_name, const String & real_name, CaseSensitiveness case_sensitiveness = CaseSensitive) void registerAlias(const String & alias_name, const String & real_name, CaseSensitiveness case_sensitiveness = CaseSensitive)
{ {
const auto & creator_map = getCreatorMap(); const auto & creator_map = getMap();
const auto & case_insensitive_creator_map = getCaseInsensitiveCreatorMap(); const auto & case_insensitive_creator_map = getCaseInsensitiveMap();
const String factory_name = getFactoryName(); const String factory_name = getFactoryName();
String real_dict_name; String real_dict_name;
@ -80,7 +80,7 @@ public:
{ {
std::vector<String> result; std::vector<String> result;
auto getter = [](const auto & pair) { return pair.first; }; auto getter = [](const auto & pair) { return pair.first; };
std::transform(getCreatorMap().begin(), getCreatorMap().end(), std::back_inserter(result), getter); std::transform(getMap().begin(), getMap().end(), std::back_inserter(result), getter);
std::transform(aliases.begin(), aliases.end(), std::back_inserter(result), getter); std::transform(aliases.begin(), aliases.end(), std::back_inserter(result), getter);
return result; return result;
} }
@ -88,7 +88,7 @@ public:
bool isCaseInsensitive(const String & name) const bool isCaseInsensitive(const String & name) const
{ {
String name_lowercase = Poco::toLower(name); String name_lowercase = Poco::toLower(name);
return getCaseInsensitiveCreatorMap().count(name_lowercase) || case_insensitive_aliases.count(name_lowercase); return getCaseInsensitiveMap().count(name_lowercase) || case_insensitive_aliases.count(name_lowercase);
} }
const String & aliasTo(const String & name) const const String & aliasTo(const String & name) const
@ -109,11 +109,11 @@ public:
virtual ~IFactoryWithAliases() override {} virtual ~IFactoryWithAliases() override {}
private: private:
using InnerMap = std::unordered_map<String, Creator>; // name -> creator using InnerMap = std::unordered_map<String, Value>; // name -> creator
using AliasMap = std::unordered_map<String, String>; // alias -> original type using AliasMap = std::unordered_map<String, String>; // alias -> original type
virtual const InnerMap & getCreatorMap() const = 0; virtual const InnerMap & getMap() const = 0;
virtual const InnerMap & getCaseInsensitiveCreatorMap() const = 0; virtual const InnerMap & getCaseInsensitiveMap() const = 0;
virtual String getFactoryName() const = 0; virtual String getFactoryName() const = 0;
/// Alias map to data_types from previous two maps /// Alias map to data_types from previous two maps

View File

@ -83,6 +83,23 @@ const char * IntervalKind::toKeyword() const
} }
const char * IntervalKind::toLowercasedKeyword() const
{
switch (kind)
{
case IntervalKind::Second: return "second";
case IntervalKind::Minute: return "minute";
case IntervalKind::Hour: return "hour";
case IntervalKind::Day: return "day";
case IntervalKind::Week: return "week";
case IntervalKind::Month: return "month";
case IntervalKind::Quarter: return "quarter";
case IntervalKind::Year: return "year";
}
__builtin_unreachable();
}
const char * IntervalKind::toDateDiffUnit() const const char * IntervalKind::toDateDiffUnit() const
{ {
switch (kind) switch (kind)

View File

@ -37,6 +37,8 @@ struct IntervalKind
/// Returns an uppercased version of what `toString()` returns. /// Returns an uppercased version of what `toString()` returns.
const char * toKeyword() const; const char * toKeyword() const;
const char * toLowercasedKeyword() const;
/// Returns the string which can be passed to the `unit` parameter of the dateDiff() function. /// Returns the string which can be passed to the `unit` parameter of the dateDiff() function.
/// For example, `IntervalKind{IntervalKind::Day}.getDateDiffParameter()` returns "day". /// For example, `IntervalKind{IntervalKind::Day}.getDateDiffParameter()` returns "day".
const char * toDateDiffUnit() const; const char * toDateDiffUnit() const;

View File

@ -203,6 +203,11 @@
\ \
M(CannotWriteToWriteBufferDiscard, "Number of stack traces dropped by query profiler or signal handler because pipe is full or cannot write to pipe.") \ M(CannotWriteToWriteBufferDiscard, "Number of stack traces dropped by query profiler or signal handler because pipe is full or cannot write to pipe.") \
M(QueryProfilerSignalOverruns, "Number of times we drop processing of a signal due to overrun plus the number of signals that OS has not delivered due to overrun.") \ M(QueryProfilerSignalOverruns, "Number of times we drop processing of a signal due to overrun plus the number of signals that OS has not delivered due to overrun.") \
\
M(CreatedLogEntryForMerge, "Successfully created log entry to merge parts in ReplicatedMergeTree.") \
M(NotCreatedLogEntryForMerge, "Log entry to merge parts in ReplicatedMergeTree is not created due to concurrent log update by another replica.") \
M(CreatedLogEntryForMutation, "Successfully created log entry to mutate parts in ReplicatedMergeTree.") \
M(NotCreatedLogEntryForMutation, "Log entry to mutate parts in ReplicatedMergeTree is not created due to concurrent log update by another replica.") \
namespace ProfileEvents namespace ProfileEvents
{ {

View File

@ -109,7 +109,8 @@ public:
uri.setPath(IDENTIFIER_QUOTE_HANDLER); uri.setPath(IDENTIFIER_QUOTE_HANDLER);
uri.addQueryParameter("connection_string", getConnectionString()); uri.addQueryParameter("connection_string", getConnectionString());
ReadWriteBufferFromHTTP buf(uri, Poco::Net::HTTPRequest::HTTP_POST, nullptr); ReadWriteBufferFromHTTP buf(
uri, Poco::Net::HTTPRequest::HTTP_POST, {}, ConnectionTimeouts::getHTTPTimeouts(context));
std::string character; std::string character;
readStringBinary(character, buf); readStringBinary(character, buf);
if (character.length() > 1) if (character.length() > 1)
@ -208,7 +209,8 @@ private:
{ {
try try
{ {
ReadWriteBufferFromHTTP buf(ping_url, Poco::Net::HTTPRequest::HTTP_GET, nullptr); ReadWriteBufferFromHTTP buf(
ping_url, Poco::Net::HTTPRequest::HTTP_GET, {}, ConnectionTimeouts::getHTTPTimeouts(context));
return checkString(XDBCBridgeHelper::PING_OK_ANSWER, buf); return checkString(XDBCBridgeHelper::PING_OK_ANSWER, buf);
} }
catch (...) catch (...)

View File

@ -14,15 +14,27 @@ using namespace DB;
TEST(zkutil, ZookeeperConnected) TEST(zkutil, ZookeeperConnected)
{ {
try /// In our CI infrastructure it is typical that ZooKeeper is unavailable for some amount of time.
size_t i;
for (i = 0; i < 100; ++i)
{ {
auto zookeeper = std::make_unique<zkutil::ZooKeeper>("localhost:2181"); try
zookeeper->exists("/"); {
zookeeper->createIfNotExists("/clickhouse_test", "Unit tests of ClickHouse"); auto zookeeper = std::make_unique<zkutil::ZooKeeper>("localhost:2181");
zookeeper->exists("/");
zookeeper->createIfNotExists("/clickhouse_test", "Unit tests of ClickHouse");
}
catch (...)
{
std::cerr << "Zookeeper is unavailable, try " << i << std::endl;
sleep(1);
continue;
}
break;
} }
catch (...) if (i == 100)
{ {
std::cerr << "No zookeeper. skip tests." << std::endl; std::cerr << "No zookeeper after " << i << " tries. skip tests." << std::endl;
exit(0); exit(0);
} }
} }

View File

@ -163,6 +163,8 @@ using BlocksPtrs = std::shared_ptr<std::vector<BlocksPtr>>;
struct ExtraBlock struct ExtraBlock
{ {
Block block; Block block;
bool empty() const { return !block; }
}; };
using ExtraBlockPtr = std::shared_ptr<ExtraBlock>; using ExtraBlockPtr = std::shared_ptr<ExtraBlock>;

View File

@ -289,6 +289,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingJoinAlgorithm, join_algorithm, JoinAlgorithm::HASH, "Specify join algorithm: 'auto', 'hash', 'partial_merge', 'prefer_partial_merge'. 'auto' tries to change HashJoin to MergeJoin on the fly to avoid out of memory.", 0) \ M(SettingJoinAlgorithm, join_algorithm, JoinAlgorithm::HASH, "Specify join algorithm: 'auto', 'hash', 'partial_merge', 'prefer_partial_merge'. 'auto' tries to change HashJoin to MergeJoin on the fly to avoid out of memory.", 0) \
M(SettingBool, partial_merge_join_optimizations, true, "Enable optimizations in partial merge join", 0) \ M(SettingBool, partial_merge_join_optimizations, true, "Enable optimizations in partial merge join", 0) \
M(SettingUInt64, default_max_bytes_in_join, 1000000000, "Maximum size of right-side table if limit is required but max_bytes_in_join is not set.", 0) \ M(SettingUInt64, default_max_bytes_in_join, 1000000000, "Maximum size of right-side table if limit is required but max_bytes_in_join is not set.", 0) \
M(SettingUInt64, partial_merge_join_left_table_buffer_bytes, 32000000, "If not 0 group left table blocks in bigger ones for left-side table in partial merge join. It uses up to 2x of specified memory per joining thread. In current version work only with 'partial_merge_join_optimizations = 1'.", 0) \
M(SettingUInt64, partial_merge_join_rows_in_right_blocks, 65536, "Split right-hand joining data in blocks of specified size. It's a portion of data indexed by min-max values and possibly unloaded on disk.", 0) \ M(SettingUInt64, partial_merge_join_rows_in_right_blocks, 65536, "Split right-hand joining data in blocks of specified size. It's a portion of data indexed by min-max values and possibly unloaded on disk.", 0) \
M(SettingUInt64, join_on_disk_max_files_to_merge, 64, "For MergeJoin on disk set how much files it's allowed to sort simultaneously. Then this value bigger then more memory used and then less disk I/O needed. Minimum is 2.", 0) \ M(SettingUInt64, join_on_disk_max_files_to_merge, 64, "For MergeJoin on disk set how much files it's allowed to sort simultaneously. Then this value bigger then more memory used and then less disk I/O needed. Minimum is 2.", 0) \
M(SettingString, temporary_files_codec, "LZ4", "Set compression codec for temporary files (sort and join on disk). I.e. LZ4, NONE.", 0) \ M(SettingString, temporary_files_codec, "LZ4", "Set compression codec for temporary files (sort and join on disk). I.e. LZ4, NONE.", 0) \
@ -403,6 +404,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingBool, input_format_skip_unknown_fields, false, "Skip columns with unknown names from input data (it works for JSONEachRow, CSVWithNames, TSVWithNames and TSKV formats).", 0) \ M(SettingBool, input_format_skip_unknown_fields, false, "Skip columns with unknown names from input data (it works for JSONEachRow, CSVWithNames, TSVWithNames and TSKV formats).", 0) \
M(SettingBool, input_format_with_names_use_header, true, "For TSVWithNames and CSVWithNames input formats this controls whether format parser is to assume that column data appear in the input exactly as they are specified in the header.", 0) \ M(SettingBool, input_format_with_names_use_header, true, "For TSVWithNames and CSVWithNames input formats this controls whether format parser is to assume that column data appear in the input exactly as they are specified in the header.", 0) \
M(SettingBool, input_format_import_nested_json, false, "Map nested JSON data to nested tables (it works for JSONEachRow format).", 0) \ M(SettingBool, input_format_import_nested_json, false, "Map nested JSON data to nested tables (it works for JSONEachRow format).", 0) \
M(SettingBool, optimize_aggregators_of_group_by_keys, true, "Eliminates min/max/any/anyLast aggregators of GROUP BY keys in SELECT section", 0) \
M(SettingBool, input_format_defaults_for_omitted_fields, true, "For input data calculate default expressions for omitted fields (it works for JSONEachRow, CSV and TSV formats).", IMPORTANT) \ M(SettingBool, input_format_defaults_for_omitted_fields, true, "For input data calculate default expressions for omitted fields (it works for JSONEachRow, CSV and TSV formats).", IMPORTANT) \
M(SettingBool, input_format_tsv_empty_as_default, false, "Treat empty fields in TSV input as default values.", 0) \ M(SettingBool, input_format_tsv_empty_as_default, false, "Treat empty fields in TSV input as default values.", 0) \
M(SettingBool, input_format_null_as_default, false, "For text input formats initialize null fields with default values if data type of this field is not nullable", 0) \ M(SettingBool, input_format_null_as_default, false, "For text input formats initialize null fields with default values if data type of this field is not nullable", 0) \

View File

@ -481,7 +481,8 @@ void SettingURI::deserialize(ReadBuffer & buf, SettingsBinaryFormat)
M(RANDOM, "random") \ M(RANDOM, "random") \
M(NEAREST_HOSTNAME, "nearest_hostname") \ M(NEAREST_HOSTNAME, "nearest_hostname") \
M(IN_ORDER, "in_order") \ M(IN_ORDER, "in_order") \
M(FIRST_OR_RANDOM, "first_or_random") M(FIRST_OR_RANDOM, "first_or_random") \
M(ROUND_ROBIN, "round_robin")
IMPLEMENT_SETTING_ENUM(LoadBalancing, LOAD_BALANCING_LIST_OF_NAMES, ErrorCodes::UNKNOWN_LOAD_BALANCING) IMPLEMENT_SETTING_ENUM(LoadBalancing, LOAD_BALANCING_LIST_OF_NAMES, ErrorCodes::UNKNOWN_LOAD_BALANCING)

View File

@ -225,11 +225,14 @@ enum class LoadBalancing
/// a replica is selected among the replicas with the minimum number of errors /// a replica is selected among the replicas with the minimum number of errors
/// with the minimum number of distinguished characters in the replica name and local hostname /// with the minimum number of distinguished characters in the replica name and local hostname
NEAREST_HOSTNAME, NEAREST_HOSTNAME,
/// replicas are walked through strictly in order; the number of errors does not matter // replicas with the same number of errors are accessed in the same order
// as they are specified in the configuration.
IN_ORDER, IN_ORDER,
/// if first replica one has higher number of errors, /// if first replica one has higher number of errors,
/// pick a random one from replicas with minimum number of errors /// pick a random one from replicas with minimum number of errors
FIRST_OR_RANDOM, FIRST_OR_RANDOM,
// round robin across replicas with the same number of errors.
ROUND_ROBIN,
}; };
using SettingLoadBalancing = SettingEnum<LoadBalancing>; using SettingLoadBalancing = SettingEnum<LoadBalancing>;

View File

@ -55,10 +55,14 @@ Block InflatingExpressionBlockInputStream::readImpl()
} }
Block res; Block res;
if (likely(!not_processed)) bool keep_going = not_processed && not_processed->empty(); /// There's data inside expression.
if (!not_processed || keep_going)
{ {
not_processed.reset();
res = children.back()->read(); res = children.back()->read();
if (res) if (res || keep_going)
expression->execute(res, not_processed, action_number); expression->execute(res, not_processed, action_number);
} }
else else

View File

@ -1,4 +1,4 @@
set(SRCS) set(SRCS)
add_executable (finish_sorting_stream finish_sorting_stream.cpp ${SRCS}) add_executable (finish_sorting_stream finish_sorting_stream.cpp ${SRCS})
target_link_libraries (finish_sorting_stream PRIVATE dbms) target_link_libraries (finish_sorting_stream PRIVATE clickhouse_aggregate_functions dbms)

View File

@ -392,7 +392,8 @@ static DataTypePtr create(const ASTPtr & arguments)
if (function_name.empty()) if (function_name.empty())
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties);
return std::make_shared<DataTypeAggregateFunction>(function, argument_types, params_row); return std::make_shared<DataTypeAggregateFunction>(function, argument_types, params_row);
} }

View File

@ -110,7 +110,8 @@ static std::pair<DataTypePtr, DataTypeCustomDescPtr> create(const ASTPtr & argum
if (function_name.empty()) if (function_name.empty())
throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR); throw Exception("Logical error: empty name of aggregate function passed", ErrorCodes::LOGICAL_ERROR);
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row); AggregateFunctionProperties properties;
function = AggregateFunctionFactory::instance().get(function_name, argument_types, params_row, properties);
// check function // check function
if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions)) if (std::find(std::begin(supported_functions), std::end(supported_functions), function->getName()) == std::end(supported_functions))

View File

@ -80,7 +80,7 @@ DataTypePtr DataTypeFactory::get(const String & family_name_param, const ASTPtr
} }
void DataTypeFactory::registerDataType(const String & family_name, Creator creator, CaseSensitiveness case_sensitiveness) void DataTypeFactory::registerDataType(const String & family_name, Value creator, CaseSensitiveness case_sensitiveness)
{ {
if (creator == nullptr) if (creator == nullptr)
throw Exception("DataTypeFactory: the data type family " + family_name + " has been provided " throw Exception("DataTypeFactory: the data type family " + family_name + " has been provided "
@ -136,7 +136,7 @@ void DataTypeFactory::registerSimpleDataTypeCustom(const String &name, SimpleCre
}, case_sensitiveness); }, case_sensitiveness);
} }
const DataTypeFactory::Creator& DataTypeFactory::findCreatorByName(const String & family_name) const const DataTypeFactory::Value & DataTypeFactory::findCreatorByName(const String & family_name) const
{ {
{ {
DataTypesDictionary::const_iterator it = data_types.find(family_name); DataTypesDictionary::const_iterator it = data_types.find(family_name);

View File

@ -23,7 +23,7 @@ class DataTypeFactory final : private boost::noncopyable, public IFactoryWithAli
{ {
private: private:
using SimpleCreator = std::function<DataTypePtr()>; using SimpleCreator = std::function<DataTypePtr()>;
using DataTypesDictionary = std::unordered_map<String, Creator>; using DataTypesDictionary = std::unordered_map<String, Value>;
using CreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>(const ASTPtr & parameters)>; using CreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>(const ASTPtr & parameters)>;
using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>()>; using SimpleCreatorWithCustom = std::function<std::pair<DataTypePtr,DataTypeCustomDescPtr>()>;
@ -35,7 +35,7 @@ public:
DataTypePtr get(const ASTPtr & ast) const; DataTypePtr get(const ASTPtr & ast) const;
/// Register a type family by its name. /// Register a type family by its name.
void registerDataType(const String & family_name, Creator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); void registerDataType(const String & family_name, Value creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
/// Register a simple data type, that have no parameters. /// Register a simple data type, that have no parameters.
void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive); void registerSimpleDataType(const String & name, SimpleCreator creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
@ -47,7 +47,7 @@ public:
void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive); void registerSimpleDataTypeCustom(const String & name, SimpleCreatorWithCustom creator, CaseSensitiveness case_sensitiveness = CaseSensitive);
private: private:
const Creator& findCreatorByName(const String & family_name) const; const Value & findCreatorByName(const String & family_name) const;
private: private:
DataTypesDictionary data_types; DataTypesDictionary data_types;
@ -57,9 +57,9 @@ private:
DataTypeFactory(); DataTypeFactory();
const DataTypesDictionary & getCreatorMap() const override { return data_types; } const DataTypesDictionary & getMap() const override { return data_types; }
const DataTypesDictionary & getCaseInsensitiveCreatorMap() const override { return case_insensitive_data_types; } const DataTypesDictionary & getCaseInsensitiveMap() const override { return case_insensitive_data_types; }
String getFactoryName() const override { return "DataTypeFactory"; } String getFactoryName() const override { return "DataTypeFactory"; }
}; };

View File

@ -1,5 +0,0 @@
add_executable (data_types_number_fixed data_types_number_fixed.cpp)
target_link_libraries (data_types_number_fixed PRIVATE dbms)
add_executable (data_type_string data_type_string.cpp)
target_link_libraries (data_type_string PRIVATE dbms)

View File

@ -1,81 +0,0 @@
#include <string>
#include <iostream>
#include <fstream>
#include <Common/Stopwatch.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/WriteBufferFromFile.h>
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeString.h>
int main(int, char **)
try
{
using namespace DB;
Stopwatch stopwatch;
size_t n = 50000000;
const char * s = "";
size_t size = strlen(s) + 1;
DataTypeString data_type;
{
auto column = ColumnString::create();
ColumnString::Chars & data = column->getChars();
ColumnString::Offsets & offsets = column->getOffsets();
data.resize(n * size);
offsets.resize(n);
for (size_t i = 0; i < n; ++i)
{
memcpy(&data[i * size], s, size);
offsets[i] = (i + 1) * size;
}
WriteBufferFromFile out_buf("test");
IDataType::SerializeBinaryBulkSettings settings;
IDataType::SerializeBinaryBulkStatePtr state;
settings.getter = [&](const IDataType::SubstreamPath &){ return &out_buf; };
stopwatch.restart();
data_type.serializeBinaryBulkStatePrefix(settings, state);
data_type.serializeBinaryBulkWithMultipleStreams(*column, 0, 0, settings, state);
data_type.serializeBinaryBulkStateSuffix(settings, state);
stopwatch.stop();
std::cout << "Writing, elapsed: " << stopwatch.elapsedSeconds() << std::endl;
}
{
auto column = ColumnString::create();
ReadBufferFromFile in_buf("test");
IDataType::DeserializeBinaryBulkSettings settings;
IDataType::DeserializeBinaryBulkStatePtr state;
settings.getter = [&](const IDataType::SubstreamPath &){ return &in_buf; };
stopwatch.restart();
data_type.deserializeBinaryBulkStatePrefix(settings, state);
data_type.deserializeBinaryBulkWithMultipleStreams(*column, n, settings, state);
stopwatch.stop();
std::cout << "Reading, elapsed: " << stopwatch.elapsedSeconds() << std::endl;
std::cout << std::endl
<< get<const String &>((*column)[0]) << std::endl
<< get<const String &>((*column)[n - 1]) << std::endl;
}
return 0;
}
catch (const DB::Exception & e)
{
std::cerr << e.what() << ", " << e.displayText() << std::endl;
return 1;
}

View File

@ -1,41 +0,0 @@
#include <iostream>
#include <fstream>
#include <Common/Stopwatch.h>
#include <IO/WriteBufferFromOStream.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesNumber.h>
int main(int, char **)
{
using namespace DB;
auto column = ColumnUInt64::create();
ColumnUInt64::Container & vec = column->getData();
DataTypeUInt64 data_type;
Stopwatch stopwatch;
size_t n = 10000000;
vec.resize(n);
for (size_t i = 0; i < n; ++i)
vec[i] = i;
std::ofstream ostr("test");
WriteBufferFromOStream out_buf(ostr);
stopwatch.restart();
IDataType::SerializeBinaryBulkSettings settings;
settings.getter = [&](const IDataType::SubstreamPath &){ return &out_buf; };
IDataType::SerializeBinaryBulkStatePtr state;
data_type.serializeBinaryBulkStatePrefix(settings, state);
data_type.serializeBinaryBulkWithMultipleStreams(*column, 0, 0, settings, state);
data_type.serializeBinaryBulkStateSuffix(settings, state);
stopwatch.stop();
std::cout << "Elapsed: " << stopwatch.elapsedSeconds() << std::endl;
return 0;
}

View File

@ -1,4 +1,4 @@
set(SRCS ) set(SRCS )
add_executable (tab_separated_streams tab_separated_streams.cpp ${SRCS}) add_executable (tab_separated_streams tab_separated_streams.cpp ${SRCS})
target_link_libraries (tab_separated_streams PRIVATE dbms) target_link_libraries (tab_separated_streams PRIVATE clickhouse_aggregate_functions dbms)

View File

@ -20,7 +20,7 @@ namespace ErrorCodes
void FunctionFactory::registerFunction(const void FunctionFactory::registerFunction(const
std::string & name, std::string & name,
Creator creator, Value creator,
CaseSensitiveness case_sensitiveness) CaseSensitiveness case_sensitiveness)
{ {
if (!functions.emplace(name, creator).second) if (!functions.emplace(name, creator).second)

View File

@ -53,7 +53,7 @@ public:
FunctionOverloadResolverImplPtr tryGetImpl(const std::string & name, const Context & context) const; FunctionOverloadResolverImplPtr tryGetImpl(const std::string & name, const Context & context) const;
private: private:
using Functions = std::unordered_map<std::string, Creator>; using Functions = std::unordered_map<std::string, Value>;
Functions functions; Functions functions;
Functions case_insensitive_functions; Functions case_insensitive_functions;
@ -64,9 +64,9 @@ private:
return std::make_unique<DefaultOverloadResolver>(Function::create(context)); return std::make_unique<DefaultOverloadResolver>(Function::create(context));
} }
const Functions & getCreatorMap() const override { return functions; } const Functions & getMap() const override { return functions; }
const Functions & getCaseInsensitiveCreatorMap() const override { return case_insensitive_functions; } const Functions & getCaseInsensitiveMap() const override { return case_insensitive_functions; }
String getFactoryName() const override { return "FunctionFactory"; } String getFactoryName() const override { return "FunctionFactory"; }
@ -74,7 +74,7 @@ private:
/// No locking, you must register all functions before usage of get. /// No locking, you must register all functions before usage of get.
void registerFunction( void registerFunction(
const std::string & name, const std::string & name,
Creator creator, Value creator,
CaseSensitiveness case_sensitiveness = CaseSensitive); CaseSensitiveness case_sensitiveness = CaseSensitive);
}; };

View File

@ -113,8 +113,9 @@ public:
auto nested_type = array_type->getNestedType(); auto nested_type = array_type->getNestedType();
DataTypes argument_types = {nested_type}; DataTypes argument_types = {nested_type};
Array params_row; Array params_row;
AggregateFunctionPtr bitmap_function AggregateFunctionProperties properties;
= AggregateFunctionFactory::instance().get(AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row); AggregateFunctionPtr bitmap_function = AggregateFunctionFactory::instance().get(
AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row, properties);
return std::make_shared<DataTypeAggregateFunction>(bitmap_function, argument_types, params_row); return std::make_shared<DataTypeAggregateFunction>(bitmap_function, argument_types, params_row);
} }
@ -156,8 +157,9 @@ private:
// output data // output data
Array params_row; Array params_row;
AggregateFunctionPtr bitmap_function AggregateFunctionProperties properties;
= AggregateFunctionFactory::instance().get(AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row); AggregateFunctionPtr bitmap_function = AggregateFunctionFactory::instance().get(
AggregateFunctionGroupBitmapData<UInt32>::name(), argument_types, params_row, properties);
auto col_to = ColumnAggregateFunction::create(bitmap_function); auto col_to = ColumnAggregateFunction::create(bitmap_function);
col_to->reserve(offsets.size()); col_to->reserve(offsets.size());

View File

@ -97,7 +97,8 @@ DataTypePtr FunctionArrayReduce::getReturnTypeImpl(const ColumnsWithTypeAndName
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params, getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName()); aggregate_function_name, params_row, "function " + getName());
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row); AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
} }
return aggregate_function->getReturnType(); return aggregate_function->getReturnType();

View File

@ -115,7 +115,8 @@ DataTypePtr FunctionArrayReduceInRanges::getReturnTypeImpl(const ColumnsWithType
getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params, getAggregateFunctionNameAndParametersArray(aggregate_function_name_with_params,
aggregate_function_name, params_row, "function " + getName()); aggregate_function_name, params_row, "function " + getName());
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row); AggregateFunctionProperties properties;
aggregate_function = AggregateFunctionFactory::instance().get(aggregate_function_name, argument_types, params_row, properties);
} }
return std::make_shared<DataTypeArray>(aggregate_function->getReturnType()); return std::make_shared<DataTypeArray>(aggregate_function->getReturnType());

View File

@ -156,7 +156,8 @@ namespace detail
public: public:
using OutStreamCallback = std::function<void(std::ostream &)>; using OutStreamCallback = std::function<void(std::ostream &)>;
explicit ReadWriteBufferFromHTTPBase(UpdatableSessionPtr session_, explicit ReadWriteBufferFromHTTPBase(
UpdatableSessionPtr session_,
Poco::URI uri_, Poco::URI uri_,
const std::string & method_ = {}, const std::string & method_ = {},
OutStreamCallback out_stream_callback_ = {}, OutStreamCallback out_stream_callback_ = {},
@ -245,9 +246,9 @@ class ReadWriteBufferFromHTTP : public detail::ReadWriteBufferFromHTTPBase<std::
public: public:
explicit ReadWriteBufferFromHTTP(Poco::URI uri_, explicit ReadWriteBufferFromHTTP(Poco::URI uri_,
const std::string & method_ = {}, const std::string & method_,
OutStreamCallback out_stream_callback_ = {}, OutStreamCallback out_stream_callback_,
const ConnectionTimeouts & timeouts = {}, const ConnectionTimeouts & timeouts,
const SettingUInt64 max_redirects = 0, const SettingUInt64 max_redirects = 0,
const Poco::Net::HTTPBasicCredentials & credentials_ = {}, const Poco::Net::HTTPBasicCredentials & credentials_ = {},
size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE, size_t buffer_size_ = DBMS_DEFAULT_BUFFER_SIZE,

View File

@ -0,0 +1,114 @@
#pragma once
#include <Functions/FunctionFactory.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/InDepthNodeVisitor.h>
#include <Parsers/ASTFunction.h>
#include <Parsers/ASTLiteral.h>
#include <Parsers/ASTSelectQuery.h>
#include <Parsers/ASTSetQuery.h>
#include <Parsers/ASTTablesInSelectQuery.h>
#include <Parsers/IAST.h>
#include <Common/typeid_cast.h>
namespace DB
{
///recursive traversal and check for optimizeAggregateFunctionsOfGroupByKeys
struct KeepAggregateFunctionMatcher
{
struct Data
{
std::unordered_set<String> & group_by_keys;
bool & keep_aggregator;
};
using Visitor = InDepthNodeVisitor<KeepAggregateFunctionMatcher, true>;
static bool needChildVisit(const ASTPtr & node, const ASTPtr &)
{
return !(node->as<ASTFunction>());
}
static void visit(ASTFunction & function_node, Data & data)
{
if ((function_node.arguments->children).empty())
{
data.keep_aggregator = true;
return;
}
if (!data.group_by_keys.count(function_node.getColumnName()))
{
Visitor(data).visit(function_node.arguments);
}
}
static void visit(ASTIdentifier & ident, Data & data)
{
if (!data.group_by_keys.count(ident.shortName()))
{
/// if variable of a function is not in GROUP BY keys, this function should not be deleted
data.keep_aggregator = true;
return;
}
}
static void visit(const ASTPtr & ast, Data & data)
{
if (data.keep_aggregator)
return;
if (auto * function_node = ast->as<ASTFunction>())
{
visit(*function_node, data);
}
else if (auto * ident = ast->as<ASTIdentifier>())
{
visit(*ident, data);
}
else if (!ast->as<ASTExpressionList>())
{
data.keep_aggregator = true;
}
}
};
using KeepAggregateFunctionVisitor = InDepthNodeVisitor<KeepAggregateFunctionMatcher, true>;
class SelectAggregateFunctionOfGroupByKeysMatcher
{
public:
struct Data
{
std::unordered_set<String> & group_by_keys;
};
static bool needChildVisit(const ASTPtr & node, const ASTPtr &)
{
return !(node->as<ASTFunction>());
}
static void visit(ASTPtr & ast, Data & data)
{
///check if function is min/max/any
auto * function_node = ast->as<ASTFunction>();
if (function_node && (function_node->name == "min" || function_node->name == "max" ||
function_node->name == "any" || function_node->name == "anyLast"))
{
bool keep_aggregator = false;
KeepAggregateFunctionVisitor::Data keep_data{data.group_by_keys, keep_aggregator};
KeepAggregateFunctionVisitor(keep_data).visit(function_node->arguments);
if (!keep_aggregator)
{
///place argument of an aggregate function instead of function
ast = (function_node->arguments->children[0])->clone();
}
}
}
};
using SelectAggregateFunctionOfGroupByKeysVisitor = InDepthNodeVisitor<SelectAggregateFunctionOfGroupByKeysMatcher, true>;
}

View File

@ -47,7 +47,8 @@ public:
String current_user; String current_user;
String current_query_id; String current_query_id;
Poco::Net::SocketAddress current_address; Poco::Net::SocketAddress current_address;
/// Use current user and password when sending query to replica leader
/// This field is only used in foreign "Arcadia" build.
String current_password; String current_password;
/// When query_kind == INITIAL_QUERY, these values are equal to current. /// When query_kind == INITIAL_QUERY, these values are equal to current.

View File

@ -660,9 +660,13 @@ void Context::setUser(const String & name, const String & password, const Poco::
auto lock = getLock(); auto lock = getLock();
client_info.current_user = name; client_info.current_user = name;
client_info.current_password = password;
client_info.current_address = address; client_info.current_address = address;
#if defined(ARCADIA_BUILD)
/// This is harmful field that is used only in foreign "Arcadia" build.
client_info.current_password = password;
#endif
auto new_user_id = getAccessControlManager().find<User>(name); auto new_user_id = getAccessControlManager().find<User>(name);
std::shared_ptr<const ContextAccess> new_access; std::shared_ptr<const ContextAccess> new_access;
if (new_user_id) if (new_user_id)

View File

@ -481,6 +481,7 @@ void DDLWorker::parseQueryAndResolveHost(DDLTask & task)
const auto & shards = task.cluster->getShardsAddresses(); const auto & shards = task.cluster->getShardsAddresses();
bool found_exact_match = false; bool found_exact_match = false;
String default_database;
for (size_t shard_num = 0; shard_num < shards.size(); ++shard_num) for (size_t shard_num = 0; shard_num < shards.size(); ++shard_num)
{ {
for (size_t replica_num = 0; replica_num < shards[shard_num].size(); ++replica_num) for (size_t replica_num = 0; replica_num < shards[shard_num].size(); ++replica_num)
@ -491,14 +492,38 @@ void DDLWorker::parseQueryAndResolveHost(DDLTask & task)
{ {
if (found_exact_match) if (found_exact_match)
{ {
throw Exception("There are two exactly the same ClickHouse instances " + address.readableString() if (default_database == address.default_database)
+ " in cluster " + task.cluster_name, ErrorCodes::INCONSISTENT_CLUSTER_DEFINITION); {
throw Exception(
"There are two exactly the same ClickHouse instances " + address.readableString() + " in cluster "
+ task.cluster_name,
ErrorCodes::INCONSISTENT_CLUSTER_DEFINITION);
}
else
{
/* Circular replication is used.
* It is when every physical node contains
* replicas of different shards of the same table.
* To distinguish one replica from another on the same node,
* every shard is placed into separate database.
* */
is_circular_replicated = true;
auto * query_with_table = dynamic_cast<ASTQueryWithTableAndOutput *>(task.query.get());
if (!query_with_table || query_with_table->database.empty())
{
throw Exception(
"For a distributed DDL on circular replicated cluster its table name must be qualified by database name.",
ErrorCodes::INCONSISTENT_CLUSTER_DEFINITION);
}
if (default_database == query_with_table->database)
return;
}
} }
found_exact_match = true; found_exact_match = true;
task.host_shard_num = shard_num; task.host_shard_num = shard_num;
task.host_replica_num = replica_num; task.host_replica_num = replica_num;
task.address_in_cluster = address; task.address_in_cluster = address;
default_database = address.default_database;
} }
} }
} }
@ -621,6 +646,7 @@ void DDLWorker::processTask(DDLTask & task, const ZooKeeperPtr & zookeeper)
{ {
try try
{ {
is_circular_replicated = false;
parseQueryAndResolveHost(task); parseQueryAndResolveHost(task);
ASTPtr rewritten_ast = task.query_on_cluster->getRewrittenASTWithoutOnCluster(task.address_in_cluster.default_database); ASTPtr rewritten_ast = task.query_on_cluster->getRewrittenASTWithoutOnCluster(task.address_in_cluster.default_database);
@ -643,7 +669,7 @@ void DDLWorker::processTask(DDLTask & task, const ZooKeeperPtr & zookeeper)
if (storage && query_with_table->as<ASTAlterQuery>()) if (storage && query_with_table->as<ASTAlterQuery>())
checkShardConfig(query_with_table->table, task, storage); checkShardConfig(query_with_table->table, task, storage);
if (storage && taskShouldBeExecutedOnLeader(rewritten_ast, storage)) if (storage && taskShouldBeExecutedOnLeader(rewritten_ast, storage) && !is_circular_replicated)
tryExecuteQueryOnLeaderReplica(task, storage, rewritten_query, task.entry_path, zookeeper); tryExecuteQueryOnLeaderReplica(task, storage, rewritten_query, task.entry_path, zookeeper);
else else
tryExecuteQuery(rewritten_query, task, task.execution_status); tryExecuteQuery(rewritten_query, task, task.execution_status);

View File

@ -100,6 +100,7 @@ private:
void attachToThreadGroup(); void attachToThreadGroup();
private: private:
bool is_circular_replicated;
Context & context; Context & context;
Poco::Logger * log; Poco::Logger * log;
std::unique_ptr<Context> current_context; std::unique_ptr<Context> current_context;

View File

@ -420,8 +420,9 @@ bool ExpressionAnalyzer::makeAggregateDescriptions(ExpressionActionsPtr & action
aggregate.argument_names[i] = name; aggregate.argument_names[i] = name;
} }
AggregateFunctionProperties properties;
aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters) : Array(); aggregate.parameters = (node->parameters) ? getAggregateFunctionParametersArray(node->parameters) : Array();
aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters); aggregate.function = AggregateFunctionFactory::instance().get(node->name, types, aggregate.parameters, properties);
aggregate_descriptions.push_back(aggregate); aggregate_descriptions.push_back(aggregate);
} }
@ -1151,13 +1152,20 @@ ExpressionAnalysisResult::ExpressionAnalysisResult(
} }
} }
bool has_stream_with_non_joined_rows = (before_join && before_join->getTableJoinAlgo()->hasStreamWithNonJoinedRows()); bool join_allow_read_in_order = true;
if (before_join)
{
/// You may find it strange but we support read_in_order for HashJoin and do not support for MergeJoin.
auto join = before_join->getTableJoinAlgo();
join_allow_read_in_order = typeid_cast<HashJoin *>(join.get()) && !join->hasStreamWithNonJoinedRows();
}
optimize_read_in_order = optimize_read_in_order =
settings.optimize_read_in_order settings.optimize_read_in_order
&& storage && query.orderBy() && storage && query.orderBy()
&& !query_analyzer.hasAggregation() && !query_analyzer.hasAggregation()
&& !query.final() && !query.final()
&& !has_stream_with_non_joined_rows; && join_allow_read_in_order;
/// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers. /// If there is aggregation, we execute expressions in SELECT and ORDER BY on the initiating server, otherwise on the source servers.
query_analyzer.appendSelect(chain, only_types || (need_aggregate ? !second_stage : !first_stage)); query_analyzer.appendSelect(chain, only_types || (need_aggregate ? !second_stage : !first_stage));

View File

@ -1,6 +1,6 @@
#include <Interpreters/InterpreterCreateQuotaQuery.h> #include <Interpreters/InterpreterCreateQuotaQuery.h>
#include <Parsers/ASTCreateQuotaQuery.h> #include <Parsers/ASTCreateQuotaQuery.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTRolesOrUsersSet.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/DDLWorker.h> #include <Interpreters/DDLWorker.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
@ -15,15 +15,18 @@ namespace DB
{ {
namespace namespace
{ {
void updateQuotaFromQueryImpl(Quota & quota, const ASTCreateQuotaQuery & query, const std::optional<ExtendedRoleSet> & roles_from_query = {}) void updateQuotaFromQueryImpl(
Quota & quota,
const ASTCreateQuotaQuery & query,
const String & override_name,
const std::optional<RolesOrUsersSet> & override_to_roles)
{ {
if (query.alter) if (!override_name.empty())
{ quota.setName(override_name);
if (!query.new_name.empty()) else if (!query.new_name.empty())
quota.setName(query.new_name); quota.setName(query.new_name);
} else if (query.names.size() == 1)
else quota.setName(query.names.front());
quota.setName(query.name);
if (query.key_type) if (query.key_type)
quota.key_type = *query.key_type; quota.key_type = *query.key_type;
@ -59,15 +62,10 @@ void updateQuotaFromQueryImpl(Quota & quota, const ASTCreateQuotaQuery & query,
quota_limits.max[resource_type] = query_limits.max[resource_type]; quota_limits.max[resource_type] = query_limits.max[resource_type];
} }
const ExtendedRoleSet * roles = nullptr; if (override_to_roles)
std::optional<ExtendedRoleSet> temp_role_set; quota.to_roles = *override_to_roles;
if (roles_from_query)
roles = &*roles_from_query;
else if (query.roles) else if (query.roles)
roles = &temp_role_set.emplace(*query.roles); quota.to_roles = *query.roles;
if (roles)
quota.to_roles = *roles;
} }
} }
@ -84,37 +82,42 @@ BlockIO InterpreterCreateQuotaQuery::execute()
return executeDDLQueryOnCluster(query_ptr, context); return executeDDLQueryOnCluster(query_ptr, context);
} }
std::optional<ExtendedRoleSet> roles_from_query; std::optional<RolesOrUsersSet> roles_from_query;
if (query.roles) if (query.roles)
roles_from_query = ExtendedRoleSet{*query.roles, access_control, context.getUserID()}; roles_from_query = RolesOrUsersSet{*query.roles, access_control, context.getUserID()};
if (query.alter) if (query.alter)
{ {
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_quota = typeid_cast<std::shared_ptr<Quota>>(entity->clone()); auto updated_quota = typeid_cast<std::shared_ptr<Quota>>(entity->clone());
updateQuotaFromQueryImpl(*updated_quota, query, roles_from_query); updateQuotaFromQueryImpl(*updated_quota, query, {}, roles_from_query);
return updated_quota; return updated_quota;
}; };
if (query.if_exists) if (query.if_exists)
{ {
if (auto id = access_control.find<Quota>(query.name)) auto ids = access_control.find<Quota>(query.names);
access_control.tryUpdate(*id, update_func); access_control.tryUpdate(ids, update_func);
} }
else else
access_control.update(access_control.getID<Quota>(query.name), update_func); access_control.update(access_control.getIDs<Quota>(query.names), update_func);
} }
else else
{ {
auto new_quota = std::make_shared<Quota>(); std::vector<AccessEntityPtr> new_quotas;
updateQuotaFromQueryImpl(*new_quota, query, roles_from_query); for (const String & name : query.names)
{
auto new_quota = std::make_shared<Quota>();
updateQuotaFromQueryImpl(*new_quota, query, name, roles_from_query);
new_quotas.emplace_back(std::move(new_quota));
}
if (query.if_not_exists) if (query.if_not_exists)
access_control.tryInsert(new_quota); access_control.tryInsert(new_quotas);
else if (query.or_replace) else if (query.or_replace)
access_control.insertOrReplace(new_quota); access_control.insertOrReplace(new_quotas);
else else
access_control.insert(new_quota); access_control.insert(new_quotas);
} }
return {}; return {};
@ -123,7 +126,7 @@ BlockIO InterpreterCreateQuotaQuery::execute()
void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query) void InterpreterCreateQuotaQuery::updateQuotaFromQuery(Quota & quota, const ASTCreateQuotaQuery & query)
{ {
updateQuotaFromQueryImpl(quota, query); updateQuotaFromQueryImpl(quota, query, {}, {});
} }
} }

View File

@ -13,25 +13,20 @@ namespace
void updateRoleFromQueryImpl( void updateRoleFromQueryImpl(
Role & role, Role & role,
const ASTCreateRoleQuery & query, const ASTCreateRoleQuery & query,
const std::optional<SettingsProfileElements> & settings_from_query = {}) const String & override_name,
const std::optional<SettingsProfileElements> & override_settings)
{ {
if (query.alter) if (!override_name.empty())
{ role.setName(override_name);
if (!query.new_name.empty()) else if (!query.new_name.empty())
role.setName(query.new_name); role.setName(query.new_name);
} else if (query.names.size() == 1)
else role.setName(query.names.front());
role.setName(query.name);
const SettingsProfileElements * settings = nullptr; if (override_settings)
std::optional<SettingsProfileElements> temp_settings; role.settings = *override_settings;
if (settings_from_query)
settings = &*settings_from_query;
else if (query.settings) else if (query.settings)
settings = &temp_settings.emplace(*query.settings); role.settings = *query.settings;
if (settings)
role.settings = *settings;
} }
} }
@ -57,28 +52,33 @@ BlockIO InterpreterCreateRoleQuery::execute()
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_role = typeid_cast<std::shared_ptr<Role>>(entity->clone()); auto updated_role = typeid_cast<std::shared_ptr<Role>>(entity->clone());
updateRoleFromQueryImpl(*updated_role, query, settings_from_query); updateRoleFromQueryImpl(*updated_role, query, {}, settings_from_query);
return updated_role; return updated_role;
}; };
if (query.if_exists) if (query.if_exists)
{ {
if (auto id = access_control.find<Role>(query.name)) auto ids = access_control.find<Role>(query.names);
access_control.tryUpdate(*id, update_func); access_control.tryUpdate(ids, update_func);
} }
else else
access_control.update(access_control.getID<Role>(query.name), update_func); access_control.update(access_control.getIDs<Role>(query.names), update_func);
} }
else else
{ {
auto new_role = std::make_shared<Role>(); std::vector<AccessEntityPtr> new_roles;
updateRoleFromQueryImpl(*new_role, query, settings_from_query); for (const auto & name : query.names)
{
auto new_role = std::make_shared<Role>();
updateRoleFromQueryImpl(*new_role, query, name, settings_from_query);
new_roles.emplace_back(std::move(new_role));
}
if (query.if_not_exists) if (query.if_not_exists)
access_control.tryInsert(new_role); access_control.tryInsert(new_roles);
else if (query.or_replace) else if (query.or_replace)
access_control.insertOrReplace(new_role); access_control.insertOrReplace(new_roles);
else else
access_control.insert(new_role); access_control.insert(new_roles);
} }
return {}; return {};
@ -87,6 +87,6 @@ BlockIO InterpreterCreateRoleQuery::execute()
void InterpreterCreateRoleQuery::updateRoleFromQuery(Role & role, const ASTCreateRoleQuery & query) void InterpreterCreateRoleQuery::updateRoleFromQuery(Role & role, const ASTCreateRoleQuery & query)
{ {
updateRoleFromQueryImpl(role, query); updateRoleFromQueryImpl(role, query, {}, {});
} }
} }

View File

@ -1,6 +1,7 @@
#include <Interpreters/InterpreterCreateRowPolicyQuery.h> #include <Interpreters/InterpreterCreateRowPolicyQuery.h>
#include <Parsers/ASTCreateRowPolicyQuery.h> #include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTRowPolicyName.h>
#include <Parsers/ASTRolesOrUsersSet.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/DDLWorker.h> #include <Interpreters/DDLWorker.h>
@ -16,35 +17,26 @@ namespace
void updateRowPolicyFromQueryImpl( void updateRowPolicyFromQueryImpl(
RowPolicy & policy, RowPolicy & policy,
const ASTCreateRowPolicyQuery & query, const ASTCreateRowPolicyQuery & query,
const std::optional<ExtendedRoleSet> & roles_from_query = {}) const RowPolicy::NameParts & override_name,
const std::optional<RolesOrUsersSet> & override_to_roles)
{ {
if (query.alter) if (!override_name.empty())
{ policy.setNameParts(override_name);
if (!query.new_short_name.empty()) else if (!query.new_short_name.empty())
policy.setShortName(query.new_short_name); policy.setShortName(query.new_short_name);
} else if (query.names->name_parts.size() == 1)
else policy.setNameParts(query.names->name_parts.front());
policy.setNameParts(query.name_parts);
if (query.is_restrictive) if (query.is_restrictive)
policy.setRestrictive(*query.is_restrictive); policy.setRestrictive(*query.is_restrictive);
for (auto condition_type : ext::range(RowPolicy::MAX_CONDITION_TYPE)) for (const auto & [condition_type, condition] : query.conditions)
{ policy.conditions[condition_type] = condition ? serializeAST(*condition) : String{};
const auto & condition = query.conditions[condition_type];
if (condition)
policy.conditions[condition_type] = *condition ? serializeAST(**condition) : String{};
}
const ExtendedRoleSet * roles = nullptr; if (override_to_roles)
std::optional<ExtendedRoleSet> temp_role_set; policy.to_roles = *override_to_roles;
if (roles_from_query)
roles = &*roles_from_query;
else if (query.roles) else if (query.roles)
roles = &temp_role_set.emplace(*query.roles); policy.to_roles = *query.roles;
if (roles)
policy.to_roles = *roles;
} }
} }
@ -61,40 +53,46 @@ BlockIO InterpreterCreateRowPolicyQuery::execute()
return executeDDLQueryOnCluster(query_ptr, context); return executeDDLQueryOnCluster(query_ptr, context);
} }
std::optional<ExtendedRoleSet> roles_from_query; assert(query.names->cluster.empty());
std::optional<RolesOrUsersSet> roles_from_query;
if (query.roles) if (query.roles)
roles_from_query = ExtendedRoleSet{*query.roles, access_control, context.getUserID()}; roles_from_query = RolesOrUsersSet{*query.roles, access_control, context.getUserID()};
if (query.name_parts.database.empty()) query.replaceEmptyDatabaseWithCurrent(context.getCurrentDatabase());
query.name_parts.database = context.getCurrentDatabase();
if (query.alter) if (query.alter)
{ {
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_policy = typeid_cast<std::shared_ptr<RowPolicy>>(entity->clone()); auto updated_policy = typeid_cast<std::shared_ptr<RowPolicy>>(entity->clone());
updateRowPolicyFromQueryImpl(*updated_policy, query, roles_from_query); updateRowPolicyFromQueryImpl(*updated_policy, query, {}, roles_from_query);
return updated_policy; return updated_policy;
}; };
Strings names = query.names->toStrings();
if (query.if_exists) if (query.if_exists)
{ {
if (auto id = access_control.find<RowPolicy>(query.name_parts.getName())) auto ids = access_control.find<RowPolicy>(names);
access_control.tryUpdate(*id, update_func); access_control.tryUpdate(ids, update_func);
} }
else else
access_control.update(access_control.getID<RowPolicy>(query.name_parts.getName()), update_func); access_control.update(access_control.getIDs<RowPolicy>(names), update_func);
} }
else else
{ {
auto new_policy = std::make_shared<RowPolicy>(); std::vector<AccessEntityPtr> new_policies;
updateRowPolicyFromQueryImpl(*new_policy, query, roles_from_query); for (const auto & name_parts : query.names->name_parts)
{
auto new_policy = std::make_shared<RowPolicy>();
updateRowPolicyFromQueryImpl(*new_policy, query, name_parts, roles_from_query);
new_policies.emplace_back(std::move(new_policy));
}
if (query.if_not_exists) if (query.if_not_exists)
access_control.tryInsert(new_policy); access_control.tryInsert(new_policies);
else if (query.or_replace) else if (query.or_replace)
access_control.insertOrReplace(new_policy); access_control.insertOrReplace(new_policies);
else else
access_control.insert(new_policy); access_control.insert(new_policies);
} }
return {}; return {};
@ -103,7 +101,7 @@ BlockIO InterpreterCreateRowPolicyQuery::execute()
void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query) void InterpreterCreateRowPolicyQuery::updateRowPolicyFromQuery(RowPolicy & policy, const ASTCreateRowPolicyQuery & query)
{ {
updateRowPolicyFromQueryImpl(policy, query); updateRowPolicyFromQueryImpl(policy, query, {}, {});
} }
} }

View File

@ -1,6 +1,6 @@
#include <Interpreters/InterpreterCreateSettingsProfileQuery.h> #include <Interpreters/InterpreterCreateSettingsProfileQuery.h>
#include <Parsers/ASTCreateSettingsProfileQuery.h> #include <Parsers/ASTCreateSettingsProfileQuery.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTRolesOrUsersSet.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/DDLWorker.h> #include <Interpreters/DDLWorker.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
@ -15,36 +15,26 @@ namespace
void updateSettingsProfileFromQueryImpl( void updateSettingsProfileFromQueryImpl(
SettingsProfile & profile, SettingsProfile & profile,
const ASTCreateSettingsProfileQuery & query, const ASTCreateSettingsProfileQuery & query,
const std::optional<SettingsProfileElements> & settings_from_query = {}, const String & override_name,
const std::optional<ExtendedRoleSet> & roles_from_query = {}) const std::optional<SettingsProfileElements> & override_settings,
const std::optional<RolesOrUsersSet> & override_to_roles)
{ {
if (query.alter) if (!override_name.empty())
{ profile.setName(override_name);
if (!query.new_name.empty()) else if (!query.new_name.empty())
profile.setName(query.new_name); profile.setName(query.new_name);
} else if (query.names.size() == 1)
else profile.setName(query.names.front());
profile.setName(query.name);
const SettingsProfileElements * settings = nullptr; if (override_settings)
std::optional<SettingsProfileElements> temp_settings; profile.elements = *override_settings;
if (settings_from_query)
settings = &*settings_from_query;
else if (query.settings) else if (query.settings)
settings = &temp_settings.emplace(*query.settings); profile.elements = *query.settings;
if (settings) if (override_to_roles)
profile.elements = *settings; profile.to_roles = *override_to_roles;
const ExtendedRoleSet * roles = nullptr;
std::optional<ExtendedRoleSet> temp_role_set;
if (roles_from_query)
roles = &*roles_from_query;
else if (query.to_roles) else if (query.to_roles)
roles = &temp_role_set.emplace(*query.to_roles); profile.to_roles = *query.to_roles;
if (roles)
profile.to_roles = *roles;
} }
} }
@ -68,37 +58,42 @@ BlockIO InterpreterCreateSettingsProfileQuery::execute()
if (query.settings) if (query.settings)
settings_from_query = SettingsProfileElements{*query.settings, access_control}; settings_from_query = SettingsProfileElements{*query.settings, access_control};
std::optional<ExtendedRoleSet> roles_from_query; std::optional<RolesOrUsersSet> roles_from_query;
if (query.to_roles) if (query.to_roles)
roles_from_query = ExtendedRoleSet{*query.to_roles, access_control, context.getUserID()}; roles_from_query = RolesOrUsersSet{*query.to_roles, access_control, context.getUserID()};
if (query.alter) if (query.alter)
{ {
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_profile = typeid_cast<std::shared_ptr<SettingsProfile>>(entity->clone()); auto updated_profile = typeid_cast<std::shared_ptr<SettingsProfile>>(entity->clone());
updateSettingsProfileFromQueryImpl(*updated_profile, query, settings_from_query, roles_from_query); updateSettingsProfileFromQueryImpl(*updated_profile, query, {}, settings_from_query, roles_from_query);
return updated_profile; return updated_profile;
}; };
if (query.if_exists) if (query.if_exists)
{ {
if (auto id = access_control.find<SettingsProfile>(query.name)) auto ids = access_control.find<SettingsProfile>(query.names);
access_control.tryUpdate(*id, update_func); access_control.tryUpdate(ids, update_func);
} }
else else
access_control.update(access_control.getID<SettingsProfile>(query.name), update_func); access_control.update(access_control.getIDs<SettingsProfile>(query.names), update_func);
} }
else else
{ {
auto new_profile = std::make_shared<SettingsProfile>(); std::vector<AccessEntityPtr> new_profiles;
updateSettingsProfileFromQueryImpl(*new_profile, query, settings_from_query, roles_from_query); for (const auto & name : query.names)
{
auto new_profile = std::make_shared<SettingsProfile>();
updateSettingsProfileFromQueryImpl(*new_profile, query, name, settings_from_query, roles_from_query);
new_profiles.emplace_back(std::move(new_profile));
}
if (query.if_not_exists) if (query.if_not_exists)
access_control.tryInsert(new_profile); access_control.tryInsert(new_profiles);
else if (query.or_replace) else if (query.or_replace)
access_control.insertOrReplace(new_profile); access_control.insertOrReplace(new_profiles);
else else
access_control.insert(new_profile); access_control.insert(new_profiles);
} }
return {}; return {};
@ -107,6 +102,6 @@ BlockIO InterpreterCreateSettingsProfileQuery::execute()
void InterpreterCreateSettingsProfileQuery::updateSettingsProfileFromQuery(SettingsProfile & SettingsProfile, const ASTCreateSettingsProfileQuery & query) void InterpreterCreateSettingsProfileQuery::updateSettingsProfileFromQuery(SettingsProfile & SettingsProfile, const ASTCreateSettingsProfileQuery & query)
{ {
updateSettingsProfileFromQueryImpl(SettingsProfile, query); updateSettingsProfileFromQueryImpl(SettingsProfile, query, {}, {}, {});
} }
} }

View File

@ -3,7 +3,8 @@
#include <Interpreters/InterpreterSetRoleQuery.h> #include <Interpreters/InterpreterSetRoleQuery.h>
#include <Interpreters/DDLWorker.h> #include <Interpreters/DDLWorker.h>
#include <Parsers/ASTCreateUserQuery.h> #include <Parsers/ASTCreateUserQuery.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTUserNameWithHost.h>
#include <Parsers/ASTRolesOrUsersSet.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/ContextAccess.h> #include <Access/ContextAccess.h>
@ -17,51 +18,50 @@ namespace
void updateUserFromQueryImpl( void updateUserFromQueryImpl(
User & user, User & user,
const ASTCreateUserQuery & query, const ASTCreateUserQuery & query,
const std::optional<ExtendedRoleSet> & default_roles_from_query = {}, const std::shared_ptr<ASTUserNameWithHost> & override_name,
const std::optional<SettingsProfileElements> & settings_from_query = {}) const std::optional<RolesOrUsersSet> & override_default_roles,
const std::optional<SettingsProfileElements> & override_settings)
{ {
if (query.alter) if (override_name)
{ user.setName(override_name->toString());
if (!query.new_name.empty()) else if (!query.new_name.empty())
user.setName(query.new_name); user.setName(query.new_name);
} else if (query.names->size() == 1)
else user.setName(query.names->front()->toString());
user.setName(query.name);
if (query.authentication) if (query.authentication)
user.authentication = *query.authentication; user.authentication = *query.authentication;
if (query.hosts) if (override_name && !override_name->host_pattern.empty())
{
user.allowed_client_hosts = AllowedClientHosts{};
user.allowed_client_hosts.addLikePattern(override_name->host_pattern);
}
else if (query.hosts)
user.allowed_client_hosts = *query.hosts; user.allowed_client_hosts = *query.hosts;
if (query.remove_hosts) if (query.remove_hosts)
user.allowed_client_hosts.remove(*query.remove_hosts); user.allowed_client_hosts.remove(*query.remove_hosts);
if (query.add_hosts) if (query.add_hosts)
user.allowed_client_hosts.add(*query.add_hosts); user.allowed_client_hosts.add(*query.add_hosts);
const ExtendedRoleSet * default_roles = nullptr; auto set_default_roles = [&](const RolesOrUsersSet & default_roles_)
std::optional<ExtendedRoleSet> temp_role_set;
if (default_roles_from_query)
default_roles = &*default_roles_from_query;
else if (query.default_roles)
default_roles = &temp_role_set.emplace(*query.default_roles);
if (default_roles)
{ {
if (!query.alter && !default_roles->all) if (!query.alter && !default_roles_.all)
user.granted_roles.grant(default_roles->getMatchingIDs()); user.granted_roles.grant(default_roles_.getMatchingIDs());
InterpreterSetRoleQuery::updateUserSetDefaultRoles(user, *default_roles); InterpreterSetRoleQuery::updateUserSetDefaultRoles(user, default_roles_);
} };
const SettingsProfileElements * settings = nullptr; if (override_default_roles)
std::optional<SettingsProfileElements> temp_settings; set_default_roles(*override_default_roles);
if (settings_from_query) else if (query.default_roles)
settings = &*settings_from_query; set_default_roles(*query.default_roles);
if (override_settings)
user.settings = *override_settings;
else if (query.settings) else if (query.settings)
settings = &temp_settings.emplace(*query.settings); user.settings = *query.settings;
if (settings)
user.settings = *settings;
} }
} }
@ -73,10 +73,10 @@ BlockIO InterpreterCreateUserQuery::execute()
auto access = context.getAccess(); auto access = context.getAccess();
access->checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER); access->checkAccess(query.alter ? AccessType::ALTER_USER : AccessType::CREATE_USER);
std::optional<ExtendedRoleSet> default_roles_from_query; std::optional<RolesOrUsersSet> default_roles_from_query;
if (query.default_roles) if (query.default_roles)
{ {
default_roles_from_query = ExtendedRoleSet{*query.default_roles, access_control}; default_roles_from_query = RolesOrUsersSet{*query.default_roles, access_control};
if (!query.alter && !default_roles_from_query->all) if (!query.alter && !default_roles_from_query->all)
{ {
for (const UUID & role : default_roles_from_query->getMatchingIDs()) for (const UUID & role : default_roles_from_query->getMatchingIDs())
@ -96,28 +96,34 @@ BlockIO InterpreterCreateUserQuery::execute()
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone()); auto updated_user = typeid_cast<std::shared_ptr<User>>(entity->clone());
updateUserFromQueryImpl(*updated_user, query, default_roles_from_query, settings_from_query); updateUserFromQueryImpl(*updated_user, query, {}, default_roles_from_query, settings_from_query);
return updated_user; return updated_user;
}; };
Strings names = query.names->toStrings();
if (query.if_exists) if (query.if_exists)
{ {
if (auto id = access_control.find<User>(query.name)) auto ids = access_control.find<User>(names);
access_control.tryUpdate(*id, update_func); access_control.tryUpdate(ids, update_func);
} }
else else
access_control.update(access_control.getID<User>(query.name), update_func); access_control.update(access_control.getIDs<User>(names), update_func);
} }
else else
{ {
auto new_user = std::make_shared<User>(); std::vector<AccessEntityPtr> new_users;
updateUserFromQueryImpl(*new_user, query, default_roles_from_query, settings_from_query); for (const auto & name : *query.names)
{
auto new_user = std::make_shared<User>();
updateUserFromQueryImpl(*new_user, query, name, default_roles_from_query, settings_from_query);
new_users.emplace_back(std::move(new_user));
}
if (query.if_not_exists) if (query.if_not_exists)
access_control.tryInsert(new_user); access_control.tryInsert(new_users);
else if (query.or_replace) else if (query.or_replace)
access_control.insertOrReplace(new_user); access_control.insertOrReplace(new_users);
else else
access_control.insert(new_user); access_control.insert(new_users);
} }
return {}; return {};
@ -126,7 +132,7 @@ BlockIO InterpreterCreateUserQuery::execute()
void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query) void InterpreterCreateUserQuery::updateUserFromQuery(User & user, const ASTCreateUserQuery & query)
{ {
updateUserFromQueryImpl(user, query); updateUserFromQueryImpl(user, query, {}, {}, {});
} }
} }

View File

@ -1,5 +1,6 @@
#include <Interpreters/InterpreterDropAccessEntityQuery.h> #include <Interpreters/InterpreterDropAccessEntityQuery.h>
#include <Parsers/ASTDropAccessEntityQuery.h> #include <Parsers/ASTDropAccessEntityQuery.h>
#include <Parsers/ASTRowPolicyName.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/DDLWorker.h> #include <Interpreters/DDLWorker.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
@ -30,26 +31,21 @@ BlockIO InterpreterDropAccessEntityQuery::execute()
if (!query.cluster.empty()) if (!query.cluster.empty())
return executeDDLQueryOnCluster(query_ptr, context); return executeDDLQueryOnCluster(query_ptr, context);
if (query.type == EntityType::ROW_POLICY) query.replaceEmptyDatabaseWithCurrent(context.getCurrentDatabase());
{
Strings names;
for (auto & name_parts : query.row_policies_name_parts)
{
if (name_parts.database.empty())
name_parts.database = context.getCurrentDatabase();
names.emplace_back(name_parts.getName());
}
if (query.if_exists)
access_control.tryRemove(access_control.find<RowPolicy>(names));
else
access_control.remove(access_control.getIDs<RowPolicy>(names));
return {};
}
if (query.if_exists) auto do_drop = [&](const Strings & names)
access_control.tryRemove(access_control.find(query.type, query.names)); {
if (query.if_exists)
access_control.tryRemove(access_control.find(query.type, names));
else
access_control.remove(access_control.getIDs(query.type, names));
};
if (query.type == EntityType::ROW_POLICY)
do_drop(query.row_policy_names->toStrings());
else else
access_control.remove(access_control.getIDs(query.type, query.names)); do_drop(query.names);
return {}; return {};
} }

View File

@ -17,6 +17,7 @@
#include <Parsers/ASTSetQuery.h> #include <Parsers/ASTSetQuery.h>
#include <Parsers/ASTSetRoleQuery.h> #include <Parsers/ASTSetRoleQuery.h>
#include <Parsers/ASTShowAccessEntitiesQuery.h> #include <Parsers/ASTShowAccessEntitiesQuery.h>
#include <Parsers/ASTShowAccessQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h> #include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTShowGrantsQuery.h> #include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/ASTShowPrivilegesQuery.h> #include <Parsers/ASTShowPrivilegesQuery.h>
@ -51,6 +52,7 @@
#include <Interpreters/InterpreterSetQuery.h> #include <Interpreters/InterpreterSetQuery.h>
#include <Interpreters/InterpreterSetRoleQuery.h> #include <Interpreters/InterpreterSetRoleQuery.h>
#include <Interpreters/InterpreterShowAccessEntitiesQuery.h> #include <Interpreters/InterpreterShowAccessEntitiesQuery.h>
#include <Interpreters/InterpreterShowAccessQuery.h>
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h> #include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/InterpreterShowGrantsQuery.h> #include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Interpreters/InterpreterShowPrivilegesQuery.h> #include <Interpreters/InterpreterShowPrivilegesQuery.h>
@ -231,6 +233,10 @@ std::unique_ptr<IInterpreter> InterpreterFactory::get(ASTPtr & query, Context &
{ {
return std::make_unique<InterpreterShowAccessEntitiesQuery>(query, context); return std::make_unique<InterpreterShowAccessEntitiesQuery>(query, context);
} }
else if (query->as<ASTShowAccessQuery>())
{
return std::make_unique<InterpreterShowAccessQuery>(query, context);
}
else if (query->as<ASTShowPrivilegesQuery>()) else if (query->as<ASTShowPrivilegesQuery>())
{ {
return std::make_unique<InterpreterShowPrivilegesQuery>(query, context); return std::make_unique<InterpreterShowPrivilegesQuery>(query, context);

View File

@ -1,11 +1,11 @@
#include <Interpreters/InterpreterGrantQuery.h> #include <Interpreters/InterpreterGrantQuery.h>
#include <Parsers/ASTGrantQuery.h> #include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTRolesOrUsersSet.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Interpreters/DDLWorker.h> #include <Interpreters/DDLWorker.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/ContextAccess.h> #include <Access/ContextAccess.h>
#include <Access/ExtendedRoleSet.h> #include <Access/RolesOrUsersSet.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/Role.h> #include <Access/Role.h>
#include <boost/range/algorithm/copy.hpp> #include <boost/range/algorithm/copy.hpp>
@ -74,7 +74,7 @@ BlockIO InterpreterGrantQuery::execute()
std::vector<UUID> roles_from_query; std::vector<UUID> roles_from_query;
if (query.roles) if (query.roles)
{ {
roles_from_query = ExtendedRoleSet{*query.roles, access_control}.getMatchingIDs(access_control); roles_from_query = RolesOrUsersSet{*query.roles, access_control}.getMatchingIDs(access_control);
for (const UUID & role_from_query : roles_from_query) for (const UUID & role_from_query : roles_from_query)
access->checkAdminOption(role_from_query); access->checkAdminOption(role_from_query);
} }
@ -85,7 +85,7 @@ BlockIO InterpreterGrantQuery::execute()
return executeDDLQueryOnCluster(query_ptr, context); return executeDDLQueryOnCluster(query_ptr, context);
} }
std::vector<UUID> to_roles = ExtendedRoleSet{*query.to_roles, access_control, context.getUserID()}.getMatchingIDs(access_control); std::vector<UUID> to_roles = RolesOrUsersSet{*query.to_roles, access_control, context.getUserID()}.getMatchingIDs(access_control);
String current_database = context.getCurrentDatabase(); String current_database = context.getCurrentDatabase();
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
@ -115,7 +115,7 @@ void InterpreterGrantQuery::updateUserFromQuery(User & user, const ASTGrantQuery
{ {
std::vector<UUID> roles_from_query; std::vector<UUID> roles_from_query;
if (query.roles) if (query.roles)
roles_from_query = ExtendedRoleSet{*query.roles}.getMatchingIDs(); roles_from_query = RolesOrUsersSet{*query.roles}.getMatchingIDs();
updateFromQueryImpl(user, query, roles_from_query, {}); updateFromQueryImpl(user, query, roles_from_query, {});
} }
@ -124,7 +124,7 @@ void InterpreterGrantQuery::updateRoleFromQuery(Role & role, const ASTGrantQuery
{ {
std::vector<UUID> roles_from_query; std::vector<UUID> roles_from_query;
if (query.roles) if (query.roles)
roles_from_query = ExtendedRoleSet{*query.roles}.getMatchingIDs(); roles_from_query = RolesOrUsersSet{*query.roles}.getMatchingIDs();
updateFromQueryImpl(role, query, roles_from_query, {}); updateFromQueryImpl(role, query, roles_from_query, {});
} }

View File

@ -1,8 +1,8 @@
#include <Interpreters/InterpreterSetRoleQuery.h> #include <Interpreters/InterpreterSetRoleQuery.h>
#include <Parsers/ASTSetRoleQuery.h> #include <Parsers/ASTSetRoleQuery.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTRolesOrUsersSet.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Access/ExtendedRoleSet.h> #include <Access/RolesOrUsersSet.h>
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/User.h> #include <Access/User.h>
@ -38,7 +38,7 @@ void InterpreterSetRoleQuery::setRole(const ASTSetRoleQuery & query)
} }
else else
{ {
ExtendedRoleSet roles_from_query{*query.roles, access_control}; RolesOrUsersSet roles_from_query{*query.roles, access_control};
boost::container::flat_set<UUID> new_current_roles; boost::container::flat_set<UUID> new_current_roles;
if (roles_from_query.all) if (roles_from_query.all)
{ {
@ -65,8 +65,8 @@ void InterpreterSetRoleQuery::setDefaultRole(const ASTSetRoleQuery & query)
context.checkAccess(AccessType::ALTER_USER); context.checkAccess(AccessType::ALTER_USER);
auto & access_control = context.getAccessControlManager(); auto & access_control = context.getAccessControlManager();
std::vector<UUID> to_users = ExtendedRoleSet{*query.to_users, access_control, context.getUserID()}.getMatchingIDs(access_control); std::vector<UUID> to_users = RolesOrUsersSet{*query.to_users, access_control, context.getUserID()}.getMatchingIDs(access_control);
ExtendedRoleSet roles_from_query{*query.roles, access_control}; RolesOrUsersSet roles_from_query{*query.roles, access_control};
auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr
{ {
@ -79,7 +79,7 @@ void InterpreterSetRoleQuery::setDefaultRole(const ASTSetRoleQuery & query)
} }
void InterpreterSetRoleQuery::updateUserSetDefaultRoles(User & user, const ExtendedRoleSet & roles_from_query) void InterpreterSetRoleQuery::updateUserSetDefaultRoles(User & user, const RolesOrUsersSet & roles_from_query)
{ {
if (!roles_from_query.all) if (!roles_from_query.all)
{ {

View File

@ -9,7 +9,7 @@ namespace DB
class Context; class Context;
class ASTSetRoleQuery; class ASTSetRoleQuery;
struct ExtendedRoleSet; struct RolesOrUsersSet;
struct User; struct User;
@ -20,7 +20,7 @@ public:
BlockIO execute() override; BlockIO execute() override;
static void updateUserSetDefaultRoles(User & user, const ExtendedRoleSet & roles_from_query); static void updateUserSetDefaultRoles(User & user, const RolesOrUsersSet & roles_from_query);
private: private:
void setRole(const ASTSetRoleQuery & query); void setRole(const ASTSetRoleQuery & query);

View File

@ -18,7 +18,7 @@ using EntityType = IAccessEntity::Type;
InterpreterShowAccessEntitiesQuery::InterpreterShowAccessEntitiesQuery(const ASTPtr & query_ptr_, Context & context_) InterpreterShowAccessEntitiesQuery::InterpreterShowAccessEntitiesQuery(const ASTPtr & query_ptr_, Context & context_)
: query_ptr(query_ptr_), context(context_), ignore_quota(query_ptr->as<ASTShowAccessEntitiesQuery &>().type == EntityType::QUOTA) : query_ptr(query_ptr_), context(context_)
{ {
} }
@ -31,7 +31,8 @@ BlockIO InterpreterShowAccessEntitiesQuery::execute()
String InterpreterShowAccessEntitiesQuery::getRewrittenQuery() const String InterpreterShowAccessEntitiesQuery::getRewrittenQuery() const
{ {
const auto & query = query_ptr->as<ASTShowAccessEntitiesQuery &>(); auto & query = query_ptr->as<ASTShowAccessEntitiesQuery &>();
query.replaceEmptyDatabaseWithCurrent(context.getCurrentDatabase());
String origin; String origin;
String expr = "*"; String expr = "*";
String filter, order; String filter, order;
@ -42,14 +43,18 @@ String InterpreterShowAccessEntitiesQuery::getRewrittenQuery() const
{ {
origin = "row_policies"; origin = "row_policies";
expr = "name"; expr = "name";
const String & table_name = query.table_name; if (!query.short_name.empty())
if (!table_name.empty()) filter += String{filter.empty() ? "" : " AND "} + "short_name = " + quoteString(query.short_name);
if (query.database_and_table_name)
{ {
String database = query.database; const String & database = query.database_and_table_name->first;
if (database.empty()) const String & table_name = query.database_and_table_name->second;
database = context.getCurrentDatabase(); if (!database.empty())
filter = "database = " + quoteString(database) + " AND table = " + quoteString(table_name); filter += String{filter.empty() ? "" : " AND "} + "database = " + quoteString(database);
expr = "short_name"; if (!table_name.empty())
filter += String{filter.empty() ? "" : " AND "} + "table = " + quoteString(table_name);
if (!database.empty() && !table_name.empty())
expr = "short_name";
} }
break; break;
} }

View File

@ -15,15 +15,14 @@ public:
BlockIO execute() override; BlockIO execute() override;
bool ignoreQuota() const override { return ignore_quota; } bool ignoreQuota() const override { return true; }
bool ignoreLimits() const override { return ignore_quota; } bool ignoreLimits() const override { return true; }
private: private:
String getRewrittenQuery() const; String getRewrittenQuery() const;
ASTPtr query_ptr; ASTPtr query_ptr;
Context & context; Context & context;
bool ignore_quota = false;
}; };
} }

View File

@ -0,0 +1,89 @@
#include <Interpreters/InterpreterShowAccessQuery.h>
#include <Parsers/formatAST.h>
#include <Interpreters/Context.h>
#include <Interpreters/InterpreterShowCreateAccessEntityQuery.h>
#include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Columns/ColumnString.h>
#include <DataStreams/OneBlockInputStream.h>
#include <DataTypes/DataTypeString.h>
#include <Access/AccessFlags.h>
#include <Access/AccessControlManager.h>
#include <ext/range.h>
#include <boost/range/algorithm/sort.hpp>
#include <boost/range/algorithm_ext/push_back.hpp>
namespace DB
{
using EntityType = IAccessEntity::Type;
BlockIO InterpreterShowAccessQuery::execute()
{
BlockIO res;
res.in = executeImpl();
return res;
}
BlockInputStreamPtr InterpreterShowAccessQuery::executeImpl() const
{
/// Build a create query.
ASTs queries = getCreateAndGrantQueries();
/// Build the result column.
MutableColumnPtr column = ColumnString::create();
std::stringstream ss;
for (const auto & query : queries)
{
ss.str("");
formatAST(*query, ss, false, true);
column->insert(ss.str());
}
String desc = "ACCESS";
return std::make_shared<OneBlockInputStream>(Block{{std::move(column), std::make_shared<DataTypeString>(), desc}});
}
std::vector<AccessEntityPtr> InterpreterShowAccessQuery::getEntities() const
{
const auto & access_control = context.getAccessControlManager();
context.checkAccess(AccessType::SHOW_ACCESS);
std::vector<AccessEntityPtr> entities;
for (auto type : ext::range(EntityType::MAX))
{
auto ids = access_control.findAll(type);
for (const auto & id : ids)
{
if (auto entity = access_control.tryRead(id))
entities.push_back(entity);
}
}
boost::range::sort(entities, IAccessEntity::LessByTypeAndName{});
return entities;
}
ASTs InterpreterShowAccessQuery::getCreateAndGrantQueries() const
{
auto entities = getEntities();
const auto & access_control = context.getAccessControlManager();
ASTs create_queries, grant_queries;
for (const auto & entity : entities)
{
create_queries.push_back(InterpreterShowCreateAccessEntityQuery::getCreateQuery(*entity, access_control));
if (entity->isTypeOf(EntityType::USER) || entity->isTypeOf(EntityType::USER))
boost::range::push_back(grant_queries, InterpreterShowGrantsQuery::getGrantQueries(*entity, access_control));
}
ASTs result = std::move(create_queries);
boost::range::push_back(result, std::move(grant_queries));
return result;
}
}

View File

@ -0,0 +1,36 @@
#pragma once
#include <Interpreters/IInterpreter.h>
#include <Parsers/IAST_fwd.h>
namespace DB
{
class Context;
struct IAccessEntity;
using AccessEntityPtr = std::shared_ptr<const IAccessEntity>;
/** Return all queries for creating access entities and grants.
*/
class InterpreterShowAccessQuery : public IInterpreter
{
public:
InterpreterShowAccessQuery(const ASTPtr & query_ptr_, Context & context_)
: query_ptr(query_ptr_), context(context_) {}
BlockIO execute() override;
bool ignoreQuota() const override { return true; }
bool ignoreLimits() const override { return true; }
private:
BlockInputStreamPtr executeImpl() const;
ASTs getCreateAndGrantQueries() const;
std::vector<AccessEntityPtr> getEntities() const;
ASTPtr query_ptr;
Context & context;
};
}

View File

@ -6,8 +6,10 @@
#include <Parsers/ASTCreateRowPolicyQuery.h> #include <Parsers/ASTCreateRowPolicyQuery.h>
#include <Parsers/ASTCreateSettingsProfileQuery.h> #include <Parsers/ASTCreateSettingsProfileQuery.h>
#include <Parsers/ASTShowCreateAccessEntityQuery.h> #include <Parsers/ASTShowCreateAccessEntityQuery.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTUserNameWithHost.h>
#include <Parsers/ASTRolesOrUsersSet.h>
#include <Parsers/ASTSettingsProfileElement.h> #include <Parsers/ASTSettingsProfileElement.h>
#include <Parsers/ASTRowPolicyName.h>
#include <Parsers/ExpressionListParsers.h> #include <Parsers/ExpressionListParsers.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Parsers/parseQuery.h> #include <Parsers/parseQuery.h>
@ -23,6 +25,7 @@
#include <Common/StringUtils/StringUtils.h> #include <Common/StringUtils/StringUtils.h>
#include <Core/Defines.h> #include <Core/Defines.h>
#include <ext/range.h> #include <ext/range.h>
#include <boost/range/algorithm/sort.hpp>
#include <sstream> #include <sstream>
@ -42,13 +45,14 @@ namespace
bool attach_mode) bool attach_mode)
{ {
auto query = std::make_shared<ASTCreateUserQuery>(); auto query = std::make_shared<ASTCreateUserQuery>();
query->name = user.getName(); query->names = std::make_shared<ASTUserNamesWithHost>();
query->names->push_back(user.getName());
query->attach = attach_mode; query->attach = attach_mode;
if (user.allowed_client_hosts != AllowedClientHosts::AnyHostTag{}) if (user.allowed_client_hosts != AllowedClientHosts::AnyHostTag{})
query->hosts = user.allowed_client_hosts; query->hosts = user.allowed_client_hosts;
if (user.default_roles != ExtendedRoleSet::AllTag{}) if (user.default_roles != RolesOrUsersSet::AllTag{})
{ {
if (attach_mode) if (attach_mode)
query->default_roles = user.default_roles.toAST(); query->default_roles = user.default_roles.toAST();
@ -77,7 +81,7 @@ namespace
ASTPtr getCreateQueryImpl(const Role & role, const AccessControlManager * manager, bool attach_mode) ASTPtr getCreateQueryImpl(const Role & role, const AccessControlManager * manager, bool attach_mode)
{ {
auto query = std::make_shared<ASTCreateRoleQuery>(); auto query = std::make_shared<ASTCreateRoleQuery>();
query->name = role.getName(); query->names.emplace_back(role.getName());
query->attach = attach_mode; query->attach = attach_mode;
if (!role.settings.empty()) if (!role.settings.empty())
@ -95,7 +99,7 @@ namespace
ASTPtr getCreateQueryImpl(const SettingsProfile & profile, const AccessControlManager * manager, bool attach_mode) ASTPtr getCreateQueryImpl(const SettingsProfile & profile, const AccessControlManager * manager, bool attach_mode)
{ {
auto query = std::make_shared<ASTCreateSettingsProfileQuery>(); auto query = std::make_shared<ASTCreateSettingsProfileQuery>();
query->name = profile.getName(); query->names.emplace_back(profile.getName());
query->attach = attach_mode; query->attach = attach_mode;
if (!profile.elements.empty()) if (!profile.elements.empty())
@ -126,10 +130,12 @@ namespace
bool attach_mode) bool attach_mode)
{ {
auto query = std::make_shared<ASTCreateQuotaQuery>(); auto query = std::make_shared<ASTCreateQuotaQuery>();
query->name = quota.getName(); query->names.emplace_back(quota.getName());
query->attach = attach_mode; query->attach = attach_mode;
query->key_type = quota.key_type; if (quota.key_type != Quota::KeyType::NONE)
query->key_type = quota.key_type;
query->all_limits.reserve(quota.all_limits.size()); query->all_limits.reserve(quota.all_limits.size());
for (const auto & limits : quota.all_limits) for (const auto & limits : quota.all_limits)
@ -160,7 +166,8 @@ namespace
bool attach_mode) bool attach_mode)
{ {
auto query = std::make_shared<ASTCreateRowPolicyQuery>(); auto query = std::make_shared<ASTCreateRowPolicyQuery>();
query->name_parts = policy.getNameParts(); query->names = std::make_shared<ASTRowPolicyNames>();
query->names->name_parts.emplace_back(policy.getNameParts());
query->attach = attach_mode; query->attach = attach_mode;
if (policy.isRestrictive()) if (policy.isRestrictive())
@ -173,7 +180,7 @@ namespace
{ {
ParserExpression parser; ParserExpression parser;
ASTPtr expr = parseQuery(parser, condition, 0, DBMS_DEFAULT_MAX_PARSER_DEPTH); ASTPtr expr = parseQuery(parser, condition, 0, DBMS_DEFAULT_MAX_PARSER_DEPTH);
query->conditions[static_cast<size_t>(type)] = expr; query->conditions.emplace_back(type, std::move(expr));
} }
} }
@ -211,7 +218,7 @@ namespace
InterpreterShowCreateAccessEntityQuery::InterpreterShowCreateAccessEntityQuery(const ASTPtr & query_ptr_, const Context & context_) InterpreterShowCreateAccessEntityQuery::InterpreterShowCreateAccessEntityQuery(const ASTPtr & query_ptr_, const Context & context_)
: query_ptr(query_ptr_), context(context_), ignore_quota(query_ptr->as<ASTShowCreateAccessEntityQuery &>().type == EntityType::QUOTA) : query_ptr(query_ptr_), context(context_)
{ {
} }
@ -226,23 +233,22 @@ BlockIO InterpreterShowCreateAccessEntityQuery::execute()
BlockInputStreamPtr InterpreterShowCreateAccessEntityQuery::executeImpl() BlockInputStreamPtr InterpreterShowCreateAccessEntityQuery::executeImpl()
{ {
auto & show_query = query_ptr->as<ASTShowCreateAccessEntityQuery &>(); /// Build a create queries.
ASTs create_queries = getCreateQueries();
/// Build a create query.
ASTPtr create_query = getCreateQuery(show_query);
/// Build the result column. /// Build the result column.
MutableColumnPtr column = ColumnString::create(); MutableColumnPtr column = ColumnString::create();
if (create_query) std::stringstream create_query_ss;
for (const auto & create_query : create_queries)
{ {
std::stringstream create_query_ss;
formatAST(*create_query, create_query_ss, false, true); formatAST(*create_query, create_query_ss, false, true);
String create_query_str = create_query_ss.str(); column->insert(create_query_ss.str());
column->insert(create_query_str); create_query_ss.str("");
} }
/// Prepare description of the result column. /// Prepare description of the result column.
std::stringstream desc_ss; std::stringstream desc_ss;
const auto & show_query = query_ptr->as<const ASTShowCreateAccessEntityQuery &>();
formatAST(show_query, desc_ss, false, true); formatAST(show_query, desc_ss, false, true);
String desc = desc_ss.str(); String desc = desc_ss.str();
String prefix = "SHOW "; String prefix = "SHOW ";
@ -253,38 +259,91 @@ BlockInputStreamPtr InterpreterShowCreateAccessEntityQuery::executeImpl()
} }
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(ASTShowCreateAccessEntityQuery & show_query) const std::vector<AccessEntityPtr> InterpreterShowCreateAccessEntityQuery::getEntities() const
{ {
auto & show_query = query_ptr->as<ASTShowCreateAccessEntityQuery &>();
const auto & access_control = context.getAccessControlManager(); const auto & access_control = context.getAccessControlManager();
context.checkAccess(getRequiredAccess()); context.checkAccess(getRequiredAccess());
show_query.replaceEmptyDatabaseWithCurrent(context.getCurrentDatabase());
std::vector<AccessEntityPtr> entities;
if (show_query.current_user) if (show_query.all)
{ {
auto user = context.getUser(); auto ids = access_control.findAll(show_query.type);
if (!user) for (const auto & id : ids)
return nullptr; {
return getCreateQueryImpl(*user, &access_control, false); if (auto entity = access_control.tryRead(id))
entities.push_back(entity);
}
} }
else if (show_query.current_user)
if (show_query.current_quota) {
if (auto user = context.getUser())
entities.push_back(user);
}
else if (show_query.current_quota)
{ {
auto usage = context.getQuotaUsage(); auto usage = context.getQuotaUsage();
if (!usage) if (usage)
return nullptr; entities.push_back(access_control.read<Quota>(usage->quota_id));
auto quota = access_control.read<Quota>(usage->quota_id);
return getCreateQueryImpl(*quota, &access_control, false);
} }
else if (show_query.type == EntityType::ROW_POLICY)
if (show_query.type == EntityType::ROW_POLICY)
{ {
if (show_query.row_policy_name_parts.database.empty()) auto ids = access_control.findAll<RowPolicy>();
show_query.row_policy_name_parts.database = context.getCurrentDatabase(); if (show_query.row_policy_names)
RowPolicyPtr policy = access_control.read<RowPolicy>(show_query.row_policy_name_parts.getName()); {
return getCreateQueryImpl(*policy, &access_control, false); for (const String & name : show_query.row_policy_names->toStrings())
entities.push_back(access_control.read<RowPolicy>(name));
}
else
{
for (const auto & id : ids)
{
auto policy = access_control.tryRead<RowPolicy>(id);
if (!policy)
continue;
if (!show_query.short_name.empty() && (policy->getShortName() != show_query.short_name))
continue;
if (show_query.database_and_table_name)
{
const String & database = show_query.database_and_table_name->first;
const String & table_name = show_query.database_and_table_name->second;
if (!database.empty() && (policy->getDatabase() != database))
continue;
if (!table_name.empty() && (policy->getTableName() != table_name))
continue;
}
entities.push_back(policy);
}
}
}
else
{
for (const String & name : show_query.names)
entities.push_back(access_control.read(access_control.getID(show_query.type, name)));
} }
auto entity = access_control.read(access_control.getID(show_query.type, show_query.name)); boost::range::sort(entities, IAccessEntity::LessByName{});
return getCreateQueryImpl(*entity, &access_control, false); return entities;
}
ASTs InterpreterShowCreateAccessEntityQuery::getCreateQueries() const
{
auto entities = getEntities();
ASTs list;
const auto & access_control = context.getAccessControlManager();
for (const auto & entity : entities)
list.push_back(getCreateQuery(*entity, access_control));
return list;
}
ASTPtr InterpreterShowCreateAccessEntityQuery::getCreateQuery(const IAccessEntity & entity, const AccessControlManager & access_control)
{
return getCreateQueryImpl(entity, &access_control, false);
} }

View File

@ -7,10 +7,11 @@
namespace DB namespace DB
{ {
class AccessControlManager;
class Context; class Context;
class ASTShowCreateAccessEntityQuery;
class AccessRightsElements; class AccessRightsElements;
struct IAccessEntity; struct IAccessEntity;
using AccessEntityPtr = std::shared_ptr<const IAccessEntity>;
/** Returns a single item containing a statement which could be used to create a specified role. /** Returns a single item containing a statement which could be used to create a specified role.
@ -22,19 +23,20 @@ public:
BlockIO execute() override; BlockIO execute() override;
bool ignoreQuota() const override { return ignore_quota; } bool ignoreQuota() const override { return true; }
bool ignoreLimits() const override { return ignore_quota; } bool ignoreLimits() const override { return true; }
static ASTPtr getCreateQuery(const IAccessEntity & entity, const AccessControlManager & access_control);
static ASTPtr getAttachQuery(const IAccessEntity & entity); static ASTPtr getAttachQuery(const IAccessEntity & entity);
private: private:
BlockInputStreamPtr executeImpl(); BlockInputStreamPtr executeImpl();
ASTPtr getCreateQuery(ASTShowCreateAccessEntityQuery & show_query) const; std::vector<AccessEntityPtr> getEntities() const;
ASTs getCreateQueries() const;
AccessRightsElements getRequiredAccess() const; AccessRightsElements getRequiredAccess() const;
ASTPtr query_ptr; ASTPtr query_ptr;
const Context & context; const Context & context;
bool ignore_quota = false;
}; };

View File

@ -1,7 +1,7 @@
#include <Interpreters/InterpreterShowGrantsQuery.h> #include <Interpreters/InterpreterShowGrantsQuery.h>
#include <Parsers/ASTShowGrantsQuery.h> #include <Parsers/ASTShowGrantsQuery.h>
#include <Parsers/ASTGrantQuery.h> #include <Parsers/ASTGrantQuery.h>
#include <Parsers/ASTExtendedRoleSet.h> #include <Parsers/ASTRolesOrUsersSet.h>
#include <Parsers/formatAST.h> #include <Parsers/formatAST.h>
#include <Interpreters/Context.h> #include <Interpreters/Context.h>
#include <Columns/ColumnString.h> #include <Columns/ColumnString.h>
@ -10,6 +10,9 @@
#include <Access/AccessControlManager.h> #include <Access/AccessControlManager.h>
#include <Access/User.h> #include <Access/User.h>
#include <Access/Role.h> #include <Access/Role.h>
#include <Access/RolesOrUsersSet.h>
#include <boost/range/algorithm/sort.hpp>
#include <boost/range/algorithm_ext/push_back.hpp>
namespace DB namespace DB
@ -29,7 +32,7 @@ namespace
{ {
ASTs res; ASTs res;
std::shared_ptr<ASTExtendedRoleSet> to_roles = std::make_shared<ASTExtendedRoleSet>(); std::shared_ptr<ASTRolesOrUsersSet> to_roles = std::make_shared<ASTRolesOrUsersSet>();
to_roles->names.push_back(grantee.getName()); to_roles->names.push_back(grantee.getName());
auto grants_and_partial_revokes = grantee.access.getGrantsAndPartialRevokes(); auto grants_and_partial_revokes = grantee.access.getGrantsAndPartialRevokes();
@ -87,9 +90,9 @@ namespace
grant_query->admin_option = admin_option; grant_query->admin_option = admin_option;
grant_query->to_roles = to_roles; grant_query->to_roles = to_roles;
if (attach_mode) if (attach_mode)
grant_query->roles = ExtendedRoleSet{roles}.toAST(); grant_query->roles = RolesOrUsersSet{roles}.toAST();
else else
grant_query->roles = ExtendedRoleSet{roles}.toASTWithNames(*manager); grant_query->roles = RolesOrUsersSet{roles}.toASTWithNames(*manager);
res.push_back(std::move(grant_query)); res.push_back(std::move(grant_query));
} }
@ -121,10 +124,8 @@ BlockIO InterpreterShowGrantsQuery::execute()
BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl() BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl()
{ {
const auto & show_query = query_ptr->as<ASTShowGrantsQuery &>();
/// Build a create query. /// Build a create query.
ASTs grant_queries = getGrantQueries(show_query); ASTs grant_queries = getGrantQueries();
/// Build the result column. /// Build the result column.
MutableColumnPtr column = ColumnString::create(); MutableColumnPtr column = ColumnString::create();
@ -138,6 +139,7 @@ BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl()
/// Prepare description of the result column. /// Prepare description of the result column.
std::stringstream desc_ss; std::stringstream desc_ss;
const auto & show_query = query_ptr->as<const ASTShowGrantsQuery &>();
formatAST(show_query, desc_ss, false, true); formatAST(show_query, desc_ss, false, true);
String desc = desc_ss.str(); String desc = desc_ss.str();
String prefix = "SHOW "; String prefix = "SHOW ";
@ -148,21 +150,41 @@ BlockInputStreamPtr InterpreterShowGrantsQuery::executeImpl()
} }
ASTs InterpreterShowGrantsQuery::getGrantQueries(const ASTShowGrantsQuery & show_query) const std::vector<AccessEntityPtr> InterpreterShowGrantsQuery::getEntities() const
{ {
const auto & show_query = query_ptr->as<ASTShowGrantsQuery &>();
const auto & access_control = context.getAccessControlManager(); const auto & access_control = context.getAccessControlManager();
auto ids = RolesOrUsersSet{*show_query.for_roles, access_control, context.getUserID()}.getMatchingIDs(access_control);
AccessEntityPtr user_or_role; std::vector<AccessEntityPtr> entities;
if (show_query.current_user) for (const auto & id : ids)
user_or_role = context.getUser();
else
{ {
user_or_role = access_control.tryRead<User>(show_query.name); auto entity = access_control.tryRead(id);
if (!user_or_role) if (entity)
user_or_role = access_control.read<Role>(show_query.name); entities.push_back(entity);
} }
return getGrantQueriesImpl(*user_or_role, &access_control); boost::range::sort(entities, IAccessEntity::LessByTypeAndName{});
return entities;
}
ASTs InterpreterShowGrantsQuery::getGrantQueries() const
{
auto entities = getEntities();
const auto & access_control = context.getAccessControlManager();
ASTs grant_queries;
for (const auto & entity : entities)
boost::range::push_back(grant_queries, getGrantQueries(*entity, access_control));
return grant_queries;
}
ASTs InterpreterShowGrantsQuery::getGrantQueries(const IAccessEntity & user_or_role, const AccessControlManager & access_control)
{
return getGrantQueriesImpl(user_or_role, &access_control, false);
} }

Some files were not shown because too many files have changed in this diff Show More