Merge remote-tracking branch 'origin' into unbundled-packager

This commit is contained in:
Yatsishin Ilya 2020-08-17 11:25:45 +03:00
commit edcd6d5502
197 changed files with 3497 additions and 2603 deletions

View File

@ -6,6 +6,8 @@
#include <string.h>
#include <unistd.h>
#include <iostream>
namespace
{
@ -107,6 +109,8 @@ ReadlineLineReader::ReadlineLineReader(
throw std::runtime_error(std::string("Cannot set signal handler for readline: ") + strerror(errno));
rl_variable_bind("completion-ignore-case", "on");
// TODO: it doesn't work
// history_write_timestamps = 1;
}
ReadlineLineReader::~ReadlineLineReader()
@ -129,6 +133,11 @@ LineReader::InputStatus ReadlineLineReader::readOneLine(const String & prompt)
void ReadlineLineReader::addToHistory(const String & line)
{
add_history(line.c_str());
// Flush changes to the disk
// NOTE readline builds a buffer of all the lines to write, and write them in one syscall.
// Thus there is no need to lock the history file here.
write_history(history_file_path.c_str());
}
#if RL_VERSION_MAJOR >= 7

View File

@ -2,7 +2,7 @@ Go to https://www.monetdb.org/
The graphical design of the website is a bit old-fashioned but I do not afraid.
Dowload now.
Download now.
Latest binary releases.
Ubuntu & Debian.
@ -1103,7 +1103,7 @@ Ok, it's doing something at least for twenty minues...
clk: 28:02 min
```
Finally it has loaded data successfuly in 28 minutes. It's not fast - just below 60 000 rows per second.
Finally it has loaded data successfully in 28 minutes. It's not fast - just below 60 000 rows per second.
But the second query from the test does not work:

View File

@ -1,6 +1,10 @@
option(ENABLE_CASSANDRA "Enable Cassandra" ${ENABLE_LIBRARIES})
if (ENABLE_CASSANDRA)
if (APPLE)
SET(CMAKE_MACOSX_RPATH ON)
endif()
if (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/libuv")
message (ERROR "submodule contrib/libuv is missing. to fix try run: \n git submodule update --init --recursive")
elseif (NOT EXISTS "${ClickHouse_SOURCE_DIR}/contrib/cassandra")

View File

@ -122,11 +122,11 @@ initdb()
CLICKHOUSE_DATADIR_FROM_CONFIG=$CLICKHOUSE_DATADIR
fi
if ! getent group ${CLICKHOUSE_USER} >/dev/null; then
if ! getent passwd ${CLICKHOUSE_USER} >/dev/null; then
echo "Can't chown to non-existing user ${CLICKHOUSE_USER}"
return
fi
if ! getent passwd ${CLICKHOUSE_GROUP} >/dev/null; then
if ! getent group ${CLICKHOUSE_GROUP} >/dev/null; then
echo "Can't chown to non-existing group ${CLICKHOUSE_GROUP}"
return
fi

View File

@ -2,7 +2,6 @@
"docker/packager/deb": {
"name": "yandex/clickhouse-deb-builder",
"dependent": [
"docker/packager/unbundled",
"docker/test/stateless",
"docker/test/stateless_with_coverage",
"docker/test/stateless_pytest",
@ -16,10 +15,6 @@
"docker/test/pvs"
]
},
"docker/packager/unbundled": {
"name": "yandex/clickhouse-unbundled-builder",
"dependent": []
},
"docker/test/coverage": {
"name": "yandex/clickhouse-coverage",
"dependent": []
@ -97,6 +92,10 @@
"name": "yandex/clickhouse-fasttest",
"dependent": []
},
"docker/test/style": {
"name": "yandex/clickhouse-style-test",
"dependent": []
},
"docker/test/integration/s3_proxy": {
"name": "yandex/clickhouse-s3-proxy",
"dependent": []

View File

@ -1,9 +1,9 @@
# docker build -t yandex/clickhouse-deb-builder .
FROM ubuntu:20.04
FROM ubuntu:19.10
RUN apt-get --allow-unauthenticated update -y && apt-get install --yes wget gnupg
RUN wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add -
RUN echo "deb [trusted=yes] http://apt.llvm.org/focal/ llvm-toolchain-focal-11 main" >> /etc/apt/sources.list
RUN echo "deb [trusted=yes] http://apt.llvm.org/eoan/ llvm-toolchain-eoan-10 main" >> /etc/apt/sources.list
# initial packages
RUN apt-get --allow-unauthenticated update -y \
@ -25,17 +25,13 @@ RUN curl -O https://clickhouse-builds.s3.yandex.net/utils/1/dpkg-deb
RUN chmod +x dpkg-deb
RUN cp dpkg-deb /usr/bin
# Libraries from OS are only needed to test the "unbundled" build (that is not used in production).
RUN apt-get --allow-unauthenticated update -y \
&& env DEBIAN_FRONTEND=noninteractive \
apt-get --allow-unauthenticated install --yes --no-install-recommends \
gcc-10 \
g++-10 \
gcc-9 \
g++-9 \
llvm-11 \
clang-11 \
lld-11 \
clang-tidy-11 \
llvm-10 \
clang-10 \
lld-10 \
@ -43,19 +39,54 @@ RUN apt-get --allow-unauthenticated update -y \
clang-9 \
lld-9 \
clang-tidy-9 \
libicu-dev \
libreadline-dev \
gperf \
ninja-build \
perl \
pkg-config \
devscripts \
debhelper \
git \
libc++-dev \
libc++abi-dev \
libboost-program-options-dev \
libboost-system-dev \
libboost-filesystem-dev \
libboost-thread-dev \
libboost-iostreams-dev \
libboost-regex-dev \
zlib1g-dev \
liblz4-dev \
libdouble-conversion-dev \
librdkafka-dev \
libpoconetssl62 \
libpoco-dev \
libgoogle-perftools-dev \
libzstd-dev \
libltdl-dev \
libre2-dev \
libjemalloc-dev \
libmsgpack-dev \
libcurl4-openssl-dev \
opencl-headers \
ocl-icd-libopencl1 \
intel-opencl-icd \
unixodbc-dev \
odbcinst \
tzdata \
gperf \
alien \
libcapnp-dev \
cmake \
gdb \
pigz \
moreutils \
pigz
libcctz-dev \
libldap2-dev \
libsasl2-dev \
heimdal-multidev \
libhyperscan-dev
# This symlink required by gcc to find lld compiler

View File

@ -11,7 +11,6 @@ SCRIPT_PATH = os.path.realpath(__file__)
IMAGE_MAP = {
"deb": "yandex/clickhouse-deb-builder",
"binary": "yandex/clickhouse-binary-builder",
"unbundled": "yandex/clickhouse-unbundled-builder"
}
def check_image_exists_locally(image_name):
@ -177,9 +176,7 @@ if __name__ == "__main__":
parser.add_argument("--clickhouse-repo-path", default=os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir))
parser.add_argument("--output-dir", required=True)
parser.add_argument("--build-type", choices=("debug", ""), default="")
parser.add_argument("--compiler", choices=("clang-10", "clang-10-darwin", "clang-10-aarch64", "clang-10-freebsd",
"clang-11", "clang-11-darwin", "clang-11-aarch64", "clang-11-freebsd",
"gcc-9", "gcc-10"), default="gcc-9")
parser.add_argument("--compiler", choices=("clang-10-darwin", "clang-10-aarch64", "clang-10-freebsd", "gcc-9", "clang-10"), default="gcc-9")
parser.add_argument("--sanitizer", choices=("address", "thread", "memory", "undefined", ""), default="")
parser.add_argument("--unbundled", action="store_true")
parser.add_argument("--split-binary", action="store_true")
@ -200,7 +197,7 @@ if __name__ == "__main__":
if not os.path.isabs(args.output_dir):
args.output_dir = os.path.abspath(os.path.join(os.getcwd(), args.output_dir))
image_type = 'binary' if args.package_type == 'performance' else 'unbundled' if args.unbundled else args.package_type
image_type = 'binary' if args.package_type == 'performance' else args.package_type
image_name = IMAGE_MAP[image_type]
if not os.path.isabs(args.clickhouse_repo_path):

View File

@ -1,56 +0,0 @@
# docker build -t yandex/clickhouse-unbundled-builder .
FROM yandex/clickhouse-deb-builder
# Libraries from OS are only needed to test the "unbundled" build (that is not used in production).
RUN apt-get --allow-unauthenticated update -y \
&& env DEBIAN_FRONTEND=noninteractive \
apt-get --allow-unauthenticated install --yes --no-install-recommends \
libicu-dev \
libreadline-dev \
gperf \
perl \
pkg-config \
devscripts \
libc++-dev \
libc++abi-dev \
libboost-program-options-dev \
libboost-system-dev \
libboost-filesystem-dev \
libboost-thread-dev \
libboost-iostreams-dev \
libboost-regex-dev \
zlib1g-dev \
liblz4-dev \
libdouble-conversion-dev \
librdkafka-dev \
libpoconetssl62 \
libpoco-dev \
libgoogle-perftools-dev \
libzstd-dev \
libltdl-dev \
libre2-dev \
libjemalloc-dev \
libmsgpack-dev \
libcurl4-openssl-dev \
opencl-headers \
ocl-icd-libopencl1 \
intel-opencl-icd \
unixodbc-dev \
odbcinst \
tzdata \
gperf \
alien \
libcapnp-dev \
cmake \
gdb \
pigz \
moreutils \
libcctz-dev \
libldap2-dev \
libsasl2-dev \
heimdal-multidev \
libhyperscan-dev
COPY build.sh /
CMD ["/bin/bash", "/build.sh"]

View File

@ -1,18 +0,0 @@
#!/usr/bin/env bash
set -x -e
# Update tzdata to the latest version. It is embedded into clickhouse binary.
sudo apt-get update && sudo apt-get install tzdata
ccache --show-stats ||:
ccache --zero-stats ||:
build/release --no-pbuilder $ALIEN_PKGS | ts '%Y-%m-%d %H:%M:%S'
mv /*.deb /output
mv *.changes /output
mv *.buildinfo /output
mv /*.rpm /output ||: # if exists
mv /*.tgz /output ||: # if exists
ccache --show-stats ||:
ln -s /usr/lib/x86_64-linux-gnu/libOpenCL.so.1.0.0 /usr/lib/libOpenCL.so ||:

View File

@ -158,6 +158,8 @@ TESTS_TO_SKIP=(
01280_ssd_complex_key_dictionary
00652_replicated_mutations_zookeeper
01411_bayesian_ab_testing
01238_http_memory_tracking # max_memory_usage_for_user can interfere another queries running concurrently
01281_group_by_limit_memory_tracking # max_memory_usage_for_user can interfere another queries running concurrently
)
clickhouse-test -j 4 --no-long --testname --shard --zookeeper --skip ${TESTS_TO_SKIP[*]} 2>&1 | ts '%Y-%m-%d %H:%M:%S' | tee /test_output/test_log.txt

View File

@ -24,20 +24,26 @@ def run_perf_test(cmd, xmls_path, output_folder):
return p
def get_options(i):
options = ""
if 0 < i:
options += " --order=random"
if i == 1:
options += " --atomic-db-engine"
return options
def run_func_test(cmd, output_prefix, num_processes, skip_tests_option):
skip_list_opt = get_skip_list_cmd(cmd)
output_paths = [os.path.join(output_prefix, "stress_test_run_{}.txt".format(i)) for i in range(num_processes)]
f = open(output_paths[0], 'w')
main_command = "{} {} {}".format(cmd, skip_list_opt, skip_tests_option)
logging.info("Run func tests main cmd '%s'", main_command)
pipes = [Popen(main_command, shell=True, stdout=f, stderr=f)]
for output_path in output_paths[1:]:
time.sleep(0.5)
f = open(output_path, 'w')
full_command = "{} {} --order=random {}".format(cmd, skip_list_opt, skip_tests_option)
pipes = []
for i in range(0, len(output_paths)):
f = open(output_paths[i], 'w')
full_command = "{} {} {} {}".format(cmd, skip_list_opt, get_options(i), skip_tests_option)
logging.info("Run func tests '%s'", full_command)
p = Popen(full_command, shell=True, stdout=f, stderr=f)
pipes.append(p)
time.sleep(0.5)
return pipes

View File

@ -0,0 +1,8 @@
# docker build -t yandex/clickhouse-style-test .
FROM ubuntu:20.04
RUN apt-get update && env DEBIAN_FRONTEND=noninteractive apt-get install --yes shellcheck libxml2-utils git python3-pip && pip3 install codespell
CMD cd /ClickHouse/utils/check-style && ./check-style -n | tee /test_output/style_output.txt && \
./check-duplicate-includes.sh | tee /test_output/duplicate_output.txt

View File

@ -31,7 +31,7 @@ For a description of request parameters, see [statement description](../../../sq
**ReplacingMergeTree Parameters**
- `ver` — column with version. Type `UInt*`, `Date` or `DateTime`. Optional parameter.
- `ver` — column with version. Type `UInt*`, `Date`, `DateTime` or `DateTime64`. Optional parameter.
When merging, `ReplacingMergeTree` from all the rows with the same sorting key leaves only one:

View File

@ -8,7 +8,7 @@ toc_title: Quotas
Quotas allow you to limit resource usage over a period of time or track the use of resources.
Quotas are set up in the user config, which is usually users.xml.
The system also has a feature for limiting the complexity of a single query. See the section “Restrictions on query complexity”).
The system also has a feature for limiting the complexity of a single query. See the section [Restrictions on query complexity](../operations/settings/query-complexity.md).
In contrast to query complexity restrictions, quotas:

View File

@ -399,7 +399,7 @@ The cache is shared for the server and memory is allocated as needed. The cache
```
## max\_server\_memory\_usage {#max_server_memory_usage}
Limits total RAM usage by the ClickHouse server. You can specify it only for the default profile.
Limits total RAM usage by the ClickHouse server.
Possible values:

View File

@ -0,0 +1,85 @@
---
toc_priority: 114
---
# groupArraySample {#grouparraysample}
Creates an array of sample argument values. The size of the resulting array is limited to `max_size` elements. Argument values are selected and added to the array randomly.
**Syntax**
``` sql
groupArraySample(max_size)(x)
```
or
``` sql
groupArraySample(max_size, seed)(x)
```
**Parameters**
- `max_size` — Maximum size of the resulting array. Positive [UInt64](../../data-types/int-uint.md).
- `seed` — Seed for the random number generator. Optional, can be omitted. Positive [UInt64](../../data-types/int-uint.md). Default value: `123456`.
- `x` — Argument name. [String](../../data-types/string.md).
**Returned values**
- Array of randomly selected `x` arguments.
Type: [Array](../../data-types/array.md).
**Examples**
Consider table `colors`:
``` text
┌─id─┬─color──┐
│ 1 │ red │
│ 2 │ blue │
│ 3 │ green │
│ 4 │ white │
│ 5 │ orange │
└────┴────────┘
```
Select `id`-s query:
``` sql
SELECT groupArraySample(3)(id) FROM colors;
```
Result:
``` text
┌─groupArraySample(3)(id)─┐
│ [1,2,4] │
└─────────────────────────┘
```
Select `color`-s query:
``` sql
SELECT groupArraySample(3)(color) FROM colors;
```
Result:
```text
┌─groupArraySample(3)(color)─┐
│ ['white','blue','green'] │
└────────────────────────────┘
```
Select `color`-s query with different seed:
``` sql
SELECT groupArraySample(3, 987654321)(color) FROM colors;
```
Result:
```text
┌─groupArraySample(3, 987654321)(color)─┐
│ ['red','orange','green'] │
└───────────────────────────────────────┘
```

View File

@ -33,7 +33,7 @@ Para obtener una descripción de los parámetros de solicitud, consulte [descrip
**ReplacingMergeTree Parámetros**
- `ver` — column with version. Type `UInt*`, `Date` o `DateTime`. Parámetro opcional.
- `ver` — column with version. Type `UInt*`, `Date`, `DateTime` o `DateTime64`. Parámetro opcional.
Al fusionar, `ReplacingMergeTree` de todas las filas con la misma clave primaria deja solo una:

View File

@ -33,7 +33,7 @@ CREATE TABLE [IF NOT EXISTS] [db.]table_name [ON CLUSTER cluster]
**پارامترهای جایگزین**
- `ver` — column with version. Type `UInt*`, `Date` یا `DateTime`. پارامتر اختیاری.
- `ver` — column with version. Type `UInt*`, `Date`, `DateTime` یا `DateTime64`. پارامتر اختیاری.
هنگام ادغام, `ReplacingMergeTree` از تمام ردیف ها با همان کلید اصلی تنها یک برگ دارد:

View File

@ -33,7 +33,7 @@ Pour une description des paramètres de requête, voir [demande de description](
**ReplacingMergeTree Paramètres**
- `ver` — column with version. Type `UInt*`, `Date` ou `DateTime`. Paramètre facultatif.
- `ver` — column with version. Type `UInt*`, `Date`, `DateTime` ou `DateTime64`. Paramètre facultatif.
Lors de la fusion, `ReplacingMergeTree` de toutes les lignes avec la même clé primaire ne laisse qu'un:

View File

@ -33,7 +33,7 @@ CREATE TABLE [IF NOT EXISTS] [db.]table_name [ON CLUSTER cluster]
**ReplacingMergeTreeパラメータ**
- `ver` — column with version. Type `UInt*`, `Date` または `DateTime`. 任意パラメータ。
- `ver` — column with version. Type `UInt*`, `Date`, `DateTime` または `DateTime64`. 任意パラメータ。
マージ時, `ReplacingMergeTree` 同じ主キーを持つすべての行から、一つだけを残します:

View File

@ -1,3 +1,5 @@
# Инструкция для разработчиков
Сборка ClickHouse поддерживается на Linux, FreeBSD, Mac OS X.
# Если вы используете Windows {#esli-vy-ispolzuete-windows}

View File

@ -25,7 +25,7 @@ CREATE TABLE [IF NOT EXISTS] [db.]table_name [ON CLUSTER cluster]
**Параметры ReplacingMergeTree**
- `ver` — столбец с версией, тип `UInt*`, `Date` или `DateTime`. Необязательный параметр.
- `ver` — столбец с версией, тип `UInt*`, `Date`, `DateTime` или `DateTime64`. Необязательный параметр.
При слиянии, из всех строк с одинаковым значением ключа сортировки `ReplacingMergeTree` оставляет только одну:

View File

@ -374,7 +374,7 @@ ClickHouse проверит условия `min_part_size` и `min_part_size_rat
## max_server_memory_usage {#max_server_memory_usage}
Ограничивает объём оперативной памяти, используемой сервером ClickHouse. Настройка может быть задана только для профиля `default`.
Ограничивает объём оперативной памяти, используемой сервером ClickHouse.
Возможные значения:

View File

@ -637,7 +637,7 @@ Upd. Готово (все директории кроме contrib).
Требует 7.26. Коллеги начали делать, есть результат.
Upd. В Аркадии частично работает небольшая часть тестов. И этого достаточно.
### 7.29. Опции clickhouse install, stop, start вместо postinst, init.d, systemd скриптов {#optsii-clickhouse-install-stop-start-vmesto-postinst-init-d-systemd-skriptov}
### 7.29. + Опции clickhouse install, stop, start вместо postinst, init.d, systemd скриптов {#optsii-clickhouse-install-stop-start-vmesto-postinst-init-d-systemd-skriptov}
Низкий приоритет.
@ -786,7 +786,7 @@ Upd. Готово.
Павел Круглов, ВШЭ и Яндекс.
Есть pull request. Готово.
### 8.17. ClickHouse как MySQL реплика {#clickhouse-kak-mysql-replika}
### 8.17. + ClickHouse как MySQL реплика {#clickhouse-kak-mysql-replika}
Задачу делает BohuTANG.
@ -1447,11 +1447,11 @@ Upd. Возможно будет отложено на следующий год
Василий Морозов, Арслан Гумеров, Альберт Кидрачев, ВШЭ.
В прошлом году задачу начинал делать другой человек, но не добился достаточного прогресса.
+ 1. Оптимизация top sort.
\+ 1. Оптимизация top sort.
В ClickHouse используется неоптимальный вариант top sort. Суть его в том, что из каждого блока достаётся top N записей, а затем, все блоки мержатся. Но доставание top N записей у каждого следующего блока бессмысленно, если мы знаем, что из них в глобальный top N войдёт меньше. Конечно нужно реализовать вариацию на тему priority queue (heap) с быстрым пропуском целых блоков, если ни одна строка не попадёт в накопленный top.
+ 2. Рекурсивный вариант сортировки по кортежам.
\+ 2. Рекурсивный вариант сортировки по кортежам.
Для сортировки по кортежам используется обычная сортировка с компаратором, который в цикле по элементам кортежа делает виртуальные вызовы `IColumn::compareAt`. Это неоптимально - как из-за короткого цикла по неизвестному в compile-time количеству элементов, так и из-за виртуальных вызовов. Чтобы обойтись без виртуальных вызовов, есть метод `IColumn::getPermutation`. Он используется в случае сортировки по одному столбцу. Есть вариант, что в случае сортировки по кортежу, что-то похожее тоже можно применить… например, сделать метод `updatePermutation`, принимающий аргументы offset и limit, и допереставляющий перестановку в диапазоне значений, в которых предыдущий столбец имел равные значения.
@ -1583,8 +1583,8 @@ Upd. Готово.
После 10.14.
[\#7237](https://github.com/ClickHouse/ClickHouse/issues/7237)
[\#2655](https://github.com/ClickHouse/ClickHouse/issues/2655)
[#7237](https://github.com/ClickHouse/ClickHouse/issues/7237)
[#2655](https://github.com/ClickHouse/ClickHouse/issues/2655)
### 22.23. Правильная обработка Nullable в функциях, которые кидают исключение на default значении: modulo, intDiv {#pravilnaia-obrabotka-nullable-v-funktsiiakh-kotorye-kidaiut-iskliuchenie-na-default-znachenii-modulo-intdiv}
@ -1598,7 +1598,7 @@ Upd. Готово.
### 22.26. Плохая производительность quantileTDigest {#plokhaia-proizvoditelnost-quantiletdigest}
[\#2668](https://github.com/ClickHouse/ClickHouse/issues/2668)
[#2668](https://github.com/ClickHouse/ClickHouse/issues/2668)
Алексей Миловидов или будет переназначено.

View File

@ -33,7 +33,7 @@ CREATE TABLE [IF NOT EXISTS] [db.]table_name [ON CLUSTER cluster]
**ReplacingMergeTree Parametreleri**
- `ver` — column with version. Type `UInt*`, `Date` veya `DateTime`. İsteğe bağlı parametre.
- `ver` — column with version. Type `UInt*`, `Date`, `DateTime` veya `DateTime64`. İsteğe bağlı parametre.
Birleş whenirken, `ReplacingMergeTree` aynı birincil anahtara sahip tüm satırlardan sadece bir tane bırakır:

View File

@ -25,7 +25,7 @@ CREATE TABLE [IF NOT EXISTS] [db.]table_name [ON CLUSTER cluster]
**参数**
- `ver` — 版本列。类型为 `UInt*`, `Date``DateTime`。可选参数。
- `ver` — 版本列。类型为 `UInt*`, `Date`, `DateTime``DateTime64`。可选参数。
合并的时候,`ReplacingMergeTree` 从所有具有相同主键的行中选择一行留下:
- 如果 `ver` 列未指定,选择最后一条。

View File

@ -247,12 +247,15 @@ try
context->setCurrentDatabase(default_database);
applyCmdOptions();
if (!context->getPath().empty())
String path = context->getPath();
if (!path.empty())
{
/// Lock path directory before read
status.emplace(context->getPath() + "status", StatusFile::write_full_info);
LOG_DEBUG(log, "Loading metadata from {}", context->getPath());
LOG_DEBUG(log, "Loading metadata from {}", path);
Poco::File(path + "data/").createDirectories();
Poco::File(path + "metadata/").createDirectories();
loadMetadataSystem(*context);
attachSystemTables(*context);
loadMetadata(*context);

View File

@ -356,7 +356,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
std::string path = getCanonicalPath(config().getString("path", DBMS_DEFAULT_PATH));
std::string default_database = config().getString("default_database", "default");
/// Check that the process' user id matches the owner of the data.
/// Check that the process user id matches the owner of the data.
const auto effective_user_id = geteuid();
struct stat statbuf;
if (stat(path.c_str(), &statbuf) == 0 && effective_user_id != statbuf.st_uid)
@ -468,6 +468,9 @@ int Server::main(const std::vector<std::string> & /*args*/)
}
{
Poco::File(path + "data/").createDirectories();
Poco::File(path + "metadata/").createDirectories();
/// Directory with metadata of tables, which was marked as dropped by Atomic database
Poco::File(path + "metadata_dropped/").createDirectories();
}

View File

@ -388,13 +388,14 @@ public:
{
for (size_t j = 0; j < UNROLL_COUNT; ++j)
{
if (has_data[j * 256 + k])
size_t idx = j * 256 + k;
if (has_data[idx])
{
AggregateDataPtr & place = map[k];
if (unlikely(!place))
init(place);
func.merge(place + place_offset, reinterpret_cast<const char *>(&places[256 * j + k]), arena);
func.merge(place + place_offset, reinterpret_cast<const char *>(&places[idx]), nullptr);
}
}
}

View File

@ -140,6 +140,7 @@ endmacro()
add_object_library(clickhouse_access Access)
add_object_library(clickhouse_core Core)
add_object_library(clickhouse_core_mysql Core/MySQL)
add_object_library(clickhouse_compression Compression)
add_object_library(clickhouse_datastreams DataStreams)
add_object_library(clickhouse_datatypes DataTypes)

View File

@ -31,8 +31,17 @@ namespace ErrorCodes
extern const int PARAMETER_OUT_OF_BOUND;
extern const int SIZES_OF_COLUMNS_DOESNT_MATCH;
extern const int LOGICAL_ERROR;
extern const int TOO_LARGE_ARRAY_SIZE;
}
/** Obtaining array as Field can be slow for large arrays and consume vast amount of memory.
* Just don't allow to do it.
* You can increase the limit if the following query:
* SELECT range(10000000)
* will take less than 500ms on your machine.
*/
static constexpr size_t max_array_size_as_field = 1000000;
ColumnArray::ColumnArray(MutableColumnPtr && nested_column, MutableColumnPtr && offsets_column)
: data(std::move(nested_column)), offsets(std::move(offsets_column))
@ -117,6 +126,11 @@ Field ColumnArray::operator[](size_t n) const
{
size_t offset = offsetAt(n);
size_t size = sizeAt(n);
if (size > max_array_size_as_field)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Array of size {} is too large to be manipulated as single field, maximum size {}",
size, max_array_size_as_field);
Array res(size);
for (size_t i = 0; i < size; ++i)
@ -130,6 +144,11 @@ void ColumnArray::get(size_t n, Field & res) const
{
size_t offset = offsetAt(n);
size_t size = sizeAt(n);
if (size > max_array_size_as_field)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Array of size {} is too large to be manipulated as single field, maximum size {}",
size, max_array_size_as_field);
res = Array(size);
Array & res_arr = DB::get<Array &>(res);

View File

@ -1,229 +0,0 @@
#include <Common/typeid_cast.h>
#include <Common/ZooKeeper/ZooKeeper.h>
#include <Common/ZooKeeper/KeeperException.h>
#include <Common/StringUtils/StringUtils.h>
#include <iostream>
#include <chrono>
#include <gtest/gtest.h>
#include <Common/ShellCommand.h>
using namespace DB;
template <typename... Args>
auto getZooKeeper(Args &&... args)
{
/// In our CI infrastructure it is typical that ZooKeeper is unavailable for some amount of time.
size_t i;
for (i = 0; i < 100; ++i)
{
try
{
auto zookeeper = std::make_unique<zkutil::ZooKeeper>("localhost:2181", std::forward<Args>(args)...);
zookeeper->exists("/");
zookeeper->createIfNotExists("/clickhouse_test", "Unit tests of ClickHouse");
return zookeeper;
}
catch (...)
{
std::cerr << "Zookeeper is unavailable, try " << i << std::endl;
sleep(1);
continue;
}
}
std::cerr << "No zookeeper after " << i << " tries. skip tests." << std::endl;
exit(0);
}
TEST(zkutil, MultiNiceExceptionMsg)
{
auto zookeeper = getZooKeeper();
Coordination::Requests ops;
ASSERT_NO_THROW(
zookeeper->tryRemoveRecursive("/clickhouse_test/zkutil_multi");
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi", "_", zkutil::CreateMode::Persistent));
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi/a", "_", zkutil::CreateMode::Persistent));
zookeeper->multi(ops);
);
try
{
ops.clear();
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi/c", "_", zkutil::CreateMode::Persistent));
ops.emplace_back(zkutil::makeRemoveRequest("/clickhouse_test/zkutil_multi/c", -1));
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi/a", "BadBoy", zkutil::CreateMode::Persistent));
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi/b", "_", zkutil::CreateMode::Persistent));
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi/a", "_", zkutil::CreateMode::Persistent));
zookeeper->multi(ops);
FAIL();
}
catch (...)
{
zookeeper->tryRemoveRecursive("/clickhouse_test/zkutil_multi");
String msg = getCurrentExceptionMessage(false);
bool msg_has_reqired_patterns = msg.find("#2") != std::string::npos;
EXPECT_TRUE(msg_has_reqired_patterns) << msg;
}
}
TEST(zkutil, MultiAsync)
{
Coordination::Requests ops;
getZooKeeper()->tryRemoveRecursive("/clickhouse_test/zkutil_multi");
{
ops.clear();
auto zookeeper = getZooKeeper();
auto fut = zookeeper->asyncMulti(ops);
}
{
ops.clear();
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi", "", zkutil::CreateMode::Persistent));
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi/a", "", zkutil::CreateMode::Persistent));
auto zookeeper = getZooKeeper();
auto fut = zookeeper->tryAsyncMulti(ops);
ops.clear();
auto res = fut.get();
ASSERT_EQ(res.error, Coordination::Error::ZOK);
ASSERT_EQ(res.responses.size(), 2);
}
EXPECT_ANY_THROW
(
auto zookeeper = getZooKeeper();
std::vector<std::future<Coordination::MultiResponse>> futures;
for (size_t i = 0; i < 10000; ++i)
{
ops.clear();
ops.emplace_back(zkutil::makeRemoveRequest("/clickhouse_test/zkutil_multi", -1));
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi", "_", zkutil::CreateMode::Persistent));
ops.emplace_back(zkutil::makeCheckRequest("/clickhouse_test/zkutil_multi", -1));
ops.emplace_back(zkutil::makeSetRequest("/clickhouse_test/zkutil_multi", "xxx", 42));
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi/a", "_", zkutil::CreateMode::Persistent));
futures.emplace_back(zookeeper->asyncMulti(ops));
}
futures[0].get();
);
/// Check there are no segfaults for remaining 999 futures
using namespace std::chrono_literals;
std::this_thread::sleep_for(1s);
try
{
ops.clear();
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi", "_", zkutil::CreateMode::Persistent));
ops.emplace_back(zkutil::makeCreateRequest("/clickhouse_test/zkutil_multi/a", "_", zkutil::CreateMode::Persistent));
auto zookeeper = getZooKeeper();
auto fut = zookeeper->tryAsyncMulti(ops);
ops.clear();
auto res = fut.get();
/// The test is quite heavy. It is normal if session is expired during this test.
/// If we don't check that, the test will be flacky.
if (res.error != Coordination::Error::ZSESSIONEXPIRED && res.error != Coordination::Error::ZCONNECTIONLOSS)
{
ASSERT_EQ(res.error, Coordination::Error::ZNODEEXISTS);
ASSERT_EQ(res.responses.size(), 2);
}
}
catch (const Coordination::Exception & e)
{
if (e.code != Coordination::Error::ZSESSIONEXPIRED && e.code != Coordination::Error::ZCONNECTIONLOSS)
throw;
}
}
TEST(zkutil, WatchGetChildrenWithChroot)
{
try
{
const String prefix = "/clickhouse_test/zkutil/watch_get_children_with_chroot";
/// Create chroot node firstly
auto zookeeper = getZooKeeper();
zookeeper->createAncestors(prefix + "/");
zookeeper = getZooKeeper("",
zkutil::DEFAULT_SESSION_TIMEOUT,
zkutil::DEFAULT_OPERATION_TIMEOUT,
prefix);
String queue_path = "/queue";
zookeeper->tryRemoveRecursive(queue_path);
zookeeper->createAncestors(queue_path + "/");
zkutil::EventPtr event = std::make_shared<Poco::Event>();
zookeeper->getChildren(queue_path, nullptr, event);
{
auto zookeeper2 = getZooKeeper("",
zkutil::DEFAULT_SESSION_TIMEOUT,
zkutil::DEFAULT_OPERATION_TIMEOUT,
prefix);
zookeeper2->create(queue_path + "/children-", "", zkutil::CreateMode::PersistentSequential);
}
event->wait();
}
catch (...)
{
std::cerr << getCurrentExceptionMessage(true);
throw;
}
}
TEST(zkutil, MultiCreateSequential)
{
try
{
const String prefix = "/clickhouse_test/zkutil";
/// Create chroot node firstly
auto zookeeper = getZooKeeper();
zookeeper->createAncestors(prefix + "/");
zookeeper = getZooKeeper("",
zkutil::DEFAULT_SESSION_TIMEOUT,
zkutil::DEFAULT_OPERATION_TIMEOUT,
"/clickhouse_test");
String base_path = "/multi_create_sequential";
zookeeper->tryRemoveRecursive(base_path);
zookeeper->createAncestors(base_path + "/");
Coordination::Requests ops;
String sequential_node_prefix = base_path + "/queue-";
ops.emplace_back(zkutil::makeCreateRequest(sequential_node_prefix, "", zkutil::CreateMode::EphemeralSequential));
auto results = zookeeper->multi(ops);
const auto & sequential_node_result_op = dynamic_cast<const Coordination::CreateResponse &>(*results.at(0));
EXPECT_FALSE(sequential_node_result_op.path_created.empty());
EXPECT_GT(sequential_node_result_op.path_created.length(), sequential_node_prefix.length());
EXPECT_EQ(sequential_node_result_op.path_created.substr(0, sequential_node_prefix.length()), sequential_node_prefix);
}
catch (...)
{
std::cerr << getCurrentExceptionMessage(false);
throw;
}
}

View File

@ -9,21 +9,34 @@ static bool check()
{
ThreadPool pool(10);
/// The throwing thread.
pool.scheduleOrThrowOnError([] { throw std::runtime_error("Hello, world!"); });
try
{
for (size_t i = 0; i < 500; ++i)
pool.scheduleOrThrowOnError([] {}); /// An exception will be rethrown from this method.
while (true)
{
/// An exception from the throwing thread will be rethrown from this method
/// as soon as the throwing thread executed.
/// This innocent thread may or may not be executed, the following possibilities exist:
/// 1. The throwing thread has already thrown exception and the attempt to schedule the innocent thread will rethrow it.
/// 2. The throwing thread has not executed, the innocent thread will be scheduled and executed.
/// 3. The throwing thread has not executed, the innocent thread will be scheduled but before it will be executed,
/// the throwing thread will be executed and throw exception and it will prevent starting of execution of the innocent thread
/// the method will return and the exception will be rethrown only on call to "wait" or on next call on next loop iteration as (1).
pool.scheduleOrThrowOnError([]{});
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
catch (const std::runtime_error &)
{
pool.wait();
return true;
}
pool.wait();
return false;
__builtin_unreachable();
}

View File

@ -0,0 +1,242 @@
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Poco/RandomStream.h>
#include <Poco/SHA1Engine.h>
#include <Access/User.h>
#include <Access/AccessControlManager.h>
#include <common/logger_useful.h>
#include <Common/MemoryTracker.h>
#include <Common/OpenSSLHelpers.h>
#include <ext/scope_guard.h>
namespace DB
{
namespace ErrorCodes
{
extern const int OPENSSL_ERROR;
extern const int UNKNOWN_EXCEPTION;
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES;
}
namespace MySQLProtocol
{
using namespace ConnectionPhase;
namespace Authentication
{
static const size_t SCRAMBLE_LENGTH = 20;
Native41::Native41()
{
scramble.resize(SCRAMBLE_LENGTH + 1, 0);
Poco::RandomInputStream generator;
/** Generate a random string using ASCII characters but avoid separator character,
* produce pseudo random numbers between with about 7 bit worth of entropty between 1-127.
* https://github.com/mysql/mysql-server/blob/8.0/mysys/crypt_genhash_impl.cc#L427
*/
for (size_t i = 0; i < SCRAMBLE_LENGTH; ++i)
{
generator >> scramble[i];
scramble[i] &= 0x7f;
if (scramble[i] == '\0' || scramble[i] == '$')
scramble[i] = scramble[i] + 1;
}
}
Native41::Native41(const String & password, const String & auth_plugin_data)
{
/// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
/// SHA1( password ) XOR SHA1( "20-bytes random data from server" <concat> SHA1( SHA1( password ) ) )
Poco::SHA1Engine engine1;
engine1.update(password);
const Poco::SHA1Engine::Digest & password_sha1 = engine1.digest();
Poco::SHA1Engine engine2;
engine2.update(password_sha1.data(), password_sha1.size());
const Poco::SHA1Engine::Digest & password_double_sha1 = engine2.digest();
Poco::SHA1Engine engine3;
engine3.update(auth_plugin_data.data(), auth_plugin_data.size());
engine3.update(password_double_sha1.data(), password_double_sha1.size());
const Poco::SHA1Engine::Digest & digest = engine3.digest();
scramble.resize(SCRAMBLE_LENGTH);
for (size_t i = 0; i < SCRAMBLE_LENGTH; i++)
{
scramble[i] = static_cast<unsigned char>(password_sha1[i] ^ digest[i]);
}
}
void Native41::authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool, const Poco::Net::SocketAddress & address)
{
if (!auth_response)
{
packet_endpoint->sendPacket(AuthSwitchRequest(getName(), scramble), true);
AuthSwitchResponse response;
packet_endpoint->receivePacket(response);
auth_response = response.value;
}
if (auth_response->empty())
{
context.setUser(user_name, "", address);
return;
}
if (auth_response->size() != Poco::SHA1Engine::DIGEST_SIZE)
throw Exception("Wrong size of auth response. Expected: " + std::to_string(Poco::SHA1Engine::DIGEST_SIZE) + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.",
ErrorCodes::UNKNOWN_EXCEPTION);
auto user = context.getAccessControlManager().read<User>(user_name);
Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);
Poco::SHA1Engine engine;
engine.update(scramble.data(), SCRAMBLE_LENGTH);
engine.update(double_sha1_value.data(), double_sha1_value.size());
String password_sha1(Poco::SHA1Engine::DIGEST_SIZE, 0x0);
const Poco::SHA1Engine::Digest & digest = engine.digest();
for (size_t i = 0; i < password_sha1.size(); i++)
{
password_sha1[i] = digest[i] ^ static_cast<unsigned char>((*auth_response)[i]);
}
context.setUser(user_name, password_sha1, address);
}
#if USE_SSL
Sha256Password::Sha256Password(RSA & public_key_, RSA & private_key_, Poco::Logger * log_)
: public_key(public_key_), private_key(private_key_), log(log_)
{
/** Native authentication sent 20 bytes + '\0' character = 21 bytes.
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin. [1]
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
*/
scramble.resize(SCRAMBLE_LENGTH + 1, 0);
Poco::RandomInputStream generator;
for (size_t i = 0; i < SCRAMBLE_LENGTH; ++i)
{
generator >> scramble[i];
scramble[i] &= 0x7f;
if (scramble[i] == '\0' || scramble[i] == '$')
scramble[i] = scramble[i] + 1;
}
}
void Sha256Password::authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address)
{
if (!auth_response)
{
packet_endpoint->sendPacket(AuthSwitchRequest(getName(), scramble), true);
if (packet_endpoint->in->eof())
throw Exception("Client doesn't support authentication method " + getName() + " used by ClickHouse. Specifying user password using 'password_double_sha1_hex' may fix the problem.",
ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
AuthSwitchResponse response;
packet_endpoint->receivePacket(response);
auth_response.emplace(response.value);
LOG_TRACE(log, "Authentication method mismatch.");
}
else
{
LOG_TRACE(log, "Authentication method match.");
}
bool sent_public_key = false;
if (auth_response == "\1")
{
LOG_TRACE(log, "Client requests public key.");
BIO * mem = BIO_new(BIO_s_mem());
SCOPE_EXIT(BIO_free(mem));
if (PEM_write_bio_RSA_PUBKEY(mem, &public_key) != 1)
{
throw Exception("Failed to write public key to memory. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
char * pem_buf = nullptr;
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wold-style-cast"
int64_t pem_size = BIO_get_mem_data(mem, &pem_buf);
# pragma GCC diagnostic pop
String pem(pem_buf, pem_size);
LOG_TRACE(log, "Key: {}", pem);
AuthMoreData data(pem);
packet_endpoint->sendPacket(data, true);
sent_public_key = true;
AuthSwitchResponse response;
packet_endpoint->receivePacket(response);
auth_response.emplace(response.value);
}
else
{
LOG_TRACE(log, "Client didn't request public key.");
}
String password;
/** Decrypt password, if it's not empty.
* The original intention was that the password is a string[NUL] but this never got enforced properly so now we have to accept that
* an empty packet is a blank password, thus the check for auth_response.empty() has to be made too.
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L4017
*/
if (!is_secure_connection && !auth_response->empty() && auth_response != String("\0", 1))
{
LOG_TRACE(log, "Received nonempty password.");
const auto & unpack_auth_response = *auth_response;
const auto * ciphertext = reinterpret_cast<const unsigned char *>(unpack_auth_response.data());
unsigned char plaintext[RSA_size(&private_key)];
int plaintext_size = RSA_private_decrypt(unpack_auth_response.size(), ciphertext, plaintext, &private_key, RSA_PKCS1_OAEP_PADDING);
if (plaintext_size == -1)
{
if (!sent_public_key)
LOG_WARNING(log, "Client could have encrypted password with different public key since it didn't request it from server.");
throw Exception("Failed to decrypt auth data. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
password.resize(plaintext_size);
for (int i = 0; i < plaintext_size; i++)
{
password[i] = plaintext[i] ^ static_cast<unsigned char>(scramble[i % scramble.size()]);
}
}
else if (is_secure_connection)
{
password = *auth_response;
}
else
{
LOG_TRACE(log, "Received empty password");
}
if (!password.empty() && password.back() == 0)
{
password.pop_back();
}
context.setUser(user_name, password, address);
}
#endif
}
}
}

View File

@ -0,0 +1,87 @@
#pragma once
#include <Core/Types.h>
#include <Interpreters/Context.h>
#include <Core/MySQL/PacketEndpoint.h>
#if !defined(ARCADIA_BUILD)
# include "config_core.h"
#endif
#if USE_SSL
# include <openssl/pem.h>
# include <openssl/rsa.h>
#endif
namespace DB
{
namespace MySQLProtocol
{
namespace Authentication
{
class IPlugin
{
public:
virtual ~IPlugin() = default;
virtual String getName() = 0;
virtual String getAuthPluginData() = 0;
virtual void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) = 0;
};
/// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
class Native41 : public IPlugin
{
public:
Native41();
Native41(const String & password, const String & auth_plugin_data);
String getName() override { return "mysql_native_password"; }
String getAuthPluginData() override { return scramble; }
void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool /* is_secure_connection */, const Poco::Net::SocketAddress & address) override;
private:
String scramble;
};
#if USE_SSL
/// Caching SHA2 plugin is not used because it would be possible to authenticate knowing hash from users.xml.
/// https://dev.mysql.com/doc/internals/en/sha256.html
class Sha256Password : public IPlugin
{
public:
Sha256Password(RSA & public_key_, RSA & private_key_, Poco::Logger * log_);
String getName() override { return "sha256_password"; }
String getAuthPluginData() override { return scramble; }
void authenticate(
const String & user_name, std::optional<String> auth_response, Context & context,
std::shared_ptr<PacketEndpoint> packet_endpoint, bool is_secure_connection, const Poco::Net::SocketAddress & address) override;
private:
RSA & public_key;
RSA & private_key;
Poco::Logger * log;
String scramble;
};
#endif
}
}
}

View File

@ -0,0 +1,81 @@
#include <Core/MySQL/IMySQLReadPacket.h>
#include <sstream>
#include <IO/MySQLPacketPayloadReadBuffer.h>
#include <IO/LimitReadBuffer.h>
namespace DB
{
namespace ErrorCodes
{
extern const int UNKNOWN_PACKET_FROM_CLIENT;
}
namespace MySQLProtocol
{
void IMySQLReadPacket::readPayload(ReadBuffer & in, uint8_t & sequence_id)
{
MySQLPacketPayloadReadBuffer payload(in, sequence_id);
payload.next();
readPayloadImpl(payload);
if (!payload.eof())
{
std::stringstream tmp;
tmp << "Packet payload is not fully read. Stopped after " << payload.count() << " bytes, while " << payload.available() << " bytes are in buffer.";
throw Exception(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
}
void IMySQLReadPacket::readPayloadWithUnpacked(ReadBuffer & in)
{
readPayloadImpl(in);
}
void LimitedReadPacket::readPayload(ReadBuffer &in, uint8_t &sequence_id)
{
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
IMySQLReadPacket::readPayload(limited, sequence_id);
}
void LimitedReadPacket::readPayloadWithUnpacked(ReadBuffer & in)
{
LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
IMySQLReadPacket::readPayloadWithUnpacked(limited);
}
uint64_t readLengthEncodedNumber(ReadBuffer & buffer)
{
char c{};
uint64_t buf = 0;
buffer.readStrict(c);
auto cc = static_cast<uint8_t>(c);
if (cc < 0xfc)
{
return cc;
}
else if (cc < 0xfd)
{
buffer.readStrict(reinterpret_cast<char *>(&buf), 2);
}
else if (cc < 0xfe)
{
buffer.readStrict(reinterpret_cast<char *>(&buf), 3);
}
else
{
buffer.readStrict(reinterpret_cast<char *>(&buf), 8);
}
return buf;
}
void readLengthEncodedString(String & s, ReadBuffer & buffer)
{
uint64_t len = readLengthEncodedNumber(buffer);
s.resize(len);
buffer.readStrict(reinterpret_cast<char *>(s.data()), len);
}
}
}

View File

@ -0,0 +1,41 @@
#pragma once
#include <IO/ReadBuffer.h>
namespace DB
{
namespace MySQLProtocol
{
class IMySQLReadPacket
{
public:
IMySQLReadPacket() = default;
virtual ~IMySQLReadPacket() = default;
IMySQLReadPacket(IMySQLReadPacket &&) = default;
virtual void readPayload(ReadBuffer & in, uint8_t & sequence_id);
virtual void readPayloadWithUnpacked(ReadBuffer & in);
protected:
virtual void readPayloadImpl(ReadBuffer & buf) = 0;
};
class LimitedReadPacket : public IMySQLReadPacket
{
public:
void readPayload(ReadBuffer & in, uint8_t & sequence_id) override;
void readPayloadWithUnpacked(ReadBuffer & in) override;
};
uint64_t readLengthEncodedNumber(ReadBuffer & buffer);
void readLengthEncodedString(String & s, ReadBuffer & buffer);
}
}

View File

@ -0,0 +1,86 @@
#include <Core/MySQL/IMySQLWritePacket.h>
#include <IO/MySQLPacketPayloadWriteBuffer.h>
#include <sstream>
namespace DB
{
namespace MySQLProtocol
{
void IMySQLWritePacket::writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const
{
MySQLPacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
writePayloadImpl(buf);
buf.next();
if (buf.remainingPayloadSize())
{
std::stringstream ss;
ss << "Incomplete payload. Written " << getPayloadSize() - buf.remainingPayloadSize() << " bytes, expected " << getPayloadSize() << " bytes.";
throw Exception(ss.str(), 0);
}
}
size_t getLengthEncodedNumberSize(uint64_t x)
{
if (x < 251)
{
return 1;
}
else if (x < (1 << 16))
{
return 3;
}
else if (x < (1 << 24))
{
return 4;
}
else
{
return 9;
}
}
size_t getLengthEncodedStringSize(const String & s)
{
return getLengthEncodedNumberSize(s.size()) + s.size();
}
void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer)
{
if (x < 251)
{
buffer.write(static_cast<char>(x));
}
else if (x < (1 << 16))
{
buffer.write(0xfc);
buffer.write(reinterpret_cast<char *>(&x), 2);
}
else if (x < (1 << 24))
{
buffer.write(0xfd);
buffer.write(reinterpret_cast<char *>(&x), 3);
}
else
{
buffer.write(0xfe);
buffer.write(reinterpret_cast<char *>(&x), 8);
}
}
void writeLengthEncodedString(const String & s, WriteBuffer & buffer)
{
writeLengthEncodedNumber(s.size(), buffer);
buffer.write(s.data(), s.size());
}
void writeNulTerminatedString(const String & s, WriteBuffer & buffer)
{
buffer.write(s.data(), s.size());
buffer.write(0);
}
}
}

View File

@ -0,0 +1,37 @@
#pragma once
#include <IO/WriteBuffer.h>
namespace DB
{
namespace MySQLProtocol
{
class IMySQLWritePacket
{
public:
IMySQLWritePacket() = default;
virtual ~IMySQLWritePacket() = default;
IMySQLWritePacket(IMySQLWritePacket &&) = default;
virtual void writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const;
protected:
virtual size_t getPayloadSize() const = 0;
virtual void writePayloadImpl(WriteBuffer & buffer) const = 0;
};
size_t getLengthEncodedNumberSize(uint64_t x);
size_t getLengthEncodedStringSize(const String & s);
void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer);
void writeLengthEncodedString(const String & s, WriteBuffer & buffer);
void writeNulTerminatedString(const String & s, WriteBuffer & buffer);
}
}

View File

@ -1,10 +1,19 @@
#include "MySQLClient.h"
#include <Core/MySQLReplication.h>
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <Core/MySQL/PacketsReplication.h>
#include <Core/MySQL/MySQLReplication.h>
namespace DB
{
using namespace Generic;
using namespace Replication;
using namespace ProtocolText;
using namespace Authentication;
using namespace ConnectionPhase;
namespace ErrorCodes
{
@ -44,7 +53,7 @@ void MySQLClient::connect()
in = std::make_shared<ReadBufferFromPocoSocket>(*socket);
out = std::make_shared<WriteBufferFromPocoSocket>(*socket);
packet_sender = std::make_shared<PacketSender>(*in, *out, seq);
packet_endpoint = std::make_shared<PacketEndpoint>(*in, *out, seq);
handshake();
}
@ -62,10 +71,10 @@ void MySQLClient::disconnect()
void MySQLClient::handshake()
{
Handshake handshake;
packet_sender->receivePacket(handshake);
packet_endpoint->receivePacket(handshake);
if (handshake.auth_plugin_name != mysql_native_password)
{
throw MySQLClientError(
throw Exception(
"Only support " + mysql_native_password + " auth plugin name, but got " + handshake.auth_plugin_name,
ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
}
@ -74,48 +83,48 @@ void MySQLClient::handshake()
String auth_plugin_data = native41.getAuthPluginData();
HandshakeResponse handshake_response(
client_capability_flags, max_packet_size, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
packet_sender->sendPacket<HandshakeResponse>(handshake_response, true);
client_capability_flags, MAX_PACKET_LENGTH, charset_utf8, user, "", auth_plugin_data, mysql_native_password);
packet_endpoint->sendPacket<HandshakeResponse>(handshake_response, true);
PacketResponse packet_response(client_capability_flags, true);
packet_sender->receivePacket(packet_response);
packet_sender->resetSequenceId();
ResponsePacket packet_response(client_capability_flags, true);
packet_endpoint->receivePacket(packet_response);
packet_endpoint->resetSequenceId();
if (packet_response.getType() == PACKET_ERR)
throw MySQLClientError(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
else if (packet_response.getType() == PACKET_AUTH_SWITCH)
throw MySQLClientError("Access denied for user " + user, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
throw Exception("Access denied for user " + user, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
}
void MySQLClient::writeCommand(char command, String query)
{
WriteCommand write_command(command, query);
packet_sender->sendPacket<WriteCommand>(write_command, true);
packet_endpoint->sendPacket<WriteCommand>(write_command, true);
PacketResponse packet_response(client_capability_flags);
packet_sender->receivePacket(packet_response);
ResponsePacket packet_response(client_capability_flags);
packet_endpoint->receivePacket(packet_response);
switch (packet_response.getType())
{
case PACKET_ERR:
throw MySQLClientError(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
case PACKET_OK:
break;
default:
break;
}
packet_sender->resetSequenceId();
packet_endpoint->resetSequenceId();
}
void MySQLClient::registerSlaveOnMaster(UInt32 slave_id)
{
RegisterSlave register_slave(slave_id);
packet_sender->sendPacket<RegisterSlave>(register_slave, true);
packet_endpoint->sendPacket<RegisterSlave>(register_slave, true);
PacketResponse packet_response(client_capability_flags);
packet_sender->receivePacket(packet_response);
packet_sender->resetSequenceId();
ResponsePacket packet_response(client_capability_flags);
packet_endpoint->receivePacket(packet_response);
packet_endpoint->resetSequenceId();
if (packet_response.getType() == PACKET_ERR)
throw MySQLClientError(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
throw Exception(packet_response.err.error_message, ErrorCodes::UNKNOWN_PACKET_FROM_SERVER);
}
void MySQLClient::ping()
@ -141,12 +150,12 @@ void MySQLClient::startBinlogDump(UInt32 slave_id, String replicate_db, String b
binlog_pos = binlog_pos < 4 ? 4 : binlog_pos;
BinlogDump binlog_dump(binlog_pos, binlog_file_name, slave_id);
packet_sender->sendPacket<BinlogDump>(binlog_dump, true);
packet_endpoint->sendPacket<BinlogDump>(binlog_dump, true);
}
BinlogEventPtr MySQLClient::readOneBinlogEvent(UInt64 milliseconds)
{
if (packet_sender->tryReceivePacket(replication, milliseconds))
if (packet_endpoint->tryReceivePacket(replication, milliseconds))
return replication.readOneEvent();
return {};

View File

@ -1,7 +1,6 @@
#pragma once
#include <Core/MySQLProtocol.h>
#include <Core/MySQLReplication.h>
#include <Core/Types.h>
#include <Core/MySQL/MySQLReplication.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteBufferFromPocoSocket.h>
@ -11,19 +10,15 @@
#include <Common/DNSResolver.h>
#include <Common/Exception.h>
#include <Common/NetException.h>
#include <Core/MySQL/IMySQLWritePacket.h>
namespace DB
{
using namespace MySQLProtocol;
using namespace MySQLReplication;
class MySQLClientError : public DB::Exception
{
public:
using Exception::Exception;
};
class MySQLClient
{
public:
@ -49,7 +44,6 @@ private:
uint8_t seq = 0;
const UInt8 charset_utf8 = 33;
const UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH;
const String mysql_native_password = "mysql_native_password";
MySQLFlavor replication;
@ -57,14 +51,14 @@ private:
std::shared_ptr<WriteBuffer> out;
std::unique_ptr<Poco::Net::StreamSocket> socket;
std::optional<Poco::Net::SocketAddress> address;
std::shared_ptr<PacketSender> packet_sender;
std::shared_ptr<PacketEndpoint> packet_endpoint;
void handshake();
void registerSlaveOnMaster(UInt32 slave_id);
void writeCommand(char command, String query);
};
class WriteCommand : public WritePacket
class WriteCommand : public IMySQLWritePacket
{
public:
char command;

View File

@ -4,6 +4,8 @@
#include <IO/ReadBufferFromString.h>
#include <common/DateLUT.h>
#include <Common/FieldVisitors.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsProtocolText.h>
namespace DB
{
@ -15,6 +17,8 @@ namespace ErrorCodes
namespace MySQLReplication
{
using namespace MySQLProtocol;
using namespace MySQLProtocol::Generic;
using namespace MySQLProtocol::ProtocolText;
/// https://dev.mysql.com/doc/internals/en/binlog-event-header.html
void EventHeader::parse(ReadBuffer & payload)
@ -59,7 +63,7 @@ namespace MySQLReplication
out << "Binlog Version: " << this->binlog_version << std::endl;
out << "Server Version: " << this->server_version << std::endl;
out << "Create Timestamp: " << this->create_timestamp << std::endl;
out << "Event Header Len: " << this->event_header_length << std::endl;
out << "Event Header Len: " << std::to_string(this->event_header_length) << std::endl;
}
/// https://dev.mysql.com/doc/internals/en/rotate-event.html
@ -119,7 +123,7 @@ namespace MySQLReplication
header.dump(out);
out << "Thread ID: " << this->thread_id << std::endl;
out << "Execution Time: " << this->exec_time << std::endl;
out << "Schema Len: " << this->schema_len << std::endl;
out << "Schema Len: " << std::to_string(this->schema_len) << std::endl;
out << "Error Code: " << this->error_code << std::endl;
out << "Status Len: " << this->status_len << std::endl;
out << "Schema: " << this->schema << std::endl;
@ -239,14 +243,14 @@ namespace MySQLReplication
header.dump(out);
out << "Table ID: " << this->table_id << std::endl;
out << "Flags: " << this->flags << std::endl;
out << "Schema Len: " << this->schema_len << std::endl;
out << "Schema Len: " << std::to_string(this->schema_len) << std::endl;
out << "Schema: " << this->schema << std::endl;
out << "Table Len: " << this->table_len << std::endl;
out << "Table Len: " << std::to_string(this->table_len) << std::endl;
out << "Table: " << this->table << std::endl;
out << "Column Count: " << this->column_count << std::endl;
for (auto i = 0U; i < column_count; i++)
{
out << "Column Type [" << i << "]: " << column_type[i] << ", Meta: " << column_meta[i] << std::endl;
out << "Column Type [" << i << "]: " << std::to_string(column_type[i]) << ", Meta: " << column_meta[i] << std::endl;
}
out << "Null Bitmap: " << this->null_bitmap << std::endl;
}
@ -717,8 +721,8 @@ namespace MySQLReplication
case PACKET_EOF:
throw ReplicationError("Master maybe lost", ErrorCodes::UNKNOWN_EXCEPTION);
case PACKET_ERR:
ERR_Packet err;
err.readPayloadImpl(payload);
ERRPacket err;
err.readPayloadWithUnpacked(payload);
throw ReplicationError(err.error_message, ErrorCodes::UNKNOWN_EXCEPTION);
}
// skip the header flag.

View File

@ -1,6 +1,6 @@
#pragma once
#include <Core/Field.h>
#include <Core/MySQLProtocol.h>
#include <Core/MySQL/PacketsReplication.h>
#include <Core/Types.h>
#include <IO/ReadBuffer.h>
#include <IO/WriteBuffer.h>
@ -13,6 +13,7 @@
namespace DB
{
namespace MySQLReplication
{
static const int EVENT_VERSION_V4 = 4;
@ -465,7 +466,7 @@ namespace MySQLReplication
void updateLogName(String binlog) { binlog_name = std::move(binlog); }
};
class IFlavor : public MySQLProtocol::ReadPacket
class IFlavor : public MySQLProtocol::IMySQLReadPacket
{
public:
virtual String getName() const = 0;

View File

@ -0,0 +1,71 @@
#include <Core/MySQL/PacketEndpoint.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <Common/typeid_cast.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
namespace MySQLProtocol
{
PacketEndpoint::PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_)
: sequence_id(sequence_id_), in(nullptr), out(&out_)
{
}
PacketEndpoint::PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_)
: sequence_id(sequence_id_), in(&in_), out(&out_)
{
}
MySQLPacketPayloadReadBuffer PacketEndpoint::getPayload()
{
return MySQLPacketPayloadReadBuffer(*in, sequence_id);
}
void PacketEndpoint::receivePacket(IMySQLReadPacket & packet)
{
packet.readPayload(*in, sequence_id);
}
bool PacketEndpoint::tryReceivePacket(IMySQLReadPacket & packet, UInt64 millisecond)
{
if (millisecond != 0)
{
ReadBufferFromPocoSocket * socket_in = typeid_cast<ReadBufferFromPocoSocket *>(in);
if (!socket_in)
throw Exception("LOGICAL ERROR: Attempt to pull the duration in a non socket stream", ErrorCodes::LOGICAL_ERROR);
if (!socket_in->poll(millisecond * 1000))
return false;
}
packet.readPayload(*in, sequence_id);
return true;
}
void PacketEndpoint::resetSequenceId()
{
sequence_id = 0;
}
String PacketEndpoint::packetToText(const String & payload)
{
String result;
for (auto c : payload)
{
result += ' ';
result += std::to_string(static_cast<unsigned char>(c));
}
return result;
}
}
}

View File

@ -0,0 +1,55 @@
#pragma once
#include <IO/ReadBuffer.h>
#include <IO/WriteBuffer.h>
#include "IMySQLReadPacket.h"
#include "IMySQLWritePacket.h"
#include "IO/MySQLPacketPayloadReadBuffer.h"
namespace DB
{
namespace MySQLProtocol
{
/* Writes and reads packets, keeping sequence-id.
* Throws ProtocolError, if packet with incorrect sequence-id was received.
*/
class PacketEndpoint
{
public:
uint8_t & sequence_id;
ReadBuffer * in;
WriteBuffer * out;
/// For writing.
PacketEndpoint(WriteBuffer & out_, uint8_t & sequence_id_);
/// For reading and writing.
PacketEndpoint(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_);
MySQLPacketPayloadReadBuffer getPayload();
void receivePacket(IMySQLReadPacket & packet);
bool tryReceivePacket(IMySQLReadPacket & packet, UInt64 millisecond = 0);
/// Sets sequence-id to 0. Must be called before each command phase.
void resetSequenceId();
template<class T>
void sendPacket(const T & packet, bool flush = false)
{
static_assert(std::is_base_of<IMySQLWritePacket, T>());
packet.writePayload(*out, sequence_id);
if (flush)
out->next();
}
/// Converts packet to text. Is used for debug output.
static String packetToText(const String & payload);
};
}
}

View File

@ -0,0 +1,242 @@
#include <Core/MySQL/PacketsConnection.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Core/MySQL/PacketsGeneric.h>
namespace DB
{
namespace MySQLProtocol
{
using namespace Generic;
namespace ConnectionPhase
{
static const size_t SCRAMBLE_LENGTH = 20;
static const size_t AUTH_PLUGIN_DATA_PART_1_LENGTH = 8;
Handshake::Handshake() : connection_id(0x00), capability_flags(0x00), character_set(0x00), status_flags(0x00)
{
}
Handshake::Handshake(
uint32_t capability_flags_, uint32_t connection_id_,
String server_version_, String auth_plugin_name_, String auth_plugin_data_, uint8_t charset_)
: protocol_version(0xa), server_version(std::move(server_version_)), connection_id(connection_id_), capability_flags(capability_flags_),
character_set(charset_), status_flags(0), auth_plugin_name(std::move(auth_plugin_name_)),
auth_plugin_data(std::move(auth_plugin_data_))
{
}
size_t Handshake::getPayloadSize() const
{
return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size();
}
void Handshake::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&protocol_version), 1);
readNullTerminated(server_version, payload);
payload.readStrict(reinterpret_cast<char *>(&connection_id), 4);
auth_plugin_data.resize(AUTH_PLUGIN_DATA_PART_1_LENGTH);
payload.readStrict(auth_plugin_data.data(), AUTH_PLUGIN_DATA_PART_1_LENGTH);
payload.ignore(1);
payload.readStrict(reinterpret_cast<char *>(&capability_flags), 2);
payload.readStrict(reinterpret_cast<char *>(&character_set), 1);
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
payload.readStrict((reinterpret_cast<char *>(&capability_flags)) + 2, 2);
UInt8 auth_plugin_data_length = 0;
if (capability_flags & Capability::CLIENT_PLUGIN_AUTH)
{
payload.readStrict(reinterpret_cast<char *>(&auth_plugin_data_length), 1);
}
else
{
payload.ignore(1);
}
payload.ignore(10);
if (capability_flags & Capability::CLIENT_SECURE_CONNECTION)
{
UInt8 part2_length = (SCRAMBLE_LENGTH - AUTH_PLUGIN_DATA_PART_1_LENGTH);
auth_plugin_data.resize(SCRAMBLE_LENGTH);
payload.readStrict(auth_plugin_data.data() + AUTH_PLUGIN_DATA_PART_1_LENGTH, part2_length);
payload.ignore(1);
}
if (capability_flags & Capability::CLIENT_PLUGIN_AUTH)
{
readNullTerminated(auth_plugin_name, payload);
}
}
void Handshake::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(static_cast<char>(protocol_version));
writeNulTerminatedString(server_version, buffer);
buffer.write(reinterpret_cast<const char *>(&connection_id), 4);
writeNulTerminatedString(auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
buffer.write(reinterpret_cast<const char *>(&capability_flags), 2);
buffer.write(reinterpret_cast<const char *>(&character_set), 1);
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
buffer.write((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
buffer.write(static_cast<char>(auth_plugin_data.size()));
writeChar(0x0, 10, buffer);
writeString(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
writeString(auth_plugin_name, buffer);
writeChar(0x0, 1, buffer);
}
HandshakeResponse::HandshakeResponse() : capability_flags(0x00), max_packet_size(0x00), character_set(0x00)
{
}
HandshakeResponse::HandshakeResponse(
UInt32 capability_flags_, UInt32 max_packet_size_, UInt8 character_set_, const String & username_, const String & database_,
const String & auth_response_, const String & auth_plugin_name_)
: capability_flags(capability_flags_), max_packet_size(max_packet_size_), character_set(character_set_), username(std::move(username_)),
database(std::move(database_)), auth_response(std::move(auth_response_)), auth_plugin_name(std::move(auth_plugin_name_))
{
}
size_t HandshakeResponse::getPayloadSize() const
{
size_t size = 0;
size += 4 + 4 + 1 + 23;
size += username.size() + 1;
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
size += getLengthEncodedStringSize(auth_response);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
size += (1 + auth_response.size());
}
else
{
size += (auth_response.size() + 1);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
size += (database.size() + 1);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
size += (auth_plugin_name.size() + 1);
}
return size;
}
void HandshakeResponse::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
payload.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
payload.readStrict(reinterpret_cast<char *>(&character_set), 1);
payload.ignore(23);
readNullTerminated(username, payload);
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
readLengthEncodedString(auth_response, payload);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
char len;
payload.readStrict(len);
auth_response.resize(static_cast<unsigned int>(len));
payload.readStrict(auth_response.data(), len);
}
else
{
readNullTerminated(auth_response, payload);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
readNullTerminated(database, payload);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
readNullTerminated(auth_plugin_name, payload);
}
}
void HandshakeResponse::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(reinterpret_cast<const char *>(&capability_flags), 4);
buffer.write(reinterpret_cast<const char *>(&max_packet_size), 4);
buffer.write(reinterpret_cast<const char *>(&character_set), 1);
writeChar(0x0, 23, buffer);
writeNulTerminatedString(username, buffer);
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
writeLengthEncodedString(auth_response, buffer);
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
writeChar(auth_response.size(), buffer);
writeString(auth_response.data(), auth_response.size(), buffer);
}
else
{
writeNulTerminatedString(auth_response, buffer);
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
writeNulTerminatedString(database, buffer);
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
writeNulTerminatedString(auth_plugin_name, buffer);
}
}
AuthSwitchRequest::AuthSwitchRequest(String plugin_name_, String auth_plugin_data_)
: plugin_name(std::move(plugin_name_)), auth_plugin_data(std::move(auth_plugin_data_))
{
}
size_t AuthSwitchRequest::getPayloadSize() const
{
return 2 + plugin_name.size() + auth_plugin_data.size();
}
void AuthSwitchRequest::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(0xfe);
writeNulTerminatedString(plugin_name, buffer);
writeString(auth_plugin_data, buffer);
}
void AuthSwitchResponse::readPayloadImpl(ReadBuffer & payload)
{
readStringUntilEOF(value, payload);
}
size_t AuthMoreData::getPayloadSize() const
{
return 1 + data.size();
}
void AuthMoreData::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(0x01);
writeString(data, buffer);
}
}
}
}

View File

@ -0,0 +1,110 @@
#pragma once
#include <Core/MySQL/IMySQLReadPacket.h>
#include <Core/MySQL/IMySQLWritePacket.h>
namespace DB
{
namespace MySQLProtocol
{
namespace ConnectionPhase
{
class Handshake : public IMySQLWritePacket, public IMySQLReadPacket
{
public:
int protocol_version = 0xa;
String server_version;
uint32_t connection_id;
uint32_t capability_flags;
uint8_t character_set;
uint32_t status_flags;
String auth_plugin_name;
String auth_plugin_data;
protected:
size_t getPayloadSize() const override;
void readPayloadImpl(ReadBuffer & payload) override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
Handshake();
Handshake(
uint32_t capability_flags_, uint32_t connection_id_,
String server_version_, String auth_plugin_name_, String auth_plugin_data_, uint8_t charset_);
};
class HandshakeResponse : public IMySQLWritePacket, public IMySQLReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
String username;
String database;
String auth_response;
String auth_plugin_name;
protected:
size_t getPayloadSize() const override;
void readPayloadImpl(ReadBuffer & payload) override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
HandshakeResponse();
HandshakeResponse(
UInt32 capability_flags_, UInt32 max_packet_size_, UInt8 character_set_,
const String & username_, const String & database_, const String & auth_response_, const String & auth_plugin_name_);
};
class AuthSwitchRequest : public IMySQLWritePacket
{
public:
String plugin_name;
String auth_plugin_data;
protected:
size_t getPayloadSize() const override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
AuthSwitchRequest(String plugin_name_, String auth_plugin_data_);
};
class AuthSwitchResponse : public LimitedReadPacket
{
public:
String value;
protected:
void readPayloadImpl(ReadBuffer & payload) override;
};
class AuthMoreData : public IMySQLWritePacket
{
public:
String data;
protected:
size_t getPayloadSize() const override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
explicit AuthMoreData(String data_): data(std::move(data_)) {}
};
}
}
}

View File

@ -0,0 +1,262 @@
#include <Core/MySQL/PacketsGeneric.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
namespace DB
{
namespace MySQLProtocol
{
namespace Generic
{
static const size_t MYSQL_ERRMSG_SIZE = 512;
void SSLRequest::readPayloadImpl(ReadBuffer & buf)
{
buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
buf.readStrict(reinterpret_cast<char *>(&character_set), 1);
}
OKPacket::OKPacket(uint32_t capabilities_)
: header(0x00), capabilities(capabilities_), affected_rows(0x00), last_insert_id(0x00), status_flags(0x00)
{
}
OKPacket::OKPacket(
uint8_t header_, uint32_t capabilities_, uint64_t affected_rows_, uint32_t status_flags_, int16_t warnings_,
String session_state_changes_, String info_)
: header(header_), capabilities(capabilities_), affected_rows(affected_rows_), last_insert_id(0), warnings(warnings_),
status_flags(status_flags_), session_state_changes(std::move(session_state_changes_)), info(std::move(info_))
{
}
size_t OKPacket::getPayloadSize() const
{
size_t result = 2 + getLengthEncodedNumberSize(affected_rows);
if (capabilities & CLIENT_PROTOCOL_41)
{
result += 4;
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
result += 2;
}
if (capabilities & CLIENT_SESSION_TRACK)
{
result += getLengthEncodedStringSize(info);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
result += getLengthEncodedStringSize(session_state_changes);
}
else
{
result += info.size();
}
return result;
}
void OKPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
affected_rows = readLengthEncodedNumber(payload);
last_insert_id = readLengthEncodedNumber(payload);
if (capabilities & CLIENT_PROTOCOL_41)
{
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
readLengthEncodedString(info, payload);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
{
readLengthEncodedString(session_state_changes, payload);
}
}
else
{
readString(info, payload);
}
}
void OKPacket::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(header);
writeLengthEncodedNumber(affected_rows, buffer);
writeLengthEncodedNumber(last_insert_id, buffer); /// last insert-id
if (capabilities & CLIENT_PROTOCOL_41)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
writeLengthEncodedString(info, buffer);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
writeLengthEncodedString(session_state_changes, buffer);
}
else
{
writeString(info, buffer);
}
}
EOFPacket::EOFPacket() : warnings(0x00), status_flags(0x00)
{
}
EOFPacket::EOFPacket(int warnings_, int status_flags_)
: warnings(warnings_), status_flags(status_flags_)
{
}
size_t EOFPacket::getPayloadSize() const
{
return 5;
}
void EOFPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xfe);
payload.readStrict(reinterpret_cast<char *>(&warnings), 2);
payload.readStrict(reinterpret_cast<char *>(&status_flags), 2);
}
void EOFPacket::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(header); // EOF header
buffer.write(reinterpret_cast<const char *>(&warnings), 2);
buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
}
void AuthSwitchPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xfe);
readStringUntilEOF(plugin_name, payload);
}
ERRPacket::ERRPacket() : error_code(0x00)
{
}
ERRPacket::ERRPacket(int error_code_, String sql_state_, String error_message_)
: error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_))
{
}
size_t ERRPacket::getPayloadSize() const
{
return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE);
}
void ERRPacket::readPayloadImpl(ReadBuffer & payload)
{
payload.readStrict(reinterpret_cast<char *>(&header), 1);
assert(header == 0xff);
payload.readStrict(reinterpret_cast<char *>(&error_code), 2);
/// SQL State [optional: # + 5bytes string]
UInt8 sharp = static_cast<unsigned char>(*payload.position());
if (sharp == 0x23)
{
payload.ignore(1);
sql_state.resize(5);
payload.readStrict(reinterpret_cast<char *>(sql_state.data()), 5);
}
readString(error_message, payload);
}
void ERRPacket::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(header);
buffer.write(reinterpret_cast<const char *>(&error_code), 2);
buffer.write('#');
buffer.write(sql_state.data(), sql_state.length());
buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
}
ResponsePacket::ResponsePacket(UInt32 server_capability_flags_)
: ok(OKPacket(server_capability_flags_))
{
}
ResponsePacket::ResponsePacket(UInt32 server_capability_flags_, bool is_handshake_)
: ok(OKPacket(server_capability_flags_)), is_handshake(is_handshake_)
{
}
void ResponsePacket::readPayloadImpl(ReadBuffer & payload)
{
UInt16 header = static_cast<unsigned char>(*payload.position());
switch (header)
{
case PACKET_OK:
packetType = PACKET_OK;
ok.readPayloadWithUnpacked(payload);
break;
case PACKET_ERR:
packetType = PACKET_ERR;
err.readPayloadWithUnpacked(payload);
break;
case PACKET_EOF:
if (is_handshake)
{
packetType = PACKET_AUTH_SWITCH;
auth_switch.readPayloadWithUnpacked(payload);
}
else
{
packetType = PACKET_EOF;
eof.readPayloadWithUnpacked(payload);
}
break;
case PACKET_LOCALINFILE:
packetType = PACKET_LOCALINFILE;
break;
default:
packetType = PACKET_OK;
column_length = readLengthEncodedNumber(payload);
}
}
LengthEncodedNumber::LengthEncodedNumber(uint64_t value_) : value(value_)
{
}
size_t LengthEncodedNumber::getPayloadSize() const
{
return getLengthEncodedNumberSize(value);
}
void LengthEncodedNumber::writePayloadImpl(WriteBuffer & buffer) const
{
writeLengthEncodedNumber(value, buffer);
}
}
}
}

View File

@ -0,0 +1,175 @@
#pragma once
#include <Core/MySQL/IMySQLReadPacket.h>
#include <Core/MySQL/IMySQLWritePacket.h>
namespace DB
{
namespace MySQLProtocol
{
namespace Generic
{
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
enum StatusFlags
{
SERVER_SESSION_STATE_CHANGED = 0x4000
};
enum Capability
{
CLIENT_CONNECT_WITH_DB = 0x00000008,
CLIENT_PROTOCOL_41 = 0x00000200,
CLIENT_SSL = 0x00000800,
CLIENT_TRANSACTIONS = 0x00002000, // TODO
CLIENT_SESSION_TRACK = 0x00800000, // TODO
CLIENT_SECURE_CONNECTION = 0x00008000,
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000,
CLIENT_PLUGIN_AUTH = 0x00080000,
CLIENT_DEPRECATE_EOF = 0x01000000,
};
class SSLRequest : public IMySQLReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
protected:
void readPayloadImpl(ReadBuffer & buf) override;
};
class OKPacket : public IMySQLWritePacket, public IMySQLReadPacket
{
public:
uint8_t header;
uint32_t capabilities;
uint64_t affected_rows;
uint64_t last_insert_id;
int16_t warnings = 0;
uint32_t status_flags;
String session_state_changes;
String info;
protected:
size_t getPayloadSize() const override;
void readPayloadImpl(ReadBuffer & payload) override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
OKPacket(uint32_t capabilities_);
OKPacket(uint8_t header_, uint32_t capabilities_, uint64_t affected_rows_,
uint32_t status_flags_, int16_t warnings_, String session_state_changes_ = "", String info_ = "");
};
class EOFPacket : public IMySQLWritePacket, public IMySQLReadPacket
{
public:
UInt8 header = 0xfe;
int warnings;
int status_flags;
protected:
size_t getPayloadSize() const override;
void readPayloadImpl(ReadBuffer & payload) override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
EOFPacket();
EOFPacket(int warnings_, int status_flags_);
};
class ERRPacket : public IMySQLWritePacket, public IMySQLReadPacket
{
public:
UInt8 header = 0xff;
int error_code;
String sql_state;
String error_message;
protected:
size_t getPayloadSize() const override;
void readPayloadImpl(ReadBuffer & payload) override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
ERRPacket();
ERRPacket(int error_code_, String sql_state_, String error_message_);
};
class AuthSwitchPacket : public IMySQLReadPacket
{
public:
String plugin_name;
AuthSwitchPacket() = default;
protected:
UInt8 header = 0x00;
void readPayloadImpl(ReadBuffer & payload) override;
};
enum ResponsePacketType
{
PACKET_OK = 0x00,
PACKET_ERR = 0xff,
PACKET_EOF = 0xfe,
PACKET_AUTH_SWITCH = 0xfe,
PACKET_LOCALINFILE = 0xfb,
};
/// https://dev.mysql.com/doc/internals/en/generic-response-packets.html
class ResponsePacket : public IMySQLReadPacket
{
public:
OKPacket ok;
ERRPacket err;
EOFPacket eof;
AuthSwitchPacket auth_switch;
UInt64 column_length = 0;
ResponsePacketType getType() { return packetType; }
protected:
bool is_handshake = false;
ResponsePacketType packetType = PACKET_OK;
void readPayloadImpl(ReadBuffer & payload) override;
public:
ResponsePacket(UInt32 server_capability_flags_);
ResponsePacket(UInt32 server_capability_flags_, bool is_handshake_);
};
class LengthEncodedNumber : public IMySQLWritePacket
{
protected:
uint64_t value;
size_t getPayloadSize() const override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
explicit LengthEncodedNumber(uint64_t value_);
};
}
}
}

View File

@ -0,0 +1,192 @@
#include <Core/MySQL/PacketsProtocolText.h>
#include <IO/WriteBufferFromString.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
namespace DB
{
namespace MySQLProtocol
{
namespace ProtocolText
{
ResultSetRow::ResultSetRow(const DataTypes & data_types, const Columns & columns_, int row_num_)
: columns(columns_), row_num(row_num_)
{
for (size_t i = 0; i < columns.size(); i++)
{
if (columns[i]->isNullAt(row_num))
{
payload_size += 1;
serialized.emplace_back("\xfb");
}
else
{
WriteBufferFromOwnString ostr;
data_types[i]->serializeAsText(*columns[i], row_num, ostr, FormatSettings());
payload_size += getLengthEncodedStringSize(ostr.str());
serialized.push_back(std::move(ostr.str()));
}
}
}
size_t ResultSetRow::getPayloadSize() const
{
return payload_size;
}
void ResultSetRow::writePayloadImpl(WriteBuffer & buffer) const
{
for (size_t i = 0; i < columns.size(); i++)
{
if (columns[i]->isNullAt(row_num))
buffer.write(serialized[i].data(), 1);
else
writeLengthEncodedString(serialized[i], buffer);
}
}
void ComFieldList::readPayloadImpl(ReadBuffer & payload)
{
// Command byte has been already read from payload.
readNullTerminated(table, payload);
readStringUntilEOF(field_wildcard, payload);
}
ColumnDefinition::ColumnDefinition()
: character_set(0x00), column_length(0), column_type(MYSQL_TYPE_DECIMAL), flags(0x00)
{
}
ColumnDefinition::ColumnDefinition(
String schema_, String table_, String org_table_, String name_, String org_name_, uint16_t character_set_, uint32_t column_length_,
ColumnType column_type_, uint16_t flags_, uint8_t decimals_)
: schema(std::move(schema_)), table(std::move(table_)), org_table(std::move(org_table_)), name(std::move(name_)),
org_name(std::move(org_name_)), character_set(character_set_), column_length(column_length_), column_type(column_type_),
flags(flags_), decimals(decimals_)
{
}
ColumnDefinition::ColumnDefinition(
String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_)
: ColumnDefinition("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_)
{
}
size_t ColumnDefinition::getPayloadSize() const
{
return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \
getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length);
}
void ColumnDefinition::readPayloadImpl(ReadBuffer & payload)
{
String def;
readLengthEncodedString(def, payload);
assert(def == "def");
readLengthEncodedString(schema, payload);
readLengthEncodedString(table, payload);
readLengthEncodedString(org_table, payload);
readLengthEncodedString(name, payload);
readLengthEncodedString(org_name, payload);
next_length = readLengthEncodedNumber(payload);
payload.readStrict(reinterpret_cast<char *>(&character_set), 2);
payload.readStrict(reinterpret_cast<char *>(&column_length), 4);
payload.readStrict(reinterpret_cast<char *>(&column_type), 1);
payload.readStrict(reinterpret_cast<char *>(&flags), 2);
payload.readStrict(reinterpret_cast<char *>(&decimals), 2);
payload.ignore(2);
}
void ColumnDefinition::writePayloadImpl(WriteBuffer & buffer) const
{
writeLengthEncodedString(std::string("def"), buffer); /// always "def"
writeLengthEncodedString(schema, buffer);
writeLengthEncodedString(table, buffer);
writeLengthEncodedString(org_table, buffer);
writeLengthEncodedString(name, buffer);
writeLengthEncodedString(org_name, buffer);
writeLengthEncodedNumber(next_length, buffer);
buffer.write(reinterpret_cast<const char *>(&character_set), 2);
buffer.write(reinterpret_cast<const char *>(&column_length), 4);
buffer.write(reinterpret_cast<const char *>(&column_type), 1);
buffer.write(reinterpret_cast<const char *>(&flags), 2);
buffer.write(reinterpret_cast<const char *>(&decimals), 2);
writeChar(0x0, 2, buffer);
}
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
{
ColumnType column_type;
CharacterSet charset = CharacterSet::binary;
int flags = 0;
switch (type_index)
{
case TypeIndex::UInt8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::Int8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float32:
column_type = ColumnType::MYSQL_TYPE_FLOAT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float64:
column_type = ColumnType::MYSQL_TYPE_DOUBLE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Date:
column_type = ColumnType::MYSQL_TYPE_DATE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::DateTime:
column_type = ColumnType::MYSQL_TYPE_DATETIME;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::String:
case TypeIndex::FixedString:
column_type = ColumnType::MYSQL_TYPE_STRING;
charset = CharacterSet::utf8_general_ci;
break;
default:
column_type = ColumnType::MYSQL_TYPE_STRING;
charset = CharacterSet::utf8_general_ci;
break;
}
return ColumnDefinition(column_name, charset, 0, column_type, flags, 0);
}
}
}
}

View File

@ -0,0 +1,157 @@
#pragma once
#include <Columns/IColumn.h>
#include <DataTypes/IDataType.h>
#include <Core/MySQL/IMySQLReadPacket.h>
#include <Core/MySQL/IMySQLWritePacket.h>
namespace DB
{
namespace MySQLProtocol
{
namespace ProtocolText
{
enum CharacterSet
{
utf8_general_ci = 33,
binary = 63
};
// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html
enum ColumnDefinitionFlags
{
UNSIGNED_FLAG = 32,
BINARY_FLAG = 128
};
enum ColumnType
{
MYSQL_TYPE_DECIMAL = 0x00,
MYSQL_TYPE_TINY = 0x01,
MYSQL_TYPE_SHORT = 0x02,
MYSQL_TYPE_LONG = 0x03,
MYSQL_TYPE_FLOAT = 0x04,
MYSQL_TYPE_DOUBLE = 0x05,
MYSQL_TYPE_NULL = 0x06,
MYSQL_TYPE_TIMESTAMP = 0x07,
MYSQL_TYPE_LONGLONG = 0x08,
MYSQL_TYPE_INT24 = 0x09,
MYSQL_TYPE_DATE = 0x0a,
MYSQL_TYPE_TIME = 0x0b,
MYSQL_TYPE_DATETIME = 0x0c,
MYSQL_TYPE_YEAR = 0x0d,
MYSQL_TYPE_NEWDATE = 0x0e,
MYSQL_TYPE_VARCHAR = 0x0f,
MYSQL_TYPE_BIT = 0x10,
MYSQL_TYPE_TIMESTAMP2 = 0x11,
MYSQL_TYPE_DATETIME2 = 0x12,
MYSQL_TYPE_TIME2 = 0x13,
MYSQL_TYPE_JSON = 0xf5,
MYSQL_TYPE_NEWDECIMAL = 0xf6,
MYSQL_TYPE_ENUM = 0xf7,
MYSQL_TYPE_SET = 0xf8,
MYSQL_TYPE_TINY_BLOB = 0xf9,
MYSQL_TYPE_MEDIUM_BLOB = 0xfa,
MYSQL_TYPE_LONG_BLOB = 0xfb,
MYSQL_TYPE_BLOB = 0xfc,
MYSQL_TYPE_VAR_STRING = 0xfd,
MYSQL_TYPE_STRING = 0xfe,
MYSQL_TYPE_GEOMETRY = 0xff
};
enum Command
{
COM_SLEEP = 0x0,
COM_QUIT = 0x1,
COM_INIT_DB = 0x2,
COM_QUERY = 0x3,
COM_FIELD_LIST = 0x4,
COM_CREATE_DB = 0x5,
COM_DROP_DB = 0x6,
COM_REFRESH = 0x7,
COM_SHUTDOWN = 0x8,
COM_STATISTICS = 0x9,
COM_PROCESS_INFO = 0xa,
COM_CONNECT = 0xb,
COM_PROCESS_KILL = 0xc,
COM_DEBUG = 0xd,
COM_PING = 0xe,
COM_TIME = 0xf,
COM_DELAYED_INSERT = 0x10,
COM_CHANGE_USER = 0x11,
COM_BINLOG_DUMP = 0x12,
COM_REGISTER_SLAVE = 0x15,
COM_RESET_CONNECTION = 0x1f,
COM_DAEMON = 0x1d
};
class ResultSetRow : public IMySQLWritePacket
{
protected:
const Columns & columns;
int row_num;
size_t payload_size = 0;
std::vector<String> serialized;
size_t getPayloadSize() const override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
ResultSetRow(const DataTypes & data_types, const Columns & columns_, int row_num_);
};
class ComFieldList : public LimitedReadPacket
{
public:
String table, field_wildcard;
void readPayloadImpl(ReadBuffer & payload) override;
};
class ColumnDefinition : public IMySQLWritePacket, public IMySQLReadPacket
{
public:
String schema;
String table;
String org_table;
String name;
String org_name;
size_t next_length = 0x0c;
uint16_t character_set;
uint32_t column_length;
ColumnType column_type;
uint16_t flags;
uint8_t decimals = 0x00;
protected:
size_t getPayloadSize() const override;
void readPayloadImpl(ReadBuffer & payload) override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
ColumnDefinition();
ColumnDefinition(
String schema_, String table_, String org_table_, String name_, String org_name_, uint16_t character_set_, uint32_t column_length_,
ColumnType column_type_, uint16_t flags_, uint8_t decimals_);
/// Should be used when column metadata (original name, table, original table, database) is unknown.
ColumnDefinition(
String name_, uint16_t character_set_, uint32_t column_length_, ColumnType column_type_, uint16_t flags_, uint8_t decimals_);
};
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex index);
}
}
}

View File

@ -0,0 +1,62 @@
#include <Core/MySQL/PacketsReplication.h>
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Core/MySQL/PacketsProtocolText.h>
namespace DB
{
namespace MySQLProtocol
{
namespace Replication
{
RegisterSlave::RegisterSlave(UInt32 server_id_)
: server_id(server_id_), slaves_mysql_port(0x00), replication_rank(0x00), master_id(0x00)
{
}
size_t RegisterSlave::getPayloadSize() const
{
return 1 + 4 + getLengthEncodedStringSize(slaves_hostname) + getLengthEncodedStringSize(slaves_users)
+ getLengthEncodedStringSize(slaves_password) + 2 + 4 + 4;
}
void RegisterSlave::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(ProtocolText::COM_REGISTER_SLAVE);
buffer.write(reinterpret_cast<const char *>(&server_id), 4);
writeLengthEncodedString(slaves_hostname, buffer);
writeLengthEncodedString(slaves_users, buffer);
writeLengthEncodedString(slaves_password, buffer);
buffer.write(reinterpret_cast<const char *>(&slaves_mysql_port), 2);
buffer.write(reinterpret_cast<const char *>(&replication_rank), 4);
buffer.write(reinterpret_cast<const char *>(&master_id), 4);
}
BinlogDump::BinlogDump(UInt32 binlog_pos_, String binlog_file_name_, UInt32 server_id_)
: binlog_pos(binlog_pos_), flags(0x00), server_id(server_id_), binlog_file_name(std::move(binlog_file_name_))
{
}
size_t BinlogDump::getPayloadSize() const
{
return 1 + 4 + 2 + 4 + binlog_file_name.size() + 1;
}
void BinlogDump::writePayloadImpl(WriteBuffer & buffer) const
{
buffer.write(ProtocolText::COM_BINLOG_DUMP);
buffer.write(reinterpret_cast<const char *>(&binlog_pos), 4);
buffer.write(reinterpret_cast<const char *>(&flags), 2);
buffer.write(reinterpret_cast<const char *>(&server_id), 4);
buffer.write(binlog_file_name.data(), binlog_file_name.length());
buffer.write(0x00);
}
}
}
}

View File

@ -0,0 +1,61 @@
#pragma once
#include <IO/MySQLPacketPayloadReadBuffer.h>
#include <IO/MySQLPacketPayloadWriteBuffer.h>
#include <Core/MySQL/PacketEndpoint.h>
/// Implementation of MySQL wire protocol.
/// Works only on little-endian architecture.
namespace DB
{
namespace MySQLProtocol
{
namespace Replication
{
/// https://dev.mysql.com/doc/internals/en/com-register-slave.html
class RegisterSlave : public IMySQLWritePacket
{
public:
UInt32 server_id;
String slaves_hostname;
String slaves_users;
String slaves_password;
size_t slaves_mysql_port;
UInt32 replication_rank;
UInt32 master_id;
protected:
size_t getPayloadSize() const override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
RegisterSlave(UInt32 server_id_);
};
/// https://dev.mysql.com/doc/internals/en/com-binlog-dump.html
class BinlogDump : public IMySQLWritePacket
{
public:
UInt32 binlog_pos;
UInt16 flags;
UInt32 server_id;
String binlog_file_name;
protected:
size_t getPayloadSize() const override;
void writePayloadImpl(WriteBuffer & buffer) const override;
public:
BinlogDump(UInt32 binlog_pos_, String binlog_file_name_, UInt32 server_id_);
};
}
}
}

View File

@ -1,171 +0,0 @@
#include "MySQLProtocol.h"
#include <IO/WriteBuffer.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromString.h>
#include <common/logger_useful.h>
#include <random>
#include <sstream>
namespace DB::MySQLProtocol
{
void PacketSender::resetSequenceId()
{
sequence_id = 0;
}
String PacketSender::packetToText(const String & payload)
{
String result;
for (auto c : payload)
{
result += ' ';
result += std::to_string(static_cast<unsigned char>(c));
}
return result;
}
uint64_t readLengthEncodedNumber(ReadBuffer & ss)
{
char c{};
uint64_t buf = 0;
ss.readStrict(c);
auto cc = static_cast<uint8_t>(c);
if (cc < 0xfc)
{
return cc;
}
else if (cc < 0xfd)
{
ss.readStrict(reinterpret_cast<char *>(&buf), 2);
}
else if (cc < 0xfe)
{
ss.readStrict(reinterpret_cast<char *>(&buf), 3);
}
else
{
ss.readStrict(reinterpret_cast<char *>(&buf), 8);
}
return buf;
}
void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer)
{
if (x < 251)
{
buffer.write(static_cast<char>(x));
}
else if (x < (1 << 16))
{
buffer.write(0xfc);
buffer.write(reinterpret_cast<char *>(&x), 2);
}
else if (x < (1 << 24))
{
buffer.write(0xfd);
buffer.write(reinterpret_cast<char *>(&x), 3);
}
else
{
buffer.write(0xfe);
buffer.write(reinterpret_cast<char *>(&x), 8);
}
}
size_t getLengthEncodedNumberSize(uint64_t x)
{
if (x < 251)
{
return 1;
}
else if (x < (1 << 16))
{
return 3;
}
else if (x < (1 << 24))
{
return 4;
}
else
{
return 9;
}
}
size_t getLengthEncodedStringSize(const String & s)
{
return getLengthEncodedNumberSize(s.size()) + s.size();
}
ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex type_index)
{
ColumnType column_type;
CharacterSet charset = CharacterSet::binary;
int flags = 0;
switch (type_index)
{
case TypeIndex::UInt8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::UInt64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG | ColumnDefinitionFlags::UNSIGNED_FLAG;
break;
case TypeIndex::Int8:
column_type = ColumnType::MYSQL_TYPE_TINY;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int16:
column_type = ColumnType::MYSQL_TYPE_SHORT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int32:
column_type = ColumnType::MYSQL_TYPE_LONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Int64:
column_type = ColumnType::MYSQL_TYPE_LONGLONG;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float32:
column_type = ColumnType::MYSQL_TYPE_FLOAT;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Float64:
column_type = ColumnType::MYSQL_TYPE_DOUBLE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::Date:
column_type = ColumnType::MYSQL_TYPE_DATE;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::DateTime:
column_type = ColumnType::MYSQL_TYPE_DATETIME;
flags = ColumnDefinitionFlags::BINARY_FLAG;
break;
case TypeIndex::String:
case TypeIndex::FixedString:
column_type = ColumnType::MYSQL_TYPE_STRING;
charset = CharacterSet::utf8_general_ci;
break;
default:
column_type = ColumnType::MYSQL_TYPE_STRING;
charset = CharacterSet::utf8_general_ci;
break;
}
return ColumnDefinition(column_name, charset, 0, column_type, flags, 0);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -347,7 +347,6 @@ class IColumn;
M(UInt64, min_free_disk_space_for_temporary_data, 0, "The minimum disk space to keep while writing temporary data used in external sorting and aggregation.", 0) \
\
M(DefaultDatabaseEngine, default_database_engine, DefaultDatabaseEngine::Ordinary, "Default database engine.", 0) \
M(Bool, allow_experimental_database_atomic, true, "Allow to create database with Engine=Atomic.", 0) \
M(Bool, show_table_uuid_in_table_create_query_if_not_nil, false, "For tables in databases with Engine=Atomic show UUID of the table in its CREATE query.", 0) \
M(Bool, enable_scalar_subquery_optimization, true, "If it is set to true, prevent scalar subqueries from (de)serializing large scalar values and possibly avoid running the same subquery more than once.", 0) \
M(Bool, optimize_trivial_count_query, true, "Process trivial 'SELECT count() FROM table' query from metadata.", 0) \
@ -395,7 +394,8 @@ class IColumn;
M(UInt64, max_memory_usage_for_all_queries, 0, "Obsolete. Will be removed after 2020-10-20", 0) \
\
M(Bool, force_optimize_skip_unused_shards_no_nested, false, "Obsolete setting, does nothing. Will be removed after 2020-12-01. Use force_optimize_skip_unused_shards_nesting instead.", 0) \
M(Bool, experimental_use_processors, true, "Obsolete setting, does nothing. Will be removed after 2020-11-29.", 0)
M(Bool, experimental_use_processors, true, "Obsolete setting, does nothing. Will be removed after 2020-11-29.", 0) \
M(Bool, allow_experimental_database_atomic, true, "Obsolete setting, does nothing. Will be removed after 2021-02-12", 0)
#define FORMAT_FACTORY_SETTINGS(M) \
M(Char, format_csv_delimiter, ',', "The character to be considered as a delimiter in CSV data. If setting with a string, a string has to have a length of 1.", 0) \

View File

@ -1,7 +1,11 @@
#include <string>
#include <Core/MySQLClient.h>
#include <Core/MySQLProtocol.h>
#include <Core/MySQL/MySQLClient.h>
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <Core/MySQL/PacketsReplication.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromString.h>
@ -11,15 +15,19 @@ int main(int argc, char ** argv)
{
using namespace DB;
using namespace MySQLProtocol;
using namespace MySQLProtocol::Generic;
using namespace MySQLProtocol::Authentication;
using namespace MySQLProtocol::ConnectionPhase;
using namespace MySQLProtocol::ProtocolText;
uint8_t sequence_id = 1;
String user = "default";
String password = "123";
String database;
UInt8 charset_utf8 = 33;
UInt32 max_packet_size = MySQLProtocol::MAX_PACKET_LENGTH;
UInt32 max_packet_size = MAX_PACKET_LENGTH;
String mysql_native_password = "mysql_native_password";
UInt32 server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH
@ -34,13 +42,13 @@ int main(int argc, char ** argv)
std::string s0;
WriteBufferFromString out0(s0);
Handshake server_handshake(server_capability_flags, -1, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa");
server_handshake.writePayloadImpl(out0);
Handshake server_handshake(server_capability_flags, -1, "ClickHouse", "mysql_native_password", "aaaaaaaaaaaaaaaaaaaaa", CharacterSet::utf8_general_ci);
server_handshake.writePayload(out0, sequence_id);
/// 1.2 Client reads the greeting
ReadBufferFromString in0(s0);
Handshake client_handshake;
client_handshake.readPayloadImpl(in0);
client_handshake.readPayload(in0, sequence_id);
/// Check packet
ASSERT(server_handshake.capability_flags == client_handshake.capability_flags)
@ -59,12 +67,12 @@ int main(int argc, char ** argv)
String auth_plugin_data = native41.getAuthPluginData();
HandshakeResponse client_handshake_response(
client_capability_flags, max_packet_size, charset_utf8, user, database, auth_plugin_data, mysql_native_password);
client_handshake_response.writePayloadImpl(out1);
client_handshake_response.writePayload(out1, sequence_id);
/// 2.2 Server reads the response
ReadBufferFromString in1(s1);
HandshakeResponse server_handshake_response;
server_handshake_response.readPayloadImpl(in1);
server_handshake_response.readPayload(in1, sequence_id);
/// Check
ASSERT(server_handshake_response.capability_flags == client_handshake_response.capability_flags)
@ -80,13 +88,13 @@ int main(int argc, char ** argv)
// 1. Server writes packet
std::string s0;
WriteBufferFromString out0(s0);
OK_Packet server(0x00, server_capability_flags, 0, 0, 0, "", "");
server.writePayloadImpl(out0);
OKPacket server(0x00, server_capability_flags, 0, 0, 0, "", "");
server.writePayload(out0, sequence_id);
// 2. Client reads packet
ReadBufferFromString in0(s0);
PacketResponse client(server_capability_flags);
client.readPayloadImpl(in0);
ResponsePacket client(server_capability_flags);
client.readPayload(in0, sequence_id);
// Check
ASSERT(client.getType() == PACKET_OK)
@ -100,13 +108,13 @@ int main(int argc, char ** argv)
// 1. Server writes packet
std::string s0;
WriteBufferFromString out0(s0);
ERR_Packet server(123, "12345", "This is the error message");
server.writePayloadImpl(out0);
ERRPacket server(123, "12345", "This is the error message");
server.writePayload(out0, sequence_id);
// 2. Client reads packet
ReadBufferFromString in0(s0);
PacketResponse client(server_capability_flags);
client.readPayloadImpl(in0);
ResponsePacket client(server_capability_flags);
client.readPayload(in0, sequence_id);
// Check
ASSERT(client.getType() == PACKET_ERR)
@ -121,13 +129,13 @@ int main(int argc, char ** argv)
// 1. Server writes packet
std::string s0;
WriteBufferFromString out0(s0);
EOF_Packet server(1, 1);
server.writePayloadImpl(out0);
EOFPacket server(1, 1);
server.writePayload(out0, sequence_id);
// 2. Client reads packet
ReadBufferFromString in0(s0);
PacketResponse client(server_capability_flags);
client.readPayloadImpl(in0);
ResponsePacket client(server_capability_flags);
client.readPayload(in0, sequence_id);
// Check
ASSERT(client.getType() == PACKET_EOF)
@ -142,12 +150,12 @@ int main(int argc, char ** argv)
std::string s0;
WriteBufferFromString out0(s0);
ColumnDefinition server("schema", "tbl", "org_tbl", "name", "org_name", 33, 0x00, MYSQL_TYPE_STRING, 0x00, 0x00);
server.writePayloadImpl(out0);
server.writePayload(out0, sequence_id);
// 2. Client reads packet
ReadBufferFromString in0(s0);
ColumnDefinition client;
client.readPayloadImpl(in0);
client.readPayload(in0, sequence_id);
// Check
ASSERT(client.column_type == server.column_type)

View File

@ -17,15 +17,22 @@ SRCS(
ExternalTable.cpp
Field.cpp
iostream_debug_helpers.cpp
MySQLClient.cpp
MySQLProtocol.cpp
MySQLReplication.cpp
NamesAndTypes.cpp
PostgreSQLProtocol.cpp
Settings.cpp
SettingsEnums.cpp
SettingsFields.cpp
SortDescription.cpp
MySQL/Authentication.cpp
MySQL/IMySQLReadPacket.cpp
MySQL/IMySQLWritePacket.cpp
MySQL/MySQLClient.cpp
MySQL/MySQLReplication.cpp
MySQL/PacketEndpoint.cpp
MySQL/PacketsConnection.cpp
MySQL/PacketsGeneric.cpp
MySQL/PacketsProtocolText.cpp
MySQL/PacketsReplication.cpp
)

View File

@ -366,6 +366,8 @@ void DatabaseAtomic::loadStoredObjects(Context & context, bool has_force_restore
std::lock_guard lock{mutex};
table_names = table_name_to_path;
}
Poco::File(path_to_table_symlinks).createDirectories();
for (const auto & table : table_names)
tryCreateSymlink(table.first, table.second);
}

View File

@ -18,7 +18,7 @@
#endif
#if USE_MYSQL
# include <Core/MySQLClient.h>
# include <Core/MySQL/MySQLClient.h>
# include <Databases/MySQL/DatabaseConnectionMySQL.h>
# include <Databases/MySQL/MaterializeMySQLSettings.h>
# include <Databases/MySQL/DatabaseMaterializeMySQL.h>
@ -103,16 +103,16 @@ DatabasePtr DatabaseFactory::getImpl(const ASTCreateQuery & create, const String
const ASTFunction * engine = engine_define->engine;
if (!engine->arguments || engine->arguments->children.size() != 4)
throw Exception(
"MySQL Database require mysql_hostname, mysql_database_name, mysql_username, mysql_password arguments.",
engine_name + " Database require mysql_hostname, mysql_database_name, mysql_username, mysql_password arguments.",
ErrorCodes::BAD_ARGUMENTS);
ASTs & arguments = engine->arguments->children;
arguments[1] = evaluateConstantExpressionOrIdentifierAsLiteral(arguments[1], context);
const auto & host_name_and_port = safeGetLiteralValue<String>(arguments[0], "MySQL");
const auto & mysql_database_name = safeGetLiteralValue<String>(arguments[1], "MySQL");
const auto & mysql_user_name = safeGetLiteralValue<String>(arguments[2], "MySQL");
const auto & mysql_user_password = safeGetLiteralValue<String>(arguments[3], "MySQL");
const auto & host_name_and_port = safeGetLiteralValue<String>(arguments[0], engine_name);
const auto & mysql_database_name = safeGetLiteralValue<String>(arguments[1], engine_name);
const auto & mysql_user_name = safeGetLiteralValue<String>(arguments[2], engine_name);
const auto & mysql_user_password = safeGetLiteralValue<String>(arguments[3], engine_name);
try
{

View File

@ -153,7 +153,6 @@ void DatabaseWithDictionaries::createDictionary(const Context & context, const S
if (isTableExist(dictionary_name, global_context))
throw Exception(ErrorCodes::TABLE_ALREADY_EXISTS, "Table {} already exists.", dict_id.getFullTableName());
String dictionary_metadata_path = getObjectMetadataPath(dictionary_name);
String dictionary_metadata_tmp_path = dictionary_metadata_path + ".tmp";
String statement = getObjectDefinitionFromCreateQuery(query);

View File

@ -6,6 +6,7 @@
# include <Databases/MySQL/DatabaseMaterializeMySQL.h>
# include <Interpreters/Context.h>
# include <Databases/DatabaseOrdinary.h>
# include <Databases/MySQL/DatabaseMaterializeTablesIterator.h>
# include <Databases/MySQL/MaterializeMySQLSyncThread.h>

View File

@ -5,7 +5,7 @@
#if USE_MYSQL
#include <mysqlxx/Pool.h>
#include <Core/MySQLClient.h>
#include <Core/MySQL/MySQLClient.h>
#include <Databases/IDatabase.h>
#include <Databases/MySQL/MaterializeMySQLSettings.h>
#include <Databases/MySQL/MaterializeMySQLSyncThread.h>

View File

@ -7,7 +7,7 @@
#if USE_MYSQL
#include <Core/Types.h>
#include <Core/MySQLReplication.h>
#include <Core/MySQL/MySQLReplication.h>
#include <mysqlxx/Connection.h>
#include <mysqlxx/PoolWithFailover.h>

View File

@ -8,7 +8,7 @@
# include <mutex>
# include <Core/BackgroundSchedulePool.h>
# include <Core/MySQLClient.h>
# include <Core/MySQL/MySQLClient.h>
# include <DataStreams/BlockIO.h>
# include <DataTypes/DataTypeString.h>
# include <DataTypes/DataTypesNumber.h>

View File

@ -18,6 +18,7 @@ void registerDictionarySourceCassandra(DictionarySourceFactory & factory)
[[maybe_unused]] const std::string & config_prefix,
[[maybe_unused]] Block & sample_block,
const Context & /* context */,
const std::string & /* default_database */,
bool /*check_config*/) -> DictionarySourcePtr
{
#if USE_CASSANDRA

View File

@ -53,7 +53,8 @@ ClickHouseDictionarySource::ClickHouseDictionarySource(
const std::string & path_to_settings,
const std::string & config_prefix,
const Block & sample_block_,
const Context & context_)
const Context & context_,
const std::string & default_database)
: update_time{std::chrono::system_clock::from_time_t(0)}
, dict_struct{dict_struct_}
, host{config.getString(config_prefix + ".host")}
@ -61,7 +62,7 @@ ClickHouseDictionarySource::ClickHouseDictionarySource(
, secure(config.getBool(config_prefix + ".secure", false))
, user{config.getString(config_prefix + ".user", "")}
, password{config.getString(config_prefix + ".password", "")}
, db{config.getString(config_prefix + ".db", "")}
, db{config.getString(config_prefix + ".db", default_database)}
, table{config.getString(config_prefix + ".table")}
, where{config.getString(config_prefix + ".where", "")}
, update_field{config.getString(config_prefix + ".update_field", "")}
@ -226,9 +227,11 @@ void registerDictionarySourceClickHouse(DictionarySourceFactory & factory)
const std::string & config_prefix,
Block & sample_block,
const Context & context,
const std::string & default_database,
bool /* check_config */) -> DictionarySourcePtr
{
return std::make_unique<ClickHouseDictionarySource>(dict_struct, config, config_prefix, config_prefix + ".clickhouse", sample_block, context);
return std::make_unique<ClickHouseDictionarySource>(
dict_struct, config, config_prefix, config_prefix + ".clickhouse", sample_block, context, default_database);
};
factory.registerSource("clickhouse", create_table_source);
}

View File

@ -24,7 +24,8 @@ public:
const std::string & path_to_settings,
const std::string & config_prefix,
const Block & sample_block_,
const Context & context);
const Context & context,
const std::string & default_database);
/// copy-constructor is provided in order to support cloneability
ClickHouseDictionarySource(const ClickHouseDictionarySource & other);

View File

@ -42,7 +42,8 @@ DictionaryPtr DictionaryFactory::create(
const DictionaryStructure dict_struct{config, config_prefix + ".structure"};
DictionarySourcePtr source_ptr = DictionarySourceFactory::instance().create(name, config, config_prefix + ".source", dict_struct, context, check_source_config);
DictionarySourcePtr source_ptr = DictionarySourceFactory::instance().create(
name, config, config_prefix + ".source", dict_struct, context, config.getString(config_prefix + ".database", ""), check_source_config);
LOG_TRACE(&Poco::Logger::get("DictionaryFactory"), "Created dictionary source '{}' for dictionary '{}'", source_ptr->toString(), name);
const auto & layout_type = keys.front();

View File

@ -80,6 +80,7 @@ DictionarySourcePtr DictionarySourceFactory::create(
const std::string & config_prefix,
const DictionaryStructure & dict_struct,
const Context & context,
const std::string & default_database,
bool check_config) const
{
Poco::Util::AbstractConfiguration::Keys keys;
@ -96,7 +97,7 @@ DictionarySourcePtr DictionarySourceFactory::create(
{
const auto & create_source = found->second;
auto sample_block = createSampleBlock(dict_struct);
return create_source(dict_struct, config, config_prefix, sample_block, context, check_config);
return create_source(dict_struct, config, config_prefix, sample_block, context, default_database, check_config);
}
throw Exception{name + ": unknown dictionary source type: " + source_type, ErrorCodes::UNKNOWN_ELEMENT_IN_CONFIG};

View File

@ -26,12 +26,16 @@ class DictionarySourceFactory : private boost::noncopyable
public:
static DictionarySourceFactory & instance();
/// 'default_database' - the database where dictionary itself was created.
/// It is used as default_database for ClickHouse dictionary source when no explicit database was specified.
/// Does not make sense for other sources.
using Creator = std::function<DictionarySourcePtr(
const DictionaryStructure & dict_struct,
const Poco::Util::AbstractConfiguration & config,
const std::string & config_prefix,
Block & sample_block,
const Context & context,
const std::string & default_database,
bool check_config)>;
DictionarySourceFactory();
@ -44,6 +48,7 @@ public:
const std::string & config_prefix,
const DictionaryStructure & dict_struct,
const Context & context,
const std::string & default_database,
bool check_config) const;
private:

View File

@ -220,6 +220,7 @@ void registerDictionarySourceExecutable(DictionarySourceFactory & factory)
const std::string & config_prefix,
Block & sample_block,
const Context & context,
const std::string & /* default_database */,
bool check_config) -> DictionarySourcePtr
{
if (dict_struct.has_expressions)

View File

@ -76,6 +76,7 @@ void registerDictionarySourceFile(DictionarySourceFactory & factory)
const std::string & config_prefix,
Block & sample_block,
const Context & context,
const std::string & /* default_database */,
bool check_config) -> DictionarySourcePtr
{
if (dict_struct.has_expressions)

View File

@ -197,6 +197,7 @@ void registerDictionarySourceHTTP(DictionarySourceFactory & factory)
const std::string & config_prefix,
Block & sample_block,
const Context & context,
const std::string & /* default_database */,
bool check_config) -> DictionarySourcePtr
{
if (dict_struct.has_expressions)

View File

@ -298,6 +298,7 @@ void registerDictionarySourceLibrary(DictionarySourceFactory & factory)
const std::string & config_prefix,
Block & sample_block,
const Context & context,
const std::string & /* default_database */,
bool check_config) -> DictionarySourcePtr
{
return std::make_unique<LibraryDictionarySource>(dict_struct, config, config_prefix + ".library", sample_block, context, check_config);

View File

@ -14,6 +14,7 @@ void registerDictionarySourceMongoDB(DictionarySourceFactory & factory)
const std::string & root_config_prefix,
Block & sample_block,
const Context &,
const std::string & /* default_database */,
bool /* check_config */)
{
const auto config_prefix = root_config_prefix + ".mongodb";

View File

@ -19,6 +19,7 @@ void registerDictionarySourceMysql(DictionarySourceFactory & factory)
const std::string & config_prefix,
Block & sample_block,
const Context & /* context */,
const std::string & /* default_database */,
bool /* check_config */) -> DictionarySourcePtr {
#if USE_MYSQL
return std::make_unique<MySQLDictionarySource>(dict_struct, config, config_prefix + ".mysql", sample_block);

View File

@ -13,6 +13,7 @@ void registerDictionarySourceRedis(DictionarySourceFactory & factory)
const String & config_prefix,
Block & sample_block,
const Context & /* context */,
const std::string & /* default_database */,
bool /* check_config */) -> DictionarySourcePtr {
return std::make_unique<RedisDictionarySource>(dict_struct, config, config_prefix + ".redis", sample_block);
};

View File

@ -275,6 +275,7 @@ void registerDictionarySourceXDBC(DictionarySourceFactory & factory)
const std::string & config_prefix,
Block & sample_block,
const Context & context,
const std::string & /* default_database */,
bool /* check_config */) -> DictionarySourcePtr {
#if USE_ODBC
BridgeHelperPtr bridge = std::make_shared<XDBCBridgeHelper<ODBCBridgeMixin>>(
@ -300,6 +301,7 @@ void registerDictionarySourceJDBC(DictionarySourceFactory & factory)
const std::string & /* config_prefix */,
Block & /* sample_block */,
const Context & /* context */,
const std::string & /* default_database */,
bool /* check_config */) -> DictionarySourcePtr {
throw Exception{"Dictionary source of type `jdbc` is disabled until consistent support for nullable fields.",
ErrorCodes::SUPPORT_IS_DISABLED};

View File

@ -402,7 +402,11 @@ void buildConfigurationFromFunctionWithKeyValueArguments(
* </mysql>
* </source>
*/
void buildSourceConfiguration(AutoPtr<Document> doc, AutoPtr<Element> root, const ASTFunctionWithKeyValueArguments * source, const ASTDictionarySettings * settings)
void buildSourceConfiguration(
AutoPtr<Document> doc,
AutoPtr<Element> root,
const ASTFunctionWithKeyValueArguments * source,
const ASTDictionarySettings * settings)
{
AutoPtr<Element> outer_element(doc->createElement("source"));
root->appendChild(outer_element);
@ -498,7 +502,9 @@ DictionaryConfigurationPtr getDictionaryConfigurationFromAST(const ASTCreateQuer
bool complex = DictionaryFactory::instance().isComplex(dictionary_layout->layout_type);
auto all_attr_names_and_types = buildDictionaryAttributesConfiguration(xml_document, structure_element, query.dictionary_attributes_list, pk_attrs);
auto all_attr_names_and_types = buildDictionaryAttributesConfiguration(
xml_document, structure_element, query.dictionary_attributes_list, pk_attrs);
checkPrimaryKey(all_attr_names_and_types, pk_attrs);
buildPrimaryKeyConfiguration(xml_document, structure_element, complex, pk_attrs, query.dictionary_attributes_list);

View File

@ -12,11 +12,15 @@
namespace DB::ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int TOO_LARGE_ARRAY_SIZE;
}
namespace DB::GatherUtils
{
inline constexpr size_t MAX_ARRAY_SIZE = 1 << 30;
/// Methods to copy Slice to Sink, overloaded for various combinations of types.
template <typename T>
@ -673,6 +677,10 @@ void resizeDynamicSize(ArraySource && array_source, ValueSource && value_source,
if (size >= 0)
{
auto length = static_cast<size_t>(size);
if (length > MAX_ARRAY_SIZE)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size: {}, maximum: {}",
length, MAX_ARRAY_SIZE);
if (array_size <= length)
{
writeSlice(array_source.getWhole(), sink);
@ -685,6 +693,10 @@ void resizeDynamicSize(ArraySource && array_source, ValueSource && value_source,
else
{
auto length = static_cast<size_t>(-size);
if (length > MAX_ARRAY_SIZE)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size: {}, maximum: {}",
length, MAX_ARRAY_SIZE);
if (array_size <= length)
{
for (size_t i = array_size; i < length; ++i)
@ -714,6 +726,10 @@ void resizeConstantSize(ArraySource && array_source, ValueSource && value_source
if (size >= 0)
{
auto length = static_cast<size_t>(size);
if (length > MAX_ARRAY_SIZE)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size: {}, maximum: {}",
length, MAX_ARRAY_SIZE);
if (array_size <= length)
{
writeSlice(array_source.getWhole(), sink);
@ -726,6 +742,10 @@ void resizeConstantSize(ArraySource && array_source, ValueSource && value_source
else
{
auto length = static_cast<size_t>(-size);
if (length > MAX_ARRAY_SIZE)
throw Exception(ErrorCodes::TOO_LARGE_ARRAY_SIZE, "Too large array size: {}, maximum: {}",
length, MAX_ARRAY_SIZE);
if (array_size <= length)
{
for (size_t i = array_size; i < length; ++i)

View File

@ -57,7 +57,6 @@ void sliceHas(IArraySource & first, IArraySource & second, ArraySearchType searc
void push(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, bool push_front);
void resizeDynamicSize(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, const IColumn & size_column);
void resizeConstantSize(IArraySource & array_source, IValueSource & value_source, IArraySink & sink, ssize_t size);
}

View File

@ -102,7 +102,8 @@ private:
}
template <typename T>
bool executeConstStartStep(Block & block, const IColumn * end_arg, const T start, const T step, const size_t input_rows_count, const size_t result) const
bool executeConstStartStep(
Block & block, const IColumn * end_arg, const T start, const T step, const size_t input_rows_count, const size_t result) const
{
auto end_column = checkAndGetColumn<ColumnVector<T>>(end_arg);
if (!end_column)
@ -145,8 +146,14 @@ private:
for (size_t row_idx = 0; row_idx < input_rows_count; ++row_idx)
{
for (size_t st = start, ed = end_data[row_idx]; st < ed; st += step)
{
out_data[offset++] = st;
if (st > st + step)
throw Exception{"A call to function " + getName() + " overflows, investigate the values of arguments you are passing",
ErrorCodes::ARGUMENT_OUT_OF_BOUND};
}
out_offsets[row_idx] = offset;
}
@ -155,7 +162,8 @@ private:
}
template <typename T>
bool executeConstStep(Block & block, const IColumn * start_arg, const IColumn * end_arg, const T step, const size_t input_rows_count, const size_t result) const
bool executeConstStep(
Block & block, const IColumn * start_arg, const IColumn * end_arg, const T step, const size_t input_rows_count, const size_t result) const
{
auto start_column = checkAndGetColumn<ColumnVector<T>>(start_arg);
auto end_column = checkAndGetColumn<ColumnVector<T>>(end_arg);
@ -200,8 +208,14 @@ private:
for (size_t row_idx = 0; row_idx < input_rows_count; ++row_idx)
{
for (size_t st = start_data[row_idx], ed = end_data[row_idx]; st < ed; st += step)
{
out_data[offset++] = st;
if (st > st + step)
throw Exception{"A call to function " + getName() + " overflows, investigate the values of arguments you are passing",
ErrorCodes::ARGUMENT_OUT_OF_BOUND};
}
out_offsets[row_idx] = offset;
}
@ -210,7 +224,8 @@ private:
}
template <typename T>
bool executeConstStart(Block & block, const IColumn * end_arg, const IColumn * step_arg, const T start, const size_t input_rows_count, const size_t result) const
bool executeConstStart(
Block & block, const IColumn * end_arg, const IColumn * step_arg, const T start, const size_t input_rows_count, const size_t result) const
{
auto end_column = checkAndGetColumn<ColumnVector<T>>(end_arg);
auto step_column = checkAndGetColumn<ColumnVector<T>>(step_arg);
@ -255,8 +270,14 @@ private:
for (size_t row_idx = 0; row_idx < input_rows_count; ++row_idx)
{
for (size_t st = start, ed = end_data[row_idx]; st < ed; st += step_data[row_idx])
{
out_data[offset++] = st;
if (st > st + step_data[row_idx])
throw Exception{"A call to function " + getName() + " overflows, investigate the values of arguments you are passing",
ErrorCodes::ARGUMENT_OUT_OF_BOUND};
}
out_offsets[row_idx] = offset;
}
@ -265,7 +286,9 @@ private:
}
template <typename T>
bool executeGeneric(Block & block, const IColumn * start_col, const IColumn * end_col, const IColumn * step_col, const size_t input_rows_count, const size_t result) const
bool executeGeneric(
Block & block, const IColumn * start_col, const IColumn * end_col, const IColumn * step_col,
const size_t input_rows_count, const size_t result) const
{
auto start_column = checkAndGetColumn<ColumnVector<T>>(start_col);
auto end_column = checkAndGetColumn<ColumnVector<T>>(end_col);
@ -313,8 +336,14 @@ private:
for (size_t row_idx = 0; row_idx < input_rows_count; ++row_idx)
{
for (size_t st = start_data[row_idx], ed = end_start[row_idx]; st < ed; st += step_data[row_idx])
{
out_data[offset++] = st;
if (st > st + step_data[row_idx])
throw Exception{"A call to function " + getName() + " overflows, investigate the values of arguments you are passing",
ErrorCodes::ARGUMENT_OUT_OF_BOUND};
}
out_offsets[row_idx] = offset;
}

View File

@ -14,10 +14,13 @@
namespace DB
{
namespace ErrorCodes
{
extern const int ILLEGAL_TYPE_OF_ARGUMENT;
extern const int PARAMETER_OUT_OF_BOUND;
}
class FunctionH3KRing : public IFunction
{
public:
@ -65,6 +68,15 @@ public:
const H3Index origin_hindex = col_hindex->getUInt(row);
const int k = col_k->getInt(row);
/// Overflow is possible. The function maxKringSize does not check for overflow.
/// The calculation is similar to square of k but several times more.
/// Let's use huge underestimation as the safe bound. We should not allow to generate too large arrays nevertheless.
constexpr auto max_k = 10000;
if (k > max_k)
throw Exception(ErrorCodes::PARAMETER_OUT_OF_BOUND, "Too large 'k' argument for {} function, maximum {}", getName(), max_k);
if (k < 0)
throw Exception(ErrorCodes::PARAMETER_OUT_OF_BOUND, "Argument 'k' for {} function must be non negative", getName());
const auto vec_size = maxKringSize(k);
hindex_vec.resize(vec_size);
kRing(origin_hindex, k, hindex_vec.data());

View File

@ -0,0 +1,65 @@
#include <IO/MySQLPacketPayloadReadBuffer.h>
#include <sstream>
namespace DB
{
namespace ErrorCodes
{
extern const int UNKNOWN_PACKET_FROM_CLIENT;
}
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
MySQLPacketPayloadReadBuffer::MySQLPacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_)
: ReadBuffer(in_.position(), 0), in(in_), sequence_id(sequence_id_) // not in.buffer().begin(), because working buffer may include previous packet
{
}
bool MySQLPacketPayloadReadBuffer::nextImpl()
{
if (!has_read_header || (payload_length == MAX_PACKET_LENGTH && offset == payload_length))
{
has_read_header = true;
working_buffer.resize(0);
offset = 0;
payload_length = 0;
in.readStrict(reinterpret_cast<char *>(&payload_length), 3);
if (payload_length > MAX_PACKET_LENGTH)
{
std::ostringstream tmp;
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
throw Exception(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
size_t packet_sequence_id = 0;
in.read(reinterpret_cast<char &>(packet_sequence_id));
if (packet_sequence_id != sequence_id)
{
std::ostringstream tmp;
tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << static_cast<unsigned int>(sequence_id) << '.';
throw Exception(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
sequence_id++;
if (payload_length == 0)
return false;
}
else if (offset == payload_length)
{
return false;
}
in.nextIfAtEnd();
working_buffer = ReadBuffer::Buffer(in.position(), in.buffer().end());
size_t count = std::min(in.available(), payload_length - offset);
working_buffer.resize(count);
in.ignore(count);
offset += count;
return true;
}
}

View File

@ -0,0 +1,33 @@
#pragma once
#include <IO/ReadBuffer.h>
namespace DB
{
/** Reading packets.
* Internally, it calls (if no more data) next() method of the underlying ReadBufferFromPocoSocket, and sets the working buffer to the rest part of the current packet payload.
*/
class MySQLPacketPayloadReadBuffer : public ReadBuffer
{
private:
ReadBuffer & in;
uint8_t & sequence_id;
bool has_read_header = false;
// Size of packet which is being read now.
size_t payload_length = 0;
// Offset in packet payload.
size_t offset = 0;
protected:
bool nextImpl() override;
public:
MySQLPacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_);
};
}

View File

@ -0,0 +1,61 @@
#include <IO/MySQLPacketPayloadWriteBuffer.h>
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_WRITE_AFTER_END_OF_BUFFER;
}
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
MySQLPacketPayloadWriteBuffer::MySQLPacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_)
: WriteBuffer(out_.position(), 0), out(out_), sequence_id(sequence_id_), total_left(payload_length_)
{
startNewPacket();
setWorkingBuffer();
pos = out.position();
}
void MySQLPacketPayloadWriteBuffer::startNewPacket()
{
payload_length = std::min(total_left, MAX_PACKET_LENGTH);
bytes_written = 0;
total_left -= payload_length;
out.write(reinterpret_cast<char *>(&payload_length), 3);
out.write(sequence_id++);
bytes += 4;
}
void MySQLPacketPayloadWriteBuffer::setWorkingBuffer()
{
out.nextIfAtEnd();
working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
if (payload_length - bytes_written == 0)
{
/// Finished writing packet. Due to an implementation of WriteBuffer, working_buffer cannot be empty. Further write attempts will throw Exception.
eof = true;
working_buffer.resize(1);
}
}
void MySQLPacketPayloadWriteBuffer::nextImpl()
{
const int written = pos - working_buffer.begin();
if (eof)
throw Exception("Cannot write after end of buffer.", ErrorCodes::CANNOT_WRITE_AFTER_END_OF_BUFFER);
out.position() += written;
bytes_written += written;
/// Packets of size greater than MAX_PACKET_LENGTH are split into few packets of size MAX_PACKET_LENGTH and las packet of size < MAX_PACKET_LENGTH.
if (bytes_written == payload_length && (total_left > 0 || payload_length == MAX_PACKET_LENGTH))
startNewPacket();
setWorkingBuffer();
}
}

View File

@ -0,0 +1,36 @@
#pragma once
#include <IO/WriteBuffer.h>
namespace DB
{
/** Writing packets.
* https://dev.mysql.com/doc/internals/en/mysql-packet.html
*/
class MySQLPacketPayloadWriteBuffer : public WriteBuffer
{
public:
MySQLPacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_);
bool remainingPayloadSize() { return total_left; }
protected:
void nextImpl() override;
private:
WriteBuffer & out;
uint8_t & sequence_id;
size_t total_left = 0;
size_t payload_length = 0;
size_t bytes_written = 0;
bool eof = false;
void startNewPacket();
/// Sets working buffer to the rest of current packet payload.
void setWorkingBuffer();
};
}

View File

@ -26,6 +26,8 @@ SRCS(
MemoryReadWriteBuffer.cpp
MMapReadBufferFromFile.cpp
MMapReadBufferFromFileDescriptor.cpp
MySQLPacketPayloadReadBuffer.cpp
MySQLPacketPayloadWriteBuffer.cpp
NullWriteBuffer.cpp
parseDateTimeBestEffort.cpp
PeekableReadBuffer.cpp

View File

@ -449,7 +449,6 @@ void NO_INLINE Aggregator::executeImpl(
typename Method::State state(key_columns, key_sizes, aggregation_state_cache);
if (!no_more_keys)
//executeImplCase<false>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
executeImplBatch(method, state, aggregates_pool, rows, aggregate_instructions);
else
executeImplCase<true>(method, state, aggregates_pool, rows, aggregate_instructions, overflow_row);
@ -533,6 +532,19 @@ void NO_INLINE Aggregator::executeImplBatch(
/// Optimization for special case when aggregating by 8bit key.
if constexpr (std::is_same_v<Method, typename decltype(AggregatedDataVariants::key8)::element_type>)
{
/// We use another method if there are aggregate functions with -Array combinator.
bool has_arrays = false;
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
{
if (inst->offsets)
{
has_arrays = true;
break;
}
}
if (!has_arrays)
{
for (AggregateFunctionInstruction * inst = aggregate_instructions; inst->that; ++inst)
{
@ -551,6 +563,7 @@ void NO_INLINE Aggregator::executeImplBatch(
}
return;
}
}
/// Generic case.

View File

@ -142,10 +142,6 @@ BlockIO InterpreterCreateQuery::createDatabase(ASTCreateQuery & create)
if (create.storage->engine->name == "Atomic")
{
if (!context.getSettingsRef().allow_experimental_database_atomic && !internal)
throw Exception("Atomic is an experimental database engine. "
"Enable allow_experimental_database_atomic to use it.", ErrorCodes::UNKNOWN_DATABASE_ENGINE);
if (create.attach && create.uuid == UUIDHelpers::Nil)
throw Exception("UUID must be specified for ATTACH", ErrorCodes::INCORRECT_QUERY);
else if (create.uuid == UUIDHelpers::Nil)

View File

@ -28,6 +28,7 @@ namespace ErrorCodes
{
extern const int ALIAS_REQUIRED;
extern const int AMBIGUOUS_COLUMN_NAME;
extern const int LOGICAL_ERROR;
}
namespace
@ -187,7 +188,8 @@ StoragePtr JoinedTables::getLeftTableStorage()
bool JoinedTables::resolveTables()
{
tables_with_columns = getDatabaseAndTablesWithColumns(table_expressions, context);
assert(tables_with_columns.size() == table_expressions.size());
if (tables_with_columns.size() != table_expressions.size())
throw Exception("Unexpected tables count", ErrorCodes::LOGICAL_ERROR);
const auto & settings = context.getSettingsRef();
if (settings.joined_subquery_requires_alias && tables_with_columns.size() > 1)

View File

@ -108,7 +108,7 @@ std::optional<String> findFirstNonDeterministicFunctionName(const MutationComman
ASTPtr prepareQueryAffectedAST(const std::vector<MutationCommand> & commands)
{
/// Execute `SELECT count() FROM storage WHERE predicate1 OR predicate2 OR ...` query.
/// The result can differ from tne number of affected rows (e.g. if there is an UPDATE command that
/// The result can differ from the number of affected rows (e.g. if there is an UPDATE command that
/// changes how many rows satisfy the predicates of the subsequent commands).
/// But we can be sure that if count = 0, then no rows will be touched.

View File

@ -344,14 +344,14 @@ ASTs InterpreterCreateImpl::getRewrittenQueries(
const auto & create_materialized_column_declaration = [&](const String & name, const String & type, const auto & default_value)
{
const auto column_declaration = std::make_shared<ASTColumnDeclaration>();
auto column_declaration = std::make_shared<ASTColumnDeclaration>();
column_declaration->name = name;
column_declaration->type = makeASTFunction(type);
column_declaration->default_specifier = "MATERIALIZED";
column_declaration->default_expression = std::make_shared<ASTLiteral>(default_value);
column_declaration->children.emplace_back(column_declaration->type);
column_declaration->children.emplace_back(column_declaration->default_expression);
return std::move(column_declaration);
return column_declaration;
};
/// Add _sign and _version column.

View File

@ -15,6 +15,7 @@ namespace DB
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
PredicateExpressionsOptimizer::PredicateExpressionsOptimizer(
@ -111,6 +112,10 @@ bool PredicateExpressionsOptimizer::tryRewritePredicatesToTables(ASTs & tables_e
{
bool is_rewrite_tables = false;
if (tables_element.size() != tables_predicates.size())
throw Exception("Unexpected elements count in predicate push down: `set enable_optimize_predicate_expression = 0` to disable",
ErrorCodes::LOGICAL_ERROR);
for (size_t index = tables_element.size(); index > 0; --index)
{
size_t table_pos = index - 1;

View File

@ -474,14 +474,17 @@ static std::tuple<ASTPtr, BlockIO> executeQueryImpl(
bool log_queries = settings.log_queries && !internal;
/// Log into system table start of query execution, if need.
if (log_queries && elem.type >= settings.log_queries_min_type)
if (log_queries)
{
if (settings.log_query_settings)
elem.query_settings = std::make_shared<Settings>(context.getSettingsRef());
if (elem.type >= settings.log_queries_min_type)
{
if (auto query_log = context.getQueryLog())
query_log->add(elem);
}
}
/// Common code for finish and exception callbacks
auto status_info_to_query_log = [ast](QueryLogElement &element, const QueryStatusInfo &info) mutable

View File

@ -84,9 +84,6 @@ static void loadDatabase(
}
#define SYSTEM_DATABASE "system"
void loadMetadata(Context & context, const String & default_database_name)
{
Poco::Logger * log = &Poco::Logger::get("loadMetadata");
@ -114,7 +111,7 @@ void loadMetadata(Context & context, const String & default_database_name)
if (endsWith(it.name(), ".sql"))
{
String db_name = it.name().substr(0, it.name().size() - 4);
if (db_name != SYSTEM_DATABASE)
if (db_name != DatabaseCatalog::SYSTEM_DATABASE)
databases.emplace(unescapeForFileName(db_name), path + "/" + db_name);
}
@ -140,7 +137,7 @@ void loadMetadata(Context & context, const String & default_database_name)
if (it.name().at(0) == '.')
continue;
if (it.name() == SYSTEM_DATABASE)
if (it.name() == DatabaseCatalog::SYSTEM_DATABASE)
continue;
databases.emplace(unescapeForFileName(it.name()), it.path().toString());
@ -172,21 +169,20 @@ void loadMetadata(Context & context, const String & default_database_name)
void loadMetadataSystem(Context & context)
{
String path = context.getPath() + "metadata/" SYSTEM_DATABASE;
if (Poco::File(path).exists())
String path = context.getPath() + "metadata/" + DatabaseCatalog::SYSTEM_DATABASE;
String metadata_file = path + ".sql";
if (Poco::File(path).exists() || Poco::File(metadata_file).exists())
{
/// 'has_force_restore_data_flag' is true, to not fail on loading query_log table, if it is corrupted.
loadDatabase(context, SYSTEM_DATABASE, path, true);
loadDatabase(context, DatabaseCatalog::SYSTEM_DATABASE, path, true);
}
else
{
/// Initialize system database manually
String global_path = context.getPath();
Poco::File(global_path + "data/" SYSTEM_DATABASE).createDirectories();
Poco::File(global_path + "metadata/" SYSTEM_DATABASE).createDirectories();
auto system_database = std::make_shared<DatabaseOrdinary>(SYSTEM_DATABASE, global_path + "metadata/" SYSTEM_DATABASE "/", context);
DatabaseCatalog::instance().attachDatabase(SYSTEM_DATABASE, system_database);
String database_create_query = "CREATE DATABASE ";
database_create_query += DatabaseCatalog::SYSTEM_DATABASE;
database_create_query += " ENGINE=Atomic";
executeCreateQuery(database_create_query, context, DatabaseCatalog::SYSTEM_DATABASE, "<no file>", true);
}
}

View File

@ -1,5 +1,4 @@
#include <Processors/Formats/Impl/MySQLOutputFormat.h>
#include <Core/MySQLProtocol.h>
#include <Interpreters/ProcessList.h>
#include <Formats/FormatFactory.h>
#include <Interpreters/Context.h>
@ -10,6 +9,8 @@ namespace DB
{
using namespace MySQLProtocol;
using namespace MySQLProtocol::Generic;
using namespace MySQLProtocol::ProtocolText;
MySQLOutputFormat::MySQLOutputFormat(WriteBuffer & out_, const Block & header_, const FormatSettings & settings_)
@ -29,17 +30,17 @@ void MySQLOutputFormat::initialize()
if (header.columns())
{
packet_sender->sendPacket(LengthEncodedNumber(header.columns()));
packet_endpoint->sendPacket(LengthEncodedNumber(header.columns()));
for (size_t i = 0; i < header.columns(); i++)
{
const auto & column_name = header.getColumnsWithTypeAndName()[i].name;
packet_sender->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
packet_endpoint->sendPacket(getColumnDefinition(column_name, data_types[i]->getTypeId()));
}
if (!(context->mysql.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
{
packet_sender->sendPacket(EOF_Packet(0, 0));
packet_endpoint->sendPacket(EOFPacket(0, 0));
}
}
}
@ -52,8 +53,8 @@ void MySQLOutputFormat::consume(Chunk chunk)
for (size_t i = 0; i < chunk.getNumRows(); i++)
{
ProtocolText::ResultsetRow row_packet(data_types, chunk.getColumns(), i);
packet_sender->sendPacket(row_packet);
ProtocolText::ResultSetRow row_packet(data_types, chunk.getColumns(), i);
packet_endpoint->sendPacket(row_packet);
}
}
@ -75,17 +76,17 @@ void MySQLOutputFormat::finalize()
const auto & header = getPort(PortKind::Main).getHeader();
if (header.columns() == 0)
packet_sender->sendPacket(OK_Packet(0x0, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
packet_endpoint->sendPacket(OKPacket(0x0, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else
if (context->mysql.client_capabilities & CLIENT_DEPRECATE_EOF)
packet_sender->sendPacket(OK_Packet(0xfe, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
packet_endpoint->sendPacket(OKPacket(0xfe, context->mysql.client_capabilities, affected_rows, 0, 0, "", human_readable_info), true);
else
packet_sender->sendPacket(EOF_Packet(0, 0), true);
packet_endpoint->sendPacket(EOFPacket(0, 0), true);
}
void MySQLOutputFormat::flush()
{
packet_sender->out->next();
packet_endpoint->out->next();
}
void registerOutputFormatProcessorMySQLWire(FormatFactory & factory)

View File

@ -3,7 +3,10 @@
#include <Processors/Formats/IRowOutputFormat.h>
#include <Core/Block.h>
#include <Core/MySQLProtocol.h>
#include <Core/MySQL/Authentication.h>
#include <Core/MySQL/PacketsGeneric.h>
#include <Core/MySQL/PacketsConnection.h>
#include <Core/MySQL/PacketsProtocolText.h>
#include <Formats/FormatSettings.h>
namespace DB
@ -26,8 +29,7 @@ public:
void setContext(const Context & context_)
{
context = &context_;
packet_sender = std::make_unique<MySQLProtocol::PacketSender>(out, const_cast<uint8_t &>(context_.mysql.sequence_id)); /// TODO: fix it
packet_sender->max_packet_size = context_.mysql.max_packet_size;
packet_endpoint = std::make_unique<MySQLProtocol::PacketEndpoint>(out, const_cast<uint8_t &>(context_.mysql.sequence_id)); /// TODO: fix it
}
void consume(Chunk) override;
@ -42,7 +44,7 @@ private:
bool initialized = false;
const Context * context = nullptr;
std::unique_ptr<MySQLProtocol::PacketSender> packet_sender;
std::unique_ptr<MySQLProtocol::PacketEndpoint> packet_endpoint;
FormatSettings format_settings;
DataTypes data_types;
};

View File

@ -140,16 +140,24 @@ void TabSeparatedRowInputFormat::readPrefix()
if (format_settings.with_names_use_header)
{
String column_name;
do
for (;;)
{
readEscapedString(column_name, in);
if (!checkChar('\t', in))
{
/// Check last column for \r before adding it, otherwise an error will be:
/// "Unknown field found in TSV header"
checkForCarriageReturn(in);
addInputColumn(column_name);
break;
}
else
addInputColumn(column_name);
}
while (checkChar('\t', in));
if (!in.eof())
{
checkForCarriageReturn(in);
assertChar('\n', in);
}
}

Some files were not shown because too many files have changed in this diff Show More