Merge branch 'master' into llvm-14

This commit is contained in:
Alexey Milovidov 2022-05-25 01:56:24 +02:00
commit a5541dc2d5
288 changed files with 6078 additions and 3611 deletions

View File

@ -13,9 +13,9 @@ on: # yamllint disable-line rule:truthy
branches:
- master
paths:
- 'docker/docs/**'
- 'docs/**'
- 'website/**'
- 'docker/docs/**'
jobs:
CheckLabels:
runs-on: [self-hosted, style-checker]

View File

@ -7,16 +7,17 @@ env:
concurrency:
group: master-release
cancel-in-progress: true
on: # yamllint disable-line rule:truthy
'on':
push:
branches:
- master
paths:
- 'docs/**'
- 'website/**'
- 'benchmark/**'
- 'docker/**'
- '.github/**'
- 'benchmark/**'
- 'docker/docs/release/**'
- 'docs/**'
- 'utils/list-versions/version_date.tsv'
- 'website/**'
workflow_dispatch:
jobs:
DockerHubPushAarch64:

View File

@ -13,6 +13,7 @@ on: # yamllint disable-line rule:truthy
branches:
- master
paths-ignore:
- 'docker/docs/**'
- 'docs/**'
- 'website/**'
##########################################################################################

4
.gitmodules vendored
View File

@ -265,10 +265,6 @@
[submodule "contrib/wyhash"]
path = contrib/wyhash
url = https://github.com/wangyi-fudan/wyhash.git
[submodule "contrib/eigen"]
path = contrib/eigen
url = https://github.com/eigen-mirror/eigen
[submodule "contrib/hashidsxx"]
path = contrib/hashidsxx
url = https://github.com/schoentoon/hashidsxx.git

View File

@ -105,6 +105,25 @@
# define ASAN_POISON_MEMORY_REGION(a, b)
#endif
#if !defined(ABORT_ON_LOGICAL_ERROR)
#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) || defined(THREAD_SANITIZER) || defined(MEMORY_SANITIZER) || defined(UNDEFINED_BEHAVIOR_SANITIZER)
#define ABORT_ON_LOGICAL_ERROR
#endif
#endif
/// chassert(x) is similar to assert(x), but:
/// - works in builds with sanitizers, not only in debug builds
/// - tries to print failed assertion into server log
/// It can be used for all assertions except heavy ones.
/// Heavy assertions (that run loops or call complex functions) are allowed in debug builds only.
#if !defined(chassert)
#if defined(ABORT_ON_LOGICAL_ERROR)
#define chassert(x) static_cast<bool>(x) ? void(0) : abortOnFailedAssertion(#x)
#else
#define chassert(x) ((void)0)
#endif
#endif
/// A template function for suppressing warnings about unused variables or function results.
template <typename... Args>
constexpr void UNUSED(Args &&... args [[maybe_unused]])

View File

@ -153,7 +153,6 @@ endif()
add_contrib (sqlite-cmake sqlite-amalgamation)
add_contrib (s2geometry-cmake s2geometry)
add_contrib (eigen-cmake eigen)
# Put all targets defined here and in subdirectories under "contrib/<immediate-subdir>" folders in GUI-based IDEs.
# Some of third-party projects may override CMAKE_FOLDER or FOLDER property of their targets, so they would not appear

1
contrib/eigen vendored

@ -1 +0,0 @@
Subproject commit 3147391d946bb4b6c68edd901f2add6ac1f31f8c

View File

@ -1,16 +0,0 @@
set(EIGEN_LIBRARY_DIR "${ClickHouse_SOURCE_DIR}/contrib/eigen")
add_library (_eigen INTERFACE)
# Only include MPL2 code from Eigen library
target_compile_definitions(_eigen INTERFACE EIGEN_MPL2_ONLY)
# Clang by default mimics gcc 4.2.1 compatibility but Eigen checks __GNUC__ version to enable
# a workaround for bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=72867 fixed in 6.3
# So we fake gcc > 6.3 when building with clang
if (COMPILER_CLANG AND ARCH_PPC64LE)
target_compile_options(_eigen INTERFACE -fgnuc-version=6.4)
endif()
target_include_directories (_eigen SYSTEM INTERFACE ${EIGEN_LIBRARY_DIR})
add_library(ch_contrib::eigen ALIAS _eigen)

View File

@ -170,7 +170,13 @@ endif ()
target_compile_definitions(_jemalloc PRIVATE -DJEMALLOC_PROF=1)
if (USE_UNWIND)
target_compile_definitions (_jemalloc PRIVATE -DJEMALLOC_PROF_LIBUNWIND=1)
# jemalloc provides support for two different libunwind flavors: the original HP libunwind and the one coming with gcc / g++ / libstdc++.
# The latter is identified by `JEMALLOC_PROF_LIBGCC` and uses `_Unwind_Backtrace` method instead of `unw_backtrace`.
# At the time ClickHouse uses LLVM libunwind which follows libgcc's way of backtracing.
# ClickHouse has to provide `unw_backtrace` method by the means of [commit 8e2b31e](https://github.com/ClickHouse/libunwind/commit/8e2b31e766dd502f6df74909e04a7dbdf5182eb1).
target_compile_definitions (_jemalloc PRIVATE -DJEMALLOC_PROF_LIBGCC=1)
target_link_libraries (_jemalloc PRIVATE unwind)
endif ()

View File

@ -0,0 +1,44 @@
# docker build -t clickhouse/docs-release .
FROM ubuntu:20.04
# ARG for quick switch to a given ubuntu mirror
ARG apt_archive="http://archive.ubuntu.com"
RUN sed -i "s|http://archive.ubuntu.com|$apt_archive|g" /etc/apt/sources.list
ENV LANG=C.UTF-8
RUN apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install --yes --no-install-recommends \
wget \
bash \
python \
curl \
python3-requests \
sudo \
git \
openssl \
python3-pip \
software-properties-common \
fonts-arphic-ukai \
fonts-arphic-uming \
fonts-ipafont-mincho \
fonts-ipafont-gothic \
fonts-unfonts-core \
xvfb \
ssh-client \
&& apt-get autoremove --yes \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN pip3 install --ignore-installed --upgrade setuptools pip virtualenv
# We create the most popular default 1000:1000 ubuntu user to not have ssh issues when running with UID==1000
RUN useradd --create-home --uid 1000 --user-group ubuntu \
&& ssh-keyscan -t rsa github.com >> /etc/ssh/ssh_known_hosts
COPY run.sh /
ENV REPO_PATH=/repo_path
ENV OUTPUT_PATH=/output_path
CMD ["/bin/bash", "/run.sh"]

View File

@ -0,0 +1,12 @@
#!/usr/bin/env bash
set -euo pipefail
cd "$REPO_PATH/docs/tools"
if ! [ -d venv ]; then
mkdir -p venv
virtualenv -p "$(which python3)" venv
source venv/bin/activate
python3 -m pip install --ignore-installed -r requirements.txt
fi
source venv/bin/activate
./release.sh 2>&1 | tee "$OUTPUT_PATH/output.log"

View File

@ -146,5 +146,9 @@
"name": "clickhouse/docs-builder",
"dependent": [
]
},
"docker/docs/release": {
"name": "clickhouse/docs-release",
"dependent": []
}
}

View File

@ -177,7 +177,6 @@ function clone_submodules
contrib/jemalloc
contrib/replxx
contrib/wyhash
contrib/eigen
contrib/hashidsxx
)

View File

@ -3,8 +3,6 @@
from multiprocessing import cpu_count
from subprocess import Popen, call, check_output, STDOUT
import os
import sys
import shutil
import argparse
import logging
import time
@ -31,6 +29,9 @@ def get_options(i, backward_compatibility_check):
if i % 5 == 1:
client_options.append("join_use_nulls=1")
if i % 15 == 1:
client_options.append("join_algorithm='parallel_hash'")
if i % 15 == 6:
client_options.append("join_algorithm='partial_merge'")

View File

@ -9,11 +9,6 @@ cmake .. \
-DCMAKE_C_COMPILER=$(which clang-14) \
-DCMAKE_CXX_COMPILER=$(which clang++-14) \
-DCMAKE_BUILD_TYPE=Debug \
-DENABLE_CLICKHOUSE_ALL=OFF \
-DENABLE_CLICKHOUSE_SERVER=ON \
-DENABLE_CLICKHOUSE_CLIENT=ON \
-DENABLE_LIBRARIES=OFF \
-DUSE_UNWIND=ON \
-DENABLE_UTILS=OFF \
-DENABLE_TESTS=OFF
```

View File

@ -106,7 +106,7 @@ vim tests/queries/0_stateless/01521_dummy_test.sql
4) run the test, and put the result of that into the reference file:
```
clickhouse-client -nmT < tests/queries/0_stateless/01521_dummy_test.sql | tee tests/queries/0_stateless/01521_dummy_test.reference
clickhouse-client -nm < tests/queries/0_stateless/01521_dummy_test.sql | tee tests/queries/0_stateless/01521_dummy_test.reference
```
5) ensure everything is correct, if the test output is incorrect (due to some bug for example), adjust the reference file using text editor.

View File

@ -13,11 +13,6 @@ cmake .. \
-DCMAKE_C_COMPILER=$(which clang-13) \
-DCMAKE_CXX_COMPILER=$(which clang++-13) \
-DCMAKE_BUILD_TYPE=Debug \
-DENABLE_CLICKHOUSE_ALL=OFF \
-DENABLE_CLICKHOUSE_SERVER=ON \
-DENABLE_CLICKHOUSE_CLIENT=ON \
-DENABLE_LIBRARIES=OFF \
-DUSE_UNWIND=ON \
-DENABLE_UTILS=OFF \
-DENABLE_TESTS=OFF
```

View File

