mirror of
https://github.com/ClickHouse/ClickHouse.git
synced 2024-11-26 17:41:59 +00:00
Merge branch 'master' of github.com:yandex/ClickHouse into DOCAPI-4293
This commit is contained in:
commit
b6dbe984da
53
CHANGELOG.md
53
CHANGELOG.md
@ -1,8 +1,59 @@
|
||||
## ClickHouse release 19.6.2.11, 2019-05-13
|
||||
|
||||
### New Features
|
||||
* TTL expressions for columns and tables. [#4212](https://github.com/yandex/ClickHouse/pull/4212) ([Anton Popov](https://github.com/CurtizJ))
|
||||
* Added support for `brotli` compression for HTTP responses (Accept-Encoding: br) [#4388](https://github.com/yandex/ClickHouse/pull/4388) ([Mikhail](https://github.com/fandyushin))
|
||||
* Added new function `isValidUTF8` for checking whether a set of bytes is correctly utf-8 encoded. [#4934](https://github.com/yandex/ClickHouse/pull/4934) ([Danila Kutenin](https://github.com/danlark1))
|
||||
* Add new load balancing policy `first_or_random` which sends queries to the first specified host and if it's inaccessible send queries to random hosts of shard. Useful for cross-replication topology setups. [#5012](https://github.com/yandex/ClickHouse/pull/5012) ([nvartolomei](https://github.com/nvartolomei))
|
||||
|
||||
### Experimental Features
|
||||
* Add setting `index_granularity_bytes` (adaptive index granularity) for MergeTree* tables family. [#4826](https://github.com/yandex/ClickHouse/pull/4826) ([alesapin](https://github.com/alesapin))
|
||||
|
||||
### Improvements
|
||||
* Added support for non-constant and negative size and length arguments for function `substringUTF8`. [#4989](https://github.com/yandex/ClickHouse/pull/4989) ([alexey-milovidov](https://github.com/alexey-milovidov))
|
||||
* Disable push-down to right table in left join, left table in right join, and both tables in full join. This fixes wrong JOIN results in some cases. [#4846](https://github.com/yandex/ClickHouse/pull/4846) ([Ivan](https://github.com/abyss7))
|
||||
* `clickhouse-copier`: auto upload task configuration from `--task-file` option [#4876](https://github.com/yandex/ClickHouse/pull/4876) ([proller](https://github.com/proller))
|
||||
* Added typos handler for storage factory and table functions factory. [#4891](https://github.com/yandex/ClickHouse/pull/4891) ([Danila Kutenin](https://github.com/danlark1))
|
||||
* Support asterisks and qualified asterisks for multiple joins without subqueries [#4898](https://github.com/yandex/ClickHouse/pull/4898) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
* Make missing column error message more user friendly. [#4915](https://github.com/yandex/ClickHouse/pull/4915) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
|
||||
### Performance Improvements
|
||||
* Significant speedup of ASOF JOIN [#4924](https://github.com/yandex/ClickHouse/pull/4924) ([Martijn Bakker](https://github.com/Gladdy))
|
||||
|
||||
### Backward Incompatible Changes
|
||||
* HTTP header `Query-Id` was renamed to `X-ClickHouse-Query-Id` for consistency. [#4972](https://github.com/yandex/ClickHouse/pull/4972) ([Mikhail](https://github.com/fandyushin))
|
||||
|
||||
### Bug Fixes
|
||||
* Fixed potential null pointer dereference in `clickhouse-copier`. [#4900](https://github.com/yandex/ClickHouse/pull/4900) ([proller](https://github.com/proller))
|
||||
* Fixed error on query with JOIN + ARRAY JOIN [#4938](https://github.com/yandex/ClickHouse/pull/4938) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
* Fixed hanging on start of the server when a dictionary depends on another dictionary via a database with engine=Dictionary. [#4962](https://github.com/yandex/ClickHouse/pull/4962) ([Vitaly Baranov](https://github.com/vitlibar))
|
||||
* Partially fix distributed_product_mode = local. It's possible to allow columns of local tables in where/having/order by/... via table aliases. Throw exception if table does not have alias. There's not possible to access to the columns without table aliases yet. [#4986](https://github.com/yandex/ClickHouse/pull/4986) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
* Fix potentially wrong result for `SELECT DISTINCT` with `JOIN` [#5001](https://github.com/yandex/ClickHouse/pull/5001) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
|
||||
### Build/Testing/Packaging Improvements
|
||||
* Fixed test failures when running clickhouse-server on different host [#4713](https://github.com/yandex/ClickHouse/pull/4713) ([Vasily Nemkov](https://github.com/Enmk))
|
||||
* clickhouse-test: Disable color control sequences in non tty environment. [#4937](https://github.com/yandex/ClickHouse/pull/4937) ([alesapin](https://github.com/alesapin))
|
||||
* clickhouse-test: Allow use any test database (remove `test.` qualification where it possible) [#5008](https://github.com/yandex/ClickHouse/pull/5008) ([proller](https://github.com/proller))
|
||||
* Fix ubsan errors [#5037](https://github.com/yandex/ClickHouse/pull/5037) ([Vitaly Baranov](https://github.com/vitlibar))
|
||||
* Yandex LFAlloc was added to ClickHouse to allocate MarkCache and UncompressedCache data in different ways to catch segfaults more reliable [#4995](https://github.com/yandex/ClickHouse/pull/4995) ([Danila Kutenin](https://github.com/danlark1))
|
||||
* Python util to help with backports and changelogs. [#4949](https://github.com/yandex/ClickHouse/pull/4949) ([Ivan](https://github.com/abyss7))
|
||||
|
||||
|
||||
## ClickHouse release 19.5.4.22, 2019-05-13
|
||||
|
||||
### Bug fixes
|
||||
* Fixed possible crash in bitmap* functions [#5220](https://github.com/yandex/ClickHouse/pull/5220) [#5228](https://github.com/yandex/ClickHouse/pull/5228) ([Andy Yang](https://github.com/andyyzh))
|
||||
* Fixed very rare data race condition that could happen when executing a query with UNION ALL involving at least two SELECTs from system.columns, system.tables, system.parts, system.parts_tables or tables of Merge family and performing ALTER of columns of the related tables concurrently. [#5189](https://github.com/yandex/ClickHouse/pull/5189) ([alexey-milovidov](https://github.com/alexey-milovidov))
|
||||
* Fixed error `Set for IN is not created yet in case of using single LowCardinality column in the left part of IN`. This error happened if LowCardinality column was the part of primary key. #5031 [#5154](https://github.com/yandex/ClickHouse/pull/5154) ([Nikolai Kochetov](https://github.com/KochetovNicolai))
|
||||
* Modification of retention function: If a row satisfies both the first and NTH condition, only the first satisfied condition is added to the data state. Now all conditions that satisfy in a row of data are added to the data state. [#5119](https://github.com/yandex/ClickHouse/pull/5119) ([小路](https://github.com/nicelulu))
|
||||
|
||||
|
||||
## ClickHouse release 19.5.3.8, 2019-04-18
|
||||
|
||||
### Bug fixes
|
||||
* Fixed type of setting `max_partitions_per_insert_block` from boolean to UInt64. [#5028](https://github.com/yandex/ClickHouse/pull/5028) ([Mohammad Hossein Sekhavat](https://github.com/mhsekhavat))
|
||||
|
||||
|
||||
## ClickHouse release 19.5.2.6, 2019-04-15
|
||||
|
||||
### New Features
|
||||
@ -294,7 +345,7 @@
|
||||
* Added support of `Nullable` types in `mysql` table function. [#4198](https://github.com/yandex/ClickHouse/pull/4198) ([Emmanuel Donin de Rosière](https://github.com/edonin))
|
||||
* Support for arbitrary constant expressions in `LIMIT` clause. [#4246](https://github.com/yandex/ClickHouse/pull/4246) ([k3box](https://github.com/k3box))
|
||||
* Added `topKWeighted` aggregate function that takes additional argument with (unsigned integer) weight. [#4245](https://github.com/yandex/ClickHouse/pull/4245) ([Andrew Golman](https://github.com/andrewgolman))
|
||||
* `StorageJoin` now supports `join_overwrite` setting that allows overwriting existing values of the same key. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird)
|
||||
* `StorageJoin` now supports `join_any_take_last_row` setting that allows overwriting existing values of the same key. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird)
|
||||
* Added function `toStartOfInterval`. [#4304](https://github.com/yandex/ClickHouse/pull/4304) ([Vitaly Baranov](https://github.com/vitlibar))
|
||||
* Added `RowBinaryWithNamesAndTypes` format. [#4200](https://github.com/yandex/ClickHouse/pull/4200) ([Oleg V. Kozlyuk](https://github.com/DarkWanderer))
|
||||
* Added `IPv4` and `IPv6` data types. More effective implementations of `IPv*` functions. [#3669](https://github.com/yandex/ClickHouse/pull/3669) ([Vasily Nemkov](https://github.com/Enmk))
|
||||
|
@ -1,3 +1,53 @@
|
||||
## ClickHouse release 19.6.2.11, 2019-05-13
|
||||
|
||||
### Новые возможности
|
||||
* TTL выражения, позволяющие настроить время жизни и автоматическую очистку данных в таблице или в отдельных её столбцах. [#4212](https://github.com/yandex/ClickHouse/pull/4212) ([Anton Popov](https://github.com/CurtizJ))
|
||||
* Добавлена поддержка алгоритма сжатия `brotli` в HTTP ответах (`Accept-Encoding: br`). Для тела POST запросов, эта возможность уже существовала. [#4388](https://github.com/yandex/ClickHouse/pull/4388) ([Mikhail](https://github.com/fandyushin))
|
||||
* Добавлена функция `isValidUTF8` для проверки, содержит ли строка валидные данные в кодировке UTF-8. [#4934](https://github.com/yandex/ClickHouse/pull/4934) ([Danila Kutenin](https://github.com/danlark1))
|
||||
* Добавлены новое правило балансировки (`load_balancing`) `first_or_random` по которому запросы посылаются на первый заданый хост и если он недоступен - на случайные хосты шарда. Полезно для топологий с кросс-репликацией. [#5012](https://github.com/yandex/ClickHouse/pull/5012) ([nvartolomei](https://github.com/nvartolomei))
|
||||
|
||||
### Эксперемннтальные возможности
|
||||
* Добавлена настройка `index_granularity_bytes` (адаптивная гранулярность индекса) для таблиц семейства MergeTree* . [#4826](https://github.com/yandex/ClickHouse/pull/4826) ([alesapin](https://github.com/alesapin))
|
||||
|
||||
### Улучшения
|
||||
* Добавлена поддержка для не константных и отрицательных значений аргументов смещения и длины для функции `substringUTF8`. [#4989](https://github.com/yandex/ClickHouse/pull/4989) ([alexey-milovidov](https://github.com/alexey-milovidov))
|
||||
* Отключение push-down в правую таблицы в left join, левую таблицу в right join, и в обе таблицы в full join. Это исправляет неправильные JOIN результаты в некоторых случаях. [#4846](https://github.com/yandex/ClickHouse/pull/4846) ([Ivan](https://github.com/abyss7))
|
||||
* `clickhouse-copier`: Автоматическая загрузка конфигурации задачи в zookeeper из `--task-file` опции [#4876](https://github.com/yandex/ClickHouse/pull/4876) ([proller](https://github.com/proller))
|
||||
* Добавлены подсказки с учётом опечаток для имён движков таблиц и табличных функций. [#4891](https://github.com/yandex/ClickHouse/pull/4891) ([Danila Kutenin](https://github.com/danlark1))
|
||||
* Поддержка выражений `select *` и `select tablename.*` для множественных join без подзапросов [#4898](https://github.com/yandex/ClickHouse/pull/4898) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
* Сообщения об ошибках об отсутствующих столбцах стали более понятными. [#4915](https://github.com/yandex/ClickHouse/pull/4915) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
|
||||
### Улучшение производительности
|
||||
* Существенное ускорение ASOF JOIN [#4924](https://github.com/yandex/ClickHouse/pull/4924) ([Martijn Bakker](https://github.com/Gladdy))
|
||||
|
||||
### Обратно несовместимые изменения
|
||||
* HTTP заголовок `Query-Id` переименован в `X-ClickHouse-Query-Id` для соответствия. [#4972](https://github.com/yandex/ClickHouse/pull/4972) ([Mikhail](https://github.com/fandyushin))
|
||||
|
||||
### Исправления ошибок
|
||||
* Исправлены возможные разыменования нулевого указателя в `clickhouse-copier`. [#4900](https://github.com/yandex/ClickHouse/pull/4900) ([proller](https://github.com/proller))
|
||||
* Исправлены ошибки в запросах с JOIN + ARRAY JOIN [#4938](https://github.com/yandex/ClickHouse/pull/4938) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
* Исправлено зависание на старте сервера если внешний словарь зависит от другого словаря через использование таблицы из БД с движком `Dictionary`. [#4962](https://github.com/yandex/ClickHouse/pull/4962) ([Vitaly Baranov](https://github.com/vitlibar))
|
||||
* При использовании `distributed_product_mode = 'local'` корректно работает использование столбцов локальных таблиц в where/having/order by/... через табличные алиасы. Выкидывает исключение если таблица не имеет алиас. Доступ к столбцам без алиасов пока не возможен. [#4986](https://github.com/yandex/ClickHouse/pull/4986) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
* Исправлен потенциально некорректный результат для `SELECT DISTINCT` с `JOIN` [#5001](https://github.com/yandex/ClickHouse/pull/5001) ([Artem Zuikov](https://github.com/4ertus2))
|
||||
|
||||
### Улучшения сборки/тестирования/пакетирования
|
||||
* Исправлена неработоспособность тестов, если `clickhouse-server` запущен на удалённом хосте [#4713](https://github.com/yandex/ClickHouse/pull/4713) ([Vasily Nemkov](https://github.com/Enmk))
|
||||
* `clickhouse-test`: Отключена раскраска результата, если команда запускается не в терминале. [#4937](https://github.com/yandex/ClickHouse/pull/4937) ([alesapin](https://github.com/alesapin))
|
||||
* `clickhouse-test`: Возможность использования не только базы данных test [#5008](https://github.com/yandex/ClickHouse/pull/5008) ([proller](https://github.com/proller))
|
||||
* Исправлены ошибки при запуске тестов под UBSan [#5037](https://github.com/yandex/ClickHouse/pull/5037) ([Vitaly Baranov](https://github.com/vitlibar))
|
||||
* Добавлен аллокатор Yandex LFAlloc для аллоцирования MarkCache и UncompressedCache данных разными способами для более надежного отлавливания проездов по памяти [#4995](https://github.com/yandex/ClickHouse/pull/4995) ([Danila Kutenin](https://github.com/danlark1))
|
||||
* Утилита для упрощения бэкпортирования изменений в старые релизы и составления changelogs. [#4949](https://github.com/yandex/ClickHouse/pull/4949) ([Ivan](https://github.com/abyss7))
|
||||
|
||||
|
||||
## ClickHouse release 19.5.4.22, 2019-05-13
|
||||
|
||||
### Исправления ошибок
|
||||
* Исправлены возможные падения в bitmap* функциях [#5220](https://github.com/yandex/ClickHouse/pull/5220) [#5228](https://github.com/yandex/ClickHouse/pull/5228) ([Andy Yang](https://github.com/andyyzh))
|
||||
* Исправлен очень редкий data race condition который мог произойти при выполнении запроса с UNION ALL включающего минимум два SELECT из таблиц system.columns, system.tables, system.parts, system.parts_tables или таблиц семейства Merge и одновременно выполняющихся запросов ALTER столбцов соответствующих таблиц. [#5189](https://github.com/yandex/ClickHouse/pull/5189) ([alexey-milovidov](https://github.com/alexey-milovidov))
|
||||
* Исправлена ошибка `Set for IN is not created yet in case of using single LowCardinality column in the left part of IN`. Эта ошибка возникала когда LowCardinality столбец была частью primary key. #5031 [#5154](https://github.com/yandex/ClickHouse/pull/5154) ([Nikolai Kochetov](https://github.com/KochetovNicolai))
|
||||
* Исправление функции retention: только первое соответствующее условие добавлялось в состояние данных. Сейчас все условия которые удовлетворяют в строке данных добавляются в состояние. [#5119](https://github.com/yandex/ClickHouse/pull/5119) ([小路](https://github.com/nicelulu))
|
||||
|
||||
|
||||
## ClickHouse release 19.5.3.8, 2019-04-18
|
||||
|
||||
### Исправления ошибок
|
||||
@ -286,7 +336,7 @@
|
||||
* Добавлена поддержка `Nullable` типов в табличной функции `mysql`. [#4198](https://github.com/yandex/ClickHouse/pull/4198) ([Emmanuel Donin de Rosière](https://github.com/edonin))
|
||||
* Добавлена поддержка произвольных константных выражений в секции `LIMIT`. [#4246](https://github.com/yandex/ClickHouse/pull/4246) ([k3box](https://github.com/k3box))
|
||||
* Добавлена агрегатная функция `topKWeighted` - вариант `topK`, позволяющий задавать (целый неотрицательный) вес добавляемого значения. [#4245](https://github.com/yandex/ClickHouse/pull/4245) ([Andrew Golman](https://github.com/andrewgolman))
|
||||
* Движок `Join` теперь поддерживает настройку `join_overwrite`, которая позволяет перезаписывать значения для существующих ключей. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird))
|
||||
* Движок `Join` теперь поддерживает настройку `join_any_take_last_row`, которая позволяет перезаписывать значения для существующих ключей. [#3973](https://github.com/yandex/ClickHouse/pull/3973) ([Amos Bird](https://github.com/amosbird))
|
||||
* Добавлена функция `toStartOfInterval`. [#4304](https://github.com/yandex/ClickHouse/pull/4304) ([Vitaly Baranov](https://github.com/vitlibar))
|
||||
* Добавлена функция `toStartOfTenMinutes`. [#4298](https://github.com/yandex/ClickHouse/pull/4298) ([Vitaly Baranov](https://github.com/vitlibar))
|
||||
* Добавлен формат `RowBinaryWithNamesAndTypes`. [#4200](https://github.com/yandex/ClickHouse/pull/4200) ([Oleg V. Kozlyuk](https://github.com/DarkWanderer))
|
||||
|
@ -49,19 +49,20 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
endif ()
|
||||
|
||||
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7)
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
|
||||
|
||||
if (WEVERYTHING)
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-return-std-move-in-c++11")
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8)
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wextra-semi-stmt -Wshadow-field -Wstring-plus-int -Wempty-init-stmt")
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wshadow-field -Wstring-plus-int")
|
||||
if(NOT APPLE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wextra-semi-stmt -Wempty-init-stmt")
|
||||
endif()
|
||||
endif ()
|
||||
|
||||
if (NOT CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9)
|
||||
if (WEVERYTHING)
|
||||
if (WEVERYTHING AND NOT APPLE)
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-ctad-maybe-unsupported")
|
||||
endif ()
|
||||
endif ()
|
||||
|
@ -1,11 +1,11 @@
|
||||
# This strings autochanged from release_lib.sh:
|
||||
set(VERSION_REVISION 54420)
|
||||
set(VERSION_REVISION 54421)
|
||||
set(VERSION_MAJOR 19)
|
||||
set(VERSION_MINOR 8)
|
||||
set(VERSION_MINOR 9)
|
||||
set(VERSION_PATCH 1)
|
||||
set(VERSION_GITHASH a76e504f45ff4a74e8c492bd269f022352d5f6d9)
|
||||
set(VERSION_DESCRIBE v19.8.1.1-testing)
|
||||
set(VERSION_STRING 19.8.1.1)
|
||||
set(VERSION_GITHASH 0c2aa460651a462f14efc7e995840a244531d373)
|
||||
set(VERSION_DESCRIBE v19.9.1.1-testing)
|
||||
set(VERSION_STRING 19.9.1.1)
|
||||
# end of autochange
|
||||
|
||||
set(VERSION_EXTRA "" CACHE STRING "")
|
||||
|
@ -325,8 +325,8 @@ private:
|
||||
double seconds = watch.elapsedSeconds();
|
||||
|
||||
std::lock_guard lock(mutex);
|
||||
info_per_interval.add(seconds, progress.rows, progress.bytes, info.rows, info.bytes);
|
||||
info_total.add(seconds, progress.rows, progress.bytes, info.rows, info.bytes);
|
||||
info_per_interval.add(seconds, progress.read_rows, progress.read_bytes, info.rows, info.bytes);
|
||||
info_total.add(seconds, progress.read_rows, progress.read_bytes, info.rows, info.bytes);
|
||||
}
|
||||
|
||||
|
||||
|
@ -435,7 +435,7 @@ private:
|
||||
#if USE_READLINE
|
||||
int res = read_history(history_file.c_str());
|
||||
if (res)
|
||||
throwFromErrno("Cannot read history from file " + history_file, ErrorCodes::CANNOT_READ_HISTORY);
|
||||
std::cerr << "Cannot read history from file " + history_file + ": "+ errnoToString(ErrorCodes::CANNOT_READ_HISTORY);
|
||||
#endif
|
||||
}
|
||||
else /// Create history file.
|
||||
@ -612,7 +612,7 @@ private:
|
||||
|
||||
#if USE_READLINE && HAVE_READLINE_HISTORY
|
||||
if (!history_file.empty() && append_history(1, history_file.c_str()))
|
||||
throwFromErrno("Cannot append history to file " + history_file, ErrorCodes::CANNOT_APPEND_HISTORY);
|
||||
std::cerr << "Cannot append history to file " + history_file + ": " + errnoToString(ErrorCodes::CANNOT_APPEND_HISTORY);
|
||||
#endif
|
||||
|
||||
prev_input = input;
|
||||
@ -866,7 +866,7 @@ private:
|
||||
std::cout << std::endl
|
||||
<< processed_rows << " rows in set. Elapsed: " << watch.elapsedSeconds() << " sec. ";
|
||||
|
||||
if (progress.rows >= 1000)
|
||||
if (progress.read_rows >= 1000)
|
||||
writeFinalProgress();
|
||||
|
||||
std::cout << std::endl << std::endl;
|
||||
@ -1420,23 +1420,23 @@ private:
|
||||
<< " Progress: ";
|
||||
|
||||
message
|
||||
<< formatReadableQuantity(progress.rows) << " rows, "
|
||||
<< formatReadableSizeWithDecimalSuffix(progress.bytes);
|
||||
<< formatReadableQuantity(progress.read_rows) << " rows, "
|
||||
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes);
|
||||
|
||||
size_t elapsed_ns = watch.elapsed();
|
||||
if (elapsed_ns)
|
||||
message << " ("
|
||||
<< formatReadableQuantity(progress.rows * 1000000000.0 / elapsed_ns) << " rows/s., "
|
||||
<< formatReadableSizeWithDecimalSuffix(progress.bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
|
||||
<< formatReadableQuantity(progress.read_rows * 1000000000.0 / elapsed_ns) << " rows/s., "
|
||||
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
|
||||
else
|
||||
message << ". ";
|
||||
|
||||
written_progress_chars = message.count() - prefix_size - (increment % 8 == 7 ? 10 : 13); /// Don't count invisible output (escape sequences).
|
||||
|
||||
/// If the approximate number of rows to process is known, we can display a progress bar and percentage.
|
||||
if (progress.total_rows > 0)
|
||||
if (progress.total_rows_to_read > 0)
|
||||
{
|
||||
size_t total_rows_corrected = std::max(progress.rows, progress.total_rows);
|
||||
size_t total_rows_corrected = std::max(progress.read_rows, progress.total_rows_to_read);
|
||||
|
||||
/// To avoid flicker, display progress bar only if .5 seconds have passed since query execution start
|
||||
/// and the query is less than halfway done.
|
||||
@ -1444,7 +1444,7 @@ private:
|
||||
if (elapsed_ns > 500000000)
|
||||
{
|
||||
/// Trigger to start displaying progress bar. If query is mostly done, don't display it.
|
||||
if (progress.rows * 2 < total_rows_corrected)
|
||||
if (progress.read_rows * 2 < total_rows_corrected)
|
||||
show_progress_bar = true;
|
||||
|
||||
if (show_progress_bar)
|
||||
@ -1452,7 +1452,7 @@ private:
|
||||
ssize_t width_of_progress_bar = static_cast<ssize_t>(terminal_size.ws_col) - written_progress_chars - strlen(" 99%");
|
||||
if (width_of_progress_bar > 0)
|
||||
{
|
||||
std::string bar = UnicodeBar::render(UnicodeBar::getWidth(progress.rows, 0, total_rows_corrected, width_of_progress_bar));
|
||||
std::string bar = UnicodeBar::render(UnicodeBar::getWidth(progress.read_rows, 0, total_rows_corrected, width_of_progress_bar));
|
||||
message << "\033[0;32m" << bar << "\033[0m";
|
||||
if (width_of_progress_bar > static_cast<ssize_t>(bar.size() / UNICODE_BAR_CHAR_SIZE))
|
||||
message << std::string(width_of_progress_bar - bar.size() / UNICODE_BAR_CHAR_SIZE, ' ');
|
||||
@ -1461,7 +1461,7 @@ private:
|
||||
}
|
||||
|
||||
/// Underestimate percentage a bit to avoid displaying 100%.
|
||||
message << ' ' << (99 * progress.rows / total_rows_corrected) << '%';
|
||||
message << ' ' << (99 * progress.read_rows / total_rows_corrected) << '%';
|
||||
}
|
||||
|
||||
message << ENABLE_LINE_WRAPPING;
|
||||
@ -1474,14 +1474,14 @@ private:
|
||||
void writeFinalProgress()
|
||||
{
|
||||
std::cout << "Processed "
|
||||
<< formatReadableQuantity(progress.rows) << " rows, "
|
||||
<< formatReadableSizeWithDecimalSuffix(progress.bytes);
|
||||
<< formatReadableQuantity(progress.read_rows) << " rows, "
|
||||
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes);
|
||||
|
||||
size_t elapsed_ns = watch.elapsed();
|
||||
if (elapsed_ns)
|
||||
std::cout << " ("
|
||||
<< formatReadableQuantity(progress.rows * 1000000000.0 / elapsed_ns) << " rows/s., "
|
||||
<< formatReadableSizeWithDecimalSuffix(progress.bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
|
||||
<< formatReadableQuantity(progress.read_rows * 1000000000.0 / elapsed_ns) << " rows/s., "
|
||||
<< formatReadableSizeWithDecimalSuffix(progress.read_bytes * 1000000000.0 / elapsed_ns) << "/s.) ";
|
||||
else
|
||||
std::cout << ". ";
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include "PerformanceTest.h"
|
||||
|
||||
#include <Core/Types.h>
|
||||
#include <Common/CpuId.h>
|
||||
#include <common/getMemoryAmount.h>
|
||||
#include <IO/ReadBufferFromFile.h>
|
||||
#include <IO/ReadHelpers.h>
|
||||
@ -71,6 +72,7 @@ bool PerformanceTest::checkPreconditions() const
|
||||
Strings preconditions;
|
||||
config->keys("preconditions", preconditions);
|
||||
size_t table_precondition_index = 0;
|
||||
size_t cpu_precondition_index = 0;
|
||||
|
||||
for (const std::string & precondition : preconditions)
|
||||
{
|
||||
@ -136,6 +138,30 @@ bool PerformanceTest::checkPreconditions() const
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (precondition == "cpu")
|
||||
{
|
||||
std::string precondition_key = "preconditions.cpu[" + std::to_string(cpu_precondition_index++) + "]";
|
||||
std::string flag_to_check = config->getString(precondition_key);
|
||||
|
||||
#define CHECK_CPU_PRECONDITION(OP) \
|
||||
if (flag_to_check == #OP) \
|
||||
{ \
|
||||
if (!Cpu::CpuFlagsCache::have_##OP) \
|
||||
{ \
|
||||
LOG_WARNING(log, "CPU doesn't support " << #OP); \
|
||||
return false; \
|
||||
} \
|
||||
} else
|
||||
|
||||
CPU_ID_ENUMERATE(CHECK_CPU_PRECONDITION)
|
||||
{
|
||||
LOG_WARNING(log, "CPU doesn't support " << flag_to_check);
|
||||
return false;
|
||||
}
|
||||
|
||||
#undef CHECK_CPU_PRECONDITION
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@ -13,7 +13,7 @@ void checkFulfilledConditionsAndUpdate(
|
||||
TestStats & statistics, TestStopConditions & stop_conditions,
|
||||
InterruptListener & interrupt_listener)
|
||||
{
|
||||
statistics.add(progress.rows, progress.bytes);
|
||||
statistics.add(progress.read_rows, progress.read_bytes);
|
||||
|
||||
stop_conditions.reportRowsRead(statistics.total_rows_read);
|
||||
stop_conditions.reportBytesReadUncompressed(statistics.total_bytes_read);
|
||||
|
@ -8,6 +8,8 @@ set(CLICKHOUSE_SERVER_SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/RootRequestHandler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/Server.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/TCPHandler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandler.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/MySQLHandlerFactory.cpp
|
||||
)
|
||||
|
||||
set(CLICKHOUSE_SERVER_LINK PRIVATE clickhouse_dictionaries clickhouse_common_io PUBLIC daemon PRIVATE clickhouse_storages_system clickhouse_functions clickhouse_aggregate_functions clickhouse_table_functions ${Poco_Net_LIBRARY})
|
||||
|
@ -453,30 +453,6 @@ void HTTPHandler::processQuery(
|
||||
return false;
|
||||
};
|
||||
|
||||
/// Used in case of POST request with form-data, but it isn't expected to be deleted after that scope.
|
||||
std::string full_query;
|
||||
|
||||
/// Support for "external data for query processing".
|
||||
if (startsWith(request.getContentType().data(), "multipart/form-data"))
|
||||
{
|
||||
ExternalTablesHandler handler(context, params);
|
||||
params.load(request, istr, handler);
|
||||
|
||||
/// Skip unneeded parameters to avoid confusing them later with context settings or query parameters.
|
||||
reserved_param_suffixes.emplace_back("_format");
|
||||
reserved_param_suffixes.emplace_back("_types");
|
||||
reserved_param_suffixes.emplace_back("_structure");
|
||||
|
||||
/// Params are of both form params POST and uri (GET params)
|
||||
for (const auto & it : params)
|
||||
if (it.first == "query")
|
||||
full_query += it.second;
|
||||
|
||||
in = std::make_unique<ReadBufferFromString>(full_query);
|
||||
}
|
||||
else
|
||||
in = std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
|
||||
|
||||
/// Settings can be overridden in the query.
|
||||
/// Some parameters (database, default_format, everything used in the code above) do not
|
||||
/// belong to the Settings class.
|
||||
@ -497,30 +473,63 @@ void HTTPHandler::processQuery(
|
||||
settings.readonly = 2;
|
||||
}
|
||||
|
||||
SettingsChanges settings_changes;
|
||||
for (auto it = params.begin(); it != params.end(); ++it)
|
||||
bool isExternalData = startsWith(request.getContentType().data(), "multipart/form-data");
|
||||
|
||||
if (isExternalData)
|
||||
{
|
||||
if (it->first == "database")
|
||||
/// Skip unneeded parameters to avoid confusing them later with context settings or query parameters.
|
||||
reserved_param_suffixes.reserve(3);
|
||||
/// It is a bug and ambiguity with `date_time_input_format` and `low_cardinality_allow_in_native_format` formats/settings.
|
||||
reserved_param_suffixes.emplace_back("_format");
|
||||
reserved_param_suffixes.emplace_back("_types");
|
||||
reserved_param_suffixes.emplace_back("_structure");
|
||||
}
|
||||
|
||||
SettingsChanges settings_changes;
|
||||
for (const auto & [key, value] : params)
|
||||
{
|
||||
if (key == "database")
|
||||
{
|
||||
context.setCurrentDatabase(it->second);
|
||||
context.setCurrentDatabase(value);
|
||||
}
|
||||
else if (it->first == "default_format")
|
||||
else if (key == "default_format")
|
||||
{
|
||||
context.setDefaultFormat(it->second);
|
||||
context.setDefaultFormat(value);
|
||||
}
|
||||
else if (param_could_be_skipped(it->first))
|
||||
else if (param_could_be_skipped(key))
|
||||
{
|
||||
}
|
||||
else
|
||||
{
|
||||
/// All other query parameters are treated as settings.
|
||||
settings_changes.push_back({it->first, it->second});
|
||||
settings_changes.push_back({key, value});
|
||||
}
|
||||
}
|
||||
|
||||
/// For external data we also want settings
|
||||
context.checkSettingsConstraints(settings_changes);
|
||||
context.applySettingsChanges(settings_changes);
|
||||
|
||||
/// Used in case of POST request with form-data, but it isn't expected to be deleted after that scope.
|
||||
std::string full_query;
|
||||
|
||||
/// Support for "external data for query processing".
|
||||
if (isExternalData)
|
||||
{
|
||||
ExternalTablesHandler handler(context, params);
|
||||
params.load(request, istr, handler);
|
||||
|
||||
/// Params are of both form params POST and uri (GET params)
|
||||
for (const auto & it : params)
|
||||
if (it.first == "query")
|
||||
full_query += it.second;
|
||||
|
||||
in = std::make_unique<ReadBufferFromString>(full_query);
|
||||
}
|
||||
else
|
||||
in = std::make_unique<ConcatReadBuffer>(*in_param, *in_post_maybe_compressed);
|
||||
|
||||
|
||||
/// HTTP response compression is turned on only if the client signalled that they support it
|
||||
/// (using Accept-Encoding header) and 'enable_http_compression' setting is turned on.
|
||||
used_output.out->setCompression(client_supports_http_compression && settings.enable_http_compression);
|
||||
|
@ -1,17 +1,16 @@
|
||||
#include "InterserverIOHTTPHandler.h"
|
||||
|
||||
#include <Poco/Net/HTTPBasicCredentials.h>
|
||||
#include <Poco/Net/HTTPServerRequest.h>
|
||||
#include <Poco/Net/HTTPServerResponse.h>
|
||||
|
||||
#include <common/logger_useful.h>
|
||||
|
||||
#include <Common/HTMLForm.h>
|
||||
#include <Common/setThreadName.h>
|
||||
#include <Compression/CompressedWriteBuffer.h>
|
||||
#include <IO/ReadBufferFromIStream.h>
|
||||
#include <IO/WriteBufferFromHTTPServerResponse.h>
|
||||
#include <Interpreters/InterserverIOHandler.h>
|
||||
|
||||
#include "InterserverIOHTTPHandler.h"
|
||||
#include "IServer.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -50,7 +49,7 @@ std::pair<String, bool> InterserverIOHTTPHandler::checkAuthentication(Poco::Net:
|
||||
return {"", true};
|
||||
}
|
||||
|
||||
void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response)
|
||||
void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response, Output & used_output)
|
||||
{
|
||||
HTMLForm params(request);
|
||||
|
||||
@ -61,24 +60,17 @@ void InterserverIOHTTPHandler::processQuery(Poco::Net::HTTPServerRequest & reque
|
||||
|
||||
ReadBufferFromIStream body(request.stream());
|
||||
|
||||
const auto & config = server.config();
|
||||
unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
|
||||
|
||||
WriteBufferFromHTTPServerResponse out(request, response, keep_alive_timeout);
|
||||
|
||||
auto endpoint = server.context().getInterserverIOHandler().getEndpoint(endpoint_name);
|
||||
|
||||
if (compress)
|
||||
{
|
||||
CompressedWriteBuffer compressed_out(out);
|
||||
CompressedWriteBuffer compressed_out(*used_output.out.get());
|
||||
endpoint->processQuery(params, body, compressed_out, response);
|
||||
}
|
||||
else
|
||||
{
|
||||
endpoint->processQuery(params, body, out, response);
|
||||
endpoint->processQuery(params, body, *used_output.out.get(), response);
|
||||
}
|
||||
|
||||
out.finalize();
|
||||
}
|
||||
|
||||
|
||||
@ -90,30 +82,30 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
|
||||
if (request.getVersion() == Poco::Net::HTTPServerRequest::HTTP_1_1)
|
||||
response.setChunkedTransferEncoding(true);
|
||||
|
||||
Output used_output;
|
||||
const auto & config = server.config();
|
||||
unsigned keep_alive_timeout = config.getUInt("keep_alive_timeout", 10);
|
||||
used_output.out = std::make_shared<WriteBufferFromHTTPServerResponse>(request, response, keep_alive_timeout);
|
||||
|
||||
try
|
||||
{
|
||||
if (auto [msg, success] = checkAuthentication(request); success)
|
||||
if (auto [message, success] = checkAuthentication(request); success)
|
||||
{
|
||||
processQuery(request, response);
|
||||
processQuery(request, response, used_output);
|
||||
LOG_INFO(log, "Done processing query");
|
||||
}
|
||||
else
|
||||
{
|
||||
response.setStatusAndReason(Poco::Net::HTTPServerResponse::HTTP_UNAUTHORIZED);
|
||||
if (!response.sent())
|
||||
response.send() << msg << std::endl;
|
||||
writeString(message, *used_output.out);
|
||||
LOG_WARNING(log, "Query processing failed request: '" << request.getURI() << "' authentification failed");
|
||||
}
|
||||
}
|
||||
catch (Exception & e)
|
||||
{
|
||||
|
||||
if (e.code() == ErrorCodes::TOO_MANY_SIMULTANEOUS_QUERIES)
|
||||
{
|
||||
if (!response.sent())
|
||||
response.send();
|
||||
return;
|
||||
}
|
||||
|
||||
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
|
||||
|
||||
@ -122,7 +114,7 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
|
||||
|
||||
std::string message = getCurrentExceptionMessage(is_real_error);
|
||||
if (!response.sent())
|
||||
response.send() << message << std::endl;
|
||||
writeString(message, *used_output.out);
|
||||
|
||||
if (is_real_error)
|
||||
LOG_ERROR(log, message);
|
||||
@ -134,7 +126,8 @@ void InterserverIOHTTPHandler::handleRequest(Poco::Net::HTTPServerRequest & requ
|
||||
response.setStatusAndReason(Poco::Net::HTTPResponse::HTTP_INTERNAL_SERVER_ERROR);
|
||||
std::string message = getCurrentExceptionMessage(false);
|
||||
if (!response.sent())
|
||||
response.send() << message << std::endl;
|
||||
writeString(message, *used_output.out);
|
||||
|
||||
LOG_ERROR(log, message);
|
||||
}
|
||||
}
|
||||
|
@ -1,12 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <Poco/Logger.h>
|
||||
#include <Poco/Net/HTTPRequestHandler.h>
|
||||
|
||||
#include <Common/CurrentMetrics.h>
|
||||
|
||||
#include "IServer.h"
|
||||
|
||||
|
||||
namespace CurrentMetrics
|
||||
{
|
||||
@ -16,6 +14,9 @@ namespace CurrentMetrics
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class IServer;
|
||||
class WriteBufferFromHTTPServerResponse;
|
||||
|
||||
class InterserverIOHTTPHandler : public Poco::Net::HTTPRequestHandler
|
||||
{
|
||||
public:
|
||||
@ -28,12 +29,17 @@ public:
|
||||
void handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response) override;
|
||||
|
||||
private:
|
||||
struct Output
|
||||
{
|
||||
std::shared_ptr<WriteBufferFromHTTPServerResponse> out;
|
||||
};
|
||||
|
||||
IServer & server;
|
||||
Poco::Logger * log;
|
||||
|
||||
CurrentMetrics::Increment metric_increment{CurrentMetrics::InterserverConnection};
|
||||
|
||||
void processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response);
|
||||
void processQuery(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response, Output & used_output);
|
||||
|
||||
std::pair<String, bool> checkAuthentication(Poco::Net::HTTPServerRequest & request) const;
|
||||
};
|
||||
|
370
dbms/programs/server/MySQLHandler.cpp
Normal file
370
dbms/programs/server/MySQLHandler.cpp
Normal file
@ -0,0 +1,370 @@
|
||||
#include <DataStreams/copyData.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/ReadBufferFromPocoSocket.h>
|
||||
#include <IO/WriteBufferFromPocoSocket.h>
|
||||
#include <Interpreters/executeQuery.h>
|
||||
#include <Storages/IStorage.h>
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <Core/NamesAndTypes.h>
|
||||
#include <Columns/ColumnVector.h>
|
||||
#include <Common/config_version.h>
|
||||
#include <Common/NetException.h>
|
||||
#include <Common/OpenSSLHelpers.h>
|
||||
#include <Poco/Crypto/RSAKey.h>
|
||||
#include <Poco/Crypto/CipherFactory.h>
|
||||
#include <Poco/Net/SecureStreamSocket.h>
|
||||
#include <Poco/Net/SSLManager.h>
|
||||
#include "MySQLHandler.h"
|
||||
#include <limits>
|
||||
#include <ext/scope_guard.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
using namespace MySQLProtocol;
|
||||
using Poco::Net::SecureStreamSocket;
|
||||
using Poco::Net::SSLManager;
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES;
|
||||
extern const int OPENSSL_ERROR;
|
||||
}
|
||||
|
||||
MySQLHandler::MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id)
|
||||
: Poco::Net::TCPServerConnection(socket_)
|
||||
, server(server_)
|
||||
, log(&Poco::Logger::get("MySQLHandler"))
|
||||
, connection_context(server.context())
|
||||
, connection_id(connection_id)
|
||||
, public_key(public_key)
|
||||
, private_key(private_key)
|
||||
{
|
||||
server_capability_flags = CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONNECTION | CLIENT_PLUGIN_AUTH | CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA | CLIENT_CONNECT_WITH_DB | CLIENT_DEPRECATE_EOF;
|
||||
if (ssl_enabled)
|
||||
server_capability_flags |= CLIENT_SSL;
|
||||
}
|
||||
|
||||
void MySQLHandler::run()
|
||||
{
|
||||
connection_context = server.context();
|
||||
connection_context.setDefaultFormat("MySQLWire");
|
||||
|
||||
in = std::make_shared<ReadBufferFromPocoSocket>(socket());
|
||||
out = std::make_shared<WriteBufferFromPocoSocket>(socket());
|
||||
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
|
||||
|
||||
try
|
||||
{
|
||||
String scramble = generateScramble();
|
||||
|
||||
/** Native authentication sent 20 bytes + '\0' character = 21 bytes.
|
||||
* This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin.
|
||||
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
|
||||
*/
|
||||
Handshake handshake(server_capability_flags, connection_id, VERSION_STRING + String("-") + VERSION_NAME, scramble + '\0');
|
||||
packet_sender->sendPacket<Handshake>(handshake, true);
|
||||
|
||||
LOG_TRACE(log, "Sent handshake");
|
||||
|
||||
HandshakeResponse handshake_response = finishHandshake();
|
||||
connection_context.client_capabilities = handshake_response.capability_flags;
|
||||
if (handshake_response.max_packet_size)
|
||||
connection_context.max_packet_size = handshake_response.max_packet_size;
|
||||
if (!connection_context.max_packet_size)
|
||||
connection_context.max_packet_size = MAX_PACKET_LENGTH;
|
||||
|
||||
LOG_DEBUG(log, "Capabilities: " << handshake_response.capability_flags
|
||||
<< "\nmax_packet_size: "
|
||||
<< handshake_response.max_packet_size
|
||||
<< "\ncharacter_set: "
|
||||
<< handshake_response.character_set
|
||||
<< "\nuser: "
|
||||
<< handshake_response.username
|
||||
<< "\nauth_response length: "
|
||||
<< handshake_response.auth_response.length()
|
||||
<< "\nauth_response: "
|
||||
<< handshake_response.auth_response
|
||||
<< "\ndatabase: "
|
||||
<< handshake_response.database
|
||||
<< "\nauth_plugin_name: "
|
||||
<< handshake_response.auth_plugin_name);
|
||||
|
||||
client_capability_flags = handshake_response.capability_flags;
|
||||
if (!(client_capability_flags & CLIENT_PROTOCOL_41))
|
||||
throw Exception("Required capability: CLIENT_PROTOCOL_41.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
|
||||
if (!(client_capability_flags & CLIENT_PLUGIN_AUTH))
|
||||
throw Exception("Required capability: CLIENT_PLUGIN_AUTH.", ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
|
||||
|
||||
authenticate(handshake_response, scramble);
|
||||
OK_Packet ok_packet(0, handshake_response.capability_flags, 0, 0, 0);
|
||||
packet_sender->sendPacket(ok_packet, true);
|
||||
|
||||
while (true)
|
||||
{
|
||||
packet_sender->resetSequenceId();
|
||||
String payload = packet_sender->receivePacketPayload();
|
||||
int command = payload[0];
|
||||
LOG_DEBUG(log, "Received command: " << std::to_string(command) << ". Connection id: " << connection_id << ".");
|
||||
try
|
||||
{
|
||||
switch (command)
|
||||
{
|
||||
case COM_QUIT:
|
||||
return;
|
||||
case COM_INIT_DB:
|
||||
comInitDB(payload);
|
||||
break;
|
||||
case COM_QUERY:
|
||||
comQuery(payload);
|
||||
break;
|
||||
case COM_FIELD_LIST:
|
||||
comFieldList(payload);
|
||||
break;
|
||||
case COM_PING:
|
||||
comPing();
|
||||
break;
|
||||
default:
|
||||
throw Exception(Poco::format("Command %d is not implemented.", command), ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
}
|
||||
catch (const NetException & exc)
|
||||
{
|
||||
log->log(exc);
|
||||
throw;
|
||||
}
|
||||
catch (const Exception & exc)
|
||||
{
|
||||
log->log(exc);
|
||||
packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Poco::Exception & exc)
|
||||
{
|
||||
log->log(exc);
|
||||
}
|
||||
}
|
||||
|
||||
/** Reads 3 bytes, finds out whether it is SSLRequest or HandshakeResponse packet, starts secure connection, if it is SSLRequest.
|
||||
* Reading is performed from socket instead of ReadBuffer to prevent reading part of SSL handshake.
|
||||
* If we read it from socket, it will be impossible to start SSL connection using Poco. Size of SSLRequest packet payload is 32 bytes, thus we can read at most 36 bytes.
|
||||
*/
|
||||
MySQLProtocol::HandshakeResponse MySQLHandler::finishHandshake()
|
||||
{
|
||||
HandshakeResponse packet;
|
||||
size_t packet_size = PACKET_HEADER_SIZE + SSL_REQUEST_PAYLOAD_SIZE;
|
||||
|
||||
/// Buffer for SSLRequest or part of HandshakeResponse.
|
||||
char buf[packet_size];
|
||||
size_t pos = 0;
|
||||
|
||||
/// Reads at least count and at most packet_size bytes.
|
||||
auto read_bytes = [this, &buf, &pos, &packet_size](size_t count) -> void {
|
||||
while (pos < count)
|
||||
{
|
||||
int ret = socket().receiveBytes(buf + pos, packet_size - pos);
|
||||
if (ret == 0)
|
||||
{
|
||||
throw Exception("Cannot read all data. Bytes read: " + std::to_string(pos) + ". Bytes expected: 3.", ErrorCodes::CANNOT_READ_ALL_DATA);
|
||||
}
|
||||
pos += ret;
|
||||
}
|
||||
};
|
||||
read_bytes(3); /// We can find out whether it is SSLRequest of HandshakeResponse by first 3 bytes.
|
||||
|
||||
size_t payload_size = unalignedLoad<uint32_t>(buf) & 0xFFFFFFu;
|
||||
LOG_TRACE(log, "payload size: " << payload_size);
|
||||
|
||||
if (payload_size == SSL_REQUEST_PAYLOAD_SIZE)
|
||||
{
|
||||
read_bytes(packet_size); /// Reading rest SSLRequest.
|
||||
SSLRequest ssl_request;
|
||||
ssl_request.readPayload(String(buf + PACKET_HEADER_SIZE, pos - PACKET_HEADER_SIZE));
|
||||
connection_context.client_capabilities = ssl_request.capability_flags;
|
||||
connection_context.max_packet_size = ssl_request.max_packet_size ? ssl_request.max_packet_size : MAX_PACKET_LENGTH;
|
||||
secure_connection = true;
|
||||
ss = std::make_shared<SecureStreamSocket>(SecureStreamSocket::attach(socket(), SSLManager::instance().defaultServerContext()));
|
||||
in = std::make_shared<ReadBufferFromPocoSocket>(*ss);
|
||||
out = std::make_shared<WriteBufferFromPocoSocket>(*ss);
|
||||
connection_context.sequence_id = 2;
|
||||
packet_sender = std::make_shared<PacketSender>(*in, *out, connection_context.sequence_id);
|
||||
packet_sender->max_packet_size = connection_context.max_packet_size;
|
||||
packet_sender->receivePacket(packet); /// Reading HandshakeResponse from secure socket.
|
||||
}
|
||||
else
|
||||
{
|
||||
/// Reading rest of HandshakeResponse.
|
||||
packet_size = PACKET_HEADER_SIZE + payload_size;
|
||||
WriteBufferFromOwnString buf_for_handshake_response;
|
||||
buf_for_handshake_response.write(buf, pos);
|
||||
copyData(*packet_sender->in, buf_for_handshake_response, packet_size - pos);
|
||||
packet.readPayload(buf_for_handshake_response.str().substr(PACKET_HEADER_SIZE));
|
||||
packet_sender->sequence_id++;
|
||||
}
|
||||
return packet;
|
||||
}
|
||||
|
||||
String MySQLHandler::generateScramble()
|
||||
{
|
||||
String scramble(MySQLProtocol::SCRAMBLE_LENGTH, 0);
|
||||
Poco::RandomInputStream generator;
|
||||
for (size_t i = 0; i < scramble.size(); i++)
|
||||
{
|
||||
generator >> scramble[i];
|
||||
}
|
||||
return scramble;
|
||||
}
|
||||
|
||||
void MySQLHandler::authenticate(const HandshakeResponse & handshake_response, const String & scramble)
|
||||
{
|
||||
|
||||
String auth_response;
|
||||
AuthSwitchResponse response;
|
||||
if (handshake_response.auth_plugin_name != Authentication::SHA256)
|
||||
{
|
||||
packet_sender->sendPacket(AuthSwitchRequest(Authentication::SHA256, scramble + '\0'), true);
|
||||
if (in->eof())
|
||||
throw Exception(
|
||||
"Client doesn't support authentication method " + String(Authentication::SHA256) + " used by ClickHouse",
|
||||
ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
|
||||
packet_sender->receivePacket(response);
|
||||
auth_response = response.value;
|
||||
LOG_TRACE(log, "Authentication method mismatch.");
|
||||
}
|
||||
else
|
||||
{
|
||||
auth_response = handshake_response.auth_response;
|
||||
LOG_TRACE(log, "Authentication method match.");
|
||||
}
|
||||
|
||||
if (auth_response == "\1")
|
||||
{
|
||||
LOG_TRACE(log, "Client requests public key.");
|
||||
|
||||
BIO * mem = BIO_new(BIO_s_mem());
|
||||
SCOPE_EXIT(BIO_free(mem));
|
||||
if (PEM_write_bio_RSA_PUBKEY(mem, &public_key) != 1)
|
||||
{
|
||||
throw Exception("Failed to write public key to memory. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
}
|
||||
char * pem_buf = nullptr;
|
||||
long pem_size = BIO_get_mem_data(mem, &pem_buf);
|
||||
String pem(pem_buf, pem_size);
|
||||
|
||||
LOG_TRACE(log, "Key: " << pem);
|
||||
|
||||
AuthMoreData data(pem);
|
||||
packet_sender->sendPacket(data, true);
|
||||
packet_sender->receivePacket(response);
|
||||
auth_response = response.value;
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_TRACE(log, "Client didn't request public key.");
|
||||
}
|
||||
|
||||
String password;
|
||||
|
||||
/** Decrypt password, if it's not empty.
|
||||
* The original intention was that the password is a string[NUL] but this never got enforced properly so now we have to accept that
|
||||
* an empty packet is a blank password, thus the check for auth_response.empty() has to be made too.
|
||||
* https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L4017
|
||||
*/
|
||||
if (!secure_connection && !auth_response.empty() && auth_response != String("\0", 1))
|
||||
{
|
||||
LOG_TRACE(log, "Received nonempty password");
|
||||
auto ciphertext = reinterpret_cast<unsigned char *>(auth_response.data());
|
||||
|
||||
unsigned char plaintext[RSA_size(&private_key)];
|
||||
int plaintext_size = RSA_private_decrypt(auth_response.size(), ciphertext, plaintext, &private_key, RSA_PKCS1_OAEP_PADDING);
|
||||
if (plaintext_size == -1)
|
||||
{
|
||||
throw Exception("Failed to decrypt auth data. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
}
|
||||
|
||||
password.resize(plaintext_size);
|
||||
for (int i = 0; i < plaintext_size; i++)
|
||||
{
|
||||
password[i] = plaintext[i] ^ static_cast<unsigned char>(scramble[i % scramble.size()]);
|
||||
}
|
||||
}
|
||||
else if (secure_connection)
|
||||
{
|
||||
password = auth_response;
|
||||
}
|
||||
else
|
||||
{
|
||||
LOG_TRACE(log, "Received empty password");
|
||||
}
|
||||
|
||||
if (!password.empty())
|
||||
{
|
||||
password.pop_back(); /// terminating null byte
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
connection_context.setUser(handshake_response.username, password, socket().address(), "");
|
||||
connection_context.setCurrentDatabase(handshake_response.database);
|
||||
connection_context.setCurrentQueryId("");
|
||||
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " succeeded.");
|
||||
}
|
||||
catch (const Exception & exc)
|
||||
{
|
||||
LOG_ERROR(log, "Authentication for user " << handshake_response.username << " failed.");
|
||||
packet_sender->sendPacket(ERR_Packet(exc.code(), "00000", exc.message()), true);
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
void MySQLHandler::comInitDB(const String & payload)
|
||||
{
|
||||
String database = payload.substr(1);
|
||||
LOG_DEBUG(log, "Setting current database to " << database);
|
||||
connection_context.setCurrentDatabase(database);
|
||||
packet_sender->sendPacket(OK_Packet(0, client_capability_flags, 0, 0, 1), true);
|
||||
}
|
||||
|
||||
void MySQLHandler::comFieldList(const String & payload)
|
||||
{
|
||||
ComFieldList packet;
|
||||
packet.readPayload(payload);
|
||||
String database = connection_context.getCurrentDatabase();
|
||||
StoragePtr tablePtr = connection_context.getTable(database, packet.table);
|
||||
for (const NameAndTypePair & column: tablePtr->getColumns().getAll())
|
||||
{
|
||||
ColumnDefinition column_definition(
|
||||
database, packet.table, packet.table, column.name, column.name, CharacterSet::binary, 100, ColumnType::MYSQL_TYPE_STRING, 0, 0
|
||||
);
|
||||
packet_sender->sendPacket(column_definition);
|
||||
}
|
||||
packet_sender->sendPacket(OK_Packet(0xfe, client_capability_flags, 0, 0, 0), true);
|
||||
}
|
||||
|
||||
void MySQLHandler::comPing()
|
||||
{
|
||||
packet_sender->sendPacket(OK_Packet(0x0, client_capability_flags, 0, 0, 0), true);
|
||||
}
|
||||
|
||||
void MySQLHandler::comQuery(const String & payload)
|
||||
{
|
||||
bool with_output = false;
|
||||
std::function<void(const String &)> set_content_type = [&with_output](const String &) -> void {
|
||||
with_output = true;
|
||||
};
|
||||
|
||||
String query = payload.substr(1);
|
||||
|
||||
// Translate query from MySQL to ClickHouse.
|
||||
// This is a temporary workaround until ClickHouse supports the syntax "@@var_name".
|
||||
if (query == "select @@version_comment limit 1") // MariaDB client starts session with that query
|
||||
query = "select ''";
|
||||
|
||||
ReadBufferFromString buf(query);
|
||||
executeQuery(buf, *out, true, connection_context, set_content_type, nullptr);
|
||||
if (!with_output)
|
||||
packet_sender->sendPacket(OK_Packet(0x00, client_capability_flags, 0, 0, 0), true);
|
||||
}
|
||||
|
||||
}
|
59
dbms/programs/server/MySQLHandler.h
Normal file
59
dbms/programs/server/MySQLHandler.h
Normal file
@ -0,0 +1,59 @@
|
||||
#pragma once
|
||||
|
||||
#include <Poco/Net/TCPServerConnection.h>
|
||||
#include <Poco/Net/SecureStreamSocket.h>
|
||||
#include <Common/getFQDNOrHostName.h>
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <openssl/rsa.h>
|
||||
#include "IServer.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Handler for MySQL wire protocol connections. Allows to connect to ClickHouse using MySQL client.
|
||||
class MySQLHandler : public Poco::Net::TCPServerConnection
|
||||
{
|
||||
public:
|
||||
MySQLHandler(IServer & server_, const Poco::Net::StreamSocket & socket_, RSA & public_key, RSA & private_key, bool ssl_enabled, size_t connection_id);
|
||||
|
||||
void run() final;
|
||||
|
||||
private:
|
||||
/// Enables SSL, if client requested.
|
||||
MySQLProtocol::HandshakeResponse finishHandshake();
|
||||
|
||||
void comQuery(const String & payload);
|
||||
|
||||
void comFieldList(const String & payload);
|
||||
|
||||
void comPing();
|
||||
|
||||
void comInitDB(const String & payload);
|
||||
|
||||
static String generateScramble();
|
||||
|
||||
void authenticate(const MySQLProtocol::HandshakeResponse &, const String & scramble);
|
||||
|
||||
IServer & server;
|
||||
Poco::Logger * log;
|
||||
Context connection_context;
|
||||
|
||||
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
|
||||
|
||||
size_t connection_id = 0;
|
||||
|
||||
size_t server_capability_flags;
|
||||
size_t client_capability_flags;
|
||||
|
||||
RSA & public_key;
|
||||
RSA & private_key;
|
||||
|
||||
std::shared_ptr<ReadBuffer> in;
|
||||
std::shared_ptr<WriteBuffer> out;
|
||||
|
||||
bool secure_connection = false;
|
||||
std::shared_ptr<Poco::Net::SecureStreamSocket> ss;
|
||||
};
|
||||
|
||||
}
|
124
dbms/programs/server/MySQLHandlerFactory.cpp
Normal file
124
dbms/programs/server/MySQLHandlerFactory.cpp
Normal file
@ -0,0 +1,124 @@
|
||||
#include <Common/OpenSSLHelpers.h>
|
||||
#include <Poco/Crypto/X509Certificate.h>
|
||||
#include <Poco/Net/SSLManager.h>
|
||||
#include <Poco/Net/TCPServerConnectionFactory.h>
|
||||
#include <Poco/Util/Application.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <ext/scope_guard.h>
|
||||
#include "IServer.h"
|
||||
#include "MySQLHandler.h"
|
||||
#include "MySQLHandlerFactory.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int CANNOT_OPEN_FILE;
|
||||
extern const int NO_ELEMENTS_IN_CONFIG;
|
||||
extern const int OPENSSL_ERROR;
|
||||
extern const int SYSTEM_ERROR;
|
||||
}
|
||||
|
||||
MySQLHandlerFactory::MySQLHandlerFactory(IServer & server_)
|
||||
: server(server_)
|
||||
, log(&Logger::get("MySQLHandlerFactory"))
|
||||
{
|
||||
try
|
||||
{
|
||||
Poco::Net::SSLManager::instance().defaultServerContext();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
LOG_INFO(log, "Failed to create SSL context. SSL will be disabled. Error: " << getCurrentExceptionMessage(false));
|
||||
ssl_enabled = false;
|
||||
}
|
||||
|
||||
/// Reading rsa keys for SHA256 authentication plugin.
|
||||
try
|
||||
{
|
||||
readRSAKeys();
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
LOG_WARNING(log, "Failed to read RSA keys. Error: " << getCurrentExceptionMessage(false));
|
||||
generateRSAKeys();
|
||||
}
|
||||
}
|
||||
|
||||
void MySQLHandlerFactory::readRSAKeys()
|
||||
{
|
||||
const Poco::Util::LayeredConfiguration & config = Poco::Util::Application::instance().config();
|
||||
String certificateFileProperty = "openSSL.server.certificateFile";
|
||||
String privateKeyFileProperty = "openSSL.server.privateKeyFile";
|
||||
|
||||
if (!config.has(certificateFileProperty))
|
||||
throw Exception("Certificate file is not set.", ErrorCodes::NO_ELEMENTS_IN_CONFIG);
|
||||
|
||||
if (!config.has(privateKeyFileProperty))
|
||||
throw Exception("Private key file is not set.", ErrorCodes::NO_ELEMENTS_IN_CONFIG);
|
||||
|
||||
{
|
||||
String certificateFile = config.getString(certificateFileProperty);
|
||||
FILE * fp = fopen(certificateFile.data(), "r");
|
||||
if (fp == nullptr)
|
||||
throw Exception("Cannot open certificate file: " + certificateFile + ".", ErrorCodes::CANNOT_OPEN_FILE);
|
||||
SCOPE_EXIT(fclose(fp));
|
||||
|
||||
X509 * x509 = PEM_read_X509(fp, nullptr, nullptr, nullptr);
|
||||
SCOPE_EXIT(X509_free(x509));
|
||||
if (x509 == nullptr)
|
||||
throw Exception("Failed to read PEM certificate from " + certificateFile + ". Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
|
||||
EVP_PKEY * p = X509_get_pubkey(x509);
|
||||
if (p == nullptr)
|
||||
throw Exception("Failed to get RSA key from X509. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
SCOPE_EXIT(EVP_PKEY_free(p));
|
||||
|
||||
public_key.reset(EVP_PKEY_get1_RSA(p));
|
||||
if (public_key.get() == nullptr)
|
||||
throw Exception("Failed to get RSA key from ENV_PKEY. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
}
|
||||
|
||||
{
|
||||
String privateKeyFile = config.getString(privateKeyFileProperty);
|
||||
|
||||
FILE * fp = fopen(privateKeyFile.data(), "r");
|
||||
if (fp == nullptr)
|
||||
throw Exception ("Cannot open private key file " + privateKeyFile + ".", ErrorCodes::CANNOT_OPEN_FILE);
|
||||
SCOPE_EXIT(fclose(fp));
|
||||
|
||||
private_key.reset(PEM_read_RSAPrivateKey(fp, nullptr, nullptr, nullptr));
|
||||
if (!private_key)
|
||||
throw Exception("Failed to read RSA private key from " + privateKeyFile + ". Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
void MySQLHandlerFactory::generateRSAKeys()
|
||||
{
|
||||
LOG_INFO(log, "Generating new RSA key.");
|
||||
public_key.reset(RSA_new());
|
||||
if (!public_key)
|
||||
throw Exception("Failed to allocate RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
|
||||
BIGNUM * e = BN_new();
|
||||
if (!e)
|
||||
throw Exception("Failed to allocate BIGNUM. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
SCOPE_EXIT(BN_free(e));
|
||||
|
||||
if (!BN_set_word(e, 65537) || !RSA_generate_key_ex(public_key.get(), 2048, e, nullptr))
|
||||
throw Exception("Failed to generate RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
|
||||
private_key.reset(RSAPrivateKey_dup(public_key.get()));
|
||||
if (!private_key)
|
||||
throw Exception("Failed to copy RSA key. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
|
||||
}
|
||||
|
||||
Poco::Net::TCPServerConnection * MySQLHandlerFactory::createConnection(const Poco::Net::StreamSocket & socket)
|
||||
{
|
||||
size_t connection_id = last_connection_id++;
|
||||
LOG_TRACE(log, "MySQL connection. Id: " << connection_id << ". Address: " << socket.peerAddress().toString());
|
||||
return new MySQLHandler(server, socket, *public_key, *private_key, ssl_enabled, connection_id);
|
||||
}
|
||||
|
||||
}
|
39
dbms/programs/server/MySQLHandlerFactory.h
Normal file
39
dbms/programs/server/MySQLHandlerFactory.h
Normal file
@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
|
||||
#include <Poco/Net/TCPServerConnectionFactory.h>
|
||||
#include <atomic>
|
||||
#include <openssl/rsa.h>
|
||||
#include "IServer.h"
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
class MySQLHandlerFactory : public Poco::Net::TCPServerConnectionFactory
|
||||
{
|
||||
private:
|
||||
IServer & server;
|
||||
Poco::Logger * log;
|
||||
|
||||
struct RSADeleter
|
||||
{
|
||||
void operator()(RSA * ptr) { RSA_free(ptr); }
|
||||
};
|
||||
using RSAPtr = std::unique_ptr<RSA, RSADeleter>;
|
||||
|
||||
RSAPtr public_key;
|
||||
RSAPtr private_key;
|
||||
|
||||
bool ssl_enabled = true;
|
||||
|
||||
std::atomic<size_t> last_connection_id = 0;
|
||||
public:
|
||||
explicit MySQLHandlerFactory(IServer & server_);
|
||||
|
||||
void readRSAKeys();
|
||||
|
||||
void generateRSAKeys();
|
||||
|
||||
Poco::Net::TCPServerConnection * createConnection(const Poco::Net::StreamSocket & socket) override;
|
||||
};
|
||||
|
||||
}
|
@ -49,6 +49,7 @@
|
||||
#include <Common/StatusFile.h>
|
||||
#include "TCPHandlerFactory.h"
|
||||
#include "Common/config_version.h"
|
||||
#include "MySQLHandlerFactory.h"
|
||||
|
||||
#if defined(__linux__)
|
||||
#include <Common/hasLinuxCapability.h>
|
||||
@ -668,7 +669,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
socket,
|
||||
new Poco::Net::TCPServerParams));
|
||||
|
||||
LOG_INFO(log, "Listening tcp: " + address.toString());
|
||||
LOG_INFO(log, "Listening for connections with native protocol (tcp): " + address.toString());
|
||||
}
|
||||
|
||||
/// TCP with SSL
|
||||
@ -685,7 +686,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
server_pool,
|
||||
socket,
|
||||
new Poco::Net::TCPServerParams));
|
||||
LOG_INFO(log, "Listening tcp_secure: " + address.toString());
|
||||
LOG_INFO(log, "Listening for connections with secure native protocol (tcp_secure): " + address.toString());
|
||||
#else
|
||||
throw Exception{"SSL support for TCP protocol is disabled because Poco library was built without NetSSL support.",
|
||||
ErrorCodes::SUPPORT_IS_DISABLED};
|
||||
@ -710,7 +711,7 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
socket,
|
||||
http_params));
|
||||
|
||||
LOG_INFO(log, "Listening interserver http: " + address.toString());
|
||||
LOG_INFO(log, "Listening for replica communication (interserver) http://" + address.toString());
|
||||
}
|
||||
|
||||
if (config().has("interserver_https_port"))
|
||||
@ -727,12 +728,27 @@ int Server::main(const std::vector<std::string> & /*args*/)
|
||||
socket,
|
||||
http_params));
|
||||
|
||||
LOG_INFO(log, "Listening interserver https: " + address.toString());
|
||||
LOG_INFO(log, "Listening for secure replica communication (interserver) https://" + address.toString());
|
||||
#else
|
||||
throw Exception{"SSL support for TCP protocol is disabled because Poco library was built without NetSSL support.",
|
||||
ErrorCodes::SUPPORT_IS_DISABLED};
|
||||
#endif
|
||||
}
|
||||
|
||||
if (config().has("mysql_port"))
|
||||
{
|
||||
Poco::Net::ServerSocket socket;
|
||||
auto address = socket_bind_listen(socket, listen_host, config().getInt("mysql_port"), /* secure = */ true);
|
||||
socket.setReceiveTimeout(Poco::Timespan());
|
||||
socket.setSendTimeout(settings.send_timeout);
|
||||
servers.emplace_back(std::make_unique<Poco::Net::TCPServer>(
|
||||
new MySQLHandlerFactory(*this),
|
||||
server_pool,
|
||||
socket,
|
||||
new Poco::Net::TCPServerParams));
|
||||
|
||||
LOG_INFO(log, "Listening for MySQL compatibility protocol: " + address.toString());
|
||||
}
|
||||
}
|
||||
catch (const Poco::Exception & e)
|
||||
{
|
||||
|
3
dbms/programs/server/data/.gitignore
vendored
3
dbms/programs/server/data/.gitignore
vendored
@ -1,2 +1,5 @@
|
||||
*.bin
|
||||
*.mrk
|
||||
*.txt
|
||||
*.dat
|
||||
*.idx
|
||||
|
@ -308,16 +308,88 @@ public:
|
||||
|
||||
/**
|
||||
* Check whether two bitmaps intersect.
|
||||
* Intersection with an empty set is always 0 (consistent with hasAny).
|
||||
*/
|
||||
UInt8 rb_intersect(const RoaringBitmapWithSmallSet & r1)
|
||||
UInt8 rb_intersect(const RoaringBitmapWithSmallSet & r1) const
|
||||
{
|
||||
if (isSmall())
|
||||
toLarge();
|
||||
roaring_bitmap_t * rb1 = r1.isSmall() ? r1.getNewRbFromSmall() : r1.getRb();
|
||||
UInt8 is_true = roaring_bitmap_intersect(rb, rb1);
|
||||
if (r1.isSmall())
|
||||
roaring_bitmap_free(rb1);
|
||||
return is_true;
|
||||
{
|
||||
if (r1.isSmall())
|
||||
{
|
||||
for (const auto & x : r1.small)
|
||||
if (small.find(x.getValue()) != small.end())
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
for (const auto & x : small)
|
||||
if (roaring_bitmap_contains(r1.rb, x.getValue()))
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
else if (r1.isSmall())
|
||||
{
|
||||
for (const auto & x : r1.small)
|
||||
if (roaring_bitmap_contains(rb, x.getValue()))
|
||||
return 1;
|
||||
}
|
||||
else if (roaring_bitmap_intersect(rb, r1.rb))
|
||||
return 1;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether the argument is the subset of this set.
|
||||
* Empty set is a subset of any other set (consistent with hasAll).
|
||||
*/
|
||||
UInt8 rb_is_subset(const RoaringBitmapWithSmallSet & r1) const
|
||||
{
|
||||
if (isSmall())
|
||||
{
|
||||
if (r1.isSmall())
|
||||
{
|
||||
for (const auto & x : r1.small)
|
||||
if (small.find(x.getValue()) == small.end())
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
UInt64 r1_size = r1.size();
|
||||
|
||||
if (r1_size > small.size())
|
||||
return 0; // A bigger set can not be a subset of ours.
|
||||
|
||||
// This is a rare case with a small number of elements on
|
||||
// both sides: r1 was promoted to large for some reason and
|
||||
// it is still not larger than our small set.
|
||||
// If r1 is our subset then our size must be equal to
|
||||
// r1_size + number of not found elements, if this sum becomes
|
||||
// greater then r1 is not a subset.
|
||||
for (const auto & x : small)
|
||||
if (!roaring_bitmap_contains(r1.rb, x.getValue()) && ++r1_size > small.size())
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
else if (r1.isSmall())
|
||||
{
|
||||
for (const auto & x : r1.small)
|
||||
if (!roaring_bitmap_contains(rb, x.getValue()))
|
||||
return 0;
|
||||
}
|
||||
else if (!roaring_bitmap_is_subset(r1.rb, rb))
|
||||
return 0;
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check whether this bitmap contains the argument.
|
||||
*/
|
||||
UInt8 rb_contains(const UInt32 x) const
|
||||
{
|
||||
return isSmall() ? small.find(x) != small.end() :
|
||||
roaring_bitmap_contains(rb, x);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -3,6 +3,8 @@
|
||||
#include <IO/ReadHelpers.h>
|
||||
#include <IO/WriteHelpers.h>
|
||||
#include <Interpreters/castColumn.h>
|
||||
#include <Columns/ColumnArray.h>
|
||||
#include <Columns/ColumnTuple.h>
|
||||
#include <Common/FieldVisitors.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
#include "AggregateFunctionFactory.h"
|
||||
@ -34,7 +36,7 @@ namespace
|
||||
|
||||
for (size_t i = 0; i < argument_types.size(); ++i)
|
||||
{
|
||||
if (!isNumber(argument_types[i]))
|
||||
if (!isNativeNumber(argument_types[i]))
|
||||
throw Exception(
|
||||
"Argument " + std::to_string(i) + " of type " + argument_types[i]->getName()
|
||||
+ " must be numeric for aggregate function " + name,
|
||||
@ -110,8 +112,8 @@ namespace
|
||||
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("LinearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||
factory.registerFunction("LogisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
|
||||
factory.registerFunction("linearRegression", createAggregateFunctionMLMethod<FuncLinearRegression>);
|
||||
factory.registerFunction("logisticRegression", createAggregateFunctionMLMethod<FuncLogisticRegression>);
|
||||
}
|
||||
|
||||
LinearModelData::LinearModelData(
|
||||
@ -149,6 +151,27 @@ void LinearModelData::predict(
|
||||
gradient_computer->predict(container, block, arguments, weights, bias, context);
|
||||
}
|
||||
|
||||
void LinearModelData::returnWeights(IColumn & to) const
|
||||
{
|
||||
size_t size = weights.size() + 1;
|
||||
|
||||
ColumnArray & arr_to = static_cast<ColumnArray &>(to);
|
||||
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
|
||||
|
||||
size_t old_size = offsets_to.back();
|
||||
offsets_to.push_back(old_size + size);
|
||||
|
||||
typename ColumnFloat64::Container & val_to
|
||||
= static_cast<ColumnFloat64 &>(arr_to.getData()).getData();
|
||||
|
||||
val_to.reserve(old_size + size);
|
||||
|
||||
for (size_t i = 0; i + 1 < size; ++i)
|
||||
val_to.push_back(weights[i]);
|
||||
|
||||
val_to.push_back(bias);
|
||||
}
|
||||
|
||||
void LinearModelData::read(ReadBuffer & buf)
|
||||
{
|
||||
readBinary(bias, buf);
|
||||
@ -345,7 +368,7 @@ void LogisticRegression::predict(
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
|
||||
if (!isNumber(cur_col.type))
|
||||
if (!isNativeNumber(cur_col.type))
|
||||
{
|
||||
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
@ -418,7 +441,7 @@ void LinearRegression::predict(
|
||||
for (size_t i = 1; i < arguments.size(); ++i)
|
||||
{
|
||||
const ColumnWithTypeAndName & cur_col = block.getByPosition(arguments[i]);
|
||||
if (!isNumber(cur_col.type))
|
||||
if (!isNativeNumber(cur_col.type))
|
||||
{
|
||||
throw Exception("Prediction arguments must have numeric type", ErrorCodes::BAD_ARGUMENTS);
|
||||
}
|
||||
|
@ -4,6 +4,8 @@
|
||||
#include <Columns/ColumnsCommon.h>
|
||||
#include <Columns/ColumnsNumber.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypeTuple.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include "IAggregateFunction.h"
|
||||
|
||||
namespace DB
|
||||
@ -218,6 +220,7 @@ public:
|
||||
void
|
||||
predict(ColumnVector<Float64>::Container & container, Block & block, const ColumnNumbers & arguments, const Context & context) const;
|
||||
|
||||
void returnWeights(IColumn & to) const;
|
||||
private:
|
||||
std::vector<Float64> weights;
|
||||
Float64 bias{0.0};
|
||||
@ -269,7 +272,15 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override { return std::make_shared<DataTypeNumber<Float64>>(); }
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
return std::make_shared<DataTypeArray>(std::make_shared<DataTypeFloat64>());
|
||||
}
|
||||
|
||||
DataTypePtr getReturnTypeToPredict() const override
|
||||
{
|
||||
return std::make_shared<DataTypeNumber<Float64>>();
|
||||
}
|
||||
|
||||
void create(AggregateDataPtr place) const override
|
||||
{
|
||||
@ -301,11 +312,12 @@ public:
|
||||
this->data(place).predict(column.getData(), block, arguments, context);
|
||||
}
|
||||
|
||||
/** This function is called if aggregate function without State modifier is selected in a query.
|
||||
* Inserts all weights of the model into the column 'to', so user may use such information if needed
|
||||
*/
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
std::ignore = place;
|
||||
std::ignore = to;
|
||||
throw std::runtime_error("not implemented");
|
||||
this->data(place).returnWeights(to);
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override { return __FILE__; }
|
||||
@ -321,10 +333,10 @@ private:
|
||||
|
||||
struct NameLinearRegression
|
||||
{
|
||||
static constexpr auto name = "LinearRegression";
|
||||
static constexpr auto name = "linearRegression";
|
||||
};
|
||||
struct NameLogisticRegression
|
||||
{
|
||||
static constexpr auto name = "LogisticRegression";
|
||||
static constexpr auto name = "logisticRegression";
|
||||
};
|
||||
}
|
||||
|
@ -61,10 +61,10 @@ public:
|
||||
AggregateFunctionIntersectionsMax(AggregateFunctionIntersectionsKind kind_, const DataTypes & arguments)
|
||||
: IAggregateFunctionDataHelper<MaxIntersectionsData<PointType>, AggregateFunctionIntersectionsMax<PointType>>(arguments, {}), kind(kind_)
|
||||
{
|
||||
if (!isNumber(arguments[0]))
|
||||
if (!isNativeNumber(arguments[0]))
|
||||
throw Exception{getName() + ": first argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
if (!isNumber(arguments[1]))
|
||||
if (!isNativeNumber(arguments[1]))
|
||||
throw Exception{getName() + ": second argument must be represented by integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
if (!arguments[0]->equals(*arguments[1]))
|
||||
|
@ -1,6 +1,12 @@
|
||||
#include <AggregateFunctions/Helpers.h>
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSequenceMatch.h>
|
||||
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
#include <DataTypes/DataTypeDateTime.h>
|
||||
|
||||
#include <ext/range.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
@ -12,32 +18,58 @@ namespace ErrorCodes
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSequenceCount(const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
template <template <typename, typename> class AggregateFunction, template <typename> class Data>
|
||||
AggregateFunctionPtr createAggregateFunctionSequenceBase(const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
|
||||
|
||||
String pattern = params.front().safeGet<std::string>();
|
||||
return std::make_shared<AggregateFunctionSequenceCount>(argument_types, params, pattern);
|
||||
}
|
||||
const auto arg_count = argument_types.size();
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionSequenceMatch(const std::string & name, const DataTypes & argument_types, const Array & params)
|
||||
{
|
||||
if (params.size() != 1)
|
||||
throw Exception{"Aggregate function " + name + " requires exactly one parameter.",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH};
|
||||
if (arg_count < 3)
|
||||
throw Exception{"Aggregate function " + name + " requires at least 3 arguments.",
|
||||
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
if (arg_count - 1 > max_events)
|
||||
throw Exception{"Aggregate function " + name + " supports up to "
|
||||
+ toString(max_events) + " event arguments.",
|
||||
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
const auto time_arg = argument_types.front().get();
|
||||
|
||||
for (const auto i : ext::range(1, arg_count))
|
||||
{
|
||||
const auto cond_arg = argument_types[i].get();
|
||||
if (!isUInt8(cond_arg))
|
||||
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1)
|
||||
+ " of aggregate function " + name + ", must be UInt8",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
String pattern = params.front().safeGet<std::string>();
|
||||
return std::make_shared<AggregateFunctionSequenceMatch>(argument_types, params, pattern);
|
||||
|
||||
AggregateFunctionPtr res(createWithUnsignedIntegerType<AggregateFunction, Data>(*argument_types[0], argument_types, params, pattern));
|
||||
if (res)
|
||||
return res;
|
||||
|
||||
WhichDataType which(argument_types.front().get());
|
||||
if (which.isDateTime())
|
||||
return std::make_shared<AggregateFunction<DataTypeDateTime::FieldType, Data<DataTypeDateTime::FieldType>>>(argument_types, params, pattern);
|
||||
else if (which.isDate())
|
||||
return std::make_shared<AggregateFunction<DataTypeDate::FieldType, Data<DataTypeDate::FieldType>>>(argument_types, params, pattern);
|
||||
|
||||
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
|
||||
+ name + ", must be DateTime",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionsSequenceMatch(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceMatch);
|
||||
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceCount);
|
||||
factory.registerFunction("sequenceMatch", createAggregateFunctionSequenceBase<AggregateFunctionSequenceMatch, AggregateFunctionSequenceMatchData>);
|
||||
factory.registerFunction("sequenceCount", createAggregateFunctionSequenceBase<AggregateFunctionSequenceCount, AggregateFunctionSequenceMatchData>);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -36,11 +36,12 @@ struct ComparePairFirst final
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr auto max_events = 32;
|
||||
|
||||
template <typename T>
|
||||
struct AggregateFunctionSequenceMatchData final
|
||||
{
|
||||
static constexpr auto max_events = 32;
|
||||
|
||||
using Timestamp = std::uint32_t;
|
||||
using Timestamp = T;
|
||||
using Events = std::bitset<max_events>;
|
||||
using TimestampEvents = std::pair<Timestamp, Events>;
|
||||
using Comparator = ComparePairFirst<std::less>;
|
||||
@ -61,6 +62,9 @@ struct AggregateFunctionSequenceMatchData final
|
||||
|
||||
void merge(const AggregateFunctionSequenceMatchData & other)
|
||||
{
|
||||
if (other.events_list.empty())
|
||||
return;
|
||||
|
||||
const auto size = events_list.size();
|
||||
|
||||
events_list.insert(std::begin(other.events_list), std::end(other.events_list));
|
||||
@ -119,7 +123,7 @@ struct AggregateFunctionSequenceMatchData final
|
||||
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
{
|
||||
std::uint32_t timestamp;
|
||||
Timestamp timestamp;
|
||||
readBinary(timestamp, buf);
|
||||
|
||||
UInt64 events;
|
||||
@ -135,48 +139,23 @@ struct AggregateFunctionSequenceMatchData final
|
||||
constexpr auto sequence_match_max_iterations = 1000000;
|
||||
|
||||
|
||||
template <typename Derived>
|
||||
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>
|
||||
template <typename T, typename Data, typename Derived>
|
||||
class AggregateFunctionSequenceBase : public IAggregateFunctionDataHelper<Data, Derived>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceBase(const DataTypes & arguments, const Array & params, const String & pattern)
|
||||
: IAggregateFunctionDataHelper<AggregateFunctionSequenceMatchData, Derived>(arguments, params)
|
||||
: IAggregateFunctionDataHelper<Data, Derived>(arguments, params)
|
||||
, pattern(pattern)
|
||||
{
|
||||
arg_count = arguments.size();
|
||||
|
||||
if (!sufficientArgs(arg_count))
|
||||
throw Exception{"Aggregate function " + derived().getName() + " requires at least 3 arguments.",
|
||||
ErrorCodes::TOO_FEW_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
if (arg_count - 1 > AggregateFunctionSequenceMatchData::max_events)
|
||||
throw Exception{"Aggregate function " + derived().getName() + " supports up to " +
|
||||
toString(AggregateFunctionSequenceMatchData::max_events) + " event arguments.",
|
||||
ErrorCodes::TOO_MANY_ARGUMENTS_FOR_FUNCTION};
|
||||
|
||||
const auto time_arg = arguments.front().get();
|
||||
if (!WhichDataType(time_arg).isDateTime())
|
||||
throw Exception{"Illegal type " + time_arg->getName() + " of first argument of aggregate function "
|
||||
+ derived().getName() + ", must be DateTime",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
for (const auto i : ext::range(1, arg_count))
|
||||
{
|
||||
const auto cond_arg = arguments[i].get();
|
||||
if (!isUInt8(cond_arg))
|
||||
throw Exception{"Illegal type " + cond_arg->getName() + " of argument " + toString(i + 1) +
|
||||
" of aggregate function " + derived().getName() + ", must be UInt8",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
parsePattern();
|
||||
}
|
||||
|
||||
void add(AggregateDataPtr place, const IColumn ** columns, const size_t row_num, Arena *) const override
|
||||
{
|
||||
const auto timestamp = static_cast<const ColumnUInt32 *>(columns[0])->getData()[row_num];
|
||||
const auto timestamp = static_cast<const ColumnVector<T> *>(columns[0])->getData()[row_num];
|
||||
|
||||
AggregateFunctionSequenceMatchData::Events events;
|
||||
typename Data::Events events;
|
||||
for (const auto i : ext::range(1, arg_count))
|
||||
{
|
||||
const auto event = static_cast<const ColumnUInt8 *>(columns[i])->getData()[row_num];
|
||||
@ -218,17 +197,15 @@ private:
|
||||
struct PatternAction final
|
||||
{
|
||||
PatternActionType type;
|
||||
std::uint32_t extra;
|
||||
std::uint64_t extra;
|
||||
|
||||
PatternAction() = default;
|
||||
PatternAction(const PatternActionType type, const std::uint32_t extra = 0) : type{type}, extra{extra} {}
|
||||
PatternAction(const PatternActionType type, const std::uint64_t extra = 0) : type{type}, extra{extra} {}
|
||||
};
|
||||
|
||||
static constexpr size_t bytes_on_stack = 64;
|
||||
using PatternActions = PODArray<PatternAction, bytes_on_stack, AllocatorWithStackMemory<Allocator<false>, bytes_on_stack>>;
|
||||
|
||||
static bool sufficientArgs(const size_t arg_count) { return arg_count >= 3; }
|
||||
|
||||
Derived & derived() { return static_cast<Derived &>(*this); }
|
||||
|
||||
void parsePattern()
|
||||
@ -340,8 +317,8 @@ protected:
|
||||
/// This algorithm performs in O(mn) (with m the number of DFA states and N the number
|
||||
/// of events) with a memory consumption and memory allocations in O(m). It means that
|
||||
/// if n >>> m (which is expected to be the case), this algorithm can be considered linear.
|
||||
template <typename T>
|
||||
bool dfaMatch(T & events_it, const T events_end) const
|
||||
template <typename EventEntry>
|
||||
bool dfaMatch(EventEntry & events_it, const EventEntry events_end) const
|
||||
{
|
||||
using ActiveStates = std::vector<bool>;
|
||||
|
||||
@ -396,8 +373,8 @@ protected:
|
||||
return active_states.back();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool backtrackingMatch(T & events_it, const T events_end) const
|
||||
template <typename EventEntry>
|
||||
bool backtrackingMatch(EventEntry & events_it, const EventEntry events_end) const
|
||||
{
|
||||
const auto action_begin = std::begin(actions);
|
||||
const auto action_end = std::end(actions);
|
||||
@ -407,7 +384,7 @@ protected:
|
||||
auto base_it = events_it;
|
||||
|
||||
/// an iterator to action plus an iterator to row in events list plus timestamp at the start of sequence
|
||||
using backtrack_info = std::tuple<decltype(action_it), T, T>;
|
||||
using backtrack_info = std::tuple<decltype(action_it), EventEntry, EventEntry>;
|
||||
std::stack<backtrack_info> back_stack;
|
||||
|
||||
/// backtrack if possible
|
||||
@ -458,7 +435,7 @@ protected:
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeLessOrEqual)
|
||||
{
|
||||
if (events_it->first - base_it->first <= action_it->extra)
|
||||
if (events_it->first <= base_it->first + action_it->extra)
|
||||
{
|
||||
/// condition satisfied, move onto next action
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
@ -470,7 +447,7 @@ protected:
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeLess)
|
||||
{
|
||||
if (events_it->first - base_it->first < action_it->extra)
|
||||
if (events_it->first < base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
@ -481,7 +458,7 @@ protected:
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeGreaterOrEqual)
|
||||
{
|
||||
if (events_it->first - base_it->first >= action_it->extra)
|
||||
if (events_it->first >= base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
@ -492,7 +469,7 @@ protected:
|
||||
}
|
||||
else if (action_it->type == PatternActionType::TimeGreater)
|
||||
{
|
||||
if (events_it->first - base_it->first > action_it->extra)
|
||||
if (events_it->first > base_it->first + action_it->extra)
|
||||
{
|
||||
back_stack.emplace(action_it, events_it, base_it);
|
||||
base_it = events_it;
|
||||
@ -575,14 +552,14 @@ private:
|
||||
DFAStates dfa_states;
|
||||
};
|
||||
|
||||
|
||||
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionSequenceMatch final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceMatch(const DataTypes & arguments, const Array & params, const String & pattern)
|
||||
: AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>(arguments, params, pattern) {}
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>(arguments, params, pattern) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<AggregateFunctionSequenceMatch>::AggregateFunctionSequenceBase;
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceMatch<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceMatch"; }
|
||||
|
||||
@ -590,27 +567,27 @@ public:
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const_cast<Data &>(data(place)).sort();
|
||||
const_cast<Data &>(this->data(place)).sort();
|
||||
|
||||
const auto & data_ref = data(place);
|
||||
const auto & data_ref = this->data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.events_list);
|
||||
const auto events_end = std::end(data_ref.events_list);
|
||||
auto events_it = events_begin;
|
||||
|
||||
bool match = pattern_has_time ? backtrackingMatch(events_it, events_end) : dfaMatch(events_it, events_end);
|
||||
bool match = this->pattern_has_time ? this->backtrackingMatch(events_it, events_end) : this->dfaMatch(events_it, events_end);
|
||||
static_cast<ColumnUInt8 &>(to).getData().push_back(match);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>
|
||||
template <typename T, typename Data>
|
||||
class AggregateFunctionSequenceCount final : public AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionSequenceCount(const DataTypes & arguments, const Array & params, const String & pattern)
|
||||
: AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>(arguments, params, pattern) {}
|
||||
: AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>(arguments, params, pattern) {}
|
||||
|
||||
using AggregateFunctionSequenceBase<AggregateFunctionSequenceCount>::AggregateFunctionSequenceBase;
|
||||
using AggregateFunctionSequenceBase<T, Data, AggregateFunctionSequenceCount<T, Data>>::AggregateFunctionSequenceBase;
|
||||
|
||||
String getName() const override { return "sequenceCount"; }
|
||||
|
||||
@ -618,21 +595,21 @@ public:
|
||||
|
||||
void insertResultInto(ConstAggregateDataPtr place, IColumn & to) const override
|
||||
{
|
||||
const_cast<Data &>(data(place)).sort();
|
||||
const_cast<Data &>(this->data(place)).sort();
|
||||
static_cast<ColumnUInt64 &>(to).getData().push_back(count(place));
|
||||
}
|
||||
|
||||
private:
|
||||
UInt64 count(const ConstAggregateDataPtr & place) const
|
||||
{
|
||||
const auto & data_ref = data(place);
|
||||
const auto & data_ref = this->data(place);
|
||||
|
||||
const auto events_begin = std::begin(data_ref.events_list);
|
||||
const auto events_end = std::end(data_ref.events_list);
|
||||
auto events_it = events_begin;
|
||||
|
||||
size_t count = 0;
|
||||
while (events_it != events_end && backtrackingMatch(events_it, events_end))
|
||||
while (events_it != events_end && this->backtrackingMatch(events_it, events_end))
|
||||
++count;
|
||||
|
||||
return count;
|
||||
|
@ -1,8 +1,9 @@
|
||||
#include <AggregateFunctions/AggregateFunctionLeastSqr.h>
|
||||
#include <AggregateFunctions/AggregateFunctionSimpleLinearRegression.h>
|
||||
|
||||
#include <AggregateFunctions/AggregateFunctionFactory.h>
|
||||
#include <AggregateFunctions/FactoryHelpers.h>
|
||||
|
||||
#include <Core/TypeListNumber.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
@ -10,7 +11,7 @@ namespace DB
|
||||
namespace
|
||||
{
|
||||
|
||||
AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
AggregateFunctionPtr createAggregateFunctionSimpleLinearRegression(
|
||||
const String & name,
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
@ -20,16 +21,11 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
assertBinary(name, arguments);
|
||||
|
||||
const IDataType * x_arg = arguments.front().get();
|
||||
|
||||
WhichDataType which_x {
|
||||
x_arg
|
||||
};
|
||||
WhichDataType which_x = x_arg;
|
||||
|
||||
const IDataType * y_arg = arguments.back().get();
|
||||
WhichDataType which_y = y_arg;
|
||||
|
||||
WhichDataType which_y {
|
||||
y_arg
|
||||
};
|
||||
|
||||
#define FOR_LEASTSQR_TYPES_2(M, T) \
|
||||
M(T, UInt8) \
|
||||
@ -55,7 +51,7 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
FOR_LEASTSQR_TYPES_2(M, Float64)
|
||||
#define DISPATCH(T1, T2) \
|
||||
if (which_x.idx == TypeIndex::T1 && which_y.idx == TypeIndex::T2) \
|
||||
return std::make_shared<AggregateFunctionLeastSqr<T1, T2>>( \
|
||||
return std::make_shared<AggregateFunctionSimpleLinearRegression<T1, T2>>( \
|
||||
arguments, \
|
||||
params \
|
||||
);
|
||||
@ -77,9 +73,9 @@ AggregateFunctionPtr createAggregateFunctionLeastSqr(
|
||||
|
||||
}
|
||||
|
||||
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory & factory)
|
||||
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory & factory)
|
||||
{
|
||||
factory.registerFunction("leastSqr", createAggregateFunctionLeastSqr);
|
||||
factory.registerFunction("simpleLinearRegression", createAggregateFunctionSimpleLinearRegression);
|
||||
}
|
||||
|
||||
}
|
@ -19,7 +19,7 @@ namespace ErrorCodes
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename Ret>
|
||||
struct AggregateFunctionLeastSqrData final
|
||||
struct AggregateFunctionSimpleLinearRegressionData final
|
||||
{
|
||||
size_t count = 0;
|
||||
Ret sum_x = 0;
|
||||
@ -36,7 +36,7 @@ struct AggregateFunctionLeastSqrData final
|
||||
sum_xy += x * y;
|
||||
}
|
||||
|
||||
void merge(const AggregateFunctionLeastSqrData & other)
|
||||
void merge(const AggregateFunctionSimpleLinearRegressionData & other)
|
||||
{
|
||||
count += other.count;
|
||||
sum_x += other.sum_x;
|
||||
@ -85,19 +85,19 @@ struct AggregateFunctionLeastSqrData final
|
||||
/// Calculates simple linear regression parameters.
|
||||
/// Result is a tuple (k, b) for y = k * x + b equation, solved by least squares approximation.
|
||||
template <typename X, typename Y, typename Ret = Float64>
|
||||
class AggregateFunctionLeastSqr final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionLeastSqrData<X, Y, Ret>,
|
||||
AggregateFunctionLeastSqr<X, Y, Ret>
|
||||
class AggregateFunctionSimpleLinearRegression final : public IAggregateFunctionDataHelper<
|
||||
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
|
||||
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
|
||||
>
|
||||
{
|
||||
public:
|
||||
AggregateFunctionLeastSqr(
|
||||
AggregateFunctionSimpleLinearRegression(
|
||||
const DataTypes & arguments,
|
||||
const Array & params
|
||||
):
|
||||
IAggregateFunctionDataHelper<
|
||||
AggregateFunctionLeastSqrData<X, Y, Ret>,
|
||||
AggregateFunctionLeastSqr<X, Y, Ret>
|
||||
AggregateFunctionSimpleLinearRegressionData<X, Y, Ret>,
|
||||
AggregateFunctionSimpleLinearRegression<X, Y, Ret>
|
||||
> {arguments, params}
|
||||
{
|
||||
// notice: arguments has been checked before
|
||||
@ -105,7 +105,7 @@ public:
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
return "leastSqr";
|
||||
return "simpleLinearRegression";
|
||||
}
|
||||
|
||||
const char * getHeaderFilePath() const override
|
||||
@ -120,12 +120,8 @@ public:
|
||||
Arena *
|
||||
) const override
|
||||
{
|
||||
auto col_x {
|
||||
static_cast<const ColumnVector<X> *>(columns[0])
|
||||
};
|
||||
auto col_y {
|
||||
static_cast<const ColumnVector<Y> *>(columns[1])
|
||||
};
|
||||
auto col_x = static_cast<const ColumnVector<X> *>(columns[0]);
|
||||
auto col_y = static_cast<const ColumnVector<Y> *>(columns[1]);
|
||||
|
||||
X x = col_x->getData()[row_num];
|
||||
Y y = col_y->getData()[row_num];
|
||||
@ -159,12 +155,14 @@ public:
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
{
|
||||
DataTypes types {
|
||||
DataTypes types
|
||||
{
|
||||
std::make_shared<DataTypeNumber<Ret>>(),
|
||||
std::make_shared<DataTypeNumber<Ret>>(),
|
||||
};
|
||||
|
||||
Strings names {
|
||||
Strings names
|
||||
{
|
||||
"k",
|
||||
"b",
|
||||
};
|
@ -56,6 +56,10 @@ void registerAggregateFunctionsStatisticsSimple(AggregateFunctionFactory & facto
|
||||
factory.registerFunction("varPop", createAggregateFunctionStatisticsUnary<AggregateFunctionVarPopSimple>);
|
||||
factory.registerFunction("stddevSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevSampSimple>);
|
||||
factory.registerFunction("stddevPop", createAggregateFunctionStatisticsUnary<AggregateFunctionStddevPopSimple>);
|
||||
factory.registerFunction("skewSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionSkewSampSimple>);
|
||||
factory.registerFunction("skewPop", createAggregateFunctionStatisticsUnary<AggregateFunctionSkewPopSimple>);
|
||||
factory.registerFunction("kurtSamp", createAggregateFunctionStatisticsUnary<AggregateFunctionKurtSampSimple>);
|
||||
factory.registerFunction("kurtPop", createAggregateFunctionStatisticsUnary<AggregateFunctionKurtPopSimple>);
|
||||
|
||||
factory.registerFunction("covarSamp", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarSampSimple>);
|
||||
factory.registerFunction("covarPop", createAggregateFunctionStatisticsBinary<AggregateFunctionCovarPopSimple>);
|
||||
|
@ -32,30 +32,42 @@ namespace DB
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int LOGICAL_ERROR;
|
||||
extern const int DECIMAL_OVERFLOW;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
/**
|
||||
Calculating univariate central moments
|
||||
Levels:
|
||||
level 2 (pop & samp): var, stddev
|
||||
level 3: skewness
|
||||
level 4: kurtosis
|
||||
References:
|
||||
https://en.wikipedia.org/wiki/Moment_(mathematics)
|
||||
https://en.wikipedia.org/wiki/Skewness
|
||||
https://en.wikipedia.org/wiki/Kurtosis
|
||||
*/
|
||||
template <typename T, size_t _level>
|
||||
struct VarMoments
|
||||
{
|
||||
T m0{};
|
||||
T m1{};
|
||||
T m2{};
|
||||
T m[_level + 1]{};
|
||||
|
||||
void add(T x)
|
||||
{
|
||||
++m0;
|
||||
m1 += x;
|
||||
m2 += x * x;
|
||||
++m[0];
|
||||
m[1] += x;
|
||||
m[2] += x * x;
|
||||
if constexpr (_level >= 3) m[3] += x * x * x;
|
||||
if constexpr (_level >= 4) m[4] += x * x * x * x;
|
||||
}
|
||||
|
||||
void merge(const VarMoments & rhs)
|
||||
{
|
||||
m0 += rhs.m0;
|
||||
m1 += rhs.m1;
|
||||
m2 += rhs.m2;
|
||||
m[0] += rhs.m[0];
|
||||
m[1] += rhs.m[1];
|
||||
m[2] += rhs.m[2];
|
||||
if constexpr (_level >= 3) m[3] += rhs.m[3];
|
||||
if constexpr (_level >= 4) m[4] += rhs.m[4];
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const
|
||||
@ -70,45 +82,90 @@ struct VarMoments
|
||||
|
||||
T NO_SANITIZE_UNDEFINED getPopulation() const
|
||||
{
|
||||
return (m2 - m1 * m1 / m0) / m0;
|
||||
return (m[2] - m[1] * m[1] / m[0]) / m[0];
|
||||
}
|
||||
|
||||
T NO_SANITIZE_UNDEFINED getSample() const
|
||||
{
|
||||
if (m0 == 0)
|
||||
if (m[0] == 0)
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
return (m2 - m1 * m1 / m0) / (m0 - 1);
|
||||
return (m[2] - m[1] * m[1] / m[0]) / (m[0] - 1);
|
||||
}
|
||||
|
||||
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
|
||||
T NO_SANITIZE_UNDEFINED getMoment3() const
|
||||
{
|
||||
// to avoid accuracy problem
|
||||
if (m[0] == 1)
|
||||
return 0;
|
||||
return (m[3]
|
||||
- (3 * m[2]
|
||||
- 2 * m[1] * m[1] / m[0]
|
||||
) * m[1] / m[0]
|
||||
) / m[0];
|
||||
}
|
||||
|
||||
T NO_SANITIZE_UNDEFINED getMoment4() const
|
||||
{
|
||||
// to avoid accuracy problem
|
||||
if (m[0] == 1)
|
||||
return 0;
|
||||
return (m[4]
|
||||
- (4 * m[3]
|
||||
- (6 * m[2]
|
||||
- 3 * m[1] * m[1] / m[0]
|
||||
) * m[1] / m[0]
|
||||
) * m[1] / m[0]
|
||||
) / m[0];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename T, size_t _level>
|
||||
struct VarMomentsDecimal
|
||||
{
|
||||
using NativeType = typename T::NativeType;
|
||||
|
||||
UInt64 m0{};
|
||||
NativeType m1{};
|
||||
NativeType m2{};
|
||||
NativeType m[_level]{};
|
||||
|
||||
NativeType & getM(size_t i)
|
||||
{
|
||||
return m[i - 1];
|
||||
}
|
||||
|
||||
const NativeType & getM(size_t i) const
|
||||
{
|
||||
return m[i - 1];
|
||||
}
|
||||
|
||||
void add(NativeType x)
|
||||
{
|
||||
++m0;
|
||||
m1 += x;
|
||||
getM(1) += x;
|
||||
|
||||
NativeType tmp; /// scale' = 2 * scale
|
||||
if (common::mulOverflow(x, x, tmp) || common::addOverflow(m2, tmp, m2))
|
||||
NativeType tmp;
|
||||
if (common::mulOverflow(x, x, tmp) || common::addOverflow(getM(2), tmp, getM(2)))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
if constexpr (_level >= 3)
|
||||
if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(3), tmp, getM(3)))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
if constexpr (_level >= 4)
|
||||
if (common::mulOverflow(tmp, x, tmp) || common::addOverflow(getM(4), tmp, getM(4)))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
}
|
||||
|
||||
void merge(const VarMomentsDecimal & rhs)
|
||||
{
|
||||
m0 += rhs.m0;
|
||||
m1 += rhs.m1;
|
||||
getM(1) += rhs.getM(1);
|
||||
|
||||
if (common::addOverflow(m2, rhs.m2, m2))
|
||||
if (common::addOverflow(getM(2), rhs.getM(2), getM(2)))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
if constexpr (_level >= 3)
|
||||
if (common::addOverflow(getM(3), rhs.getM(3), getM(3)))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
if constexpr (_level >= 4)
|
||||
if (common::addOverflow(getM(4), rhs.getM(4), getM(4)))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
}
|
||||
|
||||
void write(WriteBuffer & buf) const { writePODBinary(*this, buf); }
|
||||
@ -120,8 +177,8 @@ struct VarMomentsDecimal
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
|
||||
NativeType tmp;
|
||||
if (common::mulOverflow(m1, m1, tmp) ||
|
||||
common::subOverflow(m2, NativeType(tmp / m0), tmp))
|
||||
if (common::mulOverflow(getM(1), getM(1), tmp) ||
|
||||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
|
||||
}
|
||||
@ -134,15 +191,50 @@ struct VarMomentsDecimal
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
|
||||
NativeType tmp;
|
||||
if (common::mulOverflow(m1, m1, tmp) ||
|
||||
common::subOverflow(m2, NativeType(tmp / m0), tmp))
|
||||
if (common::mulOverflow(getM(1), getM(1), tmp) ||
|
||||
common::subOverflow(getM(2), NativeType(tmp / m0), tmp))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / (m0 - 1), scale);
|
||||
}
|
||||
|
||||
Float64 get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
|
||||
Float64 getMoment3(UInt32 scale) const
|
||||
{
|
||||
if (m0 == 0)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
|
||||
NativeType tmp;
|
||||
if (common::mulOverflow(2 * getM(1), getM(1), tmp) ||
|
||||
common::subOverflow(3 * getM(2), NativeType(tmp / m0), tmp) ||
|
||||
common::mulOverflow(tmp, getM(1), tmp) ||
|
||||
common::subOverflow(getM(3), NativeType(tmp / m0), tmp))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
|
||||
}
|
||||
|
||||
Float64 getMoment4(UInt32 scale) const
|
||||
{
|
||||
if (m0 == 0)
|
||||
return std::numeric_limits<Float64>::infinity();
|
||||
|
||||
NativeType tmp;
|
||||
if (common::mulOverflow(3 * getM(1), getM(1), tmp) ||
|
||||
common::subOverflow(6 * getM(2), NativeType(tmp / m0), tmp) ||
|
||||
common::mulOverflow(tmp, getM(1), tmp) ||
|
||||
common::subOverflow(4 * getM(3), NativeType(tmp / m0), tmp) ||
|
||||
common::mulOverflow(tmp, getM(1), tmp) ||
|
||||
common::subOverflow(getM(4), NativeType(tmp / m0), tmp))
|
||||
throw Exception("Decimal math overflow", ErrorCodes::DECIMAL_OVERFLOW);
|
||||
return convertFromDecimal<DataTypeDecimal<T>, DataTypeNumber<Float64>>(tmp / m0, scale);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
Calculating multivariate central moments
|
||||
Levels:
|
||||
level 2 (pop & samp): covar
|
||||
References:
|
||||
https://en.wikipedia.org/wiki/Moment_(mathematics)
|
||||
*/
|
||||
template <typename T>
|
||||
struct CovarMoments
|
||||
{
|
||||
@ -188,8 +280,6 @@ struct CovarMoments
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
return (xy - x1 * y1 / m0) / (m0 - 1);
|
||||
}
|
||||
|
||||
T get() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -236,9 +326,6 @@ struct CorrMoments
|
||||
{
|
||||
return (m0 * xy - x1 * y1) / sqrt((m0 * x2 - x1 * x1) * (m0 * y2 - y1 * y1));
|
||||
}
|
||||
|
||||
T getPopulation() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
|
||||
T getSample() const { throw Exception("Unexpected call", ErrorCodes::LOGICAL_ERROR); }
|
||||
};
|
||||
|
||||
|
||||
@ -246,18 +333,20 @@ enum class StatisticsFunctionKind
|
||||
{
|
||||
varPop, varSamp,
|
||||
stddevPop, stddevSamp,
|
||||
skewPop, skewSamp,
|
||||
kurtPop, kurtSamp,
|
||||
covarPop, covarSamp,
|
||||
corr
|
||||
};
|
||||
|
||||
|
||||
template <typename T, StatisticsFunctionKind _kind>
|
||||
template <typename T, StatisticsFunctionKind _kind, size_t _level>
|
||||
struct StatFuncOneArg
|
||||
{
|
||||
using Type1 = T;
|
||||
using Type2 = T;
|
||||
using ResultType = std::conditional_t<std::is_same_v<T, Float32>, Float32, Float64>;
|
||||
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128>, VarMoments<ResultType>>;
|
||||
using Data = std::conditional_t<IsDecimalNumber<T>, VarMomentsDecimal<Decimal128, _level>, VarMoments<ResultType, _level>>;
|
||||
|
||||
static constexpr StatisticsFunctionKind kind = _kind;
|
||||
static constexpr UInt32 num_args = 1;
|
||||
@ -300,17 +389,28 @@ public:
|
||||
|
||||
String getName() const override
|
||||
{
|
||||
switch (StatFunc::kind)
|
||||
{
|
||||
case StatisticsFunctionKind::varPop: return "varPop";
|
||||
case StatisticsFunctionKind::varSamp: return "varSamp";
|
||||
case StatisticsFunctionKind::stddevPop: return "stddevPop";
|
||||
case StatisticsFunctionKind::stddevSamp: return "stddevSamp";
|
||||
case StatisticsFunctionKind::covarPop: return "covarPop";
|
||||
case StatisticsFunctionKind::covarSamp: return "covarSamp";
|
||||
case StatisticsFunctionKind::corr: return "corr";
|
||||
}
|
||||
__builtin_unreachable();
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
|
||||
return "varPop";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
|
||||
return "varSamp";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
|
||||
return "stddevPop";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
|
||||
return "stddevSamp";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
|
||||
return "skewPop";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
|
||||
return "skewSamp";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
|
||||
return "kurtPop";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
|
||||
return "kurtSamp";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop)
|
||||
return "covarPop";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp)
|
||||
return "covarSamp";
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::corr)
|
||||
return "corr";
|
||||
}
|
||||
|
||||
DataTypePtr getReturnType() const override
|
||||
@ -351,28 +451,103 @@ public:
|
||||
|
||||
if constexpr (IsDecimalNumber<T1>)
|
||||
{
|
||||
switch (StatFunc::kind)
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
|
||||
dst.push_back(data.getPopulation(src_scale * 2));
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
|
||||
dst.push_back(data.getSample(src_scale * 2));
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
|
||||
dst.push_back(sqrt(data.getPopulation(src_scale * 2)));
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
|
||||
dst.push_back(sqrt(data.getSample(src_scale * 2)));
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
|
||||
{
|
||||
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation(src_scale * 2)); break;
|
||||
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample(src_scale * 2)); break;
|
||||
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation(src_scale * 2))); break;
|
||||
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample(src_scale * 2))); break;
|
||||
default:
|
||||
__builtin_unreachable();
|
||||
Float64 var_value = data.getPopulation(src_scale * 2);
|
||||
|
||||
if (var_value > 0)
|
||||
dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5));
|
||||
else
|
||||
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
|
||||
}
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
|
||||
{
|
||||
Float64 var_value = data.getSample(src_scale * 2);
|
||||
|
||||
if (var_value > 0)
|
||||
dst.push_back(data.getMoment3(src_scale * 3) / pow(var_value, 1.5));
|
||||
else
|
||||
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
|
||||
}
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
|
||||
{
|
||||
Float64 var_value = data.getPopulation(src_scale * 2);
|
||||
|
||||
if (var_value > 0)
|
||||
dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2));
|
||||
else
|
||||
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
|
||||
}
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
|
||||
{
|
||||
Float64 var_value = data.getSample(src_scale * 2);
|
||||
|
||||
if (var_value > 0)
|
||||
dst.push_back(data.getMoment4(src_scale * 4) / pow(var_value, 2));
|
||||
else
|
||||
dst.push_back(std::numeric_limits<Float64>::quiet_NaN());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
switch (StatFunc::kind)
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::varPop)
|
||||
dst.push_back(data.getPopulation());
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::varSamp)
|
||||
dst.push_back(data.getSample());
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevPop)
|
||||
dst.push_back(sqrt(data.getPopulation()));
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::stddevSamp)
|
||||
dst.push_back(sqrt(data.getSample()));
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewPop)
|
||||
{
|
||||
case StatisticsFunctionKind::varPop: dst.push_back(data.getPopulation()); break;
|
||||
case StatisticsFunctionKind::varSamp: dst.push_back(data.getSample()); break;
|
||||
case StatisticsFunctionKind::stddevPop: dst.push_back(sqrt(data.getPopulation())); break;
|
||||
case StatisticsFunctionKind::stddevSamp: dst.push_back(sqrt(data.getSample())); break;
|
||||
case StatisticsFunctionKind::covarPop: dst.push_back(data.getPopulation()); break;
|
||||
case StatisticsFunctionKind::covarSamp: dst.push_back(data.getSample()); break;
|
||||
case StatisticsFunctionKind::corr: dst.push_back(data.get()); break;
|
||||
ResultType var_value = data.getPopulation();
|
||||
|
||||
if (var_value > 0)
|
||||
dst.push_back(data.getMoment3() / pow(var_value, 1.5));
|
||||
else
|
||||
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
|
||||
}
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::skewSamp)
|
||||
{
|
||||
ResultType var_value = data.getSample();
|
||||
|
||||
if (var_value > 0)
|
||||
dst.push_back(data.getMoment3() / pow(var_value, 1.5));
|
||||
else
|
||||
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
|
||||
}
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtPop)
|
||||
{
|
||||
ResultType var_value = data.getPopulation();
|
||||
|
||||
if (var_value > 0)
|
||||
dst.push_back(data.getMoment4() / pow(var_value, 2));
|
||||
else
|
||||
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
|
||||
}
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::kurtSamp)
|
||||
{
|
||||
ResultType var_value = data.getSample();
|
||||
|
||||
if (var_value > 0)
|
||||
dst.push_back(data.getMoment4() / pow(var_value, 2));
|
||||
else
|
||||
dst.push_back(std::numeric_limits<ResultType>::quiet_NaN());
|
||||
}
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarPop)
|
||||
dst.push_back(data.getPopulation());
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::covarSamp)
|
||||
dst.push_back(data.getSample());
|
||||
if constexpr (StatFunc::kind == StatisticsFunctionKind::corr)
|
||||
dst.push_back(data.get());
|
||||
}
|
||||
}
|
||||
|
||||
@ -383,10 +558,14 @@ private:
|
||||
};
|
||||
|
||||
|
||||
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop>>;
|
||||
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp>>;
|
||||
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop>>;
|
||||
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp>>;
|
||||
template <typename T> using AggregateFunctionVarPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varPop, 2>>;
|
||||
template <typename T> using AggregateFunctionVarSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::varSamp, 2>>;
|
||||
template <typename T> using AggregateFunctionStddevPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevPop, 2>>;
|
||||
template <typename T> using AggregateFunctionStddevSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::stddevSamp, 2>>;
|
||||
template <typename T> using AggregateFunctionSkewPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::skewPop, 3>>;
|
||||
template <typename T> using AggregateFunctionSkewSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::skewSamp, 3>>;
|
||||
template <typename T> using AggregateFunctionKurtPopSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::kurtPop, 4>>;
|
||||
template <typename T> using AggregateFunctionKurtSampSimple = AggregateFunctionVarianceSimple<StatFuncOneArg<T, StatisticsFunctionKind::kurtSamp, 4>>;
|
||||
template <typename T1, typename T2> using AggregateFunctionCovarPopSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarPop>>;
|
||||
template <typename T1, typename T2> using AggregateFunctionCovarSampSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::covarSamp>>;
|
||||
template <typename T1, typename T2> using AggregateFunctionCorrSimple = AggregateFunctionVarianceSimple<StatFuncTwoArg<T1, T2, StatisticsFunctionKind::corr>>;
|
||||
|
@ -48,6 +48,12 @@ public:
|
||||
/// Get the result type.
|
||||
virtual DataTypePtr getReturnType() const = 0;
|
||||
|
||||
/// Get type which will be used for prediction result in case if function is an ML method.
|
||||
virtual DataTypePtr getReturnTypeToPredict() const
|
||||
{
|
||||
throw Exception("Prediction is not supported for " + getName(), ErrorCodes::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
virtual ~IAggregateFunction() {}
|
||||
|
||||
/** Data manipulating functions. */
|
||||
|
@ -30,7 +30,7 @@ void registerAggregateFunctionsBitmap(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionsMaxIntersections(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionMLMethod(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionEntropy(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionLeastSqr(AggregateFunctionFactory &);
|
||||
void registerAggregateFunctionSimpleLinearRegression(AggregateFunctionFactory &);
|
||||
|
||||
void registerAggregateFunctionCombinatorIf(AggregateFunctionCombinatorFactory &);
|
||||
void registerAggregateFunctionCombinatorArray(AggregateFunctionCombinatorFactory &);
|
||||
@ -73,7 +73,7 @@ void registerAggregateFunctions()
|
||||
registerAggregateFunctionTimeSeriesGroupSum(factory);
|
||||
registerAggregateFunctionMLMethod(factory);
|
||||
registerAggregateFunctionEntropy(factory);
|
||||
registerAggregateFunctionLeastSqr(factory);
|
||||
registerAggregateFunctionSimpleLinearRegression(factory);
|
||||
}
|
||||
|
||||
{
|
||||
|
@ -35,25 +35,6 @@ void ColumnAggregateFunction::addArena(ArenaPtr arena_)
|
||||
arenas.push_back(arena_);
|
||||
}
|
||||
|
||||
/// This function is used in convertToValues() and predictValues()
|
||||
/// and is written here to avoid repetitions
|
||||
bool ColumnAggregateFunction::tryFinalizeAggregateFunction(MutableColumnPtr *res_) const
|
||||
{
|
||||
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
*res_ = std::move(res);
|
||||
return true;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
*res_ = std::move(res);
|
||||
return false;
|
||||
}
|
||||
|
||||
MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
{
|
||||
/** If the aggregate function returns an unfinalized/unfinished state,
|
||||
@ -86,17 +67,17 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
* AggregateFunction(quantileTiming(0.5), UInt64)
|
||||
* into UInt16 - already finished result of `quantileTiming`.
|
||||
*/
|
||||
|
||||
/** Convertion function is used in convertToValues and predictValues
|
||||
* in the similar part of both functions
|
||||
*/
|
||||
|
||||
MutableColumnPtr res;
|
||||
if (tryFinalizeAggregateFunction(&res))
|
||||
if (const AggregateFunctionState *function_state = typeid_cast<const AggregateFunctionState *>(func.get()))
|
||||
{
|
||||
auto res = createView();
|
||||
res->set(function_state->getNestedFunction());
|
||||
res->data.assign(data.begin(), data.end());
|
||||
return res;
|
||||
}
|
||||
|
||||
MutableColumnPtr res = func->getReturnType()->createColumn();
|
||||
res->reserve(data.size());
|
||||
|
||||
for (auto val : data)
|
||||
func->insertResultInto(val, *res);
|
||||
|
||||
@ -105,8 +86,8 @@ MutableColumnPtr ColumnAggregateFunction::convertToValues() const
|
||||
|
||||
MutableColumnPtr ColumnAggregateFunction::predictValues(Block & block, const ColumnNumbers & arguments, const Context & context) const
|
||||
{
|
||||
MutableColumnPtr res;
|
||||
tryFinalizeAggregateFunction(&res);
|
||||
MutableColumnPtr res = func->getReturnTypeToPredict()->createColumn();
|
||||
res->reserve(data.size());
|
||||
|
||||
auto ML_function = func.get();
|
||||
if (ML_function)
|
||||
|
@ -427,6 +427,8 @@ namespace ErrorCodes
|
||||
extern const int BAD_TTL_EXPRESSION = 450;
|
||||
extern const int BAD_TTL_FILE = 451;
|
||||
extern const int SETTING_CONSTRAINT_VIOLATION = 452;
|
||||
extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES = 453;
|
||||
extern const int OPENSSL_ERROR = 454;
|
||||
|
||||
extern const int KEEPER_EXCEPTION = 999;
|
||||
extern const int POCO_EXCEPTION = 1000;
|
||||
|
18
dbms/src/Common/OpenSSLHelpers.cpp
Normal file
18
dbms/src/Common/OpenSSLHelpers.cpp
Normal file
@ -0,0 +1,18 @@
|
||||
#include "OpenSSLHelpers.h"
|
||||
#include <ext/scope_guard.h>
|
||||
#include <openssl/err.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
String getOpenSSLErrors()
|
||||
{
|
||||
BIO * mem = BIO_new(BIO_s_mem());
|
||||
SCOPE_EXIT(BIO_free(mem));
|
||||
ERR_print_errors(mem);
|
||||
char * buf = nullptr;
|
||||
long size = BIO_get_mem_data(mem, &buf);
|
||||
return String(buf, size);
|
||||
}
|
||||
|
||||
}
|
12
dbms/src/Common/OpenSSLHelpers.h
Normal file
12
dbms/src/Common/OpenSSLHelpers.h
Normal file
@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/// Returns concatenation of error strings for all errors that OpenSSL has recorded, emptying the error queue.
|
||||
String getOpenSSLErrors();
|
||||
|
||||
}
|
@ -56,6 +56,8 @@
|
||||
|
||||
#define DBMS_MIN_REVISION_WITH_LOW_CARDINALITY_TYPE 54405
|
||||
|
||||
#define DBMS_MIN_REVISION_WITH_CLIENT_WRITE_INFO 54421
|
||||
|
||||
/// Version of ClickHouse TCP protocol. Set to git tag with latest protocol change.
|
||||
#define DBMS_TCP_PROTOCOL_VERSION 54226
|
||||
|
||||
|
94
dbms/src/Core/MySQLProtocol.cpp
Normal file
94
dbms/src/Core/MySQLProtocol.cpp
Normal file
@ -0,0 +1,94 @@
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/ReadBufferFromString.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <common/logger_useful.h>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
#include "MySQLProtocol.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace MySQLProtocol
|
||||
{
|
||||
|
||||
void PacketSender::resetSequenceId()
|
||||
{
|
||||
sequence_id = 0;
|
||||
}
|
||||
|
||||
String PacketSender::packetToText(String payload)
|
||||
{
|
||||
String result;
|
||||
for (auto c : payload)
|
||||
{
|
||||
result += ' ';
|
||||
result += std::to_string(static_cast<unsigned char>(c));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
uint64_t readLengthEncodedNumber(std::istringstream & ss)
|
||||
{
|
||||
char c;
|
||||
uint64_t buf = 0;
|
||||
ss.get(c);
|
||||
auto cc = static_cast<uint8_t>(c);
|
||||
if (cc < 0xfc)
|
||||
{
|
||||
return cc;
|
||||
}
|
||||
else if (cc < 0xfd)
|
||||
{
|
||||
ss.read(reinterpret_cast<char *>(&buf), 2);
|
||||
}
|
||||
else if (cc < 0xfe)
|
||||
{
|
||||
ss.read(reinterpret_cast<char *>(&buf), 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
ss.read(reinterpret_cast<char *>(&buf), 8);
|
||||
}
|
||||
return buf;
|
||||
}
|
||||
|
||||
std::string writeLengthEncodedNumber(uint64_t x)
|
||||
{
|
||||
std::string result;
|
||||
if (x < 251)
|
||||
{
|
||||
result.append(1, static_cast<char>(x));
|
||||
}
|
||||
else if (x < (1 << 16))
|
||||
{
|
||||
result.append(1, 0xfc);
|
||||
result.append(reinterpret_cast<char *>(&x), 2);
|
||||
}
|
||||
else if (x < (1 << 24))
|
||||
{
|
||||
result.append(1, 0xfd);
|
||||
result.append(reinterpret_cast<char *>(&x), 3);
|
||||
}
|
||||
else
|
||||
{
|
||||
result.append(1, 0xfe);
|
||||
result.append(reinterpret_cast<char *>(&x), 8);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void writeLengthEncodedString(std::string & payload, const std::string & s)
|
||||
{
|
||||
payload.append(writeLengthEncodedNumber(s.length()));
|
||||
payload.append(s);
|
||||
}
|
||||
|
||||
void writeNulTerminatedString(std::string & payload, const std::string & s)
|
||||
{
|
||||
payload.append(s);
|
||||
payload.append(1, 0);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
649
dbms/src/Core/MySQLProtocol.h
Normal file
649
dbms/src/Core/MySQLProtocol.h
Normal file
@ -0,0 +1,649 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/Types.h>
|
||||
#include <IO/copyData.h>
|
||||
#include <IO/ReadBuffer.h>
|
||||
#include <IO/ReadBufferFromPocoSocket.h>
|
||||
#include <IO/WriteBuffer.h>
|
||||
#include <IO/WriteBufferFromPocoSocket.h>
|
||||
#include <IO/WriteBufferFromString.h>
|
||||
#include <Poco/Net/StreamSocket.h>
|
||||
#include <Poco/RandomStream.h>
|
||||
#include <random>
|
||||
#include <sstream>
|
||||
|
||||
/// Implementation of MySQL wire protocol
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int UNKNOWN_PACKET_FROM_CLIENT;
|
||||
}
|
||||
|
||||
namespace MySQLProtocol
|
||||
{
|
||||
|
||||
const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
|
||||
const size_t SCRAMBLE_LENGTH = 20;
|
||||
const size_t AUTH_PLUGIN_DATA_PART_1_LENGTH = 8;
|
||||
const size_t MYSQL_ERRMSG_SIZE = 512;
|
||||
const size_t PACKET_HEADER_SIZE = 4;
|
||||
const size_t SSL_REQUEST_PAYLOAD_SIZE = 32;
|
||||
|
||||
namespace Authentication
|
||||
{
|
||||
const String SHA256 = "sha256_password"; /// Caching SHA2 plugin is not used because it would be possible to authenticate knowing hash from users.xml.
|
||||
}
|
||||
|
||||
enum CharacterSet
|
||||
{
|
||||
utf8_general_ci = 33,
|
||||
binary = 63
|
||||
};
|
||||
|
||||
enum StatusFlags
|
||||
{
|
||||
SERVER_SESSION_STATE_CHANGED = 0x4000
|
||||
};
|
||||
|
||||
enum Capability
|
||||
{
|
||||
CLIENT_CONNECT_WITH_DB = 0x00000008,
|
||||
CLIENT_PROTOCOL_41 = 0x00000200,
|
||||
CLIENT_SSL = 0x00000800,
|
||||
CLIENT_TRANSACTIONS = 0x00002000, // TODO
|
||||
CLIENT_SESSION_TRACK = 0x00800000, // TODO
|
||||
CLIENT_SECURE_CONNECTION = 0x00008000,
|
||||
CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000,
|
||||
CLIENT_PLUGIN_AUTH = 0x00080000,
|
||||
CLIENT_DEPRECATE_EOF = 0x01000000,
|
||||
};
|
||||
|
||||
enum Command
|
||||
{
|
||||
COM_SLEEP = 0x0,
|
||||
COM_QUIT = 0x1,
|
||||
COM_INIT_DB = 0x2,
|
||||
COM_QUERY = 0x3,
|
||||
COM_FIELD_LIST = 0x4,
|
||||
COM_CREATE_DB = 0x5,
|
||||
COM_DROP_DB = 0x6,
|
||||
COM_REFRESH = 0x7,
|
||||
COM_SHUTDOWN = 0x8,
|
||||
COM_STATISTICS = 0x9,
|
||||
COM_PROCESS_INFO = 0xa,
|
||||
COM_CONNECT = 0xb,
|
||||
COM_PROCESS_KILL = 0xc,
|
||||
COM_DEBUG = 0xd,
|
||||
COM_PING = 0xe,
|
||||
COM_TIME = 0xf,
|
||||
COM_DELAYED_INSERT = 0x10,
|
||||
COM_CHANGE_USER = 0x11,
|
||||
COM_RESET_CONNECTION = 0x1f,
|
||||
COM_DAEMON = 0x1d
|
||||
};
|
||||
|
||||
enum ColumnType
|
||||
{
|
||||
MYSQL_TYPE_DECIMAL = 0x00,
|
||||
MYSQL_TYPE_TINY = 0x01,
|
||||
MYSQL_TYPE_SHORT = 0x02,
|
||||
MYSQL_TYPE_LONG = 0x03,
|
||||
MYSQL_TYPE_FLOAT = 0x04,
|
||||
MYSQL_TYPE_DOUBLE = 0x05,
|
||||
MYSQL_TYPE_NULL = 0x06,
|
||||
MYSQL_TYPE_TIMESTAMP = 0x07,
|
||||
MYSQL_TYPE_LONGLONG = 0x08,
|
||||
MYSQL_TYPE_INT24 = 0x09,
|
||||
MYSQL_TYPE_DATE = 0x0a,
|
||||
MYSQL_TYPE_TIME = 0x0b,
|
||||
MYSQL_TYPE_DATETIME = 0x0c,
|
||||
MYSQL_TYPE_YEAR = 0x0d,
|
||||
MYSQL_TYPE_VARCHAR = 0x0f,
|
||||
MYSQL_TYPE_BIT = 0x10,
|
||||
MYSQL_TYPE_NEWDECIMAL = 0xf6,
|
||||
MYSQL_TYPE_ENUM = 0xf7,
|
||||
MYSQL_TYPE_SET = 0xf8,
|
||||
MYSQL_TYPE_TINY_BLOB = 0xf9,
|
||||
MYSQL_TYPE_MEDIUM_BLOB = 0xfa,
|
||||
MYSQL_TYPE_LONG_BLOB = 0xfb,
|
||||
MYSQL_TYPE_BLOB = 0xfc,
|
||||
MYSQL_TYPE_VAR_STRING = 0xfd,
|
||||
MYSQL_TYPE_STRING = 0xfe,
|
||||
MYSQL_TYPE_GEOMETRY = 0xff
|
||||
};
|
||||
|
||||
|
||||
class ProtocolError : public DB::Exception
|
||||
{
|
||||
public:
|
||||
using Exception::Exception;
|
||||
};
|
||||
|
||||
|
||||
class WritePacket
|
||||
{
|
||||
public:
|
||||
virtual String getPayload() const = 0;
|
||||
|
||||
virtual ~WritePacket() = default;
|
||||
};
|
||||
|
||||
|
||||
class ReadPacket
|
||||
{
|
||||
public:
|
||||
ReadPacket() = default;
|
||||
ReadPacket(const ReadPacket &) = default;
|
||||
virtual void readPayload(String payload) = 0;
|
||||
|
||||
virtual ~ReadPacket() = default;
|
||||
};
|
||||
|
||||
|
||||
/* Writes and reads packets, keeping sequence-id.
|
||||
* Throws ProtocolError, if packet with incorrect sequence-id was received.
|
||||
*/
|
||||
class PacketSender
|
||||
{
|
||||
public:
|
||||
size_t & sequence_id;
|
||||
ReadBuffer * in;
|
||||
WriteBuffer * out;
|
||||
size_t max_packet_size = MAX_PACKET_LENGTH;
|
||||
|
||||
/// For reading and writing.
|
||||
PacketSender(ReadBuffer & in, WriteBuffer & out, size_t & sequence_id)
|
||||
: sequence_id(sequence_id)
|
||||
, in(&in)
|
||||
, out(&out)
|
||||
{
|
||||
}
|
||||
|
||||
/// For writing.
|
||||
PacketSender(WriteBuffer & out, size_t & sequence_id)
|
||||
: sequence_id(sequence_id)
|
||||
, in(nullptr)
|
||||
, out(&out)
|
||||
{
|
||||
}
|
||||
|
||||
String receivePacketPayload()
|
||||
{
|
||||
WriteBufferFromOwnString buf;
|
||||
|
||||
size_t payload_length = 0;
|
||||
size_t packet_sequence_id = 0;
|
||||
|
||||
// packets which are larger than or equal to 16MB are splitted
|
||||
do
|
||||
{
|
||||
in->readStrict(reinterpret_cast<char *>(&payload_length), 3);
|
||||
|
||||
if (payload_length > max_packet_size)
|
||||
{
|
||||
std::ostringstream tmp;
|
||||
tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
|
||||
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
|
||||
in->readStrict(reinterpret_cast<char *>(&packet_sequence_id), 1);
|
||||
|
||||
if (packet_sequence_id != sequence_id)
|
||||
{
|
||||
std::ostringstream tmp;
|
||||
tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << sequence_id << '.';
|
||||
throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
sequence_id++;
|
||||
|
||||
copyData(*in, static_cast<WriteBuffer &>(buf), payload_length);
|
||||
} while (payload_length == max_packet_size);
|
||||
|
||||
return std::move(buf.str());
|
||||
}
|
||||
|
||||
void receivePacket(ReadPacket & packet)
|
||||
{
|
||||
packet.readPayload(receivePacketPayload());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void sendPacket(const T & packet, bool flush = false)
|
||||
{
|
||||
static_assert(std::is_base_of<WritePacket, T>());
|
||||
String payload = packet.getPayload();
|
||||
size_t pos = 0;
|
||||
do
|
||||
{
|
||||
size_t payload_length = std::min(payload.length() - pos, max_packet_size);
|
||||
|
||||
out->write(reinterpret_cast<const char *>(&payload_length), 3);
|
||||
out->write(reinterpret_cast<const char *>(&sequence_id), 1);
|
||||
out->write(payload.data() + pos, payload_length);
|
||||
|
||||
pos += payload_length;
|
||||
sequence_id++;
|
||||
} while (pos < payload.length());
|
||||
|
||||
if (flush)
|
||||
out->next();
|
||||
}
|
||||
|
||||
/// Sets sequence-id to 0. Must be called before each command phase.
|
||||
void resetSequenceId();
|
||||
|
||||
private:
|
||||
/// Converts packet to text. Is used for debug output.
|
||||
static String packetToText(String payload);
|
||||
};
|
||||
|
||||
|
||||
uint64_t readLengthEncodedNumber(std::istringstream & ss);
|
||||
|
||||
String writeLengthEncodedNumber(uint64_t x);
|
||||
|
||||
void writeLengthEncodedString(String & payload, const String & s);
|
||||
|
||||
void writeNulTerminatedString(String & payload, const String & s);
|
||||
|
||||
|
||||
class Handshake : public WritePacket
|
||||
{
|
||||
int protocol_version = 0xa;
|
||||
String server_version;
|
||||
uint32_t connection_id;
|
||||
uint32_t capability_flags;
|
||||
uint8_t character_set;
|
||||
uint32_t status_flags;
|
||||
String auth_plugin_data;
|
||||
public:
|
||||
explicit Handshake(uint32_t capability_flags, uint32_t connection_id, String server_version, String auth_plugin_data)
|
||||
: protocol_version(0xa)
|
||||
, server_version(std::move(server_version))
|
||||
, connection_id(connection_id)
|
||||
, capability_flags(capability_flags)
|
||||
, character_set(CharacterSet::utf8_general_ci)
|
||||
, status_flags(0)
|
||||
, auth_plugin_data(auth_plugin_data)
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, protocol_version);
|
||||
writeNulTerminatedString(result, server_version);
|
||||
result.append(reinterpret_cast<const char *>(&connection_id), 4);
|
||||
writeNulTerminatedString(result, auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH));
|
||||
result.append(reinterpret_cast<const char *>(&capability_flags), 2);
|
||||
result.append(reinterpret_cast<const char *>(&character_set), 1);
|
||||
result.append(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
result.append((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
|
||||
result.append(1, auth_plugin_data.size());
|
||||
result.append(10, 0x0);
|
||||
result.append(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH));
|
||||
result.append(Authentication::SHA256);
|
||||
result.append(1, 0x0);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class SSLRequest : public ReadPacket
|
||||
{
|
||||
public:
|
||||
uint32_t capability_flags;
|
||||
uint32_t max_packet_size;
|
||||
uint8_t character_set;
|
||||
|
||||
void readPayload(String s) override
|
||||
{
|
||||
std::istringstream ss(s);
|
||||
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
|
||||
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
|
||||
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
|
||||
}
|
||||
};
|
||||
|
||||
class HandshakeResponse : public ReadPacket
|
||||
{
|
||||
public:
|
||||
uint32_t capability_flags;
|
||||
uint32_t max_packet_size;
|
||||
uint8_t character_set;
|
||||
String username;
|
||||
String auth_response;
|
||||
String database;
|
||||
String auth_plugin_name;
|
||||
|
||||
HandshakeResponse() = default;
|
||||
|
||||
HandshakeResponse(const HandshakeResponse &) = default;
|
||||
|
||||
void readPayload(String s) override
|
||||
{
|
||||
std::istringstream ss(s);
|
||||
|
||||
ss.readsome(reinterpret_cast<char *>(&capability_flags), 4);
|
||||
ss.readsome(reinterpret_cast<char *>(&max_packet_size), 4);
|
||||
ss.readsome(reinterpret_cast<char *>(&character_set), 1);
|
||||
ss.ignore(23);
|
||||
|
||||
std::getline(ss, username, static_cast<char>(0x0));
|
||||
|
||||
if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
|
||||
{
|
||||
auto len = readLengthEncodedNumber(ss);
|
||||
auth_response.resize(len);
|
||||
ss.read(auth_response.data(), static_cast<std::streamsize>(len));
|
||||
}
|
||||
else if (capability_flags & CLIENT_SECURE_CONNECTION)
|
||||
{
|
||||
uint8_t len;
|
||||
ss.read(reinterpret_cast<char *>(&len), 1);
|
||||
auth_response.resize(len);
|
||||
ss.read(auth_response.data(), len);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::getline(ss, auth_response, static_cast<char>(0x0));
|
||||
}
|
||||
|
||||
if (capability_flags & CLIENT_CONNECT_WITH_DB)
|
||||
{
|
||||
std::getline(ss, database, static_cast<char>(0x0));
|
||||
}
|
||||
|
||||
if (capability_flags & CLIENT_PLUGIN_AUTH)
|
||||
{
|
||||
std::getline(ss, auth_plugin_name, static_cast<char>(0x0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class AuthSwitchRequest : public WritePacket
|
||||
{
|
||||
String plugin_name;
|
||||
String auth_plugin_data;
|
||||
public:
|
||||
AuthSwitchRequest(String plugin_name, String auth_plugin_data)
|
||||
: plugin_name(std::move(plugin_name)), auth_plugin_data(std::move(auth_plugin_data))
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, 0xfe);
|
||||
writeNulTerminatedString(result, plugin_name);
|
||||
result.append(auth_plugin_data);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class AuthSwitchResponse : public ReadPacket
|
||||
{
|
||||
public:
|
||||
String value;
|
||||
|
||||
void readPayload(String s) override
|
||||
{
|
||||
value = std::move(s);
|
||||
}
|
||||
};
|
||||
|
||||
class AuthMoreData : public WritePacket
|
||||
{
|
||||
String data;
|
||||
public:
|
||||
AuthMoreData(String data): data(std::move(data)) {}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, 0x01);
|
||||
result.append(data);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/// Packet with a single null-terminated string. Is used for clear text authentication.
|
||||
class NullTerminatedString : public ReadPacket
|
||||
{
|
||||
public:
|
||||
String value;
|
||||
|
||||
void readPayload(String s) override
|
||||
{
|
||||
if (s.length() == 0 || s.back() != 0)
|
||||
{
|
||||
throw ProtocolError("String is not null terminated.", ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
|
||||
}
|
||||
value = s;
|
||||
value.pop_back();
|
||||
}
|
||||
};
|
||||
|
||||
class OK_Packet : public WritePacket
|
||||
{
|
||||
uint8_t header;
|
||||
uint32_t capabilities;
|
||||
uint64_t affected_rows;
|
||||
int16_t warnings = 0;
|
||||
uint32_t status_flags;
|
||||
String session_state_changes;
|
||||
String info;
|
||||
public:
|
||||
OK_Packet(uint8_t header,
|
||||
uint32_t capabilities,
|
||||
uint64_t affected_rows,
|
||||
uint32_t status_flags,
|
||||
int16_t warnings,
|
||||
String session_state_changes = "",
|
||||
String info = "")
|
||||
: header(header)
|
||||
, capabilities(capabilities)
|
||||
, affected_rows(affected_rows)
|
||||
, warnings(warnings)
|
||||
, status_flags(status_flags)
|
||||
, session_state_changes(std::move(session_state_changes))
|
||||
, info(info)
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, header);
|
||||
result.append(writeLengthEncodedNumber(affected_rows));
|
||||
result.append(writeLengthEncodedNumber(0)); /// last insert-id
|
||||
|
||||
if (capabilities & CLIENT_PROTOCOL_41)
|
||||
{
|
||||
result.append(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
result.append(reinterpret_cast<const char *>(&warnings), 2);
|
||||
}
|
||||
else if (capabilities & CLIENT_TRANSACTIONS)
|
||||
{
|
||||
result.append(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
}
|
||||
|
||||
if (capabilities & CLIENT_SESSION_TRACK)
|
||||
{
|
||||
result.append(writeLengthEncodedNumber(info.length()));
|
||||
result.append(info);
|
||||
if (status_flags & SERVER_SESSION_STATE_CHANGED)
|
||||
{
|
||||
result.append(writeLengthEncodedNumber(session_state_changes.length()));
|
||||
result.append(session_state_changes);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
result.append(info);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class EOF_Packet : public WritePacket
|
||||
{
|
||||
int warnings;
|
||||
int status_flags;
|
||||
public:
|
||||
EOF_Packet(int warnings, int status_flags) : warnings(warnings), status_flags(status_flags)
|
||||
{}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, 0xfe); // EOF header
|
||||
result.append(reinterpret_cast<const char *>(&warnings), 2);
|
||||
result.append(reinterpret_cast<const char *>(&status_flags), 2);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class ERR_Packet : public WritePacket
|
||||
{
|
||||
int error_code;
|
||||
String sql_state;
|
||||
String error_message;
|
||||
public:
|
||||
ERR_Packet(int error_code, String sql_state, String error_message)
|
||||
: error_code(error_code), sql_state(std::move(sql_state)), error_message(std::move(error_message))
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
String result;
|
||||
result.append(1, 0xff);
|
||||
result.append(reinterpret_cast<const char *>(&error_code), 2);
|
||||
result.append("#", 1);
|
||||
result.append(sql_state.data(), sql_state.length());
|
||||
result.append(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class ColumnDefinition : public WritePacket
|
||||
{
|
||||
String schema;
|
||||
String table;
|
||||
String org_table;
|
||||
String name;
|
||||
String org_name;
|
||||
size_t next_length = 0x0c;
|
||||
uint16_t character_set;
|
||||
uint32_t column_length;
|
||||
ColumnType column_type;
|
||||
uint16_t flags;
|
||||
uint8_t decimals = 0x00;
|
||||
public:
|
||||
ColumnDefinition(
|
||||
String schema,
|
||||
String table,
|
||||
String org_table,
|
||||
String name,
|
||||
String org_name,
|
||||
uint16_t character_set,
|
||||
uint32_t column_length,
|
||||
ColumnType column_type,
|
||||
uint16_t flags,
|
||||
uint8_t decimals)
|
||||
|
||||
: schema(std::move(schema)), table(std::move(table)), org_table(std::move(org_table)), name(std::move(name)),
|
||||
org_name(std::move(org_name)), character_set(character_set), column_length(column_length), column_type(column_type), flags(flags),
|
||||
decimals(decimals)
|
||||
{
|
||||
}
|
||||
|
||||
/// Should be used when column metadata (original name, table, original table, database) is unknown.
|
||||
ColumnDefinition(
|
||||
String name,
|
||||
uint16_t character_set,
|
||||
uint32_t column_length,
|
||||
ColumnType column_type,
|
||||
uint16_t flags,
|
||||
uint8_t decimals)
|
||||
: ColumnDefinition("", "", "", std::move(name), "", character_set, column_length, column_type, flags, decimals)
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
String result;
|
||||
writeLengthEncodedString(result, "def"); /// always "def"
|
||||
writeLengthEncodedString(result, ""); /// schema
|
||||
writeLengthEncodedString(result, ""); /// table
|
||||
writeLengthEncodedString(result, ""); /// org_table
|
||||
writeLengthEncodedString(result, name);
|
||||
writeLengthEncodedString(result, ""); /// org_name
|
||||
result.append(writeLengthEncodedNumber(next_length));
|
||||
result.append(reinterpret_cast<const char *>(&character_set), 2);
|
||||
result.append(reinterpret_cast<const char *>(&column_length), 4);
|
||||
result.append(reinterpret_cast<const char *>(&column_type), 1);
|
||||
result.append(reinterpret_cast<const char *>(&flags), 2);
|
||||
result.append(reinterpret_cast<const char *>(&decimals), 2);
|
||||
result.append(2, 0x0);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
class ComFieldList : public ReadPacket
|
||||
{
|
||||
public:
|
||||
String table, field_wildcard;
|
||||
|
||||
void readPayload(String payload)
|
||||
{
|
||||
std::istringstream ss(payload);
|
||||
ss.ignore(1); // command byte
|
||||
std::getline(ss, table, static_cast<char>(0x0));
|
||||
field_wildcard = payload.substr(table.length() + 2); // rest of the packet
|
||||
}
|
||||
};
|
||||
|
||||
class LengthEncodedNumber : public WritePacket
|
||||
{
|
||||
uint64_t value;
|
||||
public:
|
||||
LengthEncodedNumber(uint64_t value): value(value)
|
||||
{
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
return writeLengthEncodedNumber(value);
|
||||
}
|
||||
};
|
||||
|
||||
class ResultsetRow : public WritePacket
|
||||
{
|
||||
std::vector<String> columns;
|
||||
public:
|
||||
ResultsetRow()
|
||||
{
|
||||
}
|
||||
|
||||
void appendColumn(String value)
|
||||
{
|
||||
columns.emplace_back(std::move(value));
|
||||
}
|
||||
|
||||
String getPayload() const override
|
||||
{
|
||||
String result;
|
||||
for (const String & column : columns)
|
||||
{
|
||||
writeLengthEncodedString(result, column);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
@ -170,6 +170,46 @@ template <> struct NativeType<Decimal32> { using Type = Int32; };
|
||||
template <> struct NativeType<Decimal64> { using Type = Int64; };
|
||||
template <> struct NativeType<Decimal128> { using Type = Int128; };
|
||||
|
||||
inline const char * getTypeName(TypeIndex idx)
|
||||
{
|
||||
switch (idx)
|
||||
{
|
||||
case TypeIndex::Nothing: return "Nothing";
|
||||
case TypeIndex::UInt8: return TypeName<UInt8>::get();
|
||||
case TypeIndex::UInt16: return TypeName<UInt16>::get();
|
||||
case TypeIndex::UInt32: return TypeName<UInt32>::get();
|
||||
case TypeIndex::UInt64: return TypeName<UInt64>::get();
|
||||
case TypeIndex::UInt128: return "UInt128";
|
||||
case TypeIndex::Int8: return TypeName<Int8>::get();
|
||||
case TypeIndex::Int16: return TypeName<Int16>::get();
|
||||
case TypeIndex::Int32: return TypeName<Int32>::get();
|
||||
case TypeIndex::Int64: return TypeName<Int64>::get();
|
||||
case TypeIndex::Int128: return TypeName<Int128>::get();
|
||||
case TypeIndex::Float32: return TypeName<Float32>::get();
|
||||
case TypeIndex::Float64: return TypeName<Float64>::get();
|
||||
case TypeIndex::Date: return "Date";
|
||||
case TypeIndex::DateTime: return "DateTime";
|
||||
case TypeIndex::String: return TypeName<String>::get();
|
||||
case TypeIndex::FixedString: return "FixedString";
|
||||
case TypeIndex::Enum8: return "Enum8";
|
||||
case TypeIndex::Enum16: return "Enum16";
|
||||
case TypeIndex::Decimal32: return TypeName<Decimal32>::get();
|
||||
case TypeIndex::Decimal64: return TypeName<Decimal64>::get();
|
||||
case TypeIndex::Decimal128: return TypeName<Decimal128>::get();
|
||||
case TypeIndex::UUID: return "UUID";
|
||||
case TypeIndex::Array: return "Array";
|
||||
case TypeIndex::Tuple: return "Tuple";
|
||||
case TypeIndex::Set: return "Set";
|
||||
case TypeIndex::Interval: return "Interval";
|
||||
case TypeIndex::Nullable: return "Nullable";
|
||||
case TypeIndex::Function: return "Function";
|
||||
case TypeIndex::AggregateFunction: return "AggregateFunction";
|
||||
case TypeIndex::LowCardinality: return "LowCardinality";
|
||||
}
|
||||
|
||||
__builtin_unreachable();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/// Specialization of `std::hash` for the Decimal<T> types.
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <Functions/IFunction.h>
|
||||
#include <IO/WriteBufferFromOStream.h>
|
||||
#include <Interpreters/ExpressionAnalyzer.h>
|
||||
#include <Interpreters/ExpressionActions.h>
|
||||
#include <Parsers/IAST.h>
|
||||
#include <Storages/IStorage.h>
|
||||
#include <Common/COW.h>
|
||||
@ -70,7 +71,7 @@ std::ostream & operator<<(std::ostream & stream, const Block & what)
|
||||
|
||||
std::ostream & operator<<(std::ostream & stream, const ColumnWithTypeAndName & what)
|
||||
{
|
||||
stream << "ColumnWithTypeAndName(name = " << what.name << ", type = " << what.type << ", column = ";
|
||||
stream << "ColumnWithTypeAndName(name = " << what.name << ", type = " << *what.type << ", column = ";
|
||||
return dumpValue(stream, what.column) << ")";
|
||||
}
|
||||
|
||||
@ -109,4 +110,56 @@ std::ostream & operator<<(std::ostream & stream, const IAST & what)
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what)
|
||||
{
|
||||
stream << "ExpressionAction(" << what.toString() << ")";
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what)
|
||||
{
|
||||
stream << "ExpressionActions(" << what.dumpActions() << ")";
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::ostream & operator<<(std::ostream & stream, const SyntaxAnalyzerResult & what)
|
||||
{
|
||||
stream << "SyntaxAnalyzerResult{";
|
||||
stream << "storage=" << what.storage << "; ";
|
||||
if (!what.source_columns.empty())
|
||||
{
|
||||
stream << "source_columns=";
|
||||
dumpValue(stream, what.source_columns);
|
||||
stream << "; ";
|
||||
}
|
||||
if (!what.aliases.empty())
|
||||
{
|
||||
stream << "aliases=";
|
||||
dumpValue(stream, what.aliases);
|
||||
stream << "; ";
|
||||
}
|
||||
if (!what.array_join_result_to_source.empty())
|
||||
{
|
||||
stream << "array_join_result_to_source=";
|
||||
dumpValue(stream, what.array_join_result_to_source);
|
||||
stream << "; ";
|
||||
}
|
||||
if (!what.array_join_alias_to_name.empty())
|
||||
{
|
||||
stream << "array_join_alias_to_name=";
|
||||
dumpValue(stream, what.array_join_alias_to_name);
|
||||
stream << "; ";
|
||||
}
|
||||
if (!what.array_join_name_to_alias.empty())
|
||||
{
|
||||
stream << "array_join_name_to_alias=";
|
||||
dumpValue(stream, what.array_join_name_to_alias);
|
||||
stream << "; ";
|
||||
}
|
||||
stream << "rewrite_subqueries=" << what.rewrite_subqueries << "; ";
|
||||
stream << "}";
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -41,6 +41,14 @@ std::ostream & operator<<(std::ostream & stream, const IAST & what);
|
||||
|
||||
std::ostream & operator<<(std::ostream & stream, const Connection::Packet & what);
|
||||
|
||||
struct ExpressionAction;
|
||||
std::ostream & operator<<(std::ostream & stream, const ExpressionAction & what);
|
||||
|
||||
class ExpressionActions;
|
||||
std::ostream & operator<<(std::ostream & stream, const ExpressionActions & what);
|
||||
|
||||
struct SyntaxAnalyzerResult;
|
||||
std::ostream & operator<<(std::ostream & stream, const SyntaxAnalyzerResult & what);
|
||||
}
|
||||
|
||||
/// some operator<< should be declared before operator<<(... std::shared_ptr<>)
|
||||
|
@ -19,8 +19,8 @@ void CountingBlockOutputStream::write(const Block & block)
|
||||
Progress local_progress(block.rows(), block.bytes(), 0);
|
||||
progress.incrementPiecewiseAtomically(local_progress);
|
||||
|
||||
ProfileEvents::increment(ProfileEvents::InsertedRows, local_progress.rows);
|
||||
ProfileEvents::increment(ProfileEvents::InsertedBytes, local_progress.bytes);
|
||||
ProfileEvents::increment(ProfileEvents::InsertedRows, local_progress.read_rows);
|
||||
ProfileEvents::increment(ProfileEvents::InsertedBytes, local_progress.read_bytes);
|
||||
|
||||
if (process_elem)
|
||||
process_elem->updateProgressOut(local_progress);
|
||||
|
@ -281,7 +281,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
|
||||
/// The total amount of data processed or intended for processing in all leaf sources, possibly on remote servers.
|
||||
|
||||
ProgressValues progress = process_list_elem->getProgressIn();
|
||||
size_t total_rows_estimate = std::max(progress.rows, progress.total_rows);
|
||||
size_t total_rows_estimate = std::max(progress.read_rows, progress.total_rows_to_read);
|
||||
|
||||
/** Check the restrictions on the amount of data to read, the speed of the query, the quota on the amount of data to read.
|
||||
* NOTE: Maybe it makes sense to have them checked directly in ProcessList?
|
||||
@ -289,7 +289,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
|
||||
|
||||
if (limits.mode == LIMITS_TOTAL
|
||||
&& ((limits.size_limits.max_rows && total_rows_estimate > limits.size_limits.max_rows)
|
||||
|| (limits.size_limits.max_bytes && progress.bytes > limits.size_limits.max_bytes)))
|
||||
|| (limits.size_limits.max_bytes && progress.read_bytes > limits.size_limits.max_bytes)))
|
||||
{
|
||||
switch (limits.size_limits.overflow_mode)
|
||||
{
|
||||
@ -300,7 +300,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
|
||||
+ " rows read (or to read), maximum: " + toString(limits.size_limits.max_rows),
|
||||
ErrorCodes::TOO_MANY_ROWS);
|
||||
else
|
||||
throw Exception("Limit for (uncompressed) bytes to read exceeded: " + toString(progress.bytes)
|
||||
throw Exception("Limit for (uncompressed) bytes to read exceeded: " + toString(progress.read_bytes)
|
||||
+ " bytes read, maximum: " + toString(limits.size_limits.max_bytes),
|
||||
ErrorCodes::TOO_MANY_BYTES);
|
||||
}
|
||||
@ -308,8 +308,8 @@ void IBlockInputStream::progressImpl(const Progress & value)
|
||||
case OverflowMode::BREAK:
|
||||
{
|
||||
/// For `break`, we will stop only if so many rows were actually read, and not just supposed to be read.
|
||||
if ((limits.size_limits.max_rows && progress.rows > limits.size_limits.max_rows)
|
||||
|| (limits.size_limits.max_bytes && progress.bytes > limits.size_limits.max_bytes))
|
||||
if ((limits.size_limits.max_rows && progress.read_rows > limits.size_limits.max_rows)
|
||||
|| (limits.size_limits.max_bytes && progress.read_bytes > limits.size_limits.max_bytes))
|
||||
{
|
||||
cancel(false);
|
||||
}
|
||||
@ -322,7 +322,7 @@ void IBlockInputStream::progressImpl(const Progress & value)
|
||||
}
|
||||
}
|
||||
|
||||
size_t total_rows = progress.total_rows;
|
||||
size_t total_rows = progress.total_rows_to_read;
|
||||
|
||||
constexpr UInt64 profile_events_update_period_microseconds = 10 * 1000; // 10 milliseconds
|
||||
UInt64 total_elapsed_microseconds = info.total_stopwatch.elapsedMicroseconds();
|
||||
@ -344,20 +344,20 @@ void IBlockInputStream::progressImpl(const Progress & value)
|
||||
|
||||
if (elapsed_seconds > 0)
|
||||
{
|
||||
if (limits.min_execution_speed && progress.rows / elapsed_seconds < limits.min_execution_speed)
|
||||
throw Exception("Query is executing too slow: " + toString(progress.rows / elapsed_seconds)
|
||||
if (limits.min_execution_speed && progress.read_rows / elapsed_seconds < limits.min_execution_speed)
|
||||
throw Exception("Query is executing too slow: " + toString(progress.read_rows / elapsed_seconds)
|
||||
+ " rows/sec., minimum: " + toString(limits.min_execution_speed),
|
||||
ErrorCodes::TOO_SLOW);
|
||||
|
||||
if (limits.min_execution_speed_bytes && progress.bytes / elapsed_seconds < limits.min_execution_speed_bytes)
|
||||
throw Exception("Query is executing too slow: " + toString(progress.bytes / elapsed_seconds)
|
||||
if (limits.min_execution_speed_bytes && progress.read_bytes / elapsed_seconds < limits.min_execution_speed_bytes)
|
||||
throw Exception("Query is executing too slow: " + toString(progress.read_bytes / elapsed_seconds)
|
||||
+ " bytes/sec., minimum: " + toString(limits.min_execution_speed_bytes),
|
||||
ErrorCodes::TOO_SLOW);
|
||||
|
||||
/// If the predicted execution time is longer than `max_execution_time`.
|
||||
if (limits.max_execution_time != 0 && total_rows)
|
||||
{
|
||||
double estimated_execution_time_seconds = elapsed_seconds * (static_cast<double>(total_rows) / progress.rows);
|
||||
double estimated_execution_time_seconds = elapsed_seconds * (static_cast<double>(total_rows) / progress.read_rows);
|
||||
|
||||
if (estimated_execution_time_seconds > limits.max_execution_time.totalSeconds())
|
||||
throw Exception("Estimated query execution time (" + toString(estimated_execution_time_seconds) + " seconds)"
|
||||
@ -366,17 +366,17 @@ void IBlockInputStream::progressImpl(const Progress & value)
|
||||
ErrorCodes::TOO_SLOW);
|
||||
}
|
||||
|
||||
if (limits.max_execution_speed && progress.rows / elapsed_seconds >= limits.max_execution_speed)
|
||||
limitProgressingSpeed(progress.rows, limits.max_execution_speed, total_elapsed_microseconds);
|
||||
if (limits.max_execution_speed && progress.read_rows / elapsed_seconds >= limits.max_execution_speed)
|
||||
limitProgressingSpeed(progress.read_rows, limits.max_execution_speed, total_elapsed_microseconds);
|
||||
|
||||
if (limits.max_execution_speed_bytes && progress.bytes / elapsed_seconds >= limits.max_execution_speed_bytes)
|
||||
limitProgressingSpeed(progress.bytes, limits.max_execution_speed_bytes, total_elapsed_microseconds);
|
||||
if (limits.max_execution_speed_bytes && progress.read_bytes / elapsed_seconds >= limits.max_execution_speed_bytes)
|
||||
limitProgressingSpeed(progress.read_bytes, limits.max_execution_speed_bytes, total_elapsed_microseconds);
|
||||
}
|
||||
}
|
||||
|
||||
if (quota != nullptr && limits.mode == LIMITS_TOTAL)
|
||||
{
|
||||
quota->checkAndAddReadRowsBytes(time(nullptr), value.rows, value.bytes);
|
||||
quota->checkAndAddReadRowsBytes(time(nullptr), value.read_rows, value.read_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -36,6 +36,7 @@ public:
|
||||
bool canBeInsideNullable() const override { return false; }
|
||||
|
||||
DataTypePtr getReturnType() const { return function->getReturnType(); }
|
||||
DataTypePtr getReturnTypeToPredict() const { return function->getReturnTypeToPredict(); }
|
||||
DataTypes getArgumentsDataTypes() const { return argument_types; }
|
||||
|
||||
/// NOTE These two functions for serializing single values are incompatible with the functions below.
|
||||
|
@ -821,7 +821,7 @@ MutableColumnUniquePtr DataTypeLowCardinality::createColumnUniqueImpl(const IDat
|
||||
return creator(static_cast<ColumnVector<UInt16> *>(nullptr));
|
||||
if (typeid_cast<const DataTypeDateTime *>(type))
|
||||
return creator(static_cast<ColumnVector<UInt32> *>(nullptr));
|
||||
if (isNumber(type))
|
||||
if (isColumnedAsNumber(type))
|
||||
{
|
||||
MutableColumnUniquePtr column;
|
||||
TypeListNumbers::forEach(CreateColumnVector(column, *type, creator));
|
||||
|
@ -581,11 +581,18 @@ inline bool isFloat(const T & data_type)
|
||||
return which.isFloat();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool isNativeNumber(const T & data_type)
|
||||
{
|
||||
WhichDataType which(data_type);
|
||||
return which.isNativeInt() || which.isNativeUInt() || which.isFloat();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline bool isNumber(const T & data_type)
|
||||
{
|
||||
WhichDataType which(data_type);
|
||||
return which.isInt() || which.isUInt() || which.isFloat();
|
||||
return which.isInt() || which.isUInt() || which.isFloat() || which.isDecimal();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -86,7 +86,7 @@ void CacheDictionary::toParent(const PaddedPODArray<Key> & ids, PaddedPODArray<K
|
||||
{
|
||||
const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
|
||||
|
||||
getItemsNumber<UInt64>(*hierarchical_attribute, ids, out, [&](const size_t) { return null_value; });
|
||||
getItemsNumberImpl<UInt64, UInt64>(*hierarchical_attribute, ids, out, [&](const size_t) { return null_value; });
|
||||
}
|
||||
|
||||
|
||||
@ -207,9 +207,7 @@ void CacheDictionary::isInConstantVector(const Key child_id, const PaddedPODArra
|
||||
void CacheDictionary::getString(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ColumnString * out) const
|
||||
{
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
const auto null_value = StringRef{std::get<String>(attribute.null_values)};
|
||||
|
||||
@ -220,9 +218,7 @@ void CacheDictionary::getString(
|
||||
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const ColumnString * const def, ColumnString * const out) const
|
||||
{
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsString(attribute, ids, out, [&](const size_t row) { return def->getDataAt(row); });
|
||||
}
|
||||
@ -231,9 +227,7 @@ void CacheDictionary::getString(
|
||||
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const String & def, ColumnString * const out) const
|
||||
{
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsString(attribute, ids, out, [&](const size_t) { return StringRef{def}; });
|
||||
}
|
||||
|
@ -221,11 +221,6 @@ private:
|
||||
|
||||
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
|
||||
|
||||
|
||||
template <typename OutputType, typename DefaultGetter>
|
||||
void getItemsNumber(
|
||||
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const;
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename DefaultGetter>
|
||||
void getItemsNumberImpl(
|
||||
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const;
|
||||
|
@ -34,34 +34,6 @@ namespace ErrorCodes
|
||||
extern const int TYPE_MISMATCH;
|
||||
}
|
||||
|
||||
template <typename OutputType, typename DefaultGetter>
|
||||
void CacheDictionary::getItemsNumber(
|
||||
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const
|
||||
{
|
||||
if (false)
|
||||
{
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
else if (attribute.type == AttributeUnderlyingType::TYPE) \
|
||||
getItemsNumberImpl<TYPE, OutputType>(attribute, ids, out, std::forward<DefaultGetter>(get_default));
|
||||
DISPATCH(UInt8)
|
||||
DISPATCH(UInt16)
|
||||
DISPATCH(UInt32)
|
||||
DISPATCH(UInt64)
|
||||
DISPATCH(UInt128)
|
||||
DISPATCH(Int8)
|
||||
DISPATCH(Int16)
|
||||
DISPATCH(Int32)
|
||||
DISPATCH(Int64)
|
||||
DISPATCH(Float32)
|
||||
DISPATCH(Float64)
|
||||
DISPATCH(Decimal32)
|
||||
DISPATCH(Decimal64)
|
||||
DISPATCH(Decimal128)
|
||||
#undef DISPATCH
|
||||
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename DefaultGetter>
|
||||
void CacheDictionary::getItemsNumberImpl(
|
||||
Attribute & attribute, const PaddedPODArray<Key> & ids, ResultArrayType<OutputType> & out, DefaultGetter && get_default) const
|
||||
|
@ -12,13 +12,11 @@ using TYPE = @NAME@;
|
||||
void CacheDictionary::get@NAME@(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ResultArrayType<TYPE> & out) const
|
||||
{
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
|
||||
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
|
||||
|
||||
const auto null_value = std::get<TYPE>(attribute.null_values);
|
||||
|
||||
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t) { return null_value; });
|
||||
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t) { return null_value; });
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -15,11 +15,9 @@ void CacheDictionary::get@NAME@(const std::string & attribute_name,
|
||||
ResultArrayType<TYPE> & out) const
|
||||
{
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
|
||||
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
|
||||
|
||||
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t row) { return def[row]; });
|
||||
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t row) { return def[row]; });
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -12,11 +12,9 @@ using TYPE = @NAME@;
|
||||
void CacheDictionary::get@NAME@(const std::string & attribute_name, const PaddedPODArray<Key> & ids, const TYPE def, ResultArrayType<TYPE> & out) const
|
||||
{
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
|
||||
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
|
||||
|
||||
getItemsNumber<TYPE>(attribute, ids, out, [&](const size_t) { return def; });
|
||||
getItemsNumberImpl<TYPE, TYPE>(attribute, ids, out, [&](const size_t) { return def; });
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -77,9 +77,7 @@ void ComplexKeyCacheDictionary::getString(
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
const auto null_value = StringRef{std::get<String>(attribute.null_values)};
|
||||
|
||||
@ -96,9 +94,7 @@ void ComplexKeyCacheDictionary::getString(
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsString(attribute, key_columns, out, [&](const size_t row) { return def->getDataAt(row); });
|
||||
}
|
||||
@ -113,9 +109,7 @@ void ComplexKeyCacheDictionary::getString(
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsString(attribute, key_columns, out, [&](const size_t) { return StringRef{def}; });
|
||||
}
|
||||
|
@ -256,34 +256,6 @@ private:
|
||||
|
||||
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
|
||||
|
||||
template <typename OutputType, typename DefaultGetter>
|
||||
void
|
||||
getItemsNumber(Attribute & attribute, const Columns & key_columns, PaddedPODArray<OutputType> & out, DefaultGetter && get_default) const
|
||||
{
|
||||
if (false)
|
||||
{
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
else if (attribute.type == AttributeUnderlyingType::TYPE) \
|
||||
getItemsNumberImpl<TYPE, OutputType>(attribute, key_columns, out, std::forward<DefaultGetter>(get_default));
|
||||
DISPATCH(UInt8)
|
||||
DISPATCH(UInt16)
|
||||
DISPATCH(UInt32)
|
||||
DISPATCH(UInt64)
|
||||
DISPATCH(UInt128)
|
||||
DISPATCH(Int8)
|
||||
DISPATCH(Int16)
|
||||
DISPATCH(Int32)
|
||||
DISPATCH(Int64)
|
||||
DISPATCH(Float32)
|
||||
DISPATCH(Float64)
|
||||
DISPATCH(Decimal32)
|
||||
DISPATCH(Decimal64)
|
||||
DISPATCH(Decimal128)
|
||||
#undef DISPATCH
|
||||
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename DefaultGetter>
|
||||
void getItemsNumberImpl(
|
||||
Attribute & attribute, const Columns & key_columns, PaddedPODArray<OutputType> & out, DefaultGetter && get_default) const
|
||||
|
@ -13,12 +13,10 @@ void ComplexKeyCacheDictionary::get@NAME@(const std::string & attribute_name, co
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
|
||||
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
|
||||
|
||||
const auto null_value = std::get<TYPE>(attribute.null_values);
|
||||
|
||||
getItemsNumber<TYPE>(attribute, key_columns, out, [&](const size_t) { return null_value; });
|
||||
getItemsNumberImpl<TYPE, TYPE>(attribute, key_columns, out, [&](const size_t) { return null_value; });
|
||||
}
|
||||
}
|
||||
|
@ -18,10 +18,8 @@ void ComplexKeyCacheDictionary::get@NAME@(const std::string & attribute_name,
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
|
||||
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
|
||||
|
||||
getItemsNumber<TYPE>(attribute, key_columns, out, [&](const size_t row) { return def[row]; });
|
||||
getItemsNumberImpl<TYPE, TYPE>(attribute, key_columns, out, [&](const size_t row) { return def[row]; });
|
||||
}
|
||||
}
|
||||
|
@ -18,10 +18,8 @@ void ComplexKeyCacheDictionary::get@NAME@(const std::string & attribute_name,
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::@NAME@))
|
||||
throw Exception {name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::@NAME@);
|
||||
|
||||
getItemsNumber<TYPE>(attribute, key_columns, out, [&](const size_t) { return def; });
|
||||
getItemsNumberImpl<TYPE, TYPE>(attribute, key_columns, out, [&](const size_t) { return def; });
|
||||
}
|
||||
}
|
||||
|
@ -50,13 +50,11 @@ ComplexKeyHashedDictionary::ComplexKeyHashedDictionary(
|
||||
dict_struct.validateKeyTypes(key_types); \
|
||||
\
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
const auto null_value = std::get<TYPE>(attribute.null_values); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, \
|
||||
key_columns, \
|
||||
[&](const size_t row, const auto value) { out[row] = value; }, \
|
||||
@ -84,9 +82,7 @@ void ComplexKeyHashedDictionary::getString(
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
|
||||
|
||||
@ -108,11 +104,9 @@ void ComplexKeyHashedDictionary::getString(
|
||||
dict_struct.validateKeyTypes(key_types); \
|
||||
\
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, \
|
||||
key_columns, \
|
||||
[&](const size_t row, const auto value) { out[row] = value; }, \
|
||||
@ -144,9 +138,7 @@ void ComplexKeyHashedDictionary::getString(
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsImpl<StringRef, StringRef>(
|
||||
attribute,
|
||||
@ -166,11 +158,9 @@ void ComplexKeyHashedDictionary::getString(
|
||||
dict_struct.validateKeyTypes(key_types); \
|
||||
\
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, key_columns, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return def; }); \
|
||||
}
|
||||
DECLARE(UInt8)
|
||||
@ -199,9 +189,7 @@ void ComplexKeyHashedDictionary::getString(
|
||||
dict_struct.validateKeyTypes(key_types);
|
||||
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsImpl<StringRef, StringRef>(
|
||||
attribute,
|
||||
@ -566,34 +554,6 @@ ComplexKeyHashedDictionary::createAttributeWithType(const AttributeUnderlyingTyp
|
||||
}
|
||||
|
||||
|
||||
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void ComplexKeyHashedDictionary::getItemsNumber(
|
||||
const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const
|
||||
{
|
||||
if (false)
|
||||
{
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
else if (attribute.type == AttributeUnderlyingType::TYPE) getItemsImpl<TYPE, OutputType>( \
|
||||
attribute, key_columns, std::forward<ValueSetter>(set_value), std::forward<DefaultGetter>(get_default));
|
||||
DISPATCH(UInt8)
|
||||
DISPATCH(UInt16)
|
||||
DISPATCH(UInt32)
|
||||
DISPATCH(UInt64)
|
||||
DISPATCH(UInt128)
|
||||
DISPATCH(Int8)
|
||||
DISPATCH(Int16)
|
||||
DISPATCH(Int32)
|
||||
DISPATCH(Int64)
|
||||
DISPATCH(Float32)
|
||||
DISPATCH(Float64)
|
||||
DISPATCH(Decimal32)
|
||||
DISPATCH(Decimal64)
|
||||
DISPATCH(Decimal128)
|
||||
#undef DISPATCH
|
||||
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void ComplexKeyHashedDictionary::getItemsImpl(
|
||||
const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const
|
||||
|
@ -218,16 +218,10 @@ private:
|
||||
|
||||
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
|
||||
|
||||
|
||||
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void
|
||||
getItemsNumber(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void
|
||||
getItemsImpl(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;
|
||||
|
||||
|
||||
template <typename T>
|
||||
bool setAttributeValueImpl(Attribute & attribute, const StringRef key, const T value);
|
||||
|
||||
|
@ -40,44 +40,6 @@ namespace
|
||||
} // namespace
|
||||
|
||||
|
||||
bool isAttributeTypeConvertibleTo(AttributeUnderlyingType from, AttributeUnderlyingType to)
|
||||
{
|
||||
if (from == to)
|
||||
return true;
|
||||
|
||||
/** This enum can be somewhat incomplete and the meaning may not coincide with NumberTraits.h.
|
||||
* (for example, because integers can not be converted to floats)
|
||||
* This is normal for a limited usage scope.
|
||||
*/
|
||||
if ((from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::UInt16)
|
||||
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::UInt32)
|
||||
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::UInt64)
|
||||
|| (from == AttributeUnderlyingType::UInt16 && to == AttributeUnderlyingType::UInt32)
|
||||
|| (from == AttributeUnderlyingType::UInt16 && to == AttributeUnderlyingType::UInt64)
|
||||
|| (from == AttributeUnderlyingType::UInt32 && to == AttributeUnderlyingType::UInt64)
|
||||
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::Int16)
|
||||
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::Int32)
|
||||
|| (from == AttributeUnderlyingType::UInt8 && to == AttributeUnderlyingType::Int64)
|
||||
|| (from == AttributeUnderlyingType::UInt16 && to == AttributeUnderlyingType::Int32)
|
||||
|| (from == AttributeUnderlyingType::UInt16 && to == AttributeUnderlyingType::Int64)
|
||||
|| (from == AttributeUnderlyingType::UInt32 && to == AttributeUnderlyingType::Int64)
|
||||
|
||||
|| (from == AttributeUnderlyingType::Int8 && to == AttributeUnderlyingType::Int16)
|
||||
|| (from == AttributeUnderlyingType::Int8 && to == AttributeUnderlyingType::Int32)
|
||||
|| (from == AttributeUnderlyingType::Int8 && to == AttributeUnderlyingType::Int64)
|
||||
|| (from == AttributeUnderlyingType::Int16 && to == AttributeUnderlyingType::Int32)
|
||||
|| (from == AttributeUnderlyingType::Int16 && to == AttributeUnderlyingType::Int64)
|
||||
|| (from == AttributeUnderlyingType::Int32 && to == AttributeUnderlyingType::Int64)
|
||||
|
||||
|| (from == AttributeUnderlyingType::Float32 && to == AttributeUnderlyingType::Float64))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
AttributeUnderlyingType getAttributeUnderlyingType(const std::string & type)
|
||||
{
|
||||
static const std::unordered_map<std::string, AttributeUnderlyingType> dictionary{
|
||||
|
@ -13,6 +13,12 @@
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
namespace ErrorCodes
|
||||
{
|
||||
extern const int TYPE_MISMATCH;
|
||||
}
|
||||
|
||||
enum class AttributeUnderlyingType
|
||||
{
|
||||
UInt8,
|
||||
@ -33,14 +39,18 @@ enum class AttributeUnderlyingType
|
||||
};
|
||||
|
||||
|
||||
/** For implicit conversions in dictGet functions.
|
||||
*/
|
||||
bool isAttributeTypeConvertibleTo(AttributeUnderlyingType from, AttributeUnderlyingType to);
|
||||
|
||||
AttributeUnderlyingType getAttributeUnderlyingType(const std::string & type);
|
||||
|
||||
std::string toString(const AttributeUnderlyingType type);
|
||||
|
||||
/// Implicit conversions in dictGet functions is disabled.
|
||||
inline void checkAttributeType(const std::string & dict_name, const std::string & attribute_name,
|
||||
AttributeUnderlyingType attribute_type, AttributeUnderlyingType to)
|
||||
{
|
||||
if (attribute_type != to)
|
||||
throw Exception{dict_name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute_type)
|
||||
+ ", expected " + toString(to), ErrorCodes::TYPE_MISMATCH};
|
||||
}
|
||||
|
||||
/// Min and max lifetimes for a dictionary or it's entry
|
||||
using DictionaryLifetime = ExternalLoadableLifetime;
|
||||
|
@ -55,7 +55,7 @@ void FlatDictionary::toParent(const PaddedPODArray<Key> & ids, PaddedPODArray<Ke
|
||||
{
|
||||
const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
|
||||
|
||||
getItemsNumber<UInt64>(
|
||||
getItemsImpl<UInt64, UInt64>(
|
||||
*hierarchical_attribute,
|
||||
ids,
|
||||
[&](const size_t row, const UInt64 value) { out[row] = value; },
|
||||
@ -117,13 +117,11 @@ void FlatDictionary::isInConstantVector(const Key child_id, const PaddedPODArray
|
||||
void FlatDictionary::get##TYPE(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ResultArrayType<TYPE> & out) const \
|
||||
{ \
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
const auto null_value = std::get<TYPE>(attribute.null_values); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return null_value; }); \
|
||||
}
|
||||
DECLARE(UInt8)
|
||||
@ -145,9 +143,7 @@ DECLARE(Decimal128)
|
||||
void FlatDictionary::getString(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ColumnString * out) const
|
||||
{
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
const auto & null_value = std::get<StringRef>(attribute.null_values);
|
||||
|
||||
@ -166,11 +162,9 @@ void FlatDictionary::getString(const std::string & attribute_name, const PaddedP
|
||||
ResultArrayType<TYPE> & out) const \
|
||||
{ \
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t row) { return def[row]; }); \
|
||||
}
|
||||
DECLARE(UInt8)
|
||||
@ -193,9 +187,7 @@ void FlatDictionary::getString(
|
||||
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const ColumnString * const def, ColumnString * const out) const
|
||||
{
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsImpl<StringRef, StringRef>(
|
||||
attribute,
|
||||
@ -209,11 +201,9 @@ void FlatDictionary::getString(
|
||||
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const TYPE def, ResultArrayType<TYPE> & out) const \
|
||||
{ \
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return def; }); \
|
||||
}
|
||||
DECLARE(UInt8)
|
||||
@ -236,9 +226,7 @@ void FlatDictionary::getString(
|
||||
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const String & def, ColumnString * const out) const
|
||||
{
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
FlatDictionary::getItemsImpl<StringRef, StringRef>(
|
||||
attribute,
|
||||
@ -580,35 +568,6 @@ FlatDictionary::Attribute FlatDictionary::createAttributeWithType(const Attribut
|
||||
}
|
||||
|
||||
|
||||
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void FlatDictionary::getItemsNumber(
|
||||
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const
|
||||
{
|
||||
if (false)
|
||||
{
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
else if (attribute.type == AttributeUnderlyingType::TYPE) \
|
||||
getItemsImpl<TYPE, OutputType>(attribute, ids, std::forward<ValueSetter>(set_value), std::forward<DefaultGetter>(get_default));
|
||||
DISPATCH(UInt8)
|
||||
DISPATCH(UInt16)
|
||||
DISPATCH(UInt32)
|
||||
DISPATCH(UInt64)
|
||||
DISPATCH(UInt128)
|
||||
DISPATCH(Int8)
|
||||
DISPATCH(Int16)
|
||||
DISPATCH(Int32)
|
||||
DISPATCH(Int64)
|
||||
DISPATCH(Float32)
|
||||
DISPATCH(Float64)
|
||||
DISPATCH(Decimal32)
|
||||
DISPATCH(Decimal64)
|
||||
DISPATCH(Decimal128)
|
||||
#undef DISPATCH
|
||||
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void FlatDictionary::getItemsImpl(
|
||||
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const
|
||||
@ -701,10 +660,10 @@ void FlatDictionary::setAttributeValue(Attribute & attribute, const Key id, cons
|
||||
break;
|
||||
|
||||
case AttributeUnderlyingType::Decimal32:
|
||||
setAttributeValueImpl<Decimal32>(attribute, id, value.get<Decimal128>());
|
||||
setAttributeValueImpl<Decimal32>(attribute, id, value.get<Decimal32>());
|
||||
break;
|
||||
case AttributeUnderlyingType::Decimal64:
|
||||
setAttributeValueImpl<Decimal64>(attribute, id, value.get<Decimal128>());
|
||||
setAttributeValueImpl<Decimal64>(attribute, id, value.get<Decimal64>());
|
||||
break;
|
||||
case AttributeUnderlyingType::Decimal128:
|
||||
setAttributeValueImpl<Decimal128>(attribute, id, value.get<Decimal128>());
|
||||
|
@ -206,10 +206,6 @@ private:
|
||||
|
||||
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
|
||||
|
||||
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void getItemsNumber(
|
||||
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const;
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void getItemsImpl(
|
||||
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const;
|
||||
|
@ -49,7 +49,7 @@ void HashedDictionary::toParent(const PaddedPODArray<Key> & ids, PaddedPODArray<
|
||||
{
|
||||
const auto null_value = std::get<UInt64>(hierarchical_attribute->null_values);
|
||||
|
||||
getItemsNumber<UInt64>(
|
||||
getItemsImpl<UInt64, UInt64>(
|
||||
*hierarchical_attribute,
|
||||
ids,
|
||||
[&](const size_t row, const UInt64 value) { out[row] = value; },
|
||||
@ -116,13 +116,11 @@ void HashedDictionary::isInConstantVector(const Key child_id, const PaddedPODArr
|
||||
const \
|
||||
{ \
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
const auto null_value = std::get<TYPE>(attribute.null_values); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return null_value; }); \
|
||||
}
|
||||
DECLARE(UInt8)
|
||||
@ -144,9 +142,7 @@ DECLARE(Decimal128)
|
||||
void HashedDictionary::getString(const std::string & attribute_name, const PaddedPODArray<Key> & ids, ColumnString * out) const
|
||||
{
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
|
||||
|
||||
@ -165,11 +161,9 @@ void HashedDictionary::getString(const std::string & attribute_name, const Padde
|
||||
ResultArrayType<TYPE> & out) const \
|
||||
{ \
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t row) { return def[row]; }); \
|
||||
}
|
||||
DECLARE(UInt8)
|
||||
@ -192,9 +186,7 @@ void HashedDictionary::getString(
|
||||
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const ColumnString * const def, ColumnString * const out) const
|
||||
{
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsImpl<StringRef, StringRef>(
|
||||
attribute,
|
||||
@ -208,11 +200,9 @@ void HashedDictionary::getString(
|
||||
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const TYPE & def, ResultArrayType<TYPE> & out) const \
|
||||
{ \
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, ids, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return def; }); \
|
||||
}
|
||||
DECLARE(UInt8)
|
||||
@ -235,9 +225,7 @@ void HashedDictionary::getString(
|
||||
const std::string & attribute_name, const PaddedPODArray<Key> & ids, const String & def, ColumnString * const out) const
|
||||
{
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsImpl<StringRef, StringRef>(
|
||||
attribute,
|
||||
@ -324,7 +312,6 @@ void HashedDictionary::createAttributes()
|
||||
void HashedDictionary::blockToAttributes(const Block & block)
|
||||
{
|
||||
const auto & id_column = *block.safeGetByPosition(0).column;
|
||||
element_count += id_column.size();
|
||||
|
||||
for (const size_t attribute_idx : ext::range(0, attributes.size()))
|
||||
{
|
||||
@ -332,7 +319,8 @@ void HashedDictionary::blockToAttributes(const Block & block)
|
||||
auto & attribute = attributes[attribute_idx];
|
||||
|
||||
for (const auto row_idx : ext::range(0, id_column.size()))
|
||||
setAttributeValue(attribute, id_column[row_idx].get<UInt64>(), attribute_column[row_idx]);
|
||||
if (setAttributeValue(attribute, id_column[row_idx].get<UInt64>(), attribute_column[row_idx]))
|
||||
++element_count;
|
||||
}
|
||||
}
|
||||
|
||||
@ -567,34 +555,6 @@ HashedDictionary::Attribute HashedDictionary::createAttributeWithType(const Attr
|
||||
}
|
||||
|
||||
|
||||
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void HashedDictionary::getItemsNumber(
|
||||
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const
|
||||
{
|
||||
if (false)
|
||||
{
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
else if (attribute.type == AttributeUnderlyingType::TYPE) \
|
||||
getItemsImpl<TYPE, OutputType>(attribute, ids, std::forward<ValueSetter>(set_value), std::forward<DefaultGetter>(get_default));
|
||||
DISPATCH(UInt8)
|
||||
DISPATCH(UInt16)
|
||||
DISPATCH(UInt32)
|
||||
DISPATCH(UInt64)
|
||||
DISPATCH(UInt128)
|
||||
DISPATCH(Int8)
|
||||
DISPATCH(Int16)
|
||||
DISPATCH(Int32)
|
||||
DISPATCH(Int64)
|
||||
DISPATCH(Float32)
|
||||
DISPATCH(Float64)
|
||||
DISPATCH(Decimal32)
|
||||
DISPATCH(Decimal64)
|
||||
DISPATCH(Decimal128)
|
||||
#undef DISPATCH
|
||||
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void HashedDictionary::getItemsImpl(
|
||||
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const
|
||||
@ -613,69 +573,56 @@ void HashedDictionary::getItemsImpl(
|
||||
|
||||
|
||||
template <typename T>
|
||||
void HashedDictionary::setAttributeValueImpl(Attribute & attribute, const Key id, const T value)
|
||||
bool HashedDictionary::setAttributeValueImpl(Attribute & attribute, const Key id, const T value)
|
||||
{
|
||||
auto & map = *std::get<CollectionPtrType<T>>(attribute.maps);
|
||||
map.insert({id, value});
|
||||
return map.insert({id, value}).second;
|
||||
}
|
||||
|
||||
void HashedDictionary::setAttributeValue(Attribute & attribute, const Key id, const Field & value)
|
||||
bool HashedDictionary::setAttributeValue(Attribute & attribute, const Key id, const Field & value)
|
||||
{
|
||||
switch (attribute.type)
|
||||
{
|
||||
case AttributeUnderlyingType::UInt8:
|
||||
setAttributeValueImpl<UInt8>(attribute, id, value.get<UInt64>());
|
||||
break;
|
||||
return setAttributeValueImpl<UInt8>(attribute, id, value.get<UInt64>());
|
||||
case AttributeUnderlyingType::UInt16:
|
||||
setAttributeValueImpl<UInt16>(attribute, id, value.get<UInt64>());
|
||||
break;
|
||||
return setAttributeValueImpl<UInt16>(attribute, id, value.get<UInt64>());
|
||||
case AttributeUnderlyingType::UInt32:
|
||||
setAttributeValueImpl<UInt32>(attribute, id, value.get<UInt64>());
|
||||
break;
|
||||
return setAttributeValueImpl<UInt32>(attribute, id, value.get<UInt64>());
|
||||
case AttributeUnderlyingType::UInt64:
|
||||
setAttributeValueImpl<UInt64>(attribute, id, value.get<UInt64>());
|
||||
break;
|
||||
return setAttributeValueImpl<UInt64>(attribute, id, value.get<UInt64>());
|
||||
case AttributeUnderlyingType::UInt128:
|
||||
setAttributeValueImpl<UInt128>(attribute, id, value.get<UInt128>());
|
||||
break;
|
||||
return setAttributeValueImpl<UInt128>(attribute, id, value.get<UInt128>());
|
||||
case AttributeUnderlyingType::Int8:
|
||||
setAttributeValueImpl<Int8>(attribute, id, value.get<Int64>());
|
||||
break;
|
||||
return setAttributeValueImpl<Int8>(attribute, id, value.get<Int64>());
|
||||
case AttributeUnderlyingType::Int16:
|
||||
setAttributeValueImpl<Int16>(attribute, id, value.get<Int64>());
|
||||
break;
|
||||
return setAttributeValueImpl<Int16>(attribute, id, value.get<Int64>());
|
||||
case AttributeUnderlyingType::Int32:
|
||||
setAttributeValueImpl<Int32>(attribute, id, value.get<Int64>());
|
||||
break;
|
||||
return setAttributeValueImpl<Int32>(attribute, id, value.get<Int64>());
|
||||
case AttributeUnderlyingType::Int64:
|
||||
setAttributeValueImpl<Int64>(attribute, id, value.get<Int64>());
|
||||
break;
|
||||
return setAttributeValueImpl<Int64>(attribute, id, value.get<Int64>());
|
||||
case AttributeUnderlyingType::Float32:
|
||||
setAttributeValueImpl<Float32>(attribute, id, value.get<Float64>());
|
||||
break;
|
||||
return setAttributeValueImpl<Float32>(attribute, id, value.get<Float64>());
|
||||
case AttributeUnderlyingType::Float64:
|
||||
setAttributeValueImpl<Float64>(attribute, id, value.get<Float64>());
|
||||
break;
|
||||
return setAttributeValueImpl<Float64>(attribute, id, value.get<Float64>());
|
||||
|
||||
case AttributeUnderlyingType::Decimal32:
|
||||
setAttributeValueImpl<Decimal32>(attribute, id, value.get<Decimal32>());
|
||||
break;
|
||||
return setAttributeValueImpl<Decimal32>(attribute, id, value.get<Decimal32>());
|
||||
case AttributeUnderlyingType::Decimal64:
|
||||
setAttributeValueImpl<Decimal64>(attribute, id, value.get<Decimal64>());
|
||||
break;
|
||||
return setAttributeValueImpl<Decimal64>(attribute, id, value.get<Decimal64>());
|
||||
case AttributeUnderlyingType::Decimal128:
|
||||
setAttributeValueImpl<Decimal128>(attribute, id, value.get<Decimal128>());
|
||||
break;
|
||||
return setAttributeValueImpl<Decimal128>(attribute, id, value.get<Decimal128>());
|
||||
|
||||
case AttributeUnderlyingType::String:
|
||||
{
|
||||
auto & map = *std::get<CollectionPtrType<StringRef>>(attribute.maps);
|
||||
const auto & string = value.get<String>();
|
||||
const auto string_in_arena = attribute.string_arena->insert(string.data(), string.size());
|
||||
map.insert({id, StringRef{string_in_arena, string.size()}});
|
||||
break;
|
||||
return map.insert({id, StringRef{string_in_arena, string.size()}}).second;
|
||||
}
|
||||
}
|
||||
|
||||
throw Exception{"Invalid attribute type", ErrorCodes::BAD_ARGUMENTS};
|
||||
}
|
||||
|
||||
const HashedDictionary::Attribute & HashedDictionary::getAttribute(const std::string & attribute_name) const
|
||||
|
@ -211,18 +211,14 @@ private:
|
||||
|
||||
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
|
||||
|
||||
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void getItemsNumber(
|
||||
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const;
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void getItemsImpl(
|
||||
const Attribute & attribute, const PaddedPODArray<Key> & ids, ValueSetter && set_value, DefaultGetter && get_default) const;
|
||||
|
||||
template <typename T>
|
||||
void setAttributeValueImpl(Attribute & attribute, const Key id, const T value);
|
||||
bool setAttributeValueImpl(Attribute & attribute, const Key id, const T value);
|
||||
|
||||
void setAttributeValue(Attribute & attribute, const Key id, const Field & value);
|
||||
bool setAttributeValue(Attribute & attribute, const Key id, const Field & value);
|
||||
|
||||
const Attribute & getAttribute(const std::string & attribute_name) const;
|
||||
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "DictionaryStructure.h"
|
||||
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
namespace ErrorCodes
|
||||
@ -47,7 +48,6 @@ void registerDictionarySourceMysql(DictionarySourceFactory & factory)
|
||||
# include <Formats/MySQLBlockInputStream.h>
|
||||
# include "readInvalidateQuery.h"
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
static const UInt64 max_block_size = 8192;
|
||||
@ -71,6 +71,7 @@ MySQLDictionarySource::MySQLDictionarySource(
|
||||
, query_builder{dict_struct, db, table, where, IdentifierQuotingStyle::Backticks}
|
||||
, load_all_query{query_builder.composeLoadAllQuery()}
|
||||
, invalidate_query{config.getString(config_prefix + ".invalidate_query", "")}
|
||||
, close_connection{config.getBool(config_prefix + ".close_connection", false)}
|
||||
{
|
||||
}
|
||||
|
||||
@ -91,6 +92,7 @@ MySQLDictionarySource::MySQLDictionarySource(const MySQLDictionarySource & other
|
||||
, last_modification{other.last_modification}
|
||||
, invalidate_query{other.invalidate_query}
|
||||
, invalidate_query_response{other.invalidate_query_response}
|
||||
, close_connection{other.close_connection}
|
||||
{
|
||||
}
|
||||
|
||||
@ -117,7 +119,7 @@ BlockInputStreamPtr MySQLDictionarySource::loadAll()
|
||||
last_modification = getLastModification();
|
||||
|
||||
LOG_TRACE(log, load_all_query);
|
||||
return std::make_shared<MySQLBlockInputStream>(pool.Get(), load_all_query, sample_block, max_block_size);
|
||||
return std::make_shared<MySQLBlockInputStream>(pool.Get(), load_all_query, sample_block, max_block_size, close_connection);
|
||||
}
|
||||
|
||||
BlockInputStreamPtr MySQLDictionarySource::loadUpdatedAll()
|
||||
@ -126,7 +128,7 @@ BlockInputStreamPtr MySQLDictionarySource::loadUpdatedAll()
|
||||
|
||||
std::string load_update_query = getUpdateFieldAndDate();
|
||||
LOG_TRACE(log, load_update_query);
|
||||
return std::make_shared<MySQLBlockInputStream>(pool.Get(), load_update_query, sample_block, max_block_size);
|
||||
return std::make_shared<MySQLBlockInputStream>(pool.Get(), load_update_query, sample_block, max_block_size, close_connection);
|
||||
}
|
||||
|
||||
BlockInputStreamPtr MySQLDictionarySource::loadIds(const std::vector<UInt64> & ids)
|
||||
@ -134,7 +136,7 @@ BlockInputStreamPtr MySQLDictionarySource::loadIds(const std::vector<UInt64> & i
|
||||
/// We do not log in here and do not update the modification time, as the request can be large, and often called.
|
||||
|
||||
const auto query = query_builder.composeLoadIdsQuery(ids);
|
||||
return std::make_shared<MySQLBlockInputStream>(pool.Get(), query, sample_block, max_block_size);
|
||||
return std::make_shared<MySQLBlockInputStream>(pool.Get(), query, sample_block, max_block_size, close_connection);
|
||||
}
|
||||
|
||||
BlockInputStreamPtr MySQLDictionarySource::loadKeys(const Columns & key_columns, const std::vector<size_t> & requested_rows)
|
||||
@ -142,7 +144,7 @@ BlockInputStreamPtr MySQLDictionarySource::loadKeys(const Columns & key_columns,
|
||||
/// We do not log in here and do not update the modification time, as the request can be large, and often called.
|
||||
|
||||
const auto query = query_builder.composeLoadKeysQuery(key_columns, requested_rows, ExternalQueryBuilder::AND_OR_CHAIN);
|
||||
return std::make_shared<MySQLBlockInputStream>(pool.Get(), query, sample_block, max_block_size);
|
||||
return std::make_shared<MySQLBlockInputStream>(pool.Get(), query, sample_block, max_block_size, close_connection);
|
||||
}
|
||||
|
||||
bool MySQLDictionarySource::isModified() const
|
||||
@ -253,7 +255,7 @@ std::string MySQLDictionarySource::doInvalidateQuery(const std::string & request
|
||||
Block invalidate_sample_block;
|
||||
ColumnPtr column(ColumnString::create());
|
||||
invalidate_sample_block.insert(ColumnWithTypeAndName(column, std::make_shared<DataTypeString>(), "Sample Block"));
|
||||
MySQLBlockInputStream block_input_stream(pool.Get(), request, invalidate_sample_block, 1);
|
||||
MySQLBlockInputStream block_input_stream(pool.Get(), request, invalidate_sample_block, 1, close_connection);
|
||||
return readInvalidateQuery(block_input_stream);
|
||||
}
|
||||
|
||||
|
@ -81,6 +81,7 @@ private:
|
||||
LocalDateTime last_modification;
|
||||
std::string invalidate_query;
|
||||
mutable std::string invalidate_query_response;
|
||||
const bool close_connection;
|
||||
};
|
||||
|
||||
}
|
||||
|
@ -75,13 +75,11 @@ TrieDictionary::~TrieDictionary()
|
||||
validateKeyTypes(key_types); \
|
||||
\
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
const auto null_value = std::get<TYPE>(attribute.null_values); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, \
|
||||
key_columns, \
|
||||
[&](const size_t row, const auto value) { out[row] = value; }, \
|
||||
@ -109,9 +107,7 @@ void TrieDictionary::getString(
|
||||
validateKeyTypes(key_types);
|
||||
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
const auto & null_value = StringRef{std::get<String>(attribute.null_values)};
|
||||
|
||||
@ -133,11 +129,9 @@ void TrieDictionary::getString(
|
||||
validateKeyTypes(key_types); \
|
||||
\
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, \
|
||||
key_columns, \
|
||||
[&](const size_t row, const auto value) { out[row] = value; }, \
|
||||
@ -169,9 +163,7 @@ void TrieDictionary::getString(
|
||||
validateKeyTypes(key_types);
|
||||
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsImpl<StringRef, StringRef>(
|
||||
attribute,
|
||||
@ -191,11 +183,9 @@ void TrieDictionary::getString(
|
||||
validateKeyTypes(key_types); \
|
||||
\
|
||||
const auto & attribute = getAttribute(attribute_name); \
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::TYPE)) \
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type), \
|
||||
ErrorCodes::TYPE_MISMATCH}; \
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::TYPE); \
|
||||
\
|
||||
getItemsNumber<TYPE>( \
|
||||
getItemsImpl<TYPE, TYPE>( \
|
||||
attribute, key_columns, [&](const size_t row, const auto value) { out[row] = value; }, [&](const size_t) { return def; }); \
|
||||
}
|
||||
DECLARE(UInt8)
|
||||
@ -224,9 +214,7 @@ void TrieDictionary::getString(
|
||||
validateKeyTypes(key_types);
|
||||
|
||||
const auto & attribute = getAttribute(attribute_name);
|
||||
if (!isAttributeTypeConvertibleTo(attribute.type, AttributeUnderlyingType::String))
|
||||
throw Exception{name + ": type mismatch: attribute " + attribute_name + " has type " + toString(attribute.type),
|
||||
ErrorCodes::TYPE_MISMATCH};
|
||||
checkAttributeType(name, attribute_name, attribute.type, AttributeUnderlyingType::String);
|
||||
|
||||
getItemsImpl<StringRef, StringRef>(
|
||||
attribute,
|
||||
@ -507,34 +495,6 @@ TrieDictionary::Attribute TrieDictionary::createAttributeWithType(const Attribut
|
||||
}
|
||||
|
||||
|
||||
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void TrieDictionary::getItemsNumber(
|
||||
const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const
|
||||
{
|
||||
if (false)
|
||||
{
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
else if (attribute.type == AttributeUnderlyingType::TYPE) getItemsImpl<TYPE, OutputType>( \
|
||||
attribute, key_columns, std::forward<ValueSetter>(set_value), std::forward<DefaultGetter>(get_default));
|
||||
DISPATCH(UInt8)
|
||||
DISPATCH(UInt16)
|
||||
DISPATCH(UInt32)
|
||||
DISPATCH(UInt64)
|
||||
DISPATCH(UInt128)
|
||||
DISPATCH(Int8)
|
||||
DISPATCH(Int16)
|
||||
DISPATCH(Int32)
|
||||
DISPATCH(Int64)
|
||||
DISPATCH(Float32)
|
||||
DISPATCH(Float64)
|
||||
DISPATCH(Decimal32)
|
||||
DISPATCH(Decimal64)
|
||||
DISPATCH(Decimal128)
|
||||
#undef DISPATCH
|
||||
else throw Exception("Unexpected type of attribute: " + toString(attribute.type), ErrorCodes::LOGICAL_ERROR);
|
||||
}
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void TrieDictionary::getItemsImpl(
|
||||
const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const
|
||||
|
@ -218,10 +218,6 @@ private:
|
||||
Attribute createAttributeWithType(const AttributeUnderlyingType type, const Field & null_value);
|
||||
|
||||
|
||||
template <typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void
|
||||
getItemsNumber(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;
|
||||
|
||||
template <typename AttributeType, typename OutputType, typename ValueSetter, typename DefaultGetter>
|
||||
void
|
||||
getItemsImpl(const Attribute & attribute, const Columns & key_columns, ValueSetter && set_value, DefaultGetter && get_default) const;
|
||||
|
@ -340,7 +340,7 @@ bool OPTIMIZE(1) CSVRowInputStream::parseRowAndPrintDiagnosticInfo(MutableColumn
|
||||
if (curr_position < prev_position)
|
||||
throw Exception("Logical error: parsing is non-deterministic.", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (isNumber(current_column_type) || isDateOrDateTime(current_column_type))
|
||||
if (isNativeNumber(current_column_type) || isDateOrDateTime(current_column_type))
|
||||
{
|
||||
/// An empty string instead of a value.
|
||||
if (curr_position == prev_position)
|
||||
|
@ -130,6 +130,7 @@ void registerOutputFormatXML(FormatFactory & factory);
|
||||
void registerOutputFormatODBCDriver(FormatFactory & factory);
|
||||
void registerOutputFormatODBCDriver2(FormatFactory & factory);
|
||||
void registerOutputFormatNull(FormatFactory & factory);
|
||||
void registerOutputFormatMySQLWire(FormatFactory & factory);
|
||||
|
||||
/// Input only formats.
|
||||
|
||||
@ -168,6 +169,7 @@ FormatFactory::FormatFactory()
|
||||
registerOutputFormatODBCDriver(*this);
|
||||
registerOutputFormatODBCDriver2(*this);
|
||||
registerOutputFormatNull(*this);
|
||||
registerOutputFormatMySQLWire(*this);
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -220,10 +220,10 @@ void JSONRowOutputStream::writeStatistics()
|
||||
writeText(watch.elapsedSeconds(), *ostr);
|
||||
writeCString(",\n", *ostr);
|
||||
writeCString("\t\t\"rows_read\": ", *ostr);
|
||||
writeText(progress.rows.load(), *ostr);
|
||||
writeText(progress.read_rows.load(), *ostr);
|
||||
writeCString(",\n", *ostr);
|
||||
writeCString("\t\t\"bytes_read\": ", *ostr);
|
||||
writeText(progress.bytes.load(), *ostr);
|
||||
writeText(progress.read_bytes.load(), *ostr);
|
||||
writeChar('\n', *ostr);
|
||||
|
||||
writeCString("\t}", *ostr);
|
||||
|
@ -20,8 +20,8 @@ namespace ErrorCodes
|
||||
|
||||
|
||||
MySQLBlockInputStream::MySQLBlockInputStream(
|
||||
const mysqlxx::PoolWithFailover::Entry & entry, const std::string & query_str, const Block & sample_block, const UInt64 max_block_size)
|
||||
: entry{entry}, query{this->entry->query(query_str)}, result{query.use()}, max_block_size{max_block_size}
|
||||
const mysqlxx::PoolWithFailover::Entry & entry, const std::string & query_str, const Block & sample_block, const UInt64 max_block_size, const bool auto_close)
|
||||
: entry{entry}, query{this->entry->query(query_str)}, result{query.use()}, max_block_size{max_block_size}, auto_close{auto_close}
|
||||
{
|
||||
if (sample_block.columns() != result.getNumFields())
|
||||
throw Exception{"mysqlxx::UseQueryResult contains " + toString(result.getNumFields()) + " columns while "
|
||||
@ -93,7 +93,11 @@ Block MySQLBlockInputStream::readImpl()
|
||||
{
|
||||
auto row = result.fetch();
|
||||
if (!row)
|
||||
{
|
||||
if (auto_close)
|
||||
entry.disconnect();
|
||||
return {};
|
||||
}
|
||||
|
||||
MutableColumns columns(description.sample_block.columns());
|
||||
for (const auto i : ext::range(0, columns.size()))
|
||||
@ -126,7 +130,8 @@ Block MySQLBlockInputStream::readImpl()
|
||||
|
||||
row = result.fetch();
|
||||
}
|
||||
|
||||
if (auto_close)
|
||||
entry.disconnect();
|
||||
return description.sample_block.cloneWithColumns(std::move(columns));
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,8 @@ public:
|
||||
const mysqlxx::PoolWithFailover::Entry & entry,
|
||||
const std::string & query_str,
|
||||
const Block & sample_block,
|
||||
const UInt64 max_block_size);
|
||||
const UInt64 max_block_size,
|
||||
const bool auto_close = false);
|
||||
|
||||
String getName() const override { return "MySQL"; }
|
||||
|
||||
@ -31,6 +32,7 @@ private:
|
||||
mysqlxx::Query query;
|
||||
mysqlxx::UseQueryResult result;
|
||||
const UInt64 max_block_size;
|
||||
const bool auto_close;
|
||||
ExternalResultDescription description;
|
||||
};
|
||||
|
||||
|
86
dbms/src/Formats/MySQLWireBlockOutputStream.cpp
Normal file
86
dbms/src/Formats/MySQLWireBlockOutputStream.cpp
Normal file
@ -0,0 +1,86 @@
|
||||
#include "MySQLWireBlockOutputStream.h"
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <Interpreters/ProcessList.h>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
using namespace MySQLProtocol;
|
||||
|
||||
MySQLWireBlockOutputStream::MySQLWireBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context)
|
||||
: header(header)
|
||||
, context(context)
|
||||
, packet_sender(std::make_shared<PacketSender>(buf, context.sequence_id))
|
||||
{
|
||||
packet_sender->max_packet_size = context.max_packet_size;
|
||||
}
|
||||
|
||||
void MySQLWireBlockOutputStream::writePrefix()
|
||||
{
|
||||
if (header.columns() == 0)
|
||||
return;
|
||||
|
||||
packet_sender->sendPacket(LengthEncodedNumber(header.columns()));
|
||||
|
||||
for (const ColumnWithTypeAndName & column : header.getColumnsWithTypeAndName())
|
||||
{
|
||||
ColumnDefinition column_definition(column.name, CharacterSet::binary, 0, ColumnType::MYSQL_TYPE_STRING, 0, 0);
|
||||
packet_sender->sendPacket(column_definition);
|
||||
}
|
||||
|
||||
if (!(context.client_capabilities & Capability::CLIENT_DEPRECATE_EOF))
|
||||
{
|
||||
packet_sender->sendPacket(EOF_Packet(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
void MySQLWireBlockOutputStream::write(const Block & block)
|
||||
{
|
||||
size_t rows = block.rows();
|
||||
|
||||
for (size_t i = 0; i < rows; i++)
|
||||
{
|
||||
ResultsetRow row_packet;
|
||||
for (const ColumnWithTypeAndName & column : block)
|
||||
{
|
||||
String column_value;
|
||||
WriteBufferFromString ostr(column_value);
|
||||
column.type->serializeAsText(*column.column.get(), i, ostr, format_settings);
|
||||
ostr.finish();
|
||||
|
||||
row_packet.appendColumn(std::move(column_value));
|
||||
}
|
||||
packet_sender->sendPacket(row_packet);
|
||||
}
|
||||
}
|
||||
|
||||
void MySQLWireBlockOutputStream::writeSuffix()
|
||||
{
|
||||
QueryStatus * process_list_elem = context.getProcessListElement();
|
||||
CurrentThread::finalizePerformanceCounters();
|
||||
QueryStatusInfo info = process_list_elem->getInfo();
|
||||
size_t affected_rows = info.written_rows;
|
||||
|
||||
std::stringstream human_readable_info;
|
||||
human_readable_info << std::fixed << std::setprecision(3)
|
||||
<< "Read " << info.read_rows << " rows, " << formatReadableSizeWithBinarySuffix(info.read_bytes) << " in " << info.elapsed_seconds << " sec., "
|
||||
<< static_cast<size_t>(info.read_rows / info.elapsed_seconds) << " rows/sec., "
|
||||
<< formatReadableSizeWithBinarySuffix(info.read_bytes / info.elapsed_seconds) << "/sec.";
|
||||
|
||||
if (header.columns() == 0)
|
||||
packet_sender->sendPacket(OK_Packet(0x0, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
else
|
||||
if (context.client_capabilities & CLIENT_DEPRECATE_EOF)
|
||||
packet_sender->sendPacket(OK_Packet(0xfe, context.client_capabilities, affected_rows, 0, 0, "", human_readable_info.str()), true);
|
||||
else
|
||||
packet_sender->sendPacket(EOF_Packet(0, 0), true);
|
||||
}
|
||||
|
||||
void MySQLWireBlockOutputStream::flush()
|
||||
{
|
||||
packet_sender->out->next();
|
||||
}
|
||||
|
||||
}
|
36
dbms/src/Formats/MySQLWireBlockOutputStream.h
Normal file
36
dbms/src/Formats/MySQLWireBlockOutputStream.h
Normal file
@ -0,0 +1,36 @@
|
||||
#pragma once
|
||||
|
||||
#include <Core/MySQLProtocol.h>
|
||||
#include <DataStreams/IBlockOutputStream.h>
|
||||
#include <Formats/FormatFactory.h>
|
||||
#include <Formats/FormatSettings.h>
|
||||
#include <Interpreters/Context.h>
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
/** Interface for writing rows in MySQL Client/Server Protocol format.
|
||||
*/
|
||||
class MySQLWireBlockOutputStream : public IBlockOutputStream
|
||||
{
|
||||
public:
|
||||
MySQLWireBlockOutputStream(WriteBuffer & buf, const Block & header, Context & context);
|
||||
|
||||
Block getHeader() const { return header; }
|
||||
|
||||
void write(const Block & block);
|
||||
|
||||
void writePrefix();
|
||||
void writeSuffix();
|
||||
|
||||
void flush();
|
||||
private:
|
||||
Block header;
|
||||
Context & context;
|
||||
std::shared_ptr<MySQLProtocol::PacketSender> packet_sender;
|
||||
FormatSettings format_settings;
|
||||
};
|
||||
|
||||
using MySQLWireBlockOutputStreamPtr = std::shared_ptr<MySQLWireBlockOutputStream>;
|
||||
|
||||
}
|
19
dbms/src/Formats/MySQLWireFormat.cpp
Normal file
19
dbms/src/Formats/MySQLWireFormat.cpp
Normal file
@ -0,0 +1,19 @@
|
||||
#include <Formats/MySQLWireBlockOutputStream.h>
|
||||
|
||||
|
||||
namespace DB
|
||||
{
|
||||
|
||||
void registerOutputFormatMySQLWire(FormatFactory & factory)
|
||||
{
|
||||
factory.registerOutputFormat("MySQLWire", [](
|
||||
WriteBuffer & buf,
|
||||
const Block & sample,
|
||||
const Context & context,
|
||||
const FormatSettings &)
|
||||
{
|
||||
return std::make_shared<MySQLWireBlockOutputStream>(buf, sample, const_cast<Context &>(context));
|
||||
});
|
||||
}
|
||||
|
||||
}
|
@ -308,7 +308,7 @@ bool OPTIMIZE(1) TabSeparatedRowInputStream::parseRowAndPrintDiagnosticInfo(
|
||||
if (curr_position < prev_position)
|
||||
throw Exception("Logical error: parsing is non-deterministic.", ErrorCodes::LOGICAL_ERROR);
|
||||
|
||||
if (isNumber(current_column_type) || isDateOrDateTime(current_column_type))
|
||||
if (isNativeNumber(current_column_type) || isDateOrDateTime(current_column_type))
|
||||
{
|
||||
/// An empty string instead of a value.
|
||||
if (curr_position == prev_position)
|
||||
|
@ -215,10 +215,10 @@ void XMLRowOutputStream::writeStatistics()
|
||||
writeText(watch.elapsedSeconds(), *ostr);
|
||||
writeCString("</elapsed>\n", *ostr);
|
||||
writeCString("\t\t<rows_read>", *ostr);
|
||||
writeText(progress.rows.load(), *ostr);
|
||||
writeText(progress.read_rows.load(), *ostr);
|
||||
writeCString("</rows_read>\n", *ostr);
|
||||
writeCString("\t\t<bytes_read>", *ostr);
|
||||
writeText(progress.bytes.load(), *ostr);
|
||||
writeText(progress.read_bytes.load(), *ostr);
|
||||
writeCString("</bytes_read>\n", *ostr);
|
||||
writeCString("\t</statistics>\n", *ostr);
|
||||
}
|
||||
|
@ -263,7 +263,7 @@ public:
|
||||
+ toString(arguments.size()) + ", should be 2 or 3",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
if (!isNumber(arguments[1].type))
|
||||
if (!isNativeNumber(arguments[1].type))
|
||||
throw Exception("Second argument for function " + getName() + " (delta) must be number",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
|
@ -59,7 +59,7 @@ private:
|
||||
{
|
||||
const auto check_argument_type = [this] (const IDataType * arg)
|
||||
{
|
||||
if (!isNumber(arg))
|
||||
if (!isNativeNumber(arg))
|
||||
throw Exception{"Illegal type " + arg->getName() + " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
};
|
||||
|
@ -56,7 +56,7 @@ private:
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
const auto & arg = arguments.front();
|
||||
if (!isNumber(arg) && !isDecimal(arg))
|
||||
if (!isNumber(arg))
|
||||
throw Exception{"Illegal type " + arg->getName() + " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
|
@ -37,7 +37,7 @@ public:
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
if (!isNumber(arguments.front()))
|
||||
if (!isNativeNumber(arguments.front()))
|
||||
throw Exception{"Argument for function " + getName() + " must be number", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
return std::make_shared<DataTypeUInt8>();
|
||||
|
@ -21,5 +21,7 @@ void registerFunctionsBitmap(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionBitmapXor>();
|
||||
factory.registerFunction<FunctionBitmapAndnot>();
|
||||
|
||||
factory.registerFunction<FunctionBitmapHasAll>();
|
||||
factory.registerFunction<FunctionBitmapHasAny>();
|
||||
}
|
||||
}
|
||||
|
@ -342,7 +342,27 @@ struct BitmapAndnotCardinalityImpl
|
||||
}
|
||||
};
|
||||
|
||||
template <template <typename> class Impl, typename Name>
|
||||
template <typename T>
|
||||
struct BitmapHasAllImpl
|
||||
{
|
||||
using ReturnType = UInt8;
|
||||
static UInt8 apply(const AggregateFunctionGroupBitmapData<T> & bd1, const AggregateFunctionGroupBitmapData<T> & bd2)
|
||||
{
|
||||
return bd1.rbs.rb_is_subset(bd2.rbs);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BitmapHasAnyImpl
|
||||
{
|
||||
using ReturnType = UInt8;
|
||||
static UInt8 apply(const AggregateFunctionGroupBitmapData<T> & bd1, const AggregateFunctionGroupBitmapData<T> & bd2)
|
||||
{
|
||||
return bd1.rbs.rb_intersect(bd2.rbs);
|
||||
}
|
||||
};
|
||||
|
||||
template <template <typename> class Impl, typename Name, typename ToType>
|
||||
class FunctionBitmapCardinality : public IFunction
|
||||
{
|
||||
public:
|
||||
@ -369,6 +389,13 @@ public:
|
||||
throw Exception(
|
||||
"Second argument for function " + getName() + " must be an bitmap but it has type " + arguments[1]->getName() + ".",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
if (bitmap_type0->getArgumentsDataTypes()[0]->getTypeId() != bitmap_type1->getArgumentsDataTypes()[0]->getTypeId())
|
||||
throw Exception(
|
||||
"The nested type in bitmaps must be the same, but one is " + bitmap_type0->getArgumentsDataTypes()[0]->getName()
|
||||
+ ", and the other is " + bitmap_type1->getArgumentsDataTypes()[0]->getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return std::make_shared<DataTypeNumber<ToType>>();
|
||||
}
|
||||
|
||||
@ -398,8 +425,6 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
using ToType = UInt64;
|
||||
|
||||
template <typename T>
|
||||
void executeIntType(
|
||||
Block & block, const ColumnNumbers & arguments, size_t input_rows_count, typename ColumnVector<ToType>::Container & vec_to)
|
||||
@ -487,6 +512,13 @@ public:
|
||||
throw Exception(
|
||||
"Second argument for function " + getName() + " must be an bitmap but it has type " + arguments[1]->getName() + ".",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
if (bitmap_type0->getArgumentsDataTypes()[0]->getTypeId() != bitmap_type1->getArgumentsDataTypes()[0]->getTypeId())
|
||||
throw Exception(
|
||||
"The nested type in bitmaps must be the same, but one is " + bitmap_type0->getArgumentsDataTypes()[0]->getName()
|
||||
+ ", and the other is " + bitmap_type1->getArgumentsDataTypes()[0]->getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return arguments[0];
|
||||
}
|
||||
|
||||
@ -571,12 +603,22 @@ struct NameBitmapAndnotCardinality
|
||||
{
|
||||
static constexpr auto name = "bitmapAndnotCardinality";
|
||||
};
|
||||
struct NameBitmapHasAll
|
||||
{
|
||||
static constexpr auto name = "bitmapHasAll";
|
||||
};
|
||||
struct NameBitmapHasAny
|
||||
{
|
||||
static constexpr auto name = "bitmapHasAny";
|
||||
};
|
||||
|
||||
using FunctionBitmapSelfCardinality = FunctionBitmapSelfCardinalityImpl<NameBitmapCardinality>;
|
||||
using FunctionBitmapAndCardinality = FunctionBitmapCardinality<BitmapAndCardinalityImpl, NameBitmapAndCardinality>;
|
||||
using FunctionBitmapOrCardinality = FunctionBitmapCardinality<BitmapOrCardinalityImpl, NameBitmapOrCardinality>;
|
||||
using FunctionBitmapXorCardinality = FunctionBitmapCardinality<BitmapXorCardinalityImpl, NameBitmapXorCardinality>;
|
||||
using FunctionBitmapAndnotCardinality = FunctionBitmapCardinality<BitmapAndnotCardinalityImpl, NameBitmapAndnotCardinality>;
|
||||
using FunctionBitmapAndCardinality = FunctionBitmapCardinality<BitmapAndCardinalityImpl, NameBitmapAndCardinality, UInt64>;
|
||||
using FunctionBitmapOrCardinality = FunctionBitmapCardinality<BitmapOrCardinalityImpl, NameBitmapOrCardinality, UInt64>;
|
||||
using FunctionBitmapXorCardinality = FunctionBitmapCardinality<BitmapXorCardinalityImpl, NameBitmapXorCardinality, UInt64>;
|
||||
using FunctionBitmapAndnotCardinality = FunctionBitmapCardinality<BitmapAndnotCardinalityImpl, NameBitmapAndnotCardinality, UInt64>;
|
||||
using FunctionBitmapHasAll = FunctionBitmapCardinality<BitmapHasAllImpl, NameBitmapHasAll, UInt8>;
|
||||
using FunctionBitmapHasAny = FunctionBitmapCardinality<BitmapHasAnyImpl, NameBitmapHasAny, UInt8>;
|
||||
|
||||
struct NameBitmapAnd
|
||||
{
|
||||
|
@ -20,7 +20,7 @@ void throwExceptionForIncompletelyParsedValue(
|
||||
else
|
||||
message_buf << " at begin of string";
|
||||
|
||||
if (isNumber(to_type))
|
||||
if (isNativeNumber(to_type))
|
||||
message_buf << ". Note: there are to" << to_type.getName() << "OrZero and to" << to_type.getName() << "OrNull functions, which returns zero/NULL instead of throwing exception.";
|
||||
|
||||
throw Exception(message_buf.str(), ErrorCodes::CANNOT_PARSE_TEXT);
|
||||
|
@ -1785,7 +1785,7 @@ private:
|
||||
return createStringToEnumWrapper<ColumnString, EnumType>();
|
||||
else if (checkAndGetDataType<DataTypeFixedString>(from_type.get()))
|
||||
return createStringToEnumWrapper<ColumnFixedString, EnumType>();
|
||||
else if (isNumber(from_type) || isEnum(from_type))
|
||||
else if (isNativeNumber(from_type) || isEnum(from_type))
|
||||
{
|
||||
auto function = Function::create(context);
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <DataTypes/DataTypesDecimal.h>
|
||||
#include <DataTypes/DataTypeArray.h>
|
||||
#include <DataTypes/DataTypeString.h>
|
||||
#include <DataTypes/DataTypeDate.h>
|
||||
@ -661,21 +662,77 @@ DECLARE_DICT_GET_TRAITS(UInt32, DataTypeDateTime)
|
||||
DECLARE_DICT_GET_TRAITS(UInt128, DataTypeUUID)
|
||||
#undef DECLARE_DICT_GET_TRAITS
|
||||
|
||||
template <typename T> struct DictGetTraits<DataTypeDecimal<T>>
|
||||
{
|
||||
static constexpr bool is_dec32 = std::is_same_v<T, Decimal32>;
|
||||
static constexpr bool is_dec64 = std::is_same_v<T, Decimal64>;
|
||||
static constexpr bool is_dec128 = std::is_same_v<T, Decimal128>;
|
||||
|
||||
template <typename DictionaryType>
|
||||
static void get(const DictionaryType * dict, const std::string & name, const PaddedPODArray<UInt64> & ids,
|
||||
DecimalPaddedPODArray<T> & out)
|
||||
{
|
||||
if constexpr (is_dec32) dict->getDecimal32(name, ids, out);
|
||||
if constexpr (is_dec64) dict->getDecimal64(name, ids, out);
|
||||
if constexpr (is_dec128) dict->getDecimal128(name, ids, out);
|
||||
}
|
||||
|
||||
template <typename DictionaryType>
|
||||
static void get(const DictionaryType * dict, const std::string & name, const Columns & key_columns, const DataTypes & key_types,
|
||||
DecimalPaddedPODArray<T> & out)
|
||||
{
|
||||
if constexpr (is_dec32) dict->getDecimal32(name, key_columns, key_types, out);
|
||||
if constexpr (is_dec64) dict->getDecimal64(name, key_columns, key_types, out);
|
||||
if constexpr (is_dec128) dict->getDecimal128(name, key_columns, key_types, out);
|
||||
}
|
||||
|
||||
template <typename DictionaryType>
|
||||
static void get(const DictionaryType * dict, const std::string & name, const PaddedPODArray<UInt64> & ids,
|
||||
const PaddedPODArray<Int64> & dates, DecimalPaddedPODArray<T> & out)
|
||||
{
|
||||
if constexpr (is_dec32) dict->getDecimal32(name, ids, dates, out);
|
||||
if constexpr (is_dec64) dict->getDecimal64(name, ids, dates, out);
|
||||
if constexpr (is_dec128) dict->getDecimal128(name, ids, dates, out);
|
||||
}
|
||||
|
||||
template <typename DictionaryType, typename DefaultsType>
|
||||
static void getOrDefault(const DictionaryType * dict, const std::string & name, const PaddedPODArray<UInt64> & ids,
|
||||
const DefaultsType & def, DecimalPaddedPODArray<T> & out)
|
||||
{
|
||||
if constexpr (is_dec32) dict->getDecimal32(name, ids, def, out);
|
||||
if constexpr (is_dec64) dict->getDecimal64(name, ids, def, out);
|
||||
if constexpr (is_dec128) dict->getDecimal128(name, ids, def, out);
|
||||
}
|
||||
|
||||
template <typename DictionaryType, typename DefaultsType>
|
||||
static void getOrDefault(const DictionaryType * dict, const std::string & name, const Columns & key_columns,
|
||||
const DataTypes & key_types, const DefaultsType & def, DecimalPaddedPODArray<T> & out)
|
||||
{
|
||||
if constexpr (is_dec32) dict->getDecimal32(name, key_columns, key_types, def, out);
|
||||
if constexpr (is_dec64) dict->getDecimal64(name, key_columns, key_types, def, out);
|
||||
if constexpr (is_dec128) dict->getDecimal128(name, key_columns, key_types, def, out);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename DataType, typename Name>
|
||||
class FunctionDictGet final : public IFunction
|
||||
{
|
||||
using Type = typename DataType::FieldType;
|
||||
using ColVec = std::conditional_t<IsDecimalNumber<Type>, ColumnDecimal<Type>, ColumnVector<Type>>;
|
||||
|
||||
public:
|
||||
static constexpr auto name = Name::name;
|
||||
|
||||
static FunctionPtr create(const Context & context)
|
||||
static FunctionPtr create(const Context & context, UInt32 dec_scale = 0)
|
||||
{
|
||||
return std::make_shared<FunctionDictGet>(context.getExternalDictionaries());
|
||||
return std::make_shared<FunctionDictGet>(context.getExternalDictionaries(), dec_scale);
|
||||
}
|
||||
|
||||
FunctionDictGet(const ExternalDictionaries & dictionaries) : dictionaries(dictionaries) {}
|
||||
FunctionDictGet(const ExternalDictionaries & dictionaries, UInt32 dec_scale = 0)
|
||||
: dictionaries(dictionaries)
|
||||
, decimal_scale(dec_scale)
|
||||
{}
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
@ -719,7 +776,10 @@ private:
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
}
|
||||
|
||||
return std::make_shared<DataType>();
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
return std::make_shared<DataType>(DataType::maxPrecision(), decimal_scale);
|
||||
else
|
||||
return std::make_shared<DataType>();
|
||||
}
|
||||
|
||||
bool isDeterministic() const override { return false; }
|
||||
@ -771,7 +831,11 @@ private:
|
||||
const auto id_col_untyped = block.getByPosition(arguments[2]).column.get();
|
||||
if (const auto id_col = checkAndGetColumn<ColumnUInt64>(id_col_untyped))
|
||||
{
|
||||
auto out = ColumnVector<Type>::create(id_col->size());
|
||||
typename ColVec::MutablePtr out;
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
out = ColVec::create(id_col->size(), decimal_scale);
|
||||
else
|
||||
out = ColVec::create(id_col->size());
|
||||
const auto & ids = id_col->getData();
|
||||
auto & data = out->getData();
|
||||
DictGetTraits<DataType>::get(dict, attr_name, ids, data);
|
||||
@ -780,9 +844,21 @@ private:
|
||||
else if (const auto id_col_const = checkAndGetColumnConst<ColumnVector<UInt64>>(id_col_untyped))
|
||||
{
|
||||
const PaddedPODArray<UInt64> ids(1, id_col_const->getValue<UInt64>());
|
||||
PaddedPODArray<Type> data(1);
|
||||
DictGetTraits<DataType>::get(dict, attr_name, ids, data);
|
||||
block.getByPosition(result).column = DataTypeNumber<Type>().createColumnConst(id_col_const->size(), toField(data.front()));
|
||||
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
{
|
||||
DecimalPaddedPODArray<Type> data(1, decimal_scale);
|
||||
DictGetTraits<DataType>::get(dict, attr_name, ids, data);
|
||||
block.getByPosition(result).column =
|
||||
DataType(DataType::maxPrecision(), decimal_scale).createColumnConst(
|
||||
id_col_const->size(), toField(data.front(), decimal_scale));
|
||||
}
|
||||
else
|
||||
{
|
||||
PaddedPODArray<Type> data(1);
|
||||
DictGetTraits<DataType>::get(dict, attr_name, ids, data);
|
||||
block.getByPosition(result).column = DataTypeNumber<Type>().createColumnConst(id_col_const->size(), toField(data.front()));
|
||||
}
|
||||
}
|
||||
else
|
||||
throw Exception{"Third argument of function " + getName() + " must be UInt64", ErrorCodes::ILLEGAL_COLUMN};
|
||||
@ -818,7 +894,11 @@ private:
|
||||
const auto & key_columns = static_cast<const ColumnTuple &>(*key_col).getColumnsCopy();
|
||||
const auto & key_types = static_cast<const DataTypeTuple &>(*key_col_with_type.type).getElements();
|
||||
|
||||
auto out = ColumnVector<Type>::create(key_columns.front()->size());
|
||||
typename ColVec::MutablePtr out;
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
out = ColVec::create(key_columns.front()->size(), decimal_scale);
|
||||
else
|
||||
out = ColVec::create(key_columns.front()->size());
|
||||
auto & data = out->getData();
|
||||
DictGetTraits<DataType>::get(dict, attr_name, key_columns, key_types, data);
|
||||
block.getByPosition(result).column = std::move(out);
|
||||
@ -855,7 +935,11 @@ private:
|
||||
const auto & id_col_values = getColumnDataAsPaddedPODArray(*id_col_untyped, id_col_values_storage);
|
||||
const auto & range_col_values = getColumnDataAsPaddedPODArray(*range_col_untyped, range_col_values_storage);
|
||||
|
||||
auto out = ColumnVector<Type>::create(id_col_untyped->size());
|
||||
typename ColVec::MutablePtr out;
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
out = ColVec::create(id_col_untyped->size(), decimal_scale);
|
||||
else
|
||||
out = ColVec::create(id_col_untyped->size());
|
||||
auto & data = out->getData();
|
||||
DictGetTraits<DataType>::get(dict, attr_name, id_col_values, range_col_values, data);
|
||||
block.getByPosition(result).column = std::move(out);
|
||||
@ -864,6 +948,7 @@ private:
|
||||
}
|
||||
|
||||
const ExternalDictionaries & dictionaries;
|
||||
UInt32 decimal_scale;
|
||||
};
|
||||
|
||||
struct NameDictGetUInt8 { static constexpr auto name = "dictGetUInt8"; };
|
||||
@ -879,6 +964,9 @@ struct NameDictGetFloat64 { static constexpr auto name = "dictGetFloat64"; };
|
||||
struct NameDictGetDate { static constexpr auto name = "dictGetDate"; };
|
||||
struct NameDictGetDateTime { static constexpr auto name = "dictGetDateTime"; };
|
||||
struct NameDictGetUUID { static constexpr auto name = "dictGetUUID"; };
|
||||
struct NameDictGetDecimal32 { static constexpr auto name = "dictGetDecimal32"; };
|
||||
struct NameDictGetDecimal64 { static constexpr auto name = "dictGetDecimal64"; };
|
||||
struct NameDictGetDecimal128 { static constexpr auto name = "dictGetDecimal128"; };
|
||||
|
||||
using FunctionDictGetUInt8 = FunctionDictGet<DataTypeUInt8, NameDictGetUInt8>;
|
||||
using FunctionDictGetUInt16 = FunctionDictGet<DataTypeUInt16, NameDictGetUInt16>;
|
||||
@ -893,22 +981,29 @@ using FunctionDictGetFloat64 = FunctionDictGet<DataTypeFloat64, NameDictGetFloat
|
||||
using FunctionDictGetDate = FunctionDictGet<DataTypeDate, NameDictGetDate>;
|
||||
using FunctionDictGetDateTime = FunctionDictGet<DataTypeDateTime, NameDictGetDateTime>;
|
||||
using FunctionDictGetUUID = FunctionDictGet<DataTypeUUID, NameDictGetUUID>;
|
||||
using FunctionDictGetDecimal32 = FunctionDictGet<DataTypeDecimal<Decimal32>, NameDictGetDecimal32>;
|
||||
using FunctionDictGetDecimal64 = FunctionDictGet<DataTypeDecimal<Decimal64>, NameDictGetDecimal64>;
|
||||
using FunctionDictGetDecimal128 = FunctionDictGet<DataTypeDecimal<Decimal128>, NameDictGetDecimal128>;
|
||||
|
||||
|
||||
template <typename DataType, typename Name>
|
||||
class FunctionDictGetOrDefault final : public IFunction
|
||||
{
|
||||
using Type = typename DataType::FieldType;
|
||||
using ColVec = std::conditional_t<IsDecimalNumber<Type>, ColumnDecimal<Type>, ColumnVector<Type>>;
|
||||
|
||||
public:
|
||||
static constexpr auto name = Name::name;
|
||||
|
||||
static FunctionPtr create(const Context & context)
|
||||
static FunctionPtr create(const Context & context, UInt32 dec_scale = 0)
|
||||
{
|
||||
return std::make_shared<FunctionDictGetOrDefault>(context.getExternalDictionaries());
|
||||
return std::make_shared<FunctionDictGetOrDefault>(context.getExternalDictionaries(), dec_scale);
|
||||
}
|
||||
|
||||
FunctionDictGetOrDefault(const ExternalDictionaries & dictionaries) : dictionaries(dictionaries) {}
|
||||
FunctionDictGetOrDefault(const ExternalDictionaries & dictionaries, UInt32 dec_scale = 0)
|
||||
: dictionaries(dictionaries)
|
||||
, decimal_scale(dec_scale)
|
||||
{}
|
||||
|
||||
String getName() const override { return name; }
|
||||
|
||||
@ -935,9 +1030,12 @@ private:
|
||||
|
||||
if (!checkAndGetDataType<DataType>(arguments[3].get()))
|
||||
throw Exception{"Illegal type " + arguments[3]->getName() + " of fourth argument of function " + getName()
|
||||
+ ", must be " + String(DataType{}.getFamilyName()) + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
+ ", must be " + TypeName<Type>::get() + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
return std::make_shared<DataType>();
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
return std::make_shared<DataType>(DataType::maxPrecision(), decimal_scale);
|
||||
else
|
||||
return std::make_shared<DataType>();
|
||||
}
|
||||
|
||||
bool isDeterministic() const override { return false; }
|
||||
@ -999,20 +1097,28 @@ private:
|
||||
{
|
||||
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
|
||||
|
||||
if (const auto default_col = checkAndGetColumn<ColumnVector<Type>>(default_col_untyped))
|
||||
if (const auto default_col = checkAndGetColumn<ColVec>(default_col_untyped))
|
||||
{
|
||||
/// vector ids, vector defaults
|
||||
auto out = ColumnVector<Type>::create(id_col->size());
|
||||
typename ColVec::MutablePtr out;
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
out = ColVec::create(id_col->size(), decimal_scale);
|
||||
else
|
||||
out = ColVec::create(id_col->size());
|
||||
const auto & ids = id_col->getData();
|
||||
auto & data = out->getData();
|
||||
const auto & defs = default_col->getData();
|
||||
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, defs, data);
|
||||
block.getByPosition(result).column = std::move(out);
|
||||
}
|
||||
else if (const auto default_col_const = checkAndGetColumnConst<ColumnVector<Type>>(default_col_untyped))
|
||||
else if (const auto default_col_const = checkAndGetColumnConst<ColVec>(default_col_untyped))
|
||||
{
|
||||
/// vector ids, const defaults
|
||||
auto out = ColumnVector<Type>::create(id_col->size());
|
||||
typename ColVec::MutablePtr out;
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
out = ColVec::create(id_col->size(), decimal_scale);
|
||||
else
|
||||
out = ColVec::create(id_col->size());
|
||||
const auto & ids = id_col->getData();
|
||||
auto & data = out->getData();
|
||||
const auto def = default_col_const->template getValue<Type>();
|
||||
@ -1020,7 +1126,7 @@ private:
|
||||
block.getByPosition(result).column = std::move(out);
|
||||
}
|
||||
else
|
||||
throw Exception{"Fourth argument of function " + getName() + " must be " + String(DataType{}.getFamilyName()), ErrorCodes::ILLEGAL_COLUMN};
|
||||
throw Exception{"Fourth argument of function " + getName() + " must be " + TypeName<Type>::get(), ErrorCodes::ILLEGAL_COLUMN};
|
||||
}
|
||||
|
||||
template <typename DictionaryType>
|
||||
@ -1030,7 +1136,7 @@ private:
|
||||
{
|
||||
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
|
||||
|
||||
if (const auto default_col = checkAndGetColumn<ColumnVector<Type>>(default_col_untyped))
|
||||
if (const auto default_col = checkAndGetColumn<ColVec>(default_col_untyped))
|
||||
{
|
||||
/// const ids, vector defaults
|
||||
const PaddedPODArray<UInt64> ids(1, id_col->getValue<UInt64>());
|
||||
@ -1038,24 +1144,48 @@ private:
|
||||
dictionary->has(ids, flags);
|
||||
if (flags.front())
|
||||
{
|
||||
PaddedPODArray<Type> data(1);
|
||||
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, Type(), data);
|
||||
block.getByPosition(result).column = DataTypeNumber<Type>().createColumnConst(id_col->size(), toField(data.front()));
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
{
|
||||
DecimalPaddedPODArray<Type> data(1, decimal_scale);
|
||||
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, Type(), data);
|
||||
block.getByPosition(result).column =
|
||||
DataType(DataType::maxPrecision(), decimal_scale).createColumnConst(
|
||||
id_col->size(), toField(data.front(), decimal_scale));
|
||||
}
|
||||
else
|
||||
{
|
||||
PaddedPODArray<Type> data(1);
|
||||
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, Type(), data);
|
||||
block.getByPosition(result).column = DataType().createColumnConst(id_col->size(), toField(data.front()));
|
||||
}
|
||||
}
|
||||
else
|
||||
block.getByPosition(result).column = block.getByPosition(arguments[3]).column; // reuse the default column
|
||||
}
|
||||
else if (const auto default_col_const = checkAndGetColumnConst<ColumnVector<Type>>(default_col_untyped))
|
||||
else if (const auto default_col_const = checkAndGetColumnConst<ColVec>(default_col_untyped))
|
||||
{
|
||||
/// const ids, const defaults
|
||||
const PaddedPODArray<UInt64> ids(1, id_col->getValue<UInt64>());
|
||||
PaddedPODArray<Type> data(1);
|
||||
const auto & def = default_col_const->template getValue<Type>();
|
||||
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, def, data);
|
||||
block.getByPosition(result).column = DataTypeNumber<Type>().createColumnConst(id_col->size(), toField(data.front()));
|
||||
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
{
|
||||
DecimalPaddedPODArray<Type> data(1, decimal_scale);
|
||||
const auto & def = default_col_const->template getValue<Type>();
|
||||
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, def, data);
|
||||
block.getByPosition(result).column =
|
||||
DataType(DataType::maxPrecision(), decimal_scale).createColumnConst(
|
||||
id_col->size(), toField(data.front(), decimal_scale));
|
||||
}
|
||||
else
|
||||
{
|
||||
PaddedPODArray<Type> data(1);
|
||||
const auto & def = default_col_const->template getValue<Type>();
|
||||
DictGetTraits<DataType>::getOrDefault(dictionary, attr_name, ids, def, data);
|
||||
block.getByPosition(result).column = DataType().createColumnConst(id_col->size(), toField(data.front()));
|
||||
}
|
||||
}
|
||||
else
|
||||
throw Exception{"Fourth argument of function " + getName() + " must be " + String(DataType{}.getFamilyName()), ErrorCodes::ILLEGAL_COLUMN};
|
||||
throw Exception{"Fourth argument of function " + getName() + " must be " + TypeName<Type>::get(), ErrorCodes::ILLEGAL_COLUMN};
|
||||
}
|
||||
|
||||
template <typename DictionaryType>
|
||||
@ -1082,31 +1212,36 @@ private:
|
||||
|
||||
/// @todo detect when all key columns are constant
|
||||
const auto rows = key_col->size();
|
||||
auto out = ColumnVector<Type>::create(rows);
|
||||
typename ColVec::MutablePtr out;
|
||||
if constexpr (IsDataTypeDecimal<DataType>)
|
||||
out = ColVec::create(rows, decimal_scale);
|
||||
else
|
||||
out = ColVec::create(rows);
|
||||
auto & data = out->getData();
|
||||
|
||||
const auto default_col_untyped = block.getByPosition(arguments[3]).column.get();
|
||||
if (const auto default_col = checkAndGetColumn<ColumnVector<Type>>(default_col_untyped))
|
||||
if (const auto default_col = checkAndGetColumn<ColVec>(default_col_untyped))
|
||||
{
|
||||
/// const defaults
|
||||
const auto & defs = default_col->getData();
|
||||
|
||||
DictGetTraits<DataType>::getOrDefault(dict, attr_name, key_columns, key_types, defs, data);
|
||||
}
|
||||
else if (const auto default_col_const = checkAndGetColumnConst<ColumnVector<Type>>(default_col_untyped))
|
||||
else if (const auto default_col_const = checkAndGetColumnConst<ColVec>(default_col_untyped))
|
||||
{
|
||||
const auto def = default_col_const->template getValue<Type>();
|
||||
|
||||
DictGetTraits<DataType>::getOrDefault(dict, attr_name, key_columns, key_types, def, data);
|
||||
}
|
||||
else
|
||||
throw Exception{"Fourth argument of function " + getName() + " must be " + String(DataType{}.getFamilyName()), ErrorCodes::ILLEGAL_COLUMN};
|
||||
throw Exception{"Fourth argument of function " + getName() + " must be " + TypeName<Type>::get(), ErrorCodes::ILLEGAL_COLUMN};
|
||||
|
||||
block.getByPosition(result).column = std::move(out);
|
||||
return true;
|
||||
}
|
||||
|
||||
const ExternalDictionaries & dictionaries;
|
||||
UInt32 decimal_scale;
|
||||
};
|
||||
|
||||
struct NameDictGetUInt8OrDefault { static constexpr auto name = "dictGetUInt8OrDefault"; };
|
||||
@ -1122,6 +1257,9 @@ struct NameDictGetFloat64OrDefault { static constexpr auto name = "dictGetFloat6
|
||||
struct NameDictGetDateOrDefault { static constexpr auto name = "dictGetDateOrDefault"; };
|
||||
struct NameDictGetDateTimeOrDefault { static constexpr auto name = "dictGetDateTimeOrDefault"; };
|
||||
struct NameDictGetUUIDOrDefault { static constexpr auto name = "dictGetUUIDOrDefault"; };
|
||||
struct NameDictGetDecimal32OrDefault { static constexpr auto name = "dictGetDecimal32OrDefault"; };
|
||||
struct NameDictGetDecimal64OrDefault { static constexpr auto name = "dictGetDecimal64OrDefault"; };
|
||||
struct NameDictGetDecimal128OrDefault { static constexpr auto name = "dictGetDecimal128OrDefault"; };
|
||||
|
||||
using FunctionDictGetUInt8OrDefault = FunctionDictGetOrDefault<DataTypeUInt8, NameDictGetUInt8OrDefault>;
|
||||
using FunctionDictGetUInt16OrDefault = FunctionDictGetOrDefault<DataTypeUInt16, NameDictGetUInt16OrDefault>;
|
||||
@ -1136,21 +1274,10 @@ using FunctionDictGetFloat64OrDefault = FunctionDictGetOrDefault<DataTypeFloat64
|
||||
using FunctionDictGetDateOrDefault = FunctionDictGetOrDefault<DataTypeDate, NameDictGetDateOrDefault>;
|
||||
using FunctionDictGetDateTimeOrDefault = FunctionDictGetOrDefault<DataTypeDateTime, NameDictGetDateTimeOrDefault>;
|
||||
using FunctionDictGetUUIDOrDefault = FunctionDictGetOrDefault<DataTypeUUID, NameDictGetUUIDOrDefault>;
|
||||
using FunctionDictGetDecimal32OrDefault = FunctionDictGetOrDefault<DataTypeDecimal<Decimal32>, NameDictGetDecimal32OrDefault>;
|
||||
using FunctionDictGetDecimal64OrDefault = FunctionDictGetOrDefault<DataTypeDecimal<Decimal64>, NameDictGetDecimal64OrDefault>;
|
||||
using FunctionDictGetDecimal128OrDefault = FunctionDictGetOrDefault<DataTypeDecimal<Decimal128>, NameDictGetDecimal128OrDefault>;
|
||||
|
||||
#define FOR_DICT_TYPES(M) \
|
||||
M(UInt8) \
|
||||
M(UInt16) \
|
||||
M(UInt32) \
|
||||
M(UInt64) \
|
||||
M(Int8) \
|
||||
M(Int16) \
|
||||
M(Int32) \
|
||||
M(Int64) \
|
||||
M(Float32) \
|
||||
M(Float64) \
|
||||
M(Date) \
|
||||
M(DateTime) \
|
||||
M(UUID)
|
||||
|
||||
/// This variant of function derives the result type automatically.
|
||||
class FunctionDictGetNoType final : public IFunction
|
||||
@ -1225,15 +1352,63 @@ private:
|
||||
if (attribute.name == attr_name)
|
||||
{
|
||||
WhichDataType dt = attribute.type;
|
||||
if (dt.idx == TypeIndex::String)
|
||||
impl = FunctionDictGetString::create(context);
|
||||
#define DISPATCH(TYPE) \
|
||||
else if (dt.idx == TypeIndex::TYPE) \
|
||||
impl = FunctionDictGet<DataType##TYPE, NameDictGet##TYPE>::create(context);
|
||||
FOR_DICT_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
else
|
||||
throw Exception("Unknown dictGet type", ErrorCodes::UNKNOWN_TYPE);
|
||||
switch (dt.idx)
|
||||
{
|
||||
case TypeIndex::String:
|
||||
case TypeIndex::FixedString:
|
||||
impl = FunctionDictGetString::create(context);
|
||||
break;
|
||||
case TypeIndex::UInt8:
|
||||
impl = FunctionDictGetUInt8::create(context);
|
||||
break;
|
||||
case TypeIndex::UInt16:
|
||||
impl = FunctionDictGetUInt16::create(context);
|
||||
break;
|
||||
case TypeIndex::UInt32:
|
||||
impl = FunctionDictGetUInt32::create(context);
|
||||
break;
|
||||
case TypeIndex::UInt64:
|
||||
impl = FunctionDictGetUInt64::create(context);
|
||||
break;
|
||||
case TypeIndex::Int8:
|
||||
impl = FunctionDictGetInt8::create(context);
|
||||
break;
|
||||
case TypeIndex::Int16:
|
||||
impl = FunctionDictGetInt16::create(context);
|
||||
break;
|
||||
case TypeIndex::Int32:
|
||||
impl = FunctionDictGetInt32::create(context);
|
||||
break;
|
||||
case TypeIndex::Int64:
|
||||
impl = FunctionDictGetInt64::create(context);
|
||||
break;
|
||||
case TypeIndex::Float32:
|
||||
impl = FunctionDictGetFloat32::create(context);
|
||||
break;
|
||||
case TypeIndex::Float64:
|
||||
impl = FunctionDictGetFloat64::create(context);
|
||||
break;
|
||||
case TypeIndex::Date:
|
||||
impl = FunctionDictGetDate::create(context);
|
||||
break;
|
||||
case TypeIndex::DateTime:
|
||||
impl = FunctionDictGetDateTime::create(context);
|
||||
break;
|
||||
case TypeIndex::UUID:
|
||||
impl = FunctionDictGetUUID::create(context);
|
||||
break;
|
||||
case TypeIndex::Decimal32:
|
||||
impl = FunctionDictGetDecimal32::create(context, getDecimalScale(*attribute.type));
|
||||
break;
|
||||
case TypeIndex::Decimal64:
|
||||
impl = FunctionDictGetDecimal64::create(context, getDecimalScale(*attribute.type));
|
||||
break;
|
||||
case TypeIndex::Decimal128:
|
||||
impl = FunctionDictGetDecimal128::create(context, getDecimalScale(*attribute.type));
|
||||
break;
|
||||
default:
|
||||
throw Exception("Unknown dictGet type", ErrorCodes::UNKNOWN_TYPE);
|
||||
}
|
||||
return attribute.type;
|
||||
}
|
||||
}
|
||||
@ -1312,26 +1487,70 @@ private:
|
||||
const DictionaryAttribute & attribute = structure.attributes[idx];
|
||||
if (attribute.name == attr_name)
|
||||
{
|
||||
auto arg_type = arguments[3].type;
|
||||
WhichDataType dt = attribute.type;
|
||||
if (dt.idx == TypeIndex::String)
|
||||
|
||||
if ((arg_type->getTypeId() != dt.idx) || (dt.isStringOrFixedString() && !isString(arg_type)))
|
||||
throw Exception{"Illegal type " + arg_type->getName() + " of fourth argument of function " + getName() +
|
||||
", must be " + getTypeName(dt.idx) + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
switch (dt.idx)
|
||||
{
|
||||
if (!isString(arguments[3].type))
|
||||
throw Exception{"Illegal type " + arguments[3].type->getName() + " of fourth argument of function " + getName() +
|
||||
", must be String.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
impl = FunctionDictGetStringOrDefault::create(context);
|
||||
case TypeIndex::String:
|
||||
impl = FunctionDictGetStringOrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::UInt8:
|
||||
impl = FunctionDictGetUInt8OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::UInt16:
|
||||
impl = FunctionDictGetUInt16OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::UInt32:
|
||||
impl = FunctionDictGetUInt32OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::UInt64:
|
||||
impl = FunctionDictGetUInt64OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::Int8:
|
||||
impl = FunctionDictGetInt8OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::Int16:
|
||||
impl = FunctionDictGetInt16OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::Int32:
|
||||
impl = FunctionDictGetInt32OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::Int64:
|
||||
impl = FunctionDictGetInt64OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::Float32:
|
||||
impl = FunctionDictGetFloat32OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::Float64:
|
||||
impl = FunctionDictGetFloat64OrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::Date:
|
||||
impl = FunctionDictGetDateOrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::DateTime:
|
||||
impl = FunctionDictGetDateTimeOrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::UUID:
|
||||
impl = FunctionDictGetUUIDOrDefault::create(context);
|
||||
break;
|
||||
case TypeIndex::Decimal32:
|
||||
impl = FunctionDictGetDecimal32OrDefault::create(context, getDecimalScale(*attribute.type));
|
||||
break;
|
||||
case TypeIndex::Decimal64:
|
||||
impl = FunctionDictGetDecimal64OrDefault::create(context, getDecimalScale(*attribute.type));
|
||||
break;
|
||||
case TypeIndex::Decimal128:
|
||||
impl = FunctionDictGetDecimal128OrDefault::create(context, getDecimalScale(*attribute.type));
|
||||
break;
|
||||
default:
|
||||
throw Exception("Unknown dictGetOrDefault type", ErrorCodes::UNKNOWN_TYPE);
|
||||
}
|
||||
#define DISPATCH(TYPE) \
|
||||
else if (dt.idx == TypeIndex::TYPE) \
|
||||
{ \
|
||||
if (!checkAndGetDataType<DataType##TYPE>(arguments[3].type.get())) \
|
||||
throw Exception{"Illegal type " + arguments[3].type->getName() + " of fourth argument of function " + getName() \
|
||||
+ ", must be " + String(DataType##TYPE{}.getFamilyName()) + ".", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT}; \
|
||||
impl = FunctionDictGetOrDefault<DataType##TYPE, NameDictGet##TYPE ## OrDefault>::create(context); \
|
||||
}
|
||||
FOR_DICT_TYPES(DISPATCH)
|
||||
#undef DISPATCH
|
||||
else
|
||||
throw Exception("Unknown dictGetOrDefault type", ErrorCodes::UNKNOWN_TYPE);
|
||||
|
||||
return attribute.type;
|
||||
}
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ public:
|
||||
|
||||
const auto type_x = arguments[0];
|
||||
|
||||
if (!isNumber(type_x))
|
||||
if (!isNativeNumber(type_x))
|
||||
throw Exception{"Unsupported type " + type_x->getName() + " of first argument of function " + getName() + " must be a numeric type",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
||||
|
@ -143,7 +143,7 @@ public:
|
||||
{
|
||||
const IDataType & type = *arguments[0];
|
||||
|
||||
if (!isNumber(type))
|
||||
if (!isNativeNumber(type))
|
||||
throw Exception("Cannot format " + type.getName() + " as size in bytes", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return std::make_shared<DataTypeString>();
|
||||
|
@ -138,7 +138,7 @@ public:
|
||||
|
||||
for (auto j : ext::range(0, elements.size()))
|
||||
{
|
||||
if (!isNumber(elements[j]))
|
||||
if (!isNativeNumber(elements[j]))
|
||||
{
|
||||
throw Exception(getMsgPrefix(i) + " must contains numeric tuple at position " + toString(j + 1),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
@ -120,11 +120,22 @@ private:
|
||||
/// prepare() does Impl-specific preparation before handling each row.
|
||||
impl.prepare(Name::name, block, arguments, result_pos);
|
||||
|
||||
bool json_parsed_ok = false;
|
||||
if (col_json_const)
|
||||
{
|
||||
StringRef json{reinterpret_cast<const char *>(&chars[0]), offsets[0] - 1};
|
||||
json_parsed_ok = parser.parse(json);
|
||||
}
|
||||
|
||||
for (const auto i : ext::range(0, input_rows_count))
|
||||
{
|
||||
StringRef json{reinterpret_cast<const char *>(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1};
|
||||
bool ok = parser.parse(json);
|
||||
if (!col_json_const)
|
||||
{
|
||||
StringRef json{reinterpret_cast<const char *>(&chars[offsets[i - 1]]), offsets[i] - offsets[i - 1] - 1};
|
||||
json_parsed_ok = parser.parse(json);
|
||||
}
|
||||
|
||||
bool ok = json_parsed_ok;
|
||||
if (ok)
|
||||
{
|
||||
auto it = parser.getRoot();
|
||||
|
@ -309,8 +309,8 @@ public:
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
for (size_t i = 0; i < arguments.size(); ++i)
|
||||
if (!(isNumber(arguments[i])
|
||||
|| (Impl::specialImplementationForNulls() && (arguments[i]->onlyNull() || isNumber(removeNullable(arguments[i]))))))
|
||||
if (!(isNativeNumber(arguments[i])
|
||||
|| (Impl::specialImplementationForNulls() && (arguments[i]->onlyNull() || isNativeNumber(removeNullable(arguments[i]))))))
|
||||
throw Exception("Illegal type ("
|
||||
+ arguments[i]->getName()
|
||||
+ ") of " + toString(i + 1) + " argument of function " + getName(),
|
||||
@ -488,7 +488,7 @@ public:
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
if (!isNumber(arguments[0]))
|
||||
if (!isNativeNumber(arguments[0]))
|
||||
throw Exception("Illegal type ("
|
||||
+ arguments[0]->getName()
|
||||
+ ") of argument of function " + getName(),
|
||||
|
@ -500,7 +500,7 @@ public:
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
for (const auto & type : arguments)
|
||||
if (!isNumber(type) && !isDecimal(type))
|
||||
if (!isNumber(type))
|
||||
throw Exception("Illegal type " + arguments[0]->getName() + " of argument of function " + getName(),
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
@ -588,7 +588,7 @@ public:
|
||||
{
|
||||
const DataTypePtr & type_x = arguments[0];
|
||||
|
||||
if (!(isNumber(type_x) || isDecimal(type_x)))
|
||||
if (!isNumber(type_x))
|
||||
throw Exception{"Unsupported type " + type_x->getName()
|
||||
+ " of first argument of function " + getName()
|
||||
+ ", must be numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
@ -601,7 +601,7 @@ public:
|
||||
|
||||
const auto type_arr_nested = type_arr->getNestedType();
|
||||
|
||||
if (!(isNumber(type_arr_nested) || isDecimal(type_arr_nested)))
|
||||
if (!isNumber(type_arr_nested))
|
||||
{
|
||||
throw Exception{"Elements of array of second argument of function " + getName()
|
||||
+ " must be numeric type.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT};
|
||||
|
@ -32,7 +32,7 @@ namespace DB
|
||||
* calculation. If the right string size is big (more than 2**15 bytes),
|
||||
* the strings are not similar at all and we return 1.
|
||||
*/
|
||||
template <size_t N, class CodePoint, bool UTF8, bool CaseInsensitive>
|
||||
template <size_t N, class CodePoint, bool UTF8, bool CaseInsensitive, bool Symmetric>
|
||||
struct NgramDistanceImpl
|
||||
{
|
||||
using ResultType = Float32;
|
||||
@ -138,6 +138,7 @@ struct NgramDistanceImpl
|
||||
}
|
||||
|
||||
/// This is not a really true case insensitive utf8. We zero the 5-th bit of every byte.
|
||||
/// And first bit of first byte if there are two bytes.
|
||||
/// For ASCII it works https://catonmat.net/ascii-case-conversion-trick. For most cyrrilic letters also does.
|
||||
/// For others, we don't care now. Lowering UTF is not a cheap operation.
|
||||
if constexpr (CaseInsensitive)
|
||||
@ -151,6 +152,7 @@ struct NgramDistanceImpl
|
||||
res &= ~(1u << (5 + 2 * CHAR_BIT));
|
||||
[[fallthrough]];
|
||||
case 2:
|
||||
res &= ~(1u);
|
||||
res &= ~(1u << (5 + CHAR_BIT));
|
||||
[[fallthrough]];
|
||||
default:
|
||||
@ -222,9 +224,10 @@ struct NgramDistanceImpl
|
||||
for (; iter + N <= found; ++iter)
|
||||
{
|
||||
UInt16 hash = hash_functor(cp + iter);
|
||||
/// For symmetric version we should add when we can't subtract to get symmetric difference.
|
||||
if (static_cast<Int16>(ngram_stats[hash]) > 0)
|
||||
--distance;
|
||||
else
|
||||
else if constexpr (Symmetric)
|
||||
++distance;
|
||||
if constexpr (ReuseStats)
|
||||
ngram_storage[ngram_cnt] = hash;
|
||||
@ -267,7 +270,8 @@ struct NgramDistanceImpl
|
||||
if (data_size <= max_string_size)
|
||||
{
|
||||
size_t first_size = dispatchSearcher(calculateHaystackStatsAndMetric<false>, data.data(), data_size, common_stats, distance, nullptr);
|
||||
res = distance * 1.f / std::max(first_size + second_size, size_t(1));
|
||||
/// For !Symmetric version we should not use first_size.
|
||||
res = distance * 1.f / std::max(Symmetric * first_size + second_size, size_t(1));
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -326,7 +330,10 @@ struct NgramDistanceImpl
|
||||
--common_stats[needle_ngram_storage[j]];
|
||||
|
||||
/// For now, common stats is a zero array.
|
||||
res[i] = distance * 1.f / std::max(haystack_stats_size + needle_stats_size, size_t(1));
|
||||
|
||||
|
||||
/// For !Symmetric version we should not use haystack_stats_size.
|
||||
res[i] = distance * 1.f / std::max(Symmetric * haystack_stats_size + needle_stats_size, size_t(1));
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -340,6 +347,71 @@ struct NgramDistanceImpl
|
||||
}
|
||||
}
|
||||
|
||||
static void constant_vector(
|
||||
std::string haystack,
|
||||
const ColumnString::Chars & needle_data,
|
||||
const ColumnString::Offsets & needle_offsets,
|
||||
PaddedPODArray<Float32> & res)
|
||||
{
|
||||
/// For symmetric version it is better to use vector_constant
|
||||
if constexpr (Symmetric)
|
||||
{
|
||||
vector_constant(needle_data, needle_offsets, std::move(haystack), res);
|
||||
}
|
||||
else
|
||||
{
|
||||
const size_t haystack_size = haystack.size();
|
||||
haystack.resize(haystack_size + default_padding);
|
||||
|
||||
/// For logic explanation see vector_vector function.
|
||||
const size_t needle_offsets_size = needle_offsets.size();
|
||||
size_t prev_offset = 0;
|
||||
|
||||
NgramStats common_stats = {};
|
||||
|
||||
std::unique_ptr<UInt16[]> needle_ngram_storage(new UInt16[max_string_size]);
|
||||
std::unique_ptr<UInt16[]> haystack_ngram_storage(new UInt16[max_string_size]);
|
||||
|
||||
for (size_t i = 0; i < needle_offsets_size; ++i)
|
||||
{
|
||||
const char * needle = reinterpret_cast<const char *>(&needle_data[prev_offset]);
|
||||
const size_t needle_size = needle_offsets[i] - prev_offset - 1;
|
||||
|
||||
if (needle_size <= max_string_size && haystack_size <= max_string_size)
|
||||
{
|
||||
const size_t needle_stats_size = dispatchSearcher(
|
||||
calculateNeedleStats<true>,
|
||||
needle,
|
||||
needle_size,
|
||||
common_stats,
|
||||
needle_ngram_storage.get());
|
||||
|
||||
size_t distance = needle_stats_size;
|
||||
|
||||
dispatchSearcher(
|
||||
calculateHaystackStatsAndMetric<true>,
|
||||
haystack.data(),
|
||||
haystack_size,
|
||||
common_stats,
|
||||
distance,
|
||||
haystack_ngram_storage.get());
|
||||
|
||||
for (size_t j = 0; j < needle_stats_size; ++j)
|
||||
--common_stats[needle_ngram_storage[j]];
|
||||
|
||||
res[i] = distance * 1.f / std::max(needle_stats_size, size_t(1));
|
||||
}
|
||||
else
|
||||
{
|
||||
res[i] = 1.f;
|
||||
}
|
||||
|
||||
prev_offset = needle_offsets[i];
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
static void vector_constant(
|
||||
const ColumnString::Chars & data,
|
||||
const ColumnString::Offsets & offsets,
|
||||
@ -373,7 +445,8 @@ struct NgramDistanceImpl
|
||||
haystack_size, common_stats,
|
||||
distance,
|
||||
ngram_storage.get());
|
||||
res[i] = distance * 1.f / std::max(haystack_stats_size + needle_stats_size, size_t(1));
|
||||
/// For !Symmetric version we should not use haystack_stats_size.
|
||||
res[i] = distance * 1.f / std::max(Symmetric * haystack_stats_size + needle_stats_size, size_t(1));
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -391,7 +464,6 @@ struct NameNgramDistance
|
||||
{
|
||||
static constexpr auto name = "ngramDistance";
|
||||
};
|
||||
|
||||
struct NameNgramDistanceCaseInsensitive
|
||||
{
|
||||
static constexpr auto name = "ngramDistanceCaseInsensitive";
|
||||
@ -407,10 +479,34 @@ struct NameNgramDistanceUTF8CaseInsensitive
|
||||
static constexpr auto name = "ngramDistanceCaseInsensitiveUTF8";
|
||||
};
|
||||
|
||||
using FunctionNgramDistance = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, false>, NameNgramDistance>;
|
||||
using FunctionNgramDistanceCaseInsensitive = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, true>, NameNgramDistanceCaseInsensitive>;
|
||||
using FunctionNgramDistanceUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, false>, NameNgramDistanceUTF8>;
|
||||
using FunctionNgramDistanceCaseInsensitiveUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, true>, NameNgramDistanceUTF8CaseInsensitive>;
|
||||
struct NameNgramSearch
|
||||
{
|
||||
static constexpr auto name = "ngramSearch";
|
||||
};
|
||||
struct NameNgramSearchCaseInsensitive
|
||||
{
|
||||
static constexpr auto name = "ngramSearchCaseInsensitive";
|
||||
};
|
||||
struct NameNgramSearchUTF8
|
||||
{
|
||||
static constexpr auto name = "ngramSearchUTF8";
|
||||
};
|
||||
|
||||
struct NameNgramSearchUTF8CaseInsensitive
|
||||
{
|
||||
static constexpr auto name = "ngramSearchCaseInsensitiveUTF8";
|
||||
};
|
||||
|
||||
using FunctionNgramDistance = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, false, true>, NameNgramDistance>;
|
||||
using FunctionNgramDistanceCaseInsensitive = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, true, true>, NameNgramDistanceCaseInsensitive>;
|
||||
using FunctionNgramDistanceUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, false, true>, NameNgramDistanceUTF8>;
|
||||
using FunctionNgramDistanceCaseInsensitiveUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, true, true>, NameNgramDistanceUTF8CaseInsensitive>;
|
||||
|
||||
using FunctionNgramSearch = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, false, false>, NameNgramSearch>;
|
||||
using FunctionNgramSearchCaseInsensitive = FunctionsStringSimilarity<NgramDistanceImpl<4, UInt8, false, true, false>, NameNgramSearchCaseInsensitive>;
|
||||
using FunctionNgramSearchUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, false, false>, NameNgramSearchUTF8>;
|
||||
using FunctionNgramSearchCaseInsensitiveUTF8 = FunctionsStringSimilarity<NgramDistanceImpl<3, UInt32, true, true, false>, NameNgramSearchUTF8CaseInsensitive>;
|
||||
|
||||
|
||||
void registerFunctionsStringSimilarity(FunctionFactory & factory)
|
||||
{
|
||||
@ -418,6 +514,11 @@ void registerFunctionsStringSimilarity(FunctionFactory & factory)
|
||||
factory.registerFunction<FunctionNgramDistanceCaseInsensitive>();
|
||||
factory.registerFunction<FunctionNgramDistanceUTF8>();
|
||||
factory.registerFunction<FunctionNgramDistanceCaseInsensitiveUTF8>();
|
||||
|
||||
factory.registerFunction<FunctionNgramSearch>();
|
||||
factory.registerFunction<FunctionNgramSearchCaseInsensitive>();
|
||||
factory.registerFunction<FunctionNgramSearchUTF8>();
|
||||
factory.registerFunction<FunctionNgramSearchCaseInsensitiveUTF8>();
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -110,15 +110,15 @@ public:
|
||||
}
|
||||
else if (col_haystack_const && col_needle_vector)
|
||||
{
|
||||
const String & needle = col_haystack_const->getValue<String>();
|
||||
if (needle.size() > Impl::max_string_size)
|
||||
const String & haystack = col_haystack_const->getValue<String>();
|
||||
if (haystack.size() > Impl::max_string_size)
|
||||
{
|
||||
throw Exception(
|
||||
"String size of needle is too big for function " + getName() + ". Should be at most "
|
||||
"String size of haystack is too big for function " + getName() + ". Should be at most "
|
||||
+ std::to_string(Impl::max_string_size),
|
||||
ErrorCodes::TOO_LARGE_STRING_SIZE);
|
||||
}
|
||||
Impl::vector_constant(col_needle_vector->getChars(), col_needle_vector->getOffsets(), needle, vec_res);
|
||||
Impl::constant_vector(haystack, col_needle_vector->getChars(), col_needle_vector->getOffsets(), vec_res);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -595,7 +595,7 @@ inline bool allowArrayIndex(const DataTypePtr & type0, const DataTypePtr & type1
|
||||
DataTypePtr data_type0 = removeNullable(type0);
|
||||
DataTypePtr data_type1 = removeNullable(type1);
|
||||
|
||||
return ((isNumber(data_type0) || isEnum(data_type0)) && isNumber(data_type1))
|
||||
return ((isNativeNumber(data_type0) || isEnum(data_type0)) && isNativeNumber(data_type1))
|
||||
|| data_type0->equals(*data_type1);
|
||||
}
|
||||
|
||||
|
@ -183,7 +183,7 @@ Columns FunctionArrayIntersect::castColumns(
|
||||
auto & type_nested = type_array->getNestedType();
|
||||
auto type_not_nullable_nested = removeNullable(type_nested);
|
||||
|
||||
const bool is_numeric_or_string = isNumber(type_not_nullable_nested)
|
||||
const bool is_numeric_or_string = isNativeNumber(type_not_nullable_nested)
|
||||
|| isDateOrDateTime(type_not_nullable_nested)
|
||||
|| isStringOrFixedString(type_not_nullable_nested);
|
||||
|
||||
|
@ -37,7 +37,7 @@ public:
|
||||
|
||||
DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override
|
||||
{
|
||||
if (!isNumber(arguments[0]))
|
||||
if (!isNativeNumber(arguments[0]))
|
||||
throw Exception("Illegal type " + arguments[0]->getName() +
|
||||
" of argument of function " + getName() +
|
||||
", expected Integer", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
@ -55,8 +55,8 @@ public:
|
||||
+ ".",
|
||||
ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH);
|
||||
|
||||
if (!isNumber(arguments[0]) || !isNumber(arguments[1]) || !isNumber(arguments[2])
|
||||
|| (arguments.size() == 4 && !isNumber(arguments[3])))
|
||||
if (!isNativeNumber(arguments[0]) || !isNativeNumber(arguments[1]) || !isNativeNumber(arguments[2])
|
||||
|| (arguments.size() == 4 && !isNativeNumber(arguments[3])))
|
||||
throw Exception("All arguments for function " + getName() + " must be numeric.", ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return std::make_shared<DataTypeString>();
|
||||
|
@ -2,6 +2,7 @@
|
||||
#include <Functions/FunctionFactory.h>
|
||||
#include <Functions/FunctionHelpers.h>
|
||||
#include <DataTypes/DataTypeAggregateFunction.h>
|
||||
#include <DataTypes/DataTypesNumber.h>
|
||||
#include <Columns/ColumnAggregateFunction.h>
|
||||
#include <Common/typeid_cast.h>
|
||||
|
||||
@ -60,7 +61,7 @@ public:
|
||||
throw Exception("Argument for function " + getName() + " must have type AggregateFunction - state of aggregate function.",
|
||||
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
|
||||
|
||||
return type->getReturnType();
|
||||
return type->getReturnTypeToPredict();
|
||||
}
|
||||
|
||||
void executeImpl(Block & block, const ColumnNumbers & arguments, size_t result, size_t /*input_rows_count*/) override
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user