Merge remote-tracking branch 'upstream/master' into issue-5286

This commit is contained in:
Ivan Lezhankin 2019-06-03 17:19:09 +03:00
commit b311984879
371 changed files with 10051 additions and 2464 deletions

View File

@ -1,8 +1,59 @@
## ClickHouse release 19.6.2.11, 2019-05-13
### New Features
* TTL expressions for columns and tables. [#4212](https://github.com/yandex/ClickHouse/pull/4212) ([Anton Popov](https://github.com/CurtizJ))
* Added support for `brotli` compression for HTTP responses (Accept-Encoding: br) [#4388](https://github.com/yandex/ClickHouse/pull/4388) ([Mikhail](https://github.com/fandyushin))
* Added new function `isValidUTF8` for checking whether a set of bytes is correctly utf-8 encoded. [#4934](https://github.com/yandex/ClickHouse/pull/4934) ([Danila Kutenin](https://github.com/danlark1))
* Add new load balancing policy `first_or_random` which sends queries to the first specified host and if it's inaccessible send queries to random hosts of shard. Useful for cross-replication topology setups. [#5012](https://github.com/yandex/ClickHouse/pull/5012) ([nvartolomei](https://github.com/nvartolomei))
### Experimental Features
* Add setting `index_granularity_bytes` (adaptive index granularity) for MergeTree* tables family. [#4826](https://github.com/yandex/ClickHouse/pull/4826) ([alesapin](https://github.com/alesapin))
### Improvements
* Added support for non-constant and negative size and length arguments for function `substringUTF8`. [#4989](https://github.com/yandex/ClickHouse/pull/4989) ([alexey-milovidov](https://github.com/alexey-milovidov))
* Disable push-down to right table in left join, left table in right join, and both tables in full join. This fixes wrong JOIN results in some cases. [#4846](https://github.com/yandex/ClickHouse/pull/4846) ([Ivan](https://github.com/abyss7))
* `clickhouse-copier`: auto upload task configuration from `--task-file` option [#4876](https://github.com/yandex/ClickHouse/pull/4876) ([proller](https://github.com/proller))
* Added typos handler for storage factory and table functions factory. [#4891](https://github.com/yandex/ClickHouse/pull/4891) ([Danila Kutenin](https://github.com/danlark1))
* Support asterisks and qualified asterisks for multiple joins without subqueries [#4898](https://github.com/yandex/ClickHouse/pull/4898) ([Artem Zuikov](https://github.com/4ertus2))
* Make missing column error message more user friendly. [#4915](https://github.com/yandex/ClickHouse/pull/4915) ([Artem Zuikov](https://github.com/4ertus2))
### Performance Improvements
* Significant speedup of ASOF JOIN [#4924](https://github.com/yandex/ClickHouse/pull/4924) ([Martijn Bakker](https://github.com/Gladdy))
### Backward Incompatible Changes
* HTTP header `Query-Id` was renamed to `X-ClickHouse-Query-Id` for consistency. [#4972](https://github.com/yandex/ClickHouse/pull/4972) ([Mikhail](https://github.com/fandyushin))
### Bug Fixes
* Fixed potential null pointer dereference in `clickhouse-copier`. [#4900](https://github.com/yandex/ClickHouse/pull/4900) ([proller](https://github.com/proller))
* Fixed error on query with JOIN + ARRAY JOIN [#4938](https://github.com/yandex/ClickHouse/pull/4938) ([Artem Zuikov](https://github.com/4ertus2))
* Fixed hanging on start of the server when a dictionary depends on another dictionary via a database with engine=Dictionary. [#4962](https://github.com/yandex/ClickHouse/pull/4962) ([Vitaly Baranov](https://github.com/vitlibar))
* Partially fix distributed_product_mode = local. It's possible to allow columns of local tables in where/having/order by/... via table aliases. Throw exception if table does not have alias. There's not possible to access to the columns without table aliases yet. [#4986](https://github.com/yandex/ClickHouse/pull/4986) ([Artem Zuikov](https://github.com/4ertus2))
* Fix potentially wrong result for `SELECT DISTINCT` with `JOIN` [#5001](https://github.com/yandex/ClickHouse/pull/5001) ([Artem Zuikov](https://github.com/4ertus2))
### Build/Testing/Packaging Improvements
* Fixed test failures when running clickhouse-server on different host [#4713](https://github.com/yandex/ClickHouse/pull/4713) ([Vasily Nemkov](https://github.com/Enmk))
* clickhouse-test: Disable color control sequences in non tty environment. [#4937](https://github.com/yandex/ClickHouse/pull/4937) ([alesapin](https://github.com/alesapin))
* clickhouse-test: Allow use any test database (remove `test.` qualification where it possible) [#5008](https://github.com/yandex/ClickHouse/pull/5008) ([proller](https://github.com/proller))
* Fix ubsan errors [#5037](https://github.com/yandex/ClickHouse/pull/5037) ([Vitaly Baranov](https://github.com/vitlibar))
* Yandex LFAlloc was added to ClickHouse to allocate MarkCache and UncompressedCache data in different ways to catch segfaults more reliable [#4995](https://github.com/yandex/ClickHouse/pull/4995) ([Danila Kutenin](https://github.com/danlark1))
* Python util to help with backports and changelogs. [#4949](https://github.com/yandex/ClickHouse/pull/4949) ([Ivan](https://github.com/abyss7))
## ClickHouse release 19.5.4.22, 2019-05-13
### Bug fixes
* Fixed possible crash in bitmap* functions [#5220](https://github.com/yandex/ClickHouse/pull/5220) [#5228](https://github.com/yandex/ClickHouse/pull/5228) ([Andy Yang](https://github.com/andyyzh))
* Fixed very rare data race condition that could happen when executing a query with UNION ALL involving at least two SELECTs from system.columns, system.tables, system.parts, system.parts_tables or tables of Merge family and performing ALTER of columns of the related tables concurrently. [#5189](https://github.com/yandex/ClickHouse/pull/5189) ([alexey-milovidov](https://github.com/alexey-milovidov))
* Fixed error `Set for IN is not created yet in case of using single LowCardinality column in the left part of IN`. This error happened if LowCardinality column was the part of primary key. #5031 [#5154](https://github.com/yandex/ClickHouse/pull/5154) ([Nikolai Kochetov](https://github.com/KochetovNicolai))
* Modification of retention function: If a row satisfies both the first and NTH condition, only the first satisfied condition is added to the data state. Now all conditions that satisfy in a row of data are added to the data state. [#5119](https://github.com/yandex/ClickHouse/pull/5119) ([小路](https://github.com/nicelulu))
## ClickHouse release 19.5.3.8, 2019-04-18
### Bug fixes
* Fixed type of setting `max_partitions_per_insert_block` from boolean to UInt64. [#5028](https://github.com/yandex/ClickHouse/pull/5028) ([Mohammad Hossein Sekhavat](https://github.com/mhsekhavat))
## ClickHouse release 19.5.2.6, 2019-04-15
### New Features
@ -294,7 +345,7 @@
* Added support of `Nullable` types in `mysql` table function. [#4198](https://github.com/yandex/ClickHouse/pull/4198) ([Emmanuel Donin de Rosière](https://github.com/edonin))
* Support for arbitrary constant expressions in `LIMIT` clause. [#4246](https://github.com/yandex/ClickHouse/pull/4246) ([k3box](https://github.com/k3box))
* Added `topKWeighted` aggregate function that takes additional argument with (unsigned integer) weight. [#4245](https://github.com/yandex/ClickHouse/pull/4245) ([Andrew Golman](https://github.com/andrewgolman))
* `StorageJoin` now supports `join_overwrite` setting that allows overwriting existing values of the same key. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird)
* `StorageJoin` now supports `join_any_take_last_row` setting that allows overwriting existing values of the same key. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird)
* Added function `toStartOfInterval`. [#4304](https://github.com/yandex/ClickHouse/pull/4304) ([Vitaly Baranov](https://github.com/vitlibar))
* Added `RowBinaryWithNamesAndTypes` format. [#4200](https://github.com/yandex/ClickHouse/pull/4200) ([Oleg V. Kozlyuk](https://github.com/DarkWanderer))
* Added `IPv4` and `IPv6` data types. More effective implementations of `IPv*` functions. [#3669](https://github.com/yandex/ClickHouse/pull/3669) ([Vasily Nemkov](https://github.com/Enmk))

View File

@ -1,3 +1,53 @@
## ClickHouse release 19.6.2.11, 2019-05-13
### Новые возможности
* TTL выражения, позволяющие настроить время жизни и автоматическую очистку данных в таблице или в отдельных её столбцах. [#4212](https://github.com/yandex/ClickHouse/pull/4212) ([Anton Popov](https://github.com/CurtizJ))
* Добавлена поддержка алгоритма сжатия `brotli` в HTTP ответах (`Accept-Encoding: br`). Для тела POST запросов, эта возможность уже существовала. [#4388](https://github.com/yandex/ClickHouse/pull/4388) ([Mikhail](https://github.com/fandyushin))
* Добавлена функция `isValidUTF8` для проверки, содержит ли строка валидные данные в кодировке UTF-8. [#4934](https://github.com/yandex/ClickHouse/pull/4934) ([Danila Kutenin](https://github.com/danlark1))
* Добавлены новое правило балансировки (`load_balancing`) `first_or_random` по которому запросы посылаются на первый заданый хост и если он недоступен - на случайные хосты шарда. Полезно для топологий с кросс-репликацией. [#5012](https://github.com/yandex/ClickHouse/pull/5012) ([nvartolomei](https://github.com/nvartolomei))
### Эксперемннтальные возможности
* Добавлена настройка `index_granularity_bytes` (адаптивная гранулярность индекса) для таблиц семейства MergeTree* . [#4826](https://github.com/yandex/ClickHouse/pull/4826) ([alesapin](https://github.com/alesapin))
### Улучшения
* Добавлена поддержка для не константных и отрицательных значений аргументов смещения и длины для функции `substringUTF8`. [#4989](https://github.com/yandex/ClickHouse/pull/4989) ([alexey-milovidov](https://github.com/alexey-milovidov))
* Отключение push-down в правую таблицы в left join, левую таблицу в right join, и в обе таблицы в full join. Это исправляет неправильные JOIN результаты в некоторых случаях. [#4846](https://github.com/yandex/ClickHouse/pull/4846) ([Ivan](https://github.com/abyss7))
* `clickhouse-copier`: Автоматическая загрузка конфигурации задачи в zookeeper из `--task-file` опции [#4876](https://github.com/yandex/ClickHouse/pull/4876) ([proller](https://github.com/proller))
* Добавлены подсказки с учётом опечаток для имён движков таблиц и табличных функций. [#4891](https://github.com/yandex/ClickHouse/pull/4891) ([Danila Kutenin](https://github.com/danlark1))
* Поддержка выражений `select *` и `select tablename.*` для множественных join без подзапросов [#4898](https://github.com/yandex/ClickHouse/pull/4898) ([Artem Zuikov](https://github.com/4ertus2))
* Сообщения об ошибках об отсутствующих столбцах стали более понятными. [#4915](https://github.com/yandex/ClickHouse/pull/4915) ([Artem Zuikov](https://github.com/4ertus2))
### Улучшение производительности
* Существенное ускорение ASOF JOIN [#4924](https://github.com/yandex/ClickHouse/pull/4924) ([Martijn Bakker](https://github.com/Gladdy))
### Обратно несовместимые изменения
* HTTP заголовок `Query-Id` переименован в `X-ClickHouse-Query-Id` для соответствия. [#4972](https://github.com/yandex/ClickHouse/pull/4972) ([Mikhail](https://github.com/fandyushin))
### Исправления ошибок
* Исправлены возможные разыменования нулевого указателя в `clickhouse-copier`. [#4900](https://github.com/yandex/ClickHouse/pull/4900) ([proller](https://github.com/proller))
* Исправлены ошибки в запросах с JOIN + ARRAY JOIN [#4938](https://github.com/yandex/ClickHouse/pull/4938) ([Artem Zuikov](https://github.com/4ertus2))
* Исправлено зависание на старте сервера если внешний словарь зависит от другого словаря через использование таблицы из БД с движком `Dictionary`. [#4962](https://github.com/yandex/ClickHouse/pull/4962) ([Vitaly Baranov](https://github.com/vitlibar))
* При использовании `distributed_product_mode = 'local'` корректно работает использование столбцов локальных таблиц в where/having/order by/... через табличные алиасы. Выкидывает исключение если таблица не имеет алиас. Доступ к столбцам без алиасов пока не возможен. [#4986](https://github.com/yandex/ClickHouse/pull/4986) ([Artem Zuikov](https://github.com/4ertus2))
* Исправлен потенциально некорректный результат для `SELECT DISTINCT` с `JOIN` [#5001](https://github.com/yandex/ClickHouse/pull/5001) ([Artem Zuikov](https://github.com/4ertus2))
### Улучшения сборки/тестирования/пакетирования
* Исправлена неработоспособность тестов, если `clickhouse-server` запущен на удалённом хосте [#4713](https://github.com/yandex/ClickHouse/pull/4713) ([Vasily Nemkov](https://github.com/Enmk))
* `clickhouse-test`: Отключена раскраска результата, если команда запускается не в терминале. [#4937](https://github.com/yandex/ClickHouse/pull/4937) ([alesapin](https://github.com/alesapin))
* `clickhouse-test`: Возможность использования не только базы данных test [#5008](https://github.com/yandex/ClickHouse/pull/5008) ([proller](https://github.com/proller))
* Исправлены ошибки при запуске тестов под UBSan [#5037](https://github.com/yandex/ClickHouse/pull/5037) ([Vitaly Baranov](https://github.com/vitlibar))
* Добавлен аллокатор Yandex LFAlloc для аллоцирования MarkCache и UncompressedCache данных разными способами для более надежного отлавливания проездов по памяти [#4995](https://github.com/yandex/ClickHouse/pull/4995) ([Danila Kutenin](https://github.com/danlark1))
* Утилита для упрощения бэкпортирования изменений в старые релизы и составления changelogs. [#4949](https://github.com/yandex/ClickHouse/pull/4949) ([Ivan](https://github.com/abyss7))
## ClickHouse release 19.5.4.22, 2019-05-13
### Исправления ошибок
* Исправлены возможные падения в bitmap* функциях [#5220](https://github.com/yandex/ClickHouse/pull/5220) [#5228](https://github.com/yandex/ClickHouse/pull/5228) ([Andy Yang](https://github.com/andyyzh))
* Исправлен очень редкий data race condition который мог произойти при выполнении запроса с UNION ALL включающего минимум два SELECT из таблиц system.columns, system.tables, system.parts, system.parts_tables или таблиц семейства Merge и одновременно выполняющихся запросов ALTER столбцов соответствующих таблиц. [#5189](https://github.com/yandex/ClickHouse/pull/5189) ([alexey-milovidov](https://github.com/alexey-milovidov))
* Исправлена ошибка `Set for IN is not created yet in case of using single LowCardinality column in the left part of IN`. Эта ошибка возникала когда LowCardinality столбец была частью primary key. #5031 [#5154](https://github.com/yandex/ClickHouse/pull/5154) ([Nikolai Kochetov](https://github.com/KochetovNicolai))
* Исправление функции retention: только первое соответствующее условие добавлялось в состояние данных. Сейчас все условия которые удовлетворяют в строке данных добавляются в состояние. [#5119](https://github.com/yandex/ClickHouse/pull/5119) ([小路](https://github.com/nicelulu))
## ClickHouse release 19.5.3.8, 2019-04-18
### Исправления ошибок
@ -286,7 +336,7 @@
* Добавлена поддержка `Nullable` типов в табличной функции `mysql`. [#4198](https://github.com/yandex/ClickHouse/pull/4198) ([Emmanuel Donin de Rosière](https://github.com/edonin))
* Добавлена поддержка произвольных константных выражений в секции `LIMIT`. [#4246](https://github.com/yandex/ClickHouse/pull/4246) ([k3box](https://github.com/k3box))
* Добавлена агрегатная функция `topKWeighted` - вариант `topK`, позволяющий задавать (целый неотрицательный) вес добавляемого значения. [#4245](https://github.com/yandex/ClickHouse/pull/4245) ([Andrew Golman](https://github.com/andrewgolman))
* Движок `Join` теперь поддерживает настройку `join_overwrite`, которая позволяет перезаписывать значения для существующих ключей. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird))
* Движок `Join` теперь поддерживает настройку `join_any_take_last_row`, которая позволяет перезаписывать значения для существующих ключей. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird))
* Добавлена функция `toStartOfInterval`. [#4304](https://github.com/yandex/ClickHouse/pull/4304) ([Vitaly Baranov](https://github.com/vitlibar))
* Добавлена функция `toStartOfTenMinutes`. [#4298](https://github.com/yandex/ClickHouse/pull/4298) ([Vitaly Baranov](https://github.com/vitlibar))
* Добавлен формат `RowBinaryWithNamesAndTypes`. [#4200](https://github.com/yandex/ClickHouse/pull/4200) ([Oleg V. Kozlyuk](https://github.com/DarkWanderer))

View File

@ -93,6 +93,11 @@ if (COMPILER_GCC OR COMPILER_CLANG)
set (CXX_WARNING_FLAGS "${CXX_WARNING_FLAGS} -Wnon-virtual-dtor")
endif ()
if (COMPILER_GCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "8.3.0")
# Warnings in protobuf generating
set (CXX_WARNING_FLAGS "${CXX_WARNING_FLAGS} -Wno-array-bounds")
endif ()
if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# clang: warning: argument unused during compilation: '-stdlib=libc++'
# clang: warning: argument unused during compilation: '-specs=/usr/share/dpkg/no-pie-compile.specs' [-Wunused-command-line-argument]
@ -226,7 +231,7 @@ if (OS_LINUX AND NOT UNBUNDLED AND (GLIBC_COMPATIBILITY OR USE_LIBCXX))
set (CMAKE_POSTFIX_VARIABLE "CMAKE_${CMAKE_BUILD_TYPE_UC}_POSTFIX")
# FIXME: glibc-compatibility may be non-static in some builds!
set (DEFAULT_LIBS "${DEFAULT_LIBS} libs/libglibc-compatibility/libglibc-compatibility${${CMAKE_POSTFIX_VARIABLE}}.a")
set (DEFAULT_LIBS "${DEFAULT_LIBS} ${ClickHouse_BINARY_DIR}/libs/libglibc-compatibility/libglibc-compatibility${${CMAKE_POSTFIX_VARIABLE}}.a")
endif ()
# Add Libc. GLIBC is actually a collection of interdependent libraries.

View File

@ -5,7 +5,7 @@ endmacro()
macro(add_headers_and_sources prefix common_path)
add_glob(${prefix}_headers RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${common_path}/*.h)
add_glob(${prefix}_sources ${common_path}/*.cpp ${common_path}/*.h)
add_glob(${prefix}_sources ${common_path}/*.cpp ${common_path}/*.c ${common_path}/*.h)
endmacro()
macro(add_headers_only prefix common_path)

View File

@ -1,4 +1,5 @@
if (NOT SANITIZE AND NOT ARCH_ARM AND NOT ARCH_32 AND NOT ARCH_PPC64LE AND NOT OS_FREEBSD AND NOT APPLE)
# TODO(danlark1). Disable LFAlloc for a while to fix mmap count problem
if (NOT OS_LINUX AND NOT SANITIZE AND NOT ARCH_ARM AND NOT ARCH_32 AND NOT ARCH_PPC64LE AND NOT OS_FREEBSD AND NOT APPLE)
option (ENABLE_LFALLOC "Set to FALSE to use system libgsasl library instead of bundled" ${NOT_UNBUNDLED})
endif ()

View File

@ -1,8 +1,8 @@
# Third-party libraries may have substandard code.
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-function -Wno-unused-variable -Wno-unused-but-set-variable -Wno-unused-result -Wno-deprecated-declarations -Wno-maybe-uninitialized -Wno-format -Wno-misleading-indentation -Wno-stringop-overflow -Wno-implicit-function-declaration -Wno-return-type -Wno-array-bounds -Wno-bool-compare -Wno-int-conversion -Wno-switch")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-old-style-cast -Wno-unused-function -Wno-unused-variable -Wno-unused-but-set-variable -Wno-unused-result -Wno-deprecated-declarations -Wno-non-virtual-dtor -Wno-maybe-uninitialized -Wno-format -Wno-misleading-indentation -Wno-implicit-fallthrough -Wno-class-memaccess -Wno-sign-compare -std=c++1z")
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-function -Wno-unused-variable -Wno-unused-but-set-variable -Wno-unused-result -Wno-deprecated-declarations -Wno-maybe-uninitialized -Wno-format -Wno-misleading-indentation -Wno-stringop-overflow -Wno-implicit-function-declaration -Wno-return-type -Wno-array-bounds -Wno-bool-compare -Wno-int-conversion -Wno-switch -Wno-stringop-truncation")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-old-style-cast -Wno-unused-function -Wno-unused-variable -Wno-unused-but-set-variable -Wno-unused-result -Wno-deprecated-declarations -Wno-non-virtual-dtor -Wno-maybe-uninitialized -Wno-format -Wno-misleading-indentation -Wno-implicit-fallthrough -Wno-class-memaccess -Wno-sign-compare -Wno-array-bounds -Wno-missing-attributes -Wno-stringop-truncation -std=c++1z")
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-function -Wno-unused-variable -Wno-unused-result -Wno-deprecated-declarations -Wno-format -Wno-parentheses-equality -Wno-tautological-constant-compare -Wno-tautological-constant-out-of-range-compare -Wno-implicit-function-declaration -Wno-return-type -Wno-pointer-bool-conversion -Wno-enum-conversion -Wno-int-conversion -Wno-switch -Wno-string-plus-int")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-old-style-cast -Wno-unused-function -Wno-unused-variable -Wno-unused-result -Wno-deprecated-declarations -Wno-non-virtual-dtor -Wno-format -Wno-inconsistent-missing-override -std=c++1z")

2
contrib/boost vendored

@ -1 +1 @@
Subproject commit 79bf85ea99c05ba4fb6959474d4464ab126f8973
Subproject commit 8abda007bfe52d78a51548d4594879d6d82a22fa

2
contrib/hyperscan vendored

@ -1 +1 @@
Subproject commit 05b0f9064cca4bd55548dedb0a32ed9461146c1e
Subproject commit ed17d34a7c786512471946f9105eaa8d925f34c3

View File

@ -40,6 +40,10 @@ if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
list(APPEND SRCS ${JEMALLOC_SOURCE_DIR}/src/zone.c)
endif()
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -w")
endif ()
add_library(jemalloc STATIC ${SRCS})
target_include_directories(jemalloc PUBLIC

2
contrib/simdjson vendored

@ -1 +1 @@
Subproject commit 14cd1f7a0b0563db78bda8053a9f6ac2ea95a441
Subproject commit 2151ad7f34cf773a23f086e941d661f8a8873144

View File

@ -49,19 +49,20 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
endif ()
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
if (WEVERYTHING)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-return-std-move-in-c++11")
endif ()
endif ()
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wextra-semi-stmt -Wshadow-field -Wstring-plus-int -Wempty-init-stmt")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wshadow-field -Wstring-plus-int")
if(NOT APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wextra-semi-stmt -Wempty-init-stmt")
endif()
endif ()
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9)
if (WEVERYTHING)
if (WEVERYTHING AND NOT APPLE)
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-ctad-maybe-unsupported")
endif ()
endif ()

View File

@ -1,11 +1,11 @@
# This strings autochanged from release_lib.sh:
set(VERSION_REVISION 54420)
set(VERSION_REVISION 54421)
set(VERSION_MAJOR 19)
set(VERSION_MINOR 8)
set(VERSION_MINOR 9)
set(VERSION_PATCH 1)
set(VERSION_GITHASH a76e504f45ff4a74e8c492bd269f022352d5f6d9)
set(VERSION_DESCRIBE v19.8.1.1-testing)
set(VERSION_STRING 19.8.1.1)
set(VERSION_GITHASH 0c2aa460651a462f14efc7e995840a244531d373)
set(VERSION_DESCRIBE v19.9.1.1-testing)
set(VERSION_STRING 19.9.1.1)
# end of autochange
set(VERSION_EXTRA "" CACHE STRING "")

View File

@ -325,8 +325,8 @@ private:
double seconds = watch.elapsedSeconds();
std::lock_guard lock(mutex);
info_per_interval.add(seconds, progress.rows, progress.bytes, info.rows, info.bytes);
info_total.add(seconds, progress.rows, progress.bytes, info.rows, info.bytes);
info_per_interval.add(seconds, progress.read_rows, progress.read_bytes, info.rows, info.bytes);
info_total.add(seconds, progress.read_rows, progress.read_bytes, info.rows, info.bytes);
}

View File

@ -26,8 +26,8 @@ elseif (EXISTS ${INTERNAL_COMPILER_BIN_ROOT}${INTERNAL_COMPILER_EXECUTABLE})
set (COPY_HEADERS_COMPILER "${INTERNAL_COMPILER_BIN_ROOT}${INTERNAL_COMPILER_EXECUTABLE}")
endif ()
if (COPY_HEADERS_COMPILER AND OS_LINUX)
add_custom_target (copy-headers [ -f ${TMP_HEADERS_DIR}/dbms/src/Interpreters/SpecializedAggregator.h ] || env CLANG=${COPY_HEADERS_COMPILER} BUILD_PATH=${ClickHouse_BINARY_DIR} DESTDIR=${ClickHouse_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/copy_headers.sh ${ClickHouse_SOURCE_DIR} ${TMP_HEADERS_DIR} DEPENDS ${COPY_HEADERS_DEPENDS} WORKING_DIRECTORY ${ClickHouse_SOURCE_DIR} SOURCES copy_headers.sh)
if (COPY_HEADERS_COMPILER)
add_custom_target (copy-headers [ -f ${TMP_HEADERS_DIR}/dbms/src/Interpreters/SpecializedAggregator.h ] || env CLANG=${COPY_HEADERS_COMPILER} BUILD_PATH=${ClickHouse_BINARY_DIR} DESTDIR=${ClickHouse_SOURCE_DIR} CMAKE_CXX_COMPILER_VERSION=${CMAKE_CXX_COMPILER_VERSION} ${CMAKE_CURRENT_SOURCE_DIR}/copy_headers.sh ${ClickHouse_SOURCE_DIR} ${TMP_HEADERS_DIR} DEPENDS ${COPY_HEADERS_DEPENDS} WORKING_DIRECTORY ${ClickHouse_SOURCE_DIR} SOURCES copy_headers.sh)
if (USE_INTERNAL_LLVM_LIBRARY)
set (CLANG_HEADERS_DIR "${ClickHouse_SOURCE_DIR}/contrib/llvm/clang/lib/Headers")

View File

@ -38,26 +38,28 @@ for header in $START_HEADERS; do
START_HEADERS_INCLUDE+="-include $header "
done
# Опция -mcx16 для того, чтобы выбиралось больше заголовочных файлов (с запасом).
# The latter options are the same that are added while building packages.
# TODO: Does not work on macos:
GCC_ROOT=`$CLANG -v 2>&1 | grep "Selected GCC installation"| sed -n -e 's/^.*: //p'`
for src_file in $(echo | $CLANG -M -xc++ -std=c++1z -Wall -Werror -msse4 -mcx16 -mpopcnt -O3 -g -fPIC -fstack-protector -D_FORTIFY_SOURCE=2 \
# TODO: Does not work on macos?
GCC_ROOT=${GCC_ROOT:=/usr/lib/clang/${CMAKE_CXX_COMPILER_VERSION}}
# Опция -mcx16 для того, чтобы выбиралось больше заголовочных файлов (с запасом).
# The latter options are the same that are added while building packages.
for src_file in $(echo | $CLANG -M -xc++ -std=c++1z -Wall -Werror -msse2 -msse4 -mcx16 -mpopcnt -O3 -g -fPIC -fstack-protector -D_FORTIFY_SOURCE=2 \
-I $GCC_ROOT/include \
-I $GCC_ROOT/include-fixed \
$(cat "$BUILD_PATH/include_directories.txt") \
$START_HEADERS_INCLUDE \
- |
tr -d '\\' |
sed --posix -E -e 's/^-\.o://');
sed -E -e 's/^-\.o://');
do
dst_file=$src_file;
[ -n $BUILD_PATH ] && dst_file=$(echo $dst_file | sed --posix -E -e "s!^$BUILD_PATH!!")
[ -n $DESTDIR ] && dst_file=$(echo $dst_file | sed --posix -E -e "s!^$DESTDIR!!")
dst_file=$(echo $dst_file | sed --posix -E -e 's/build\///') # for simplicity reasons, will put generated headers near the rest.
mkdir -p "$DST/$(echo $dst_file | sed --posix -E -e 's/\/[^/]*$/\//')";
[ -n $BUILD_PATH ] && dst_file=$(echo $dst_file | sed -E -e "s!^$BUILD_PATH!!")
[ -n $DESTDIR ] && dst_file=$(echo $dst_file | sed -E -e "s!^$DESTDIR!!")
dst_file=$(echo $dst_file | sed -E -e 's/build\///') # for simplicity reasons, will put generated headers near the rest.
mkdir -p "$DST/$(echo $dst_file | sed -E -e 's/\/[^/]*$/\//')";
cp "$src_file" "$DST/$dst_file";
done
@ -68,9 +70,9 @@ done
for src_file in $(ls -1 $($CLANG -v -xc++ - <<<'' 2>&1 | grep '^ /' | grep 'include' | grep -E '/lib/clang/|/include/clang/')/*.h | grep -vE 'arm|altivec|Intrin');
do
dst_file=$src_file;
[ -n $BUILD_PATH ] && dst_file=$(echo $dst_file | sed --posix -E -e "s!^$BUILD_PATH!!")
[ -n $DESTDIR ] && dst_file=$(echo $dst_file | sed --posix -E -e "s!^$DESTDIR!!")
mkdir -p "$DST/$(echo $dst_file | sed --posix -E -e 's/\/[^/]*$/\//')";
[ -n $BUILD_PATH ] && dst_file=$(echo $dst_file | sed -E -e "s!^$BUILD_PATH!!")
[ -n $DESTDIR ] && dst_file=$(echo $dst_file | sed -E -e "s!^$DESTDIR!!")
mkdir -p "$DST/$(echo $dst_file | sed -E -e 's/\/[^/]*$/\//')";
cp "$src_file" "$DST/$dst_file";
done
@ -79,9 +81,9 @@ if [ -d "$SOURCE_PATH/contrib/boost/libs/smart_ptr/include/boost/smart_ptr/detai
for src_file in $(ls -1 $SOURCE_PATH/contrib/boost/libs/smart_ptr/include/boost/smart_ptr/detail/*);
do
dst_file=$src_file;
[ -n $BUILD_PATH ] && dst_file=$(echo $dst_file | sed --posix -E -e "s!^$BUILD_PATH!!")
[ -n $DESTDIR ] && dst_file=$(echo $dst_file | sed --posix -E -e "s!^$DESTDIR!!")
mkdir -p "$DST/$(echo $dst_file | sed --posix -E -e 's/\/[^/]*$/\//')";
[ -n $BUILD_PATH ] && dst_file=$(echo $dst_file | sed -E -e "s!^$BUILD_PATH!!")
[ -n $DESTDIR ] && dst_file=$(echo $dst_file | sed -E -e "s!^$DESTDIR!!")
mkdir -p "$DST/$(echo $dst_file | sed -E -e 's/\/[^/]*$/\//')";
cp "$src_file" "$DST/$dst_file";
done
fi
@ -90,9 +92,9 @@ if [ -d "$SOURCE_PATH/contrib/boost/boost/smart_ptr/detail" ]; then
for src_file in $(ls -1 $SOURCE_PATH/contrib/boost/boost/smart_ptr/detail/*);
do
dst_file=$src_file;
[ -n $BUILD_PATH ] && dst_file=$(echo $dst_file | sed --posix -E -e "s!^$BUILD_PATH!!")
[ -n $DESTDIR ] && dst_file=$(echo $dst_file | sed --posix -E -e "s!^$DESTDIR!!")
mkdir -p "$DST/$(echo $dst_file | sed --posix -E -e 's/\/[^/]*$/\//')";
[ -n $BUILD_PATH ] && dst_file=$(echo $dst_file | sed -E -e "s!^$BUILD_PATH!!")
[ -n $DESTDIR ] && dst_file=$(echo $dst_file | sed -E -e "s!^$DESTDIR!!")
mkdir -p "$DST/$(echo $dst_file | sed -E -e 's/\/[^/]*$/\//')";
cp "$src_file" "$DST/$dst_file";
done
fi

View File

@ -1,6 +1,19 @@
set(CLICKHOUSE_CLIENT_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/Client.cpp)
set(CLICKHOUSE_CLIENT_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/Client.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ConnectionParameters.cpp
)
set(CLICKHOUSE_CLIENT_LINK PRIVATE clickhouse_common_config clickhouse_functions clickhouse_aggregate_functions clickhouse_common_io ${LINE_EDITING_LIBS} ${Boost_PROGRAM_OPTIONS_LIBRARY})
set(CLICKHOUSE_CLIENT_INCLUDE SYSTEM PRIVATE ${READLINE_INCLUDE_DIR})
set(CLICKHOUSE_CLIENT_INCLUDE SYSTEM PRIVATE ${READLINE_INCLUDE_DIR} PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/include)
include(CheckSymbolExists)
check_symbol_exists(readpassphrase readpassphrase.h HAVE_READPASSPHRASE)
configure_file(config_client.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/config_client.h)
if(NOT HAVE_READPASSPHRASE)
add_subdirectory(readpassphrase)
list(APPEND CLICKHOUSE_CLIENT_LINK PRIVATE readpassphrase)
endif()
clickhouse_program_add(client)

View File

@ -435,7 +435,7 @@ private:
#if USE_READLINE
int res = read_history(history_file.c_str());
if (res)
throwFromErrno("Cannot read history from file " + history_file, ErrorCodes::CANNOT_READ_HISTORY);
std::cerr << "Cannot read history from file " + history_file + ": "+ errnoToString(ErrorCodes::CANNOT_READ_HISTORY);
#endif
}
else /// Create history file.
@ -612,7 +612,7 @@ private:
#if USE_READLINE && HAVE_READLINE_HISTORY
if (!history_file.empty() && append_history(1, history_file.c_str()))
throwFromErrno("Cannot append history to file " + history_file, ErrorCodes::CANNOT_APPEND_HISTORY);
std::cerr << "Cannot append history to file " + history_file + ": " + errnoToString(ErrorCodes::CANNOT_APPEND_HISTORY);
#endif
prev_input = input;
@ -866,7 +866,7 @@ private:
std::cout << std::endl
<< processed_rows << " rows in set. Elapsed: " << watch.elapsedSeconds() << " sec. ";
if (progress.rows >= 1000)
if (progress.read_rows >= 1000)
writeFinalProgress();
std::cout << std::endl << std::endl;
@ -1420,23 +1420,23 @@ private:
<< " Progress: ";
message
<< formatReadableQuantity(progress.rows) << " rows, "
<< formatReadableSizeWithDecimalSuffix(progress.bytes);
<< formatReadableQuantity(progress.read_rows) << " rows, "
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes);
size_t elapsed_ns = watch.elapsed();
if (elapsed_ns)
message << " ("
<< formatReadableQuantity(progress.rows * 1000000000.0 / elapsed_ns) << " rows/s., "
<< formatReadableSizeWithDecimalSuffix(progress.bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
<< formatReadableQuantity(progress.read_rows * 1000000000.0 / elapsed_ns) << " rows/s., "
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
else
message << ". ";
written_progress_chars = message.count() - prefix_size - (increment % 8 == 7 ? 10 : 13); /// Don't count invisible output (escape sequences).
/// If the approximate number of rows to process is known, we can display a progress bar and percentage.
if (progress.total_rows > 0)
if (progress.total_rows_to_read > 0)
{
size_t total_rows_corrected = std::max(progress.rows, progress.total_rows);
size_t total_rows_corrected = std::max(progress.read_rows, progress.total_rows_to_read);
/// To avoid flicker, display progress bar only if .5 seconds have passed since query execution start
/// and the query is less than halfway done.
@ -1444,7 +1444,7 @@ private:
if (elapsed_ns > 500000000)
{
/// Trigger to start displaying progress bar. If query is mostly done, don't display it.
if (progress.rows * 2 < total_rows_corrected)
if (progress.read_rows * 2 < total_rows_corrected)
show_progress_bar = true;
if (show_progress_bar)
@ -1452,7 +1452,7 @@ private:
ssize_t width_of_progress_bar = static_cast<ssize_t>(terminal_size.ws_col) - written_progress_chars - strlen(" 99%");
if (width_of_progress_bar > 0)
{
std::string bar = UnicodeBar::render(UnicodeBar::getWidth(progress.rows, 0, total_rows_corrected, width_of_progress_bar));
std::string bar = UnicodeBar::render(UnicodeBar::getWidth(progress.read_rows, 0, total_rows_corrected, width_of_progress_bar));
message << "\033[0;32m" << bar << "\033[0m";
if (width_of_progress_bar > static_cast<ssize_t>(bar.size() / UNICODE_BAR_CHAR_SIZE))
message << std::string(width_of_progress_bar - bar.size() / UNICODE_BAR_CHAR_SIZE, ' ');
@ -1461,7 +1461,7 @@ private:
}
/// Underestimate percentage a bit to avoid displaying 100%.
message << ' ' << (99 * progress.rows / total_rows_corrected) << '%';
message << ' ' << (99 * progress.read_rows / total_rows_corrected) << '%';
}
message << ENABLE_LINE_WRAPPING;
@ -1474,14 +1474,14 @@ private:
void writeFinalProgress()
{
std::cout << "Processed "
<< formatReadableQuantity(progress.rows) << " rows, "
<< formatReadableSizeWithDecimalSuffix(progress.bytes);
<< formatReadableQuantity(progress.read_rows) << " rows, "
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes);
size_t elapsed_ns = watch.elapsed();
if (elapsed_ns)
std::cout << " ("
<< formatReadableQuantity(progress.rows * 1000000000.0 / elapsed_ns) << " rows/s., "
<< formatReadableSizeWithDecimalSuffix(progress.bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
<< formatReadableQuantity(progress.read_rows * 1000000000.0 / elapsed_ns) << " rows/s., "
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
else
std::cout << ". ";
}

View File

@ -0,0 +1,63 @@
#include "ConnectionParameters.h"
#include <fstream>
#include <iostream>
#include <Core/Defines.h>
#include <Core/Protocol.h>
#include <Core/Types.h>
#include <IO/ConnectionTimeouts.h>
#include <Poco/Util/AbstractConfiguration.h>
#include <Common/Exception.h>
#include <common/setTerminalEcho.h>
#include <ext/scope_guard.h>
#include <readpassphrase.h>
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfiguration & config)
{
bool is_secure = config.getBool("secure", false);
security = is_secure ? Protocol::Secure::Enable : Protocol::Secure::Disable;
host = config.getString("host", "localhost");
port = config.getInt(
"port", config.getInt(is_secure ? "tcp_port_secure" : "tcp_port", is_secure ? DBMS_DEFAULT_SECURE_PORT : DBMS_DEFAULT_PORT));
default_database = config.getString("database", "");
/// changed the default value to "default" to fix the issue when the user in the prompt is blank
user = config.getString("user", "default");
bool password_prompt = false;
if (config.getBool("ask-password", false))
{
if (config.has("password"))
throw Exception("Specified both --password and --ask-password. Remove one of them", ErrorCodes::BAD_ARGUMENTS);
password_prompt = true;
}
else
{
password = config.getString("password", "");
/// if the value of --password is omitted, the password will be set implicitly to "\n"
if (password == "\n")
password_prompt = true;
}
if (password_prompt)
{
std::string prompt{"Password for user (" + user + "): "};
char buf[1000] = {};
if (auto result = readpassphrase(prompt.c_str(), buf, sizeof(buf), 0))
password = result;
}
compression = config.getBool("compression", true) ? Protocol::Compression::Enable : Protocol::Compression::Disable;
timeouts = ConnectionTimeouts(
Poco::Timespan(config.getInt("connect_timeout", DBMS_DEFAULT_CONNECT_TIMEOUT_SEC), 0),
Poco::Timespan(config.getInt("send_timeout", DBMS_DEFAULT_SEND_TIMEOUT_SEC), 0),
Poco::Timespan(config.getInt("receive_timeout", DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC), 0),
Poco::Timespan(config.getInt("tcp_keep_alive_timeout", 0), 0));
}
}

View File

@ -1,90 +1,30 @@
#pragma once
#include <iostream>
#include <Core/Types.h>
#include <string>
#include <Core/Protocol.h>
#include <Core/Defines.h>
#include <Common/Exception.h>
#include <IO/ConnectionTimeouts.h>
#include <common/setTerminalEcho.h>
#include <ext/scope_guard.h>
#include <Poco/Util/AbstractConfiguration.h>
namespace Poco::Util
{
class AbstractConfiguration;
}
namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}
struct ConnectionParameters
{
String host;
std::string host;
UInt16 port{};
String default_database;
String user;
String password;
std::string default_database;
std::string user;
std::string password;
Protocol::Secure security = Protocol::Secure::Disable;
Protocol::Compression compression = Protocol::Compression::Enable;
ConnectionTimeouts timeouts;
ConnectionParameters() {}
ConnectionParameters(const Poco::Util::AbstractConfiguration & config)
{
bool is_secure = config.getBool("secure", false);
security = is_secure
? Protocol::Secure::Enable
: Protocol::Secure::Disable;
host = config.getString("host", "localhost");
port = config.getInt("port",
config.getInt(is_secure ? "tcp_port_secure" : "tcp_port",
is_secure ? DBMS_DEFAULT_SECURE_PORT : DBMS_DEFAULT_PORT));
default_database = config.getString("database", "");
/// changed the default value to "default" to fix the issue when the user in the prompt is blank
user = config.getString("user", "default");
bool password_prompt = false;
if (config.getBool("ask-password", false))
{
if (config.has("password"))
throw Exception("Specified both --password and --ask-password. Remove one of them", ErrorCodes::BAD_ARGUMENTS);
password_prompt = true;
}
else
{
password = config.getString("password", "");
/// if the value of --password is omitted, the password will be set implicitly to "\n"
if (password == "\n")
password_prompt = true;
}
if (password_prompt)
{
std::cout << "Password for user (" << user << "): ";
setTerminalEcho(false);
SCOPE_EXIT({
setTerminalEcho(true);
});
std::getline(std::cin, password);
std::cout << std::endl;
}
compression = config.getBool("compression", true)
? Protocol::Compression::Enable
: Protocol::Compression::Disable;
timeouts = ConnectionTimeouts(
Poco::Timespan(config.getInt("connect_timeout", DBMS_DEFAULT_CONNECT_TIMEOUT_SEC), 0),
Poco::Timespan(config.getInt("send_timeout", DBMS_DEFAULT_SEND_TIMEOUT_SEC), 0),
Poco::Timespan(config.getInt("receive_timeout", DBMS_DEFAULT_RECEIVE_TIMEOUT_SEC), 0),
Poco::Timespan(config.getInt("tcp_keep_alive_timeout", 0), 0));
}
ConnectionParameters(const Poco::Util::AbstractConfiguration & config);
};
}

View File

@ -0,0 +1,3 @@
#pragma once
#cmakedefine HAVE_READPASSPHRASE

View File

@ -0,0 +1,10 @@
# wget https://raw.githubusercontent.com/openssh/openssh-portable/master/openbsd-compat/readpassphrase.c
# wget https://raw.githubusercontent.com/openssh/openssh-portable/master/openbsd-compat/readpassphrase.h
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unused-result -Wno-reserved-id-macro")
configure_file(includes.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/includes.h)
add_library(readpassphrase ${CMAKE_CURRENT_SOURCE_DIR}/readpassphrase.c)
# . to allow #include <readpassphrase.h>
target_include_directories(readpassphrase PUBLIC . ${CMAKE_CURRENT_BINARY_DIR}/include ${CMAKE_CURRENT_BINARY_DIR}/../include)

View File

@ -0,0 +1,9 @@
#pragma once
#cmakedefine HAVE_READPASSPHRASE
#if !defined(HAVE_READPASSPHRASE)
# ifndef _PATH_TTY
# define _PATH_TTY "/dev/tty"
# endif
#endif

View File

@ -0,0 +1,211 @@
/* $OpenBSD: readpassphrase.c,v 1.26 2016/10/18 12:47:18 millert Exp $ */
/*
* Copyright (c) 2000-2002, 2007, 2010
* Todd C. Miller <Todd.Miller@courtesan.com>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*
* Sponsored in part by the Defense Advanced Research Projects
* Agency (DARPA) and Air Force Research Laboratory, Air Force
* Materiel Command, USAF, under agreement number F39502-99-1-0512.
*/
/* OPENBSD ORIGINAL: lib/libc/gen/readpassphrase.c */
#include "includes.h"
#ifndef HAVE_READPASSPHRASE
#include <termios.h>
#include <signal.h>
#include <ctype.h>
#include <fcntl.h>
#include <readpassphrase.h>
#include <errno.h>
#include <string.h>
#include <unistd.h>
#ifndef TCSASOFT
/* If we don't have TCSASOFT define it so that ORing it it below is a no-op. */
# define TCSASOFT 0
#endif
/* SunOS 4.x which lacks _POSIX_VDISABLE, but has VDISABLE */
#if !defined(_POSIX_VDISABLE) && defined(VDISABLE)
# define _POSIX_VDISABLE VDISABLE
#endif
static volatile sig_atomic_t signo[_NSIG];
static void handler(int);
char *
readpassphrase(const char *prompt, char *buf, size_t bufsiz, int flags)
{
ssize_t nr;
int input, output, save_errno, i, need_restart;
char ch, *p, *end;
struct termios term, oterm;
struct sigaction sa, savealrm, saveint, savehup, savequit, saveterm;
struct sigaction savetstp, savettin, savettou, savepipe;
/* I suppose we could alloc on demand in this case (XXX). */
if (bufsiz == 0) {
errno = EINVAL;
return(NULL);
}
restart:
for (i = 0; i < _NSIG; i++)
signo[i] = 0;
nr = -1;
save_errno = 0;
need_restart = 0;
/*
* Read and write to /dev/tty if available. If not, read from
* stdin and write to stderr unless a tty is required.
*/
if ((flags & RPP_STDIN) ||
(input = output = open(_PATH_TTY, O_RDWR)) == -1) {
if (flags & RPP_REQUIRE_TTY) {
errno = ENOTTY;
return(NULL);
}
input = STDIN_FILENO;
output = STDERR_FILENO;
}
/*
* Turn off echo if possible.
* If we are using a tty but are not the foreground pgrp this will
* generate SIGTTOU, so do it *before* installing the signal handlers.
*/
if (input != STDIN_FILENO && tcgetattr(input, &oterm) == 0) {
memcpy(&term, &oterm, sizeof(term));
if (!(flags & RPP_ECHO_ON))
term.c_lflag &= ~(ECHO | ECHONL);
#ifdef VSTATUS
if (term.c_cc[VSTATUS] != _POSIX_VDISABLE)
term.c_cc[VSTATUS] = _POSIX_VDISABLE;
#endif
(void)tcsetattr(input, TCSAFLUSH|TCSASOFT, &term);
} else {
memset(&term, 0, sizeof(term));
term.c_lflag |= ECHO;
memset(&oterm, 0, sizeof(oterm));
oterm.c_lflag |= ECHO;
}
/*
* Catch signals that would otherwise cause the user to end
* up with echo turned off in the shell. Don't worry about
* things like SIGXCPU and SIGVTALRM for now.
*/
sigemptyset(&sa.sa_mask);
sa.sa_flags = 0; /* don't restart system calls */
sa.sa_handler = handler;
(void)sigaction(SIGALRM, &sa, &savealrm);
(void)sigaction(SIGHUP, &sa, &savehup);
(void)sigaction(SIGINT, &sa, &saveint);
(void)sigaction(SIGPIPE, &sa, &savepipe);
(void)sigaction(SIGQUIT, &sa, &savequit);
(void)sigaction(SIGTERM, &sa, &saveterm);
(void)sigaction(SIGTSTP, &sa, &savetstp);
(void)sigaction(SIGTTIN, &sa, &savettin);
(void)sigaction(SIGTTOU, &sa, &savettou);
if (!(flags & RPP_STDIN))
(void)write(output, prompt, strlen(prompt));
end = buf + bufsiz - 1;
p = buf;
while ((nr = read(input, &ch, 1)) == 1 && ch != '\n' && ch != '\r') {
if (p < end) {
if ((flags & RPP_SEVENBIT))
ch &= 0x7f;
if (isalpha((unsigned char)ch)) {
if ((flags & RPP_FORCELOWER))
ch = (char)tolower((unsigned char)ch);
if ((flags & RPP_FORCEUPPER))
ch = (char)toupper((unsigned char)ch);
}
*p++ = ch;
}
}
*p = '\0';
save_errno = errno;
if (!(term.c_lflag & ECHO))
(void)write(output, "\n", 1);
/* Restore old terminal settings and signals. */
if (memcmp(&term, &oterm, sizeof(term)) != 0) {
const int sigttou = signo[SIGTTOU];
/* Ignore SIGTTOU generated when we are not the fg pgrp. */
while (tcsetattr(input, TCSAFLUSH|TCSASOFT, &oterm) == -1 &&
errno == EINTR && !signo[SIGTTOU])
continue;
signo[SIGTTOU] = sigttou;
}
(void)sigaction(SIGALRM, &savealrm, NULL);
(void)sigaction(SIGHUP, &savehup, NULL);
(void)sigaction(SIGINT, &saveint, NULL);
(void)sigaction(SIGQUIT, &savequit, NULL);
(void)sigaction(SIGPIPE, &savepipe, NULL);
(void)sigaction(SIGTERM, &saveterm, NULL);
(void)sigaction(SIGTSTP, &savetstp, NULL);
(void)sigaction(SIGTTIN, &savettin, NULL);
(void)sigaction(SIGTTOU, &savettou, NULL);
if (input != STDIN_FILENO)
(void)close(input);
/*
* If we were interrupted by a signal, resend it to ourselves
* now that we have restored the signal handlers.
*/
for (i = 0; i < _NSIG; i++) {
if (signo[i]) {
kill(getpid(), i);
switch (i) {
case SIGTSTP:
case SIGTTIN:
case SIGTTOU:
need_restart = 1;
}
}
}
if (need_restart)
goto restart;
if (save_errno)
errno = save_errno;
return(nr == -1 ? NULL : buf);
}
//DEF_WEAK(readpassphrase);
#if 0
char *
getpass(const char *prompt)
{
static char buf[_PASSWORD_LEN + 1];
return(readpassphrase(prompt, buf, sizeof(buf), RPP_ECHO_OFF));
}
#endif
static void handler(int s)
{
signo[s] = 1;
}
#endif /* HAVE_READPASSPHRASE */

View File

@ -0,0 +1,56 @@
// /* $OpenBSD: readpassphrase.h,v 1.5 2003/06/17 21:56:23 millert Exp $ */
/*
* Copyright (c) 2000, 2002 Todd C. Miller <Todd.Miller@courtesan.com>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*
* Sponsored in part by the Defense Advanced Research Projects
* Agency (DARPA) and Air Force Research Laboratory, Air Force
* Materiel Command, USAF, under agreement number F39502-99-1-0512.
*/
/* OPENBSD ORIGINAL: include/readpassphrase.h */
#pragma once
// #ifndef _READPASSPHRASE_H_
// #define _READPASSPHRASE_H_
//#include "includes.h"
#include "config_client.h"
#ifndef HAVE_READPASSPHRASE
# ifdef __cplusplus
extern "C" {
# endif
# define RPP_ECHO_OFF 0x00 /* Turn off echo (default). */
# define RPP_ECHO_ON 0x01 /* Leave echo on. */
# define RPP_REQUIRE_TTY 0x02 /* Fail if there is no tty. */
# define RPP_FORCELOWER 0x04 /* Force input to lower case. */
# define RPP_FORCEUPPER 0x08 /* Force input to upper case. */
# define RPP_SEVENBIT 0x10 /* Strip the high bit from input. */
# define RPP_STDIN 0x20 /* Read from stdin, not /dev/tty */
char * readpassphrase(const char *, char *, size_t, int);
# ifdef __cplusplus
}
# endif
#endif /* HAVE_READPASSPHRASE */
// #endif /* !_READPASSPHRASE_H_ */

View File

@ -1,6 +1,7 @@
#include "PerformanceTest.h"
#include <Core/Types.h>
#include <Common/CpuId.h>
#include <common/getMemoryAmount.h>
#include <IO/ReadBufferFromFile.h>
#include <IO/ReadHelpers.h>
@ -71,6 +72,7 @@ bool PerformanceTest::checkPreconditions() const
Strings preconditions;
config->keys("preconditions", preconditions);
size_t table_precondition_index = 0;
size_t cpu_precondition_index = 0;
for (const std::string & precondition : preconditions)
{
@ -136,6 +138,30 @@ bool PerformanceTest::checkPreconditions() const
return false;
}
}
if (precondition == "cpu")
{
std::string precondition_key = "preconditions.cpu[" + std::to_string(cpu_precondition_index++) + "]";
std::string flag_to_check = config->getString(precondition_key);
#define CHECK_CPU_PRECONDITION(OP) \
if (flag_to_check == #OP) \
{ \
if (!Cpu::CpuFlagsCache::have_##OP) \
{ \
LOG_WARNING(log, "CPU doesn't support " << #OP); \
return false; \
} \
} else
CPU_ID_ENUMERATE(CHECK_CPU_PRECONDITION)
{
LOG_WARNING(log, "CPU doesn't support " << flag_to_check);
return false;
}
#undef CHECK_CPU_PRECONDITION
}
}
return true;
@ -159,17 +185,9 @@ UInt64 PerformanceTest::calculateMaxExecTime() const
void PerformanceTest::prepare() const
{
for (const auto & query : test_info.create_queries)
for (const auto & query : test_info.create_and_fill_queries)
{
LOG_INFO(log, "Executing create query \"" << query << '\"');
connection.sendQuery(query, "", QueryProcessingStage::Complete, &test_info.settings, nullptr, false);
waitQuery(connection);
LOG_INFO(log, "Query finished");
}
for (const auto & query : test_info.fill_queries)
{
LOG_INFO(log, "Executing fill query \"" << query << '\"');
LOG_INFO(log, "Executing create or fill query \"" << query << '\"');
connection.sendQuery(query, "", QueryProcessingStage::Complete, &test_info.settings, nullptr, false);
waitQuery(connection);
LOG_INFO(log, "Query finished");

View File

@ -30,11 +30,6 @@ public:
std::vector<TestStats> execute();
void finish() const;
const PerformanceTestInfo & getTestInfo() const
{
return test_info;
}
bool checkSIGINT() const
{
return got_SIGINT;

View File

@ -60,10 +60,10 @@ PerformanceTestInfo::PerformanceTestInfo(
applySettings(config);
extractQueries(config);
extractAuxiliaryQueries(config);
processSubstitutions(config);
getExecutionType(config);
getStopConditions(config);
extractAuxiliaryQueries(config);
}
void PerformanceTestInfo::applySettings(XMLConfigurationPtr config)
@ -153,13 +153,29 @@ void PerformanceTestInfo::processSubstitutions(XMLConfigurationPtr config)
ConfigurationPtr substitutions_view(config->createView("substitutions"));
constructSubstitutions(substitutions_view, substitutions);
auto queries_pre_format = queries;
auto create_and_fill_queries_preformat = create_and_fill_queries;
create_and_fill_queries.clear();
for (const auto & query : create_and_fill_queries_preformat)
{
auto formatted = formatQueries(query, substitutions);
create_and_fill_queries.insert(create_and_fill_queries.end(), formatted.begin(), formatted.end());
}
auto queries_preformat = queries;
queries.clear();
for (const auto & query : queries_pre_format)
for (const auto & query : queries_preformat)
{
auto formatted = formatQueries(query, substitutions);
queries.insert(queries.end(), formatted.begin(), formatted.end());
}
auto drop_queries_preformat = drop_queries;
drop_queries.clear();
for (const auto & query : drop_queries_preformat)
{
auto formatted = formatQueries(query, substitutions);
drop_queries.insert(drop_queries.end(), formatted.begin(), formatted.end());
}
}
}
@ -203,13 +219,20 @@ void PerformanceTestInfo::getStopConditions(XMLConfigurationPtr config)
void PerformanceTestInfo::extractAuxiliaryQueries(XMLConfigurationPtr config)
{
if (config->has("create_query"))
create_queries = getMultipleValuesFromConfig(*config, "", "create_query");
{
create_and_fill_queries = getMultipleValuesFromConfig(*config, "", "create_query");
}
if (config->has("fill_query"))
fill_queries = getMultipleValuesFromConfig(*config, "", "fill_query");
{
auto fill_queries = getMultipleValuesFromConfig(*config, "", "fill_query");
create_and_fill_queries.insert(create_and_fill_queries.end(), fill_queries.begin(), fill_queries.end());
}
if (config->has("drop_query"))
{
drop_queries = getMultipleValuesFromConfig(*config, "", "drop_query");
}
}
}

View File

@ -42,8 +42,7 @@ public:
std::vector<TestStopConditions> stop_conditions_by_run;
Strings create_queries;
Strings fill_queries;
Strings create_and_fill_queries;
Strings drop_queries;
private:
@ -52,7 +51,6 @@ private:
void processSubstitutions(XMLConfigurationPtr config);
void getExecutionType(XMLConfigurationPtr config);
void getStopConditions(XMLConfigurationPtr config);
void getMetrics(XMLConfigurationPtr config);
void extractAuxiliaryQueries(XMLConfigurationPtr config);
};

View File

@ -202,8 +202,7 @@ private:
LOG_INFO(log, "Preconditions for test '" << info.test_name << "' are fullfilled");
LOG_INFO(
log,
"Preparing for run, have " << info.create_queries.size() << " create queries and " << info.fill_queries.size()
<< " fill queries");
"Preparing for run, have " << info.create_and_fill_queries.size() << " create and fill queries");
current.prepare();
LOG_INFO(log, "Prepared");
LOG_INFO(log, "Running test '" << info.test_name << "'");

View File

@ -13,7 +13,7 @@ void checkFulfilledConditionsAndUpdate(
TestStats & statistics, TestStopConditions & stop_conditions,
InterruptListener & interrupt_listener)
{
statistics.add(progress.rows, progress.bytes);
statistics.add(progress.read_rows, progress.read_bytes);
stop_conditions.reportRowsRead(statistics.total_rows_read);
stop_conditions.reportBytesReadUncompressed(statistics.total_bytes_read);

View File

@ -8,6 +8,8 @@ set(CLICKHOUSE_SERVER_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/RootRequestHandler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Server.cpp
${CMAKE_CURRENT_SOURCE_DIR}/TCPHandler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandler.cpp
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandlerFactory.cpp
)
set(CLICKHOUSE_SERVER_LINK PRIVATE clickhouse_dictionaries clickhouse_common_io PUBLIC daemon PRIVATE clickhouse_storages_system clickhouse_functions clickhouse_aggregate_functions clickhouse_table_functions ${Poco_Net_LIBRARY})

View File

@ -453,30 +453,6 @@ void HTTPHandler::processQuery(
return false;
};
/// Used in case of POST request with form-data, but it isn't expected to be deleted after that scope.
std::string full_query;
/// Support for "external data for query processing".
if (startsWith(request.getContentType().data(), "multipart/form-data"))
{
ExternalTablesHandler handler(context, params);
params.load(request, istr, handler);
/// Skip unneeded parameters to avoid confusing them later with context settings or query parameters.
reserved_param_suffixes.emplace_back("_format");
reserved_param_suffixes.emplace_back("_types");
reserved_param_suffixes.emplace_back("_structure");
/// Params are of both form params POST and uri (GET params)
for (const auto & it : params)
if (it.first == "query")
full_query += it.second;
in = std::make_unique<ReadBufferFromString>(full_query);
}
else
in = std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
/// Settings can be overridden in the query.
/// Some parameters (database, default_format, everything used in the code above) do not
/// belong to the Settings class.
@ -497,30 +473,63 @@ void HTTPHandler::processQuery(
settings.readonly = 2;
}
SettingsChanges settings_changes;
for (auto it = params.begin(); it != params.end(); ++it)
bool isExternalData = startsWith(request.getContentType().data(), "multipart/form-data");
if (isExternalData)
{
if (it->first == "database")
/// Skip unneeded parameters to avoid confusing them later with context settings or query parameters.
reserved_param_suffixes.reserve(3);
/// It is a bug and ambiguity with `date_time_input_format` and `low_cardinality_allow_in_native_format` formats/settings.
reserved_param_suffixes.emplace_back("_format");
reserved_param_suffixes.emplace_back("_types");
reserved_param_suffixes.emplace_back("_structure");
}
SettingsChanges settings_changes;
for (const auto & [key, value] : params)
{
if (key == "database")
{
context.setCurrentDatabase(it->second);
context.setCurrentDatabase(value);
}
else if (it->first == "default_format")
else if (key == "default_format")
{
context.setDefaultFormat(it->second);
context.setDefaultFormat(value);
}
else if (param_could_be_skipped(it->first))
else if (param_could_be_skipped(key))
{
}
else
{
/// All other query parameters are treated as settings.
settings_changes.push_back({it->first, it->second});
settings_changes.push_back({key, value});
}
}
/// For external data we also want settings
context.checkSettingsConstraints(settings_changes);
context.applySettingsChanges(settings_changes);
/// Used in case of POST request with form-data, but it isn't expected to be deleted after that scope.
std::string full_query;
/// Support for "external data for query processing".
if (isExternalData)
{
ExternalTablesHandler handler(context, params);
params.load(request, istr, handler);
/// Params are of both form params POST and uri (GET params)
for (const auto & it : params)
if (it.first == "query")
full_query += it.second;
in = std::make_unique<ReadBufferFromString>(full_query);
}
else
in = std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
/// HTTP response compression is turned on only if the client signalled that they support it
/// (using Accept-Encoding header) and 'enable_http_compression' setting is turned on.
used_output.out->setCompression(client_supports_http_compression && settings.enable_http_compression);

View File

@ -1,17 +1,16 @@
#include "InterserverIOHTTPHandler.h"
#include <Poco/Net/HTTPBasicCredentials.h>
#include <Poco/Net/HTTPServerRequest.h>
#include <Poco/Net/HTTPServerResponse.h>
#include <common/logger_useful.h>
#include <Common/HTMLForm.h>
#include <Common/setThreadName.h>
#include <Compression/CompressedWriteBuffer.h>
#include <IO/ReadBufferFromIStream.h>
#include <IO/WriteBufferFromHTTPServerResponse.h>
#include <Interpreters/InterserverIOHandler.h>
#include "InterserverIOHTTPHandler.h"
#include "IServer.h"
namespace DB
{
@ -50,7 +49,7 @@ std::pair<String, bool> InterserverIOHTTPHandler::checkAuthentication(Poco::Net:
return {"", true};
}
void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response)
void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response, Output & used_output)
{
HTMLForm params(request);
@ -61,24 +60,17 @@ void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & reque
ReadBufferFromIStream body(request.stream());
const auto & config = server.config();
unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
WriteBufferFromHTTPServerResponse out(request, response, keep_alive_timeout);
auto endpoint = server.context().getInterserverIOHandler().getEndpoint(endpoint_name);
if (compress)
{
CompressedWriteBuffer compressed_out(out);
CompressedWriteBuffer compressed_out(*used_output.out.get());
endpoint->processQuery(params, body, compressed_out, response);
}
else
{
endpoint->processQuery(params, body, out, response);
endpoint->processQuery(params, body, *used_output.out.get(), response);
}
out.finalize();
}
@ -90,30 +82,30 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
if (request.getVersion() == Poco::Net::HTTPServerRequest::HTTP_1_1)
response.setChunkedTransferEncoding(true);
Output used_output;
const auto & config = server.config();
unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>(request, response, keep_alive_timeout);
try
{
if (auto [msg, success] = checkAuthentication(request); success)
if (auto [message, success] = checkAuthentication(request); success)
{
processQuery(request, response);
processQuery(request, response, used_output);
LOG_INFO(log, "Done processing query");
}
else
{
response.setStatusAndReason(Poco::Net::HTTPServerResponse::HTTP_UNAUTHORIZED);
if (!response.sent())
response.send() << msg << std::endl;
writeString(message, *used_output.out);
LOG_WARNING(log, "Query processing failed request: '" << request.getURI() << "' authentification failed");
}
}
catch (Exception & e)
{
if (e.code() == ErrorCodes::TOO_MANY_SIMULTANEOUS_QUERIES)
{
if (!response.sent())
response.send();
return;
}
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
@ -122,7 +114,7 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
std::string message = getCurrentExceptionMessage(is_real_error);
if (!response.sent())
response.send() << message << std::endl;
writeString(message, *used_output.out);
if (is_real_error)
LOG_ERROR(log, message);
@ -134,7 +126,8 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
std::string message = getCurrentExceptionMessage(false);
if (!response.sent())
response.send() << message << std::endl;
writeString(message, *used_output.out);
LOG_ERROR(log, message);
}
}

View File

@ -1,12 +1,10 @@
#pragma once
#include <memory>
#include <Poco/Logger.h>
#include <Poco/Net/HTTPRequestHandler.h>
#include <Common/CurrentMetrics.h>
#include "IServer.h"
namespace CurrentMetrics
{
@ -16,6 +14,9 @@ namespace CurrentMetrics
namespace DB
{
class IServer;
class WriteBufferFromHTTPServerResponse;
class InterserverIOHTTPHandler : public Poco::Net::HTTPRequestHandler
{
public:
@ -28,12 +29,17 @@ public:
void handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response) override;
private:
struct Output
{
std::shared_ptr<WriteBufferFromHTTPServerResponse> out;
};
IServer & server;
Poco::Logger * log;
CurrentMetrics::Increment metric_increment{CurrentMetrics::InterserverConnection};
void processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response);
void processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response, Output & used_output);
std::pair<String, bool> checkAuthentication(Poco::Net::HTTPServerRequest & request) const;
};

View File

@ -0,0 +1,370 @@
#include <DataStreams/copyData.h>
#include <IO/ReadBufferFromString.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <Interpreters/executeQuery.h>
#include <Storages/IStorage.h>
#include <Core/MySQLProtocol.h>
#include <Core/NamesAndTypes.h>
#include <Columns/ColumnVector.h>
#include <Common/config_version.h>
#include <Common/NetException.h>
#include <Common/OpenSSLHelpers.h>
#include <Poco/Crypto/RSAKey.h>
#include <Poco/Crypto/CipherFactory.h>
#include <Poco/Net/SecureStreamSocket.h>
#include <Poco/Net/SSLManager.h>
#include "MySQLHandler.h"
#include <limits>
#include <ext/scope_guard.h>
namespace DB
{
using namespace MySQLProtocol;
using Poco::Net::SecureStreamSocket;
using Poco::Net::SSLManager;
namespace ErrorCodes
{
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES;
extern const int OPENSSL_ERROR;
}
MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id)
: Poco::Net::TCPServerConnection(socket_)
, server(server_)
, log(&Poco::Logger::get("MySQLHandler"))
, connection_context(server.context())
, connection_id(connection_id)
, public_key(public_key)
, private_key(private_key)
{
server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
if (ssl_enabled)
server_capability_flags |= CLIENT_SSL;
}
void MySQLHandler::run()
{
connection_context = server.context();
connection_context.setDefaultFormat("MySQLWire");
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
try
{
String scramble = generateScramble();
/** Native authentication sent 20 bytes + '\0' character = 21 bytes.
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
*/
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, scramble + '\0');
packet_sender->sendPacket<Handshake>(handshake, true);
LOG_TRACE(log, "Sent handshake");
HandshakeResponse handshake_response = finishHandshake();
connection_context.client_capabilities = handshake_response.capability_flags;
if (handshake_response.max_packet_size)
connection_context.max_packet_size = handshake_response.max_packet_size;
if (!connection_context.max_packet_size)
connection_context.max_packet_size = MAX_PACKET_LENGTH;
LOG_DEBUG(log, "Capabilities: " << handshake_response.capability_flags
<< "\nmax_packet_size: "
<< handshake_response.max_packet_size
<< "\ncharacter_set: "
<< handshake_response.character_set
<< "\nuser: "
<< handshake_response.username
<< "\nauth_response length: "
<< handshake_response.auth_response.length()
<< "\nauth_response: "
<< handshake_response.auth_response
<< "\ndatabase: "
<< handshake_response.database
<< "\nauth_plugin_name: "
<< handshake_response.auth_plugin_name);
client_capability_flags = handshake_response.capability_flags;
if (!(client_capability_flags & CLIENT_PROTOCOL_41))
throw Exception("Required capability: CLIENT_PROTOCOL_41.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
if (!(client_capability_flags & CLIENT_PLUGIN_AUTH))
throw Exception("Required capability: CLIENT_PLUGIN_AUTH.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
authenticate(handshake_response, scramble);
OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
packet_sender->sendPacket(ok_packet, true);
while (true)
{
packet_sender->resetSequenceId();
String payload = packet_sender->receivePacketPayload();
int command = payload[0];
LOG_DEBUG(log, "Received command: " << std::to_string(command) << ". Connection id: " << connection_id << ".");
try
{
switch (command)
{
case COM_QUIT:
return;
case COM_INIT_DB:
comInitDB(payload);
break;
case COM_QUERY:
comQuery(payload);
break;
case COM_FIELD_LIST:
comFieldList(payload);
break;
case COM_PING:
comPing();
break;
default:
throw Exception(Poco::format("Command %d is not implemented.", command), ErrorCodes::NOT_IMPLEMENTED);
}
}
catch (const NetException & exc)
{
log->log(exc);
throw;
}
catch (const Exception & exc)
{
log->log(exc);
packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
}
}
}
catch (Poco::Exception & exc)
{
log->log(exc);
}
}
/** Reads 3 bytes, finds out whether it is SSLRequest or HandshakeResponse packet, starts secure connection, if it is SSLRequest.
* Reading is performed from socket instead of ReadBuffer to prevent reading part of SSL handshake.
* If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes.
*/
MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
{
HandshakeResponse packet;
size_t packet_size = PACKET_HEADER_SIZE + SSL_REQUEST_PAYLOAD_SIZE;
/// Buffer for SSLRequest or part of HandshakeResponse.
char buf[packet_size];
size_t pos = 0;
/// Reads at least count and at most packet_size bytes.
auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void {
while (pos < count)
{
int ret = socket().receiveBytes(buf + pos, packet_size - pos);
if (ret == 0)
{
throw Exception("Cannot read all data. Bytes read: " + std::to_string(pos) + ". Bytes expected: 3.", ErrorCodes::CANNOT_READ_ALL_DATA);
}
pos += ret;
}
};
read_bytes(3); /// We can find out whether it is SSLRequest of HandshakeResponse by first 3 bytes.
size_t payload_size = unalignedLoad<uint32_t>(buf) & 0xFFFFFFu;
LOG_TRACE(log, "payload size: " << payload_size);
if (payload_size == SSL_REQUEST_PAYLOAD_SIZE)
{
read_bytes(packet_size); /// Reading rest SSLRequest.
SSLRequest ssl_request;
ssl_request.readPayload(String(buf + PACKET_HEADER_SIZE, pos - PACKET_HEADER_SIZE));
connection_context.client_capabilities = ssl_request.capability_flags;
connection_context.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
secure_connection = true;
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
connection_context.sequence_id = 2;
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
packet_sender->max_packet_size = connection_context.max_packet_size;
packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
}
else
{
/// Reading rest of HandshakeResponse.
packet_size = PACKET_HEADER_SIZE + payload_size;
WriteBufferFromOwnString buf_for_handshake_response;
buf_for_handshake_response.write(buf, pos);
copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos);
packet.readPayload(buf_for_handshake_response.str().substr(PACKET_HEADER_SIZE));
packet_sender->sequence_id++;
}
return packet;
}
String MySQLHandler::generateScramble()
{
String scramble(MySQLProtocol::SCRAMBLE_LENGTH, 0);
Poco::RandomInputStream generator;
for (size_t i = 0; i < scramble.size(); i++)
{
generator >> scramble[i];
}
return scramble;
}
void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, const String & scramble)
{
String auth_response;
AuthSwitchResponse response;
if (handshake_response.auth_plugin_name != Authentication::SHA256)
{
packet_sender->sendPacket(AuthSwitchRequest(Authentication::SHA256, scramble + '\0'), true);
if (in->eof())
throw Exception(
"Client doesn't support authentication method " + String(Authentication::SHA256) + " used by ClickHouse",
ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
packet_sender->receivePacket(response);
auth_response = response.value;
LOG_TRACE(log, "Authentication method mismatch.");
}
else
{
auth_response = handshake_response.auth_response;
LOG_TRACE(log, "Authentication method match.");
}
if (auth_response == "\1")
{
LOG_TRACE(log, "Client requests public key.");
BIO * mem = BIO_new(BIO_s_mem());
SCOPE_EXIT(BIO_free(mem));
if (PEM_write_bio_RSA_PUBKEY(mem, &public_key) != 1)
{
throw Exception("Failed to write public key to memory. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
char * pem_buf = nullptr;
long pem_size = BIO_get_mem_data(mem, &pem_buf);
String pem(pem_buf, pem_size);
LOG_TRACE(log, "Key: " << pem);
AuthMoreData data(pem);
packet_sender->sendPacket(data, true);
packet_sender->receivePacket(response);
auth_response = response.value;
}
else
{
LOG_TRACE(log, "Client didn't request public key.");
}
String password;
/** Decrypt password, if it's not empty.
* The original intention was that the password is a string[NUL] but this never got enforced properly so now we have to accept that
* an empty packet is a blank password, thus the check for auth_response.empty() has to be made too.
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L4017
*/
if (!secure_connection && !auth_response.empty() && auth_response != String("\0", 1))
{
LOG_TRACE(log, "Received nonempty password");
auto ciphertext = reinterpret_cast<unsigned char *>(auth_response.data());
unsigned char plaintext[RSA_size(&private_key)];
int plaintext_size = RSA_private_decrypt(auth_response.size(), ciphertext, plaintext, &private_key, RSA_PKCS1_OAEP_PADDING);
if (plaintext_size == -1)
{
throw Exception("Failed to decrypt auth data. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
password.resize(plaintext_size);
for (int i = 0; i < plaintext_size; i++)
{
password[i] = plaintext[i] ^ static_cast<unsigned char>(scramble[i % scramble.size()]);
}
}
else if (secure_connection)
{
password = auth_response;
}
else
{
LOG_TRACE(log, "Received empty password");
}
if (!password.empty())
{
password.pop_back(); /// terminating null byte
}
try
{
connection_context.setUser(handshake_response.username, password, socket().address(), "");
connection_context.setCurrentDatabase(handshake_response.database);
connection_context.setCurrentQueryId("");
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " succeeded.");
}
catch (const Exception & exc)
{
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " failed.");
packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
throw;
}
}
void MySQLHandler::comInitDB(const String & payload)
{
String database = payload.substr(1);
LOG_DEBUG(log, "Setting current database to " << database);
connection_context.setCurrentDatabase(database);
packet_sender->sendPacket(OK_Packet(0, client_capability_flags, 0, 0, 1), true);
}
void MySQLHandler::comFieldList(const String & payload)
{
ComFieldList packet;
packet.readPayload(payload);
String database = connection_context.getCurrentDatabase();
StoragePtr tablePtr = connection_context.getTable(database, packet.table);
for (const NameAndTypePair & column: tablePtr->getColumns().getAll())
{
ColumnDefinition column_definition(
database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0
);
packet_sender->sendPacket(column_definition);
}
packet_sender->sendPacket(OK_Packet(0xfe, client_capability_flags, 0, 0, 0), true);
}
void MySQLHandler::comPing()
{
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
}
void MySQLHandler::comQuery(const String & payload)
{
bool with_output = false;
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
with_output = true;
};
String query = payload.substr(1);
// Translate query from MySQL to ClickHouse.
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
if (query == "select @@version_comment limit 1") // MariaDB client starts session with that query
query = "select ''";
ReadBufferFromString buf(query);
executeQuery(buf, *out, true, connection_context, set_content_type, nullptr);
if (!with_output)
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
}
}

View File

@ -0,0 +1,59 @@
#pragma once
#include <Poco/Net/TCPServerConnection.h>
#include <Poco/Net/SecureStreamSocket.h>
#include <Common/getFQDNOrHostName.h>
#include <Core/MySQLProtocol.h>
#include <openssl/rsa.h>
#include "IServer.h"
namespace DB
{
/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client.
class MySQLHandler : public Poco::Net::TCPServerConnection
{
public:
MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id);
void run() final;
private:
/// Enables SSL, if client requested.
MySQLProtocol::HandshakeResponse finishHandshake();
void comQuery(const String & payload);
void comFieldList(const String & payload);
void comPing();
void comInitDB(const String & payload);
static String generateScramble();
void authenticate(const MySQLProtocol::HandshakeResponse &, const String & scramble);
IServer & server;
Poco::Logger * log;
Context connection_context;
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
size_t connection_id = 0;
size_t server_capability_flags;
size_t client_capability_flags;
RSA & public_key;
RSA & private_key;
std::shared_ptr<ReadBuffer> in;
std::shared_ptr<WriteBuffer> out;
bool secure_connection = false;
std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
};
}

View File

@ -0,0 +1,124 @@
#include <Common/OpenSSLHelpers.h>
#include <Poco/Crypto/X509Certificate.h>
#include <Poco/Net/SSLManager.h>
#include <Poco/Net/TCPServerConnectionFactory.h>
#include <Poco/Util/Application.h>
#include <common/logger_useful.h>
#include <ext/scope_guard.h>
#include "IServer.h"
#include "MySQLHandler.h"
#include "MySQLHandlerFactory.h"
namespace DB
{
namespace ErrorCodes
{
extern const int CANNOT_OPEN_FILE;
extern const int NO_ELEMENTS_IN_CONFIG;
extern const int OPENSSL_ERROR;
extern const int SYSTEM_ERROR;
}
MySQLHandlerFactory::MySQLHandlerFactory(IServer & server_)
: server(server_)
, log(&Logger::get("MySQLHandlerFactory"))
{
try
{
Poco::Net::SSLManager::instance().defaultServerContext();
}
catch (...)
{
LOG_INFO(log, "Failed to create SSL context. SSL will be disabled. Error: " << getCurrentExceptionMessage(false));
ssl_enabled = false;
}
/// Reading rsa keys for SHA256 authentication plugin.
try
{
readRSAKeys();
}
catch (...)
{
LOG_WARNING(log, "Failed to read RSA keys. Error: " << getCurrentExceptionMessage(false));
generateRSAKeys();
}
}
void MySQLHandlerFactory::readRSAKeys()
{
const Poco::Util::LayeredConfiguration & config = Poco::Util::Application::instance().config();
String certificateFileProperty = "openSSL.server.certificateFile";
String privateKeyFileProperty = "openSSL.server.privateKeyFile";
if (!config.has(certificateFileProperty))
throw Exception("Certificate file is not set.", ErrorCodes::NO_ELEMENTS_IN_CONFIG);
if (!config.has(privateKeyFileProperty))
throw Exception("Private key file is not set.", ErrorCodes::NO_ELEMENTS_IN_CONFIG);
{
String certificateFile = config.getString(certificateFileProperty);
FILE * fp = fopen(certificateFile.data(), "r");
if (fp == nullptr)
throw Exception("Cannot open certificate file: " + certificateFile + ".", ErrorCodes::CANNOT_OPEN_FILE);
SCOPE_EXIT(fclose(fp));
X509 * x509 = PEM_read_X509(fp, nullptr, nullptr, nullptr);
SCOPE_EXIT(X509_free(x509));
if (x509 == nullptr)
throw Exception("Failed to read PEM certificate from " + certificateFile + ". Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
EVP_PKEY * p = X509_get_pubkey(x509);
if (p == nullptr)
throw Exception("Failed to get RSA key from X509. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
SCOPE_EXIT(EVP_PKEY_free(p));
public_key.reset(EVP_PKEY_get1_RSA(p));
if (public_key.get() == nullptr)
throw Exception("Failed to get RSA key from ENV_PKEY. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
{
String privateKeyFile = config.getString(privateKeyFileProperty);
FILE * fp = fopen(privateKeyFile.data(), "r");
if (fp == nullptr)
throw Exception ("Cannot open private key file " + privateKeyFile + ".", ErrorCodes::CANNOT_OPEN_FILE);
SCOPE_EXIT(fclose(fp));
private_key.reset(PEM_read_RSAPrivateKey(fp, nullptr, nullptr, nullptr));
if (!private_key)
throw Exception("Failed to read RSA private key from " + privateKeyFile + ". Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
}
void MySQLHandlerFactory::generateRSAKeys()
{
LOG_INFO(log, "Generating new RSA key.");
public_key.reset(RSA_new());
if (!public_key)
throw Exception("Failed to allocate RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
BIGNUM * e = BN_new();
if (!e)
throw Exception("Failed to allocate BIGNUM. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
SCOPE_EXIT(BN_free(e));
if (!BN_set_word(e, 65537) || !RSA_generate_key_ex(public_key.get(), 2048, e, nullptr))
throw Exception("Failed to generate RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
private_key.reset(RSAPrivateKey_dup(public_key.get()));
if (!private_key)
throw Exception("Failed to copy RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
}
Poco::Net::TCPServerConnection * MySQLHandlerFactory::createConnection(const Poco::Net::StreamSocket & socket)
{
size_t connection_id = last_connection_id++;
LOG_TRACE(log, "MySQL connection. Id: " << connection_id << ". Address: " << socket.peerAddress().toString());
return new MySQLHandler(server, socket, *public_key, *private_key, ssl_enabled, connection_id);
}
}

View File

@ -0,0 +1,39 @@
#pragma once
#include <Poco/Net/TCPServerConnectionFactory.h>
#include <atomic>
#include <openssl/rsa.h>
#include "IServer.h"
namespace DB
{
class MySQLHandlerFactory : public Poco::Net::TCPServerConnectionFactory
{
private:
IServer & server;
Poco::Logger * log;
struct RSADeleter
{
void operator()(RSA * ptr) { RSA_free(ptr); }
};
using RSAPtr = std::unique_ptr<RSA, RSADeleter>;
RSAPtr public_key;
RSAPtr private_key;
bool ssl_enabled = true;
std::atomic<size_t> last_connection_id = 0;
public:
explicit MySQLHandlerFactory(IServer & server_);
void readRSAKeys();
void generateRSAKeys();
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket) override;
};
}

View File

@ -49,6 +49,7 @@
#include <Common/StatusFile.h>
#include "TCPHandlerFactory.h"
#include "Common/config_version.h"
#include "MySQLHandlerFactory.h"
#if defined(__linux__)
#include <Common/hasLinuxCapability.h>
@ -668,7 +669,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
socket,
new Poco::Net::TCPServerParams));
LOG_INFO(log, "Listening tcp: " + address.toString());
LOG_INFO(log, "Listening for connections with native protocol (tcp): " + address.toString());
}
/// TCP with SSL
@ -685,7 +686,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
server_pool,
socket,
new Poco::Net::TCPServerParams));
LOG_INFO(log, "Listening tcp_secure: " + address.toString());
LOG_INFO(log, "Listening for connections with secure native protocol (tcp_secure): " + address.toString());
#else
throw Exception{"SSL support for TCP protocol is disabled because Poco library was built without NetSSL support.",
ErrorCodes::SUPPORT_IS_DISABLED};
@ -710,7 +711,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
socket,
http_params));
LOG_INFO(log, "Listening interserver http: " + address.toString());
LOG_INFO(log, "Listening for replica communication (interserver) http://" + address.toString());
}
if (config().has("interserver_https_port"))
@ -727,12 +728,27 @@ int Server::main(const std::vector<std::string> & /*args*/)
socket,
http_params));
LOG_INFO(log, "Listening interserver https: " + address.toString());
LOG_INFO(log, "Listening for secure replica communication (interserver) https://" + address.toString());
#else
throw Exception{"SSL support for TCP protocol is disabled because Poco library was built without NetSSL support.",
ErrorCodes::SUPPORT_IS_DISABLED};
#endif
}
if (config().has("mysql_port"))
{
Poco::Net::ServerSocket socket;
auto address = socket_bind_listen(socket, listen_host, config().getInt("mysql_port"), /* secure = */ true);
socket.setReceiveTimeout(Poco::Timespan());
socket.setSendTimeout(settings.send_timeout);
servers.emplace_back(std::make_unique<Poco::Net::TCPServer>(
new MySQLHandlerFactory(*this),
server_pool,
socket,
new Poco::Net::TCPServerParams));
LOG_INFO(log, "Listening for MySQL compatibility protocol: " + address.toString());
}
}
catch (const Poco::Exception & e)
{

View File

@ -1,2 +1,5 @@
*.bin
*.mrk
*.txt
*.dat
*.idx

View File

@ -308,16 +308,88 @@ public:
/**
* Check whether two bitmaps intersect.
* Intersection with an empty set is always 0 (consistent with hasAny).
*/
UInt8 rb_intersect(const RoaringBitmapWithSmallSet & r1)
UInt8 rb_intersect(const RoaringBitmapWithSmallSet & r1) const
{
if (isSmall())
toLarge();
roaring_bitmap_t * rb1 = r1.isSmall() ? r1.getNewRbFromSmall() : r1.getRb();
UInt8 is_true = roaring_bitmap_intersect(rb, rb1);
if (r1.isSmall())
roaring_bitmap_free(rb1);
return is_true;
{
if (r1.isSmall())
{
for (const auto & x : r1.small)
if (small.find(x.getValue()) != small.end())
return 1;
}
else
{
for (const auto & x : small)
if (roaring_bitmap_contains(r1.rb, x.getValue()))
return 1;
}
}
else if (r1.isSmall())
{
for (const auto & x : r1.small)
if (roaring_bitmap_contains(rb, x.getValue()))
return 1;
}
else if (roaring_bitmap_intersect(rb, r1.rb))
return 1;
return 0;
}
/**
* Check whether the argument is the subset of this set.
* Empty set is a subset of any other set (consistent with hasAll).
*/
UInt8 rb_is_subset(const RoaringBitmapWithSmallSet & r1) const
{
if (isSmall())
{
if (r1.isSmall())
{
for (const auto & x : r1.small)
if (small.find(x.getValue()) == small.end())
return 0;
}
else
{
UInt64 r1_size = r1.size();
if (r1_size > small.size())
return 0; // A bigger set can not be a subset of ours.
// This is a rare case with a small number of elements on
// both sides: r1 was promoted to large for some reason and
// it is still not larger than our small set.
// If r1 is our subset then our size must be equal to
// r1_size + number of not found elements, if this sum becomes
// greater then r1 is not a subset.
for (const auto & x : small)
if (!roaring_bitmap_contains(r1.rb, x.getValue()) && ++r1_size > small.size())
return 0;
}
}
else if (r1.isSmall())
{
for (const auto & x : r1.small)
if (!roaring_bitmap_contains(rb, x.getValue()))
return 0;
}
else if (!roaring_bitmap_is_subset(r1.rb, rb))
return 0;
return 1;
}
/**
* Check whether this bitmap contains the argument.
*/
UInt8 rb_contains(const UInt32 x) const
{
return isSmall() ? small.find(x) != small.end() :
roaring_bitmap_contains(rb, x);
}
/**

View File

@ -3,6 +3,8 @@
#include <IO/ReadHelpers.h>
#include <IO/WriteHelpers.h>
#include <Interpreters/castColumn.h>
#include <Columns/ColumnArray.h>
#include <Columns/ColumnTuple.h>
#include <Common/FieldVisitors.h>
#include <Common/typeid_cast.h>
#include "AggregateFunctionFactory.h"
@ -34,7 +36,7 @@ namespace
for (size_t i = 0; i < argument_types.size(); ++i)
{
if (!isNumber(argument_types[i]))
if (!isNativeNumber(argument_types[i]))
throw Exception(
"Argument " + std::to_string(i) + " of type " + argument_types[i]->getName()
+ " must be numeric for aggregate function " + name,
@ -43,11 +45,11 @@ namespace
/// Such default parameters were picked because they did good on some tests,
/// though it still requires to fit parameters to achieve better result
auto learning_rate = Float64(0.01);
auto l2_reg_coef = Float64(0.01);
UInt32 batch_size = 1;
auto learning_rate = Float64(0.00001);
auto l2_reg_coef = Float64(0.1);
UInt32 batch_size = 15;
std::shared_ptr<IWeightsUpdater> weights_updater = std::make_shared<StochasticGradientDescent>();
std::string weights_updater_name = "\'SGD\'";
std::shared_ptr<IGradientComputer> gradient_computer;
if (!parameters.empty())
@ -64,19 +66,8 @@ namespace
}
if (parameters.size() > 3)
{
if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'SGD\'")
{
weights_updater = std::make_shared<StochasticGradientDescent>();
}
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Momentum\'")
{
weights_updater = std::make_shared<Momentum>();
}
else if (applyVisitor(FieldVisitorToString(), parameters[3]) == "\'Nesterov\'")
{
weights_updater = std::make_shared<Nesterov>();
}
else
weights_updater_name = applyVisitor(FieldVisitorToString(), parameters[3]);
if (weights_updater_name != "\'SGD\'" && weights_updater_name != "\'Momentum\'" && weights_updater_name != "\'Nesterov\'")
{
throw Exception("Invalid parameter for weights updater", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
}
@ -98,20 +89,19 @@ namespace
return std::make_shared<Method>(
argument_types.size() - 1,
gradient_computer,
weights_updater,
weights_updater_name,
learning_rate,
l2_reg_coef,
batch_size,
argument_types,
parameters);
}
}
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
{
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
factory.registerFunction("stochasticLinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
factory.registerFunction("stochasticLogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
}
LinearModelData::LinearModelData(
@ -149,6 +139,26 @@ void LinearModelData::predict(
gradient_computer->predict(container, block, arguments, weights, bias, context);
}
void LinearModelData::returnWeights(IColumn & to) const
{
size_t size = weights.size() + 1;
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
size_t old_size = offsets_to.back();
offsets_to.push_back(old_size + size);
typename ColumnFloat64::Container & val_to
= static_cast<ColumnFloat64 &>(arr_to.getData()).getData();
val_to.reserve(old_size + size);
for (size_t i = 0; i + 1 < size; ++i)
val_to.push_back(weights[i]);
val_to.push_back(bias);
}
void LinearModelData::read(ReadBuffer & buf)
{
readBinary(bias, buf);
@ -192,7 +202,8 @@ void LinearModelData::merge(const DB::LinearModelData & rhs)
void LinearModelData::add(const IColumn ** columns, size_t row_num)
{
/// first column stores target; features start from (columns + 1)
const auto target = (*columns[0])[row_num].get<Float64>();
Float64 target = (*columns[0]).getFloat64(row_num);
/// Here we have columns + 1 as first column corresponds to target value, and others - to features
weights_updater->add_to_batch(
gradient_batch, *gradient_computer, weights, bias, learning_rate, l2_reg_coef, target, columns + 1, row_num);
@ -345,7 +356,7 @@ void LogisticRegression::predict(
for (size_t i = 1; i < arguments.size(); ++i)
{
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNumber(cur_col.type))
if (!isNativeNumber(cur_col.type))
{
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
@ -385,7 +396,7 @@ void LogisticRegression::compute(
Float64 derivative = bias;
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i])[row_num].get<Float64>();
auto value = (*columns[i]).getFloat64(row_num);
derivative += weights[i] * value;
}
derivative *= target;
@ -394,8 +405,8 @@ void LogisticRegression::compute(
batch_gradient[weights.size()] += learning_rate * target / (derivative + 1);
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i])[row_num].get<Float64>();
batch_gradient[i] += learning_rate * target * value / (derivative + 1) - 2 * l2_reg_coef * weights[i];
auto value = (*columns[i]).getFloat64(row_num);
batch_gradient[i] += learning_rate * target * value / (derivative + 1) - 2 * learning_rate * l2_reg_coef * weights[i];
}
}
@ -418,7 +429,7 @@ void LinearRegression::predict(
for (size_t i = 1; i < arguments.size(); ++i)
{
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
if (!isNumber(cur_col.type))
if (!isNativeNumber(cur_col.type))
{
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
}
@ -458,7 +469,7 @@ void LinearRegression::compute(
Float64 derivative = (target - bias);
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i])[row_num].get<Float64>();
auto value = (*columns[i]).getFloat64(row_num);
derivative -= weights[i] * value;
}
derivative *= (2 * learning_rate);
@ -466,8 +477,8 @@ void LinearRegression::compute(
batch_gradient[weights.size()] += derivative;
for (size_t i = 0; i < weights.size(); ++i)
{
auto value = (*columns[i])[row_num].get<Float64>();
batch_gradient[i] += derivative * value - 2 * l2_reg_coef * weights[i];
auto value = (*columns[i]).getFloat64(row_num);
batch_gradient[i] += derivative * value - 2 * learning_rate * l2_reg_coef * weights[i];
}
}

View File

@ -4,6 +4,8 @@
#include <Columns/ColumnsCommon.h>
#include <Columns/ColumnsNumber.h>
#include <DataTypes/DataTypesNumber.h>
#include <DataTypes/DataTypeTuple.h>
#include <DataTypes/DataTypeArray.h>
#include "IAggregateFunction.h"
namespace DB
@ -12,6 +14,7 @@ namespace ErrorCodes
{
extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH;
extern const int BAD_ARGUMENTS;
extern const int BAD_CAST;
}
/**
@ -218,6 +221,7 @@ public:
void
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
void returnWeights(IColumn & to) const;
private:
std::vector<Float64> weights;
Float64 bias{0.0};
@ -253,7 +257,7 @@ public:
explicit AggregateFunctionMLMethod(
UInt32 param_num,
std::shared_ptr<IGradientComputer> gradient_computer,
std::shared_ptr<IWeightsUpdater> weights_updater,
std::string weights_updater_name,
Float64 learning_rate,
Float64 l2_reg_coef,
UInt32 batch_size,
@ -265,15 +269,39 @@ public:
, l2_reg_coef(l2_reg_coef)
, batch_size(batch_size)
, gradient_computer(std::move(gradient_computer))
, weights_updater(std::move(weights_updater))
, weights_updater_name(std::move(weights_updater_name))
{
}
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
/// This function is called when SELECT linearRegression(...) is called
DataTypePtr getReturnType() const override
{
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
}
/// This function is called from evalMLMethod function for correct predictValues call
DataTypePtr getReturnTypeToPredict() const override
{
return std::make_shared<DataTypeNumber<Float64>>();
}
void create(AggregateDataPtr place) const override
{
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, weights_updater);
std::shared_ptr<IWeightsUpdater> new_weights_updater;
if (weights_updater_name == "\'SGD\'")
{
new_weights_updater = std::make_shared<StochasticGradientDescent>();
} else if (weights_updater_name == "\'Momentum\'")
{
new_weights_updater = std::make_shared<Momentum>();
} else if (weights_updater_name == "\'Nesterov\'")
{
new_weights_updater = std::make_shared<Nesterov>();
} else
{
throw Exception("Illegal name of weights updater (should have been checked earlier)", ErrorCodes::LOGICAL_ERROR);
}
new (place) Data(learning_rate, l2_reg_coef, param_num, batch_size, gradient_computer, new_weights_updater);
}
void add(AggregateDataPtr place, const IColumn ** columns, size_t row_num, Arena *) const override
@ -296,16 +324,26 @@ public:
+ ". Required: " + std::to_string(param_num + 1),
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
auto & column = dynamic_cast<ColumnVector<Float64> &>(to);
/// This cast might be correct because column type is based on getReturnTypeToPredict.
ColumnVector<Float64> * column;
try
{
column = &dynamic_cast<ColumnVector<Float64> &>(to);
} catch (const std::bad_cast &)
{
throw Exception("Cast of column of predictions is incorrect. getReturnTypeToPredict must return same value as it is casted to",
ErrorCodes::BAD_CAST);
}
this->data(place).predict(column.getData(), block, arguments, context);
this->data(place).predict(column->getData(), block, arguments, context);
}
/** This function is called if aggregate function without State modifier is selected in a query.
* Inserts all weights of the model into the column 'to', so user may use such information if needed
*/
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
std::ignore = place;
std::ignore = to;
throw std::runtime_error("not implemented");
this->data(place).returnWeights(to);
}
const char * getHeaderFilePath() const override { return __FILE__; }
@ -316,15 +354,15 @@ private:
Float64 l2_reg_coef;
UInt32 batch_size;
std::shared_ptr<IGradientComputer> gradient_computer;
std::shared_ptr<IWeightsUpdater> weights_updater;
std::string weights_updater_name;
};
struct NameLinearRegression
{
static constexpr auto name = "LinearRegression";
static constexpr auto name = "stochasticLinearRegression";
};
struct NameLogisticRegression
{
static constexpr auto name = "LogisticRegression";
static constexpr auto name = "stochasticLogisticRegression";
};
}

View File

@ -61,10 +61,10 @@ public:
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_)
{
if (!isNumber(arguments[0]))
if (!isNativeNumber(arguments[0]))
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!isNumber(arguments[1]))
if (!isNativeNumber(arguments[1]))
throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
if (!arguments[0]->equals(*arguments[1]))

View File

@ -1,6 +1,12 @@
#include <AggregateFunctions/Helpers.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/AggregateFunctionSequenceMatch.h>
#include <DataTypes/DataTypeDate.h>
#include <DataTypes/DataTypeDateTime.h>
#include <ext/range.h>
namespace DB
{
@ -12,32 +18,58 @@ namespace ErrorCodes
namespace
{
AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & name, const DataTypes & argument_types, const Array & params)
template <template <typename, typename> class AggregateFunction, template <typename> class Data>
AggregateFunctionPtr createAggregateFunctionSequenceBase(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (params.size() != 1)
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, params, pattern);
}
const auto arg_count = argument_types.size();
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params)
{
if (params.size() != 1)
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
if (arg_count < 3)
throw Exception{"Aggregate function " + name + " requires at least 3 arguments.",
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
if (arg_count - 1 > max_events)
throw Exception{"Aggregate function " + name + " supports up to "
+ toString(max_events) + " event arguments.",
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
const auto time_arg = argument_types.front().get();
for (const auto i : ext::range(1, arg_count))
{
const auto cond_arg = argument_types[i].get();
if (!isUInt8(cond_arg))
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1)
+ " of aggregate function " + name + ", must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
String pattern = params.front().safeGet<std::string>();
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, params, pattern);
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunction, Data>(*argument_types[0], argument_types, params, pattern));
if (res)
return res;
WhichDataType which(argument_types.front().get());
if (which.isDateTime())
return std::make_shared<AggregateFunction<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types, params, pattern);
else if (which.isDate())
return std::make_shared<AggregateFunction<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types, params, pattern);
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
+ name + ", must be DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
}
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory)
{
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceMatch);
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceCount);
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceBase<AggregateFunctionSequenceMatch, AggregateFunctionSequenceMatchData>);
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceBase<AggregateFunctionSequenceCount, AggregateFunctionSequenceMatchData>);
}
}

View File

@ -36,11 +36,12 @@ struct ComparePairFirst final
}
};
static constexpr auto max_events = 32;
template <typename T>
struct AggregateFunctionSequenceMatchData final
{
static constexpr auto max_events = 32;
using Timestamp = std::uint32_t;
using Timestamp = T;
using Events = std::bitset<max_events>;
using TimestampEvents = std::pair<Timestamp, Events>;
using Comparator = ComparePairFirst<std::less>;
@ -61,6 +62,9 @@ struct AggregateFunctionSequenceMatchData final
void merge(const AggregateFunctionSequenceMatchData & other)
{
if (other.events_list.empty())
return;
const auto size = events_list.size();
events_list.insert(std::begin(other.events_list), std::end(other.events_list));
@ -119,7 +123,7 @@ struct AggregateFunctionSequenceMatchData final
for (size_t i = 0; i < size; ++i)
{
std::uint32_t timestamp;
Timestamp timestamp;
readBinary(timestamp, buf);
UInt64 events;
@ -135,48 +139,23 @@ struct AggregateFunctionSequenceMatchData final
constexpr auto sequence_match_max_iterations = 1000000;
template <typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>
template <typename T, typename Data, typename Derived>
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
{
public:
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern)
: IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>(arguments, params)
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params)
, pattern(pattern)
{
arg_count = arguments.size();
if (!sufficientArgs(arg_count))
throw Exception{"Aggregate function " + derived().getName() + " requires at least 3 arguments.",
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
if (arg_count - 1 > AggregateFunctionSequenceMatchData::max_events)
throw Exception{"Aggregate function " + derived().getName() + " supports up to " +
toString(AggregateFunctionSequenceMatchData::max_events) + " event arguments.",
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
const auto time_arg = arguments.front().get();
if (!WhichDataType(time_arg).isDateTime())
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
+ derived().getName() + ", must be DateTime",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
for (const auto i : ext::range(1, arg_count))
{
const auto cond_arg = arguments[i].get();
if (!isUInt8(cond_arg))
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) +
" of aggregate function " + derived().getName() + ", must be UInt8",
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
}
parsePattern();
}
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
{
const auto timestamp = static_cast<const ColumnUInt32 *>(columns[0])->getData()[row_num];
const auto timestamp = static_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
AggregateFunctionSequenceMatchData::Events events;
typename Data::Events events;
for (const auto i : ext::range(1, arg_count))
{
const auto event = static_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
@ -218,17 +197,15 @@ private:
struct PatternAction final
{
PatternActionType type;
std::uint32_t extra;
std::uint64_t extra;
PatternAction() = default;
PatternAction(const PatternActionType type, const std::uint32_t extra = 0) : type{type}, extra{extra} {}
PatternAction(const PatternActionType type, const std::uint64_t extra = 0) : type{type}, extra{extra} {}
};
static constexpr size_t bytes_on_stack = 64;
using PatternActions = PODArray<PatternAction, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
static bool sufficientArgs(const size_t arg_count) { return arg_count >= 3; }
Derived & derived() { return static_cast<Derived &>(*this); }
void parsePattern()
@ -340,8 +317,8 @@ protected:
/// This algorithm performs in O(mn) (with m the number of DFA states and N the number
/// of events) with a memory consumption and memory allocations in O(m). It means that
/// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
template <typename T>
bool dfaMatch(T & events_it, const T events_end) const
template <typename EventEntry>
bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
{
using ActiveStates = std::vector<bool>;
@ -396,8 +373,8 @@ protected:
return active_states.back();
}
template <typename T>
bool backtrackingMatch(T & events_it, const T events_end) const
template <typename EventEntry>
bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
{
const auto action_begin = std::begin(actions);
const auto action_end = std::end(actions);
@ -407,7 +384,7 @@ protected:
auto base_it = events_it;
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
using backtrack_info = std::tuple<decltype(action_it), T, T>;
using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
std::stack<backtrack_info> back_stack;
/// backtrack if possible
@ -458,7 +435,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeLessOrEqual)
{
if (events_it->first - base_it->first <= action_it->extra)
if (events_it->first <= base_it->first + action_it->extra)
{
/// condition satisfied, move onto next action
back_stack.emplace(action_it, events_it, base_it);
@ -470,7 +447,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeLess)
{
if (events_it->first - base_it->first < action_it->extra)
if (events_it->first < base_it->first + action_it->extra)
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -481,7 +458,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
{
if (events_it->first - base_it->first >= action_it->extra)
if (events_it->first >= base_it->first + action_it->extra)
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -492,7 +469,7 @@ protected:
}
else if (action_it->type == PatternActionType::TimeGreater)
{
if (events_it->first - base_it->first > action_it->extra)
if (events_it->first > base_it->first + action_it->extra)
{
back_stack.emplace(action_it, events_it, base_it);
base_it = events_it;
@ -575,14 +552,14 @@ private:
DFAStates dfa_states;
};
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>
template <typename T, typename Data>
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
{
public:
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>(arguments, params, pattern) {}
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>::AggregateFunctionSequenceBase;
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceMatch"; }
@ -590,27 +567,27 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const_cast<Data &>(data(place)).sort();
const_cast<Data &>(this->data(place)).sort();
const auto & data_ref = data(place);
const auto & data_ref = this->data(place);
const auto events_begin = std::begin(data_ref.events_list);
const auto events_end = std::end(data_ref.events_list);
auto events_it = events_begin;
bool match = pattern_has_time ? backtrackingMatch(events_it, events_end) : dfaMatch(events_it, events_end);
bool match = this->pattern_has_time ? this->backtrackingMatch(events_it, events_end) : this->dfaMatch(events_it, events_end);
static_cast<ColumnUInt8 &>(to).getData().push_back(match);
}
};
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>
template <typename T, typename Data>
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
{
public:
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern)
: AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>(arguments, params, pattern) {}
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern) {}
using AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>::AggregateFunctionSequenceBase;
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
String getName() const override { return "sequenceCount"; }
@ -618,21 +595,21 @@ public:
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
{
const_cast<Data &>(data(place)).sort();
const_cast<Data &>(this->data(place)).sort();
static_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
}
private:
UInt64 count(const ConstAggregateDataPtr & place) const
{
const auto & data_ref = data(place);
const auto & data_ref = this->data(place);
const auto events_begin = std::begin(data_ref.events_list);
const auto events_end = std::end(data_ref.events_list);
auto events_it = events_begin;
size_t count = 0;
while (events_it != events_end && backtrackingMatch(events_it, events_end))
while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
++count;
return count;

View File

@ -1,8 +1,9 @@
#include <AggregateFunctions/AggregateFunctionLeastSqr.h>
#include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h>
#include <AggregateFunctions/AggregateFunctionFactory.h>
#include <AggregateFunctions/FactoryHelpers.h>
#include <Core/TypeListNumber.h>
namespace DB
{
@ -10,7 +11,7 @@ namespace DB
namespace
{
AggregateFunctionPtr createAggregateFunctionLeastSqr(
AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
const String & name,
const DataTypes & arguments,
const Array & params
@ -20,16 +21,11 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
assertBinary(name, arguments);
const IDataType * x_arg = arguments.front().get();
WhichDataType which_x {
x_arg
};
WhichDataType which_x = x_arg;
const IDataType * y_arg = arguments.back().get();
WhichDataType which_y = y_arg;
WhichDataType which_y {
y_arg
};
#define FOR_LEASTSQR_TYPES_2(M, T) \
M(T, UInt8) \
@ -55,7 +51,7 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
FOR_LEASTSQR_TYPES_2(M, Float64)
#define DISPATCH(T1, T2) \
if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
return std::make_shared<AggregateFunctionLeastSqr<T1, T2>>( \
return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>( \
arguments, \
params \
);
@ -77,9 +73,9 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
}
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory)
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory & factory)
{
factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr);
factory.registerFunction("simpleLinearRegression", createAggregateFunctionSimpleLinearRegression);
}
}

View File

@ -19,7 +19,7 @@ namespace ErrorCodes
}
template <typename X, typename Y, typename Ret>
struct AggregateFunctionLeastSqrData final
struct AggregateFunctionSimpleLinearRegressionData final
{
size_t count = 0;
Ret sum_x = 0;
@ -36,7 +36,7 @@ struct AggregateFunctionLeastSqrData final
sum_xy += x * y;
}
void merge(const AggregateFunctionLeastSqrData & other)
void merge(const AggregateFunctionSimpleLinearRegressionData & other)
{
count += other.count;
sum_x += other.sum_x;
@ -85,19 +85,19 @@ struct AggregateFunctionLeastSqrData final
/// Calculates simple linear regression parameters.
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
template <typename X, typename Y, typename Ret = Float64>
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
AggregateFunctionLeastSqr<X, Y, Ret>
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
>
{
public:
AggregateFunctionLeastSqr(
AggregateFunctionSimpleLinearRegression(
const DataTypes & arguments,
const Array & params
):
IAggregateFunctionDataHelper<
AggregateFunctionLeastSqrData<X, Y, Ret>,
AggregateFunctionLeastSqr<X, Y, Ret>
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
> {arguments, params}
{
// notice: arguments has been checked before
@ -105,7 +105,7 @@ public:
String getName() const override
{
return "leastSqr";
return "simpleLinearRegression";
}
const char * getHeaderFilePath() const override
@ -120,12 +120,8 @@ public:
Arena *
) const override
{
auto col_x {
static_cast<const ColumnVector<X> *>(columns[0])
};
auto col_y {
static_cast<const ColumnVector<Y> *>(columns[1])
};
auto col_x = static_cast<const ColumnVector<X> *>(columns[0]);
auto col_y = static_cast<const ColumnVector<Y> *>(columns[1]);
X x = col_x->getData()[row_num];
Y y = col_y->getData()[row_num];
@ -159,12 +155,14 @@ public:
DataTypePtr getReturnType() const override
{
DataTypes types {
DataTypes types
{
std::make_shared<DataTypeNumber<Ret>>(),
std::make_shared<DataTypeNumber<Ret>>(),
};
Strings names {
Strings names
{
"k",
"b",
};

View File

@ -56,6 +56,10 @@ void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & facto
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>);
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>);
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>);
factory.registerFunction("skewSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionSkewSampSimple>);
factory.registerFunction("skewPop", createAggregateFunctionStatisticsUnary<AggregateFunctionSkewPopSimple>);
factory.registerFunction("kurtSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionKurtSampSimple>);
factory.registerFunction("kurtPop", createAggregateFunctionStatisticsUnary<AggregateFunctionKurtPopSimple>);
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>);
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>);

View File

@ -32,30 +32,42 @@ namespace DB
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
extern const int DECIMAL_OVERFLOW;
}
template <typename T>
/**
Calculating univariate central moments
Levels:
level 2 (pop & samp): var, stddev
level 3: skewness
level 4: kurtosis
References:
https://en.wikipedia.org/wiki/Moment_(mathematics)
https://en.wikipedia.org/wiki/Skewness
https://en.wikipedia.org/wiki/Kurtosis
*/
template <typename T, size_t _level>
struct VarMoments
{
T m0{};
T m1{};
T m2{};
T m[_level + 1]{};
void add(T x)
{
++m0;
m1 += x;
m2 += x * x;
++m[0];
m[1] += x;
m[2] += x * x;
if constexpr (_level >= 3) m[3] += x * x * x;
if constexpr (_level >= 4) m[4] += x * x * x * x;
}
void merge(const VarMoments & rhs)
{
m0 += rhs.m0;
m1 += rhs.m1;
m2 += rhs.m2;
m[0] += rhs.m[0];
m[1] += rhs.m[1];
m[2] += rhs.m[2];
if constexpr (_level >= 3) m[3] += rhs.m[3];
if constexpr (_level >= 4) m[4] += rhs.m[4];
}
void write(WriteBuffer & buf) const
@ -70,45 +82,90 @@ struct VarMoments
T NO_SANITIZE_UNDEFINED getPopulation() const
{
return (m2 - m1 * m1 / m0) / m0;
return (m[2] - m[1] * m[1] / m[0]) / m[0];
}
T NO_SANITIZE_UNDEFINED getSample() const
{
if (m0 == 0)
if (m[0] == 0)
return std::numeric_limits<T>::quiet_NaN();
return (m2 - m1 * m1 / m0) / (m0 - 1);
return (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
T NO_SANITIZE_UNDEFINED getMoment3() const
{
// to avoid accuracy problem
if (m[0] == 1)
return 0;
return (m[3]
- (3 * m[2]
- 2 * m[1] * m[1] / m[0]
) * m[1] / m[0]
) / m[0];
}
T NO_SANITIZE_UNDEFINED getMoment4() const
{
// to avoid accuracy problem
if (m[0] == 1)
return 0;
return (m[4]
- (4 * m[3]
- (6 * m[2]
- 3 * m[1] * m[1] / m[0]
) * m[1] / m[0]
) * m[1] / m[0]
) / m[0];
}
};
template <typename T>
template <typename T, size_t _level>
struct VarMomentsDecimal
{
using NativeType = typename T::NativeType;
UInt64 m0{};
NativeType m1{};
NativeType m2{};
NativeType m[_level]{};
NativeType & getM(size_t i)
{
return m[i - 1];
}
const NativeType & getM(size_t i) const
{
return m[i - 1];
}
void add(NativeType x)
{
++m0;
m1 += x;
getM(1) += x;
NativeType tmp; /// scale' = 2 * scale
if (common::mulOverflow(x, x, tmp) || common::addOverflow(m2, tmp, m2))
NativeType tmp;
if (common::mulOverflow(x, x, tmp) || common::addOverflow(getM(2), tmp, getM(2)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
if constexpr (_level >= 3)
if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(3), tmp, getM(3)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
if constexpr (_level >= 4)
if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(4), tmp, getM(4)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void merge(const VarMomentsDecimal & rhs)
{
m0 += rhs.m0;
m1 += rhs.m1;
getM(1) += rhs.getM(1);
if (common::addOverflow(m2, rhs.m2, m2))
if (common::addOverflow(getM(2), rhs.getM(2), getM(2)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
if constexpr (_level >= 3)
if (common::addOverflow(getM(3), rhs.getM(3), getM(3)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
if constexpr (_level >= 4)
if (common::addOverflow(getM(4), rhs.getM(4), getM(4)))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
}
void write(WriteBuffer & buf) const { writePODBinary(*this, buf); }
@ -120,8 +177,8 @@ struct VarMomentsDecimal
return std::numeric_limits<Float64>::infinity();
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp / m0), tmp))
if (common::mulOverflow(getM(1), getM(1), tmp) ||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
@ -134,15 +191,50 @@ struct VarMomentsDecimal
return std::numeric_limits<Float64>::infinity();
NativeType tmp;
if (common::mulOverflow(m1, m1, tmp) ||
common::subOverflow(m2, NativeType(tmp / m0), tmp))
if (common::mulOverflow(getM(1), getM(1), tmp) ||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale);
}
Float64 get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
Float64 getMoment3(UInt32 scale) const
{
if (m0 == 0)
return std::numeric_limits<Float64>::infinity();
NativeType tmp;
if (common::mulOverflow(2 * getM(1), getM(1), tmp) ||
common::subOverflow(3 * getM(2), NativeType(tmp / m0), tmp) ||
common::mulOverflow(tmp, getM(1), tmp) ||
common::subOverflow(getM(3), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
Float64 getMoment4(UInt32 scale) const
{
if (m0 == 0)
return std::numeric_limits<Float64>::infinity();
NativeType tmp;
if (common::mulOverflow(3 * getM(1), getM(1), tmp) ||
common::subOverflow(6 * getM(2), NativeType(tmp / m0), tmp) ||
common::mulOverflow(tmp, getM(1), tmp) ||
common::subOverflow(4 * getM(3), NativeType(tmp / m0), tmp) ||
common::mulOverflow(tmp, getM(1), tmp) ||
common::subOverflow(getM(4), NativeType(tmp / m0), tmp))
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
}
};
/**
Calculating multivariate central moments
Levels:
level 2 (pop & samp): covar
References:
https://en.wikipedia.org/wiki/Moment_(mathematics)
*/
template <typename T>
struct CovarMoments
{
@ -188,8 +280,6 @@ struct CovarMoments
return std::numeric_limits<T>::quiet_NaN();
return (xy - x1 * y1 / m0) / (m0 - 1);
}
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
template <typename T>
@ -236,9 +326,6 @@ struct CorrMoments
{
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
}
T getPopulation() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
T getSample() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
};
@ -246,18 +333,20 @@ enum class StatisticsFunctionKind
{
varPop, varSamp,
stddevPop, stddevSamp,
skewPop, skewSamp,
kurtPop, kurtSamp,
covarPop, covarSamp,
corr
};
template <typename T, StatisticsFunctionKind _kind>
template <typename T, StatisticsFunctionKind _kind, size_t _level>
struct StatFuncOneArg
{
using Type1 = T;
using Type2 = T;
using ResultType = std::conditional_t<std::is_same_v<T, Float32>, Float32, Float64>;
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128>, VarMoments<ResultType>>;
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128, _level>, VarMoments<ResultType, _level>>;
static constexpr StatisticsFunctionKind kind = _kind;
static constexpr UInt32 num_args = 1;
@ -300,17 +389,28 @@ public:
String getName() const override
{
switch (StatFunc::kind)
{
case StatisticsFunctionKind::varPop: return "varPop";
case StatisticsFunctionKind::varSamp: return "varSamp";
case StatisticsFunctionKind::stddevPop: return "stddevPop";
case StatisticsFunctionKind::stddevSamp: return "stddevSamp";
case StatisticsFunctionKind::covarPop: return "covarPop";
case StatisticsFunctionKind::covarSamp: return "covarSamp";
case StatisticsFunctionKind::corr: return "corr";
}
__builtin_unreachable();
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
return "varPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
return "varSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
return "stddevPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
return "stddevSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
return "skewPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
return "skewSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
return "kurtPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
return "kurtSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop)
return "covarPop";
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp)
return "covarSamp";
if constexpr (StatFunc::kind == StatisticsFunctionKind::corr)
return "corr";
}
DataTypePtr getReturnType() const override
@ -351,28 +451,103 @@ public:
if constexpr (IsDecimalNumber<T1>)
{
switch (StatFunc::kind)
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
dst.push_back(data.getPopulation(src_scale * 2));
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
dst.push_back(data.getSample(src_scale * 2));
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
dst.push_back(sqrt(data.getPopulation(src_scale * 2)));
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
dst.push_back(sqrt(data.getSample(src_scale * 2)));
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation(src_scale * 2)); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample(src_scale * 2)); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation(src_scale * 2))); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample(src_scale * 2))); break;
default:
__builtin_unreachable();
Float64 var_value = data.getPopulation(src_scale * 2);
if (var_value > 0)
dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5));
else
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
{
Float64 var_value = data.getSample(src_scale * 2);
if (var_value > 0)
dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5));
else
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
{
Float64 var_value = data.getPopulation(src_scale * 2);
if (var_value > 0)
dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2));
else
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
{
Float64 var_value = data.getSample(src_scale * 2);
if (var_value > 0)
dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2));
else
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
}
}
else
{
switch (StatFunc::kind)
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
dst.push_back(data.getPopulation());
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
dst.push_back(data.getSample());
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
dst.push_back(sqrt(data.getPopulation()));
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
dst.push_back(sqrt(data.getSample()));
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
{
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation())); break;
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample())); break;
case StatisticsFunctionKind::covarPop: dst.push_back(data.getPopulation()); break;
case StatisticsFunctionKind::covarSamp: dst.push_back(data.getSample()); break;
case StatisticsFunctionKind::corr: dst.push_back(data.get()); break;
ResultType var_value = data.getPopulation();
if (var_value > 0)
dst.push_back(data.getMoment3() / pow(var_value, 1.5));
else
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
{
ResultType var_value = data.getSample();
if (var_value > 0)
dst.push_back(data.getMoment3() / pow(var_value, 1.5));
else
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
{
ResultType var_value = data.getPopulation();
if (var_value > 0)
dst.push_back(data.getMoment4() / pow(var_value, 2));
else
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
{
ResultType var_value = data.getSample();
if (var_value > 0)
dst.push_back(data.getMoment4() / pow(var_value, 2));
else
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
}
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop)
dst.push_back(data.getPopulation());
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp)
dst.push_back(data.getSample());
if constexpr (StatFunc::kind == StatisticsFunctionKind::corr)
dst.push_back(data.get());
}
}
@ -383,10 +558,14 @@ private:
};
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop>>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp>>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop>>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp>>;
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop, 2>>;
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp, 2>>;
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop, 2>>;
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp, 2>>;
template <typename T> using AggregateFunctionSkewPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::skewPop, 3>>;
template <typename T> using AggregateFunctionSkewSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::skewSamp, 3>>;
template <typename T> using AggregateFunctionKurtPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::kurtPop, 4>>;
template <typename T> using AggregateFunctionKurtSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::kurtSamp, 4>>;
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarPop>>;
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarSamp>>;
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::corr>>;

View File

@ -1,30 +0,0 @@
#include "AggregateFunctionTSGroupSum.h"
#include "AggregateFunctionFactory.h"
#include "FactoryHelpers.h"
#include "Helpers.h"
namespace DB
{
namespace
{
template <bool rate>
AggregateFunctionPtr createAggregateFunctionTSgroupSum(const std::string & name, const DataTypes & arguments, const Array & params)
{
assertNoParameters(name, params);
if (arguments.size() < 3)
throw Exception("Not enough event arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<AggregateFunctionTSgroupSum<rate>>(arguments);
}
}
void registerAggregateFunctionTSgroupSum(AggregateFunctionFactory & factory)
{
factory.registerFunction("TSgroupSum", createAggregateFunctionTSgroupSum<false>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("TSgroupRateSum", createAggregateFunctionTSgroupSum<true>, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -0,0 +1,30 @@
#include "AggregateFunctionTimeSeriesGroupSum.h"
#include "AggregateFunctionFactory.h"
#include "FactoryHelpers.h"
#include "Helpers.h"
namespace DB
{
namespace
{
template <bool rate>
AggregateFunctionPtr createAggregateFunctionTimeSeriesGroupSum(const std::string & name, const DataTypes & arguments, const Array & params)
{
assertNoParameters(name, params);
if (arguments.size() < 3)
throw Exception("Not enough event arguments for aggregate function " + name, ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
return std::make_shared<AggregateFunctionTimeSeriesGroupSum<rate>>(arguments);
}
}
void registerAggregateFunctionTimeSeriesGroupSum(AggregateFunctionFactory & factory)
{
factory.registerFunction("timeSeriesGroupSum", createAggregateFunctionTimeSeriesGroupSum<false>, AggregateFunctionFactory::CaseInsensitive);
factory.registerFunction("timeSeriesGroupRateSum", createAggregateFunctionTimeSeriesGroupSum<true>, AggregateFunctionFactory::CaseInsensitive);
}
}

View File

@ -28,7 +28,7 @@ namespace ErrorCodes
extern const int TOO_MANY_ARGUMENTS_FOR_FUNCTION;
}
template <bool rate>
struct AggregateFunctionTSgroupSumData
struct AggregateFunctionTimeSeriesGroupSumData
{
using DataPoint = std::pair<Int64, Float64>;
struct Points
@ -90,7 +90,7 @@ struct AggregateFunctionTSgroupSumData
it_ss->second.add(t, v);
}
if (result.size() > 0 && t < result.back().first)
throw Exception{"TSgroupSum or TSgroupRateSum must order by timestamp asc!!!", ErrorCodes::LOGICAL_ERROR};
throw Exception{"timeSeriesGroupSum or timeSeriesGroupRateSum must order by timestamp asc!!!", ErrorCodes::LOGICAL_ERROR};
if (result.size() > 0 && t == result.back().first)
{
//do not add new point
@ -119,7 +119,7 @@ struct AggregateFunctionTSgroupSumData
}
}
void merge(const AggregateFunctionTSgroupSumData & other)
void merge(const AggregateFunctionTimeSeriesGroupSumData & other)
{
//if ts has overlap, then aggregate two series by interpolation;
AggSeries tmp;
@ -199,15 +199,15 @@ struct AggregateFunctionTSgroupSumData
}
};
template <bool rate>
class AggregateFunctionTSgroupSum final
: public IAggregateFunctionDataHelper<AggregateFunctionTSgroupSumData<rate>, AggregateFunctionTSgroupSum<rate>>
class AggregateFunctionTimeSeriesGroupSum final
: public IAggregateFunctionDataHelper<AggregateFunctionTimeSeriesGroupSumData<rate>, AggregateFunctionTimeSeriesGroupSum<rate>>
{
private:
public:
String getName() const override { return rate ? "TSgroupRateSum" : "TSgroupSum"; }
String getName() const override { return rate ? "timeSeriesGroupRateSum" : "timeSeriesGroupSum"; }
AggregateFunctionTSgroupSum(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionTSgroupSumData<rate>, AggregateFunctionTSgroupSum<rate>>(arguments, {})
AggregateFunctionTimeSeriesGroupSum(const DataTypes & arguments)
: IAggregateFunctionDataHelper<AggregateFunctionTimeSeriesGroupSumData<rate>, AggregateFunctionTimeSeriesGroupSum<rate>>(arguments, {})
{
if (!WhichDataType(arguments[0].get()).isUInt64())
throw Exception{"Illegal type " + arguments[0].get()->getName() + " of argument 1 of aggregate function " + getName()

View File

@ -48,6 +48,12 @@ public:
/// Get the result type.
virtual DataTypePtr getReturnType() const = 0;
/// Get type which will be used for prediction result in case if function is an ML method.
virtual DataTypePtr getReturnTypeToPredict() const
{
throw Exception("Prediction is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
virtual ~IAggregateFunction() {}
/** Data manipulating functions. */

View File

@ -30,7 +30,7 @@ void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
@ -41,7 +41,7 @@ void registerAggregateFunctionCombinatorNull(AggregateFunctionCombinatorFactory
void registerAggregateFunctionHistogram(AggregateFunctionFactory & factory);
void registerAggregateFunctionRetention(AggregateFunctionFactory & factory);
void registerAggregateFunctionTSgroupSum(AggregateFunctionFactory & factory);
void registerAggregateFunctionTimeSeriesGroupSum(AggregateFunctionFactory & factory);
void registerAggregateFunctions()
{
{
@ -70,10 +70,10 @@ void registerAggregateFunctions()
registerAggregateFunctionsMaxIntersections(factory);
registerAggregateFunctionHistogram(factory);
registerAggregateFunctionRetention(factory);
registerAggregateFunctionTSgroupSum(factory);
registerAggregateFunctionTimeSeriesGroupSum(factory);
registerAggregateFunctionMLMethod(factory);
registerAggregateFunctionEntropy(factory);
registerAggregateFunctionLeastSqr(factory);
registerAggregateFunctionSimpleLinearRegression(factory);
}
{

View File

@ -35,25 +35,6 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
arenas.push_back(arena_);
}
/// This function is used in convertToValues() and predictValues()
/// and is written here to avoid repetitions
bool ColumnAggregateFunction::tryFinalizeAggregateFunction(MutableColumnPtr *res_) const
{
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
*res_ = std::move(res);
return true;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
*res_ = std::move(res);
return false;
}
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
{
/** If the aggregate function returns an unfinalized/unfinished state,
@ -86,17 +67,17 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
* AggregateFunction(quantileTiming(0.5), UInt64)
* into UInt16 - already finished result of `quantileTiming`.
*/
/** Convertion function is used in convertToValues and predictValues
* in the similar part of both functions
*/
MutableColumnPtr res;
if (tryFinalizeAggregateFunction(&res))
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
{
auto res = createView();
res->set(function_state->getNestedFunction());
res->data.assign(data.begin(), data.end());
return res;
}
MutableColumnPtr res = func->getReturnType()->createColumn();
res->reserve(data.size());
for (auto val : data)
func->insertResultInto(val, *res);
@ -105,8 +86,8 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
{
MutableColumnPtr res;
tryFinalizeAggregateFunction(&res);
MutableColumnPtr res = func->getReturnTypeToPredict()->createColumn();
res->reserve(data.size());
auto ML_function = func.get();
if (ML_function)

View File

@ -94,7 +94,7 @@ ColumnPtr ColumnDecimal<T>::permute(const IColumn::Permutation & perm, size_t li
for (size_t i = 0; i < size; ++i)
res_data[i] = data[perm[i]];
return std::move(res);
return res;
}
template <typename T>
@ -117,7 +117,7 @@ MutableColumnPtr ColumnDecimal<T>::cloneResized(size_t size) const
}
}
return std::move(res);
return res;
}
template <typename T>
@ -169,7 +169,7 @@ ColumnPtr ColumnDecimal<T>::filter(const IColumn::Filter & filt, ssize_t result_
++data_pos;
}
return std::move(res);
return res;
}
template <typename T>
@ -202,7 +202,7 @@ ColumnPtr ColumnDecimal<T>::replicate(const IColumn::Offsets & offsets) const
res_data.push_back(data[i]);
}
return std::move(res);
return res;
}
template <typename T>

View File

@ -187,7 +187,7 @@ ColumnPtr ColumnDecimal<T>::indexImpl(const PaddedPODArray<Type> & indexes, size
for (size_t i = 0; i < limit; ++i)
res_data[i] = data[indexes[i]];
return std::move(res);
return res;
}
}

View File

@ -349,7 +349,7 @@ ColumnPtr ColumnLowCardinality::countKeys() const
auto counter = ColumnUInt64::create(dict_size, 0);
idx.countKeys(counter->getData());
return std::move(counter);
return counter;
}

View File

@ -81,9 +81,18 @@ StringRef ColumnNullable::getDataAt(size_t /*n*/) const
throw Exception{"Method getDataAt is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED};
}
void ColumnNullable::insertData(const char * /*pos*/, size_t /*length*/)
void ColumnNullable::insertData(const char * pos, size_t length)
{
throw Exception{"Method insertData is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED};
if (pos == nullptr)
{
getNestedColumn().insertDefault();
getNullMapData().push_back(1);
}
else
{
getNestedColumn().insertData(pos, length);
getNullMapData().push_back(0);
}
}
StringRef ColumnNullable::serializeValueIntoArena(size_t n, Arena & arena, char const *& begin) const

View File

@ -51,6 +51,8 @@ public:
bool getBool(size_t n) const override { return isNullAt(n) ? 0 : nested_column->getBool(n); }
UInt64 get64(size_t n) const override { return nested_column->get64(n); }
StringRef getDataAt(size_t n) const override;
/// Will insert null value if pos=nullptr
void insertData(const char * pos, size_t length) override;
StringRef serializeValueIntoArena(size_t n, Arena & arena, char const *& begin) const override;
const char * deserializeAndInsertFromArena(const char * pos) override;

View File

@ -215,6 +215,12 @@ UInt64 ColumnVector<T>::get64(size_t n) const
return ext::bit_cast<UInt64>(data[n]);
}
template <typename T>
Float64 ColumnVector<T>::getFloat64(size_t n) const
{
return static_cast<Float64>(data[n]);
}
template <typename T>
void ColumnVector<T>::insertRangeFrom(const IColumn & src, size_t start, size_t length)
{

View File

@ -202,6 +202,8 @@ public:
UInt64 get64(size_t n) const override;
Float64 getFloat64(size_t n) const override;
UInt64 getUInt(size_t n) const override
{
return UInt64(data[n]);

View File

@ -91,6 +91,13 @@ public:
throw Exception("Method get64 is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/// If column stores native numeric type, it returns n-th element casted to Float64
/// Is used in regression methods to cast each features into uniform type
virtual Float64 getFloat64(size_t /*n*/) const
{
throw Exception("Method getFloat64 is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
}
/** If column is numeric, return value of n-th element, casted to UInt64.
* For NULL values of Nullable column it is allowed to return arbitrary value.
* Otherwise throw an exception.
@ -141,6 +148,7 @@ public:
/// Appends data located in specified memory chunk if it is possible (throws an exception if it cannot be implemented).
/// Is used to optimize some computations (in aggregation, for example).
/// Parameter length could be ignored if column values have fixed size.
/// All data will be inserted as single element
virtual void insertData(const char * pos, size_t length) = 0;
/// Appends "default value".

View File

@ -375,7 +375,7 @@ ColumnUInt64::MutablePtr ReverseIndex<IndexType, ColumnType>::calcHashes() const
for (auto row : ext::range(0, size))
hash->getElement(row) = getHash(column->getDataAt(row));
return std::move(hash);
return hash;
}
template <typename IndexType, typename ColumnType>

View File

@ -0,0 +1,67 @@
#include <Columns/getLeastSuperColumn.h>
#include <Columns/IColumn.h>
#include <Columns/ColumnConst.h>
#include <DataTypes/getLeastSupertype.h>
namespace DB
{
namespace ErrorCodes
{
extern const int LOGICAL_ERROR;
}
static bool sameConstants(const IColumn & a, const IColumn & b)
{
return static_cast<const ColumnConst &>(a).getField() == static_cast<const ColumnConst &>(b).getField();
}
ColumnWithTypeAndName getLeastSuperColumn(std::vector<const ColumnWithTypeAndName *> columns)
{
if (columns.empty())
throw Exception("Logical error: no src columns for supercolumn", ErrorCodes::LOGICAL_ERROR);
ColumnWithTypeAndName result = *columns[0];
/// Determine common type.
size_t num_const = 0;
DataTypes types(columns.size());
for (size_t i = 0; i < columns.size(); ++i)
{
types[i] = columns[i]->type;
if (columns[i]->column->isColumnConst())
++num_const;
}
result.type = getLeastSupertype(types);
/// Create supertype column saving constness if possible.
bool save_constness = false;
if (columns.size() == num_const)
{
save_constness = true;
for (size_t i = 1; i < columns.size(); ++i)
{
const ColumnWithTypeAndName & first = *columns[0];
const ColumnWithTypeAndName & other = *columns[i];
if (!sameConstants(*first.column, *other.column))
{
save_constness = false;
break;
}
}
}
if (save_constness)
result.column = result.type->createColumnConst(0, static_cast<const ColumnConst &>(*columns[0]->column).getField());
else
result.column = result.type->createColumn();
return result;
}
}

View File

@ -0,0 +1,12 @@
#pragma once
#include <Core/ColumnWithTypeAndName.h>
namespace DB
{
/// getLeastSupertype + related column changes
ColumnWithTypeAndName getLeastSuperColumn(std::vector<const ColumnWithTypeAndName *> columns);
}

View File

@ -427,6 +427,9 @@ namespace ErrorCodes
extern const int BAD_TTL_EXPRESSION = 450;
extern const int BAD_TTL_FILE = 451;
extern const int SETTING_CONSTRAINT_VIOLATION = 452;
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES = 453;
extern const int OPENSSL_ERROR = 454;
extern const int SUSPICIOUS_TYPE_FOR_LOW_CARDINALITY = 455;
extern const int KEEPER_EXCEPTION = 999;
extern const int POCO_EXCEPTION = 1000;

View File

@ -0,0 +1,18 @@
#include "OpenSSLHelpers.h"
#include <ext/scope_guard.h>
#include <openssl/err.h>
namespace DB
{
String getOpenSSLErrors()
{
BIO * mem = BIO_new(BIO_s_mem());
SCOPE_EXIT(BIO_free(mem));
ERR_print_errors(mem);
char * buf = nullptr;
long size = BIO_get_mem_data(mem, &buf);
return String(buf, size);
}
}

View File

@ -0,0 +1,12 @@
#pragma once
#include <Core/Types.h>
namespace DB
{
/// Returns concatenation of error strings for all errors that OpenSSL has recorded, emptying the error queue.
String getOpenSSLErrors();
}

View File

@ -63,7 +63,7 @@ template <typename T> bool inline operator> (T a, const UInt128 b) { return UIn
template <typename T> bool inline operator<= (T a, const UInt128 b) { return UInt128(a) <= b; }
template <typename T> bool inline operator< (T a, const UInt128 b) { return UInt128(a) < b; }
template <> constexpr bool IsNumber<UInt128> = true;
template <> inline constexpr bool IsNumber<UInt128> = true;
template <> struct TypeName<UInt128> { static const char * get() { return "UInt128"; } };
template <> struct TypeId<UInt128> { static constexpr const TypeIndex value = TypeIndex::UInt128; };

View File

@ -1,7 +1,9 @@
#include <Common/localBackup.h>
#include "localBackup.h"
#include <Common/createHardLink.h>
#include <Common/Exception.h>
#include <Poco/DirectoryIterator.h>
#include <Poco/Path.h>
#include <Poco/File.h>
#include <string>
#include <iostream>

View File

@ -1,8 +1,8 @@
#pragma once
#include <Poco/Path.h>
#include <optional>
namespace Poco { class Path; }
namespace DB
{

View File

@ -37,7 +37,61 @@ namespace ErrorCodes
extern const int CANNOT_DECOMPRESS;
}
static constexpr auto CHECKSUM_SIZE{sizeof(CityHash_v1_0_2::uint128)};
using Checksum = CityHash_v1_0_2::uint128;
/// Validate checksum of data, and if it mismatches, find out possible reason and throw exception.
static void validateChecksum(char * data, size_t size, const Checksum expected_checksum)
{
auto calculated_checksum = CityHash_v1_0_2::CityHash128(data, size);
if (expected_checksum == calculated_checksum)
return;
std::stringstream message;
/// TODO mess up of endianess in error message.
message << "Checksum doesn't match: corrupted data."
" Reference: " + getHexUIntLowercase(expected_checksum.first) + getHexUIntLowercase(expected_checksum.second)
+ ". Actual: " + getHexUIntLowercase(calculated_checksum.first) + getHexUIntLowercase(calculated_checksum.second)
+ ". Size of compressed block: " + toString(size);
auto message_hardware_failure = "This is most likely due to hardware failure. If you receive broken data over network and the error does not repeat every time, this can be caused by bad RAM on network interface controller or bad controller itself or bad RAM on network switches or bad CPU on network switches (look at the logs on related network switches; note that TCP checksums don't help) or bad RAM on host (look at dmesg or kern.log for enormous amount of EDAC errors, ECC-related reports, Machine Check Exceptions, mcelog; note that ECC memory can fail if the number of errors is huge) or bad CPU on host. If you read data from disk, this can be caused by disk bit rott. This exception protects ClickHouse from data corruption due to hardware failures.";
auto flip_bit = [](char * buf, size_t pos)
{
buf[pos / 8] ^= 1 << pos % 8;
};
/// Check if the difference caused by single bit flip in data.
for (size_t bit_pos = 0; bit_pos < size * 8; ++bit_pos)
{
flip_bit(data, bit_pos);
auto checksum_of_data_with_flipped_bit = CityHash_v1_0_2::CityHash128(data, size);
if (expected_checksum == checksum_of_data_with_flipped_bit)
{
message << ". The mismatch is caused by single bit flip in data block at byte " << (bit_pos / 8) << ", bit " << (bit_pos % 8) << ". "
<< message_hardware_failure;
throw Exception(message.str(), ErrorCodes::CHECKSUM_DOESNT_MATCH);
}
flip_bit(data, bit_pos); /// Restore
}
/// Check if the difference caused by single bit flip in stored checksum.
size_t difference = __builtin_popcountll(expected_checksum.first ^ calculated_checksum.first)
+ __builtin_popcountll(expected_checksum.second ^ calculated_checksum.second);
if (difference == 1)
{
message << ". The mismatch is caused by single bit flip in checksum. "
<< message_hardware_failure;
throw Exception(message.str(), ErrorCodes::CHECKSUM_DOESNT_MATCH);
}
throw Exception(message.str(), ErrorCodes::CHECKSUM_DOESNT_MATCH);
}
/// Read compressed data into compressed_buffer. Get size of decompressed data from block header. Checksum if need.
/// Returns number of compressed bytes read.
@ -46,8 +100,8 @@ size_t CompressedReadBufferBase::readCompressedData(size_t & size_decompressed,
if (compressed_in->eof())
return 0;
CityHash_v1_0_2::uint128 checksum;
compressed_in->readStrict(reinterpret_cast<char *>(&checksum), CHECKSUM_SIZE);
Checksum checksum;
compressed_in->readStrict(reinterpret_cast<char *>(&checksum), sizeof(Checksum));
UInt8 header_size = ICompressionCodec::getHeaderSize();
own_compressed_buffer.resize(header_size);
@ -73,7 +127,7 @@ size_t CompressedReadBufferBase::readCompressedData(size_t & size_decompressed,
+ ". Most likely corrupted data.",
ErrorCodes::TOO_LARGE_SIZE_COMPRESSED);
ProfileEvents::increment(ProfileEvents::ReadCompressedBytes, size_compressed_without_checksum + CHECKSUM_SIZE);
ProfileEvents::increment(ProfileEvents::ReadCompressedBytes, size_compressed_without_checksum + sizeof(Checksum));
/// Is whole compressed block located in 'compressed_in->' buffer?
if (compressed_in->offset() >= header_size &&
@ -91,18 +145,9 @@ size_t CompressedReadBufferBase::readCompressedData(size_t & size_decompressed,
}
if (!disable_checksum)
{
auto checksum_calculated = CityHash_v1_0_2::CityHash128(compressed_buffer, size_compressed_without_checksum);
if (checksum != checksum_calculated)
throw Exception("Checksum doesn't match: corrupted data."
" Reference: " + getHexUIntLowercase(checksum.first) + getHexUIntLowercase(checksum.second)
+ ". Actual: " + getHexUIntLowercase(checksum_calculated.first) + getHexUIntLowercase(checksum_calculated.second)
+ ". Size of compressed block: " + toString(size_compressed_without_checksum),
ErrorCodes::CHECKSUM_DOESNT_MATCH);
}
validateChecksum(compressed_buffer, size_compressed_without_checksum, checksum);
return size_compressed_without_checksum + CHECKSUM_SIZE;
return size_compressed_without_checksum + sizeof(Checksum);
}

View File

@ -56,6 +56,8 @@
#define DBMS_MIN_REVISION_WITH_LOW_CARDINALITY_TYPE 54405
#define DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO 54421
/// Version of ClickHouse TCP protocol. Set to git tag with latest protocol change.
#define DBMS_TCP_PROTOCOL_VERSION 54226

View File

@ -0,0 +1,94 @@
#include <IO/WriteBuffer.h>
#include <IO/ReadBufferFromString.h>
#include <IO/WriteBufferFromString.h>
#include <common/logger_useful.h>
#include <random>
#include <sstream>
#include "MySQLProtocol.h"
namespace DB
{
namespace MySQLProtocol
{
void PacketSender::resetSequenceId()
{
sequence_id = 0;
}
String PacketSender::packetToText(String payload)
{
String result;
for (auto c : payload)
{
result += ' ';
result += std::to_string(static_cast<unsigned char>(c));
}
return result;
}
uint64_t readLengthEncodedNumber(std::istringstream & ss)
{
char c;
uint64_t buf = 0;
ss.get(c);
auto cc = static_cast<uint8_t>(c);
if (cc < 0xfc)
{
return cc;
}
else if (cc < 0xfd)
{
ss.read(reinterpret_cast<char *>(&buf), 2);
}
else if (cc < 0xfe)
{
ss.read(reinterpret_cast<char *>(&buf), 3);
}
else
{
ss.read(reinterpret_cast<char *>(&buf), 8);
}
return buf;
}
std::string writeLengthEncodedNumber(uint64_t x)
{
std::string result;
if (x < 251)
{
result.append(1, static_cast<char>(x));
}
else if (x < (1 << 16))
{
result.append(1, 0xfc);
result.append(reinterpret_cast<char *>(&x), 2);
}
else if (x < (1 << 24))
{
result.append(1, 0xfd);
result.append(reinterpret_cast<char *>(&x), 3);
}
else
{
result.append(1, 0xfe);
result.append(reinterpret_cast<char *>(&x), 8);
}
return result;
}
void writeLengthEncodedString(std::string & payload, const std::string & s)
{
payload.append(writeLengthEncodedNumber(s.length()));
payload.append(s);
}
void writeNulTerminatedString(std::string & payload, const std::string & s)
{
payload.append(s);
payload.append(1, 0);
}
}
}

View File

@ -0,0 +1,649 @@
#pragma once
#include <Core/Types.h>
#include <IO/copyData.h>
#include <IO/ReadBuffer.h>
#include <IO/ReadBufferFromPocoSocket.h>
#include <IO/WriteBuffer.h>
#include <IO/WriteBufferFromPocoSocket.h>
#include <IO/WriteBufferFromString.h>
#include <Poco/Net/StreamSocket.h>
#include <Poco/RandomStream.h>
#include <random>
#include <sstream>
/// Implementation of MySQL wire protocol
namespace DB
{
namespace ErrorCodes
{
extern const int UNKNOWN_PACKET_FROM_CLIENT;
}
namespace MySQLProtocol
{
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
const size_t SCRAMBLE_LENGTH = 20;
const size_t AUTH_PLUGIN_DATA_PART_1_LENGTH = 8;
const size_t MYSQL_ERRMSG_SIZE = 512;
const size_t PACKET_HEADER_SIZE = 4;
const size_t SSL_REQUEST_PAYLOAD_SIZE = 32;
namespace Authentication
{
const String SHA256 = "sha256_password"; /// Caching SHA2 plugin is not used because it would be possible to authenticate knowing hash from users.xml.
}
enum CharacterSet
{
utf8_general_ci = 33,
binary = 63
};
enum StatusFlags
{
SERVER_SESSION_STATE_CHANGED = 0x4000
};
enum Capability
{
CLIENT_CONNECT_WITH_DB = 0x00000008,
CLIENT_PROTOCOL_41 = 0x00000200,
CLIENT_SSL = 0x00000800,
CLIENT_TRANSACTIONS = 0x00002000, // TODO
CLIENT_SESSION_TRACK = 0x00800000, // TODO
CLIENT_SECURE_CONNECTION = 0x00008000,
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000,
CLIENT_PLUGIN_AUTH = 0x00080000,
CLIENT_DEPRECATE_EOF = 0x01000000,
};
enum Command
{
COM_SLEEP = 0x0,
COM_QUIT = 0x1,
COM_INIT_DB = 0x2,
COM_QUERY = 0x3,
COM_FIELD_LIST = 0x4,
COM_CREATE_DB = 0x5,
COM_DROP_DB = 0x6,
COM_REFRESH = 0x7,
COM_SHUTDOWN = 0x8,
COM_STATISTICS = 0x9,
COM_PROCESS_INFO = 0xa,
COM_CONNECT = 0xb,
COM_PROCESS_KILL = 0xc,
COM_DEBUG = 0xd,
COM_PING = 0xe,
COM_TIME = 0xf,
COM_DELAYED_INSERT = 0x10,
COM_CHANGE_USER = 0x11,
COM_RESET_CONNECTION = 0x1f,
COM_DAEMON = 0x1d
};
enum ColumnType
{
MYSQL_TYPE_DECIMAL = 0x00,
MYSQL_TYPE_TINY = 0x01,
MYSQL_TYPE_SHORT = 0x02,
MYSQL_TYPE_LONG = 0x03,
MYSQL_TYPE_FLOAT = 0x04,
MYSQL_TYPE_DOUBLE = 0x05,
MYSQL_TYPE_NULL = 0x06,
MYSQL_TYPE_TIMESTAMP = 0x07,
MYSQL_TYPE_LONGLONG = 0x08,
MYSQL_TYPE_INT24 = 0x09,
MYSQL_TYPE_DATE = 0x0a,
MYSQL_TYPE_TIME = 0x0b,
MYSQL_TYPE_DATETIME = 0x0c,
MYSQL_TYPE_YEAR = 0x0d,
MYSQL_TYPE_VARCHAR = 0x0f,
MYSQL_TYPE_BIT = 0x10,
MYSQL_TYPE_NEWDECIMAL = 0xf6,
MYSQL_TYPE_ENUM = 0xf7,
MYSQL_TYPE_SET = 0xf8,
MYSQL_TYPE_TINY_BLOB = 0xf9,
MYSQL_TYPE_MEDIUM_BLOB = 0xfa,
MYSQL_TYPE_LONG_BLOB = 0xfb,
MYSQL_TYPE_BLOB = 0xfc,
MYSQL_TYPE_VAR_STRING = 0xfd,
MYSQL_TYPE_STRING = 0xfe,
MYSQL_TYPE_GEOMETRY = 0xff
};
class ProtocolError : public DB::Exception
{
public:
using Exception::Exception;
};
class WritePacket
{
public:
virtual String getPayload() const = 0;
virtual ~WritePacket() = default;
};
class ReadPacket
{
public:
ReadPacket() = default;
ReadPacket(const ReadPacket &) = default;
virtual void readPayload(String payload) = 0;
virtual ~ReadPacket() = default;
};
/* Writes and reads packets, keeping sequence-id.
* Throws ProtocolError, if packet with incorrect sequence-id was received.
*/
class PacketSender
{
public:
size_t & sequence_id;
ReadBuffer * in;
WriteBuffer * out;
size_t max_packet_size = MAX_PACKET_LENGTH;
/// For reading and writing.
PacketSender(ReadBuffer & in, WriteBuffer & out, size_t & sequence_id)
: sequence_id(sequence_id)
, in(&in)
, out(&out)
{
}
/// For writing.
PacketSender(WriteBuffer & out, size_t & sequence_id)
: sequence_id(sequence_id)
, in(nullptr)
, out(&out)
{
}
String receivePacketPayload()
{
WriteBufferFromOwnString buf;
size_t payload_length = 0;
size_t packet_sequence_id = 0;
// packets which are larger than or equal to 16MB are splitted
do
{
in->readStrict(reinterpret_cast<char *>(&payload_length), 3);
if (payload_length > max_packet_size)
{
std::ostringstream tmp;
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
in->readStrict(reinterpret_cast<char *>(&packet_sequence_id), 1);
if (packet_sequence_id != sequence_id)
{
std::ostringstream tmp;
tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << sequence_id << '.';
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
sequence_id++;
copyData(*in, static_cast<WriteBuffer &>(buf), payload_length);
} while (payload_length == max_packet_size);
return std::move(buf.str());
}
void receivePacket(ReadPacket & packet)
{
packet.readPayload(receivePacketPayload());
}
template<class T>
void sendPacket(const T & packet, bool flush = false)
{
static_assert(std::is_base_of<WritePacket, T>());
String payload = packet.getPayload();
size_t pos = 0;
do
{
size_t payload_length = std::min(payload.length() - pos, max_packet_size);
out->write(reinterpret_cast<const char *>(&payload_length), 3);
out->write(reinterpret_cast<const char *>(&sequence_id), 1);
out->write(payload.data() + pos, payload_length);
pos += payload_length;
sequence_id++;
} while (pos < payload.length());
if (flush)
out->next();
}
/// Sets sequence-id to 0. Must be called before each command phase.
void resetSequenceId();
private:
/// Converts packet to text. Is used for debug output.
static String packetToText(String payload);
};
uint64_t readLengthEncodedNumber(std::istringstream & ss);
String writeLengthEncodedNumber(uint64_t x);
void writeLengthEncodedString(String & payload, const String & s);
void writeNulTerminatedString(String & payload, const String & s);
class Handshake : public WritePacket
{
int protocol_version = 0xa;
String server_version;
uint32_t connection_id;
uint32_t capability_flags;
uint8_t character_set;
uint32_t status_flags;
String auth_plugin_data;
public:
explicit Handshake(uint32_t capability_flags, uint32_t connection_id, String server_version, String auth_plugin_data)
: protocol_version(0xa)
, server_version(std::move(server_version))
, connection_id(connection_id)
, capability_flags(capability_flags)
, character_set(CharacterSet::utf8_general_ci)
, status_flags(0)
, auth_plugin_data(auth_plugin_data)
{
}
String getPayload() const override
{
String result;
result.append(1, protocol_version);
writeNulTerminatedString(result, server_version);
result.append(reinterpret_cast<const char *>(&connection_id), 4);
writeNulTerminatedString(result, auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH));
result.append(reinterpret_cast<const char *>(&capability_flags), 2);
result.append(reinterpret_cast<const char *>(&character_set), 1);
result.append(reinterpret_cast<const char *>(&status_flags), 2);
result.append((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
result.append(1, auth_plugin_data.size());
result.append(10, 0x0);
result.append(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH));
result.append(Authentication::SHA256);
result.append(1, 0x0);
return result;
}
};
class SSLRequest : public ReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
void readPayload(String s) override
{
std::istringstream ss(s);
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
}
};
class HandshakeResponse : public ReadPacket
{
public:
uint32_t capability_flags;
uint32_t max_packet_size;
uint8_t character_set;
String username;
String auth_response;
String database;
String auth_plugin_name;
HandshakeResponse() = default;
HandshakeResponse(const HandshakeResponse &) = default;
void readPayload(String s) override
{
std::istringstream ss(s);
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
ss.ignore(23);
std::getline(ss, username, static_cast<char>(0x0));
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
{
auto len = readLengthEncodedNumber(ss);
auth_response.resize(len);
ss.read(auth_response.data(), static_cast<std::streamsize>(len));
}
else if (capability_flags & CLIENT_SECURE_CONNECTION)
{
uint8_t len;
ss.read(reinterpret_cast<char *>(&len), 1);
auth_response.resize(len);
ss.read(auth_response.data(), len);
}
else
{
std::getline(ss, auth_response, static_cast<char>(0x0));
}
if (capability_flags & CLIENT_CONNECT_WITH_DB)
{
std::getline(ss, database, static_cast<char>(0x0));
}
if (capability_flags & CLIENT_PLUGIN_AUTH)
{
std::getline(ss, auth_plugin_name, static_cast<char>(0x0));
}
}
};
class AuthSwitchRequest : public WritePacket
{
String plugin_name;
String auth_plugin_data;
public:
AuthSwitchRequest(String plugin_name, String auth_plugin_data)
: plugin_name(std::move(plugin_name)), auth_plugin_data(std::move(auth_plugin_data))
{
}
String getPayload() const override
{
String result;
result.append(1, 0xfe);
writeNulTerminatedString(result, plugin_name);
result.append(auth_plugin_data);
return result;
}
};
class AuthSwitchResponse : public ReadPacket
{
public:
String value;
void readPayload(String s) override
{
value = std::move(s);
}
};
class AuthMoreData : public WritePacket
{
String data;
public:
AuthMoreData(String data): data(std::move(data)) {}
String getPayload() const override
{
String result;
result.append(1, 0x01);
result.append(data);
return result;
}
};
/// Packet with a single null-terminated string. Is used for clear text authentication.
class NullTerminatedString : public ReadPacket
{
public:
String value;
void readPayload(String s) override
{
if (s.length() == 0 || s.back() != 0)
{
throw ProtocolError("String is not null terminated.", ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
}
value = s;
value.pop_back();
}
};
class OK_Packet : public WritePacket
{
uint8_t header;
uint32_t capabilities;
uint64_t affected_rows;
int16_t warnings = 0;
uint32_t status_flags;
String session_state_changes;
String info;
public:
OK_Packet(uint8_t header,
uint32_t capabilities,
uint64_t affected_rows,
uint32_t status_flags,
int16_t warnings,
String session_state_changes = "",
String info = "")
: header(header)
, capabilities(capabilities)
, affected_rows(affected_rows)
, warnings(warnings)
, status_flags(status_flags)
, session_state_changes(std::move(session_state_changes))
, info(info)
{
}
String getPayload() const override
{
String result;
result.append(1, header);
result.append(writeLengthEncodedNumber(affected_rows));
result.append(writeLengthEncodedNumber(0)); /// last insert-id
if (capabilities & CLIENT_PROTOCOL_41)
{
result.append(reinterpret_cast<const char *>(&status_flags), 2);
result.append(reinterpret_cast<const char *>(&warnings), 2);
}
else if (capabilities & CLIENT_TRANSACTIONS)
{
result.append(reinterpret_cast<const char *>(&status_flags), 2);
}
if (capabilities & CLIENT_SESSION_TRACK)
{
result.append(writeLengthEncodedNumber(info.length()));
result.append(info);
if (status_flags & SERVER_SESSION_STATE_CHANGED)
{
result.append(writeLengthEncodedNumber(session_state_changes.length()));
result.append(session_state_changes);
}
}
else
{
result.append(info);
}
return result;
}
};
class EOF_Packet : public WritePacket
{
int warnings;
int status_flags;
public:
EOF_Packet(int warnings, int status_flags) : warnings(warnings), status_flags(status_flags)
{}
String getPayload() const override
{
String result;
result.append(1, 0xfe); // EOF header
result.append(reinterpret_cast<const char *>(&warnings), 2);
result.append(reinterpret_cast<const char *>(&status_flags), 2);
return result;
}
};
class ERR_Packet : public WritePacket
{
int error_code;
String sql_state;
String error_message;
public:
ERR_Packet(int error_code, String sql_state, String error_message)
: error_code(error_code), sql_state(std::move(sql_state)), error_message(std::move(error_message))
{
}
String getPayload() const override
{
String result;
result.append(1, 0xff);
result.append(reinterpret_cast<const char *>(&error_code), 2);
result.append("#", 1);
result.append(sql_state.data(), sql_state.length());
result.append(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
return result;
}
};
class ColumnDefinition : public WritePacket
{
String schema;
String table;
String org_table;
String name;
String org_name;
size_t next_length = 0x0c;
uint16_t character_set;
uint32_t column_length;
ColumnType column_type;
uint16_t flags;
uint8_t decimals = 0x00;
public:
ColumnDefinition(
String schema,
String table,
String org_table,
String name,
String org_name,
uint16_t character_set,
uint32_t column_length,
ColumnType column_type,
uint16_t flags,
uint8_t decimals)
: schema(std::move(schema)), table(std::move(table)), org_table(std::move(org_table)), name(std::move(name)),
org_name(std::move(org_name)), character_set(character_set), column_length(column_length), column_type(column_type), flags(flags),
decimals(decimals)
{
}
/// Should be used when column metadata (original name, table, original table, database) is unknown.
ColumnDefinition(
String name,
uint16_t character_set,
uint32_t column_length,
ColumnType column_type,
uint16_t flags,
uint8_t decimals)
: ColumnDefinition("", "", "", std::move(name), "", character_set, column_length, column_type, flags, decimals)
{
}
String getPayload() const override
{
String result;
writeLengthEncodedString(result, "def"); /// always "def"
writeLengthEncodedString(result, ""); /// schema
writeLengthEncodedString(result, ""); /// table
writeLengthEncodedString(result, ""); /// org_table
writeLengthEncodedString(result, name);
writeLengthEncodedString(result, ""); /// org_name
result.append(writeLengthEncodedNumber(next_length));
result.append(reinterpret_cast<const char *>(&character_set), 2);
result.append(reinterpret_cast<const char *>(&column_length), 4);
result.append(reinterpret_cast<const char *>(&column_type), 1);
result.append(reinterpret_cast<const char *>(&flags), 2);
result.append(reinterpret_cast<const char *>(&decimals), 2);
result.append(2, 0x0);
return result;
}
};
class ComFieldList : public ReadPacket
{
public:
String table, field_wildcard;
void readPayload(String payload)
{
std::istringstream ss(payload);
ss.ignore(1); // command byte
std::getline(ss, table, static_cast<char>(0x0));
field_wildcard = payload.substr(table.length() + 2); // rest of the packet
}
};
class LengthEncodedNumber : public WritePacket
{
uint64_t value;
public:
LengthEncodedNumber(uint64_t value): value(value)
{
}
String getPayload() const override
{
return writeLengthEncodedNumber(value);
}
};
class ResultsetRow : public WritePacket
{
std::vector<String> columns;
public:
ResultsetRow()
{
}
void appendColumn(String value)
{
columns.emplace_back(std::move(value));
}
String getPayload() const override
{
String result;
for (const String & column : columns)
{
writeLengthEncodedString(result, column);
}
return result;
}
};
}
}

View File

@ -85,6 +85,7 @@ struct Settings : public SettingsCollection<Settings>
M(SettingFloat, totals_auto_threshold, 0.5, "The threshold for totals_mode = 'auto'.") \
\
M(SettingBool, compile, false, "Whether query compilation is enabled.") \
M(SettingBool, allow_suspicious_low_cardinality_types, false, "In CREATE TABLE statement allows specifying LowCardinality modifier for types of small fixed size (8 or less). Enabling this may increase merge times and memory consumption.") \
M(SettingBool, compile_expressions, false, "Compile some scalar functions and operators to native code.") \
M(SettingUInt64, min_count_to_compile, 3, "The number of structurally identical queries before they are compiled.") \
M(SettingUInt64, min_count_to_compile_expression, 3, "The number of identical expressions before they are JIT-compiled") \

View File

@ -32,16 +32,16 @@ using String = std::string;
*/
template <typename T> constexpr bool IsNumber = false;
template <> constexpr bool IsNumber<UInt8> = true;
template <> constexpr bool IsNumber<UInt16> = true;
template <> constexpr bool IsNumber<UInt32> = true;
template <> constexpr bool IsNumber<UInt64> = true;
template <> constexpr bool IsNumber<Int8> = true;
template <> constexpr bool IsNumber<Int16> = true;
template <> constexpr bool IsNumber<Int32> = true;
template <> constexpr bool IsNumber<Int64> = true;
template <> constexpr bool IsNumber<Float32> = true;
template <> constexpr bool IsNumber<Float64> = true;
template <> inline constexpr bool IsNumber<UInt8> = true;
template <> inline constexpr bool IsNumber<UInt16> = true;
template <> inline constexpr bool IsNumber<UInt32> = true;
template <> inline constexpr bool IsNumber<UInt64> = true;
template <> inline constexpr bool IsNumber<Int8> = true;
template <> inline constexpr bool IsNumber<Int16> = true;
template <> inline constexpr bool IsNumber<Int32> = true;
template <> inline constexpr bool IsNumber<Int64> = true;
template <> inline constexpr bool IsNumber<Float32> = true;
template <> inline constexpr bool IsNumber<Float64> = true;
template <typename T> struct TypeName;
@ -109,8 +109,8 @@ using Strings = std::vector<String>;
using Int128 = __int128;
template <> constexpr bool IsNumber<Int128> = true;
template <> struct TypeName<Int128> { static const char * get() { return "Int128"; } };
template <> inline constexpr bool IsNumber<Int128> = true;
template <> struct TypeName<Int128> { static const char * get() { return "Int128"; } };
template <> struct TypeId<Int128> { static constexpr const TypeIndex value = TypeIndex::Int128; };
/// Own FieldType for Decimal.
@ -159,17 +159,56 @@ template <> struct TypeId<Decimal32> { static constexpr const TypeIndex value
template <> struct TypeId<Decimal64> { static constexpr const TypeIndex value = TypeIndex::Decimal64; };
template <> struct TypeId<Decimal128> { static constexpr const TypeIndex value = TypeIndex::Decimal128; };
template <typename T>
constexpr bool IsDecimalNumber = false;
template <> constexpr bool IsDecimalNumber<Decimal32> = true;
template <> constexpr bool IsDecimalNumber<Decimal64> = true;
template <> constexpr bool IsDecimalNumber<Decimal128> = true;
template <typename T> constexpr bool IsDecimalNumber = false;
template <> inline constexpr bool IsDecimalNumber<Decimal32> = true;
template <> inline constexpr bool IsDecimalNumber<Decimal64> = true;
template <> inline constexpr bool IsDecimalNumber<Decimal128> = true;
template <typename T> struct NativeType { using Type = T; };
template <> struct NativeType<Decimal32> { using Type = Int32; };
template <> struct NativeType<Decimal64> { using Type = Int64; };
template <> struct NativeType<Decimal128> { using Type = Int128; };
inline const char * getTypeName(TypeIndex idx)
{
switch (idx)
{
case TypeIndex::Nothing: return "Nothing";
case TypeIndex::UInt8: return TypeName<UInt8>::get();
case TypeIndex::UInt16: return TypeName<UInt16>::get();
case TypeIndex::UInt32: return TypeName<UInt32>::get();
case TypeIndex::UInt64: return TypeName<UInt64>::get();
case TypeIndex::UInt128: return "UInt128";
case TypeIndex::Int8: return TypeName<Int8>::get();
case TypeIndex::Int16: return TypeName<Int16>::get();
case TypeIndex::Int32: return TypeName<Int32>::get();
case TypeIndex::Int64: return TypeName<Int64>::get();
case TypeIndex::Int128: return TypeName<Int128>::get();
case TypeIndex::Float32: return TypeName<Float32>::get();
case TypeIndex::Float64: return TypeName<Float64>::get();
case TypeIndex::Date: return "Date";
case TypeIndex::DateTime: return "DateTime";
case TypeIndex::String: return TypeName<String>::get();
case TypeIndex::FixedString: return "FixedString";
case TypeIndex::Enum8: return "Enum8";
case TypeIndex::Enum16: return "Enum16";
case TypeIndex::Decimal32: return TypeName<Decimal32>::get();
case TypeIndex::Decimal64: return TypeName<Decimal64>::get();
case TypeIndex::Decimal128: return TypeName<Decimal128>::get();
case TypeIndex::UUID: return "UUID";
case TypeIndex::Array: return "Array";
case TypeIndex::Tuple: return "Tuple";
case TypeIndex::Set: return "Set";
case TypeIndex::Interval: return "Interval";
case TypeIndex::Nullable: return "Nullable";
case TypeIndex::Function: return "Function";
case TypeIndex::AggregateFunction: return "AggregateFunction";
case TypeIndex::LowCardinality: return "LowCardinality";
}
__builtin_unreachable();
}
}
/// Specialization of `std::hash` for the Decimal<T> types.

View File

@ -10,6 +10,7 @@
#include <Functions/IFunction.h>
#include <IO/WriteBufferFromOStream.h>
#include <Interpreters/ExpressionAnalyzer.h>
#include <Interpreters/ExpressionActions.h>
#include <Parsers/IAST.h>
#include <Storages/IStorage.h>
#include <Common/COW.h>
@ -70,7 +71,7 @@ std::ostream & operator<<(std::ostream & stream, const Block & what)
std::ostream & operator<<(std::ostream & stream, const ColumnWithTypeAndName & what)
{
stream << "ColumnWithTypeAndName(name = " << what.name << ", type = " << what.type << ", column = ";
stream << "ColumnWithTypeAndName(name = " << what.name << ", type = " << *what.type << ", column = ";
return dumpValue(stream, what.column) << ")";
}
@ -109,4 +110,56 @@ std::ostream & operator<<(std::ostream & stream, const IAST & what)
return stream;
}
std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what)
{
stream << "ExpressionAction(" << what.toString() << ")";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what)
{
stream << "ExpressionActions(" << what.dumpActions() << ")";
return stream;
}
std::ostream & operator<<(std::ostream & stream, const SyntaxAnalyzerResult & what)
{
stream << "SyntaxAnalyzerResult{";
stream << "storage=" << what.storage << "; ";
if (!what.source_columns.empty())
{
stream << "source_columns=";
dumpValue(stream, what.source_columns);
stream << "; ";
}
if (!what.aliases.empty())
{
stream << "aliases=";
dumpValue(stream, what.aliases);
stream << "; ";
}
if (!what.array_join_result_to_source.empty())
{
stream << "array_join_result_to_source=";
dumpValue(stream, what.array_join_result_to_source);
stream << "; ";
}
if (!what.array_join_alias_to_name.empty())
{
stream << "array_join_alias_to_name=";
dumpValue(stream, what.array_join_alias_to_name);
stream << "; ";
}
if (!what.array_join_name_to_alias.empty())
{
stream << "array_join_name_to_alias=";
dumpValue(stream, what.array_join_name_to_alias);
stream << "; ";
}
stream << "rewrite_subqueries=" << what.rewrite_subqueries << "; ";
stream << "}";
return stream;
}
}

View File

@ -41,6 +41,14 @@ std::ostream & operator<<(std::ostream & stream, const IAST & what);
std::ostream & operator<<(std::ostream & stream, const Connection::Packet & what);
struct ExpressionAction;
std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what);
class ExpressionActions;
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what);
struct SyntaxAnalyzerResult;
std::ostream & operator<<(std::ostream & stream, const SyntaxAnalyzerResult & what);
}
/// some operator<< should be declared before operator<<(... std::shared_ptr<>)

View File

@ -16,15 +16,15 @@ struct BlockIO
BlockIO(const BlockIO &) = default;
~BlockIO() = default;
BlockOutputStreamPtr out;
BlockInputStreamPtr in;
/** process_list_entry should be destroyed after in and after out,
* since in and out contain pointer to objects inside process_list_entry (query-level MemoryTracker for example),
* which could be used before destroying of in and out.
*/
std::shared_ptr<ProcessListEntry> process_list_entry;
BlockOutputStreamPtr out;
BlockInputStreamPtr in;
/// Callbacks for query logging could be set here.
std::function<void(IBlockInputStream *, IBlockOutputStream *)> finish_callback;
std::function<void()> exception_callback;

View File

@ -19,8 +19,8 @@ void CountingBlockOutputStream::write(const Block & block)
Progress local_progress(block.rows(), block.bytes(), 0);
progress.incrementPiecewiseAtomically(local_progress);
ProfileEvents::increment(ProfileEvents::InsertedRows, local_progress.rows);
ProfileEvents::increment(ProfileEvents::InsertedBytes, local_progress.bytes);
ProfileEvents::increment(ProfileEvents::InsertedRows, local_progress.read_rows);
ProfileEvents::increment(ProfileEvents::InsertedBytes, local_progress.read_bytes);
if (process_elem)
process_elem->updateProgressOut(local_progress);

View File

@ -281,7 +281,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
/// The total amount of data processed or intended for processing in all leaf sources, possibly on remote servers.
ProgressValues progress = process_list_elem->getProgressIn();
size_t total_rows_estimate = std::max(progress.rows, progress.total_rows);
size_t total_rows_estimate = std::max(progress.read_rows, progress.total_rows_to_read);
/** Check the restrictions on the amount of data to read, the speed of the query, the quota on the amount of data to read.
* NOTE: Maybe it makes sense to have them checked directly in ProcessList?
@ -289,7 +289,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
if (limits.mode == LIMITS_TOTAL
&& ((limits.size_limits.max_rows && total_rows_estimate > limits.size_limits.max_rows)
|| (limits.size_limits.max_bytes && progress.bytes > limits.size_limits.max_bytes)))
|| (limits.size_limits.max_bytes && progress.read_bytes > limits.size_limits.max_bytes)))
{
switch (limits.size_limits.overflow_mode)
{
@ -300,7 +300,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
+ " rows read (or to read), maximum: " + toString(limits.size_limits.max_rows),
ErrorCodes::TOO_MANY_ROWS);
else
throw Exception("Limit for (uncompressed) bytes to read exceeded: " + toString(progress.bytes)
throw Exception("Limit for (uncompressed) bytes to read exceeded: " + toString(progress.read_bytes)
+ " bytes read, maximum: " + toString(limits.size_limits.max_bytes),
ErrorCodes::TOO_MANY_BYTES);
}
@ -308,8 +308,8 @@ void IBlockInputStream::progressImpl(const Progress & value)
case OverflowMode::BREAK:
{
/// For `break`, we will stop only if so many rows were actually read, and not just supposed to be read.
if ((limits.size_limits.max_rows && progress.rows > limits.size_limits.max_rows)
|| (limits.size_limits.max_bytes && progress.bytes > limits.size_limits.max_bytes))
if ((limits.size_limits.max_rows && progress.read_rows > limits.size_limits.max_rows)
|| (limits.size_limits.max_bytes && progress.read_bytes > limits.size_limits.max_bytes))
{
cancel(false);
}
@ -322,7 +322,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
}
}
size_t total_rows = progress.total_rows;
size_t total_rows = progress.total_rows_to_read;
constexpr UInt64 profile_events_update_period_microseconds = 10 * 1000; // 10 milliseconds
UInt64 total_elapsed_microseconds = info.total_stopwatch.elapsedMicroseconds();
@ -344,20 +344,20 @@ void IBlockInputStream::progressImpl(const Progress & value)
if (elapsed_seconds > 0)
{
if (limits.min_execution_speed && progress.rows / elapsed_seconds < limits.min_execution_speed)
throw Exception("Query is executing too slow: " + toString(progress.rows / elapsed_seconds)
if (limits.min_execution_speed && progress.read_rows / elapsed_seconds < limits.min_execution_speed)
throw Exception("Query is executing too slow: " + toString(progress.read_rows / elapsed_seconds)
+ " rows/sec., minimum: " + toString(limits.min_execution_speed),
ErrorCodes::TOO_SLOW);
if (limits.min_execution_speed_bytes && progress.bytes / elapsed_seconds < limits.min_execution_speed_bytes)
throw Exception("Query is executing too slow: " + toString(progress.bytes / elapsed_seconds)
if (limits.min_execution_speed_bytes && progress.read_bytes / elapsed_seconds < limits.min_execution_speed_bytes)
throw Exception("Query is executing too slow: " + toString(progress.read_bytes / elapsed_seconds)
+ " bytes/sec., minimum: " + toString(limits.min_execution_speed_bytes),
ErrorCodes::TOO_SLOW);
/// If the predicted execution time is longer than `max_execution_time`.
if (limits.max_execution_time != 0 && total_rows)
{
double estimated_execution_time_seconds = elapsed_seconds * (static_cast<double>(total_rows) / progress.rows);
double estimated_execution_time_seconds = elapsed_seconds * (static_cast<double>(total_rows) / progress.read_rows);
if (estimated_execution_time_seconds > limits.max_execution_time.totalSeconds())
throw Exception("Estimated query execution time (" + toString(estimated_execution_time_seconds) + " seconds)"
@ -366,17 +366,17 @@ void IBlockInputStream::progressImpl(const Progress & value)
ErrorCodes::TOO_SLOW);
}
if (limits.max_execution_speed && progress.rows / elapsed_seconds >= limits.max_execution_speed)
limitProgressingSpeed(progress.rows, limits.max_execution_speed, total_elapsed_microseconds);
if (limits.max_execution_speed && progress.read_rows / elapsed_seconds >= limits.max_execution_speed)
limitProgressingSpeed(progress.read_rows, limits.max_execution_speed, total_elapsed_microseconds);
if (limits.max_execution_speed_bytes && progress.bytes / elapsed_seconds >= limits.max_execution_speed_bytes)
limitProgressingSpeed(progress.bytes, limits.max_execution_speed_bytes, total_elapsed_microseconds);
if (limits.max_execution_speed_bytes && progress.read_bytes / elapsed_seconds >= limits.max_execution_speed_bytes)
limitProgressingSpeed(progress.read_bytes, limits.max_execution_speed_bytes, total_elapsed_microseconds);
}
}
if (quota != nullptr && limits.mode == LIMITS_TOTAL)
{
quota->checkAndAddReadRowsBytes(time(nullptr), value.rows, value.bytes);
quota->checkAndAddReadRowsBytes(time(nullptr), value.read_rows, value.read_bytes);
}
}
}

View File

@ -36,6 +36,7 @@ public:
bool canBeInsideNullable() const override { return false; }
DataTypePtr getReturnType() const { return function->getReturnType(); }
DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); }
DataTypes getArgumentsDataTypes() const { return argument_types; }
/// NOTE These two functions for serializing single values are incompatible with the functions below.

View File

@ -821,7 +821,7 @@ MutableColumnUniquePtr DataTypeLowCardinality::createColumnUniqueImpl(const IDat
return creator(static_cast<ColumnVector<UInt16> *>(nullptr));
if (typeid_cast<const DataTypeDateTime *>(type))
return creator(static_cast<ColumnVector<UInt32> *>(nullptr));
if (isNumber(type))
if (isColumnedAsNumber(type))
{
MutableColumnUniquePtr column;
TypeListNumbers::forEach(CreateColumnVector(column, *type, creator));

View File

@ -111,7 +111,7 @@ ColumnPtr recursiveLowCardinalityConversion(const ColumnPtr & column, const Data
{
auto col = low_cardinality_type->createColumn();
static_cast<ColumnLowCardinality &>(*col).insertRangeFromFullColumn(*column, 0, column->size());
return std::move(col);
return col;
}
}

View File

@ -246,9 +246,9 @@ inline UInt32 getDecimalScale(const IDataType & data_type, UInt32 default_value
///
template <typename DataType> constexpr bool IsDataTypeDecimal = false;
template <> constexpr bool IsDataTypeDecimal<DataTypeDecimal<Decimal32>> = true;
template <> constexpr bool IsDataTypeDecimal<DataTypeDecimal<Decimal64>> = true;
template <> constexpr bool IsDataTypeDecimal<DataTypeDecimal<Decimal128>> = true;
template <> inline constexpr bool IsDataTypeDecimal<DataTypeDecimal<Decimal32>> = true;
template <> inline constexpr bool IsDataTypeDecimal<DataTypeDecimal<Decimal64>> = true;
template <> inline constexpr bool IsDataTypeDecimal<DataTypeDecimal<Decimal128>> = true;
template <typename DataType> constexpr bool IsDataTypeDecimalOrNumber = IsDataTypeDecimal<DataType> || IsDataTypeNumber<DataType>;

View File

@ -38,15 +38,15 @@ using DataTypeFloat32 = DataTypeNumber<Float32>;
using DataTypeFloat64 = DataTypeNumber<Float64>;
template <typename DataType> constexpr bool IsDataTypeNumber = false;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<UInt8>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<UInt16>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<UInt32>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<UInt64>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<Int8>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<Int16>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<Int32>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<Int64>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<Float32>> = true;
template <> constexpr bool IsDataTypeNumber<DataTypeNumber<Float64>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<UInt8>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<UInt16>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<UInt32>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<UInt64>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<Int8>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<Int16>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<Int32>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<Int64>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<Float32>> = true;
template <> inline constexpr bool IsDataTypeNumber<DataTypeNumber<Float64>> = true;
}

View File

@ -581,11 +581,18 @@ inline bool isFloat(const T & data_type)
return which.isFloat();
}
template <typename T>
inline bool isNativeNumber(const T & data_type)
{
WhichDataType which(data_type);
return which.isNativeInt() || which.isNativeUInt() || which.isFloat();
}
template <typename T>
inline bool isNumber(const T & data_type)
{
WhichDataType which(data_type);
return which.isInt() || which.isUInt() || which.isFloat();
return which.isInt() || which.isUInt() || which.isFloat() || which.isDecimal();
}
template <typename T>

View File

@ -86,7 +86,7 @@ void CacheDictionary::toParent(const PaddedPODArray<Key> & ids, PaddedPODArray<K
{
const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
getItemsNumber<UInt64>(*hierarchical_attribute, ids, out, [&](const size_t) { return null_value; });
getItemsNumberImpl<UInt64, UInt64>(*hierarchical_attribute, ids, out, [&](const size_t) { return null_value; });
}
@ -207,9 +207,7 @@ void CacheDictionary::isInConstantVector(const Key child_id, const PaddedPODArra
void CacheDictionary::getString(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ColumnString * out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
const auto null_value = StringRef{std::get<String>(attribute.null_values)};
@ -220,9 +218,7 @@ void CacheDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const ColumnString * const def, ColumnString * const out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsString(attribute, ids, out, [&](const size_t row) { return def->getDataAt(row); });
}
@ -231,9 +227,7 @@ void CacheDictionary::getString(
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const String & def, ColumnString * const out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsString(attribute, ids, out, [&](const size_t) { return StringRef{def}; });
}

View File

@ -221,11 +221,6 @@ private:
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
template <typename OutputType, typename DefaultGetter>
void getItemsNumber(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const;
template <typename AttributeType, typename OutputType, typename DefaultGetter>
void getItemsNumberImpl(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const;

View File

@ -34,34 +34,6 @@ namespace ErrorCodes
extern const int TYPE_MISMATCH;
}
template <typename OutputType, typename DefaultGetter>
void CacheDictionary::getItemsNumber(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const
{
if (false)
{
}
#define DISPATCH(TYPE) \
else if (attribute.type == AttributeUnderlyingType::TYPE) \
getItemsNumberImpl<TYPE, OutputType>(attribute, ids, out, std::forward<DefaultGetter>(get_default));
DISPATCH(UInt8)
DISPATCH(UInt16)
DISPATCH(UInt32)
DISPATCH(UInt64)
DISPATCH(UInt128)
DISPATCH(Int8)
DISPATCH(Int16)
DISPATCH(Int32)
DISPATCH(Int64)
DISPATCH(Float32)
DISPATCH(Float64)
DISPATCH(Decimal32)
DISPATCH(Decimal64)
DISPATCH(Decimal128)
#undef DISPATCH
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
}
template <typename AttributeType, typename OutputType, typename DefaultGetter>
void CacheDictionary::getItemsNumberImpl(
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const

View File

@ -12,13 +12,11 @@ using TYPE = @NAME@;
void CacheDictionary::get@NAME@(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ResultArrayType<TYPE> & out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
const auto null_value = std::get<TYPE>(attribute.null_values);
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t) { return null_value; });
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t) { return null_value; });
}
}

View File

@ -15,11 +15,9 @@ void CacheDictionary::get@NAME@(const std::string & attribute_name,
ResultArrayType<TYPE> & out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t row) { return def[row]; });
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t row) { return def[row]; });
}
}

View File

@ -12,11 +12,9 @@ using TYPE = @NAME@;
void CacheDictionary::get@NAME@(const std::string & attribute_name, const PaddedPODArray<Key> & ids, const TYPE def, ResultArrayType<TYPE> & out) const
{
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t) { return def; });
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t) { return def; });
}
}

View File

@ -77,9 +77,7 @@ void ComplexKeyCacheDictionary::getString(
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
const auto null_value = StringRef{std::get<String>(attribute.null_values)};
@ -96,9 +94,7 @@ void ComplexKeyCacheDictionary::getString(
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsString(attribute, key_columns, out, [&](const size_t row) { return def->getDataAt(row); });
}
@ -113,9 +109,7 @@ void ComplexKeyCacheDictionary::getString(
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
getItemsString(attribute, key_columns, out, [&](const size_t) { return StringRef{def}; });
}

View File

@ -256,34 +256,6 @@ private:
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
template <typename OutputType, typename DefaultGetter>
void
getItemsNumber(Attribute & attribute, const Columns & key_columns, PaddedPODArray<OutputType> & out, DefaultGetter && get_default) const
{
if (false)
{
}
#define DISPATCH(TYPE) \
else if (attribute.type == AttributeUnderlyingType::TYPE) \
getItemsNumberImpl<TYPE, OutputType>(attribute, key_columns, out, std::forward<DefaultGetter>(get_default));
DISPATCH(UInt8)
DISPATCH(UInt16)
DISPATCH(UInt32)
DISPATCH(UInt64)
DISPATCH(UInt128)
DISPATCH(Int8)
DISPATCH(Int16)
DISPATCH(Int32)
DISPATCH(Int64)
DISPATCH(Float32)
DISPATCH(Float64)
DISPATCH(Decimal32)
DISPATCH(Decimal64)
DISPATCH(Decimal128)
#undef DISPATCH
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
}
template <typename AttributeType, typename OutputType, typename DefaultGetter>
void getItemsNumberImpl(
Attribute & attribute, const Columns & key_columns, PaddedPODArray<OutputType> & out, DefaultGetter && get_default) const

View File

@ -13,12 +13,10 @@ void ComplexKeyCacheDictionary::get@NAME@(const std::string & attribute_name, co
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
const auto null_value = std::get<TYPE>(attribute.null_values);
getItemsNumber<TYPE>(attribute, key_columns, out, [&](const size_t) { return null_value; });
getItemsNumberImpl<TYPE, TYPE>(attribute, key_columns, out, [&](const size_t) { return null_value; });
}
}

View File

@ -18,10 +18,8 @@ void ComplexKeyCacheDictionary::get@NAME@(const std::string & attribute_name,
dict_struct.validateKeyTypes(key_types);
auto & attribute = getAttribute(attribute_name);
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
ErrorCodes::TYPE_MISMATCH};
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
getItemsNumber<TYPE>(attribute, key_columns, out, [&](const size_t row) { return def[row]; });
getItemsNumberImpl<TYPE, TYPE>(attribute, key_columns, out, [&](const size_t row) { return def[row]; });
}
}

Some files were not shown because too many files have changed in this diff Show More