@ -31,8 +31,11 @@ The supported formats are:
| [JSON](#json) | ✗ | ✔ |
| [JSONAsString](#jsonasstring) | ✔ | ✗ |
| [JSONStrings](#jsonstrings) | ✗ | ✔ |
| [JSONColumns](#jsoncolumns) | ✔ | ✔ |
| [JSONColumnsWithMetadata](#jsoncolumnswithmetadata) | ✗ | ✔ |
| [JSONCompact](#jsoncompact) | ✗ | ✔ |
| [JSONCompactStrings](#jsoncompactstrings) | ✗ | ✔ |
| [JSONCompactColumns](#jsoncompactcolumns) | ✔ | ✔ |
| [JSONEachRow](#jsoneachrow) | ✔ | ✔ |
| [JSONEachRowWithProgress](#jsoneachrowwithprogress) | ✗ | ✔ |
| [JSONStringsEachRow](#jsonstringseachrow) | ✔ | ✔ |
@ -400,6 +403,8 @@ Both data output and parsing are supported in this format. For parsing, any orde
Parsing allows the presence of the additional field `tskv` without the equal sign or a value. This field is ignored.
During import, columns with unknown names will be skipped if setting [input_format_skip_unknown_fields](../operations/settings/settings.md#settings-input-format-skip-unknown-fields) is set to 1.
## CSV {#csv}
Comma Separated Values format ([RFC](https://tools.ietf.org/html/rfc4180)).
@ -459,15 +464,15 @@ SELECT SearchPhrase, count() AS c FROM test.hits GROUP BY SearchPhrase WITH TOTA
"meta":
[
{
"name": "'hello'",
"name": "num",
"type": "Int32"
},
{
"name": "str",
"type": "String"
},
{
"name": "multiply(42, number)",
"type": "UInt64"
},
{
"name": "range(5)",
"name": "arr",
"type": "Array(UInt8)"
}
],
@ -475,25 +480,32 @@ SELECT SearchPhrase, count() AS c FROM test.hits GROUP BY SearchPhrase WITH TOTA
"data":
[
{
"'hello'": "hello",
"multiply(42, number)": "0",
"range(5)": [0,1,2,3,4]
"num": 42,
"str": "hello",
"arr": [0,1]
},
{
"'hello'": "hello",
"multiply(42, number)": "42",
"range(5)": [0,1,2,3,4]
"num": 43,
"str": "hello",
"arr": [0,1,2]
},
{
"'hello'": "hello",
"multiply(42, number)": "84",
"range(5)": [0,1,2,3,4]
"num": 44,
"str": "hello",
"arr": [0,1,2,3]
}
],
"rows": 3,
"rows_before_limit_at_least": 3
"rows_before_limit_at_least": 3,
"statistics":
{
"elapsed": 0.001137687,
"rows_read": 3,
"bytes_read": 24
}
}
```
@ -528,15 +540,15 @@ Example:
"meta":
[
{
"name": "'hello'",
"name": "num",
"type": "Int32"
},
{
"name": "str",
"type": "String"
},
{
"name": "multiply(42, number)",
"type": "UInt64"
},
{
"name": "range(5)",
"name": "arr",
"type": "Array(UInt8)"
}
],
@ -544,25 +556,95 @@ Example:
"data":
[
{
"'hello'": "hello",
"multiply(42, number)": "0",
"range(5)": "[0,1,2,3,4]"
"num": "42",
"str": "hello",
"arr": "[0,1]"
},
{
"'hello'": "hello",
"multiply(42, number)": "42",
"range(5)": "[0,1,2,3,4]"
"num": "43",
"str": "hello",
"arr": "[0,1,2]"
},
{
"'hello'": "hello",
"multiply(42, number)": "84",
"range(5)": "[0,1,2,3,4]"
"num": "44",
"str": "hello",
"arr": "[0,1,2,3]"
}
],
"rows": 3,
"rows_before_limit_at_least": 3
"rows_before_limit_at_least": 3,
"statistics":
{
"elapsed": 0.001403233,
"rows_read": 3,
"bytes_read": 24
}
}
```
## JSONColumns {#jsoncolumns}
In this format, all data is represented as a single JSON Object.
Note that JSONColumns output format buffers all data in memory to output it as a single block and it can lead to high memory consumption.
Example:
```json
{
"num": [42, 43, 44],
"str": ["hello", "hello", "hello"],
"arr": [[0,1], [0,1,2], [0,1,2,3]]
}
```
During import, columns with unknown names will be skipped if setting [input_format_skip_unknown_fields](../operations/settings/settings.md#settings-input-format-skip-unknown-fields) is set to 1.
Columns that are not present in the block will be filled with default values (you can use [input_format_defaults_for_omitted_fields](../operations/settings/settings.md#session_settings-input_format_defaults_for_omitted_fields) setting here)
## JSONColumnsWithMetadata {#jsoncolumnsmonoblock}
Differs from JSONColumns output format in that it also outputs some metadata and statistics (similar to JSON output format).
This format buffers all data in memory and then outputs them as a single block, so, it can lead to high memory consumption.
Example:
```json
{
"meta":
[
{
"name": "num",
"type": "Int32"
},
{
"name": "str",
"type": "String"
},
{
"name": "arr",
"type": "Array(UInt8)"
}
],
"data":
{
"num": [42, 43, 44],
"str": ["hello", "hello", "hello"],
"arr": [[0,1], [0,1,2], [0,1,2,3]]
},
"rows": 3,
"rows_before_limit_at_least": 3,
"statistics":
{
"elapsed": 0.000272376,
"rows_read": 3,
"bytes_read": 24
}
}
```
@ -618,71 +700,101 @@ Result:
Differs from JSON only in that data rows are output in arrays, not in objects.
Examples:
1) JSONCompact:
```json
{
"meta":
[
{
"name": "num",
"type": "Int32"
},
{
"name": "str",
"type": "String"
},
{
"name": "arr",
"type": "Array(UInt8)"
}
],
"data":
[
[42, "hello", [0,1]],
[43, "hello", [0,1,2]],
[44, "hello", [0,1,2,3]]
],
"rows": 3,
"rows_before_limit_at_least": 3,
"statistics":
{
"elapsed": 0.001222069,
"rows_read": 3,
"bytes_read": 24
}
}
```
2) JSONCompactStrings
```json
{
"meta":
[
{
"name": "num",
"type": "Int32"
},
{
"name": "str",
"type": "String"
},
{
"name": "arr",
"type": "Array(UInt8)"
}
],
"data":
[
["42", "hello", "[0,1]"],
["43", "hello", "[0,1,2]"],
["44", "hello", "[0,1,2,3]"]
],
"rows": 3,
"rows_before_limit_at_least": 3,
"statistics":
{
"elapsed": 0.001572097,
"rows_read": 3,
"bytes_read": 24
}
}
```
## JSONCompactColumns {#jsoncompactcolumns}
In this format, all data is represented as a single JSON Array.
Note that JSONCompactColumns output format buffers all data in memory to output it as a single block and it can lead to high memory consumption
Example:
```
// JSONCompact
{
"meta":
[
{
"name": "'hello'",
"type": "String"
},
{
"name": "multiply(42, number)",
"type": "UInt64"
},
{
"name": "range(5)",
"type": "Array(UInt8)"
}
],
"data":
[
["hello", "0", [0,1,2,3,4]],
["hello", "42", [0,1,2,3,4]],
["hello", "84", [0,1,2,3,4]]
],
"rows": 3,
"rows_before_limit_at_least": 3
}
```json
[
[42, 43, 44],
["hello", "hello", "hello"],
[[0,1], [0,1,2], [0,1,2,3]]
]
```
```
// JSONCompactStrings
{
"meta":
[
{
"name": "'hello'",
"type": "String"
},
{
"name": "multiply(42, number)",
"type": "UInt64"
},
{
"name": "range(5)",
"type": "Array(UInt8)"
}
],
"data":
[
["hello", "0", "[0,1,2,3,4]"],
["hello", "42", "[0,1,2,3,4]"],
["hello", "84", "[0,1,2,3,4]"]
],
"rows": 3,
"rows_before_limit_at_least": 3
}
```
Columns that are not present in the block will be filled with default values (you can use [input_format_defaults_for_omitted_fields](../operations/settings/settings.md#session_settings-input_format_defaults_for_omitted_fields) setting here)
## JSONEachRow {#jsoneachrow}
## JSONStringsEachRow {#jsonstringseachrow}
@ -699,15 +811,17 @@ When using these formats, ClickHouse outputs rows as separated, newline-delimite
When inserting the data, you should provide a separate JSON value for each row.
In JSONEachRow/JSONStringsEachRow input formats columns with unknown names will be skipped if setting [input_format_skip_unknown_fields](../operations/settings/settings.md#settings-input-format-skip-unknown-fields) is set to 1.
## JSONEachRowWithProgress {#jsoneachrowwithprogress}
## JSONStringsEachRowWithProgress {#jsonstringseachrowwithprogress}
Differs from `JSONEachRow`/`JSONStringsEachRow` in that ClickHouse will also yield progress information as JSON values.
```json
{"row":{"'hello'":"hello","multiply(42, number)":"0","range(5)":[0,1,2,3,4]}}
{"row":{"'hello'":"hello","multiply(42, number)":"42","range(5)":[0,1,2,3,4]}}
{"row":{"'hello'":"hello","multiply(42, number)":"84","range(5)":[0,1,2,3,4]}}
{"row":{"num":42,"str":"hello","arr":[0,1]}}
{"row":{"num":43,"str":"hello","arr":[0,1,2]}}
{"row":{"num":44,"str":"hello","arr":[0,1,2,3]}}
{"progress":{"read_rows":"3","read_bytes":"24","written_rows":"0","written_bytes":"0","total_rows_to_read":"3"}}
```
@ -728,11 +842,11 @@ Differs from `JSONCompactStringsEachRow` in that in that it also prints the head
Differs from `JSONCompactStringsEachRow` in that it also prints two header rows with column names and types, similar to [TabSeparatedWithNamesAndTypes](#tabseparatedwithnamesandtypes).
```json
["'hello'", "multiply(42, number)", "range(5)"]
["String", "UInt64", "Array(UInt8)"]
["hello", "0", [0,1,2,3,4]]
["hello", "42", [0,1,2,3,4]]
["hello", "84", [0,1,2,3,4]]
["num", "str", "arr"]
["Int32", "String", "Array(UInt8)"]
[42, "hello", [0,1]]
[43, "hello", [0,1,2]]
[44, "hello", [0,1,2,3]]
```
### Inserting Data {#inserting-data}

View File

@ -11,10 +11,16 @@ The functions for working with UUID are listed below.
Generates the [UUID](../data-types/uuid.md) of [version 4](https://tools.ietf.org/html/rfc4122#section-4.4).
**Syntax**
``` sql
generateUUIDv4()
generateUUIDv4([x])
```
**Arguments**
- `x` — [Expression](../../sql-reference/syntax.md#syntax-expressions) resulting in any of the [supported data types](../../sql-reference/data-types/index.md#data_types). The resulting value is discarded, but the expression itself if used for bypassing [common subexpression elimination](../../sql-reference/functions/index.md#common-subexpression-elimination) if the function is called multiple times in one query. Optional parameter.
**Returned value**
The UUID type value.
@ -37,6 +43,15 @@ SELECT * FROM t_uuid
└──────────────────────────────────────┘
```
**Usage example if it is needed to generate multiple values in one row**
```sql
SELECT generateUUIDv4(1), generateUUIDv4(2)
┌─generateUUIDv4(1)────────────────────┬─generateUUIDv4(2)────────────────────┐
│ 2d49dc6e-ddce-4cd0-afb8-790956df54c1 │ 8abf8c13-7dea-4fdf-af3e-0e18767770e6 │
└──────────────────────────────────────┴──────────────────────────────────────┘
```
## empty {#empty}
Checks whether the input UUID is empty.

View File

@ -105,7 +105,7 @@ Example: `regionToCountry(toUInt32(213)) = 225` converts Moscow (213) to Russia
Converts a region to a continent. In every other way, this function is the same as regionToCity.
Example: `regionToContinent(toUInt32(213)) = 10001` converts Moscow (213) to Eurasia (10001).
### regionToTopContinent (#regiontotopcontinent) {#regiontotopcontinent-regiontotopcontinent}
### regionToTopContinent(id\[, geobase\]) {#regiontotopcontinentid-geobase}
Finds the highest continent in the hierarchy for the region.

View File

@ -9,10 +9,16 @@ sidebar_label: "Функции для работы с UUID"
Генерирует идентификатор [UUID версии 4](https://tools.ietf.org/html/rfc4122#section-4.4).
**Синтаксис**
``` sql
generateUUIDv4()
generateUUIDv4([x])
```
**Аргументы**
- `x` — [выражение](../syntax.md#syntax-expressions), возвращающее значение одного из [поддерживаемых типов данных](../data-types/index.md#data_types). Значение используется, чтобы избежать [склейки одинаковых выражений](index.md#common-subexpression-elimination), если функция вызывается несколько раз в одном запросе. Необязательный параметр.
**Возвращаемое значение**
Значение типа [UUID](../../sql-reference/functions/uuid-functions.md).
@ -35,6 +41,15 @@ SELECT * FROM t_uuid
└──────────────────────────────────────┘
```
**Пример использования, для генерации нескольких значений в одной строке**
```sql
SELECT generateUUIDv4(1), generateUUIDv4(2)
┌─generateUUIDv4(1)────────────────────┬─generateUUIDv4(2)────────────────────┐
│ 2d49dc6e-ddce-4cd0-afb8-790956df54c1 │ 8abf8c13-7dea-4fdf-af3e-0e18767770e6 │
└──────────────────────────────────────┴──────────────────────────────────────┘
```
## empty {#empty}
Проверяет, является ли входной UUID пустым.

View File

@ -1,113 +0,0 @@
#!/usr/bin/env python3
import datetime
import logging
import os
import time
import nav # monkey patches mkdocs
import mkdocs.commands
from mkdocs import config
from mkdocs import exceptions
import mdx_clickhouse
import redirects
import util
def build_for_lang(lang, args):
logging.info(f"Building {lang} blog")
try:
theme_cfg = {
"name": None,
"custom_dir": os.path.join(os.path.dirname(__file__), "..", args.theme_dir),
"language": lang,
"direction": "ltr",
"static_templates": ["404.html"],
"extra": {
"now": int(
time.mktime(datetime.datetime.now().timetuple())
) # TODO better way to avoid caching
},
}
# the following list of languages is sorted according to
# https://en.wikipedia.org/wiki/List_of_languages_by_total_number_of_speakers
languages = {"en": "English"}
site_names = {"en": "ClickHouse Blog"}
assert len(site_names) == len(languages)
site_dir = os.path.join(args.blog_output_dir, lang)
plugins = ["macros"]
if args.htmlproofer:
plugins.append("htmlproofer")
website_url = "https://clickhouse.com"
site_name = site_names.get(lang, site_names["en"])
blog_nav, post_meta = nav.build_blog_nav(lang, args)
raw_config = dict(
site_name=site_name,
site_url=f"{website_url}/blog/{lang}/",
docs_dir=os.path.join(args.blog_dir, lang),
site_dir=site_dir,
strict=True,
theme=theme_cfg,
nav=blog_nav,
copyright="©20162022 ClickHouse, Inc.",
use_directory_urls=True,
repo_name="ClickHouse/ClickHouse",
repo_url="https://github.com/ClickHouse/ClickHouse/",
edit_uri=f"edit/master/website/blog/{lang}",
markdown_extensions=mdx_clickhouse.MARKDOWN_EXTENSIONS,
plugins=plugins,
extra=dict(
now=datetime.datetime.now().isoformat(),
rev=args.rev,
rev_short=args.rev_short,
rev_url=args.rev_url,
website_url=website_url,
events=args.events,
languages=languages,
includes_dir=os.path.join(os.path.dirname(__file__), "..", "_includes"),
is_blog=True,
post_meta=post_meta,
today=datetime.date.today().isoformat(),
),
)
cfg = config.load_config(**raw_config)
mkdocs.commands.build.build(cfg)
redirects.build_blog_redirects(args)
env = util.init_jinja2_env(args)
with open(
os.path.join(args.website_dir, "templates", "blog", "rss.xml"), "rb"
) as f:
rss_template_string = f.read().decode("utf-8").strip()
rss_template = env.from_string(rss_template_string)
with open(os.path.join(args.blog_output_dir, lang, "rss.xml"), "w") as f:
f.write(rss_template.render({"config": raw_config}))
logging.info(f"Finished building {lang} blog")
except exceptions.ConfigurationError as e:
raise SystemExit("\n" + str(e))
def build_blog(args):
tasks = []
for lang in args.blog_lang.split(","):
if lang:
tasks.append(
(
lang,
args,
)
)
util.run_function_in_parallel(build_for_lang, tasks, threads=False)

View File

@ -1,144 +1,17 @@
#!/usr/bin/env python3
import argparse
import datetime
import logging
import os
import shutil
import subprocess
import sys
import time
import jinja2
import livereload
import markdown.util
import nav # monkey patches mkdocs
from mkdocs import config
from mkdocs import exceptions
import mkdocs.commands.build
import blog
import mdx_clickhouse
import redirects
import util
import website
from cmake_in_clickhouse_generator import generate_cmake_flags_files
class ClickHouseMarkdown(markdown.extensions.Extension):
class ClickHousePreprocessor(markdown.util.Processor):
def run(self, lines):
for line in lines:
if "<!--hide-->" not in line:
yield line
def extendMarkdown(self, md):
md.preprocessors.register(
self.ClickHousePreprocessor(), "clickhouse_preprocessor", 31
)
markdown.extensions.ClickHouseMarkdown = ClickHouseMarkdown
def build_for_lang(lang, args):
logging.info(f"Building {lang} docs")
try:
theme_cfg = {
"name": None,
"custom_dir": os.path.join(os.path.dirname(__file__), "..", args.theme_dir),
"language": lang,
"direction": "rtl" if lang == "fa" else "ltr",
"static_templates": ["404.html"],
"extra": {
"now": int(
time.mktime(datetime.datetime.now().timetuple())
) # TODO better way to avoid caching
},
}
# the following list of languages is sorted according to
# https://en.wikipedia.org/wiki/List_of_languages_by_total_number_of_speakers
languages = {"en": "English", "zh": "中文", "ru": "Русский", "ja": "日本語"}
site_names = {
"en": "ClickHouse %s Documentation",
"zh": "ClickHouse文档 %s",
"ru": "Документация ClickHouse %s",
"ja": "ClickHouseドキュメント %s",
}
assert len(site_names) == len(languages)
site_dir = os.path.join(args.docs_output_dir, lang)
plugins = ["macros"]
if args.htmlproofer:
plugins.append("htmlproofer")
website_url = "https://clickhouse.com"
site_name = site_names.get(lang, site_names["en"]) % ""
site_name = site_name.replace(" ", " ")
raw_config = dict(
site_name=site_name,
site_url=f"{website_url}/docs/{lang}/",
docs_dir=os.path.join(args.docs_dir, lang),
site_dir=site_dir,
strict=True,
theme=theme_cfg,
copyright="©20162022 ClickHouse, Inc.",
use_directory_urls=True,
repo_name="ClickHouse/ClickHouse",
repo_url="https://github.com/ClickHouse/ClickHouse/",
edit_uri=f"edit/master/docs/{lang}",
markdown_extensions=mdx_clickhouse.MARKDOWN_EXTENSIONS,
plugins=plugins,
extra=dict(
now=datetime.datetime.now().isoformat(),
rev=args.rev,
rev_short=args.rev_short,
rev_url=args.rev_url,
website_url=website_url,
events=args.events,
languages=languages,
includes_dir=os.path.join(os.path.dirname(__file__), "..", "_includes"),
is_blog=False,
),
)
raw_config["nav"] = nav.build_docs_nav(lang, args)
cfg = config.load_config(**raw_config)
if not args.skip_multi_page:
mkdocs.commands.build.build(cfg)
mdx_clickhouse.PatchedMacrosPlugin.disabled = False
logging.info(f"Finished building {lang} docs")
except exceptions.ConfigurationError as e:
raise SystemExit("\n" + str(e))
def build_docs(args):
tasks = []
for lang in args.lang.split(","):
if lang:
tasks.append(
(
lang,
args,
)
)
util.run_function_in_parallel(build_for_lang, tasks, threads=False)
redirects.build_docs_redirects(args)
def build(args):
if os.path.exists(args.output_dir):
@ -147,14 +20,6 @@ def build(args):
if not args.skip_website:
website.build_website(args)
if not args.skip_docs:
generate_cmake_flags_files()
build_docs(args)
if not args.skip_blog:
blog.build_blog(args)
if not args.skip_website:
website.process_benchmark_results(args)
website.minify_website(args)
@ -171,20 +36,14 @@ if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--lang", default="en,ru,zh,ja")
arg_parser.add_argument("--blog-lang", default="en")
arg_parser.add_argument("--docs-dir", default=".")
arg_parser.add_argument("--theme-dir", default=website_dir)
arg_parser.add_argument("--website-dir", default=website_dir)
arg_parser.add_argument("--src-dir", default=src_dir)
arg_parser.add_argument("--blog-dir", default=os.path.join(website_dir, "blog"))
arg_parser.add_argument("--output-dir", default="build")
arg_parser.add_argument("--nav-limit", type=int, default="0")
arg_parser.add_argument("--skip-multi-page", action="store_true")
arg_parser.add_argument("--skip-website", action="store_true")
arg_parser.add_argument("--skip-blog", action="store_true")
arg_parser.add_argument("--skip-docs", action="store_true")
arg_parser.add_argument("--htmlproofer", action="store_true")
arg_parser.add_argument("--no-docs-macros", action="store_true")
arg_parser.add_argument("--livereload", type=int, default="0")
arg_parser.add_argument("--verbose", action="store_true")
@ -196,11 +55,6 @@ if __name__ == "__main__":
logging.getLogger("MARKDOWN").setLevel(logging.INFO)
args.docs_output_dir = os.path.join(os.path.abspath(args.output_dir), "docs")
args.blog_output_dir = os.path.join(os.path.abspath(args.output_dir), "blog")
from github import get_events
args.rev = (
subprocess.check_output("git rev-parse HEAD", shell=True)
.decode("utf-8")
@ -212,9 +66,6 @@ if __name__ == "__main__":
.strip()
)
args.rev_url = f"https://github.com/ClickHouse/ClickHouse/commit/{args.rev}"
args.events = get_events(args)
from build import build
build(args)
@ -223,9 +74,6 @@ if __name__ == "__main__":
new_args = sys.executable + " " + " ".join(new_args)
server = livereload.Server()
server.watch(
args.docs_dir + "**/*", livereload.shell(new_args, cwd="tools", shell=True)
)
server.watch(
args.website_dir + "**/*",
livereload.shell(new_args, cwd="tools", shell=True),

View File

@ -1,181 +0,0 @@
import re
import os
from typing import TextIO, List, Tuple, Optional, Dict
# name, default value, description
Entity = Tuple[str, str, str]
# https://regex101.com/r/R6iogw/12
cmake_option_regex: str = (
r"^\s*option\s*\(([A-Z_0-9${}]+)\s*(?:\"((?:.|\n)*?)\")?\s*(.*)?\).*$"
)
ch_master_url: str = "https://github.com/clickhouse/clickhouse/blob/master/"
name_str: str = '<a name="{anchor}"></a>[`{name}`](' + ch_master_url + "{path}#L{line})"
default_anchor_str: str = "[`{name}`](#{anchor})"
comment_var_regex: str = r"\${(.+)}"
comment_var_replace: str = "`\\1`"
table_header: str = """
| Name | Default value | Description | Comment |
|------|---------------|-------------|---------|
"""
# Needed to detect conditional variables (those which are defined twice)
# name -> (path, values)
entities: Dict[str, Tuple[str, str]] = {}
def make_anchor(t: str) -> str:
return "".join(
["-" if i == "_" else i.lower() for i in t if i.isalpha() or i == "_"]
)
def process_comment(comment: str) -> str:
return re.sub(comment_var_regex, comment_var_replace, comment, flags=re.MULTILINE)
def build_entity(path: str, entity: Entity, line_comment: Tuple[int, str]) -> None:
(line, comment) = line_comment
(name, description, default) = entity
if name in entities:
return
if len(default) == 0:
formatted_default: str = "`OFF`"
elif default[0] == "$":
formatted_default: str = "`{}`".format(default[2:-1])
else:
formatted_default: str = "`" + default + "`"
formatted_name: str = name_str.format(
anchor=make_anchor(name), name=name, path=path, line=line
)
formatted_description: str = "".join(description.split("\n"))
formatted_comment: str = process_comment(comment)
formatted_entity: str = "| {} | {} | {} | {} |".format(
formatted_name, formatted_default, formatted_description, formatted_comment
)
entities[name] = path, formatted_entity
def process_file(root_path: str, file_path: str, file_name: str) -> None:
with open(os.path.join(file_path, file_name), "r") as cmake_file:
contents: str = cmake_file.read()
def get_line_and_comment(target: str) -> Tuple[int, str]:
contents_list: List[str] = contents.split("\n")
comment: str = ""
for n, line in enumerate(contents_list):
if "option" not in line.lower() or target not in line:
continue
for maybe_comment_line in contents_list[n - 1 :: -1]:
if not re.match("\s*#\s*", maybe_comment_line):
break
comment = re.sub("\s*#\s*", "", maybe_comment_line) + " " + comment
# line numbering starts with 1
return n + 1, comment
matches: Optional[List[Entity]] = re.findall(
cmake_option_regex, contents, re.MULTILINE
)
file_rel_path_with_name: str = os.path.join(
file_path[len(root_path) :], file_name
)
if file_rel_path_with_name.startswith("/"):
file_rel_path_with_name = file_rel_path_with_name[1:]
if matches:
for entity in matches:
build_entity(
file_rel_path_with_name, entity, get_line_and_comment(entity[0])
)
def process_folder(root_path: str, name: str) -> None:
for root, _, files in os.walk(os.path.join(root_path, name)):
for f in files:
if f == "CMakeLists.txt" or ".cmake" in f:
process_file(root_path, root, f)
def generate_cmake_flags_files() -> None:
root_path: str = os.path.join(os.path.dirname(__file__), "..", "..")
output_file_name: str = os.path.join(
root_path, "docs/en/development/cmake-in-clickhouse.md"
)
header_file_name: str = os.path.join(
root_path, "docs/_includes/cmake_in_clickhouse_header.md"
)
footer_file_name: str = os.path.join(
root_path, "docs/_includes/cmake_in_clickhouse_footer.md"
)
process_file(root_path, root_path, "CMakeLists.txt")
process_file(root_path, os.path.join(root_path, "programs"), "CMakeLists.txt")
process_folder(root_path, "base")
process_folder(root_path, "cmake")
process_folder(root_path, "src")
with open(output_file_name, "w") as f:
with open(header_file_name, "r") as header:
f.write(header.read())
sorted_keys: List[str] = sorted(entities.keys())
ignored_keys: List[str] = []
f.write("### ClickHouse modes\n" + table_header)
for k in sorted_keys:
if k.startswith("ENABLE_CLICKHOUSE_"):
f.write(entities[k][1] + "\n")
ignored_keys.append(k)
f.write(
"\n### External libraries\nNote that ClickHouse uses forks of these libraries, see https://github.com/ClickHouse-Extras.\n"
+ table_header
)
for k in sorted_keys:
if k.startswith("ENABLE_") and ".cmake" in entities[k][0]:
f.write(entities[k][1] + "\n")
ignored_keys.append(k)
f.write("\n\n### Other flags\n" + table_header)
for k in sorted(set(sorted_keys).difference(set(ignored_keys))):
f.write(entities[k][1] + "\n")
with open(footer_file_name, "r") as footer:
f.write(footer.read())
other_languages = [
"docs/ja/development/cmake-in-clickhouse.md",
"docs/zh/development/cmake-in-clickhouse.md",
"docs/ru/development/cmake-in-clickhouse.md",
]
for lang in other_languages:
other_file_name = os.path.join(root_path, lang)
if os.path.exists(other_file_name):
os.unlink(other_file_name)
os.symlink(output_file_name, other_file_name)
if __name__ == "__main__":
generate_cmake_flags_files()

View File

@ -12,12 +12,11 @@
#
set -ex
BASE_DIR=$(dirname $(readlink -f $0))
BASE_DIR=$(dirname "$(readlink -f "$0")")
GIT_USER=${GIT_USER:-$USER}
GIT_TEST_URI=git@github.com:${GIT_USER}/clickhouse.github.io.git \
GIT_PROD_URI=git@github.com:${GIT_USER}/clickhouse.github.io.git \
BASE_DOMAIN=${GIT_USER}-test.clickhouse.com \
EXTRA_BUILD_ARGS="${@}" \
EXTRA_BUILD_ARGS="${*}" \
CLOUDFLARE_TOKEN="" \
HISTORY_SIZE=3 \
${BASE_DIR}/release.sh
"${BASE_DIR}/release.sh"

View File

@ -1,186 +0,0 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, sys
import argparse
import subprocess
import contextlib
from git import cmd
from tempfile import NamedTemporaryFile
SCRIPT_DESCRIPTION = """
usage: ./easy_diff.py language/document path
Show the difference between a language document and an English document.
This script is based on the assumption that documents in other languages are fully synchronized with the en document at a commit.
For example:
Execute:
./easy_diff.py --no-pager zh/data_types
Output:
Need translate document:~/ClickHouse/docs/en/data_types/uuid.md
Need link document:~/ClickHouse/docs/en/data_types/decimal.md to ~/ClickHouse/docs/zh/data_types/decimal.md
diff --git a/docs/en/data_types/domains/ipv6.md b/docs/en/data_types/domains/ipv6.md
index 1bfbe3400b..e2abaff017 100644
--- a/docs/en/data_types/domains/ipv6.md
+++ b/docs/en/data_types/domains/ipv6.md
@@ -4,13 +4,13 @@
### Basic Usage
-``` sql
+```sql
CREATE TABLE hits (url String, from IPv6) ENGINE = MergeTree() ORDER BY url;
DESCRIBE TABLE hits;
```
-```
+```text
nametypedefault_typedefault_expressioncommentcodec_expression
url String
from IPv6
@@ -19,19 +19,19 @@ DESCRIBE TABLE hits;
OR you can use `IPv6` domain as a key:
-``` sql
+```sql
CREATE TABLE hits (url String, from IPv6) ENGINE = MergeTree() ORDER BY from;
... MORE
OPTIONS:
-h, --help show this help message and exit
--no-pager use stdout as difference result output
"""
SCRIPT_PATH = os.path.abspath(__file__)
CLICKHOUSE_REPO_HOME = os.path.join(os.path.dirname(SCRIPT_PATH), "..", "..")
SCRIPT_COMMAND_EXECUTOR = cmd.Git(CLICKHOUSE_REPO_HOME)
SCRIPT_COMMAND_PARSER = argparse.ArgumentParser(add_help=False)
SCRIPT_COMMAND_PARSER.add_argument("path", type=bytes, nargs="?", default=None)
SCRIPT_COMMAND_PARSER.add_argument("--no-pager", action="store_true", default=False)
SCRIPT_COMMAND_PARSER.add_argument("-h", "--help", action="store_true", default=False)
def execute(commands):
return SCRIPT_COMMAND_EXECUTOR.execute(commands)
def get_hash(file_name):
return execute(["git", "log", "-n", "1", '--pretty=format:"%H"', file_name])
def diff_file(reference_file, working_file, out):
if not os.path.exists(reference_file):
raise RuntimeError(
"reference file [" + os.path.abspath(reference_file) + "] is not exists."
)
if os.path.islink(working_file):
out.writelines(["Need translate document:" + os.path.abspath(reference_file)])
elif not os.path.exists(working_file):
out.writelines(
[
"Need link document "
+ os.path.abspath(reference_file)
+ " to "
+ os.path.abspath(working_file)
]
)
elif get_hash(working_file) != get_hash(reference_file):
out.writelines(
[
(
execute(
[
"git",
"diff",
get_hash(working_file).strip('"'),
reference_file,
]
).encode("utf-8")
)
]
)
return 0
def diff_directory(reference_directory, working_directory, out):
if not os.path.isdir(reference_directory):
return diff_file(reference_directory, working_directory, out)
for list_item in os.listdir(reference_directory):
working_item = os.path.join(working_directory, list_item)
reference_item = os.path.join(reference_directory, list_item)
if (
diff_file(reference_item, working_item, out)
if os.path.isfile(reference_item)
else diff_directory(reference_item, working_item, out) != 0
):
return 1
return 0
def find_language_doc(custom_document, other_language="en", children=[]):
if len(custom_document) == 0:
raise RuntimeError(
"The "
+ os.path.join(custom_document, *children)
+ " is not in docs directory."
)
if os.path.samefile(os.path.join(CLICKHOUSE_REPO_HOME, "docs"), custom_document):
return os.path.join(CLICKHOUSE_REPO_HOME, "docs", other_language, *children[1:])
children.insert(0, os.path.split(custom_document)[1])
return find_language_doc(
os.path.split(custom_document)[0], other_language, children
)
class ToPager:
def __init__(self, temp_named_file):
self.temp_named_file = temp_named_file
def writelines(self, lines):
self.temp_named_file.writelines(lines)
def close(self):
self.temp_named_file.flush()
git_pager = execute(["git", "var", "GIT_PAGER"])
subprocess.check_call([git_pager, self.temp_named_file.name])
self.temp_named_file.close()
class ToStdOut:
def writelines(self, lines):
self.system_stdout_stream.writelines(lines)
def close(self):
self.system_stdout_stream.flush()
def __init__(self, system_stdout_stream):
self.system_stdout_stream = system_stdout_stream
if __name__ == "__main__":
arguments = SCRIPT_COMMAND_PARSER.parse_args()
if arguments.help or not arguments.path:
sys.stdout.write(SCRIPT_DESCRIPTION)
sys.exit(0)
working_language = os.path.join(CLICKHOUSE_REPO_HOME, "docs", arguments.path)
with contextlib.closing(
ToStdOut(sys.stdout)
if arguments.no_pager
else ToPager(NamedTemporaryFile("r+"))
) as writer:
exit(
diff_directory(
find_language_doc(working_language), working_language, writer
)
)

View File

@ -1,41 +0,0 @@
import collections
import copy
import io
import logging
import os
import random
import sys
import tarfile
import time
import requests
import util
def get_events(args):
events = []
skip = True
with open(os.path.join(args.docs_dir, "..", "README.md")) as f:
for line in f:
if skip:
if "Upcoming Events" in line:
skip = False
else:
if not line:
continue
line = line.strip().split("](")
if len(line) == 2:
tail = line[1].split(") ")
events.append(
{
"signup_link": tail[0],
"event_name": line[0].replace("* [", ""),
"event_date": tail[1].replace("on ", "").replace(".", ""),
}
)
return events
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG, stream=sys.stderr)

View File

@ -1,190 +0,0 @@
import collections
import datetime
import hashlib
import logging
import os
import mkdocs.structure.nav
import util
def find_first_header(content):
for line in content.split("\n"):
if line.startswith("#"):
no_hash = line.lstrip("#")
return no_hash.split("{", 1)[0].strip()
def build_nav_entry(root, args):
if root.endswith("images"):
return None, None, None
result_items = []
index_meta, index_content = util.read_md_file(os.path.join(root, "index.md"))
current_title = index_meta.get("toc_folder_title", index_meta.get("toc_title"))
current_title = current_title or index_meta.get(
"title", find_first_header(index_content)
)
for filename in os.listdir(root):
path = os.path.join(root, filename)
if os.path.isdir(path):
prio, title, payload = build_nav_entry(path, args)
if title and payload:
result_items.append((prio, title, payload))
elif filename.endswith(".md"):
path = os.path.join(root, filename)
meta = ""
content = ""
try:
meta, content = util.read_md_file(path)
except:
print("Error in file: {}".format(path))
raise
path = path.split("/", 2)[-1]
title = meta.get("toc_title", find_first_header(content))
if title:
title = title.strip().rstrip(".")
else:
title = meta.get("toc_folder_title", "hidden")
prio = meta.get("toc_priority", 9999)
logging.debug(f"Nav entry: {prio}, {title}, {path}")
if meta.get("toc_hidden") or not content.strip():
title = "hidden"
if title == "hidden":
title = "hidden-" + hashlib.sha1(content.encode("utf-8")).hexdigest()
if args.nav_limit and len(result_items) >= args.nav_limit:
break
result_items.append((prio, title, path))
result_items = sorted(result_items, key=lambda x: (x[0], x[1]))
result = collections.OrderedDict([(item[1], item[2]) for item in result_items])
if index_meta.get("toc_hidden_folder"):
current_title += "|hidden-folder"
return index_meta.get("toc_priority", 10000), current_title, result
def build_docs_nav(lang, args):
docs_dir = os.path.join(args.docs_dir, lang)
_, _, nav = build_nav_entry(docs_dir, args)
result = []
index_key = None
for key, value in list(nav.items()):
if key and value:
if value == "index.md":
index_key = key
continue
result.append({key: value})
if args.nav_limit and len(result) >= args.nav_limit:
break
if index_key:
key = list(result[0].keys())[0]
result[0][key][index_key] = "index.md"
result[0][key].move_to_end(index_key, last=False)
return result
def build_blog_nav(lang, args):
blog_dir = os.path.join(args.blog_dir, lang)
years = sorted(os.listdir(blog_dir), reverse=True)
result_nav = [{"hidden": "index.md"}]
post_meta = collections.OrderedDict()
for year in years:
year_dir = os.path.join(blog_dir, year)
if not os.path.isdir(year_dir):
continue
result_nav.append({year: collections.OrderedDict()})
posts = []
post_meta_items = []
for post in os.listdir(year_dir):
post_path = os.path.join(year_dir, post)
if not post.endswith(".md"):
raise RuntimeError(
f"Unexpected non-md file in posts folder: {post_path}"
)
meta, _ = util.read_md_file(post_path)
post_date = meta["date"]
post_title = meta["title"]
if datetime.date.fromisoformat(post_date) > datetime.date.today():
continue
posts.append(
(
post_date,
post_title,
os.path.join(year, post),
)
)
if post_title in post_meta:
raise RuntimeError(f"Duplicate post title: {post_title}")
if not post_date.startswith(f"{year}-"):
raise RuntimeError(
f"Post date {post_date} doesn't match the folder year {year}: {post_title}"
)
post_url_part = post.replace(".md", "")
post_meta_items.append(
(
post_date,
{
"date": post_date,
"title": post_title,
"image": meta.get("image"),
"url": f"/blog/{lang}/{year}/{post_url_part}/",
},
)
)
for _, title, path in sorted(posts, reverse=True):
result_nav[-1][year][title] = path
for _, post_meta_item in sorted(
post_meta_items, reverse=True, key=lambda item: item[0]
):
post_meta[post_meta_item["title"]] = post_meta_item
return result_nav, post_meta
def _custom_get_navigation(files, config):
nav_config = config["nav"] or mkdocs.structure.nav.nest_paths(
f.src_path for f in files.documentation_pages()
)
items = mkdocs.structure.nav._data_to_navigation(nav_config, files, config)
if not isinstance(items, list):
items = [items]
pages = mkdocs.structure.nav._get_by_type(items, mkdocs.structure.nav.Page)
mkdocs.structure.nav._add_previous_and_next_links(pages)
mkdocs.structure.nav._add_parent_links(items)
missing_from_config = [
file for file in files.documentation_pages() if file.page is None
]
if missing_from_config:
files._files = [
file for file in files._files if file not in missing_from_config
]
links = mkdocs.structure.nav._get_by_type(items, mkdocs.structure.nav.Link)
for link in links:
scheme, netloc, path, params, query, fragment = mkdocs.structure.nav.urlparse(
link.url
)
if scheme or netloc:
mkdocs.structure.nav.log.debug(
"An external link to '{}' is included in "
"the 'nav' configuration.".format(link.url)
)
elif link.url.startswith("/"):
mkdocs.structure.nav.log.debug(
"An absolute path to '{}' is included in the 'nav' configuration, "
"which presumably points to an external resource.".format(link.url)
)
else:
msg = (
"A relative path to '{}' is included in the 'nav' configuration, "
"which is not found in the documentation files".format(link.url)
)
mkdocs.structure.nav.log.warning(msg)
return mkdocs.structure.nav.Navigation(items, pages)
mkdocs.structure.nav.get_navigation = _custom_get_navigation

View File

@ -27,45 +27,6 @@ def write_redirect_html(out_path, to_url):
)
def build_redirect_html(args, base_prefix, lang, output_dir, from_path, to_path):
out_path = os.path.join(
output_dir,
lang,
from_path.replace("/index.md", "/index.html").replace(".md", "/index.html"),
)
target_path = to_path.replace("/index.md", "/").replace(".md", "/")
if target_path[0:7] != "http://" and target_path[0:8] != "https://":
to_url = f"/{base_prefix}/{lang}/{target_path}"
else:
to_url = target_path
to_url = to_url.strip()
write_redirect_html(out_path, to_url)
def build_docs_redirects(args):
with open(os.path.join(args.docs_dir, "redirects.txt"), "r") as f:
for line in f:
for lang in args.lang.split(","):
from_path, to_path = line.split(" ", 1)
build_redirect_html(
args, "docs", lang, args.docs_output_dir, from_path, to_path
)
def build_blog_redirects(args):
for lang in args.blog_lang.split(","):
redirects_path = os.path.join(args.blog_dir, lang, "redirects.txt")
if os.path.exists(redirects_path):
with open(redirects_path, "r") as f:
for line in f:
from_path, to_path = line.split(" ", 1)
build_redirect_html(
args, "blog", lang, args.blog_output_dir, from_path, to_path
)
def build_static_redirects(args):
for static_redirect in [
("benchmark.html", "/benchmark/dbms/"),

View File

@ -1,24 +1,24 @@
#!/usr/bin/env bash
set -ex
BASE_DIR=$(dirname $(readlink -f $0))
BASE_DIR=$(dirname "$(readlink -f "$0")")
BUILD_DIR="${BASE_DIR}/../build"
PUBLISH_DIR="${BASE_DIR}/../publish"
BASE_DOMAIN="${BASE_DOMAIN:-content.clickhouse.com}"
GIT_TEST_URI="${GIT_TEST_URI:-git@github.com:ClickHouse/clickhouse-com-content.git}"
GIT_PROD_URI="git@github.com:ClickHouse/clickhouse-website-content.git"
GIT_PROD_URI="${GIT_PROD_URI:-git@github.com:ClickHouse/clickhouse-com-content.git}"
EXTRA_BUILD_ARGS="${EXTRA_BUILD_ARGS:---verbose}"
if [[ -z "$1" ]]
then
source "${BASE_DIR}/venv/bin/activate"
# shellcheck disable=2086
python3 "${BASE_DIR}/build.py" ${EXTRA_BUILD_ARGS}
rm -rf "${PUBLISH_DIR}"
mkdir "${PUBLISH_DIR}" && cd "${PUBLISH_DIR}"
# Will make a repository with website content as the only commit.
git init
git remote add origin "${GIT_TEST_URI}"
git remote add origin "${GIT_PROD_URI}"
git config user.email "robot-clickhouse@clickhouse.com"
git config user.name "robot-clickhouse"
@ -28,7 +28,7 @@ then
echo -n "" > README.md
echo -n "" > ".nojekyll"
cp "${BASE_DIR}/../../LICENSE" .
git add *
git add ./*
git add ".nojekyll"
git commit --quiet -m "Add new release at $(date)"
@ -40,7 +40,7 @@ then
# Turn off logging.
set +x
if [[ ! -z "${CLOUDFLARE_TOKEN}" ]]
if [[ -n "${CLOUDFLARE_TOKEN}" ]]
then
sleep 1m
# https://api.cloudflare.com/#zone-purge-files-by-cache-tags,-host-or-prefix

View File

@ -1,39 +1,30 @@
Babel==2.9.1
backports-abc==0.5
backports.functools-lru-cache==1.6.1
beautifulsoup4==4.9.1
certifi==2020.4.5.2
chardet==3.0.4
click==7.1.2
closure==20191111
cssmin==0.2.0
future==0.18.2
htmlmin==0.1.12
idna==2.10
Jinja2==3.0.3
jinja2-highlight==0.6.1
jsmin==3.0.0
livereload==2.6.3
Markdown==3.3.2
MarkupSafe==2.1.0
mkdocs==1.3.0
mkdocs-htmlproofer-plugin==0.0.3
mkdocs-macros-plugin==0.4.20
nltk==3.7
nose==1.3.7
protobuf==3.14.0
numpy==1.21.2
pymdown-extensions==8.0
python-slugify==4.0.1
MarkupSafe==2.1.1
PyYAML==6.0
repackage==0.7.3
requests==2.25.1
singledispatch==3.4.0.3
Pygments>=2.12.0
beautifulsoup4==4.9.1
click==7.1.2
ghp_import==2.1.0
importlib_metadata==4.11.4
jinja2-highlight==0.6.1
livereload==2.6.3
mergedeep==1.3.4
mkdocs-macros-plugin==0.4.20
mkdocs-macros-test==0.1.0
mkdocs-material==8.2.15
mkdocs==1.3.0
mkdocs_material_extensions==1.0.3
packaging==21.3
pymdown_extensions==9.4
pyparsing==3.0.9
python-slugify==4.0.1
python_dateutil==2.8.2
pytz==2022.1
six==1.15.0
soupsieve==2.0.1
soupsieve==2.3.2
termcolor==1.1.0
text_unidecode==1.3
tornado==6.1
Unidecode==1.1.1
urllib3>=1.26.8
Pygments>=2.11.2
zipp==3.8.0

View File

@ -124,7 +124,7 @@ def init_jinja2_env(args):
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(
[args.website_dir, os.path.join(args.docs_dir, "_includes")]
[args.website_dir, os.path.join(args.src_dir, "docs", "_includes")]
),
extensions=["jinja2.ext.i18n", "jinja2_highlight.HighlightExtension"],
)

View File

@ -1,81 +0,0 @@
const path = require('path');
const jsPath = path.resolve(__dirname, '../../website/src/js');
const scssPath = path.resolve(__dirname, '../../website/src/scss');
console.log(path.resolve(__dirname, 'node_modules/bootstrap', require('bootstrap/package.json').sass));
module.exports = {
mode: ('development' === process.env.NODE_ENV) && 'development' || 'production',
...(('development' === process.env.NODE_ENV) && {
watch: true,
}),
entry: [
path.resolve(scssPath, 'bootstrap.scss'),
path.resolve(scssPath, 'main.scss'),
path.resolve(jsPath, 'main.js'),
],
output: {
path: path.resolve(__dirname, '../../website'),
filename: 'js/main.js',
},
resolve: {
alias: {
bootstrap: path.resolve(__dirname, 'node_modules/bootstrap', require('bootstrap/package.json').sass),
},
},
module: {
rules: [{
test: /\.js$/,
exclude: /(node_modules)/,
use: [{
loader: 'babel-loader',
options: {
presets: ['@babel/preset-env'],
},
}],
}, {
test: /\.scss$/,
use: [{
loader: 'file-loader',
options: {
sourceMap: true,
outputPath: (url, entryPath, context) => {
if (0 === entryPath.indexOf(scssPath)) {
const outputFile = entryPath.slice(entryPath.lastIndexOf('/') + 1, -5)
const outputPath = entryPath.slice(0, entryPath.lastIndexOf('/')).slice(scssPath.length + 1)
return `./css/${outputPath}/${outputFile}.css`
}
return `./css/${url}`
},
},
}, {
loader: 'postcss-loader',
options: {
options: {},
plugins: () => ([
require('autoprefixer'),
('production' === process.env.NODE_ENV) && require('cssnano'),
].filter(plugin => plugin)),
}
}, {
loader: 'sass-loader',
options: {
implementation: require('sass'),
implementation: require('sass'),
sourceMap: ('development' === process.env.NODE_ENV),
sassOptions: {
importer: require('node-sass-glob-importer')(),
precision: 10,
},
},
}],
}],
},
};

View File

@ -1314,7 +1314,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
global_context->setConfigReloadCallback([&]()
{
main_config_reloader->reload();
access_control.reloadUsersConfigs();
access_control.reload();
});
/// Limit on total number of concurrently executed queries.
@ -1405,6 +1405,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
/// Stop reloading of the main config. This must be done before `global_context->shutdown()` because
/// otherwise the reloading may pass a changed config to some destroyed parts of ContextSharedPart.
main_config_reloader.reset();
access_control.stopPeriodicReloading();
async_metrics.stop();
@ -1628,7 +1629,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
buildLoggers(config(), logger());
main_config_reloader->start();
access_control.startPeriodicReloadingUsersConfigs();
access_control.startPeriodicReloading();
if (dns_cache_updater)
dns_cache_updater->start();

View File

@ -81,6 +81,8 @@
{
height: 100%;
margin: 0;
/* This enables position: sticky on controls */
overflow: auto;
}
html
@ -89,9 +91,26 @@
font-family: Liberation Sans, DejaVu Sans, sans-serif, Noto Color Emoji, Apple Color Emoji, Segoe UI Emoji;
background: var(--background-color);
color: var(--text-color);
}
body
{
/* This element will show scroll-bar on overflow, and the scroll-bar will be outside of the padding. */
padding: 0.5rem;
}
#controls
{
/* Make enough space for even huge queries. */
height: 20%;
/* When a page will be scrolled horizontally due to large table size, keep controls in place. */
position: sticky;
left: 0;
/* This allows query textarea to occupy the remaining height while other elements have fixed height. */
display: flex;
flex-direction: column;
}
/* Otherwise Webkit based browsers will display ugly border on focus. */
textarea, input, button
{
@ -129,8 +148,7 @@
#query_div
{
/* Make enough space for even huge queries. */
height: 20%;
height: 100%;
}
#query
@ -380,19 +398,21 @@
</head>
<body>
<div id="inputs">
<input class="monospace shadow" id="url" type="text" value="http://localhost:8123/" placeholder="url" /><input class="monospace shadow" id="user" type="text" value="default" placeholder="user" /><input class="monospace shadow" id="password" type="password" placeholder="password" />
</div>
<div id="query_div">
<textarea autofocus spellcheck="false" class="monospace shadow" id="query"></textarea>
</div>
<div id="run_div">
<button class="shadow" id="run">Run</button>
<span class="hint">&nbsp;(Ctrl/Cmd+Enter)</span>
<span id="hourglass"></span>
<span id="check-mark"></span>
<span id="stats"></span>
<span id="toggle-dark">🌑</span><span id="toggle-light">🌞</span>
<div id="controls">
<div id="inputs">
<input class="monospace shadow" id="url" type="text" value="http://localhost:8123/" placeholder="url" /><input class="monospace shadow" id="user" type="text" value="default" placeholder="user" /><input class="monospace shadow" id="password" type="password" placeholder="password" />
</div>
<div id="query_div">
<textarea autofocus spellcheck="false" class="monospace shadow" id="query"></textarea>
</div>
<div id="run_div">
<button class="shadow" id="run">Run</button>
<span class="hint">&nbsp;(Ctrl/Cmd+Enter)</span>
<span id="hourglass"></span>
<span id="check-mark"></span>
<span id="stats"></span>
<span id="toggle-dark">🌑</span><span id="toggle-light">🌞</span>
</div>
</div>
<div id="data_div">
<table class="monospace-table shadow" id="data-table"></table>

View File

@ -0,0 +1,122 @@
#include <Access/AccessChangesNotifier.h>
#include <boost/range/algorithm/copy.hpp>
namespace DB
{
AccessChangesNotifier::AccessChangesNotifier() : handlers(std::make_shared<Handlers>())
{
}
AccessChangesNotifier::~AccessChangesNotifier() = default;
void AccessChangesNotifier::onEntityAdded(const UUID & id, const AccessEntityPtr & new_entity)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.entity = new_entity;
event.type = new_entity->getType();
queue.push(std::move(event));
}
void AccessChangesNotifier::onEntityUpdated(const UUID & id, const AccessEntityPtr & changed_entity)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.entity = changed_entity;
event.type = changed_entity->getType();
queue.push(std::move(event));
}
void AccessChangesNotifier::onEntityRemoved(const UUID & id, AccessEntityType type)
{
std::lock_guard lock{queue_mutex};
Event event;
event.id = id;
event.type = type;
queue.push(std::move(event));
}
scope_guard AccessChangesNotifier::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler)
{
std::lock_guard lock{handlers->mutex};
auto & list = handlers->by_type[static_cast<size_t>(type)];
list.push_back(handler);
auto handler_it = std::prev(list.end());
return [handlers=handlers, type, handler_it]
{
std::lock_guard lock2{handlers->mutex};
auto & list2 = handlers->by_type[static_cast<size_t>(type)];
list2.erase(handler_it);
};
}
scope_guard AccessChangesNotifier::subscribeForChanges(const UUID & id, const OnChangedHandler & handler)
{
std::lock_guard lock{handlers->mutex};
auto it = handlers->by_id.emplace(id, std::list<OnChangedHandler>{}).first;
auto & list = it->second;
list.push_back(handler);
auto handler_it = std::prev(list.end());
return [handlers=handlers, it, handler_it]
{
std::lock_guard lock2{handlers->mutex};
auto & list2 = it->second;
list2.erase(handler_it);
if (list2.empty())
handlers->by_id.erase(it);
};
}
scope_guard AccessChangesNotifier::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler)
{
scope_guard subscriptions;
for (const auto & id : ids)
subscriptions.join(subscribeForChanges(id, handler));
return subscriptions;
}
void AccessChangesNotifier::sendNotifications()
{
/// Only one thread can send notification at any time.
std::lock_guard sending_notifications_lock{sending_notifications};
std::unique_lock queue_lock{queue_mutex};
while (!queue.empty())
{
auto event = std::move(queue.front());
queue.pop();
queue_lock.unlock();
std::vector<OnChangedHandler> current_handlers;
{
std::lock_guard handlers_lock{handlers->mutex};
boost::range::copy(handlers->by_type[static_cast<size_t>(event.type)], std::back_inserter(current_handlers));
auto it = handlers->by_id.find(event.id);
if (it != handlers->by_id.end())
boost::range::copy(it->second, std::back_inserter(current_handlers));
}
for (const auto & handler : current_handlers)
{
try
{
handler(event.id, event.entity);
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
queue_lock.lock();
}
}
}

View File

@ -0,0 +1,73 @@
#pragma once
#include <Access/IAccessEntity.h>
#include <base/scope_guard.h>
#include <list>
#include <queue>
#include <unordered_map>
namespace DB
{
/// Helper class implementing subscriptions and notifications in access management.
class AccessChangesNotifier
{
public:
AccessChangesNotifier();
~AccessChangesNotifier();
using OnChangedHandler
= std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler);
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler)
{
return subscribeForChanges(EntityClassT::TYPE, handler);
}
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler);
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler);
/// Called by access storages after a new access entity has been added.
void onEntityAdded(const UUID & id, const AccessEntityPtr & new_entity);
/// Called by access storages after an access entity has been changed.
void onEntityUpdated(const UUID & id, const AccessEntityPtr & changed_entity);
/// Called by access storages after an access entity has been removed.
void onEntityRemoved(const UUID & id, AccessEntityType type);
/// Sends notifications to subscribers about changes in access entities
/// (added with previous calls onEntityAdded(), onEntityUpdated(), onEntityRemoved()).
void sendNotifications();
private:
struct Handlers
{
std::unordered_map<UUID, std::list<OnChangedHandler>> by_id;
std::list<OnChangedHandler> by_type[static_cast<size_t>(AccessEntityType::MAX)];
std::mutex mutex;
};
/// shared_ptr is here for safety because AccessChangesNotifier can be destroyed before all subscriptions are removed.
std::shared_ptr<Handlers> handlers;
struct Event
{
UUID id;
AccessEntityPtr entity;
AccessEntityType type;
};
std::queue<Event> queue;
std::mutex queue_mutex;
std::mutex sending_notifications;
};
}

View File

@ -14,9 +14,10 @@
#include <Access/SettingsProfilesCache.h>
#include <Access/User.h>
#include <Access/ExternalAuthenticators.h>
#include <Access/AccessChangesNotifier.h>
#include <Core/Settings.h>
#include <base/find_symbols.h>
#include <Poco/ExpireCache.h>
#include <Poco/AccessExpireCache.h>
#include <boost/algorithm/string/join.hpp>
#include <boost/algorithm/string/split.hpp>
#include <boost/algorithm/string/trim.hpp>
@ -82,7 +83,7 @@ public:
private:
const AccessControl & access_control;
Poco::ExpireCache<ContextAccess::Params, std::shared_ptr<const ContextAccess>> cache;
Poco::AccessExpireCache<ContextAccess::Params, std::shared_ptr<const ContextAccess>> cache;
std::mutex mutex;
};
@ -142,7 +143,8 @@ AccessControl::AccessControl()
quota_cache(std::make_unique<QuotaCache>(*this)),
settings_profiles_cache(std::make_unique<SettingsProfilesCache>(*this)),
external_authenticators(std::make_unique<ExternalAuthenticators>()),
custom_settings_prefixes(std::make_unique<CustomSettingsPrefixes>())
custom_settings_prefixes(std::make_unique<CustomSettingsPrefixes>()),
changes_notifier(std::make_unique<AccessChangesNotifier>())
{
}
@ -231,35 +233,6 @@ void AccessControl::addUsersConfigStorage(
LOG_DEBUG(getLogger(), "Added {} access storage '{}', path: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getPath());
}
void AccessControl::reloadUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->reload();
}
}
void AccessControl::startPeriodicReloadingUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->startPeriodicReloading();
}
}
void AccessControl::stopPeriodicReloadingUsersConfigs()
{
auto storages = getStoragesPtr();
for (const auto & storage : *storages)
{
if (auto users_config_storage = typeid_cast<std::shared_ptr<UsersConfigAccessStorage>>(storage))
users_config_storage->stopPeriodicReloading();
}
}
void AccessControl::addReplicatedStorage(
const String & storage_name_,
@ -272,10 +245,9 @@ void AccessControl::addReplicatedStorage(
if (auto replicated_storage = typeid_cast<std::shared_ptr<ReplicatedAccessStorage>>(storage))
return;
}
auto new_storage = std::make_shared<ReplicatedAccessStorage>(storage_name_, zookeeper_path_, get_zookeeper_function_);
auto new_storage = std::make_shared<ReplicatedAccessStorage>(storage_name_, zookeeper_path_, get_zookeeper_function_, *changes_notifier);
addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}'", String(new_storage->getStorageType()), new_storage->getStorageName());
new_storage->startup();
}
void AccessControl::addDiskStorage(const String & directory_, bool readonly_)
@ -298,7 +270,7 @@ void AccessControl::addDiskStorage(const String & storage_name_, const String &
}
}
}
auto new_storage = std::make_shared<DiskAccessStorage>(storage_name_, directory_, readonly_);
auto new_storage = std::make_shared<DiskAccessStorage>(storage_name_, directory_, readonly_, *changes_notifier);
addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}', path: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getPath());
}
@ -312,7 +284,7 @@ void AccessControl::addMemoryStorage(const String & storage_name_)
if (auto memory_storage = typeid_cast<std::shared_ptr<MemoryAccessStorage>>(storage))
return;
}
auto new_storage = std::make_shared<MemoryAccessStorage>(storage_name_);
auto new_storage = std::make_shared<MemoryAccessStorage>(storage_name_, *changes_notifier);
addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}'", String(new_storage->getStorageType()), new_storage->getStorageName());
}
@ -320,7 +292,7 @@ void AccessControl::addMemoryStorage(const String & storage_name_)
void AccessControl::addLDAPStorage(const String & storage_name_, const Poco::Util::AbstractConfiguration & config_, const String & prefix_)
{
auto new_storage = std::make_shared<LDAPAccessStorage>(storage_name_, this, config_, prefix_);
auto new_storage = std::make_shared<LDAPAccessStorage>(storage_name_, *this, config_, prefix_);
addStorage(new_storage);
LOG_DEBUG(getLogger(), "Added {} access storage '{}', LDAP server name: {}", String(new_storage->getStorageType()), new_storage->getStorageName(), new_storage->getLDAPServerName());
}
@ -423,6 +395,57 @@ void AccessControl::addStoragesFromMainConfig(
}
void AccessControl::reload()
{
MultipleAccessStorage::reload();
changes_notifier->sendNotifications();
}
scope_guard AccessControl::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(type, handler);
}
scope_guard AccessControl::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(id, handler);
}
scope_guard AccessControl::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
{
return changes_notifier->subscribeForChanges(ids, handler);
}
std::optional<UUID> AccessControl::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists)
{
auto id = MultipleAccessStorage::insertImpl(entity, replace_if_exists, throw_if_exists);
if (id)
changes_notifier->sendNotifications();
return id;
}
bool AccessControl::removeImpl(const UUID & id, bool throw_if_not_exists)
{
bool removed = MultipleAccessStorage::removeImpl(id, throw_if_not_exists);
if (removed)
changes_notifier->sendNotifications();
return removed;
}
bool AccessControl::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
bool updated = MultipleAccessStorage::updateImpl(id, update_func, throw_if_not_exists);
if (updated)
changes_notifier->sendNotifications();
return updated;
}
AccessChangesNotifier & AccessControl::getChangesNotifier()
{
return *changes_notifier;
}
UUID AccessControl::authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const
{
try

View File

@ -3,8 +3,8 @@
#include <Access/MultipleAccessStorage.h>
#include <Common/SettingsChanges.h>
#include <Common/ZooKeeper/Common.h>
#include <base/scope_guard.h>
#include <boost/container/flat_set.hpp>
#include <Access/UsersConfigAccessStorage.h>
#include <memory>
@ -40,6 +40,7 @@ class SettingsProfilesCache;
class SettingsProfileElements;
class ClientInfo;
class ExternalAuthenticators;
class AccessChangesNotifier;
struct Settings;
@ -50,6 +51,7 @@ public:
AccessControl();
~AccessControl() override;
/// Initializes access storage (user directories).
void setUpFromMainConfig(const Poco::Util::AbstractConfiguration & config_, const String & config_path_,
const zkutil::GetZooKeeper & get_zookeeper_function_);
@ -74,9 +76,6 @@ public:
const String & preprocessed_dir_,
const zkutil::GetZooKeeper & get_zookeeper_function_ = {});
void reloadUsersConfigs();
void startPeriodicReloadingUsersConfigs();
void stopPeriodicReloadingUsersConfigs();
/// Loads access entities from the directory on the local disk.
/// Use that directory to keep created users/roles/etc.
void addDiskStorage(const String & directory_, bool readonly_ = false);
@ -106,6 +105,26 @@ public:
const String & config_path,
const zkutil::GetZooKeeper & get_zookeeper_function);
/// Reloads and updates entities in this storage. This function is used to implement SYSTEM RELOAD CONFIG.
void reload() override;
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const;
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(EntityClassT::TYPE, handler); }
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const;
void setExternalAuthenticatorsConfig(const Poco::Util::AbstractConfiguration & config);
/// Sets the default profile's name.
/// The default profile's settings are always applied before any other profile's.
void setDefaultProfileName(const String & default_profile_name);
@ -135,9 +154,6 @@ public:
void setOnClusterQueriesRequireClusterGrant(bool enable) { on_cluster_queries_require_cluster_grant = enable; }
bool doesOnClusterQueriesRequireClusterGrant() const { return on_cluster_queries_require_cluster_grant; }
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address) const;
void setExternalAuthenticatorsConfig(const Poco::Util::AbstractConfiguration & config);
std::shared_ptr<const ContextAccess> getContextAccess(
const UUID & user_id,
const std::vector<UUID> & current_roles,
@ -178,10 +194,17 @@ public:
const ExternalAuthenticators & getExternalAuthenticators() const;
/// Gets manager of notifications.
AccessChangesNotifier & getChangesNotifier();
private:
class ContextAccessCache;
class CustomSettingsPrefixes;
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
std::unique_ptr<ContextAccessCache> context_access_cache;
std::unique_ptr<RoleCache> role_cache;
std::unique_ptr<RowPolicyCache> row_policy_cache;
@ -189,6 +212,7 @@ private:
std::unique_ptr<SettingsProfilesCache> settings_profiles_cache;
std::unique_ptr<ExternalAuthenticators> external_authenticators;
std::unique_ptr<CustomSettingsPrefixes> custom_settings_prefixes;
std::unique_ptr<AccessChangesNotifier> changes_notifier;
std::atomic_bool allow_plaintext_password = true;
std::atomic_bool allow_no_password = true;
std::atomic_bool users_without_row_policies_can_read_rows = false;

View File

@ -149,6 +149,21 @@ ContextAccess::ContextAccess(const AccessControl & access_control_, const Params
}
ContextAccess::~ContextAccess()
{
enabled_settings.reset();
enabled_quota.reset();
enabled_row_policies.reset();
access_with_implicit.reset();
access.reset();
roles_info.reset();
subscription_for_roles_changes.reset();
enabled_roles.reset();
subscription_for_user_change.reset();
user.reset();
}
void ContextAccess::initialize()
{
std::lock_guard lock{mutex};

View File

@ -155,6 +155,8 @@ public:
/// without any limitations. This is used for the global context.
static std::shared_ptr<const ContextAccess> getFullAccess();
~ContextAccess();
private:
friend class AccessControl;
ContextAccess() {} /// NOLINT

View File

@ -1,5 +1,6 @@
#include <Access/DiskAccessStorage.h>
#include <Access/AccessEntityIO.h>
#include <Access/AccessChangesNotifier.h>
#include <IO/WriteHelpers.h>
#include <IO/ReadHelpers.h>
#include <IO/ReadBufferFromFile.h>
@ -164,13 +165,8 @@ namespace
}
DiskAccessStorage::DiskAccessStorage(const String & directory_path_, bool readonly_)
: DiskAccessStorage(STORAGE_TYPE, directory_path_, readonly_)
{
}
DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_)
: IAccessStorage(storage_name_)
DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_, AccessChangesNotifier & changes_notifier_)
: IAccessStorage(storage_name_), changes_notifier(changes_notifier_)
{
directory_path = makeDirectoryPathCanonical(directory_path_);
readonly = readonly_;
@ -199,7 +195,15 @@ DiskAccessStorage::DiskAccessStorage(const String & storage_name_, const String
DiskAccessStorage::~DiskAccessStorage()
{
stopListsWritingThread();
writeLists();
try
{
writeLists();
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
}
}
@ -470,19 +474,16 @@ std::optional<String> DiskAccessStorage::readNameImpl(const UUID & id, bool thro
std::optional<UUID> DiskAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
UUID id = generateRandomID();
std::lock_guard lock{mutex};
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists, notifications))
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists))
return id;
return std::nullopt;
}
bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications)
bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{
const String & name = new_entity->getName();
AccessEntityType type = new_entity->getType();
@ -514,7 +515,7 @@ bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & ne
writeAccessEntityToDisk(id, *new_entity);
if (name_collision && replace_if_exists)
removeNoLock(it_by_name->second->id, /* throw_if_not_exists = */ false, notifications);
removeNoLock(it_by_name->second->id, /* throw_if_not_exists = */ false);
/// Do insertion.
auto & entry = entries_by_id[id];
@ -523,22 +524,20 @@ bool DiskAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & ne
entry.name = name;
entry.entity = new_entity;
entries_by_name[entry.name] = &entry;
prepareNotifications(id, entry, false, notifications);
changes_notifier.onEntityAdded(id, new_entity);
return true;
}
bool DiskAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
return removeNoLock(id, throw_if_not_exists, notifications);
return removeNoLock(id, throw_if_not_exists);
}
bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications)
bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists)
{
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
@ -559,25 +558,24 @@ bool DiskAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists,
deleteAccessEntityOnDisk(id);
/// Do removing.
prepareNotifications(id, entry, true, notifications);
UUID removed_id = id;
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
entries_by_name.erase(entry.name);
entries_by_id.erase(it);
changes_notifier.onEntityRemoved(removed_id, type);
return true;
}
bool DiskAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
return updateNoLock(id, update_func, throw_if_not_exists, notifications);
return updateNoLock(id, update_func, throw_if_not_exists);
}
bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications)
bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
@ -626,7 +624,8 @@ bool DiskAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_
entries_by_name[entry.name] = &entry;
}
prepareNotifications(id, entry, false, notifications);
changes_notifier.onEntityUpdated(id, new_entity);
return true;
}
@ -650,74 +649,4 @@ void DiskAccessStorage::deleteAccessEntityOnDisk(const UUID & id) const
throw Exception("Couldn't delete " + file_path, ErrorCodes::FILE_DOESNT_EXIST);
}
void DiskAccessStorage::prepareNotifications(const UUID & id, const Entry & entry, bool remove, Notifications & notifications) const
{
if (!remove && !entry.entity)
return;
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.type)])
notifications.push_back({handler, id, entity});
}
scope_guard DiskAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
scope_guard DiskAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
bool DiskAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool DiskAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
}

View File

@ -7,14 +7,15 @@
namespace DB
{
class AccessChangesNotifier;
/// Loads and saves access entities on a local disk to a specified directory.
class DiskAccessStorage : public IAccessStorage
{
public:
static constexpr char STORAGE_TYPE[] = "local directory";
DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_ = false);
DiskAccessStorage(const String & directory_path_, bool readonly_ = false);
DiskAccessStorage(const String & storage_name_, const String & directory_path_, bool readonly_, AccessChangesNotifier & changes_notifier_);
~DiskAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; }
@ -27,8 +28,6 @@ public:
bool isReadOnly() const override { return readonly; }
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -38,8 +37,6 @@ private:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
void clear();
bool readLists();
@ -50,9 +47,9 @@ private:
void listsWritingThreadFunc();
void stopListsWritingThread();
bool insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications);
bool removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications);
bool insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists);
bool removeNoLock(const UUID & id, bool throw_if_not_exists);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
AccessEntityPtr readAccessEntityFromDisk(const UUID & id) const;
void writeAccessEntityToDisk(const UUID & id, const IAccessEntity & entity) const;
@ -65,11 +62,8 @@ private:
String name;
AccessEntityType type;
mutable AccessEntityPtr entity; /// may be nullptr, if the entity hasn't been loaded yet.
mutable std::list<OnChangedHandler> handlers_by_id;
};
void prepareNotifications(const UUID & id, const Entry & entry, bool remove, Notifications & notifications) const;
String directory_path;
std::atomic<bool> readonly;
std::unordered_map<UUID, Entry> entries_by_id;
@ -79,7 +73,7 @@ private:
ThreadFromGlobalPool lists_writing_thread; /// List files are written in a separate thread.
std::condition_variable lists_writing_thread_should_exit; /// Signals `lists_writing_thread` to exit.
bool lists_writing_thread_is_waiting = false;
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
AccessChangesNotifier & changes_notifier;
mutable std::mutex mutex;
};
}

View File

@ -6,7 +6,7 @@
namespace DB
{
EnabledRoles::EnabledRoles(const Params & params_) : params(params_)
EnabledRoles::EnabledRoles(const Params & params_) : params(params_), handlers(std::make_shared<Handlers>())
{
}
@ -15,42 +15,50 @@ EnabledRoles::~EnabledRoles() = default;
std::shared_ptr<const EnabledRolesInfo> EnabledRoles::getRolesInfo() const
{
std::lock_guard lock{mutex};
std::lock_guard lock{info_mutex};
return info;
}
scope_guard EnabledRoles::subscribeForChanges(const OnChangeHandler & handler) const
{
std::lock_guard lock{mutex};
handlers.push_back(handler);
auto it = std::prev(handlers.end());
std::lock_guard lock{handlers->mutex};
handlers->list.push_back(handler);
auto it = std::prev(handlers->list.end());
return [this, it]
return [handlers=handlers, it]
{
std::lock_guard lock2{mutex};
handlers.erase(it);
std::lock_guard lock2{handlers->mutex};
handlers->list.erase(it);
};
}
void EnabledRoles::setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard & notifications)
void EnabledRoles::setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard * notifications)
{
std::lock_guard lock{mutex};
if (info && info_ && *info == *info_)
return;
info = info_;
std::vector<OnChangeHandler> handlers_to_notify;
boost::range::copy(handlers, std::back_inserter(handlers_to_notify));
notifications.join(scope_guard([info = info, handlers_to_notify = std::move(handlers_to_notify)]
{
for (const auto & handler : handlers_to_notify)
handler(info);
}));
std::lock_guard lock{info_mutex};
if (info && info_ && *info == *info_)
return;
info = info_;
}
if (notifications)
{
std::vector<OnChangeHandler> handlers_to_notify;
{
std::lock_guard lock{handlers->mutex};
boost::range::copy(handlers->list, std::back_inserter(handlers_to_notify));
}
notifications->join(scope_guard(
[info = info, handlers_to_notify = std::move(handlers_to_notify)]
{
for (const auto & handler : handlers_to_notify)
handler(info);
}));
}
}
}

View File

@ -4,6 +4,7 @@
#include <base/scope_guard.h>
#include <boost/container/flat_set.hpp>
#include <list>
#include <memory>
#include <mutex>
#include <vector>
@ -43,12 +44,21 @@ private:
friend class RoleCache;
explicit EnabledRoles(const Params & params_);
void setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard & notifications);
void setRolesInfo(const std::shared_ptr<const EnabledRolesInfo> & info_, scope_guard * notifications);
const Params params;
mutable std::shared_ptr<const EnabledRolesInfo> info;
mutable std::list<OnChangeHandler> handlers;
mutable std::mutex mutex;
std::shared_ptr<const EnabledRolesInfo> info;
mutable std::mutex info_mutex;
struct Handlers
{
std::list<OnChangeHandler> list;
std::mutex mutex;
};
/// shared_ptr is here for safety because EnabledRoles can be destroyed before all subscriptions are removed.
std::shared_ptr<Handlers> handlers;
};
}

View File

@ -410,34 +410,6 @@ bool IAccessStorage::updateImpl(const UUID & id, const UpdateFunc &, bool throw_
}
scope_guard IAccessStorage::subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(type, handler);
}
scope_guard IAccessStorage::subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const
{
return subscribeForChangesImpl(id, handler);
}
scope_guard IAccessStorage::subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const
{
scope_guard subscriptions;
for (const auto & id : ids)
subscriptions.join(subscribeForChangesImpl(id, handler));
return subscriptions;
}
void IAccessStorage::notify(const Notifications & notifications)
{
for (const auto & [fn, id, new_entity] : notifications)
fn(id, new_entity);
}
UUID IAccessStorage::authenticate(
const Credentials & credentials,
const Poco::Net::IPAddress & address,

View File

@ -3,7 +3,6 @@
#include <Access/IAccessEntity.h>
#include <Core/Types.h>
#include <Core/UUID.h>
#include <base/scope_guard.h>
#include <functional>
#include <optional>
#include <vector>
@ -22,7 +21,7 @@ enum class AuthenticationType;
/// Contains entities, i.e. instances of classes derived from IAccessEntity.
/// The implementations of this class MUST be thread-safe.
class IAccessStorage
class IAccessStorage : public boost::noncopyable
{
public:
explicit IAccessStorage(const String & storage_name_) : storage_name(storage_name_) {}
@ -41,6 +40,15 @@ public:
/// Returns true if this entity is readonly.
virtual bool isReadOnly(const UUID &) const { return isReadOnly(); }
/// Reloads and updates entities in this storage. This function is used to implement SYSTEM RELOAD CONFIG.
virtual void reload() {}
/// Starts periodic reloading and update of entities in this storage.
virtual void startPeriodicReloading() {}
/// Stops periodic reloading and update of entities in this storage.
virtual void stopPeriodicReloading() {}
/// Returns the identifiers of all the entities of a specified type contained in the storage.
std::vector<UUID> findAll(AccessEntityType type) const;
@ -130,23 +138,6 @@ public:
/// Updates multiple entities in the storage. Returns the list of successfully updated.
std::vector<UUID> tryUpdate(const std::vector<UUID> & ids, const UpdateFunc & update_func);
using OnChangedHandler = std::function<void(const UUID & /* id */, const AccessEntityPtr & /* new or changed entity, null if removed */)>;
/// Subscribes for all changes.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(AccessEntityType type, const OnChangedHandler & handler) const;
template <typename EntityClassT>
scope_guard subscribeForChanges(OnChangedHandler handler) const { return subscribeForChanges(EntityClassT::TYPE, handler); }
/// Subscribes for changes of a specific entry.
/// Can return nullptr if cannot subscribe (identifier not found) or if it doesn't make sense (the storage is read-only).
scope_guard subscribeForChanges(const UUID & id, const OnChangedHandler & handler) const;
scope_guard subscribeForChanges(const std::vector<UUID> & ids, const OnChangedHandler & handler) const;
virtual bool hasSubscription(AccessEntityType type) const = 0;
virtual bool hasSubscription(const UUID & id) const = 0;
/// Finds a user, check the provided credentials and returns the ID of the user if they are valid.
/// Throws an exception if no such user or credentials are invalid.
UUID authenticate(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool allow_no_password, bool allow_plaintext_password) const;
@ -160,8 +151,6 @@ protected:
virtual std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
virtual bool removeImpl(const UUID & id, bool throw_if_not_exists);
virtual bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
virtual scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const = 0;
virtual scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const = 0;
virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const;
virtual bool areCredentialsValid(const User & user, const Credentials & credentials, const ExternalAuthenticators & external_authenticators) const;
virtual bool isAddressAllowed(const User & user, const Poco::Net::IPAddress & address) const;
@ -181,9 +170,6 @@ protected:
[[noreturn]] static void throwAddressNotAllowed(const Poco::Net::IPAddress & address);
[[noreturn]] static void throwInvalidCredentials();
[[noreturn]] static void throwAuthenticationTypeNotAllowed(AuthenticationType auth_type);
using Notification = std::tuple<OnChangedHandler, UUID, AccessEntityPtr>;
using Notifications = std::vector<Notification>;
static void notify(const Notifications & notifications);
private:
const String storage_name;

View File

@ -27,10 +27,10 @@ namespace ErrorCodes
}
LDAPAccessStorage::LDAPAccessStorage(const String & storage_name_, AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix)
: IAccessStorage(storage_name_)
LDAPAccessStorage::LDAPAccessStorage(const String & storage_name_, AccessControl & access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix)
: IAccessStorage(storage_name_), access_control(access_control_), memory_storage(storage_name_, access_control.getChangesNotifier())
{
setConfiguration(access_control_, config, prefix);
setConfiguration(config, prefix);
}
@ -40,7 +40,7 @@ String LDAPAccessStorage::getLDAPServerName() const
}
void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix)
void LDAPAccessStorage::setConfiguration(const Poco::Util::AbstractConfiguration & config, const String & prefix)
{
std::scoped_lock lock(mutex);
@ -80,7 +80,6 @@ void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const
}
}
access_control = access_control_;
ldap_server_name = ldap_server_name_cfg;
role_search_params.swap(role_search_params_cfg);
common_role_names.swap(common_roles_cfg);
@ -91,7 +90,7 @@ void LDAPAccessStorage::setConfiguration(AccessControl * access_control_, const
granted_role_names.clear();
granted_role_ids.clear();
role_change_subscription = access_control->subscribeForChanges<Role>(
role_change_subscription = access_control.subscribeForChanges<Role>(
[this] (const UUID & id, const AccessEntityPtr & entity)
{
return this->processRoleChange(id, entity);
@ -215,7 +214,7 @@ void LDAPAccessStorage::assignRolesNoLock(User & user, const LDAPClient::SearchR
auto it = granted_role_ids.find(role_name);
if (it == granted_role_ids.end())
{
if (const auto role_id = access_control->find<Role>(role_name))
if (const auto role_id = access_control.find<Role>(role_name))
{
granted_role_names.insert_or_assign(*role_id, role_name);
it = granted_role_ids.insert_or_assign(role_name, *role_id).first;
@ -450,33 +449,6 @@ std::optional<String> LDAPAccessStorage::readNameImpl(const UUID & id, bool thro
}
scope_guard LDAPAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::scoped_lock lock(mutex);
return memory_storage.subscribeForChanges(id, handler);
}
scope_guard LDAPAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::scoped_lock lock(mutex);
return memory_storage.subscribeForChanges(type, handler);
}
bool LDAPAccessStorage::hasSubscription(const UUID & id) const
{
std::scoped_lock lock(mutex);
return memory_storage.hasSubscription(id);
}
bool LDAPAccessStorage::hasSubscription(AccessEntityType type) const
{
std::scoped_lock lock(mutex);
return memory_storage.hasSubscription(type);
}
std::optional<UUID> LDAPAccessStorage::authenticateImpl(
const Credentials & credentials,
const Poco::Net::IPAddress & address,

View File

@ -32,7 +32,7 @@ class LDAPAccessStorage : public IAccessStorage
public:
static constexpr char STORAGE_TYPE[] = "ldap";
explicit LDAPAccessStorage(const String & storage_name_, AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix);
explicit LDAPAccessStorage(const String & storage_name_, AccessControl & access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix);
virtual ~LDAPAccessStorage() override = default;
String getLDAPServerName() const;
@ -42,19 +42,15 @@ public:
virtual String getStorageParamsJSON() const override;
virtual bool isReadOnly() const override { return true; }
virtual bool exists(const UUID & id) const override;
virtual bool hasSubscription(const UUID & id) const override;
virtual bool hasSubscription(AccessEntityType type) const override;
private: // IAccessStorage implementations.
virtual std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
virtual std::vector<UUID> findAllImpl(AccessEntityType type) const override;
virtual AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
virtual std::optional<String> readNameImpl(const UUID & id, bool throw_if_not_exists) const override;
virtual scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
virtual scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
virtual std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const override;
void setConfiguration(AccessControl * access_control_, const Poco::Util::AbstractConfiguration & config, const String & prefix);
void setConfiguration(const Poco::Util::AbstractConfiguration & config, const String & prefix);
void processRoleChange(const UUID & id, const AccessEntityPtr & entity);
void applyRoleChangeNoLock(bool grant, const UUID & role_id, const String & role_name);
@ -66,7 +62,7 @@ private: // IAccessStorage implementations.
const ExternalAuthenticators & external_authenticators, LDAPClient::SearchResultsList & role_search_results) const;
mutable std::recursive_mutex mutex;
AccessControl * access_control = nullptr;
AccessControl & access_control;
String ldap_server_name;
LDAPClient::RoleSearchParamsList role_search_params;
std::set<String> common_role_names; // role name that should be granted to all users at all times

View File

@ -1,4 +1,5 @@
#include <Access/MemoryAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
#include <base/scope_guard.h>
#include <boost/container/flat_set.hpp>
#include <boost/range/adaptor/map.hpp>
@ -7,8 +8,8 @@
namespace DB
{
MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_)
: IAccessStorage(storage_name_)
MemoryAccessStorage::MemoryAccessStorage(const String & storage_name_, AccessChangesNotifier & changes_notifier_)
: IAccessStorage(storage_name_), changes_notifier(changes_notifier_)
{
}
@ -63,19 +64,16 @@ AccessEntityPtr MemoryAccessStorage::readImpl(const UUID & id, bool throw_if_not
std::optional<UUID> MemoryAccessStorage::insertImpl(const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
UUID id = generateRandomID();
std::lock_guard lock{mutex};
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists, notifications))
if (insertNoLock(id, new_entity, replace_if_exists, throw_if_exists))
return id;
return std::nullopt;
}
bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications)
bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr & new_entity, bool replace_if_exists, bool throw_if_exists)
{
const String & name = new_entity->getName();
AccessEntityType type = new_entity->getType();
@ -103,7 +101,7 @@ bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr &
if (name_collision && replace_if_exists)
{
const auto & existing_entry = *(it_by_name->second);
removeNoLock(existing_entry.id, /* throw_if_not_exists = */ false, notifications);
removeNoLock(existing_entry.id, /* throw_if_not_exists = */ false);
}
/// Do insertion.
@ -111,22 +109,19 @@ bool MemoryAccessStorage::insertNoLock(const UUID & id, const AccessEntityPtr &
entry.id = id;
entry.entity = new_entity;
entries_by_name[name] = &entry;
prepareNotifications(entry, false, notifications);
changes_notifier.onEntityAdded(id, new_entity);
return true;
}
bool MemoryAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
return removeNoLock(id, throw_if_not_exists, notifications);
return removeNoLock(id, throw_if_not_exists);
}
bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications)
bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists)
{
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
@ -141,27 +136,25 @@ bool MemoryAccessStorage::removeNoLock(const UUID & id, bool throw_if_not_exists
const String & name = entry.entity->getName();
AccessEntityType type = entry.entity->getType();
prepareNotifications(entry, true, notifications);
/// Do removing.
UUID removed_id = id;
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
entries_by_name.erase(name);
entries_by_id.erase(it);
changes_notifier.onEntityRemoved(removed_id, type);
return true;
}
bool MemoryAccessStorage::updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
return updateNoLock(id, update_func, throw_if_not_exists, notifications);
return updateNoLock(id, update_func, throw_if_not_exists);
}
bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications)
bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists)
{
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
@ -195,7 +188,7 @@ bool MemoryAccessStorage::updateNoLock(const UUID & id, const UpdateFunc & updat
entries_by_name[new_entity->getName()] = &entry;
}
prepareNotifications(entry, false, notifications);
changes_notifier.onEntityUpdated(id, new_entity);
return true;
}
@ -212,16 +205,8 @@ void MemoryAccessStorage::setAll(const std::vector<AccessEntityPtr> & all_entiti
void MemoryAccessStorage::setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
setAllNoLock(all_entities, notifications);
}
void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications)
{
boost::container::flat_set<UUID> not_used_ids;
std::vector<UUID> conflicting_ids;
@ -256,7 +241,7 @@ void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessE
boost::container::flat_set<UUID> ids_to_remove = std::move(not_used_ids);
boost::range::copy(conflicting_ids, std::inserter(ids_to_remove, ids_to_remove.end()));
for (const auto & id : ids_to_remove)
removeNoLock(id, /* throw_if_not_exists = */ false, notifications);
removeNoLock(id, /* throw_if_not_exists = */ false);
/// Insert or update entities.
for (const auto & [id, entity] : all_entities)
@ -269,84 +254,14 @@ void MemoryAccessStorage::setAllNoLock(const std::vector<std::pair<UUID, AccessE
const AccessEntityPtr & changed_entity = entity;
updateNoLock(id,
[&changed_entity](const AccessEntityPtr &) { return changed_entity; },
/* throw_if_not_exists = */ true,
notifications);
/* throw_if_not_exists = */ true);
}
}
else
{
insertNoLock(id, entity, /* replace_if_exists = */ false, /* throw_if_exists = */ true, notifications);
insertNoLock(id, entity, /* replace_if_exists = */ false, /* throw_if_exists = */ true);
}
}
}
void MemoryAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.entity->getType())])
notifications.push_back({handler, entry.id, entity});
}
scope_guard MemoryAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
scope_guard MemoryAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
bool MemoryAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
auto it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool MemoryAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
}

View File

@ -9,13 +9,15 @@
namespace DB
{
class AccessChangesNotifier;
/// Implementation of IAccessStorage which keeps all data in memory.
class MemoryAccessStorage : public IAccessStorage
{
public:
static constexpr char STORAGE_TYPE[] = "memory";
explicit MemoryAccessStorage(const String & storage_name_ = STORAGE_TYPE);
explicit MemoryAccessStorage(const String & storage_name_, AccessChangesNotifier & changes_notifier_);
const char * getStorageType() const override { return STORAGE_TYPE; }
@ -24,8 +26,6 @@ public:
void setAll(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities);
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -34,25 +34,20 @@ private:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
bool insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists);
bool removeNoLock(const UUID & id, bool throw_if_not_exists);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
struct Entry
{
UUID id;
AccessEntityPtr entity;
mutable std::list<OnChangedHandler> handlers_by_id;
};
bool insertNoLock(const UUID & id, const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists, Notifications & notifications);
bool removeNoLock(const UUID & id, bool throw_if_not_exists, Notifications & notifications);
bool updateNoLock(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists, Notifications & notifications);
void setAllNoLock(const std::vector<std::pair<UUID, AccessEntityPtr>> & all_entities, Notifications & notifications);
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
mutable std::recursive_mutex mutex;
mutable std::mutex mutex;
std::unordered_map<UUID, Entry> entries_by_id; /// We want to search entries both by ID and by the pair of name and type.
std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
AccessChangesNotifier & changes_notifier;
};
}

View File

@ -45,7 +45,6 @@ void MultipleAccessStorage::setStorages(const std::vector<StoragePtr> & storages
std::unique_lock lock{mutex};
nested_storages = std::make_shared<const Storages>(storages);
ids_cache.reset();
updateSubscriptionsToNestedStorages(lock);
}
void MultipleAccessStorage::addStorage(const StoragePtr & new_storage)
@ -56,7 +55,6 @@ void MultipleAccessStorage::addStorage(const StoragePtr & new_storage)
auto new_storages = std::make_shared<Storages>(*nested_storages);
new_storages->push_back(new_storage);
nested_storages = new_storages;
updateSubscriptionsToNestedStorages(lock);
}
void MultipleAccessStorage::removeStorage(const StoragePtr & storage_to_remove)
@ -70,7 +68,6 @@ void MultipleAccessStorage::removeStorage(const StoragePtr & storage_to_remove)
new_storages->erase(new_storages->begin() + index);
nested_storages = new_storages;
ids_cache.reset();
updateSubscriptionsToNestedStorages(lock);
}
std::vector<StoragePtr> MultipleAccessStorage::getStorages()
@ -225,6 +222,28 @@ bool MultipleAccessStorage::isReadOnly(const UUID & id) const
}
void MultipleAccessStorage::reload()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->reload();
}
void MultipleAccessStorage::startPeriodicReloading()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->startPeriodicReloading();
}
void MultipleAccessStorage::stopPeriodicReloading()
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
storage->stopPeriodicReloading();
}
std::optional<UUID> MultipleAccessStorage::insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists)
{
std::shared_ptr<IAccessStorage> storage_for_insertion;
@ -310,145 +329,6 @@ bool MultipleAccessStorage::updateImpl(const UUID & id, const UpdateFunc & updat
}
scope_guard MultipleAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
auto storage = findStorage(id);
if (!storage)
return {};
return storage->subscribeForChanges(id, handler);
}
bool MultipleAccessStorage::hasSubscription(const UUID & id) const
{
auto storages = getStoragesInternal();
for (const auto & storage : *storages)
{
if (storage->hasSubscription(id))
return true;
}
return false;
}
scope_guard MultipleAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::unique_lock lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
if (handlers.size() == 1)
updateSubscriptionsToNestedStorages(lock);
return [this, type, handler_it]
{
std::unique_lock lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
if (handlers2.empty())
updateSubscriptionsToNestedStorages(lock2);
};
}
bool MultipleAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
/// Updates subscriptions to nested storages.
/// We need the subscriptions to the nested storages if someone has subscribed to us.
/// If any of the nested storages is changed we call our subscribers.
void MultipleAccessStorage::updateSubscriptionsToNestedStorages(std::unique_lock<std::mutex> & lock) const
{
/// lock is already locked.
std::vector<std::pair<StoragePtr, scope_guard>> added_subscriptions[static_cast<size_t>(AccessEntityType::MAX)];
std::vector<scope_guard> removed_subscriptions;
for (auto type : collections::range(AccessEntityType::MAX))
{
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
auto & subscriptions = subscriptions_to_nested_storages[static_cast<size_t>(type)];
if (handlers.empty())
{
/// None has subscribed to us, we need no subscriptions to the nested storages.
for (auto & subscription : subscriptions | boost::adaptors::map_values)
removed_subscriptions.push_back(std::move(subscription));
subscriptions.clear();
}
else
{
/// Someone has subscribed to us, now we need to have a subscription to each nested storage.
for (auto it = subscriptions.begin(); it != subscriptions.end();)
{
const auto & storage = it->first;
auto & subscription = it->second;
if (boost::range::find(*nested_storages, storage) == nested_storages->end())
{
removed_subscriptions.push_back(std::move(subscription));
it = subscriptions.erase(it);
}
else
++it;
}
for (const auto & storage : *nested_storages)
{
if (!subscriptions.contains(storage))
added_subscriptions[static_cast<size_t>(type)].push_back({storage, nullptr});
}
}
}
/// Unlock the mutex temporarily because it's much better to subscribe to the nested storages
/// with the mutex unlocked.
lock.unlock();
removed_subscriptions.clear();
for (auto type : collections::range(AccessEntityType::MAX))
{
if (!added_subscriptions[static_cast<size_t>(type)].empty())
{
auto on_changed = [this, type](const UUID & id, const AccessEntityPtr & entity)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock2{mutex};
for (const auto & handler : handlers_by_type[static_cast<size_t>(type)])
notifications.push_back({handler, id, entity});
};
for (auto & [storage, subscription] : added_subscriptions[static_cast<size_t>(type)])
subscription = storage->subscribeForChanges(type, on_changed);
}
}
/// Lock the mutex again to store added subscriptions to the nested storages.
lock.lock();
for (auto type : collections::range(AccessEntityType::MAX))
{
if (!added_subscriptions[static_cast<size_t>(type)].empty())
{
auto & subscriptions = subscriptions_to_nested_storages[static_cast<size_t>(type)];
for (auto & [storage, subscription] : added_subscriptions[static_cast<size_t>(type)])
{
if (!subscriptions.contains(storage) && (boost::range::find(*nested_storages, storage) != nested_storages->end())
&& !handlers_by_type[static_cast<size_t>(type)].empty())
{
subscriptions.emplace(std::move(storage), std::move(subscription));
}
}
}
}
lock.unlock();
}
std::optional<UUID>
MultipleAccessStorage::authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address,
const ExternalAuthenticators & external_authenticators,

View File

@ -24,6 +24,10 @@ public:
bool isReadOnly() const override;
bool isReadOnly(const UUID & id) const override;
void reload() override;
void startPeriodicReloading() override;
void stopPeriodicReloading() override;
void setStorages(const std::vector<StoragePtr> & storages);
void addStorage(const StoragePtr & new_storage);
void removeStorage(const StoragePtr & storage_to_remove);
@ -37,8 +41,6 @@ public:
StoragePtr getStorage(const UUID & id);
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
protected:
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
@ -48,19 +50,14 @@ protected:
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
bool updateImpl(const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists) override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
std::optional<UUID> authenticateImpl(const Credentials & credentials, const Poco::Net::IPAddress & address, const ExternalAuthenticators & external_authenticators, bool throw_if_user_not_exists, bool allow_no_password, bool allow_plaintext_password) const override;
private:
using Storages = std::vector<StoragePtr>;
std::shared_ptr<const Storages> getStoragesInternal() const;
void updateSubscriptionsToNestedStorages(std::unique_lock<std::mutex> & lock) const;
std::shared_ptr<const Storages> nested_storages;
mutable LRUCache<UUID, Storage> ids_cache;
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::unordered_map<StoragePtr, scope_guard> subscriptions_to_nested_storages[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::mutex mutex;
};

View File

@ -1,12 +1,14 @@
#include <Access/AccessEntityIO.h>
#include <Access/MemoryAccessStorage.h>
#include <Access/ReplicatedAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
#include <IO/ReadHelpers.h>
#include <boost/container/flat_set.hpp>
#include <Common/ZooKeeper/KeeperException.h>
#include <Common/ZooKeeper/Types.h>
#include <Common/ZooKeeper/ZooKeeper.h>
#include <Common/escapeForFileName.h>
#include <Common/setThreadName.h>
#include <base/range.h>
#include <base/sleep.h>
@ -30,11 +32,13 @@ static UUID parseUUID(const String & text)
ReplicatedAccessStorage::ReplicatedAccessStorage(
const String & storage_name_,
const String & zookeeper_path_,
zkutil::GetZooKeeper get_zookeeper_)
zkutil::GetZooKeeper get_zookeeper_,
AccessChangesNotifier & changes_notifier_)
: IAccessStorage(storage_name_)
, zookeeper_path(zookeeper_path_)
, get_zookeeper(get_zookeeper_)
, refresh_queue(std::numeric_limits<size_t>::max())
, watched_queue(std::make_shared<ConcurrentBoundedQueue<UUID>>(std::numeric_limits<size_t>::max()))
, changes_notifier(changes_notifier_)
{
if (zookeeper_path.empty())
throw Exception("ZooKeeper path must be non-empty", ErrorCodes::BAD_ARGUMENTS);
@ -45,29 +49,30 @@ ReplicatedAccessStorage::ReplicatedAccessStorage(
/// If zookeeper chroot prefix is used, path should start with '/', because chroot concatenates without it.
if (zookeeper_path.front() != '/')
zookeeper_path = "/" + zookeeper_path;
initializeZookeeper();
}
ReplicatedAccessStorage::~ReplicatedAccessStorage()
{
ReplicatedAccessStorage::shutdown();
stopWatchingThread();
}
void ReplicatedAccessStorage::startup()
void ReplicatedAccessStorage::startWatchingThread()
{
initializeZookeeper();
worker_thread = ThreadFromGlobalPool(&ReplicatedAccessStorage::runWorkerThread, this);
bool prev_watching_flag = watching.exchange(true);
if (!prev_watching_flag)
watching_thread = ThreadFromGlobalPool(&ReplicatedAccessStorage::runWatchingThread, this);
}
void ReplicatedAccessStorage::shutdown()
void ReplicatedAccessStorage::stopWatchingThread()
{
bool prev_stop_flag = stop_flag.exchange(true);
if (!prev_stop_flag)
bool prev_watching_flag = watching.exchange(false);
if (prev_watching_flag)
{
refresh_queue.finish();
if (worker_thread.joinable())
worker_thread.join();
watched_queue->finish();
if (watching_thread.joinable())
watching_thread.join();
}
}
@ -105,10 +110,8 @@ std::optional<UUID> ReplicatedAccessStorage::insertImpl(const AccessEntityPtr &
if (!ok)
return std::nullopt;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id, notifications);
refreshEntityNoLock(zookeeper, id);
return id;
}
@ -207,10 +210,8 @@ bool ReplicatedAccessStorage::removeImpl(const UUID & id, bool throw_if_not_exis
if (!ok)
return false;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
removeEntityNoLock(id, notifications);
removeEntityNoLock(id);
return true;
}
@ -261,10 +262,8 @@ bool ReplicatedAccessStorage::updateImpl(const UUID & id, const UpdateFunc & upd
if (!ok)
return false;
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id, notifications);
refreshEntityNoLock(zookeeper, id);
return true;
}
@ -328,16 +327,18 @@ bool ReplicatedAccessStorage::updateZooKeeper(const zkutil::ZooKeeperPtr & zooke
}
void ReplicatedAccessStorage::runWorkerThread()
void ReplicatedAccessStorage::runWatchingThread()
{
LOG_DEBUG(getLogger(), "Started worker thread");
while (!stop_flag)
LOG_DEBUG(getLogger(), "Started watching thread");
setThreadName("ReplACLWatch");
while (watching)
{
try
{
if (!initialized)
initializeZookeeper();
refresh();
if (refresh())
changes_notifier.sendNotifications();
}
catch (...)
{
@ -353,7 +354,7 @@ void ReplicatedAccessStorage::resetAfterError()
initialized = false;
UUID id;
while (refresh_queue.tryPop(id)) {}
while (watched_queue->tryPop(id)) {}
std::lock_guard lock{mutex};
for (const auto type : collections::range(AccessEntityType::MAX))
@ -389,21 +390,20 @@ void ReplicatedAccessStorage::createRootNodes(const zkutil::ZooKeeperPtr & zooke
}
}
void ReplicatedAccessStorage::refresh()
bool ReplicatedAccessStorage::refresh()
{
UUID id;
if (refresh_queue.tryPop(id, /* timeout_ms: */ 10000))
{
if (stop_flag)
return;
if (!watched_queue->tryPop(id, /* timeout_ms: */ 10000))
return false;
auto zookeeper = get_zookeeper();
auto zookeeper = get_zookeeper();
if (id == UUIDHelpers::Nil)
refreshEntities(zookeeper);
else
refreshEntity(zookeeper, id);
}
if (id == UUIDHelpers::Nil)
refreshEntities(zookeeper);
else
refreshEntity(zookeeper, id);
return true;
}
@ -412,9 +412,9 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
LOG_DEBUG(getLogger(), "Refreshing entities list");
const String zookeeper_uuids_path = zookeeper_path + "/uuid";
auto watch_entities_list = [this](const Coordination::WatchResponse &)
auto watch_entities_list = [watched_queue = watched_queue](const Coordination::WatchResponse &)
{
[[maybe_unused]] bool push_result = refresh_queue.push(UUIDHelpers::Nil);
[[maybe_unused]] bool push_result = watched_queue->push(UUIDHelpers::Nil);
};
Coordination::Stat stat;
const auto entity_uuid_strs = zookeeper->getChildrenWatch(zookeeper_uuids_path, &stat, watch_entities_list);
@ -424,8 +424,6 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
for (const String & entity_uuid_str : entity_uuid_strs)
entity_uuids.insert(parseUUID(entity_uuid_str));
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
std::vector<UUID> entities_to_remove;
@ -437,14 +435,14 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
entities_to_remove.push_back(entity_uuid);
}
for (const auto & entity_uuid : entities_to_remove)
removeEntityNoLock(entity_uuid, notifications);
removeEntityNoLock(entity_uuid);
/// Locally add entities that were added to ZooKeeper
for (const auto & entity_uuid : entity_uuids)
{
const auto it = entries_by_id.find(entity_uuid);
if (it == entries_by_id.end())
refreshEntityNoLock(zookeeper, entity_uuid, notifications);
refreshEntityNoLock(zookeeper, entity_uuid);
}
LOG_DEBUG(getLogger(), "Refreshing entities list finished");
@ -452,21 +450,18 @@ void ReplicatedAccessStorage::refreshEntities(const zkutil::ZooKeeperPtr & zooke
void ReplicatedAccessStorage::refreshEntity(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id)
{
Notifications notifications;
SCOPE_EXIT({ notify(notifications); });
std::lock_guard lock{mutex};
refreshEntityNoLock(zookeeper, id, notifications);
refreshEntityNoLock(zookeeper, id);
}
void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, Notifications & notifications)
void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id)
{
LOG_DEBUG(getLogger(), "Refreshing entity {}", toString(id));
const auto watch_entity = [this, id](const Coordination::WatchResponse & response)
const auto watch_entity = [watched_queue = watched_queue, id](const Coordination::WatchResponse & response)
{
if (response.type == Coordination::Event::CHANGED)
[[maybe_unused]] bool push_result = refresh_queue.push(id);
[[maybe_unused]] bool push_result = watched_queue->push(id);
};
Coordination::Stat entity_stat;
const String entity_path = zookeeper_path + "/uuid/" + toString(id);
@ -475,16 +470,16 @@ void ReplicatedAccessStorage::refreshEntityNoLock(const zkutil::ZooKeeperPtr & z
if (exists)
{
const AccessEntityPtr entity = deserializeAccessEntity(entity_definition, entity_path);
setEntityNoLock(id, entity, notifications);
setEntityNoLock(id, entity);
}
else
{
removeEntityNoLock(id, notifications);
removeEntityNoLock(id);
}
}
void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntityPtr & entity, Notifications & notifications)
void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntityPtr & entity)
{
LOG_DEBUG(getLogger(), "Setting id {} to entity named {}", toString(id), entity->getName());
const AccessEntityType type = entity->getType();
@ -494,12 +489,14 @@ void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntit
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
if (auto it = entries_by_name.find(name); it != entries_by_name.end() && it->second->id != id)
{
removeEntityNoLock(it->second->id, notifications);
removeEntityNoLock(it->second->id);
}
/// If the entity already exists under a different type+name, remove old type+name
bool existed_before = false;
if (auto it = entries_by_id.find(id); it != entries_by_id.end())
{
existed_before = true;
const AccessEntityPtr & existing_entity = it->second.entity;
const AccessEntityType existing_type = existing_entity->getType();
const String & existing_name = existing_entity->getName();
@ -514,11 +511,18 @@ void ReplicatedAccessStorage::setEntityNoLock(const UUID & id, const AccessEntit
entry.id = id;
entry.entity = entity;
entries_by_name[name] = &entry;
prepareNotifications(entry, false, notifications);
if (initialized)
{
if (existed_before)
changes_notifier.onEntityUpdated(id, entity);
else
changes_notifier.onEntityAdded(id, entity);
}
}
void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications & notifications)
void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id)
{
LOG_DEBUG(getLogger(), "Removing entity with id {}", toString(id));
const auto it = entries_by_id.find(id);
@ -531,7 +535,6 @@ void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications
const Entry & entry = it->second;
const AccessEntityType type = entry.entity->getType();
const String & name = entry.entity->getName();
prepareNotifications(entry, true, notifications);
auto & entries_by_name = entries_by_name_and_type[static_cast<size_t>(type)];
const auto name_it = entries_by_name.find(name);
@ -542,8 +545,11 @@ void ReplicatedAccessStorage::removeEntityNoLock(const UUID & id, Notifications
else
entries_by_name.erase(name);
UUID removed_id = id;
entries_by_id.erase(id);
LOG_DEBUG(getLogger(), "Removed entity with id {}", toString(id));
changes_notifier.onEntityRemoved(removed_id, type);
}
@ -594,73 +600,4 @@ AccessEntityPtr ReplicatedAccessStorage::readImpl(const UUID & id, bool throw_if
return entry.entity;
}
void ReplicatedAccessStorage::prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const
{
const AccessEntityPtr entity = remove ? nullptr : entry.entity;
for (const auto & handler : entry.handlers_by_id)
notifications.push_back({handler, entry.id, entity});
for (const auto & handler : handlers_by_type[static_cast<size_t>(entry.entity->getType())])
notifications.push_back({handler, entry.id, entity});
}
scope_guard ReplicatedAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
auto & handlers = handlers_by_type[static_cast<size_t>(type)];
handlers.push_back(handler);
auto handler_it = std::prev(handlers.end());
return [this, type, handler_it]
{
std::lock_guard lock2{mutex};
auto & handlers2 = handlers_by_type[static_cast<size_t>(type)];
handlers2.erase(handler_it);
};
}
scope_guard ReplicatedAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
std::lock_guard lock{mutex};
const auto it = entries_by_id.find(id);
if (it == entries_by_id.end())
return {};
const Entry & entry = it->second;
auto handler_it = entry.handlers_by_id.insert(entry.handlers_by_id.end(), handler);
return [this, id, handler_it]
{
std::lock_guard lock2{mutex};
auto it2 = entries_by_id.find(id);
if (it2 != entries_by_id.end())
{
const Entry & entry2 = it2->second;
entry2.handlers_by_id.erase(handler_it);
}
};
}
bool ReplicatedAccessStorage::hasSubscription(const UUID & id) const
{
std::lock_guard lock{mutex};
const auto & it = entries_by_id.find(id);
if (it != entries_by_id.end())
{
const Entry & entry = it->second;
return !entry.handlers_by_id.empty();
}
return false;
}
bool ReplicatedAccessStorage::hasSubscription(AccessEntityType type) const
{
std::lock_guard lock{mutex};
const auto & handlers = handlers_by_type[static_cast<size_t>(type)];
return !handlers.empty();
}
}

View File

@ -18,32 +18,33 @@
namespace DB
{
class AccessChangesNotifier;
/// Implementation of IAccessStorage which keeps all data in zookeeper.
class ReplicatedAccessStorage : public IAccessStorage
{
public:
static constexpr char STORAGE_TYPE[] = "replicated";
ReplicatedAccessStorage(const String & storage_name, const String & zookeeper_path, zkutil::GetZooKeeper get_zookeeper);
ReplicatedAccessStorage(const String & storage_name, const String & zookeeper_path, zkutil::GetZooKeeper get_zookeeper, AccessChangesNotifier & changes_notifier_);
virtual ~ReplicatedAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; }
virtual void startup();
virtual void shutdown();
void startPeriodicReloading() override { startWatchingThread(); }
void stopPeriodicReloading() override { stopWatchingThread(); }
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private:
String zookeeper_path;
zkutil::GetZooKeeper get_zookeeper;
std::atomic<bool> initialized = false;
std::atomic<bool> stop_flag = false;
ThreadFromGlobalPool worker_thread;
ConcurrentBoundedQueue<UUID> refresh_queue;
std::atomic<bool> watching = false;
ThreadFromGlobalPool watching_thread;
std::shared_ptr<ConcurrentBoundedQueue<UUID>> watched_queue;
std::optional<UUID> insertImpl(const AccessEntityPtr & entity, bool replace_if_exists, bool throw_if_exists) override;
bool removeImpl(const UUID & id, bool throw_if_not_exists) override;
@ -53,37 +54,36 @@ private:
bool removeZooKeeper(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, bool throw_if_not_exists);
bool updateZooKeeper(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, const UpdateFunc & update_func, bool throw_if_not_exists);
void runWorkerThread();
void resetAfterError();
void initializeZookeeper();
void createRootNodes(const zkutil::ZooKeeperPtr & zookeeper);
void refresh();
void startWatchingThread();
void stopWatchingThread();
void runWatchingThread();
void resetAfterError();
bool refresh();
void refreshEntities(const zkutil::ZooKeeperPtr & zookeeper);
void refreshEntity(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id);
void refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id, Notifications & notifications);
void refreshEntityNoLock(const zkutil::ZooKeeperPtr & zookeeper, const UUID & id);
void setEntityNoLock(const UUID & id, const AccessEntityPtr & entity, Notifications & notifications);
void removeEntityNoLock(const UUID & id, Notifications & notifications);
void setEntityNoLock(const UUID & id, const AccessEntityPtr & entity);
void removeEntityNoLock(const UUID & id);
struct Entry
{
UUID id;
AccessEntityPtr entity;
mutable std::list<OnChangedHandler> handlers_by_id;
};
std::optional<UUID> findImpl(AccessEntityType type, const String & name) const override;
std::vector<UUID> findAllImpl(AccessEntityType type) const override;
AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
void prepareNotifications(const Entry & entry, bool remove, Notifications & notifications) const;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
mutable std::mutex mutex;
std::unordered_map<UUID, Entry> entries_by_id;
std::unordered_map<String, Entry *> entries_by_name_and_type[static_cast<size_t>(AccessEntityType::MAX)];
mutable std::list<OnChangedHandler> handlers_by_type[static_cast<size_t>(AccessEntityType::MAX)];
AccessChangesNotifier & changes_notifier;
};
}

View File

@ -66,9 +66,6 @@ RoleCache::~RoleCache() = default;
std::shared_ptr<const EnabledRoles>
RoleCache::getEnabledRoles(const std::vector<UUID> & roles, const std::vector<UUID> & roles_with_admin_option)
{
/// Declared before `lock` to send notifications after the mutex will be unlocked.
scope_guard notifications;
std::lock_guard lock{mutex};
EnabledRoles::Params params;
params.current_roles.insert(roles.begin(), roles.end());
@ -83,13 +80,13 @@ RoleCache::getEnabledRoles(const std::vector<UUID> & roles, const std::vector<UU
}
auto res = std::shared_ptr<EnabledRoles>(new EnabledRoles(params));
collectEnabledRoles(*res, notifications);
collectEnabledRoles(*res, nullptr);
enabled_roles.emplace(std::move(params), res);
return res;
}
void RoleCache::collectEnabledRoles(scope_guard & notifications)
void RoleCache::collectEnabledRoles(scope_guard * notifications)
{
/// `mutex` is already locked.
@ -107,7 +104,7 @@ void RoleCache::collectEnabledRoles(scope_guard & notifications)
}
void RoleCache::collectEnabledRoles(EnabledRoles & enabled, scope_guard & notifications)
void RoleCache::collectEnabledRoles(EnabledRoles & enabled, scope_guard * notifications)
{
/// `mutex` is already locked.
@ -170,7 +167,7 @@ void RoleCache::roleChanged(const UUID & role_id, const RolePtr & changed_role)
return;
role_from_cache->first = changed_role;
cache.update(role_id, role_from_cache);
collectEnabledRoles(notifications);
collectEnabledRoles(&notifications);
}
@ -181,7 +178,7 @@ void RoleCache::roleRemoved(const UUID & role_id)
std::lock_guard lock{mutex};
cache.remove(role_id);
collectEnabledRoles(notifications);
collectEnabledRoles(&notifications);
}
}

View File

@ -1,7 +1,7 @@
#pragma once
#include <Access/EnabledRoles.h>
#include <Poco/ExpireCache.h>
#include <Poco/AccessExpireCache.h>
#include <boost/container/flat_set.hpp>
#include <map>
#include <mutex>
@ -24,14 +24,14 @@ public:
const std::vector<UUID> & current_roles_with_admin_option);
private:
void collectEnabledRoles(scope_guard & notifications);
void collectEnabledRoles(EnabledRoles & enabled, scope_guard & notifications);
void collectEnabledRoles(scope_guard * notifications);
void collectEnabledRoles(EnabledRoles & enabled, scope_guard * notifications);
RolePtr getRole(const UUID & role_id);
void roleChanged(const UUID & role_id, const RolePtr & changed_role);
void roleRemoved(const UUID & role_id);
const AccessControl & access_control;
Poco::ExpireCache<UUID, std::pair<RolePtr, scope_guard>> cache;
Poco::AccessExpireCache<UUID, std::pair<RolePtr, scope_guard>> cache;
std::map<EnabledRoles::Params, std::weak_ptr<EnabledRoles>> enabled_roles;
mutable std::mutex mutex;
};

View File

@ -4,6 +4,7 @@
#include <Access/User.h>
#include <Access/SettingsProfile.h>
#include <Access/AccessControl.h>
#include <Access/AccessChangesNotifier.h>
#include <Dictionaries/IDictionary.h>
#include <Common/Config/ConfigReloader.h>
#include <Common/StringUtils/StringUtils.h>
@ -14,9 +15,6 @@
#include <Poco/JSON/JSON.h>
#include <Poco/JSON/Object.h>
#include <Poco/JSON/Stringifier.h>
#include <Common/logger_useful.h>
#include <boost/range/algorithm/copy.hpp>
#include <boost/range/adaptor/map.hpp>
#include <cstring>
#include <filesystem>
#include <base/FnTraits.h>
@ -525,8 +523,8 @@ namespace
}
}
UsersConfigAccessStorage::UsersConfigAccessStorage(const String & storage_name_, const AccessControl & access_control_)
: IAccessStorage(storage_name_), access_control(access_control_)
UsersConfigAccessStorage::UsersConfigAccessStorage(const String & storage_name_, AccessControl & access_control_)
: IAccessStorage(storage_name_), access_control(access_control_), memory_storage(storage_name_, access_control.getChangesNotifier())
{
}
@ -605,9 +603,9 @@ void UsersConfigAccessStorage::load(
std::make_shared<Poco::Event>(),
[&](Poco::AutoPtr<Poco::Util::AbstractConfiguration> new_config, bool /*initial_loading*/)
{
parseFromConfig(*new_config);
Settings::checkNoSettingNamesAtTopLevel(*new_config, users_config_path);
parseFromConfig(*new_config);
access_control.getChangesNotifier().sendNotifications();
},
/* already_loaded = */ false);
}
@ -662,27 +660,4 @@ std::optional<String> UsersConfigAccessStorage::readNameImpl(const UUID & id, bo
return memory_storage.readName(id, throw_if_not_exists);
}
scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(id, handler);
}
scope_guard UsersConfigAccessStorage::subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const
{
return memory_storage.subscribeForChanges(type, handler);
}
bool UsersConfigAccessStorage::hasSubscription(const UUID & id) const
{
return memory_storage.hasSubscription(id);
}
bool UsersConfigAccessStorage::hasSubscription(AccessEntityType type) const
{
return memory_storage.hasSubscription(type);
}
}

View File

@ -22,7 +22,7 @@ public:
static constexpr char STORAGE_TYPE[] = "users.xml";
UsersConfigAccessStorage(const String & storage_name_, const AccessControl & access_control_);
UsersConfigAccessStorage(const String & storage_name_, AccessControl & access_control_);
~UsersConfigAccessStorage() override;
const char * getStorageType() const override { return STORAGE_TYPE; }
@ -37,13 +37,12 @@ public:
const String & include_from_path = {},
const String & preprocessed_dir = {},
const zkutil::GetZooKeeper & get_zookeeper_function = {});
void reload();
void startPeriodicReloading();
void stopPeriodicReloading();
void reload() override;
void startPeriodicReloading() override;
void stopPeriodicReloading() override;
bool exists(const UUID & id) const override;
bool hasSubscription(const UUID & id) const override;
bool hasSubscription(AccessEntityType type) const override;
private:
void parseFromConfig(const Poco::Util::AbstractConfiguration & config);
@ -51,10 +50,8 @@ private:
std::vector<UUID> findAllImpl(AccessEntityType type) const override;
AccessEntityPtr readImpl(const UUID & id, bool throw_if_not_exists) const override;
std::optional<String> readNameImpl(const UUID & id, bool throw_if_not_exists) const override;
scope_guard subscribeForChangesImpl(const UUID & id, const OnChangedHandler & handler) const override;
scope_guard subscribeForChangesImpl(AccessEntityType type, const OnChangedHandler & handler) const override;
const AccessControl & access_control;
AccessControl & access_control;
MemoryAccessStorage memory_storage;
String path;
std::unique_ptr<ConfigReloader> config_reloader;

View File

@ -1,5 +1,6 @@
#include <gtest/gtest.h>
#include <Access/ReplicatedAccessStorage.h>
#include <Access/AccessChangesNotifier.h>
using namespace DB;
@ -12,18 +13,6 @@ namespace ErrorCodes
}
TEST(ReplicatedAccessStorage, ShutdownWithoutStartup)
{
auto get_zk = []()
{
return std::shared_ptr<zkutil::ZooKeeper>();
};
auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk);
storage.shutdown();
}
TEST(ReplicatedAccessStorage, ShutdownWithFailedStartup)
{
auto get_zk = []()
@ -31,16 +20,16 @@ TEST(ReplicatedAccessStorage, ShutdownWithFailedStartup)
return std::shared_ptr<zkutil::ZooKeeper>();
};
auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk);
AccessChangesNotifier changes_notifier;
try
{
storage.startup();
auto storage = ReplicatedAccessStorage("replicated", "/clickhouse/access", get_zk, changes_notifier);
}
catch (Exception & e)
{
if (e.code() != ErrorCodes::NO_ZOOKEEPER)
throw;
}
storage.shutdown();
}

View File

@ -59,11 +59,11 @@ struct AggregateFunctionSumData
}
/// Vectorized version
MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(addManyImpl,
MULTITARGET_FH(
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(
template <typename Value>
void NO_SANITIZE_UNDEFINED NO_INLINE
), /*addManyImpl*/ MULTITARGET_FB((const Value * __restrict ptr, size_t start, size_t end) /// NOLINT
), addManyImpl, MULTITARGET_FUNCTION_BODY((const Value * __restrict ptr, size_t start, size_t end) /// NOLINT
{
ptr += start;
size_t count = end - start;
@ -122,11 +122,11 @@ struct AggregateFunctionSumData
addManyImpl(ptr, start, end);
}
MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(addManyConditionalInternalImpl,
MULTITARGET_FH(
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(
template <typename Value, bool add_if_zero>
void NO_SANITIZE_UNDEFINED NO_INLINE
), /*addManyConditionalInternalImpl*/ MULTITARGET_FB((const Value * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end) /// NOLINT
), addManyConditionalInternalImpl, MULTITARGET_FUNCTION_BODY((const Value * __restrict ptr, const UInt8 * __restrict condition_map, size_t start, size_t end) /// NOLINT
{
ptr += start;
size_t count = end - start;

View File

@ -81,7 +81,8 @@ void IColumn::compareImpl(const Derived & rhs, size_t rhs_row_num,
if constexpr (use_indexes)
{
num_indexes = row_indexes->size();
next_index = indexes = row_indexes->data();
indexes = row_indexes->data();
next_index = indexes;
}
compare_results.resize(num_rows);
@ -100,15 +101,9 @@ void IColumn::compareImpl(const Derived & rhs, size_t rhs_row_num,
if constexpr (use_indexes)
row = indexes[i];
int res = compareAt(row, rhs_row_num, rhs, nan_direction_hint);
/// We need to convert int to Int8. Sometimes comparison return values which do not fit in one byte.
if (res < 0)
compare_results[row] = -1;
else if (res > 0)
compare_results[row] = 1;
else
compare_results[row] = 0;
int res = static_cast<const Derived *>(this)->compareAt(row, rhs_row_num, rhs, nan_direction_hint);
assert(res == 1 || res == -1 || res == 0);
compare_results[row] = static_cast<Int8>(res);
if constexpr (reversed)
compare_results[row] = -compare_results[row];
@ -124,7 +119,10 @@ void IColumn::compareImpl(const Derived & rhs, size_t rhs_row_num,
}
if constexpr (use_indexes)
row_indexes->resize(next_index - row_indexes->data());
{
size_t equal_row_indexes_size = next_index - row_indexes->data();
row_indexes->resize(equal_row_indexes_size);
}
}
template <typename Derived>

View File

@ -627,6 +627,7 @@
M(656, MEILISEARCH_EXCEPTION) \
M(657, UNSUPPORTED_MEILISEARCH_TYPE) \
M(658, MEILISEARCH_MISSING_SOME_COLUMNS) \
M(659, UNKNOWN_STATUS_OF_TRANSACTION) \
\
M(999, KEEPER_EXCEPTION) \
M(1000, POCO_EXCEPTION) \

View File

@ -35,6 +35,18 @@ namespace ErrorCodes
extern const int CANNOT_MREMAP;
}
void abortOnFailedAssertion(const String & description)
{
LOG_FATAL(&Poco::Logger::root(), "Logical error: '{}'.", description);
/// This is to suppress -Wmissing-noreturn
volatile bool always_false = false;
if (always_false)
return;
abort();
}
/// - Aborts the process if error code is LOGICAL_ERROR.
/// - Increments error codes statistics.
void handle_error_code([[maybe_unused]] const std::string & msg, int code, bool remote, const Exception::FramePointers & trace)
@ -44,8 +56,7 @@ void handle_error_code([[maybe_unused]] const std::string & msg, int code, bool
#ifdef ABORT_ON_LOGICAL_ERROR
if (code == ErrorCodes::LOGICAL_ERROR)
{
LOG_FATAL(&Poco::Logger::root(), "Logical error: '{}'.", msg);
abort();
abortOnFailedAssertion(msg);
}
#endif

View File

@ -12,16 +12,14 @@
#include <fmt/format.h>
#if !defined(NDEBUG) || defined(ADDRESS_SANITIZER) || defined(THREAD_SANITIZER) || defined(MEMORY_SANITIZER) || defined(UNDEFINED_BEHAVIOR_SANITIZER)
#define ABORT_ON_LOGICAL_ERROR
#endif
namespace Poco { class Logger; }
namespace DB
{
void abortOnFailedAssertion(const String & description);
class Exception : public Poco::Exception
{
public:

View File

@ -107,3 +107,4 @@ private:
};
using OptimizedRegularExpression = OptimizedRegularExpressionImpl<true>;
using OptimizedRegularExpressionSingleThreaded = OptimizedRegularExpressionImpl<false>;

View File

@ -233,8 +233,8 @@ DECLARE_AVX512F_SPECIFIC_CODE(
* class TestClass
* {
* public:
* MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(testFunctionImpl,
* MULTITARGET_FH(int), /\*testFunction*\/ MULTITARGET_FB((int value)
* MULTITARGET_FUNCTION_AVX2_SSE42(
* MULTITARGET_FUNCTION_HEADER(int), testFunctionImpl, MULTITARGET_FUNCTION_BODY((int value)
* {
* return value;
* })
@ -259,15 +259,15 @@ DECLARE_AVX512F_SPECIFIC_CODE(
*/
/// Function header
#define MULTITARGET_FH(...) __VA_ARGS__
#define MULTITARGET_FUNCTION_HEADER(...) __VA_ARGS__
/// Function body
#define MULTITARGET_FB(...) __VA_ARGS__
#define MULTITARGET_FUNCTION_BODY(...) __VA_ARGS__
#if ENABLE_MULTITARGET_CODE && defined(__GNUC__) && defined(__x86_64__)
/// NOLINTNEXTLINE
#define MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(name, FUNCTION_HEADER, FUNCTION_BODY) \
#define MULTITARGET_FUNCTION_AVX2_SSE42(FUNCTION_HEADER, name, FUNCTION_BODY) \
FUNCTION_HEADER \
\
AVX2_FUNCTION_SPECIFIC_ATTRIBUTE \
@ -288,7 +288,7 @@ DECLARE_AVX512F_SPECIFIC_CODE(
#else
/// NOLINTNEXTLINE
#define MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(name, FUNCTION_HEADER, FUNCTION_BODY) \
#define MULTITARGET_FUNCTION_AVX2_SSE42(FUNCTION_HEADER, name, FUNCTION_BODY) \
FUNCTION_HEADER \
\
name \

View File

@ -22,15 +22,22 @@ void CompressedWriteBuffer::nextImpl()
if (!offset())
return;
UInt32 compressed_size = 0;
size_t decompressed_size = offset();
UInt32 compressed_reserve_size = codec->getCompressedReserveSize(decompressed_size);
if (out.available() > compressed_reserve_size + CHECKSUM_SIZE)
/** During compression we need buffer with capacity >= compressed_reserve_size + CHECKSUM_SIZE.
*
* If output buffer has necessary capacity, we can compress data directly in output buffer.
* Then we can write checksum at the output buffer begin.
*
* If output buffer does not have necessary capacity. Compress data in temporary buffer.
* Then we can write checksum and temporary buffer in output buffer.
*/
if (out.available() >= compressed_reserve_size + CHECKSUM_SIZE)
{
char * out_checksum_ptr = out.position();
char * out_compressed_ptr = out.position() + CHECKSUM_SIZE;
compressed_size = codec->compress(working_buffer.begin(), decompressed_size, out_compressed_ptr);
UInt32 compressed_size = codec->compress(working_buffer.begin(), decompressed_size, out_compressed_ptr);
CityHash_v1_0_2::uint128 checksum = CityHash_v1_0_2::CityHash128(out_compressed_ptr, compressed_size);
memcpy(out_checksum_ptr, reinterpret_cast<const char *>(&checksum), CHECKSUM_SIZE);
@ -39,7 +46,7 @@ void CompressedWriteBuffer::nextImpl()
else
{
compressed_buffer.resize(compressed_reserve_size);
compressed_size = codec->compress(working_buffer.begin(), decompressed_size, compressed_buffer.data());
UInt32 compressed_size = codec->compress(working_buffer.begin(), decompressed_size, compressed_buffer.data());
CityHash_v1_0_2::uint128 checksum = CityHash_v1_0_2::CityHash128(compressed_buffer.data(), compressed_size);
out.write(reinterpret_cast<const char *>(&checksum), CHECKSUM_SIZE);

View File

@ -466,20 +466,23 @@ nuraft::cb_func::ReturnCode KeeperServer::callbackFunc(nuraft::cb_func::Type typ
{
if (is_recovering)
{
const auto finish_recovering = [&]
{
auto new_params = raft_instance->get_current_params();
new_params.custom_commit_quorum_size_ = 0;
new_params.custom_election_quorum_size_ = 0;
raft_instance->update_params(new_params);
LOG_INFO(log, "Recovery is done. You can continue using cluster normally.");
is_recovering = false;
};
switch (type)
{
case nuraft::cb_func::HeartBeat:
{
if (raft_instance->isClusterHealthy())
{
auto new_params = raft_instance->get_current_params();
new_params.custom_commit_quorum_size_ = 0;
new_params.custom_election_quorum_size_ = 0;
raft_instance->update_params(new_params);
LOG_INFO(log, "Recovery is done. You can continue using cluster normally.");
is_recovering = false;
}
finish_recovering();
break;
}
case nuraft::cb_func::NewConfig:
@ -490,8 +493,19 @@ nuraft::cb_func::ReturnCode KeeperServer::callbackFunc(nuraft::cb_func::Type typ
// Because we manually set the config to commit
// we need to call the reconfigure also
uint64_t log_idx = *static_cast<uint64_t *>(param->ctx);
if (log_idx == state_manager->load_config()->get_log_idx())
raft_instance->forceReconfigure(state_manager->load_config());
auto config = state_manager->load_config();
if (log_idx == config->get_log_idx())
{
raft_instance->forceReconfigure(config);
// Single node cluster doesn't need to wait for any other nodes
// so we can finish recovering immediately after applying
// new configuration
if (config->get_servers().size() == 1)
finish_recovering();
}
break;
}
case nuraft::cb_func::ProcessReq:

View File

@ -601,6 +601,15 @@ NamesAndTypesList Block::getNamesAndTypesList() const
return res;
}
NamesAndTypes Block::getNamesAndTypes() const
{
NamesAndTypes res;
for (const auto & elem : data)
res.emplace_back(elem.name, elem.type);
return res;
}
Names Block::getNames() const
{
@ -756,6 +765,17 @@ void Block::updateHash(SipHash & hash) const
col.column->updateHashWithValue(row_no, hash);
}
Serializations Block::getSerializations() const
{
Serializations res;
res.reserve(data.size());
for (const auto & column : data)
res.push_back(column.type->getDefaultSerialization());
return res;
}
void convertToFullIfSparse(Block & block)
{
for (auto & column : block)

View File

@ -89,11 +89,14 @@ public:
const ColumnsWithTypeAndName & getColumnsWithTypeAndName() const;
NamesAndTypesList getNamesAndTypesList() const;
NamesAndTypes getNamesAndTypes() const;
Names getNames() const;
DataTypes getDataTypes() const;
Names getDataTypeNames() const;
std::unordered_map<String, size_t> getNamesToIndexesMap() const;
Serializations getSerializations() const;
/// Returns number of rows from first column in block, not equal to nullptr. If no columns, returns 0.
size_t rows() const;

View File

@ -65,6 +65,13 @@ void BlockMissingValues::setBit(size_t column_idx, size_t row_idx)
mask[row_idx] = true;
}
void BlockMissingValues::setBits(size_t column_idx, size_t rows)
{
RowsBitMask & mask = rows_mask_by_column_id[column_idx];
mask.resize(rows);
std::fill(mask.begin(), mask.end(), true);
}
const BlockMissingValues::RowsBitMask & BlockMissingValues::getDefaultsBitmask(size_t column_idx) const
{
static RowsBitMask none;

View File

@ -56,7 +56,10 @@ public:
const RowsBitMask & getDefaultsBitmask(size_t column_idx) const;
/// Check that we have to replace default value at least in one of columns
bool hasDefaultBits(size_t column_idx) const;
/// Set bit for a specified row in a single column.
void setBit(size_t column_idx, size_t row_idx);
/// Set bits for all rows in a single column.
void setBits(size_t column_idx, size_t rows);
bool empty() const { return rows_mask_by_column_id.empty(); }
size_t size() const { return rows_mask_by_column_id.size(); }
void clear() { rows_mask_by_column_id.clear(); }

View File

@ -156,7 +156,7 @@ inline DecimalComponents<DecimalType> splitWithScaleMultiplier(
using T = typename DecimalType::NativeType;
const auto whole = decimal.value / scale_multiplier;
auto fractional = decimal.value % scale_multiplier;
if (fractional < T(0))
if (whole && fractional < T(0))
fractional *= T(-1);
return {whole, fractional};
@ -199,7 +199,7 @@ inline typename DecimalType::NativeType getFractionalPartWithScaleMultiplier(
/// Anycase we make modulo before compare to make scale_multiplier > 1 unaffected.
T result = decimal.value % scale_multiplier;
if constexpr (!keep_sign)
if (result < T(0))
if (decimal.value / scale_multiplier && result < T(0))
result = -result;
return result;

View File

@ -33,8 +33,6 @@
#define DEFAULT_TEMPORARY_LIVE_VIEW_TIMEOUT_SEC 5
#define DEFAULT_PERIODIC_LIVE_VIEW_REFRESH_SEC 60
#define DEFAULT_WINDOW_VIEW_CLEAN_INTERVAL_SEC 5
#define DEFAULT_WINDOW_VIEW_HEARTBEAT_INTERVAL_SEC 15
#define SHOW_CHARS_ON_SYNTAX_ERROR ptrdiff_t(160)
#define DBMS_CONNECTION_POOL_WITH_FAILOVER_DEFAULT_MAX_TRIES 3
/// each period reduces the error counter by 2 times

View File

@ -435,8 +435,9 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
M(Seconds, live_view_heartbeat_interval, 15, "The heartbeat interval in seconds to indicate live query is alive.", 0) \
M(UInt64, max_live_view_insert_blocks_before_refresh, 64, "Limit maximum number of inserted blocks after which mergeable blocks are dropped and query is re-executed.", 0) \
M(Bool, allow_experimental_window_view, false, "Enable WINDOW VIEW. Not mature enough.", 0) \
M(Seconds, window_view_clean_interval, DEFAULT_WINDOW_VIEW_CLEAN_INTERVAL_SEC, "The clean interval of window view in seconds to free outdated data.", 0) \
M(Seconds, window_view_heartbeat_interval, DEFAULT_WINDOW_VIEW_HEARTBEAT_INTERVAL_SEC, "The heartbeat interval in seconds to indicate watch query is alive.", 0) \
M(Seconds, window_view_clean_interval, 60, "The clean interval of window view in seconds to free outdated data.", 0) \
M(Seconds, window_view_heartbeat_interval, 15, "The heartbeat interval in seconds to indicate watch query is alive.", 0) \
M(Seconds, wait_for_window_view_fire_signal_timeout, 10, "Timeout for waiting for window view fire signal in event time processing", 0) \
M(UInt64, min_free_disk_space_for_temporary_data, 0, "The minimum disk space to keep while writing temporary data used in external sorting and aggregation.", 0) \
\
M(DefaultDatabaseEngine, default_database_engine, DefaultDatabaseEngine::Atomic, "Default database engine.", 0) \
@ -591,6 +592,7 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
M(String, insert_deduplication_token, "", "If not empty, used for duplicate detection instead of data digest", 0) \
M(Bool, count_distinct_optimization, false, "Rewrite count distinct to subquery of group by", 0) \
M(Bool, throw_on_unsupported_query_inside_transaction, true, "Throw exception if unsupported query is used inside transaction", 0) \
M(TransactionsWaitCSNMode, wait_changes_become_visible_after_commit_mode, TransactionsWaitCSNMode::WAIT_UNKNOWN, "Wait for committed changes to become actually visible in the latest snapshot", 0) \
M(Bool, throw_if_no_data_to_insert, true, "Enables or disables empty INSERTs, enabled by default", 0) \
M(Bool, compatibility_ignore_auto_increment_in_create_table, false, "Ignore AUTO_INCREMENT keyword in column declaration if true, otherwise return error. It simplifies migration from MySQL", 0) \
// End of COMMON_SETTINGS
@ -636,7 +638,7 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
M(Bool, output_format_csv_crlf_end_of_line, false, "If it is set true, end of line in CSV format will be \\r\\n instead of \\n.", 0) \
M(Bool, input_format_csv_enum_as_number, false, "Treat inserted enum values in CSV formats as enum indices \\N", 0) \
M(Bool, input_format_csv_arrays_as_nested_csv, false, R"(When reading Array from CSV, expect that its elements were serialized in nested CSV and then put into string. Example: "[""Hello"", ""world"", ""42"""" TV""]". Braces around array can be omitted.)", 0) \
M(Bool, input_format_skip_unknown_fields, false, "Skip columns with unknown names from input data (it works for JSONEachRow, -WithNames, -WithNamesAndTypes and TSKV formats).", 0) \
M(Bool, input_format_skip_unknown_fields, true, "Skip columns with unknown names from input data (it works for JSONEachRow, -WithNames, -WithNamesAndTypes and TSKV formats).", 0) \
M(Bool, input_format_with_names_use_header, true, "For -WithNames input formats this controls whether format parser is to assume that column data appear in the input exactly as they are specified in the header.", 0) \
M(Bool, input_format_with_types_use_header, true, "For -WithNamesAndTypes input formats this controls whether format parser should check if data types from the input match data types from the header.", 0) \
M(Bool, input_format_import_nested_json, false, "Map nested JSON data to nested tables (it works for JSONEachRow format).", 0) \
@ -699,6 +701,7 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
M(Bool, output_format_pretty_color, true, "Use ANSI escape sequences to paint colors in Pretty formats", 0) \
M(String, output_format_pretty_grid_charset, "UTF-8", "Charset for printing grid borders. Available charsets: ASCII, UTF-8 (default one).", 0) \
M(UInt64, output_format_parquet_row_group_size, 1000000, "Row group size in rows.", 0) \
M(Bool, output_format_parquet_string_as_string, false, "Use Parquet String type instead of Binary for String columns.", 0) \
M(String, output_format_avro_codec, "", "Compression codec used for output. Possible values: 'null', 'deflate', 'snappy'.", 0) \
M(UInt64, output_format_avro_sync_interval, 16 * 1024, "Sync interval in bytes.", 0) \
M(String, output_format_avro_string_column_pattern, "", "For Avro format: regexp of String columns to select as AVRO string.", 0) \
@ -736,6 +739,9 @@ static constexpr UInt64 operator""_GiB(unsigned long long value)
M(UInt64, cross_to_inner_join_rewrite, 1, "Use inner join instead of comma/cross join if possible. Possible values: 0 - no rewrite, 1 - apply if possible, 2 - force rewrite all cross joins", 0) \
\
M(Bool, output_format_arrow_low_cardinality_as_dictionary, false, "Enable output LowCardinality type as Dictionary Arrow type", 0) \
M(Bool, output_format_arrow_string_as_string, false, "Use Arrow String type instead of Binary for String columns", 0) \
\
M(Bool, output_format_orc_string_as_string, false, "Use ORC String type instead of Binary for String columns", 0) \
\
M(EnumComparingMode, format_capn_proto_enum_comparising_mode, FormatSettings::EnumComparingMode::BY_VALUES, "How to map ClickHouse Enum and CapnProto Enum", 0) \
\

View File

@ -131,6 +131,11 @@ IMPLEMENT_SETTING_ENUM(ShortCircuitFunctionEvaluation, ErrorCodes::BAD_ARGUMENTS
{"force_enable", ShortCircuitFunctionEvaluation::FORCE_ENABLE},
{"disable", ShortCircuitFunctionEvaluation::DISABLE}})
IMPLEMENT_SETTING_ENUM(TransactionsWaitCSNMode, ErrorCodes::BAD_ARGUMENTS,
{{"async", TransactionsWaitCSNMode::ASYNC},
{"wait", TransactionsWaitCSNMode::WAIT},
{"wait_unknown", TransactionsWaitCSNMode::WAIT_UNKNOWN}})
IMPLEMENT_SETTING_ENUM(EnumComparingMode, ErrorCodes::BAD_ARGUMENTS,
{{"by_names", FormatSettings::EnumComparingMode::BY_NAMES},
{"by_values", FormatSettings::EnumComparingMode::BY_VALUES},

View File

@ -183,6 +183,15 @@ enum class ShortCircuitFunctionEvaluation
DECLARE_SETTING_ENUM(ShortCircuitFunctionEvaluation)
enum class TransactionsWaitCSNMode
{
ASYNC,
WAIT,
WAIT_UNKNOWN,
};
DECLARE_SETTING_ENUM(TransactionsWaitCSNMode)
DECLARE_SETTING_ENUM_WITH_RENAME(EnumComparingMode, FormatSettings::EnumComparingMode)
DECLARE_SETTING_ENUM_WITH_RENAME(EscapingRule, FormatSettings::EscapingRule)

View File

@ -176,7 +176,7 @@ INSTANTIATE_TEST_SUITE_P(Basic,
}
},
{
"When scale is not 0 and whole part is 0.",
"For positive Decimal value, with scale not 0, and whole part is 0.",
123,
3,
{
@ -184,6 +184,16 @@ INSTANTIATE_TEST_SUITE_P(Basic,
123
}
},
{
"For negative Decimal value, with scale not 0, and whole part is 0.",
-123,
3,
{
0,
-123
}
},
{
"For negative Decimal value whole part is negative, fractional is non-negative.",
-1234567'89,
@ -216,6 +226,24 @@ INSTANTIATE_TEST_SUITE_P(Basic,
187618332,
123
}
},
{
"Negative timestamp 1969-12-31 23:59:59.123 UTC",
DateTime64(-877),
3,
{
0,
-877
}
},
{
"Positive timestamp 1970-01-01 00:00:00.123 UTC",
DateTime64(123),
3,
{
0,
123
}
}
})
);

View File

@ -1,5 +1,5 @@
#include <Formats/EscapingRuleUtils.h>
#include <Formats/JSONEachRowUtils.h>
#include <Formats/JSONUtils.h>
#include <Formats/ReadSchemaUtils.h>
#include <DataTypes/Serializations/SerializationNullable.h>
#include <DataTypes/DataTypeString.h>
@ -71,7 +71,7 @@ String escapingRuleToString(FormatSettings::EscapingRule escaping_rule)
void skipFieldByEscapingRule(ReadBuffer & buf, FormatSettings::EscapingRule escaping_rule, const FormatSettings & format_settings)
{
String tmp;
NullOutput out;
constexpr const char * field_name = "<SKIPPED COLUMN>";
constexpr size_t field_name_len = 16;
switch (escaping_rule)
@ -80,19 +80,19 @@ void skipFieldByEscapingRule(ReadBuffer & buf, FormatSettings::EscapingRule esca
/// Empty field, just skip spaces
break;
case FormatSettings::EscapingRule::Escaped:
readEscapedString(tmp, buf);
readEscapedStringInto(out, buf);
break;
case FormatSettings::EscapingRule::Quoted:
readQuotedFieldIntoString(tmp, buf);
readQuotedFieldInto(out, buf);
break;
case FormatSettings::EscapingRule::CSV:
readCSVString(tmp, buf, format_settings.csv);
readCSVStringInto(out, buf, format_settings.csv);
break;
case FormatSettings::EscapingRule::JSON:
skipJSONField(buf, StringRef(field_name, field_name_len));
break;
case FormatSettings::EscapingRule::Raw:
readString(tmp, buf);
readStringInto(out, buf);
break;
default:
__builtin_unreachable();
@ -219,13 +219,13 @@ String readByEscapingRule(ReadBuffer & buf, FormatSettings::EscapingRule escapin
if constexpr (read_string)
readQuotedString(result, buf);
else
readQuotedFieldIntoString(result, buf);
readQuotedField(result, buf);
break;
case FormatSettings::EscapingRule::JSON:
if constexpr (read_string)
readJSONString(result, buf);
else
readJSONFieldIntoString(result, buf);
readJSONField(result, buf);
break;
case FormatSettings::EscapingRule::Raw:
readString(result, buf);
@ -452,7 +452,7 @@ DataTypePtr determineDataTypeByEscapingRule(const String & field, const FormatSe
return buf.eof() ? type : nullptr;
}
case FormatSettings::EscapingRule::JSON:
return getDataTypeFromJSONField(field);
return JSONUtils::getDataTypeFromField(field);
case FormatSettings::EscapingRule::CSV:
{
if (!format_settings.csv.input_format_use_best_effort_in_schema_inference)

View File

@ -99,6 +99,7 @@ FormatSettings getFormatSettings(ContextPtr context, const Settings & settings)
format_settings.parquet.case_insensitive_column_matching = settings.input_format_parquet_case_insensitive_column_matching;
format_settings.parquet.allow_missing_columns = settings.input_format_parquet_allow_missing_columns;
format_settings.parquet.skip_columns_with_unsupported_types_in_schema_inference = settings.input_format_parquet_skip_columns_with_unsupported_types_in_schema_inference;
format_settings.parquet.output_string_as_string = settings.output_format_parquet_string_as_string;
format_settings.pretty.charset = settings.output_format_pretty_grid_charset.toString() == "ASCII" ? FormatSettings::Pretty::Charset::ASCII : FormatSettings::Pretty::Charset::UTF8;
format_settings.pretty.color = settings.output_format_pretty_color;
format_settings.pretty.max_column_pad_width = settings.output_format_pretty_max_column_pad_width;
@ -132,17 +133,19 @@ FormatSettings getFormatSettings(ContextPtr context, const Settings & settings)
format_settings.arrow.import_nested = settings.input_format_arrow_import_nested;
format_settings.arrow.allow_missing_columns = settings.input_format_arrow_allow_missing_columns;
format_settings.arrow.skip_columns_with_unsupported_types_in_schema_inference = settings.input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference;
format_settings.arrow.skip_columns_with_unsupported_types_in_schema_inference = settings.input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference;
format_settings.arrow.case_insensitive_column_matching = settings.input_format_arrow_case_insensitive_column_matching;
format_settings.arrow.output_string_as_string = settings.output_format_arrow_string_as_string;
format_settings.orc.import_nested = settings.input_format_orc_import_nested;
format_settings.orc.allow_missing_columns = settings.input_format_orc_allow_missing_columns;
format_settings.orc.row_batch_size = settings.input_format_orc_row_batch_size;
format_settings.orc.skip_columns_with_unsupported_types_in_schema_inference = settings.input_format_orc_skip_columns_with_unsupported_types_in_schema_inference;
format_settings.arrow.skip_columns_with_unsupported_types_in_schema_inference = settings.input_format_arrow_skip_columns_with_unsupported_types_in_schema_inference;
format_settings.arrow.case_insensitive_column_matching = settings.input_format_arrow_case_insensitive_column_matching;
format_settings.orc.import_nested = settings.input_format_orc_import_nested;
format_settings.orc.allow_missing_columns = settings.input_format_orc_allow_missing_columns;
format_settings.orc.row_batch_size = settings.input_format_orc_row_batch_size;
format_settings.orc.skip_columns_with_unsupported_types_in_schema_inference = settings.input_format_orc_skip_columns_with_unsupported_types_in_schema_inference;
format_settings.orc.case_insensitive_column_matching = settings.input_format_orc_case_insensitive_column_matching;
format_settings.orc.output_string_as_string = settings.output_format_orc_string_as_string;
format_settings.defaults_for_omitted_fields = settings.input_format_defaults_for_omitted_fields;
format_settings.capn_proto.enum_comparing_mode = settings.format_capn_proto_enum_comparising_mode;
format_settings.seekable_read = settings.input_format_allow_seeks;
@ -538,19 +541,19 @@ void FormatFactory::markOutputFormatSupportsParallelFormatting(const String & na
}
void FormatFactory::markFormatAsColumnOriented(const String & name)
void FormatFactory::markFormatSupportsSubsetOfColumns(const String & name)
{
auto & target = dict[name].is_column_oriented;
auto & target = dict[name].supports_subset_of_columns;
if (target)
throw Exception("FormatFactory: Format " + name + " is already marked as column oriented", ErrorCodes::LOGICAL_ERROR);
throw Exception("FormatFactory: Format " + name + " is already marked as supporting subset of columns", ErrorCodes::LOGICAL_ERROR);
target = true;
}
bool FormatFactory::checkIfFormatIsColumnOriented(const String & name)
bool FormatFactory::checkIfFormatSupportsSubsetOfColumns(const String & name) const
{
const auto & target = getCreators(name);
return target.is_column_oriented;
return target.supports_subset_of_columns;
}
bool FormatFactory::isInputFormat(const String & name) const
@ -565,19 +568,19 @@ bool FormatFactory::isOutputFormat(const String & name) const
return it != dict.end() && it->second.output_creator;
}
bool FormatFactory::checkIfFormatHasSchemaReader(const String & name)
bool FormatFactory::checkIfFormatHasSchemaReader(const String & name) const
{
const auto & target = getCreators(name);
return bool(target.schema_reader_creator);
}
bool FormatFactory::checkIfFormatHasExternalSchemaReader(const String & name)
bool FormatFactory::checkIfFormatHasExternalSchemaReader(const String & name) const
{
const auto & target = getCreators(name);
return bool(target.external_schema_reader_creator);
}
bool FormatFactory::checkIfFormatHasAnySchemaReader(const String & name)
bool FormatFactory::checkIfFormatHasAnySchemaReader(const String & name) const
{
return checkIfFormatHasSchemaReader(name) || checkIfFormatHasExternalSchemaReader(name);
}

View File

@ -108,7 +108,7 @@ private:
SchemaReaderCreator schema_reader_creator;
ExternalSchemaReaderCreator external_schema_reader_creator;
bool supports_parallel_formatting{false};
bool is_column_oriented{false};
bool supports_subset_of_columns{false};
NonTrivialPrefixAndSuffixChecker non_trivial_prefix_and_suffix_checker;
AppendSupportChecker append_support_checker;
};
@ -194,13 +194,13 @@ public:
void registerExternalSchemaReader(const String & name, ExternalSchemaReaderCreator external_schema_reader_creator);
void markOutputFormatSupportsParallelFormatting(const String & name);
void markFormatAsColumnOriented(const String & name);
void markFormatSupportsSubsetOfColumns(const String & name);
bool checkIfFormatIsColumnOriented(const String & name);
bool checkIfFormatSupportsSubsetOfColumns(const String & name) const;
bool checkIfFormatHasSchemaReader(const String & name);
bool checkIfFormatHasExternalSchemaReader(const String & name);
bool checkIfFormatHasAnySchemaReader(const String & name);
bool checkIfFormatHasSchemaReader(const String & name) const;
bool checkIfFormatHasExternalSchemaReader(const String & name) const;
bool checkIfFormatHasAnySchemaReader(const String & name) const;
const FormatsDictionary & getAllFormats() const
{

View File

@ -81,6 +81,7 @@ struct FormatSettings
bool allow_missing_columns = false;
bool skip_columns_with_unsupported_types_in_schema_inference = false;
bool case_insensitive_column_matching = false;
bool output_string_as_string = false;
} arrow;
struct
@ -148,6 +149,7 @@ struct FormatSettings
bool skip_columns_with_unsupported_types_in_schema_inference = false;
bool case_insensitive_column_matching = false;
std::unordered_set<int> skip_row_groups = {};
bool output_string_as_string = false;
} parquet;
struct Pretty
@ -234,6 +236,7 @@ struct FormatSettings
bool skip_columns_with_unsupported_types_in_schema_inference = false;
bool case_insensitive_column_matching = false;
std::unordered_set<int> skip_stripes = {};
bool output_string_as_string = false;
} orc;
/// For capnProto format we should determine how to

View File

@ -1,387 +0,0 @@
#include <IO/ReadHelpers.h>
#include <Formats/JSONEachRowUtils.h>
#include <Formats/ReadSchemaUtils.h>
#include <IO/ReadBufferFromString.h>
#include <DataTypes/Serializations/SerializationNullable.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeObject.h>
#include <DataTypes/DataTypeFactory.h>
#include <Common/JSONParsers/SimdJSONParser.h>
#include <Common/JSONParsers/RapidJSONParser.h>
#include <Common/JSONParsers/DummyJSONParser.h>
#include <base/find_symbols.h>
namespace DB
{
namespace ErrorCodes
{
extern const int INCORRECT_DATA;
extern const int LOGICAL_ERROR;
}
template <const char opening_bracket, const char closing_bracket>
static std::pair<bool, size_t> fileSegmentationEngineJSONEachRowImpl(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size, size_t min_rows)
{
skipWhitespaceIfAny(in);
char * pos = in.position();
size_t balance = 0;
bool quotes = false;
size_t number_of_rows = 0;
while (loadAtPosition(in, memory, pos) && (balance || memory.size() + static_cast<size_t>(pos - in.position()) < min_chunk_size || number_of_rows < min_rows))
{
const auto current_object_size = memory.size() + static_cast<size_t>(pos - in.position());
if (min_chunk_size != 0 && current_object_size > 10 * min_chunk_size)
throw ParsingException("Size of JSON object is extremely large. Expected not greater than " +
std::to_string(min_chunk_size) + " bytes, but current is " + std::to_string(current_object_size) +
" bytes per row. Increase the value setting 'min_chunk_bytes_for_parallel_parsing' or check your data manually, most likely JSON is malformed", ErrorCodes::INCORRECT_DATA);
if (quotes)
{
pos = find_first_symbols<'\\', '"'>(pos, in.buffer().end());
if (pos > in.buffer().end())
throw Exception("Position in buffer is out of bounds. There must be a bug.", ErrorCodes::LOGICAL_ERROR);
else if (pos == in.buffer().end())
continue;
if (*pos == '\\')
{
++pos;
if (loadAtPosition(in, memory, pos))
++pos;
}
else if (*pos == '"')
{
++pos;
quotes = false;
}
}
else
{
pos = find_first_symbols<opening_bracket, closing_bracket, '\\', '"'>(pos, in.buffer().end());
if (pos > in.buffer().end())
throw Exception("Position in buffer is out of bounds. There must be a bug.", ErrorCodes::LOGICAL_ERROR);
else if (pos == in.buffer().end())
continue;
else if (*pos == opening_bracket)
{
++balance;
++pos;
}
else if (*pos == closing_bracket)
{
--balance;
++pos;
}
else if (*pos == '\\')
{
++pos;
if (loadAtPosition(in, memory, pos))
++pos;
}
else if (*pos == '"')
{
quotes = true;
++pos;
}
if (balance == 0)
++number_of_rows;
}
}
saveUpToPosition(in, memory, pos);
return {loadAtPosition(in, memory, pos), number_of_rows};
}
template <const char opening_bracket, const char closing_bracket>
static String readJSONEachRowLineIntoStringImpl(ReadBuffer & in)
{
Memory memory;
fileSegmentationEngineJSONEachRowImpl<opening_bracket, closing_bracket>(in, memory, 0, 1);
return String(memory.data(), memory.size());
}
template <class Element>
DataTypePtr getDataTypeFromJSONFieldImpl(const Element & field)
{
if (field.isNull())
return nullptr;
if (field.isBool())
return DataTypeFactory::instance().get("Nullable(Bool)");
if (field.isInt64() || field.isUInt64() || field.isDouble())
return makeNullable(std::make_shared<DataTypeFloat64>());
if (field.isString())
return makeNullable(std::make_shared<DataTypeString>());
if (field.isArray())
{
auto array = field.getArray();
/// Return nullptr in case of empty array because we cannot determine nested type.
if (array.size() == 0)
return nullptr;
DataTypes nested_data_types;
/// If this array contains fields with different types we will treat it as Tuple.
bool is_tuple = false;
for (const auto element : array)
{
auto type = getDataTypeFromJSONFieldImpl(element);
if (!type)
return nullptr;
if (!nested_data_types.empty() && type->getName() != nested_data_types.back()->getName())
is_tuple = true;
nested_data_types.push_back(std::move(type));
}
if (is_tuple)
return std::make_shared<DataTypeTuple>(nested_data_types);
return std::make_shared<DataTypeArray>(nested_data_types.back());
}
if (field.isObject())
{
auto object = field.getObject();
DataTypePtr value_type;
bool is_object = false;
for (const auto key_value_pair : object)
{
auto type = getDataTypeFromJSONFieldImpl(key_value_pair.second);
if (!type)
continue;
if (isObject(type))
{
is_object = true;
break;
}
if (!value_type)
{
value_type = type;
}
else if (!value_type->equals(*type))
{
is_object = true;
break;
}
}
if (is_object)
return std::make_shared<DataTypeObject>("json", true);
if (value_type)
return std::make_shared<DataTypeMap>(std::make_shared<DataTypeString>(), value_type);
return nullptr;
}
throw Exception{ErrorCodes::INCORRECT_DATA, "Unexpected JSON type"};
}
auto getJSONParserAndElement()
{
#if USE_SIMDJSON
return std::pair<SimdJSONParser, SimdJSONParser::Element>();
#elif USE_RAPIDJSON
return std::pair<RapidJSONParser, RapidJSONParser::Element>();
#else
return std::pair<DummyJSONParser, DummyJSONParser::Element>();
#endif
}
DataTypePtr getDataTypeFromJSONField(const String & field)
{
auto [parser, element] = getJSONParserAndElement();
bool parsed = parser.parse(field, element);
if (!parsed)
throw Exception(ErrorCodes::INCORRECT_DATA, "Cannot parse JSON object");
return getDataTypeFromJSONFieldImpl(element);
}
template <class Extractor, const char opening_bracket, const char closing_bracket>
static DataTypes determineColumnDataTypesFromJSONEachRowDataImpl(ReadBuffer & in, bool /*json_strings*/, Extractor & extractor)
{
String line = readJSONEachRowLineIntoStringImpl<opening_bracket, closing_bracket>(in);
auto [parser, element] = getJSONParserAndElement();
bool parsed = parser.parse(line, element);
if (!parsed)
throw Exception(ErrorCodes::INCORRECT_DATA, "Cannot parse JSON object");
auto fields = extractor.extract(element);
DataTypes data_types;
data_types.reserve(fields.size());
for (const auto & field : fields)
data_types.push_back(getDataTypeFromJSONFieldImpl(field));
/// TODO: For JSONStringsEachRow/JSONCompactStringsEach all types will be strings.
/// Should we try to parse data inside strings somehow in this case?
return data_types;
}
std::pair<bool, size_t> fileSegmentationEngineJSONEachRow(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size)
{
return fileSegmentationEngineJSONEachRowImpl<'{', '}'>(in, memory, min_chunk_size, 1);
}
std::pair<bool, size_t> fileSegmentationEngineJSONCompactEachRow(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size, size_t min_rows)
{
return fileSegmentationEngineJSONEachRowImpl<'[', ']'>(in, memory, min_chunk_size, min_rows);
}
struct JSONEachRowFieldsExtractor
{
template <class Element>
std::vector<Element> extract(const Element & element)
{
/// {..., "<column_name>" : <value>, ...}
if (!element.isObject())
throw Exception(ErrorCodes::INCORRECT_DATA, "Root JSON value is not an object");
auto object = element.getObject();
std::vector<Element> fields;
fields.reserve(object.size());
column_names.reserve(object.size());
for (const auto & key_value_pair : object)
{
column_names.emplace_back(key_value_pair.first);
fields.push_back(key_value_pair.second);
}
return fields;
}
std::vector<String> column_names;
};
NamesAndTypesList readRowAndGetNamesAndDataTypesForJSONEachRow(ReadBuffer & in, bool json_strings)
{
JSONEachRowFieldsExtractor extractor;
auto data_types = determineColumnDataTypesFromJSONEachRowDataImpl<JSONEachRowFieldsExtractor, '{', '}'>(in, json_strings, extractor);
NamesAndTypesList result;
for (size_t i = 0; i != extractor.column_names.size(); ++i)
result.emplace_back(extractor.column_names[i], data_types[i]);
return result;
}
struct JSONCompactEachRowFieldsExtractor
{
template <class Element>
std::vector<Element> extract(const Element & element)
{
/// [..., <value>, ...]
if (!element.isArray())
throw Exception(ErrorCodes::INCORRECT_DATA, "Root JSON value is not an array");
auto array = element.getArray();
std::vector<Element> fields;
fields.reserve(array.size());
for (size_t i = 0; i != array.size(); ++i)
fields.push_back(array[i]);
return fields;
}
};
DataTypes readRowAndGetDataTypesForJSONCompactEachRow(ReadBuffer & in, bool json_strings)
{
JSONCompactEachRowFieldsExtractor extractor;
return determineColumnDataTypesFromJSONEachRowDataImpl<JSONCompactEachRowFieldsExtractor, '[', ']'>(in, json_strings, extractor);
}
bool nonTrivialPrefixAndSuffixCheckerJSONEachRowImpl(ReadBuffer & buf)
{
/// For JSONEachRow we can safely skip whitespace characters
skipWhitespaceIfAny(buf);
return buf.eof() || *buf.position() == '[';
}
bool readFieldImpl(ReadBuffer & in, IColumn & column, const DataTypePtr & type, const SerializationPtr & serialization, const String & column_name, const FormatSettings & format_settings, bool yield_strings)
{
try
{
bool as_nullable = format_settings.null_as_default && !type->isNullable() && !type->isLowCardinalityNullable();
if (yield_strings)
{
String str;
readJSONString(str, in);
ReadBufferFromString buf(str);
if (as_nullable)
return SerializationNullable::deserializeWholeTextImpl(column, buf, format_settings, serialization);
serialization->deserializeWholeText(column, buf, format_settings);
return true;
}
if (as_nullable)
return SerializationNullable::deserializeTextJSONImpl(column, in, format_settings, serialization);
serialization->deserializeTextJSON(column, in, format_settings);
return true;
}
catch (Exception & e)
{
e.addMessage("(while reading the value of key " + column_name + ")");
throw;
}
}
DataTypePtr getCommonTypeForJSONFormats(const DataTypePtr & first, const DataTypePtr & second, bool allow_bools_as_numbers)
{
if (allow_bools_as_numbers)
{
auto not_nullable_first = removeNullable(first);
auto not_nullable_second = removeNullable(second);
/// Check if we have Bool and Number and if so make the result type Number
bool bool_type_presents = isBool(not_nullable_first) || isBool(not_nullable_second);
bool number_type_presents = isNumber(not_nullable_first) || isNumber(not_nullable_second);
if (bool_type_presents && number_type_presents)
{
if (isBool(not_nullable_first))
return second;
return first;
}
}
/// If we have Map and Object, make result type Object
bool object_type_presents = isObject(first) || isObject(second);
bool map_type_presents = isMap(first) || isMap(second);
if (object_type_presents && map_type_presents)
{
if (isObject(first))
return first;
return second;
}
/// If we have different Maps, make result type Object
if (isMap(first) && isMap(second) && !first->equals(*second))
return std::make_shared<DataTypeObject>("json", true);
return nullptr;
}
}

View File

@ -1,37 +0,0 @@
#pragma once
#include <DataTypes/IDataType.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Formats/FormatSettings.h>
#include <IO/BufferWithOwnMemory.h>
#include <IO/ReadBuffer.h>
#include <utility>
namespace DB
{
std::pair<bool, size_t> fileSegmentationEngineJSONEachRow(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size);
std::pair<bool, size_t> fileSegmentationEngineJSONCompactEachRow(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size, size_t min_rows);
/// Parse JSON from string and convert it's type to ClickHouse type. Make the result type always Nullable.
/// JSON array with different nested types is treated as Tuple.
/// If cannot convert (for example when field contains null), return nullptr.
DataTypePtr getDataTypeFromJSONField(const String & field);
/// Read row in JSONEachRow format and try to determine type for each field.
/// Return list of names and types.
/// If cannot determine the type of some field, return nullptr for it.
NamesAndTypesList readRowAndGetNamesAndDataTypesForJSONEachRow(ReadBuffer & in, bool json_strings);
/// Read row in JSONCompactEachRow format and try to determine type for each field.
/// If cannot determine the type of some field, return nullptr for it.
DataTypes readRowAndGetDataTypesForJSONCompactEachRow(ReadBuffer & in, bool json_strings);
bool nonTrivialPrefixAndSuffixCheckerJSONEachRowImpl(ReadBuffer & buf);
bool readFieldImpl(ReadBuffer & in, IColumn & column, const DataTypePtr & type, const SerializationPtr & serialization, const String & column_name, const FormatSettings & format_settings, bool yield_strings);
DataTypePtr getCommonTypeForJSONFormats(const DataTypePtr & first, const DataTypePtr & second, bool allow_bools_as_numbers);
}

603
src/Formats/JSONUtils.cpp Normal file
View File

@ -0,0 +1,603 @@
#include <IO/ReadHelpers.h>
#include <Formats/JSONUtils.h>
#include <Formats/ReadSchemaUtils.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferValidUTF8.h>
#include <DataTypes/Serializations/SerializationNullable.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeString.h>
#include <DataTypes/DataTypeArray.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeMap.h>
#include <DataTypes/DataTypeObject.h>
#include <DataTypes/DataTypeFactory.h>
#include <Common/JSONParsers/SimdJSONParser.h>
#include <Common/JSONParsers/RapidJSONParser.h>
#include <Common/JSONParsers/DummyJSONParser.h>
#include <base/find_symbols.h>
namespace DB
{
namespace ErrorCodes
{
extern const int INCORRECT_DATA;
extern const int LOGICAL_ERROR;
}
namespace JSONUtils
{
template <const char opening_bracket, const char closing_bracket>
static std::pair<bool, size_t>
fileSegmentationEngineJSONEachRowImpl(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size, size_t min_rows)
{
skipWhitespaceIfAny(in);
char * pos = in.position();
size_t balance = 0;
bool quotes = false;
size_t number_of_rows = 0;
while (loadAtPosition(in, memory, pos)
&& (balance || memory.size() + static_cast<size_t>(pos - in.position()) < min_chunk_size || number_of_rows < min_rows))
{
const auto current_object_size = memory.size() + static_cast<size_t>(pos - in.position());
if (min_chunk_size != 0 && current_object_size > 10 * min_chunk_size)
throw ParsingException(
"Size of JSON object is extremely large. Expected not greater than " + std::to_string(min_chunk_size)
+ " bytes, but current is " + std::to_string(current_object_size)
+ " bytes per row. Increase the value setting 'min_chunk_bytes_for_parallel_parsing' or check your data manually, most likely JSON is malformed",
ErrorCodes::INCORRECT_DATA);
if (quotes)
{
pos = find_first_symbols<'\\', '"'>(pos, in.buffer().end());
if (pos > in.buffer().end())
throw Exception("Position in buffer is out of bounds. There must be a bug.", ErrorCodes::LOGICAL_ERROR);
else if (pos == in.buffer().end())
continue;
if (*pos == '\\')
{
++pos;
if (loadAtPosition(in, memory, pos))
++pos;
}
else if (*pos == '"')
{
++pos;
quotes = false;
}
}
else
{
pos = find_first_symbols<opening_bracket, closing_bracket, '\\', '"'>(pos, in.buffer().end());
if (pos > in.buffer().end())
throw Exception("Position in buffer is out of bounds. There must be a bug.", ErrorCodes::LOGICAL_ERROR);
else if (pos == in.buffer().end())
continue;
else if (*pos == opening_bracket)
{
++balance;
++pos;
}
else if (*pos == closing_bracket)
{
--balance;
++pos;
}
else if (*pos == '\\')
{
++pos;
if (loadAtPosition(in, memory, pos))
++pos;
}
else if (*pos == '"')
{
quotes = true;
++pos;
}
if (balance == 0)
++number_of_rows;
}
}
saveUpToPosition(in, memory, pos);
return {loadAtPosition(in, memory, pos), number_of_rows};
}
template <const char opening_bracket, const char closing_bracket>
static String readJSONEachRowLineIntoStringImpl(ReadBuffer & in)
{
Memory memory;
fileSegmentationEngineJSONEachRowImpl<opening_bracket, closing_bracket>(in, memory, 0, 1);
return String(memory.data(), memory.size());
}
template <class Element>
DataTypePtr getDataTypeFromFieldImpl(const Element & field)
{
if (field.isNull())
return nullptr;
if (field.isBool())
return DataTypeFactory::instance().get("Nullable(Bool)");
if (field.isInt64() || field.isUInt64() || field.isDouble())
return makeNullable(std::make_shared<DataTypeFloat64>());
if (field.isString())
return makeNullable(std::make_shared<DataTypeString>());
if (field.isArray())
{
auto array = field.getArray();
/// Return nullptr in case of empty array because we cannot determine nested type.
if (array.size() == 0)
return nullptr;
DataTypes nested_data_types;
/// If this array contains fields with different types we will treat it as Tuple.
bool is_tuple = false;
for (const auto element : array)
{
auto type = getDataTypeFromFieldImpl(element);
if (!type)
return nullptr;
if (!nested_data_types.empty() && type->getName() != nested_data_types.back()->getName())
is_tuple = true;
nested_data_types.push_back(std::move(type));
}
if (is_tuple)
return std::make_shared<DataTypeTuple>(nested_data_types);
return std::make_shared<DataTypeArray>(nested_data_types.back());
}
if (field.isObject())
{
auto object = field.getObject();
DataTypePtr value_type;
bool is_object = false;
for (const auto key_value_pair : object)
{
auto type = getDataTypeFromFieldImpl(key_value_pair.second);
if (!type)
continue;
if (isObject(type))
{
is_object = true;
break;
}
if (!value_type)
{
value_type = type;
}
else if (!value_type->equals(*type))
{
is_object = true;
break;
}
}
if (is_object)
return std::make_shared<DataTypeObject>("json", true);
if (value_type)
return std::make_shared<DataTypeMap>(std::make_shared<DataTypeString>(), value_type);
return nullptr;
}
throw Exception{ErrorCodes::INCORRECT_DATA, "Unexpected JSON type"};
}
auto getJSONParserAndElement()
{
#if USE_SIMDJSON
return std::pair<SimdJSONParser, SimdJSONParser::Element>();
#elif USE_RAPIDJSON
return std::pair<RapidJSONParser, RapidJSONParser::Element>();
#else
return std::pair<DummyJSONParser, DummyJSONParser::Element>();
#endif
}
DataTypePtr getDataTypeFromField(const String & field)
{
auto [parser, element] = getJSONParserAndElement();
bool parsed = parser.parse(field, element);
if (!parsed)
throw Exception(ErrorCodes::INCORRECT_DATA, "Cannot parse JSON object here: {}", field);
return getDataTypeFromFieldImpl(element);
}
template <class Extractor, const char opening_bracket, const char closing_bracket>
static DataTypes determineColumnDataTypesFromJSONEachRowDataImpl(ReadBuffer & in, bool /*json_strings*/, Extractor & extractor)
{
String line = readJSONEachRowLineIntoStringImpl<opening_bracket, closing_bracket>(in);
auto [parser, element] = getJSONParserAndElement();
bool parsed = parser.parse(line, element);
if (!parsed)
throw Exception(ErrorCodes::INCORRECT_DATA, "Cannot parse JSON object here: {}", line);
auto fields = extractor.extract(element);
DataTypes data_types;
data_types.reserve(fields.size());
for (const auto & field : fields)
data_types.push_back(getDataTypeFromFieldImpl(field));
/// TODO: For JSONStringsEachRow/JSONCompactStringsEach all types will be strings.
/// Should we try to parse data inside strings somehow in this case?
return data_types;
}
std::pair<bool, size_t> fileSegmentationEngineJSONEachRow(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size)
{
return fileSegmentationEngineJSONEachRowImpl<'{', '}'>(in, memory, min_chunk_size, 1);
}
std::pair<bool, size_t>
fileSegmentationEngineJSONCompactEachRow(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size, size_t min_rows)
{
return fileSegmentationEngineJSONEachRowImpl<'[', ']'>(in, memory, min_chunk_size, min_rows);
}
struct JSONEachRowFieldsExtractor
{
template <class Element>
std::vector<Element> extract(const Element & element)
{
/// {..., "<column_name>" : <value>, ...}
if (!element.isObject())
throw Exception(ErrorCodes::INCORRECT_DATA, "Root JSON value is not an object");
auto object = element.getObject();
std::vector<Element> fields;
fields.reserve(object.size());
column_names.reserve(object.size());
for (const auto & key_value_pair : object)
{
column_names.emplace_back(key_value_pair.first);
fields.push_back(key_value_pair.second);
}
return fields;
}
std::vector<String> column_names;
};
NamesAndTypesList readRowAndGetNamesAndDataTypesForJSONEachRow(ReadBuffer & in, bool json_strings)
{
JSONEachRowFieldsExtractor extractor;
auto data_types
= determineColumnDataTypesFromJSONEachRowDataImpl<JSONEachRowFieldsExtractor, '{', '}'>(in, json_strings, extractor);
NamesAndTypesList result;
for (size_t i = 0; i != extractor.column_names.size(); ++i)
result.emplace_back(extractor.column_names[i], data_types[i]);
return result;
}
struct JSONCompactEachRowFieldsExtractor
{
template <class Element>
std::vector<Element> extract(const Element & element)
{
/// [..., <value>, ...]
if (!element.isArray())
throw Exception(ErrorCodes::INCORRECT_DATA, "Root JSON value is not an array");
auto array = element.getArray();
std::vector<Element> fields;
fields.reserve(array.size());
for (size_t i = 0; i != array.size(); ++i)
fields.push_back(array[i]);
return fields;
}
};
DataTypes readRowAndGetDataTypesForJSONCompactEachRow(ReadBuffer & in, bool json_strings)
{
JSONCompactEachRowFieldsExtractor extractor;
return determineColumnDataTypesFromJSONEachRowDataImpl<JSONCompactEachRowFieldsExtractor, '[', ']'>(in, json_strings, extractor);
}
bool nonTrivialPrefixAndSuffixCheckerJSONEachRowImpl(ReadBuffer & buf)
{
/// For JSONEachRow we can safely skip whitespace characters
skipWhitespaceIfAny(buf);
return buf.eof() || *buf.position() == '[';
}
bool readField(
ReadBuffer & in,
IColumn & column,
const DataTypePtr & type,
const SerializationPtr & serialization,
const String & column_name,
const FormatSettings & format_settings,
bool yield_strings)
{
try
{
bool as_nullable = format_settings.null_as_default && !type->isNullable() && !type->isLowCardinalityNullable();
if (yield_strings)
{
String str;
readJSONString(str, in);
ReadBufferFromString buf(str);
if (as_nullable)
return SerializationNullable::deserializeWholeTextImpl(column, buf, format_settings, serialization);
serialization->deserializeWholeText(column, buf, format_settings);
return true;
}
if (as_nullable)
return SerializationNullable::deserializeTextJSONImpl(column, in, format_settings, serialization);
serialization->deserializeTextJSON(column, in, format_settings);
return true;
}
catch (Exception & e)
{
e.addMessage("(while reading the value of key " + column_name + ")");
throw;
}
}
DataTypePtr getCommonTypeForJSONFormats(const DataTypePtr & first, const DataTypePtr & second, bool allow_bools_as_numbers)
{
if (allow_bools_as_numbers)
{
auto not_nullable_first = removeNullable(first);
auto not_nullable_second = removeNullable(second);
/// Check if we have Bool and Number and if so make the result type Number
bool bool_type_presents = isBool(not_nullable_first) || isBool(not_nullable_second);
bool number_type_presents = isNumber(not_nullable_first) || isNumber(not_nullable_second);
if (bool_type_presents && number_type_presents)
{
if (isBool(not_nullable_first))
return second;
return first;
}
}
/// If we have Map and Object, make result type Object
bool object_type_presents = isObject(first) || isObject(second);
bool map_type_presents = isMap(first) || isMap(second);
if (object_type_presents && map_type_presents)
{
if (isObject(first))
return first;
return second;
}
/// If we have different Maps, make result type Object
if (isMap(first) && isMap(second) && !first->equals(*second))
return std::make_shared<DataTypeObject>("json", true);
return nullptr;
}
void writeFieldDelimiter(WriteBuffer & out, size_t new_lines)
{
writeChar(',', out);
writeChar('\n', new_lines, out);
}
void writeFieldCompactDelimiter(WriteBuffer & out) { writeCString(", ", out); }
template <bool with_space>
void writeTitle(const char * title, WriteBuffer & out, size_t indent)
{
writeChar('\t', indent, out);
writeChar('"', out);
writeCString(title, out);
if constexpr (with_space)
writeCString("\": ", out);
else
writeCString("\":\n", out);
}
void writeObjectStart(WriteBuffer & out, size_t indent, const char * title)
{
if (title)
writeTitle<false>(title, out, indent);
writeChar('\t', indent, out);
writeCString("{\n", out);
}
void writeObjectEnd(WriteBuffer & out, size_t indent)
{
writeChar('\n', out);
writeChar('\t', indent, out);
writeChar('}', out);
}
void writeArrayStart(WriteBuffer & out, size_t indent, const char * title)
{
if (title)
writeTitle<false>(title, out, indent);
writeChar('\t', indent, out);
writeCString("[\n", out);
}
void writeCompactArrayStart(WriteBuffer & out, size_t indent, const char * title)
{
if (title)
writeTitle<true>(title, out, indent);
else
writeChar('\t', indent, out);
writeCString("[", out);
}
void writeArrayEnd(WriteBuffer & out, size_t indent)
{
writeChar('\n', out);
writeChar('\t', indent, out);
writeChar(']', out);
}
void writeCompactArrayEnd(WriteBuffer & out) { writeChar(']', out); }
void writeFieldFromColumn(
const IColumn & column,
const ISerialization & serialization,
size_t row_num,
bool yield_strings,
const FormatSettings & settings,
WriteBuffer & out,
const std::optional<String> & name,
size_t indent)
{
if (name.has_value())
writeTitle<true>(name->data(), out, indent);
if (yield_strings)
{
WriteBufferFromOwnString buf;
serialization.serializeText(column, row_num, buf, settings);
writeJSONString(buf.str(), out, settings);
}
else
serialization.serializeTextJSON(column, row_num, out, settings);
}
void writeColumns(
const Columns & columns,
const NamesAndTypes & fields,
const Serializations & serializations,
size_t row_num,
bool yield_strings,
const FormatSettings & settings,
WriteBuffer & out,
size_t indent)
{
for (size_t i = 0; i < columns.size(); ++i)
{
if (i != 0)
writeFieldDelimiter(out);
writeFieldFromColumn(*columns[i], *serializations[i], row_num, yield_strings, settings, out, fields[i].name, indent);
}
}
void writeCompactColumns(
const Columns & columns,
const Serializations & serializations,
size_t row_num,
bool yield_strings,
const FormatSettings & settings,
WriteBuffer & out)
{
for (size_t i = 0; i < columns.size(); ++i)
{
if (i != 0)
writeFieldCompactDelimiter(out);
writeFieldFromColumn(*columns[i], *serializations[i], row_num, yield_strings, settings, out);
}
}
void writeMetadata(const NamesAndTypes & fields, const FormatSettings & settings, WriteBuffer & out)
{
writeArrayStart(out, 1, "meta");
for (size_t i = 0; i < fields.size(); ++i)
{
writeObjectStart(out, 2);
writeTitle<true>("name", out, 3);
writeDoubleQuoted(fields[i].name, out);
writeFieldDelimiter(out);
writeTitle<true>("type", out, 3);
writeJSONString(fields[i].type->getName(), out, settings);
writeObjectEnd(out, 2);
if (i + 1 < fields.size())
writeFieldDelimiter(out);
}
writeArrayEnd(out, 1);
}
void writeAdditionalInfo(
size_t rows,
size_t rows_before_limit,
bool applied_limit,
const Stopwatch & watch,
const Progress & progress,
bool write_statistics,
WriteBuffer & out)
{
writeFieldDelimiter(out, 2);
writeTitle<true>("rows", out, 1);
writeIntText(rows, out);
if (applied_limit)
{
writeFieldDelimiter(out, 2);
writeTitle<true>("rows_before_limit_at_least", out, 1);
writeIntText(rows_before_limit, out);
}
if (write_statistics)
{
writeFieldDelimiter(out, 2);
writeObjectStart(out, 1, "statistics");
writeTitle<true>("elapsed", out, 2);
writeText(watch.elapsedSeconds(), out);
writeFieldDelimiter(out);
writeTitle<true>("rows_read", out, 2);
writeText(progress.read_rows.load(), out);
writeFieldDelimiter(out);
writeTitle<true>("bytes_read", out, 2);
writeText(progress.read_bytes.load(), out);
writeObjectEnd(out, 1);
}
}
void makeNamesAndTypesWithValidUTF8(NamesAndTypes & fields, const FormatSettings & settings, bool & need_validate_utf8)
{
for (auto & field : fields)
{
if (!field.type->textCanContainOnlyValidUTF8())
need_validate_utf8 = true;
WriteBufferFromOwnString buf;
{
WriteBufferValidUTF8 validating_buf(buf);
writeJSONString(field.name, validating_buf, settings);
}
field.name = buf.str().substr(1, buf.str().size() - 2);
}
}
}
}

109
src/Formats/JSONUtils.h Normal file
View File

@ -0,0 +1,109 @@
#pragma once
#include <DataTypes/IDataType.h>
#include <DataTypes/Serializations/ISerialization.h>
#include <Formats/FormatSettings.h>
#include <IO/BufferWithOwnMemory.h>
#include <IO/ReadBuffer.h>
#include <IO/Progress.h>
#include <Core/NamesAndTypes.h>
#include <Common/Stopwatch.h>
#include <utility>
namespace DB
{
namespace JSONUtils
{
std::pair<bool, size_t> fileSegmentationEngineJSONEachRow(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size);
std::pair<bool, size_t>
fileSegmentationEngineJSONCompactEachRow(ReadBuffer & in, DB::Memory<> & memory, size_t min_chunk_size, size_t min_rows);
/// Parse JSON from string and convert it's type to ClickHouse type. Make the result type always Nullable.
/// JSON array with different nested types is treated as Tuple.
/// If cannot convert (for example when field contains null), return nullptr.
DataTypePtr getDataTypeFromField(const String & field);
/// Read row in JSONEachRow format and try to determine type for each field.
/// Return list of names and types.
/// If cannot determine the type of some field, return nullptr for it.
NamesAndTypesList readRowAndGetNamesAndDataTypesForJSONEachRow(ReadBuffer & in, bool json_strings);
/// Read row in JSONCompactEachRow format and try to determine type for each field.
/// If cannot determine the type of some field, return nullptr for it.
DataTypes readRowAndGetDataTypesForJSONCompactEachRow(ReadBuffer & in, bool json_strings);
bool nonTrivialPrefixAndSuffixCheckerJSONEachRowImpl(ReadBuffer & buf);
bool readField(
ReadBuffer & in,
IColumn & column,
const DataTypePtr & type,
const SerializationPtr & serialization,
const String & column_name,
const FormatSettings & format_settings,
bool yield_strings);
DataTypePtr getCommonTypeForJSONFormats(const DataTypePtr & first, const DataTypePtr & second, bool allow_bools_as_numbers);
void makeNamesAndTypesWithValidUTF8(NamesAndTypes & fields, const FormatSettings & settings, bool & need_validate_utf8);
/// Functions helpers for writing JSON data to WriteBuffer.
void writeFieldDelimiter(WriteBuffer & out, size_t new_lines = 1);
void writeFieldCompactDelimiter(WriteBuffer & out);
void writeObjectStart(WriteBuffer & out, size_t indent = 0, const char * title = nullptr);
void writeObjectEnd(WriteBuffer & out, size_t indent = 0);
void writeArrayStart(WriteBuffer & out, size_t indent = 0, const char * title = nullptr);
void writeCompactArrayStart(WriteBuffer & out, size_t indent = 0, const char * title = nullptr);
void writeArrayEnd(WriteBuffer & out, size_t indent = 0);
void writeCompactArrayEnd(WriteBuffer & out);
void writeFieldFromColumn(
const IColumn & column,
const ISerialization & serialization,
size_t row_num,
bool yield_strings,
const FormatSettings & settings,
WriteBuffer & out,
const std::optional<String> & name = std::nullopt,
size_t indent = 0);
void writeColumns(
const Columns & columns,
const NamesAndTypes & fields,
const Serializations & serializations,
size_t row_num,
bool yield_strings,
const FormatSettings & settings,
WriteBuffer & out,
size_t indent = 0);
void writeCompactColumns(
const Columns & columns,
const Serializations & serializations,
size_t row_num,
bool yield_strings,
const FormatSettings & settings,
WriteBuffer & out);
void writeMetadata(const NamesAndTypes & fields, const FormatSettings & settings, WriteBuffer & out);
void writeAdditionalInfo(
size_t rows,
size_t rows_before_limit,
bool applied_limit,
const Stopwatch & watch,
const Progress & progress,
bool write_statistics,
WriteBuffer & out);
}
}

View File

@ -23,6 +23,7 @@ namespace ErrorCodes
extern const int INCORRECT_INDEX;
extern const int LOGICAL_ERROR;
extern const int CANNOT_READ_ALL_DATA;
extern const int INCORRECT_DATA;
}
@ -31,8 +32,8 @@ NativeReader::NativeReader(ReadBuffer & istr_, UInt64 server_revision_)
{
}
NativeReader::NativeReader(ReadBuffer & istr_, const Block & header_, UInt64 server_revision_)
: istr(istr_), header(header_), server_revision(server_revision_)
NativeReader::NativeReader(ReadBuffer & istr_, const Block & header_, UInt64 server_revision_, bool skip_unknown_columns_)
: istr(istr_), header(header_), server_revision(server_revision_), skip_unknown_columns(skip_unknown_columns_)
{
}
@ -186,18 +187,29 @@ Block NativeReader::read()
column.column = std::move(read_column);
bool use_in_result = true;
if (header)
{
/// Support insert from old clients without low cardinality type.
auto & header_column = header.getByName(column.name);
if (!header_column.type->equals(*column.type))
if (header.has(column.name))
{
column.column = recursiveTypeConversion(column.column, column.type, header.safeGetByPosition(i).type);
column.type = header.safeGetByPosition(i).type;
/// Support insert from old clients without low cardinality type.
auto & header_column = header.getByName(column.name);
if (!header_column.type->equals(*column.type))
{
column.column = recursiveTypeConversion(column.column, column.type, header.safeGetByPosition(i).type);
column.type = header.safeGetByPosition(i).type;
}
}
else
{
if (!skip_unknown_columns)
throw Exception(ErrorCodes::INCORRECT_DATA, "Unknown column with name {} found while reading data in Native format", column.name);
use_in_result = false;
}
}
res.insert(std::move(column));
if (use_in_result)
res.insert(std::move(column));
if (use_index)
++index_column_it;

View File

@ -24,7 +24,7 @@ public:
/// For cases when data structure (header) is known in advance.
/// NOTE We may use header for data validation and/or type conversions. It is not implemented.
NativeReader(ReadBuffer & istr_, const Block & header_, UInt64 server_revision_);
NativeReader(ReadBuffer & istr_, const Block & header_, UInt64 server_revision_, bool skip_unknown_columns_ = false);
/// For cases when we have an index. It allows to skip columns. Only columns specified in the index will be read.
NativeReader(ReadBuffer & istr_, UInt64 server_revision_,
@ -43,6 +43,7 @@ private:
ReadBuffer & istr;
Block header;
UInt64 server_revision;
bool skip_unknown_columns;
bool use_index = false;
IndexForNativeFormat::Blocks::const_iterator index_block_it;

View File

@ -38,6 +38,10 @@ void registerInputFormatJSONEachRow(FormatFactory & factory);
void registerOutputFormatJSONEachRow(FormatFactory & factory);
void registerInputFormatJSONCompactEachRow(FormatFactory & factory);
void registerOutputFormatJSONCompactEachRow(FormatFactory & factory);
void registerInputFormatJSONColumns(FormatFactory & factory);
void registerOutputFormatJSONColumns(FormatFactory & factory);
void registerInputFormatJSONCompactColumns(FormatFactory & factory);
void registerOutputFormatJSONCompactColumns(FormatFactory & factory);
void registerInputFormatProtobuf(FormatFactory & factory);
void registerOutputFormatProtobuf(FormatFactory & factory);
void registerInputFormatProtobufList(FormatFactory & factory);
@ -70,6 +74,7 @@ void registerOutputFormatVertical(FormatFactory & factory);
void registerOutputFormatJSON(FormatFactory & factory);
void registerOutputFormatJSONCompact(FormatFactory & factory);
void registerOutputFormatJSONEachRowWithProgress(FormatFactory & factory);
void registerOutputFormatJSONColumnsWithMetadata(FormatFactory & factory);
void registerOutputFormatXML(FormatFactory & factory);
void registerOutputFormatODBCDriver2(FormatFactory & factory);
void registerOutputFormatNull(FormatFactory & factory);
@ -102,14 +107,16 @@ void registerTSVSchemaReader(FormatFactory & factory);
void registerCSVSchemaReader(FormatFactory & factory);
void registerJSONCompactEachRowSchemaReader(FormatFactory & factory);
void registerJSONEachRowSchemaReader(FormatFactory & factory);
void registerJSONAsStringSchemaReader(FormatFactory & factory);
void registerJSONAsObjectSchemaReader(FormatFactory & factory);
void registerJSONColumnsSchemaReader(FormatFactory & factory);
void registerJSONCompactColumnsSchemaReader(FormatFactory & factory);
void registerNativeSchemaReader(FormatFactory & factory);
void registerRowBinaryWithNamesAndTypesSchemaReader(FormatFactory & factory);
void registerAvroSchemaReader(FormatFactory & factory);
void registerProtobufSchemaReader(FormatFactory & factory);
void registerProtobufListSchemaReader(FormatFactory & factory);
void registerLineAsStringSchemaReader(FormatFactory & factory);
void registerJSONAsStringSchemaReader(FormatFactory & factory);
void registerJSONAsObjectSchemaReader(FormatFactory & factory);
void registerRawBLOBSchemaReader(FormatFactory & factory);
void registerMsgPackSchemaReader(FormatFactory & factory);
void registerCapnProtoSchemaReader(FormatFactory & factory);
@ -120,6 +127,7 @@ void registerValuesSchemaReader(FormatFactory & factory);
void registerTemplateSchemaReader(FormatFactory & factory);
void registerMySQLSchemaReader(FormatFactory & factory);
void registerFileExtensions(FormatFactory & factory);
void registerFormats()
@ -128,8 +136,8 @@ void registerFormats()
registerFileSegmentationEngineTabSeparated(factory);
registerFileSegmentationEngineCSV(factory);
registerFileSegmentationEngineJSONEachRow(factory);
registerFileSegmentationEngineRegexp(factory);
registerFileSegmentationEngineJSONEachRow(factory);
registerFileSegmentationEngineJSONAsString(factory);
registerFileSegmentationEngineJSONAsObject(factory);
registerFileSegmentationEngineJSONCompactEachRow(factory);
@ -155,6 +163,10 @@ void registerFormats()
registerOutputFormatJSONEachRow(factory);
registerInputFormatJSONCompactEachRow(factory);
registerOutputFormatJSONCompactEachRow(factory);
registerInputFormatJSONColumns(factory);
registerOutputFormatJSONColumns(factory);
registerInputFormatJSONCompactColumns(factory);
registerOutputFormatJSONCompactColumns(factory);
registerInputFormatProtobuf(factory);
registerOutputFormatProtobufList(factory);
registerInputFormatProtobufList(factory);
@ -184,6 +196,7 @@ void registerFormats()
registerOutputFormatJSON(factory);
registerOutputFormatJSONCompact(factory);
registerOutputFormatJSONEachRowWithProgress(factory);
registerOutputFormatJSONColumnsWithMetadata(factory);
registerOutputFormatXML(factory);
registerOutputFormatODBCDriver2(factory);
registerOutputFormatNull(factory);
@ -195,8 +208,8 @@ void registerFormats()
registerInputFormatRegexp(factory);
registerInputFormatJSONAsString(factory);
registerInputFormatLineAsString(factory);
registerInputFormatJSONAsObject(factory);
registerInputFormatLineAsString(factory);
#if USE_HIVE
registerInputFormatHiveText(factory);
#endif
@ -215,14 +228,16 @@ void registerFormats()
registerCSVSchemaReader(factory);
registerJSONCompactEachRowSchemaReader(factory);
registerJSONEachRowSchemaReader(factory);
registerJSONAsStringSchemaReader(factory);
registerJSONAsObjectSchemaReader(factory);
registerJSONColumnsSchemaReader(factory);
registerJSONCompactColumnsSchemaReader(factory);
registerNativeSchemaReader(factory);
registerRowBinaryWithNamesAndTypesSchemaReader(factory);
registerAvroSchemaReader(factory);
registerProtobufSchemaReader(factory);
registerProtobufListSchemaReader(factory);
registerLineAsStringSchemaReader(factory);
registerJSONAsStringSchemaReader(factory);
registerJSONAsObjectSchemaReader(factory);
registerRawBLOBSchemaReader(factory);
registerMsgPackSchemaReader(factory);
registerCapnProtoSchemaReader(factory);

View File

@ -10,4 +10,10 @@ void registerWithNamesAndTypes(const std::string & base_format_name, RegisterWit
register_func(base_format_name + "WithNamesAndTypes", true, true);
}
void markFormatWithNamesAndTypesSupportsSamplingColumns(const std::string & base_format_name, FormatFactory & factory)
{
factory.markFormatSupportsSubsetOfColumns(base_format_name + "WithNames");
factory.markFormatSupportsSubsetOfColumns(base_format_name + "WithNamesAndTypes");
}
}

View File

@ -2,6 +2,7 @@
#include <string>
#include <functional>
#include <Formats/FormatFactory.h>
namespace DB
{
@ -9,4 +10,6 @@ namespace DB
using RegisterWithNamesAndTypesFunc = std::function<void(const std::string & format_name, bool with_names, bool with_types)>;
void registerWithNamesAndTypes(const std::string & base_format_name, RegisterWithNamesAndTypesFunc register_func);
void markFormatWithNamesAndTypesSupportsSamplingColumns(const std::string & base_format_name, FormatFactory & factory);
}

View File

@ -26,19 +26,21 @@ struct CountSubstringsImpl
static constexpr bool supports_start_pos = true;
static constexpr auto name = Name::name;
static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {};}
using ResultType = UInt64;
/// Count occurrences of one substring in many strings.
static void vectorConstant(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
const ColumnString::Chars & haystack_data,
const ColumnString::Offsets & haystack_offsets,
const std::string & needle,
const ColumnPtr & start_pos,
PaddedPODArray<UInt64> & res)
{
const UInt8 * begin = data.data();
const UInt8 * const begin = haystack_data.data();
const UInt8 * const end = haystack_data.data() + haystack_data.size();
const UInt8 * pos = begin;
const UInt8 * end = pos + data.size();
/// FIXME: suboptimal
memset(&res[0], 0, res.size() * sizeof(res[0]));
@ -52,15 +54,15 @@ struct CountSubstringsImpl
while (pos < end && end != (pos = searcher.search(pos, end - pos)))
{
/// Determine which index it refers to.
while (begin + offsets[i] <= pos)
while (begin + haystack_offsets[i] <= pos)
++i;
auto start = start_pos != nullptr ? start_pos->getUInt(i) : 0;
/// We check that the entry does not pass through the boundaries of strings.
if (pos + needle.size() < begin + offsets[i])
if (pos + needle.size() < begin + haystack_offsets[i])
{
auto res_pos = needle.size() + Impl::countChars(reinterpret_cast<const char *>(begin + offsets[i - 1]), reinterpret_cast<const char *>(pos));
auto res_pos = needle.size() + Impl::countChars(reinterpret_cast<const char *>(begin + haystack_offsets[i - 1]), reinterpret_cast<const char *>(pos));
if (res_pos >= start)
{
++res[i];
@ -69,14 +71,14 @@ struct CountSubstringsImpl
pos += needle.size();
continue;
}
pos = begin + offsets[i];
pos = begin + haystack_offsets[i];
++i;
}
}
/// Count number of occurrences of substring in string.
static void constantConstantScalar(
std::string data,
std::string haystack,
std::string needle,
UInt64 start_pos,
UInt64 & res)
@ -87,9 +89,9 @@ struct CountSubstringsImpl
return;
auto start = std::max(start_pos, UInt64(1));
size_t start_byte = Impl::advancePos(data.data(), data.data() + data.size(), start - 1) - data.data();
size_t start_byte = Impl::advancePos(haystack.data(), haystack.data() + haystack.size(), start - 1) - haystack.data();
size_t new_start_byte;
while ((new_start_byte = data.find(needle, start_byte)) != std::string::npos)
while ((new_start_byte = haystack.find(needle, start_byte)) != std::string::npos)
{
++res;
/// Intersecting substrings in haystack accounted only once
@ -99,21 +101,21 @@ struct CountSubstringsImpl
/// Count number of occurrences of substring in string starting from different positions.
static void constantConstant(
std::string data,
std::string haystack,
std::string needle,
const ColumnPtr & start_pos,
PaddedPODArray<UInt64> & res)
{
Impl::toLowerIfNeed(data);
Impl::toLowerIfNeed(haystack);
Impl::toLowerIfNeed(needle);
if (start_pos == nullptr)
{
constantConstantScalar(data, needle, 0, res[0]);
constantConstantScalar(haystack, needle, 0, res[0]);
return;
}
size_t haystack_size = Impl::countChars(data.data(), data.data() + data.size());
size_t haystack_size = Impl::countChars(haystack.data(), haystack.data() + haystack.size());
size_t size = start_pos != nullptr ? start_pos->size() : 0;
for (size_t i = 0; i < size; ++i)
@ -125,7 +127,7 @@ struct CountSubstringsImpl
res[i] = 0;
continue;
}
constantConstantScalar(data, needle, start, res[i]);
constantConstantScalar(haystack, needle, start, res[i]);
}
}
@ -228,6 +230,12 @@ struct CountSubstringsImpl
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name);
}
template <typename... Args>
static void vectorFixedVector(Args &&...)
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name);
}
};
}

View File

@ -213,7 +213,7 @@ private:
template <typename Name, Float64(Function)(Float64, Float64)>
struct BinaryFunctionPlain
struct BinaryFunctionVectorized
{
static constexpr auto name = Name::name;
static constexpr auto rows_per_iteration = 1;
@ -225,6 +225,4 @@ struct BinaryFunctionPlain
}
};
#define BinaryFunctionVectorized BinaryFunctionPlain
}

View File

@ -42,9 +42,8 @@ struct UnaryOperationImpl
using ArrayA = typename ColVecA::Container;
using ArrayC = typename ColVecC::Container;
MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(vectorImpl,
MULTITARGET_FH(
static void NO_INLINE), /*vectorImpl*/ MULTITARGET_FB((const ArrayA & a, ArrayC & c) /// NOLINT
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(static void NO_INLINE), vectorImpl, MULTITARGET_FUNCTION_BODY((const ArrayA & a, ArrayC & c) /// NOLINT
{
size_t size = a.size();
for (size_t i = 0; i < size; ++i)
@ -79,9 +78,9 @@ struct UnaryOperationImpl
template <typename Op>
struct FixedStringUnaryOperationImpl
{
MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(vectorImpl,
MULTITARGET_FH(
static void NO_INLINE), /*vectorImpl*/ MULTITARGET_FB((const ColumnFixedString::Chars & a, ColumnFixedString::Chars & c) /// NOLINT
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(static void NO_INLINE), vectorImpl, MULTITARGET_FUNCTION_BODY((const ColumnFixedString::Chars & a, /// NOLINT
ColumnFixedString::Chars & c)
{
size_t size = a.size();
for (size_t i = 0; i < size; ++i)

View File

@ -253,13 +253,13 @@ struct UnbinImpl
/// Encode number or string to string with binary or hexadecimal representation
template <typename Impl>
class EncodeToBinaryRepr : public IFunction
class EncodeToBinaryRepresentation : public IFunction
{
public:
static constexpr auto name = Impl::name;
static constexpr size_t word_size = Impl::word_size;
static FunctionPtr create(ContextPtr) { return std::make_shared<EncodeToBinaryRepr>(); }
static FunctionPtr create(ContextPtr) { return std::make_shared<EncodeToBinaryRepresentation>(); }
String getName() const override { return name; }
@ -550,12 +550,12 @@ public:
/// Decode number or string from string with binary or hexadecimal representation
template <typename Impl>
class DecodeFromBinaryRepr : public IFunction
class DecodeFromBinaryRepresentation : public IFunction
{
public:
static constexpr auto name = Impl::name;
static constexpr size_t word_size = Impl::word_size;
static FunctionPtr create(ContextPtr) { return std::make_shared<DecodeFromBinaryRepr>(); }
static FunctionPtr create(ContextPtr) { return std::make_shared<DecodeFromBinaryRepresentation>(); }
String getName() const override { return name; }
@ -623,10 +623,10 @@ public:
void registerFunctionsBinaryRepr(FunctionFactory & factory)
{
factory.registerFunction<EncodeToBinaryRepr<HexImpl>>(FunctionFactory::CaseInsensitive);
factory.registerFunction<DecodeFromBinaryRepr<UnhexImpl>>(FunctionFactory::CaseInsensitive);
factory.registerFunction<EncodeToBinaryRepr<BinImpl>>(FunctionFactory::CaseInsensitive);
factory.registerFunction<DecodeFromBinaryRepr<UnbinImpl>>(FunctionFactory::CaseInsensitive);
factory.registerFunction<EncodeToBinaryRepresentation<HexImpl>>(FunctionFactory::CaseInsensitive);
factory.registerFunction<DecodeFromBinaryRepresentation<UnhexImpl>>(FunctionFactory::CaseInsensitive);
factory.registerFunction<EncodeToBinaryRepresentation<BinImpl>>(FunctionFactory::CaseInsensitive);
factory.registerFunction<DecodeFromBinaryRepresentation<UnbinImpl>>(FunctionFactory::CaseInsensitive);
}
}

View File

@ -85,8 +85,9 @@ struct NumComparisonImpl
using ContainerA = PaddedPODArray<A>;
using ContainerB = PaddedPODArray<B>;
MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(vectorVectorImpl,
MULTITARGET_FH(static void), /*vectorVectorImpl*/ MULTITARGET_FB((const ContainerA & a, const ContainerB & b, PaddedPODArray<UInt8> & c) /// NOLINT
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(static void), vectorVectorImpl, MULTITARGET_FUNCTION_BODY(( /// NOLINT
const ContainerA & a, const ContainerB & b, PaddedPODArray<UInt8> & c)
{
/** GCC 4.8.2 vectorizes a loop only if it is written in this form.
* In this case, if you loop through the array index (the code will look simpler),
@ -127,8 +128,9 @@ struct NumComparisonImpl
}
MULTITARGET_FUNCTION_WRAPPER_AVX2_SSE42(vectorConstantImpl,
MULTITARGET_FH(static void), /*vectorConstantImpl*/ MULTITARGET_FB((const ContainerA & a, B b, PaddedPODArray<UInt8> & c) /// NOLINT
MULTITARGET_FUNCTION_AVX2_SSE42(
MULTITARGET_FUNCTION_HEADER(static void), vectorConstantImpl, MULTITARGET_FUNCTION_BODY(( /// NOLINT
const ContainerA & a, B b, PaddedPODArray<UInt8> & c)
{
size_t size = a.size();
const A * __restrict a_pos = a.data();

View File

@ -15,7 +15,6 @@
namespace DB
{
/** Search and replace functions in strings:
*
* position(haystack, needle) - the normal search for a substring in a string, returns the position (in bytes) of the found substring starting with 1, or 0 if no substring is found.
* positionUTF8(haystack, needle) - the same, but the position is calculated at code points, provided that the string is encoded in UTF-8.
* positionCaseInsensitive(haystack, needle)
@ -24,13 +23,29 @@ namespace DB
* like(haystack, pattern) - search by the regular expression LIKE; Returns 0 or 1. Case-insensitive, but only for Latin.
* notLike(haystack, pattern)
*
* ilike(haystack, pattern) - like 'like' but case-insensitive
* notIlike(haystack, pattern)
*
* match(haystack, pattern) - search by regular expression re2; Returns 0 or 1.
* multiMatchAny(haystack, [pattern_1, pattern_2, ..., pattern_n]) -- search by re2 regular expressions pattern_i; Returns 0 or 1 if any pattern_i matches.
* multiMatchAnyIndex(haystack, [pattern_1, pattern_2, ..., pattern_n]) -- search by re2 regular expressions pattern_i; Returns index of any match or zero if none;
* multiMatchAllIndices(haystack, [pattern_1, pattern_2, ..., pattern_n]) -- search by re2 regular expressions pattern_i; Returns an array of matched indices in any order;
*
* countSubstrings(haystack, needle) -- count number of occurrences of needle in haystack.
* countSubstringsCaseInsensitive(haystack, needle)
* countSubstringsCaseInsensitiveUTF8(haystack, needle)
*
* hasToken()
* hasTokenCaseInsensitive()
*
* JSON stuff:
* visitParamExtractBool()
* simpleJSONExtractBool()
* visitParamExtractFloat()
* simpleJSONExtractFloat()
* visitParamExtractInt()
* simpleJSONExtractInt()
* visitParamExtractUInt()
* simpleJSONExtractUInt()
* visitParamHas()
* simpleJSONHas()
*
* Applies regexp re2 and pulls:
* - the first subpattern, if the regexp has a subpattern;
@ -70,11 +85,7 @@ public:
ColumnNumbers getArgumentsThatAreAlwaysConstant() const override
{
if (!Impl::use_default_implementation_for_constants)
return ColumnNumbers{};
if (!Impl::supports_start_pos)
return ColumnNumbers{1, 2};
return ColumnNumbers{1, 2, 3};
return Impl::getArgumentsThatAreAlwaysConstant();
}
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
@ -104,8 +115,6 @@ public:
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & result_type, size_t /*input_rows_count*/) const override
{
using ResultType = typename Impl::ResultType;
const ColumnPtr & column_haystack = arguments[0].column;
const ColumnPtr & column_needle = arguments[1].column;
@ -116,6 +125,8 @@ public:
const ColumnConst * col_haystack_const = typeid_cast<const ColumnConst *>(&*column_haystack);
const ColumnConst * col_needle_const = typeid_cast<const ColumnConst *>(&*column_needle);
using ResultType = typename Impl::ResultType;
if constexpr (!Impl::use_default_implementation_for_constants)
{
bool is_col_start_pos_const = column_start_pos == nullptr || isColumnConst(*column_start_pos);
@ -162,6 +173,14 @@ public:
col_needle_const->getValue<String>(),
column_start_pos,
vec_res);
else if (col_haystack_vector_fixed && col_needle_vector)
Impl::vectorFixedVector(
col_haystack_vector_fixed->getChars(),
col_haystack_vector_fixed->getN(),
col_needle_vector->getChars(),
col_needle_vector->getOffsets(),
column_start_pos,
vec_res);
else if (col_haystack_vector_fixed && col_needle_const)
Impl::vectorFixedConstant(
col_haystack_vector_fixed->getChars(),

View File

@ -83,10 +83,12 @@ struct ExtractParamImpl
static constexpr bool supports_start_pos = false;
static constexpr auto name = Name::name;
static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1, 2};}
/// It is assumed that `res` is the correct size and initialized with zeros.
static void vectorConstant(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
const ColumnString::Chars & haystack_data,
const ColumnString::Offsets & haystack_offsets,
std::string needle,
const ColumnPtr & start_pos,
PaddedPODArray<ResultType> & res)
@ -97,9 +99,9 @@ struct ExtractParamImpl
/// We are looking for a parameter simply as a substring of the form "name"
needle = "\"" + needle + "\":";
const UInt8 * begin = data.data();
const UInt8 * const begin = haystack_data.data();
const UInt8 * const end = haystack_data.data() + haystack_data.size();
const UInt8 * pos = begin;
const UInt8 * end = pos + data.size();
/// The current index in the string array.
size_t i = 0;
@ -110,19 +112,19 @@ struct ExtractParamImpl
while (pos < end && end != (pos = searcher.search(pos, end - pos)))
{
/// Let's determine which index it belongs to.
while (begin + offsets[i] <= pos)
while (begin + haystack_offsets[i] <= pos)
{
res[i] = 0;
++i;
}
/// We check that the entry does not pass through the boundaries of strings.
if (pos + needle.size() < begin + offsets[i])
res[i] = ParamExtractor::extract(pos + needle.size(), begin + offsets[i] - 1); /// don't include terminating zero
if (pos + needle.size() < begin + haystack_offsets[i])
res[i] = ParamExtractor::extract(pos + needle.size(), begin + haystack_offsets[i] - 1); /// don't include terminating zero
else
res[i] = 0;
pos = begin + offsets[i];
pos = begin + haystack_offsets[i];
++i;
}
@ -145,6 +147,12 @@ struct ExtractParamImpl
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name);
}
template <typename... Args>
static void vectorFixedVector(Args &&...)
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name);
}
};
@ -153,20 +161,20 @@ struct ExtractParamImpl
template <typename ParamExtractor>
struct ExtractParamToStringImpl
{
static void vector(const ColumnString::Chars & data, const ColumnString::Offsets & offsets,
static void vector(const ColumnString::Chars & haystack_data, const ColumnString::Offsets & haystack_offsets,
std::string needle,
ColumnString::Chars & res_data, ColumnString::Offsets & res_offsets)
{
/// Constant 5 is taken from a function that performs a similar task FunctionsStringSearch.h::ExtractImpl
res_data.reserve(data.size() / 5);
res_offsets.resize(offsets.size());
res_data.reserve(haystack_data.size() / 5);
res_offsets.resize(haystack_offsets.size());
/// We are looking for a parameter simply as a substring of the form "name"
needle = "\"" + needle + "\":";
const UInt8 * begin = data.data();
const UInt8 * const begin = haystack_data.data();
const UInt8 * const end = haystack_data.data() + haystack_data.size();
const UInt8 * pos = begin;
const UInt8 * end = pos + data.size();
/// The current index in the string array.
size_t i = 0;
@ -177,7 +185,7 @@ struct ExtractParamToStringImpl
while (pos < end && end != (pos = searcher.search(pos, end - pos)))
{
/// Determine which index it belongs to.
while (begin + offsets[i] <= pos)
while (begin + haystack_offsets[i] <= pos)
{
res_data.push_back(0);
res_offsets[i] = res_data.size();
@ -185,10 +193,10 @@ struct ExtractParamToStringImpl
}
/// We check that the entry does not pass through the boundaries of strings.
if (pos + needle.size() < begin + offsets[i])
ParamExtractor::extract(pos + needle.size(), begin + offsets[i], res_data);
if (pos + needle.size() < begin + haystack_offsets[i])
ParamExtractor::extract(pos + needle.size(), begin + haystack_offsets[i], res_data);
pos = begin + offsets[i];
pos = begin + haystack_offsets[i];
res_data.push_back(0);
res_offsets[i] = res_data.size();

View File

@ -1,6 +1,7 @@
#pragma once
#include <Columns/ColumnString.h>
#include <Core/ColumnNumbers.h>
namespace DB
@ -14,7 +15,7 @@ namespace ErrorCodes
/** Token search the string, means that needle must be surrounded by some separator chars, like whitespace or puctuation.
*/
template <typename Name, typename TokenSearcher, bool negate_result = false>
template <typename Name, typename TokenSearcher, bool negate>
struct HasTokenImpl
{
using ResultType = UInt8;
@ -23,9 +24,11 @@ struct HasTokenImpl
static constexpr bool supports_start_pos = false;
static constexpr auto name = Name::name;
static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {1, 2};}
static void vectorConstant(
const ColumnString::Chars & data,
const ColumnString::Offsets & offsets,
const ColumnString::Chars & haystack_data,
const ColumnString::Offsets & haystack_offsets,
const std::string & pattern,
const ColumnPtr & start_pos,
PaddedPODArray<UInt8> & res)
@ -33,12 +36,12 @@ struct HasTokenImpl
if (start_pos != nullptr)
throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Function '{}' does not support start_pos argument", name);
if (offsets.empty())
if (haystack_offsets.empty())
return;
const UInt8 * begin = data.data();
const UInt8 * const begin = haystack_data.data();
const UInt8 * const end = haystack_data.data() + haystack_data.size();
const UInt8 * pos = begin;
const UInt8 * end = pos + data.size();
/// The current index in the array of strings.
size_t i = 0;
@ -49,25 +52,25 @@ struct HasTokenImpl
while (pos < end && end != (pos = searcher.search(pos, end - pos)))
{
/// Let's determine which index it refers to.
while (begin + offsets[i] <= pos)
while (begin + haystack_offsets[i] <= pos)
{
res[i] = negate_result;
res[i] = negate;
++i;
}
/// We check that the entry does not pass through the boundaries of strings.
if (pos + pattern.size() < begin + offsets[i])
res[i] = !negate_result;
if (pos + pattern.size() < begin + haystack_offsets[i])
res[i] = !negate;
else
res[i] = negate_result;
res[i] = negate;
pos = begin + offsets[i];
pos = begin + haystack_offsets[i];
++i;
}
/// Tail, in which there can be no substring.
if (i < res.size())
memset(&res[i], negate_result, (res.size() - i) * sizeof(res[0]));
memset(&res[i], negate, (res.size() - i) * sizeof(res[0]));
}
template <typename... Args>
@ -88,6 +91,12 @@ struct HasTokenImpl
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name);
}
template <typename... Args>
static void vectorFixedVector(Args &&...)
{
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Function '{}' doesn't support FixedString haystack argument", name);
}
};
}

Some files were not shown because too many files have changed in this diff Show More